From fa93415e860cce590c3392079e93d3ae21ffc83c Mon Sep 17 00:00:00 2001 From: Neelesh Srinivas Salian Date: Sat, 8 Aug 2015 19:51:52 -0700 Subject: [PATCH 001/253] Added yarn-deploy-mode alternative --- docs/submitting-applications.md | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/docs/submitting-applications.md b/docs/submitting-applications.md index e58645274e525..015cef9460629 100644 --- a/docs/submitting-applications.md +++ b/docs/submitting-applications.md @@ -48,6 +48,20 @@ Some of the commonly used options are: * `application-jar`: Path to a bundled jar including your application and all dependencies. The URL must be globally visible inside of your cluster, for instance, an `hdfs://` path or a `file://` path that is present on all nodes. * `application-arguments`: Arguments passed to the main method of your main class, if any +Alternatively, for submitting on yarn, + +{% highlight bash %} +./bin/spark-submit \ + --class + --master + --conf = \ + ... # other options + \ + [application-arguments] +{% endhighlight %} + +* `--master`: The --master parameter is either `yarn-client` or `yarn-cluster`. Defaults to `yarn-client` + A common deployment strategy is to submit your application from a gateway machine that is physically co-located with your worker machines (e.g. Master node in a standalone EC2 cluster). @@ -99,7 +113,7 @@ run it with `--help`. Here are a few examples of common options: /path/to/examples.jar \ 1000 -# Run on a YARN cluster +# Run on a YARN cluster without --deploy mode export HADOOP_CONF_DIR=XXX ./bin/spark-submit \ --class org.apache.spark.examples.SparkPi \ From 437a4d451147f179617628a672eaa795b3b76ea0 Mon Sep 17 00:00:00 2001 From: Neelesh Srinivas Salian Date: Sat, 8 Aug 2015 19:54:04 -0700 Subject: [PATCH 002/253] Moved Master URLs closer above before the examples --- docs/submitting-applications.md | 49 ++++++++++++++++----------------- 1 file changed, 24 insertions(+), 25 deletions(-) diff --git a/docs/submitting-applications.md b/docs/submitting-applications.md index 015cef9460629..4a564ed0a765b 100644 --- a/docs/submitting-applications.md +++ b/docs/submitting-applications.md @@ -62,6 +62,30 @@ Alternatively, for submitting on yarn, * `--master`: The --master parameter is either `yarn-client` or `yarn-cluster`. Defaults to `yarn-client` +# Master URLs + +The master URL passed to Spark can be in one of the following formats: + + + + + + + + + + +
Master URLMeaning
local Run Spark locally with one worker thread (i.e. no parallelism at all).
local[K] Run Spark locally with K worker threads (ideally, set this to the number of cores on your machine).
local[*] Run Spark locally with as many worker threads as logical cores on your machine.
spark://HOST:PORT Connect to the given Spark standalone + cluster master. The port must be whichever one your master is configured to use, which is 7077 by default. +
mesos://HOST:PORT Connect to the given Mesos cluster. + The port must be whichever one your is configured to use, which is 5050 by default. + Or, for a Mesos cluster using ZooKeeper, use mesos://zk://.... +
yarn-client Connect to a YARN cluster in +client mode. The cluster location will be found based on the HADOOP_CONF_DIR or YARN_CONF_DIR variable. +
yarn-cluster Connect to a YARN cluster in +cluster mode. The cluster location will be found based on the HADOOP_CONF_DIR or YARN_CONF_DIR variable. +
+ A common deployment strategy is to submit your application from a gateway machine that is physically co-located with your worker machines (e.g. Master node in a standalone EC2 cluster). @@ -130,31 +154,6 @@ export HADOOP_CONF_DIR=XXX 1000 {% endhighlight %} -# Master URLs - -The master URL passed to Spark can be in one of the following formats: - - - - - - - - - - -
Master URLMeaning
local Run Spark locally with one worker thread (i.e. no parallelism at all).
local[K] Run Spark locally with K worker threads (ideally, set this to the number of cores on your machine).
local[*] Run Spark locally with as many worker threads as logical cores on your machine.
spark://HOST:PORT Connect to the given Spark standalone - cluster master. The port must be whichever one your master is configured to use, which is 7077 by default. -
mesos://HOST:PORT Connect to the given Mesos cluster. - The port must be whichever one your is configured to use, which is 5050 by default. - Or, for a Mesos cluster using ZooKeeper, use mesos://zk://.... -
yarn-client Connect to a YARN cluster in -client mode. The cluster location will be found based on the HADOOP_CONF_DIR or YARN_CONF_DIR variable. -
yarn-cluster Connect to a YARN cluster in -cluster mode. The cluster location will be found based on the HADOOP_CONF_DIR or YARN_CONF_DIR variable. -
- - # Loading Configuration from a File The `spark-submit` script can load default [Spark configuration values](configuration.html) from a From 05fe708c24f07f9661a558dfbe51970aa940e4e5 Mon Sep 17 00:00:00 2001 From: Neelesh Srinivas Salian Date: Mon, 10 Aug 2015 10:14:59 -0700 Subject: [PATCH 003/253] Removed the addition section --- docs/submitting-applications.md | 14 -------------- 1 file changed, 14 deletions(-) diff --git a/docs/submitting-applications.md b/docs/submitting-applications.md index 4a564ed0a765b..a15f83bde24be 100644 --- a/docs/submitting-applications.md +++ b/docs/submitting-applications.md @@ -48,20 +48,6 @@ Some of the commonly used options are: * `application-jar`: Path to a bundled jar including your application and all dependencies. The URL must be globally visible inside of your cluster, for instance, an `hdfs://` path or a `file://` path that is present on all nodes. * `application-arguments`: Arguments passed to the main method of your main class, if any -Alternatively, for submitting on yarn, - -{% highlight bash %} -./bin/spark-submit \ - --class - --master - --conf = \ - ... # other options - \ - [application-arguments] -{% endhighlight %} - -* `--master`: The --master parameter is either `yarn-client` or `yarn-cluster`. Defaults to `yarn-client` - # Master URLs The master URL passed to Spark can be in one of the following formats: From 98624e89c6b303db4fc30408e14705df021ca591 Mon Sep 17 00:00:00 2001 From: Neelesh Srinivas Salian Date: Mon, 10 Aug 2015 10:16:14 -0700 Subject: [PATCH 004/253] Added a section for alternative submission. Distinguished from the shifting of Master URLS --- docs/submitting-applications.md | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/docs/submitting-applications.md b/docs/submitting-applications.md index a15f83bde24be..4a564ed0a765b 100644 --- a/docs/submitting-applications.md +++ b/docs/submitting-applications.md @@ -48,6 +48,20 @@ Some of the commonly used options are: * `application-jar`: Path to a bundled jar including your application and all dependencies. The URL must be globally visible inside of your cluster, for instance, an `hdfs://` path or a `file://` path that is present on all nodes. * `application-arguments`: Arguments passed to the main method of your main class, if any +Alternatively, for submitting on yarn, + +{% highlight bash %} +./bin/spark-submit \ + --class + --master + --conf = \ + ... # other options + \ + [application-arguments] +{% endhighlight %} + +* `--master`: The --master parameter is either `yarn-client` or `yarn-cluster`. Defaults to `yarn-client` + # Master URLs The master URL passed to Spark can be in one of the following formats: From 00b655cced637e1c3b750c19266086b9dcd7c158 Mon Sep 17 00:00:00 2001 From: Feynman Liang Date: Mon, 10 Aug 2015 11:01:45 -0700 Subject: [PATCH 005/253] [SPARK-9755] [MLLIB] Add docs to MultivariateOnlineSummarizer methods Adds method documentations back to `MultivariateOnlineSummarizer`, which were present in 1.4 but disappeared somewhere along the way to 1.5. jkbradley Author: Feynman Liang Closes #8045 from feynmanliang/SPARK-9755 and squashes the following commits: af67fde [Feynman Liang] Add MultivariateOnlineSummarizer docs --- .../stat/MultivariateOnlineSummarizer.scala | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala index 62da9f2ef22a3..64e4be0ebb97e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala @@ -153,6 +153,8 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S } /** + * Sample mean of each dimension. + * * @since 1.1.0 */ override def mean: Vector = { @@ -168,6 +170,8 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S } /** + * Sample variance of each dimension. + * * @since 1.1.0 */ override def variance: Vector = { @@ -193,11 +197,15 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S } /** + * Sample size. + * * @since 1.1.0 */ override def count: Long = totalCnt /** + * Number of nonzero elements in each dimension. + * * @since 1.1.0 */ override def numNonzeros: Vector = { @@ -207,6 +215,8 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S } /** + * Maximum value of each dimension. + * * @since 1.1.0 */ override def max: Vector = { @@ -221,6 +231,8 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S } /** + * Minimum value of each dimension. + * * @since 1.1.0 */ override def min: Vector = { @@ -235,6 +247,8 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S } /** + * L2 (Euclidian) norm of each dimension. + * * @since 1.2.0 */ override def normL2: Vector = { @@ -252,6 +266,8 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S } /** + * L1 norm of each dimension. + * * @since 1.2.0 */ override def normL1: Vector = { From d285212756168200383bf4df2c951bd80a492a7c Mon Sep 17 00:00:00 2001 From: Mahmoud Lababidi Date: Mon, 10 Aug 2015 13:02:01 -0700 Subject: [PATCH 006/253] Fixed AtmoicReference<> Example Author: Mahmoud Lababidi Closes #8076 from lababidi/master and squashes the following commits: af4553b [Mahmoud Lababidi] Fixed AtmoicReference<> Example --- docs/streaming-kafka-integration.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/streaming-kafka-integration.md b/docs/streaming-kafka-integration.md index 775d508d4879b..7571e22575efd 100644 --- a/docs/streaming-kafka-integration.md +++ b/docs/streaming-kafka-integration.md @@ -152,7 +152,7 @@ Next, we discuss how to use this approach in your streaming application.
// Hold a reference to the current offset ranges, so it can be used downstream - final AtomicReference offsetRanges = new AtomicReference(); + final AtomicReference offsetRanges = new AtomicReference<>(); directKafkaStream.transformToPair( new Function, JavaPairRDD>() { From 0fe66744f16854fc8cd8a72174de93a788e3cf6c Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Mon, 10 Aug 2015 13:05:03 -0700 Subject: [PATCH 007/253] [SPARK-9784] [SQL] Exchange.isUnsafe should check whether codegen and unsafe are enabled Exchange.isUnsafe should check whether codegen and unsafe are enabled. Author: Josh Rosen Closes #8073 from JoshRosen/SPARK-9784 and squashes the following commits: 7a1019f [Josh Rosen] [SPARK-9784] Exchange.isUnsafe should check whether codegen and unsafe are enabled --- .../main/scala/org/apache/spark/sql/execution/Exchange.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala index b89e634761eb1..029f2264a6a27 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala @@ -46,7 +46,7 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una * Returns true iff we can support the data type, and we are not doing range partitioning. */ private lazy val tungstenMode: Boolean = { - GenerateUnsafeProjection.canSupport(child.schema) && + unsafeEnabled && codegenEnabled && GenerateUnsafeProjection.canSupport(child.schema) && !newPartitioning.isInstanceOf[RangePartitioning] } From 40ed2af587cedadc6e5249031857a922b3b234ca Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Mon, 10 Aug 2015 13:49:23 -0700 Subject: [PATCH 008/253] [SPARK-9763][SQL] Minimize exposure of internal SQL classes. There are a few changes in this pull request: 1. Moved all data sources to execution.datasources, except the public JDBC APIs. 2. In order to maintain backward compatibility from 1, added a backward compatibility translation map in data source resolution. 3. Moved ui and metric package into execution. 4. Added more documentation on some internal classes. 5. Renamed DataSourceRegister.format -> shortName. 6. Added "override" modifier on shortName. 7. Removed IntSQLMetric. Author: Reynold Xin Closes #8056 from rxin/SPARK-9763 and squashes the following commits: 9df4801 [Reynold Xin] Removed hardcoded name in test cases. d9babc6 [Reynold Xin] Shorten. e484419 [Reynold Xin] Removed VisibleForTesting. 171b812 [Reynold Xin] MimaExcludes. 2041389 [Reynold Xin] Compile ... 79dda42 [Reynold Xin] Compile. 0818ba3 [Reynold Xin] Removed IntSQLMetric. c46884f [Reynold Xin] Two more fixes. f9aa88d [Reynold Xin] [SPARK-9763][SQL] Minimize exposure of internal SQL classes. --- project/MimaExcludes.scala | 24 +- ...pache.spark.sql.sources.DataSourceRegister | 6 +- .../ui/static/spark-sql-viz.css | 0 .../ui/static/spark-sql-viz.js | 0 .../org/apache/spark/sql/DataFrame.scala | 2 +- .../apache/spark/sql/DataFrameReader.scala | 6 +- .../apache/spark/sql/DataFrameWriter.scala | 6 +- .../org/apache/spark/sql/SQLContext.scala | 2 +- .../spark/sql/execution/SQLExecution.scala | 2 +- .../spark/sql/execution/SparkPlan.scala | 8 +- .../spark/sql/execution/basicOperators.scala | 2 +- .../sql/execution/datasources/DDLParser.scala | 185 +++++++++ .../execution/datasources/DefaultSource.scala | 64 ++++ .../datasources/InsertIntoDataSource.scala | 23 +- .../datasources/ResolvedDataSource.scala | 204 ++++++++++ .../spark/sql/execution/datasources/ddl.scala | 352 +----------------- .../datasources/jdbc/DefaultSource.scala | 62 +++ .../datasources/jdbc/DriverRegistry.scala | 60 +++ .../datasources/jdbc/DriverWrapper.scala | 48 +++ .../datasources}/jdbc/JDBCRDD.scala | 9 +- .../datasources}/jdbc/JDBCRelation.scala | 41 +- .../datasources/jdbc/JdbcUtils.scala | 219 +++++++++++ .../datasources}/json/InferSchema.scala | 4 +- .../datasources}/json/JSONRelation.scala | 7 +- .../datasources}/json/JacksonGenerator.scala | 2 +- .../datasources}/json/JacksonParser.scala | 4 +- .../datasources}/json/JacksonUtils.scala | 2 +- .../parquet/CatalystReadSupport.scala | 2 +- .../parquet/CatalystRecordMaterializer.scala | 2 +- .../parquet/CatalystRowConverter.scala | 2 +- .../parquet/CatalystSchemaConverter.scala | 4 +- .../DirectParquetOutputCommitter.scala | 2 +- .../parquet/ParquetConverter.scala | 2 +- .../datasources}/parquet/ParquetFilters.scala | 2 +- .../parquet/ParquetRelation.scala | 4 +- .../parquet/ParquetTableSupport.scala | 2 +- .../parquet/ParquetTypesConverter.scala | 2 +- .../{ => execution}/metric/SQLMetrics.scala | 36 +- .../apache/spark/sql/execution/package.scala | 8 +- .../ui/AllExecutionsPage.scala | 2 +- .../{ => execution}/ui/ExecutionPage.scala | 2 +- .../sql/{ => execution}/ui/SQLListener.scala | 7 +- .../spark/sql/{ => execution}/ui/SQLTab.scala | 6 +- .../{ => execution}/ui/SparkPlanGraph.scala | 4 +- .../org/apache/spark/sql/jdbc/JdbcUtils.scala | 52 --- .../org/apache/spark/sql/jdbc/jdbc.scala | 250 ------------- .../apache/spark/sql/sources/interfaces.scala | 15 +- .../parquet/test/avro/CompatibilityTest.java | 4 +- .../parquet/test/avro/Nested.java | 30 +- .../parquet/test/avro/ParquetAvroCompat.java | 106 +++--- .../org/apache/spark/sql/DataFrameSuite.scala | 4 +- .../datasources}/json/JsonSuite.scala | 4 +- .../datasources}/json/TestJsonData.scala | 2 +- .../ParquetAvroCompatibilitySuite.scala | 4 +- .../parquet/ParquetCompatibilityTest.scala | 2 +- .../parquet/ParquetFilterSuite.scala | 2 +- .../datasources}/parquet/ParquetIOSuite.scala | 4 +- .../ParquetPartitionDiscoverySuite.scala | 2 +- .../parquet/ParquetQuerySuite.scala | 2 +- .../parquet/ParquetSchemaSuite.scala | 2 +- .../datasources}/parquet/ParquetTest.scala | 2 +- .../ParquetThriftCompatibilitySuite.scala | 2 +- .../metric/SQLMetricsSuite.scala | 12 +- .../{ => execution}/ui/SQLListenerSuite.scala | 4 +- .../sources/CreateTableAsSelectSuite.scala | 20 +- .../sql/sources/DDLSourceLoadSuite.scala | 59 +-- .../sql/sources/ResolvedDataSourceSuite.scala | 39 +- .../spark/sql/hive/HiveMetastoreCatalog.scala | 2 +- .../spark/sql/hive/orc/OrcRelation.scala | 6 +- .../spark/sql/hive/HiveParquetSuite.scala | 2 +- .../sql/hive/MetastoreDataSourcesSuite.scala | 2 +- .../hive/ParquetHiveCompatibilitySuite.scala | 2 +- .../sql/hive/execution/SQLQuerySuite.scala | 2 +- .../apache/spark/sql/hive/parquetSuites.scala | 2 +- .../ParquetHadoopFsRelationSuite.scala | 4 +- .../SimpleTextHadoopFsRelationSuite.scala | 2 +- 76 files changed, 1114 insertions(+), 966 deletions(-) rename sql/core/src/main/resources/org/apache/spark/sql/{ => execution}/ui/static/spark-sql-viz.css (100%) rename sql/core/src/main/resources/org/apache/spark/sql/{ => execution}/ui/static/spark-sql-viz.js (100%) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DDLParser.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DefaultSource.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ResolvedDataSource.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/DefaultSource.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/DriverRegistry.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/DriverWrapper.scala rename sql/core/src/main/scala/org/apache/spark/sql/{ => execution/datasources}/jdbc/JDBCRDD.scala (98%) rename sql/core/src/main/scala/org/apache/spark/sql/{ => execution/datasources}/jdbc/JDBCRelation.scala (71%) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala rename sql/core/src/main/scala/org/apache/spark/sql/{ => execution/datasources}/json/InferSchema.scala (98%) rename sql/core/src/main/scala/org/apache/spark/sql/{ => execution/datasources}/json/JSONRelation.scala (97%) rename sql/core/src/main/scala/org/apache/spark/sql/{ => execution/datasources}/json/JacksonGenerator.scala (98%) rename sql/core/src/main/scala/org/apache/spark/sql/{ => execution/datasources}/json/JacksonParser.scala (98%) rename sql/core/src/main/scala/org/apache/spark/sql/{ => execution/datasources}/json/JacksonUtils.scala (95%) rename sql/core/src/main/scala/org/apache/spark/sql/{ => execution/datasources}/parquet/CatalystReadSupport.scala (99%) rename sql/core/src/main/scala/org/apache/spark/sql/{ => execution/datasources}/parquet/CatalystRecordMaterializer.scala (96%) rename sql/core/src/main/scala/org/apache/spark/sql/{ => execution/datasources}/parquet/CatalystRowConverter.scala (99%) rename sql/core/src/main/scala/org/apache/spark/sql/{ => execution/datasources}/parquet/CatalystSchemaConverter.scala (99%) rename sql/core/src/main/scala/org/apache/spark/sql/{ => execution/datasources}/parquet/DirectParquetOutputCommitter.scala (98%) rename sql/core/src/main/scala/org/apache/spark/sql/{ => execution/datasources}/parquet/ParquetConverter.scala (96%) rename sql/core/src/main/scala/org/apache/spark/sql/{ => execution/datasources}/parquet/ParquetFilters.scala (99%) rename sql/core/src/main/scala/org/apache/spark/sql/{ => execution/datasources}/parquet/ParquetRelation.scala (99%) rename sql/core/src/main/scala/org/apache/spark/sql/{ => execution/datasources}/parquet/ParquetTableSupport.scala (99%) rename sql/core/src/main/scala/org/apache/spark/sql/{ => execution/datasources}/parquet/ParquetTypesConverter.scala (99%) rename sql/core/src/main/scala/org/apache/spark/sql/{ => execution}/metric/SQLMetrics.scala (77%) rename sql/core/src/main/scala/org/apache/spark/sql/{ => execution}/ui/AllExecutionsPage.scala (99%) rename sql/core/src/main/scala/org/apache/spark/sql/{ => execution}/ui/ExecutionPage.scala (99%) rename sql/core/src/main/scala/org/apache/spark/sql/{ => execution}/ui/SQLListener.scala (98%) rename sql/core/src/main/scala/org/apache/spark/sql/{ => execution}/ui/SQLTab.scala (90%) rename sql/core/src/main/scala/org/apache/spark/sql/{ => execution}/ui/SparkPlanGraph.scala (97%) delete mode 100644 sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcUtils.scala delete mode 100644 sql/core/src/main/scala/org/apache/spark/sql/jdbc/jdbc.scala rename sql/core/src/test/gen-java/org/apache/spark/sql/{ => execution/datasources}/parquet/test/avro/CompatibilityTest.java (93%) rename sql/core/src/test/gen-java/org/apache/spark/sql/{ => execution/datasources}/parquet/test/avro/Nested.java (78%) rename sql/core/src/test/gen-java/org/apache/spark/sql/{ => execution/datasources}/parquet/test/avro/ParquetAvroCompat.java (83%) rename sql/core/src/test/scala/org/apache/spark/sql/{ => execution/datasources}/json/JsonSuite.scala (99%) rename sql/core/src/test/scala/org/apache/spark/sql/{ => execution/datasources}/json/TestJsonData.scala (99%) rename sql/core/src/test/scala/org/apache/spark/sql/{ => execution/datasources}/parquet/ParquetAvroCompatibilitySuite.scala (96%) rename sql/core/src/test/scala/org/apache/spark/sql/{ => execution/datasources}/parquet/ParquetCompatibilityTest.scala (97%) rename sql/core/src/test/scala/org/apache/spark/sql/{ => execution/datasources}/parquet/ParquetFilterSuite.scala (99%) rename sql/core/src/test/scala/org/apache/spark/sql/{ => execution/datasources}/parquet/ParquetIOSuite.scala (99%) rename sql/core/src/test/scala/org/apache/spark/sql/{ => execution/datasources}/parquet/ParquetPartitionDiscoverySuite.scala (99%) rename sql/core/src/test/scala/org/apache/spark/sql/{ => execution/datasources}/parquet/ParquetQuerySuite.scala (99%) rename sql/core/src/test/scala/org/apache/spark/sql/{ => execution/datasources}/parquet/ParquetSchemaSuite.scala (99%) rename sql/core/src/test/scala/org/apache/spark/sql/{ => execution/datasources}/parquet/ParquetTest.scala (98%) rename sql/core/src/test/scala/org/apache/spark/sql/{ => execution/datasources}/parquet/ParquetThriftCompatibilitySuite.scala (98%) rename sql/core/src/test/scala/org/apache/spark/sql/{ => execution}/metric/SQLMetricsSuite.scala (92%) rename sql/core/src/test/scala/org/apache/spark/sql/{ => execution}/ui/SQLListenerSuite.scala (99%) diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index b60ae784c3798..90261ca3d61aa 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -62,8 +62,6 @@ object MimaExcludes { "org.apache.spark.ml.classification.LogisticCostFun.this"), // SQL execution is considered private. excludePackage("org.apache.spark.sql.execution"), - // Parquet support is considered private. - excludePackage("org.apache.spark.sql.parquet"), // The old JSON RDD is removed in favor of streaming Jackson ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.json.JsonRDD$"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.json.JsonRDD"), @@ -155,7 +153,27 @@ object MimaExcludes { ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.SqlNewHadoopRDD$NewHadoopMapPartitionsWithSplitRDD$"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.PartitionSpec$"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.DescribeCommand"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.DDLException") + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.DDLException"), + // SPARK-9763 Minimize exposure of internal SQL classes + excludePackage("org.apache.spark.sql.parquet"), + excludePackage("org.apache.spark.sql.json"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.JDBCRDD$DecimalConversion$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.JDBCPartition"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.JdbcUtils$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.JDBCRDD$DecimalConversion"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.JDBCPartitioningInfo$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.JDBCPartition$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.package"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.JDBCRDD$JDBCConversion"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.JDBCRDD$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.package$DriverWrapper"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.JDBCRDD"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.JDBCPartitioningInfo"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.JdbcUtils"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.DefaultSource"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.JDBCRelation$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.package$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.JDBCRelation") ) ++ Seq( // SPARK-4751 Dynamic allocation for standalone mode ProblemFilters.exclude[MissingMethodProblem]( diff --git a/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister b/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister index cc32d4b72748e..ca50000b4756e 100644 --- a/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister +++ b/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister @@ -1,3 +1,3 @@ -org.apache.spark.sql.jdbc.DefaultSource -org.apache.spark.sql.json.DefaultSource -org.apache.spark.sql.parquet.DefaultSource +org.apache.spark.sql.execution.datasources.jdbc.DefaultSource +org.apache.spark.sql.execution.datasources.json.DefaultSource +org.apache.spark.sql.execution.datasources.parquet.DefaultSource diff --git a/sql/core/src/main/resources/org/apache/spark/sql/ui/static/spark-sql-viz.css b/sql/core/src/main/resources/org/apache/spark/sql/execution/ui/static/spark-sql-viz.css similarity index 100% rename from sql/core/src/main/resources/org/apache/spark/sql/ui/static/spark-sql-viz.css rename to sql/core/src/main/resources/org/apache/spark/sql/execution/ui/static/spark-sql-viz.css diff --git a/sql/core/src/main/resources/org/apache/spark/sql/ui/static/spark-sql-viz.js b/sql/core/src/main/resources/org/apache/spark/sql/execution/ui/static/spark-sql-viz.js similarity index 100% rename from sql/core/src/main/resources/org/apache/spark/sql/ui/static/spark-sql-viz.js rename to sql/core/src/main/resources/org/apache/spark/sql/execution/ui/static/spark-sql-viz.js diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 570b8b2d5928d..27b994f1f0caf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -39,7 +39,7 @@ import org.apache.spark.sql.catalyst.plans.{Inner, JoinType} import org.apache.spark.sql.catalyst.{CatalystTypeConverters, ScalaReflection, SqlParser} import org.apache.spark.sql.execution.{EvaluatePython, ExplainCommand, LogicalRDD, SQLExecution} import org.apache.spark.sql.execution.datasources.{CreateTableUsingAsSelect, LogicalRelation} -import org.apache.spark.sql.json.JacksonGenerator +import org.apache.spark.sql.execution.datasources.json.JacksonGenerator import org.apache.spark.sql.sources.HadoopFsRelation import org.apache.spark.sql.types._ import org.apache.spark.storage.StorageLevel diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index 85f33c5e99523..9ea955b010017 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -25,10 +25,10 @@ import org.apache.spark.annotation.Experimental import org.apache.spark.api.java.JavaRDD import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.rdd.RDD +import org.apache.spark.sql.execution.datasources.jdbc.{JDBCPartition, JDBCPartitioningInfo, JDBCRelation} +import org.apache.spark.sql.execution.datasources.json.JSONRelation +import org.apache.spark.sql.execution.datasources.parquet.ParquetRelation import org.apache.spark.sql.execution.datasources.{LogicalRelation, ResolvedDataSource} -import org.apache.spark.sql.jdbc.{JDBCPartition, JDBCPartitioningInfo, JDBCRelation} -import org.apache.spark.sql.json.JSONRelation -import org.apache.spark.sql.parquet.ParquetRelation import org.apache.spark.sql.types.StructType import org.apache.spark.{Logging, Partition} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index 2a4992db09bc2..5fa11da4c38cd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -23,8 +23,8 @@ import org.apache.spark.annotation.Experimental import org.apache.spark.sql.catalyst.{SqlParser, TableIdentifier} import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation import org.apache.spark.sql.catalyst.plans.logical.InsertIntoTable +import org.apache.spark.sql.execution.datasources.jdbc.JdbcUtils import org.apache.spark.sql.execution.datasources.{CreateTableUsingAsSelect, ResolvedDataSource} -import org.apache.spark.sql.jdbc.{JDBCWriteDetails, JdbcUtils} import org.apache.spark.sql.sources.HadoopFsRelation @@ -264,7 +264,7 @@ final class DataFrameWriter private[sql](df: DataFrame) { // Create the table if the table didn't exist. if (!tableExists) { - val schema = JDBCWriteDetails.schemaString(df, url) + val schema = JdbcUtils.schemaString(df, url) val sql = s"CREATE TABLE $table ($schema)" conn.prepareStatement(sql).executeUpdate() } @@ -272,7 +272,7 @@ final class DataFrameWriter private[sql](df: DataFrame) { conn.close() } - JDBCWriteDetails.saveTable(df, url, table, connectionProperties) + JdbcUtils.saveTable(df, url, table, connectionProperties) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 832572571cabd..f73bb0488c984 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -43,7 +43,7 @@ import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.sources.BaseRelation import org.apache.spark.sql.types._ -import org.apache.spark.sql.ui.{SQLListener, SQLTab} +import org.apache.spark.sql.execution.ui.{SQLListener, SQLTab} import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.Utils diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala index 97f1323e97835..cee58218a885b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala @@ -21,7 +21,7 @@ import java.util.concurrent.atomic.AtomicLong import org.apache.spark.SparkContext import org.apache.spark.sql.SQLContext -import org.apache.spark.sql.ui.SparkPlanGraph +import org.apache.spark.sql.execution.ui.SparkPlanGraph import org.apache.spark.util.Utils private[sql] object SQLExecution { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index 1915496d16205..9ba5cf2d2b39e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -32,7 +32,7 @@ import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.physical._ -import org.apache.spark.sql.metric.{IntSQLMetric, LongSQLMetric, SQLMetric, SQLMetrics} +import org.apache.spark.sql.execution.metric.{LongSQLMetric, SQLMetric, SQLMetrics} import org.apache.spark.sql.types.DataType object SparkPlan { @@ -98,12 +98,6 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ */ private[sql] def metrics: Map[String, SQLMetric[_, _]] = defaultMetrics - /** - * Return a IntSQLMetric according to the name. - */ - private[sql] def intMetric(name: String): IntSQLMetric = - metrics(name).asInstanceOf[IntSQLMetric] - /** * Return a LongSQLMetric according to the name. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index 24950f26061f7..bf2de244c8e4a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.CatalystTypeConverters import org.apache.spark.sql.catalyst.errors._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical._ -import org.apache.spark.sql.metric.SQLMetrics +import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.types.StructType import org.apache.spark.util.collection.ExternalSorter import org.apache.spark.util.collection.unsafe.sort.PrefixComparator diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DDLParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DDLParser.scala new file mode 100644 index 0000000000000..6c462fa30461b --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DDLParser.scala @@ -0,0 +1,185 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.execution.datasources + +import scala.language.implicitConversions +import scala.util.matching.Regex + +import org.apache.spark.Logging +import org.apache.spark.sql.SaveMode +import org.apache.spark.sql.catalyst.{TableIdentifier, AbstractSparkSQLParser} +import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.types._ + + +/** + * A parser for foreign DDL commands. + */ +class DDLParser(parseQuery: String => LogicalPlan) + extends AbstractSparkSQLParser with DataTypeParser with Logging { + + def parse(input: String, exceptionOnError: Boolean): LogicalPlan = { + try { + parse(input) + } catch { + case ddlException: DDLException => throw ddlException + case _ if !exceptionOnError => parseQuery(input) + case x: Throwable => throw x + } + } + + // Keyword is a convention with AbstractSparkSQLParser, which will scan all of the `Keyword` + // properties via reflection the class in runtime for constructing the SqlLexical object + protected val CREATE = Keyword("CREATE") + protected val TEMPORARY = Keyword("TEMPORARY") + protected val TABLE = Keyword("TABLE") + protected val IF = Keyword("IF") + protected val NOT = Keyword("NOT") + protected val EXISTS = Keyword("EXISTS") + protected val USING = Keyword("USING") + protected val OPTIONS = Keyword("OPTIONS") + protected val DESCRIBE = Keyword("DESCRIBE") + protected val EXTENDED = Keyword("EXTENDED") + protected val AS = Keyword("AS") + protected val COMMENT = Keyword("COMMENT") + protected val REFRESH = Keyword("REFRESH") + + protected lazy val ddl: Parser[LogicalPlan] = createTable | describeTable | refreshTable + + protected def start: Parser[LogicalPlan] = ddl + + /** + * `CREATE [TEMPORARY] TABLE avroTable [IF NOT EXISTS] + * USING org.apache.spark.sql.avro + * OPTIONS (path "../hive/src/test/resources/data/files/episodes.avro")` + * or + * `CREATE [TEMPORARY] TABLE avroTable(intField int, stringField string...) [IF NOT EXISTS] + * USING org.apache.spark.sql.avro + * OPTIONS (path "../hive/src/test/resources/data/files/episodes.avro")` + * or + * `CREATE [TEMPORARY] TABLE avroTable [IF NOT EXISTS] + * USING org.apache.spark.sql.avro + * OPTIONS (path "../hive/src/test/resources/data/files/episodes.avro")` + * AS SELECT ... + */ + protected lazy val createTable: Parser[LogicalPlan] = { + // TODO: Support database.table. + (CREATE ~> TEMPORARY.? <~ TABLE) ~ (IF ~> NOT <~ EXISTS).? ~ ident ~ + tableCols.? ~ (USING ~> className) ~ (OPTIONS ~> options).? ~ (AS ~> restInput).? ^^ { + case temp ~ allowExisting ~ tableName ~ columns ~ provider ~ opts ~ query => + if (temp.isDefined && allowExisting.isDefined) { + throw new DDLException( + "a CREATE TEMPORARY TABLE statement does not allow IF NOT EXISTS clause.") + } + + val options = opts.getOrElse(Map.empty[String, String]) + if (query.isDefined) { + if (columns.isDefined) { + throw new DDLException( + "a CREATE TABLE AS SELECT statement does not allow column definitions.") + } + // When IF NOT EXISTS clause appears in the query, the save mode will be ignore. + val mode = if (allowExisting.isDefined) { + SaveMode.Ignore + } else if (temp.isDefined) { + SaveMode.Overwrite + } else { + SaveMode.ErrorIfExists + } + + val queryPlan = parseQuery(query.get) + CreateTableUsingAsSelect(tableName, + provider, + temp.isDefined, + Array.empty[String], + mode, + options, + queryPlan) + } else { + val userSpecifiedSchema = columns.flatMap(fields => Some(StructType(fields))) + CreateTableUsing( + tableName, + userSpecifiedSchema, + provider, + temp.isDefined, + options, + allowExisting.isDefined, + managedIfNoPath = false) + } + } + } + + protected lazy val tableCols: Parser[Seq[StructField]] = "(" ~> repsep(column, ",") <~ ")" + + /* + * describe [extended] table avroTable + * This will display all columns of table `avroTable` includes column_name,column_type,comment + */ + protected lazy val describeTable: Parser[LogicalPlan] = + (DESCRIBE ~> opt(EXTENDED)) ~ (ident <~ ".").? ~ ident ^^ { + case e ~ db ~ tbl => + val tblIdentifier = db match { + case Some(dbName) => + Seq(dbName, tbl) + case None => + Seq(tbl) + } + DescribeCommand(UnresolvedRelation(tblIdentifier, None), e.isDefined) + } + + protected lazy val refreshTable: Parser[LogicalPlan] = + REFRESH ~> TABLE ~> (ident <~ ".").? ~ ident ^^ { + case maybeDatabaseName ~ tableName => + RefreshTable(TableIdentifier(tableName, maybeDatabaseName)) + } + + protected lazy val options: Parser[Map[String, String]] = + "(" ~> repsep(pair, ",") <~ ")" ^^ { case s: Seq[(String, String)] => s.toMap } + + protected lazy val className: Parser[String] = repsep(ident, ".") ^^ { case s => s.mkString(".")} + + override implicit def regexToParser(regex: Regex): Parser[String] = acceptMatch( + s"identifier matching regex $regex", { + case lexical.Identifier(str) if regex.unapplySeq(str).isDefined => str + case lexical.Keyword(str) if regex.unapplySeq(str).isDefined => str + } + ) + + protected lazy val optionPart: Parser[String] = "[_a-zA-Z][_a-zA-Z0-9]*".r ^^ { + case name => name + } + + protected lazy val optionName: Parser[String] = repsep(optionPart, ".") ^^ { + case parts => parts.mkString(".") + } + + protected lazy val pair: Parser[(String, String)] = + optionName ~ stringLit ^^ { case k ~ v => (k, v) } + + protected lazy val column: Parser[StructField] = + ident ~ dataType ~ (COMMENT ~> stringLit).? ^^ { case columnName ~ typ ~ cm => + val meta = cm match { + case Some(comment) => + new MetadataBuilder().putString(COMMENT.str.toLowerCase, comment).build() + case None => Metadata.empty + } + + StructField(columnName, typ, nullable = true, meta) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DefaultSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DefaultSource.scala new file mode 100644 index 0000000000000..6e4cc4de7f651 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DefaultSource.scala @@ -0,0 +1,64 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.execution.datasources + +import java.util.Properties + +import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.execution.datasources.jdbc.{JDBCRelation, JDBCPartitioningInfo, DriverRegistry} +import org.apache.spark.sql.sources.{BaseRelation, DataSourceRegister, RelationProvider} + + +class DefaultSource extends RelationProvider with DataSourceRegister { + + override def shortName(): String = "jdbc" + + /** Returns a new base relation with the given parameters. */ + override def createRelation( + sqlContext: SQLContext, + parameters: Map[String, String]): BaseRelation = { + val url = parameters.getOrElse("url", sys.error("Option 'url' not specified")) + val driver = parameters.getOrElse("driver", null) + val table = parameters.getOrElse("dbtable", sys.error("Option 'dbtable' not specified")) + val partitionColumn = parameters.getOrElse("partitionColumn", null) + val lowerBound = parameters.getOrElse("lowerBound", null) + val upperBound = parameters.getOrElse("upperBound", null) + val numPartitions = parameters.getOrElse("numPartitions", null) + + if (driver != null) DriverRegistry.register(driver) + + if (partitionColumn != null + && (lowerBound == null || upperBound == null || numPartitions == null)) { + sys.error("Partitioning incompletely specified") + } + + val partitionInfo = if (partitionColumn == null) { + null + } else { + JDBCPartitioningInfo( + partitionColumn, + lowerBound.toLong, + upperBound.toLong, + numPartitions.toInt) + } + val parts = JDBCRelation.columnPartition(partitionInfo) + val properties = new Properties() // Additional properties that we will pass to getConnection + parameters.foreach(kv => properties.setProperty(kv._1, kv._2)) + JDBCRelation(url, table, parts, properties)(sqlContext) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoDataSource.scala index 6ccde7693bd34..3b7dc2e8d0210 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoDataSource.scala @@ -17,27 +17,10 @@ package org.apache.spark.sql.execution.datasources -import java.io.IOException -import java.util.{Date, UUID} - -import scala.collection.JavaConversions.asScalaIterator - -import org.apache.hadoop.fs.Path -import org.apache.hadoop.mapreduce._ -import org.apache.hadoop.mapreduce.lib.output.{FileOutputCommitter => MapReduceFileOutputCommitter, FileOutputFormat} -import org.apache.spark._ -import org.apache.spark.mapred.SparkHadoopMapRedUtil -import org.apache.spark.mapreduce.SparkHadoopMapReduceUtil import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.codegen.GenerateProjection -import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project} -import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} -import org.apache.spark.sql.execution.{RunnableCommand, SQLExecution} -import org.apache.spark.sql.sources._ -import org.apache.spark.sql.types.StringType -import org.apache.spark.util.{Utils, SerializableConfiguration} +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.execution.RunnableCommand +import org.apache.spark.sql.sources.InsertableRelation /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ResolvedDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ResolvedDataSource.scala new file mode 100644 index 0000000000000..7770bbd712f04 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ResolvedDataSource.scala @@ -0,0 +1,204 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.execution.datasources + +import java.util.ServiceLoader + +import scala.collection.JavaConversions._ +import scala.language.{existentials, implicitConversions} +import scala.util.{Success, Failure, Try} + +import org.apache.hadoop.fs.Path + +import org.apache.spark.Logging +import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.sql.{DataFrame, SaveMode, AnalysisException, SQLContext} +import org.apache.spark.sql.sources._ +import org.apache.spark.sql.types.{CalendarIntervalType, StructType} +import org.apache.spark.util.Utils + + +case class ResolvedDataSource(provider: Class[_], relation: BaseRelation) + + +object ResolvedDataSource extends Logging { + + /** A map to maintain backward compatibility in case we move data sources around. */ + private val backwardCompatibilityMap = Map( + "org.apache.spark.sql.jdbc" -> classOf[jdbc.DefaultSource].getCanonicalName, + "org.apache.spark.sql.jdbc.DefaultSource" -> classOf[jdbc.DefaultSource].getCanonicalName, + "org.apache.spark.sql.json" -> classOf[json.DefaultSource].getCanonicalName, + "org.apache.spark.sql.json.DefaultSource" -> classOf[json.DefaultSource].getCanonicalName, + "org.apache.spark.sql.parquet" -> classOf[parquet.DefaultSource].getCanonicalName, + "org.apache.spark.sql.parquet.DefaultSource" -> classOf[parquet.DefaultSource].getCanonicalName + ) + + /** Given a provider name, look up the data source class definition. */ + def lookupDataSource(provider0: String): Class[_] = { + val provider = backwardCompatibilityMap.getOrElse(provider0, provider0) + val provider2 = s"$provider.DefaultSource" + val loader = Utils.getContextOrSparkClassLoader + val serviceLoader = ServiceLoader.load(classOf[DataSourceRegister], loader) + + serviceLoader.iterator().filter(_.shortName().equalsIgnoreCase(provider)).toList match { + /** the provider format did not match any given registered aliases */ + case Nil => Try(loader.loadClass(provider)).orElse(Try(loader.loadClass(provider2))) match { + case Success(dataSource) => dataSource + case Failure(error) => + if (provider.startsWith("org.apache.spark.sql.hive.orc")) { + throw new ClassNotFoundException( + "The ORC data source must be used with Hive support enabled.", error) + } else { + throw new ClassNotFoundException( + s"Failed to load class for data source: $provider.", error) + } + } + /** there is exactly one registered alias */ + case head :: Nil => head.getClass + /** There are multiple registered aliases for the input */ + case sources => sys.error(s"Multiple sources found for $provider, " + + s"(${sources.map(_.getClass.getName).mkString(", ")}), " + + "please specify the fully qualified class name.") + } + } + + /** Create a [[ResolvedDataSource]] for reading data in. */ + def apply( + sqlContext: SQLContext, + userSpecifiedSchema: Option[StructType], + partitionColumns: Array[String], + provider: String, + options: Map[String, String]): ResolvedDataSource = { + val clazz: Class[_] = lookupDataSource(provider) + def className: String = clazz.getCanonicalName + val relation = userSpecifiedSchema match { + case Some(schema: StructType) => clazz.newInstance() match { + case dataSource: SchemaRelationProvider => + dataSource.createRelation(sqlContext, new CaseInsensitiveMap(options), schema) + case dataSource: HadoopFsRelationProvider => + val maybePartitionsSchema = if (partitionColumns.isEmpty) { + None + } else { + Some(partitionColumnsSchema(schema, partitionColumns)) + } + + val caseInsensitiveOptions = new CaseInsensitiveMap(options) + val paths = { + val patternPath = new Path(caseInsensitiveOptions("path")) + val fs = patternPath.getFileSystem(sqlContext.sparkContext.hadoopConfiguration) + val qualifiedPattern = patternPath.makeQualified(fs.getUri, fs.getWorkingDirectory) + SparkHadoopUtil.get.globPathIfNecessary(qualifiedPattern).map(_.toString).toArray + } + + val dataSchema = + StructType(schema.filterNot(f => partitionColumns.contains(f.name))).asNullable + + dataSource.createRelation( + sqlContext, + paths, + Some(dataSchema), + maybePartitionsSchema, + caseInsensitiveOptions) + case dataSource: org.apache.spark.sql.sources.RelationProvider => + throw new AnalysisException(s"$className does not allow user-specified schemas.") + case _ => + throw new AnalysisException(s"$className is not a RelationProvider.") + } + + case None => clazz.newInstance() match { + case dataSource: RelationProvider => + dataSource.createRelation(sqlContext, new CaseInsensitiveMap(options)) + case dataSource: HadoopFsRelationProvider => + val caseInsensitiveOptions = new CaseInsensitiveMap(options) + val paths = { + val patternPath = new Path(caseInsensitiveOptions("path")) + val fs = patternPath.getFileSystem(sqlContext.sparkContext.hadoopConfiguration) + val qualifiedPattern = patternPath.makeQualified(fs.getUri, fs.getWorkingDirectory) + SparkHadoopUtil.get.globPathIfNecessary(qualifiedPattern).map(_.toString).toArray + } + dataSource.createRelation(sqlContext, paths, None, None, caseInsensitiveOptions) + case dataSource: org.apache.spark.sql.sources.SchemaRelationProvider => + throw new AnalysisException( + s"A schema needs to be specified when using $className.") + case _ => + throw new AnalysisException( + s"$className is neither a RelationProvider nor a FSBasedRelationProvider.") + } + } + new ResolvedDataSource(clazz, relation) + } + + private def partitionColumnsSchema( + schema: StructType, + partitionColumns: Array[String]): StructType = { + StructType(partitionColumns.map { col => + schema.find(_.name == col).getOrElse { + throw new RuntimeException(s"Partition column $col not found in schema $schema") + } + }).asNullable + } + + /** Create a [[ResolvedDataSource]] for saving the content of the given DataFrame. */ + def apply( + sqlContext: SQLContext, + provider: String, + partitionColumns: Array[String], + mode: SaveMode, + options: Map[String, String], + data: DataFrame): ResolvedDataSource = { + if (data.schema.map(_.dataType).exists(_.isInstanceOf[CalendarIntervalType])) { + throw new AnalysisException("Cannot save interval data type into external storage.") + } + val clazz: Class[_] = lookupDataSource(provider) + val relation = clazz.newInstance() match { + case dataSource: CreatableRelationProvider => + dataSource.createRelation(sqlContext, mode, options, data) + case dataSource: HadoopFsRelationProvider => + // Don't glob path for the write path. The contracts here are: + // 1. Only one output path can be specified on the write path; + // 2. Output path must be a legal HDFS style file system path; + // 3. It's OK that the output path doesn't exist yet; + val caseInsensitiveOptions = new CaseInsensitiveMap(options) + val outputPath = { + val path = new Path(caseInsensitiveOptions("path")) + val fs = path.getFileSystem(sqlContext.sparkContext.hadoopConfiguration) + path.makeQualified(fs.getUri, fs.getWorkingDirectory) + } + val dataSchema = StructType(data.schema.filterNot(f => partitionColumns.contains(f.name))) + val r = dataSource.createRelation( + sqlContext, + Array(outputPath.toString), + Some(dataSchema.asNullable), + Some(partitionColumnsSchema(data.schema, partitionColumns)), + caseInsensitiveOptions) + + // For partitioned relation r, r.schema's column ordering can be different from the column + // ordering of data.logicalPlan (partition columns are all moved after data column). This + // will be adjusted within InsertIntoHadoopFsRelation. + sqlContext.executePlan( + InsertIntoHadoopFsRelation( + r, + data.logicalPlan, + mode)).toRdd + r + case _ => + sys.error(s"${clazz.getCanonicalName} does not allow create table as select.") + } + ResolvedDataSource(clazz, relation) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala index 8c2f297e42458..ecd304c30cdee 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala @@ -17,340 +17,12 @@ package org.apache.spark.sql.execution.datasources -import java.util.ServiceLoader - -import scala.collection.Iterator -import scala.collection.JavaConversions._ -import scala.language.{existentials, implicitConversions} -import scala.util.{Failure, Success, Try} -import scala.util.matching.Regex - -import org.apache.hadoop.fs.Path - -import org.apache.spark.Logging -import org.apache.spark.deploy.SparkHadoopUtil -import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation +import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.catalyst.{AbstractSparkSQLParser, TableIdentifier} import org.apache.spark.sql.execution.RunnableCommand -import org.apache.spark.sql.sources._ import org.apache.spark.sql.types._ -import org.apache.spark.sql.{AnalysisException, DataFrame, Row, SQLContext, SaveMode} -import org.apache.spark.util.Utils - -/** - * A parser for foreign DDL commands. - */ -private[sql] class DDLParser( - parseQuery: String => LogicalPlan) - extends AbstractSparkSQLParser with DataTypeParser with Logging { - - def parse(input: String, exceptionOnError: Boolean): LogicalPlan = { - try { - parse(input) - } catch { - case ddlException: DDLException => throw ddlException - case _ if !exceptionOnError => parseQuery(input) - case x: Throwable => throw x - } - } - - // Keyword is a convention with AbstractSparkSQLParser, which will scan all of the `Keyword` - // properties via reflection the class in runtime for constructing the SqlLexical object - protected val CREATE = Keyword("CREATE") - protected val TEMPORARY = Keyword("TEMPORARY") - protected val TABLE = Keyword("TABLE") - protected val IF = Keyword("IF") - protected val NOT = Keyword("NOT") - protected val EXISTS = Keyword("EXISTS") - protected val USING = Keyword("USING") - protected val OPTIONS = Keyword("OPTIONS") - protected val DESCRIBE = Keyword("DESCRIBE") - protected val EXTENDED = Keyword("EXTENDED") - protected val AS = Keyword("AS") - protected val COMMENT = Keyword("COMMENT") - protected val REFRESH = Keyword("REFRESH") - - protected lazy val ddl: Parser[LogicalPlan] = createTable | describeTable | refreshTable - - protected def start: Parser[LogicalPlan] = ddl - - /** - * `CREATE [TEMPORARY] TABLE avroTable [IF NOT EXISTS] - * USING org.apache.spark.sql.avro - * OPTIONS (path "../hive/src/test/resources/data/files/episodes.avro")` - * or - * `CREATE [TEMPORARY] TABLE avroTable(intField int, stringField string...) [IF NOT EXISTS] - * USING org.apache.spark.sql.avro - * OPTIONS (path "../hive/src/test/resources/data/files/episodes.avro")` - * or - * `CREATE [TEMPORARY] TABLE avroTable [IF NOT EXISTS] - * USING org.apache.spark.sql.avro - * OPTIONS (path "../hive/src/test/resources/data/files/episodes.avro")` - * AS SELECT ... - */ - protected lazy val createTable: Parser[LogicalPlan] = - // TODO: Support database.table. - (CREATE ~> TEMPORARY.? <~ TABLE) ~ (IF ~> NOT <~ EXISTS).? ~ ident ~ - tableCols.? ~ (USING ~> className) ~ (OPTIONS ~> options).? ~ (AS ~> restInput).? ^^ { - case temp ~ allowExisting ~ tableName ~ columns ~ provider ~ opts ~ query => - if (temp.isDefined && allowExisting.isDefined) { - throw new DDLException( - "a CREATE TEMPORARY TABLE statement does not allow IF NOT EXISTS clause.") - } - - val options = opts.getOrElse(Map.empty[String, String]) - if (query.isDefined) { - if (columns.isDefined) { - throw new DDLException( - "a CREATE TABLE AS SELECT statement does not allow column definitions.") - } - // When IF NOT EXISTS clause appears in the query, the save mode will be ignore. - val mode = if (allowExisting.isDefined) { - SaveMode.Ignore - } else if (temp.isDefined) { - SaveMode.Overwrite - } else { - SaveMode.ErrorIfExists - } - - val queryPlan = parseQuery(query.get) - CreateTableUsingAsSelect(tableName, - provider, - temp.isDefined, - Array.empty[String], - mode, - options, - queryPlan) - } else { - val userSpecifiedSchema = columns.flatMap(fields => Some(StructType(fields))) - CreateTableUsing( - tableName, - userSpecifiedSchema, - provider, - temp.isDefined, - options, - allowExisting.isDefined, - managedIfNoPath = false) - } - } - - protected lazy val tableCols: Parser[Seq[StructField]] = "(" ~> repsep(column, ",") <~ ")" - - /* - * describe [extended] table avroTable - * This will display all columns of table `avroTable` includes column_name,column_type,comment - */ - protected lazy val describeTable: Parser[LogicalPlan] = - (DESCRIBE ~> opt(EXTENDED)) ~ (ident <~ ".").? ~ ident ^^ { - case e ~ db ~ tbl => - val tblIdentifier = db match { - case Some(dbName) => - Seq(dbName, tbl) - case None => - Seq(tbl) - } - DescribeCommand(UnresolvedRelation(tblIdentifier, None), e.isDefined) - } - - protected lazy val refreshTable: Parser[LogicalPlan] = - REFRESH ~> TABLE ~> (ident <~ ".").? ~ ident ^^ { - case maybeDatabaseName ~ tableName => - RefreshTable(TableIdentifier(tableName, maybeDatabaseName)) - } - - protected lazy val options: Parser[Map[String, String]] = - "(" ~> repsep(pair, ",") <~ ")" ^^ { case s: Seq[(String, String)] => s.toMap } - - protected lazy val className: Parser[String] = repsep(ident, ".") ^^ { case s => s.mkString(".")} - - override implicit def regexToParser(regex: Regex): Parser[String] = acceptMatch( - s"identifier matching regex $regex", { - case lexical.Identifier(str) if regex.unapplySeq(str).isDefined => str - case lexical.Keyword(str) if regex.unapplySeq(str).isDefined => str - } - ) - - protected lazy val optionPart: Parser[String] = "[_a-zA-Z][_a-zA-Z0-9]*".r ^^ { - case name => name - } - - protected lazy val optionName: Parser[String] = repsep(optionPart, ".") ^^ { - case parts => parts.mkString(".") - } - - protected lazy val pair: Parser[(String, String)] = - optionName ~ stringLit ^^ { case k ~ v => (k, v) } - - protected lazy val column: Parser[StructField] = - ident ~ dataType ~ (COMMENT ~> stringLit).? ^^ { case columnName ~ typ ~ cm => - val meta = cm match { - case Some(comment) => - new MetadataBuilder().putString(COMMENT.str.toLowerCase, comment).build() - case None => Metadata.empty - } - - StructField(columnName, typ, nullable = true, meta) - } -} - -private[sql] object ResolvedDataSource extends Logging { - - /** Given a provider name, look up the data source class definition. */ - def lookupDataSource(provider: String): Class[_] = { - val provider2 = s"$provider.DefaultSource" - val loader = Utils.getContextOrSparkClassLoader - val serviceLoader = ServiceLoader.load(classOf[DataSourceRegister], loader) - - serviceLoader.iterator().filter(_.format().equalsIgnoreCase(provider)).toList match { - /** the provider format did not match any given registered aliases */ - case Nil => Try(loader.loadClass(provider)).orElse(Try(loader.loadClass(provider2))) match { - case Success(dataSource) => dataSource - case Failure(error) => if (provider.startsWith("org.apache.spark.sql.hive.orc")) { - throw new ClassNotFoundException( - "The ORC data source must be used with Hive support enabled.", error) - } else { - throw new ClassNotFoundException( - s"Failed to load class for data source: $provider", error) - } - } - /** there is exactly one registered alias */ - case head :: Nil => head.getClass - /** There are multiple registered aliases for the input */ - case sources => sys.error(s"Multiple sources found for $provider, " + - s"(${sources.map(_.getClass.getName).mkString(", ")}), " + - "please specify the fully qualified class name") - } - } - - /** Create a [[ResolvedDataSource]] for reading data in. */ - def apply( - sqlContext: SQLContext, - userSpecifiedSchema: Option[StructType], - partitionColumns: Array[String], - provider: String, - options: Map[String, String]): ResolvedDataSource = { - val clazz: Class[_] = lookupDataSource(provider) - def className: String = clazz.getCanonicalName - val relation = userSpecifiedSchema match { - case Some(schema: StructType) => clazz.newInstance() match { - case dataSource: SchemaRelationProvider => - dataSource.createRelation(sqlContext, new CaseInsensitiveMap(options), schema) - case dataSource: HadoopFsRelationProvider => - val maybePartitionsSchema = if (partitionColumns.isEmpty) { - None - } else { - Some(partitionColumnsSchema(schema, partitionColumns)) - } - - val caseInsensitiveOptions = new CaseInsensitiveMap(options) - val paths = { - val patternPath = new Path(caseInsensitiveOptions("path")) - val fs = patternPath.getFileSystem(sqlContext.sparkContext.hadoopConfiguration) - val qualifiedPattern = patternPath.makeQualified(fs.getUri, fs.getWorkingDirectory) - SparkHadoopUtil.get.globPathIfNecessary(qualifiedPattern).map(_.toString).toArray - } - - val dataSchema = - StructType(schema.filterNot(f => partitionColumns.contains(f.name))).asNullable - - dataSource.createRelation( - sqlContext, - paths, - Some(dataSchema), - maybePartitionsSchema, - caseInsensitiveOptions) - case dataSource: org.apache.spark.sql.sources.RelationProvider => - throw new AnalysisException(s"$className does not allow user-specified schemas.") - case _ => - throw new AnalysisException(s"$className is not a RelationProvider.") - } - - case None => clazz.newInstance() match { - case dataSource: RelationProvider => - dataSource.createRelation(sqlContext, new CaseInsensitiveMap(options)) - case dataSource: HadoopFsRelationProvider => - val caseInsensitiveOptions = new CaseInsensitiveMap(options) - val paths = { - val patternPath = new Path(caseInsensitiveOptions("path")) - val fs = patternPath.getFileSystem(sqlContext.sparkContext.hadoopConfiguration) - val qualifiedPattern = patternPath.makeQualified(fs.getUri, fs.getWorkingDirectory) - SparkHadoopUtil.get.globPathIfNecessary(qualifiedPattern).map(_.toString).toArray - } - dataSource.createRelation(sqlContext, paths, None, None, caseInsensitiveOptions) - case dataSource: org.apache.spark.sql.sources.SchemaRelationProvider => - throw new AnalysisException( - s"A schema needs to be specified when using $className.") - case _ => - throw new AnalysisException( - s"$className is neither a RelationProvider nor a FSBasedRelationProvider.") - } - } - new ResolvedDataSource(clazz, relation) - } - - private def partitionColumnsSchema( - schema: StructType, - partitionColumns: Array[String]): StructType = { - StructType(partitionColumns.map { col => - schema.find(_.name == col).getOrElse { - throw new RuntimeException(s"Partition column $col not found in schema $schema") - } - }).asNullable - } - - /** Create a [[ResolvedDataSource]] for saving the content of the given [[DataFrame]]. */ - def apply( - sqlContext: SQLContext, - provider: String, - partitionColumns: Array[String], - mode: SaveMode, - options: Map[String, String], - data: DataFrame): ResolvedDataSource = { - if (data.schema.map(_.dataType).exists(_.isInstanceOf[CalendarIntervalType])) { - throw new AnalysisException("Cannot save interval data type into external storage.") - } - val clazz: Class[_] = lookupDataSource(provider) - val relation = clazz.newInstance() match { - case dataSource: CreatableRelationProvider => - dataSource.createRelation(sqlContext, mode, options, data) - case dataSource: HadoopFsRelationProvider => - // Don't glob path for the write path. The contracts here are: - // 1. Only one output path can be specified on the write path; - // 2. Output path must be a legal HDFS style file system path; - // 3. It's OK that the output path doesn't exist yet; - val caseInsensitiveOptions = new CaseInsensitiveMap(options) - val outputPath = { - val path = new Path(caseInsensitiveOptions("path")) - val fs = path.getFileSystem(sqlContext.sparkContext.hadoopConfiguration) - path.makeQualified(fs.getUri, fs.getWorkingDirectory) - } - val dataSchema = StructType(data.schema.filterNot(f => partitionColumns.contains(f.name))) - val r = dataSource.createRelation( - sqlContext, - Array(outputPath.toString), - Some(dataSchema.asNullable), - Some(partitionColumnsSchema(data.schema, partitionColumns)), - caseInsensitiveOptions) - - // For partitioned relation r, r.schema's column ordering can be different from the column - // ordering of data.logicalPlan (partition columns are all moved after data column). This - // will be adjusted within InsertIntoHadoopFsRelation. - sqlContext.executePlan( - InsertIntoHadoopFsRelation( - r, - data.logicalPlan, - mode)).toRdd - r - case _ => - sys.error(s"${clazz.getCanonicalName} does not allow create table as select.") - } - new ResolvedDataSource(clazz, relation) - } -} - -private[sql] case class ResolvedDataSource(provider: Class[_], relation: BaseRelation) +import org.apache.spark.sql.{DataFrame, Row, SQLContext, SaveMode} /** * Returned for the "DESCRIBE [EXTENDED] [dbName.]tableName" command. @@ -358,11 +30,12 @@ private[sql] case class ResolvedDataSource(provider: Class[_], relation: BaseRel * @param isExtended True if "DESCRIBE EXTENDED" is used. Otherwise, false. * It is effective only when the table is a Hive table. */ -private[sql] case class DescribeCommand( +case class DescribeCommand( table: LogicalPlan, isExtended: Boolean) extends LogicalPlan with Command { override def children: Seq[LogicalPlan] = Seq.empty + override val output: Seq[Attribute] = Seq( // Column names are based on Hive. AttributeReference("col_name", StringType, nullable = false, @@ -370,7 +43,8 @@ private[sql] case class DescribeCommand( AttributeReference("data_type", StringType, nullable = false, new MetadataBuilder().putString("comment", "data type of the column").build())(), AttributeReference("comment", StringType, nullable = false, - new MetadataBuilder().putString("comment", "comment of the column").build())()) + new MetadataBuilder().putString("comment", "comment of the column").build())() + ) } /** @@ -378,7 +52,7 @@ private[sql] case class DescribeCommand( * @param allowExisting If it is true, we will do nothing when the table already exists. * If it is false, an exception will be thrown */ -private[sql] case class CreateTableUsing( +case class CreateTableUsing( tableName: String, userSpecifiedSchema: Option[StructType], provider: String, @@ -397,7 +71,7 @@ private[sql] case class CreateTableUsing( * can analyze the logical plan that will be used to populate the table. * So, [[PreWriteCheck]] can detect cases that are not allowed. */ -private[sql] case class CreateTableUsingAsSelect( +case class CreateTableUsingAsSelect( tableName: String, provider: String, temporary: Boolean, @@ -410,7 +84,7 @@ private[sql] case class CreateTableUsingAsSelect( // override lazy val resolved = databaseName != None && childrenResolved } -private[sql] case class CreateTempTableUsing( +case class CreateTempTableUsing( tableName: String, userSpecifiedSchema: Option[StructType], provider: String, @@ -425,7 +99,7 @@ private[sql] case class CreateTempTableUsing( } } -private[sql] case class CreateTempTableUsingAsSelect( +case class CreateTempTableUsingAsSelect( tableName: String, provider: String, partitionColumns: Array[String], @@ -443,7 +117,7 @@ private[sql] case class CreateTempTableUsingAsSelect( } } -private[sql] case class RefreshTable(tableIdent: TableIdentifier) +case class RefreshTable(tableIdent: TableIdentifier) extends RunnableCommand { override def run(sqlContext: SQLContext): Seq[Row] = { @@ -472,7 +146,7 @@ private[sql] case class RefreshTable(tableIdent: TableIdentifier) /** * Builds a map in which keys are case insensitive */ -protected[sql] class CaseInsensitiveMap(map: Map[String, String]) extends Map[String, String] +class CaseInsensitiveMap(map: Map[String, String]) extends Map[String, String] with Serializable { val baseMap = map.map(kv => kv.copy(_1 = kv._1.toLowerCase)) @@ -490,4 +164,4 @@ protected[sql] class CaseInsensitiveMap(map: Map[String, String]) extends Map[St /** * The exception thrown from the DDL parser. */ -protected[sql] class DDLException(message: String) extends Exception(message) +class DDLException(message: String) extends RuntimeException(message) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/DefaultSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/DefaultSource.scala new file mode 100644 index 0000000000000..6773afc794f9c --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/DefaultSource.scala @@ -0,0 +1,62 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.execution.datasources.jdbc + +import java.util.Properties + +import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.sources.{BaseRelation, RelationProvider, DataSourceRegister} + +class DefaultSource extends RelationProvider with DataSourceRegister { + + override def shortName(): String = "jdbc" + + /** Returns a new base relation with the given parameters. */ + override def createRelation( + sqlContext: SQLContext, + parameters: Map[String, String]): BaseRelation = { + val url = parameters.getOrElse("url", sys.error("Option 'url' not specified")) + val driver = parameters.getOrElse("driver", null) + val table = parameters.getOrElse("dbtable", sys.error("Option 'dbtable' not specified")) + val partitionColumn = parameters.getOrElse("partitionColumn", null) + val lowerBound = parameters.getOrElse("lowerBound", null) + val upperBound = parameters.getOrElse("upperBound", null) + val numPartitions = parameters.getOrElse("numPartitions", null) + + if (driver != null) DriverRegistry.register(driver) + + if (partitionColumn != null + && (lowerBound == null || upperBound == null || numPartitions == null)) { + sys.error("Partitioning incompletely specified") + } + + val partitionInfo = if (partitionColumn == null) { + null + } else { + JDBCPartitioningInfo( + partitionColumn, + lowerBound.toLong, + upperBound.toLong, + numPartitions.toInt) + } + val parts = JDBCRelation.columnPartition(partitionInfo) + val properties = new Properties() // Additional properties that we will pass to getConnection + parameters.foreach(kv => properties.setProperty(kv._1, kv._2)) + JDBCRelation(url, table, parts, properties)(sqlContext) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/DriverRegistry.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/DriverRegistry.scala new file mode 100644 index 0000000000000..7ccd61ed469e9 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/DriverRegistry.scala @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.jdbc + +import java.sql.{Driver, DriverManager} + +import scala.collection.mutable + +import org.apache.spark.Logging +import org.apache.spark.util.Utils + +/** + * java.sql.DriverManager is always loaded by bootstrap classloader, + * so it can't load JDBC drivers accessible by Spark ClassLoader. + * + * To solve the problem, drivers from user-supplied jars are wrapped into thin wrapper. + */ +object DriverRegistry extends Logging { + + private val wrapperMap: mutable.Map[String, DriverWrapper] = mutable.Map.empty + + def register(className: String): Unit = { + val cls = Utils.getContextOrSparkClassLoader.loadClass(className) + if (cls.getClassLoader == null) { + logTrace(s"$className has been loaded with bootstrap ClassLoader, wrapper is not required") + } else if (wrapperMap.get(className).isDefined) { + logTrace(s"Wrapper for $className already exists") + } else { + synchronized { + if (wrapperMap.get(className).isEmpty) { + val wrapper = new DriverWrapper(cls.newInstance().asInstanceOf[Driver]) + DriverManager.registerDriver(wrapper) + wrapperMap(className) = wrapper + logTrace(s"Wrapper for $className registered") + } + } + } + } + + def getDriverClassName(url: String): String = DriverManager.getDriver(url) match { + case wrapper: DriverWrapper => wrapper.wrapped.getClass.getCanonicalName + case driver => driver.getClass.getCanonicalName + } +} + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/DriverWrapper.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/DriverWrapper.scala new file mode 100644 index 0000000000000..18263fe227d04 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/DriverWrapper.scala @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.jdbc + +import java.sql.{Connection, Driver, DriverPropertyInfo, SQLFeatureNotSupportedException} +import java.util.Properties + +/** + * A wrapper for a JDBC Driver to work around SPARK-6913. + * + * The problem is in `java.sql.DriverManager` class that can't access drivers loaded by + * Spark ClassLoader. + */ +class DriverWrapper(val wrapped: Driver) extends Driver { + override def acceptsURL(url: String): Boolean = wrapped.acceptsURL(url) + + override def jdbcCompliant(): Boolean = wrapped.jdbcCompliant() + + override def getPropertyInfo(url: String, info: Properties): Array[DriverPropertyInfo] = { + wrapped.getPropertyInfo(url, info) + } + + override def getMinorVersion: Int = wrapped.getMinorVersion + + def getParentLogger: java.util.logging.Logger = { + throw new SQLFeatureNotSupportedException( + s"${this.getClass.getName}.getParentLogger is not yet implemented.") + } + + override def connect(url: String, info: Properties): Connection = wrapped.connect(url, info) + + override def getMajorVersion: Int = wrapped.getMajorVersion +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala similarity index 98% rename from sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala index 3cf70db6b7b09..8eab6a0adccc4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.jdbc +package org.apache.spark.sql.execution.datasources.jdbc import java.sql.{Connection, DriverManager, ResultSet, ResultSetMetaData, SQLException} import java.util.Properties @@ -26,6 +26,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.SpecificMutableRow import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.jdbc.JdbcDialects import org.apache.spark.sql.sources._ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -180,9 +181,8 @@ private[sql] object JDBCRDD extends Logging { try { if (driver != null) DriverRegistry.register(driver) } catch { - case e: ClassNotFoundException => { - logWarning(s"Couldn't find class $driver", e); - } + case e: ClassNotFoundException => + logWarning(s"Couldn't find class $driver", e) } DriverManager.getConnection(url, properties) } @@ -344,7 +344,6 @@ private[sql] class JDBCRDD( }).toArray } - /** * Runs the SQL query against the JDBC driver. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala similarity index 71% rename from sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRelation.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala index 48d97ced9ca0a..f9300dc2cb529 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.jdbc +package org.apache.spark.sql.execution.datasources.jdbc import java.util.Properties @@ -77,45 +77,6 @@ private[sql] object JDBCRelation { } } -private[sql] class DefaultSource extends RelationProvider with DataSourceRegister { - - def format(): String = "jdbc" - - /** Returns a new base relation with the given parameters. */ - override def createRelation( - sqlContext: SQLContext, - parameters: Map[String, String]): BaseRelation = { - val url = parameters.getOrElse("url", sys.error("Option 'url' not specified")) - val driver = parameters.getOrElse("driver", null) - val table = parameters.getOrElse("dbtable", sys.error("Option 'dbtable' not specified")) - val partitionColumn = parameters.getOrElse("partitionColumn", null) - val lowerBound = parameters.getOrElse("lowerBound", null) - val upperBound = parameters.getOrElse("upperBound", null) - val numPartitions = parameters.getOrElse("numPartitions", null) - - if (driver != null) DriverRegistry.register(driver) - - if (partitionColumn != null - && (lowerBound == null || upperBound == null || numPartitions == null)) { - sys.error("Partitioning incompletely specified") - } - - val partitionInfo = if (partitionColumn == null) { - null - } else { - JDBCPartitioningInfo( - partitionColumn, - lowerBound.toLong, - upperBound.toLong, - numPartitions.toInt) - } - val parts = JDBCRelation.columnPartition(partitionInfo) - val properties = new Properties() // Additional properties that we will pass to getConnection - parameters.foreach(kv => properties.setProperty(kv._1, kv._2)) - JDBCRelation(url, table, parts, properties)(sqlContext) - } -} - private[sql] case class JDBCRelation( url: String, table: String, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala new file mode 100644 index 0000000000000..039c13bf163ca --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala @@ -0,0 +1,219 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.jdbc + +import java.sql.{Connection, DriverManager, PreparedStatement} +import java.util.Properties + +import scala.util.Try + +import org.apache.spark.Logging +import org.apache.spark.sql.jdbc.JdbcDialects +import org.apache.spark.sql.types._ +import org.apache.spark.sql.{DataFrame, Row} + +/** + * Util functions for JDBC tables. + */ +object JdbcUtils extends Logging { + + /** + * Establishes a JDBC connection. + */ + def createConnection(url: String, connectionProperties: Properties): Connection = { + DriverManager.getConnection(url, connectionProperties) + } + + /** + * Returns true if the table already exists in the JDBC database. + */ + def tableExists(conn: Connection, table: String): Boolean = { + // Somewhat hacky, but there isn't a good way to identify whether a table exists for all + // SQL database systems, considering "table" could also include the database name. + Try(conn.prepareStatement(s"SELECT 1 FROM $table LIMIT 1").executeQuery().next()).isSuccess + } + + /** + * Drops a table from the JDBC database. + */ + def dropTable(conn: Connection, table: String): Unit = { + conn.prepareStatement(s"DROP TABLE $table").executeUpdate() + } + + /** + * Returns a PreparedStatement that inserts a row into table via conn. + */ + def insertStatement(conn: Connection, table: String, rddSchema: StructType): PreparedStatement = { + val sql = new StringBuilder(s"INSERT INTO $table VALUES (") + var fieldsLeft = rddSchema.fields.length + while (fieldsLeft > 0) { + sql.append("?") + if (fieldsLeft > 1) sql.append(", ") else sql.append(")") + fieldsLeft = fieldsLeft - 1 + } + conn.prepareStatement(sql.toString()) + } + + /** + * Saves a partition of a DataFrame to the JDBC database. This is done in + * a single database transaction in order to avoid repeatedly inserting + * data as much as possible. + * + * It is still theoretically possible for rows in a DataFrame to be + * inserted into the database more than once if a stage somehow fails after + * the commit occurs but before the stage can return successfully. + * + * This is not a closure inside saveTable() because apparently cosmetic + * implementation changes elsewhere might easily render such a closure + * non-Serializable. Instead, we explicitly close over all variables that + * are used. + */ + def savePartition( + getConnection: () => Connection, + table: String, + iterator: Iterator[Row], + rddSchema: StructType, + nullTypes: Array[Int]): Iterator[Byte] = { + val conn = getConnection() + var committed = false + try { + conn.setAutoCommit(false) // Everything in the same db transaction. + val stmt = insertStatement(conn, table, rddSchema) + try { + while (iterator.hasNext) { + val row = iterator.next() + val numFields = rddSchema.fields.length + var i = 0 + while (i < numFields) { + if (row.isNullAt(i)) { + stmt.setNull(i + 1, nullTypes(i)) + } else { + rddSchema.fields(i).dataType match { + case IntegerType => stmt.setInt(i + 1, row.getInt(i)) + case LongType => stmt.setLong(i + 1, row.getLong(i)) + case DoubleType => stmt.setDouble(i + 1, row.getDouble(i)) + case FloatType => stmt.setFloat(i + 1, row.getFloat(i)) + case ShortType => stmt.setInt(i + 1, row.getShort(i)) + case ByteType => stmt.setInt(i + 1, row.getByte(i)) + case BooleanType => stmt.setBoolean(i + 1, row.getBoolean(i)) + case StringType => stmt.setString(i + 1, row.getString(i)) + case BinaryType => stmt.setBytes(i + 1, row.getAs[Array[Byte]](i)) + case TimestampType => stmt.setTimestamp(i + 1, row.getAs[java.sql.Timestamp](i)) + case DateType => stmt.setDate(i + 1, row.getAs[java.sql.Date](i)) + case t: DecimalType => stmt.setBigDecimal(i + 1, row.getDecimal(i)) + case _ => throw new IllegalArgumentException( + s"Can't translate non-null value for field $i") + } + } + i = i + 1 + } + stmt.executeUpdate() + } + } finally { + stmt.close() + } + conn.commit() + committed = true + } finally { + if (!committed) { + // The stage must fail. We got here through an exception path, so + // let the exception through unless rollback() or close() want to + // tell the user about another problem. + conn.rollback() + conn.close() + } else { + // The stage must succeed. We cannot propagate any exception close() might throw. + try { + conn.close() + } catch { + case e: Exception => logWarning("Transaction succeeded, but closing failed", e) + } + } + } + Array[Byte]().iterator + } + + /** + * Compute the schema string for this RDD. + */ + def schemaString(df: DataFrame, url: String): String = { + val sb = new StringBuilder() + val dialect = JdbcDialects.get(url) + df.schema.fields foreach { field => { + val name = field.name + val typ: String = + dialect.getJDBCType(field.dataType).map(_.databaseTypeDefinition).getOrElse( + field.dataType match { + case IntegerType => "INTEGER" + case LongType => "BIGINT" + case DoubleType => "DOUBLE PRECISION" + case FloatType => "REAL" + case ShortType => "INTEGER" + case ByteType => "BYTE" + case BooleanType => "BIT(1)" + case StringType => "TEXT" + case BinaryType => "BLOB" + case TimestampType => "TIMESTAMP" + case DateType => "DATE" + case t: DecimalType => s"DECIMAL(${t.precision}},${t.scale}})" + case _ => throw new IllegalArgumentException(s"Don't know how to save $field to JDBC") + }) + val nullable = if (field.nullable) "" else "NOT NULL" + sb.append(s", $name $typ $nullable") + }} + if (sb.length < 2) "" else sb.substring(2) + } + + /** + * Saves the RDD to the database in a single transaction. + */ + def saveTable( + df: DataFrame, + url: String, + table: String, + properties: Properties = new Properties()) { + val dialect = JdbcDialects.get(url) + val nullTypes: Array[Int] = df.schema.fields.map { field => + dialect.getJDBCType(field.dataType).map(_.jdbcNullType).getOrElse( + field.dataType match { + case IntegerType => java.sql.Types.INTEGER + case LongType => java.sql.Types.BIGINT + case DoubleType => java.sql.Types.DOUBLE + case FloatType => java.sql.Types.REAL + case ShortType => java.sql.Types.INTEGER + case ByteType => java.sql.Types.INTEGER + case BooleanType => java.sql.Types.BIT + case StringType => java.sql.Types.CLOB + case BinaryType => java.sql.Types.BLOB + case TimestampType => java.sql.Types.TIMESTAMP + case DateType => java.sql.Types.DATE + case t: DecimalType => java.sql.Types.DECIMAL + case _ => throw new IllegalArgumentException( + s"Can't translate null value for field $field") + }) + } + + val rddSchema = df.schema + val driver: String = DriverRegistry.getDriverClassName(url) + val getConnection: () => Connection = JDBCRDD.getConnector(driver, url, properties) + df.foreachPartition { iterator => + savePartition(getConnection, table, iterator, rddSchema, nullTypes) + } + } + +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/InferSchema.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/InferSchema.scala similarity index 98% rename from sql/core/src/main/scala/org/apache/spark/sql/json/InferSchema.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/InferSchema.scala index ec5668c6b95a1..b6f3410bad690 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/json/InferSchema.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/InferSchema.scala @@ -15,13 +15,13 @@ * limitations under the License. */ -package org.apache.spark.sql.json +package org.apache.spark.sql.execution.datasources.json import com.fasterxml.jackson.core._ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.analysis.HiveTypeCoercion -import org.apache.spark.sql.json.JacksonUtils.nextUntil +import org.apache.spark.sql.execution.datasources.json.JacksonUtils.nextUntil import org.apache.spark.sql.types._ private[sql] object InferSchema { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala similarity index 97% rename from sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala index 5bb9e62310a50..114c8b211891e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.json +package org.apache.spark.sql.execution.datasources.json import java.io.CharArrayWriter @@ -39,9 +39,10 @@ import org.apache.spark.sql.types.StructType import org.apache.spark.sql.{AnalysisException, Row, SQLContext} import org.apache.spark.util.SerializableConfiguration -private[sql] class DefaultSource extends HadoopFsRelationProvider with DataSourceRegister { - def format(): String = "json" +class DefaultSource extends HadoopFsRelationProvider with DataSourceRegister { + + override def shortName(): String = "json" override def createRelation( sqlContext: SQLContext, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JacksonGenerator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonGenerator.scala similarity index 98% rename from sql/core/src/main/scala/org/apache/spark/sql/json/JacksonGenerator.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonGenerator.scala index d734e7e8904bd..37c2b5a296c15 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/json/JacksonGenerator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonGenerator.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.json +package org.apache.spark.sql.execution.datasources.json import org.apache.spark.sql.catalyst.InternalRow diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JacksonParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonParser.scala similarity index 98% rename from sql/core/src/main/scala/org/apache/spark/sql/json/JacksonParser.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonParser.scala index b8fd3b9cc150e..cd68bd667c5c4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/json/JacksonParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonParser.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.json +package org.apache.spark.sql.execution.datasources.json import java.io.ByteArrayOutputStream @@ -27,7 +27,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.util.DateTimeUtils -import org.apache.spark.sql.json.JacksonUtils.nextUntil +import org.apache.spark.sql.execution.datasources.json.JacksonUtils.nextUntil import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JacksonUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonUtils.scala similarity index 95% rename from sql/core/src/main/scala/org/apache/spark/sql/json/JacksonUtils.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonUtils.scala index fde96852ce68e..005546f37dda0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/json/JacksonUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonUtils.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.json +package org.apache.spark.sql.execution.datasources.json import com.fasterxml.jackson.core.{JsonParser, JsonToken} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystReadSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystReadSupport.scala similarity index 99% rename from sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystReadSupport.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystReadSupport.scala index 975fec101d9c2..4049795ed3bad 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystReadSupport.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystReadSupport.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.parquet +package org.apache.spark.sql.execution.datasources.parquet import java.util.{Map => JMap} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystRecordMaterializer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystRecordMaterializer.scala similarity index 96% rename from sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystRecordMaterializer.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystRecordMaterializer.scala index 84f1dccfeb788..ed9e0aa65977b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystRecordMaterializer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystRecordMaterializer.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.parquet +package org.apache.spark.sql.execution.datasources.parquet import org.apache.parquet.io.api.{GroupConverter, RecordMaterializer} import org.apache.parquet.schema.MessageType diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystRowConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystRowConverter.scala similarity index 99% rename from sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystRowConverter.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystRowConverter.scala index 4fe8a39f20abd..3542dfbae1292 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystRowConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystRowConverter.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.parquet +package org.apache.spark.sql.execution.datasources.parquet import java.math.{BigDecimal, BigInteger} import java.nio.ByteOrder diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystSchemaConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystSchemaConverter.scala similarity index 99% rename from sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystSchemaConverter.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystSchemaConverter.scala index b12149dcf1c92..a3fc74cf7929b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystSchemaConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystSchemaConverter.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.parquet +package org.apache.spark.sql.execution.datasources.parquet import scala.collection.JavaConversions._ @@ -25,7 +25,7 @@ import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName._ import org.apache.parquet.schema.Type.Repetition._ import org.apache.parquet.schema._ -import org.apache.spark.sql.parquet.CatalystSchemaConverter.{MAX_PRECISION_FOR_INT32, MAX_PRECISION_FOR_INT64, maxPrecisionForBytes} +import org.apache.spark.sql.execution.datasources.parquet.CatalystSchemaConverter.{MAX_PRECISION_FOR_INT32, MAX_PRECISION_FOR_INT64, maxPrecisionForBytes} import org.apache.spark.sql.types._ import org.apache.spark.sql.{AnalysisException, SQLConf} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/DirectParquetOutputCommitter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/DirectParquetOutputCommitter.scala similarity index 98% rename from sql/core/src/main/scala/org/apache/spark/sql/parquet/DirectParquetOutputCommitter.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/DirectParquetOutputCommitter.scala index 1551afd7b7bf2..2c6b914328b60 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/DirectParquetOutputCommitter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/DirectParquetOutputCommitter.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.parquet +package org.apache.spark.sql.execution.datasources.parquet import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetConverter.scala similarity index 96% rename from sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetConverter.scala index 6ed3580af0729..ccd7ebf319af9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetConverter.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.parquet +package org.apache.spark.sql.execution.datasources.parquet import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.types.{MapData, ArrayData} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetFilters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala similarity index 99% rename from sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetFilters.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala index d57b789f5c1c7..9e2e232f50167 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetFilters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.parquet +package org.apache.spark.sql.execution.datasources.parquet import java.io.Serializable import java.nio.ByteBuffer diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala similarity index 99% rename from sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala index b6db71b5b8a62..4086a139bed72 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.parquet +package org.apache.spark.sql.execution.datasources.parquet import java.net.URI import java.util.logging.{Level, Logger => JLogger} @@ -51,7 +51,7 @@ import org.apache.spark.util.{SerializableConfiguration, Utils} private[sql] class DefaultSource extends HadoopFsRelationProvider with DataSourceRegister { - def format(): String = "parquet" + override def shortName(): String = "parquet" override def createRelation( sqlContext: SQLContext, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTableSupport.scala similarity index 99% rename from sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTableSupport.scala index 9cd0250f9c510..3191cf3d121bb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTableSupport.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.parquet +package org.apache.spark.sql.execution.datasources.parquet import java.math.BigInteger import java.nio.{ByteBuffer, ByteOrder} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypesConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTypesConverter.scala similarity index 99% rename from sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypesConverter.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTypesConverter.scala index 3854f5bd39fb1..019db34fc666d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypesConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTypesConverter.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.parquet +package org.apache.spark.sql.execution.datasources.parquet import java.io.IOException diff --git a/sql/core/src/main/scala/org/apache/spark/sql/metric/SQLMetrics.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala similarity index 77% rename from sql/core/src/main/scala/org/apache/spark/sql/metric/SQLMetrics.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala index 3b907e5da7897..1b51a5e5c8a8e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/metric/SQLMetrics.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.metric +package org.apache.spark.sql.execution.metric import org.apache.spark.{Accumulable, AccumulableParam, SparkContext} @@ -93,22 +93,6 @@ private[sql] class LongSQLMetric private[metric](name: String) } } -/** - * A specialized int Accumulable to avoid boxing and unboxing when using Accumulator's - * `+=` and `add`. - */ -private[sql] class IntSQLMetric private[metric](name: String) - extends SQLMetric[IntSQLMetricValue, Int](name, IntSQLMetricParam) { - - override def +=(term: Int): Unit = { - localValue.add(term) - } - - override def add(term: Int): Unit = { - localValue.add(term) - } -} - private object LongSQLMetricParam extends SQLMetricParam[LongSQLMetricValue, Long] { override def addAccumulator(r: LongSQLMetricValue, t: Long): LongSQLMetricValue = r.add(t) @@ -121,26 +105,8 @@ private object LongSQLMetricParam extends SQLMetricParam[LongSQLMetricValue, Lon override def zero: LongSQLMetricValue = new LongSQLMetricValue(0L) } -private object IntSQLMetricParam extends SQLMetricParam[IntSQLMetricValue, Int] { - - override def addAccumulator(r: IntSQLMetricValue, t: Int): IntSQLMetricValue = r.add(t) - - override def addInPlace(r1: IntSQLMetricValue, r2: IntSQLMetricValue): IntSQLMetricValue = - r1.add(r2.value) - - override def zero(initialValue: IntSQLMetricValue): IntSQLMetricValue = zero - - override def zero: IntSQLMetricValue = new IntSQLMetricValue(0) -} - private[sql] object SQLMetrics { - def createIntMetric(sc: SparkContext, name: String): IntSQLMetric = { - val acc = new IntSQLMetric(name) - sc.cleaner.foreach(_.registerAccumulatorForCleanup(acc)) - acc - } - def createLongMetric(sc: SparkContext, name: String): LongSQLMetric = { val acc = new LongSQLMetric(name) sc.cleaner.foreach(_.registerAccumulatorForCleanup(acc)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/package.scala index 66237f8f1314b..28fa231e722d0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/package.scala @@ -18,12 +18,6 @@ package org.apache.spark.sql /** - * :: DeveloperApi :: - * An execution engine for relational query plans that runs on top Spark and returns RDDs. - * - * Note that the operators in this package are created automatically by a query planner using a - * [[SQLContext]] and are not intended to be used directly by end users of Spark SQL. They are - * documented here in order to make it easier for others to understand the performance - * characteristics of query plans that are generated by Spark SQL. + * The physical execution component of Spark SQL. Note that this is a private package. */ package object execution diff --git a/sql/core/src/main/scala/org/apache/spark/sql/ui/AllExecutionsPage.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/AllExecutionsPage.scala similarity index 99% rename from sql/core/src/main/scala/org/apache/spark/sql/ui/AllExecutionsPage.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/ui/AllExecutionsPage.scala index cb7ca60b2fe48..49646a99d68c8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/ui/AllExecutionsPage.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/AllExecutionsPage.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.ui +package org.apache.spark.sql.execution.ui import javax.servlet.http.HttpServletRequest diff --git a/sql/core/src/main/scala/org/apache/spark/sql/ui/ExecutionPage.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/ExecutionPage.scala similarity index 99% rename from sql/core/src/main/scala/org/apache/spark/sql/ui/ExecutionPage.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/ui/ExecutionPage.scala index 52ddf99e9266a..f0b56c2eb7a53 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/ui/ExecutionPage.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/ExecutionPage.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.ui +package org.apache.spark.sql.execution.ui import javax.servlet.http.HttpServletRequest diff --git a/sql/core/src/main/scala/org/apache/spark/sql/ui/SQLListener.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala similarity index 98% rename from sql/core/src/main/scala/org/apache/spark/sql/ui/SQLListener.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala index 2fd4fc658d068..0b9bad987c488 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/ui/SQLListener.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.ui +package org.apache.spark.sql.execution.ui import scala.collection.mutable @@ -26,7 +26,7 @@ import org.apache.spark.executor.TaskMetrics import org.apache.spark.scheduler._ import org.apache.spark.sql.SQLContext import org.apache.spark.sql.execution.SQLExecution -import org.apache.spark.sql.metric.{SQLMetricParam, SQLMetricValue} +import org.apache.spark.sql.execution.metric.{SQLMetricParam, SQLMetricValue} private[sql] class SQLListener(sqlContext: SQLContext) extends SparkListener with Logging { @@ -51,17 +51,14 @@ private[sql] class SQLListener(sqlContext: SQLContext) extends SparkListener wit private val completedExecutions = mutable.ListBuffer[SQLExecutionUIData]() - @VisibleForTesting def executionIdToData: Map[Long, SQLExecutionUIData] = synchronized { _executionIdToData.toMap } - @VisibleForTesting def jobIdToExecutionId: Map[Long, Long] = synchronized { _jobIdToExecutionId.toMap } - @VisibleForTesting def stageIdToStageMetrics: Map[Long, SQLStageMetrics] = synchronized { _stageIdToStageMetrics.toMap } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/ui/SQLTab.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLTab.scala similarity index 90% rename from sql/core/src/main/scala/org/apache/spark/sql/ui/SQLTab.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLTab.scala index 3bba0afaf14eb..0b0867f67eb6e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/ui/SQLTab.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLTab.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.ui +package org.apache.spark.sql.execution.ui import java.util.concurrent.atomic.AtomicInteger @@ -38,12 +38,12 @@ private[sql] class SQLTab(sqlContext: SQLContext, sparkUI: SparkUI) private[sql] object SQLTab { - private val STATIC_RESOURCE_DIR = "org/apache/spark/sql/ui/static" + private val STATIC_RESOURCE_DIR = "org/apache/spark/sql/execution/ui/static" private val nextTabId = new AtomicInteger(0) private def nextTabName: String = { val nextId = nextTabId.getAndIncrement() - if (nextId == 0) "SQL" else s"SQL${nextId}" + if (nextId == 0) "SQL" else s"SQL$nextId" } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/ui/SparkPlanGraph.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala similarity index 97% rename from sql/core/src/main/scala/org/apache/spark/sql/ui/SparkPlanGraph.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala index 1ba50b95becc1..ae3d752dde348 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/ui/SparkPlanGraph.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala @@ -15,14 +15,14 @@ * limitations under the License. */ -package org.apache.spark.sql.ui +package org.apache.spark.sql.execution.ui import java.util.concurrent.atomic.AtomicLong import scala.collection.mutable import org.apache.spark.sql.execution.SparkPlan -import org.apache.spark.sql.metric.{SQLMetricParam, SQLMetricValue} +import org.apache.spark.sql.execution.metric.{SQLMetricParam, SQLMetricValue} /** * A graph used for storing information of an executionPlan of DataFrame. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcUtils.scala deleted file mode 100644 index cc918c237192b..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcUtils.scala +++ /dev/null @@ -1,52 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.jdbc - -import java.sql.{Connection, DriverManager} -import java.util.Properties - -import scala.util.Try - -/** - * Util functions for JDBC tables. - */ -private[sql] object JdbcUtils { - - /** - * Establishes a JDBC connection. - */ - def createConnection(url: String, connectionProperties: Properties): Connection = { - DriverManager.getConnection(url, connectionProperties) - } - - /** - * Returns true if the table already exists in the JDBC database. - */ - def tableExists(conn: Connection, table: String): Boolean = { - // Somewhat hacky, but there isn't a good way to identify whether a table exists for all - // SQL database systems, considering "table" could also include the database name. - Try(conn.prepareStatement(s"SELECT 1 FROM $table LIMIT 1").executeQuery().next()).isSuccess - } - - /** - * Drops a table from the JDBC database. - */ - def dropTable(conn: Connection, table: String): Unit = { - conn.prepareStatement(s"DROP TABLE $table").executeUpdate() - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/jdbc.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/jdbc.scala deleted file mode 100644 index 035e0510080ff..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/jdbc.scala +++ /dev/null @@ -1,250 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql - -import java.sql.{Connection, Driver, DriverManager, DriverPropertyInfo, PreparedStatement, SQLFeatureNotSupportedException} -import java.util.Properties - -import scala.collection.mutable - -import org.apache.spark.Logging -import org.apache.spark.sql.types._ -import org.apache.spark.util.Utils - -package object jdbc { - private[sql] object JDBCWriteDetails extends Logging { - /** - * Returns a PreparedStatement that inserts a row into table via conn. - */ - def insertStatement(conn: Connection, table: String, rddSchema: StructType): - PreparedStatement = { - val sql = new StringBuilder(s"INSERT INTO $table VALUES (") - var fieldsLeft = rddSchema.fields.length - while (fieldsLeft > 0) { - sql.append("?") - if (fieldsLeft > 1) sql.append(", ") else sql.append(")") - fieldsLeft = fieldsLeft - 1 - } - conn.prepareStatement(sql.toString) - } - - /** - * Saves a partition of a DataFrame to the JDBC database. This is done in - * a single database transaction in order to avoid repeatedly inserting - * data as much as possible. - * - * It is still theoretically possible for rows in a DataFrame to be - * inserted into the database more than once if a stage somehow fails after - * the commit occurs but before the stage can return successfully. - * - * This is not a closure inside saveTable() because apparently cosmetic - * implementation changes elsewhere might easily render such a closure - * non-Serializable. Instead, we explicitly close over all variables that - * are used. - */ - def savePartition( - getConnection: () => Connection, - table: String, - iterator: Iterator[Row], - rddSchema: StructType, - nullTypes: Array[Int]): Iterator[Byte] = { - val conn = getConnection() - var committed = false - try { - conn.setAutoCommit(false) // Everything in the same db transaction. - val stmt = insertStatement(conn, table, rddSchema) - try { - while (iterator.hasNext) { - val row = iterator.next() - val numFields = rddSchema.fields.length - var i = 0 - while (i < numFields) { - if (row.isNullAt(i)) { - stmt.setNull(i + 1, nullTypes(i)) - } else { - rddSchema.fields(i).dataType match { - case IntegerType => stmt.setInt(i + 1, row.getInt(i)) - case LongType => stmt.setLong(i + 1, row.getLong(i)) - case DoubleType => stmt.setDouble(i + 1, row.getDouble(i)) - case FloatType => stmt.setFloat(i + 1, row.getFloat(i)) - case ShortType => stmt.setInt(i + 1, row.getShort(i)) - case ByteType => stmt.setInt(i + 1, row.getByte(i)) - case BooleanType => stmt.setBoolean(i + 1, row.getBoolean(i)) - case StringType => stmt.setString(i + 1, row.getString(i)) - case BinaryType => stmt.setBytes(i + 1, row.getAs[Array[Byte]](i)) - case TimestampType => stmt.setTimestamp(i + 1, row.getAs[java.sql.Timestamp](i)) - case DateType => stmt.setDate(i + 1, row.getAs[java.sql.Date](i)) - case t: DecimalType => stmt.setBigDecimal(i + 1, row.getDecimal(i)) - case _ => throw new IllegalArgumentException( - s"Can't translate non-null value for field $i") - } - } - i = i + 1 - } - stmt.executeUpdate() - } - } finally { - stmt.close() - } - conn.commit() - committed = true - } finally { - if (!committed) { - // The stage must fail. We got here through an exception path, so - // let the exception through unless rollback() or close() want to - // tell the user about another problem. - conn.rollback() - conn.close() - } else { - // The stage must succeed. We cannot propagate any exception close() might throw. - try { - conn.close() - } catch { - case e: Exception => logWarning("Transaction succeeded, but closing failed", e) - } - } - } - Array[Byte]().iterator - } - - /** - * Compute the schema string for this RDD. - */ - def schemaString(df: DataFrame, url: String): String = { - val sb = new StringBuilder() - val dialect = JdbcDialects.get(url) - df.schema.fields foreach { field => { - val name = field.name - val typ: String = - dialect.getJDBCType(field.dataType).map(_.databaseTypeDefinition).getOrElse( - field.dataType match { - case IntegerType => "INTEGER" - case LongType => "BIGINT" - case DoubleType => "DOUBLE PRECISION" - case FloatType => "REAL" - case ShortType => "INTEGER" - case ByteType => "BYTE" - case BooleanType => "BIT(1)" - case StringType => "TEXT" - case BinaryType => "BLOB" - case TimestampType => "TIMESTAMP" - case DateType => "DATE" - case t: DecimalType => s"DECIMAL(${t.precision}},${t.scale}})" - case _ => throw new IllegalArgumentException(s"Don't know how to save $field to JDBC") - }) - val nullable = if (field.nullable) "" else "NOT NULL" - sb.append(s", $name $typ $nullable") - }} - if (sb.length < 2) "" else sb.substring(2) - } - - /** - * Saves the RDD to the database in a single transaction. - */ - def saveTable( - df: DataFrame, - url: String, - table: String, - properties: Properties = new Properties()) { - val dialect = JdbcDialects.get(url) - val nullTypes: Array[Int] = df.schema.fields.map { field => - dialect.getJDBCType(field.dataType).map(_.jdbcNullType).getOrElse( - field.dataType match { - case IntegerType => java.sql.Types.INTEGER - case LongType => java.sql.Types.BIGINT - case DoubleType => java.sql.Types.DOUBLE - case FloatType => java.sql.Types.REAL - case ShortType => java.sql.Types.INTEGER - case ByteType => java.sql.Types.INTEGER - case BooleanType => java.sql.Types.BIT - case StringType => java.sql.Types.CLOB - case BinaryType => java.sql.Types.BLOB - case TimestampType => java.sql.Types.TIMESTAMP - case DateType => java.sql.Types.DATE - case t: DecimalType => java.sql.Types.DECIMAL - case _ => throw new IllegalArgumentException( - s"Can't translate null value for field $field") - }) - } - - val rddSchema = df.schema - val driver: String = DriverRegistry.getDriverClassName(url) - val getConnection: () => Connection = JDBCRDD.getConnector(driver, url, properties) - df.foreachPartition { iterator => - JDBCWriteDetails.savePartition(getConnection, table, iterator, rddSchema, nullTypes) - } - } - - } - - private [sql] class DriverWrapper(val wrapped: Driver) extends Driver { - override def acceptsURL(url: String): Boolean = wrapped.acceptsURL(url) - - override def jdbcCompliant(): Boolean = wrapped.jdbcCompliant() - - override def getPropertyInfo(url: String, info: Properties): Array[DriverPropertyInfo] = { - wrapped.getPropertyInfo(url, info) - } - - override def getMinorVersion: Int = wrapped.getMinorVersion - - def getParentLogger: java.util.logging.Logger = - throw new SQLFeatureNotSupportedException( - s"${this.getClass().getName}.getParentLogger is not yet implemented.") - - override def connect(url: String, info: Properties): Connection = wrapped.connect(url, info) - - override def getMajorVersion: Int = wrapped.getMajorVersion - } - - /** - * java.sql.DriverManager is always loaded by bootstrap classloader, - * so it can't load JDBC drivers accessible by Spark ClassLoader. - * - * To solve the problem, drivers from user-supplied jars are wrapped - * into thin wrapper. - */ - private [sql] object DriverRegistry extends Logging { - - private val wrapperMap: mutable.Map[String, DriverWrapper] = mutable.Map.empty - - def register(className: String): Unit = { - val cls = Utils.getContextOrSparkClassLoader.loadClass(className) - if (cls.getClassLoader == null) { - logTrace(s"$className has been loaded with bootstrap ClassLoader, wrapper is not required") - } else if (wrapperMap.get(className).isDefined) { - logTrace(s"Wrapper for $className already exists") - } else { - synchronized { - if (wrapperMap.get(className).isEmpty) { - val wrapper = new DriverWrapper(cls.newInstance().asInstanceOf[Driver]) - DriverManager.registerDriver(wrapper) - wrapperMap(className) = wrapper - logTrace(s"Wrapper for $className registered") - } - } - } - } - - def getDriverClassName(url: String): String = DriverManager.getDriver(url) match { - case wrapper: DriverWrapper => wrapper.wrapped.getClass.getCanonicalName - case driver => driver.getClass.getCanonicalName - } - } - -} // package object jdbc diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala index 6bcabbab4f77b..2f8417a48d32e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala @@ -43,19 +43,24 @@ import org.apache.spark.util.SerializableConfiguration * This allows users to give the data source alias as the format type over the fully qualified * class name. * - * ex: parquet.DefaultSource.format = "parquet". - * * A new instance of this class with be instantiated each time a DDL call is made. + * + * @since 1.5.0 */ @DeveloperApi trait DataSourceRegister { /** * The string that represents the format that this data source provider uses. This is - * overridden by children to provide a nice alias for the data source, - * ex: override def format(): String = "parquet" + * overridden by children to provide a nice alias for the data source. For example: + * + * {{{ + * override def format(): String = "parquet" + * }}} + * + * @since 1.5.0 */ - def format(): String + def shortName(): String } /** diff --git a/sql/core/src/test/gen-java/org/apache/spark/sql/parquet/test/avro/CompatibilityTest.java b/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/CompatibilityTest.java similarity index 93% rename from sql/core/src/test/gen-java/org/apache/spark/sql/parquet/test/avro/CompatibilityTest.java rename to sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/CompatibilityTest.java index daec65a5bbe57..70dec1a9d3c92 100644 --- a/sql/core/src/test/gen-java/org/apache/spark/sql/parquet/test/avro/CompatibilityTest.java +++ b/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/CompatibilityTest.java @@ -3,7 +3,7 @@ * * DO NOT EDIT DIRECTLY */ -package org.apache.spark.sql.parquet.test.avro; +package org.apache.spark.sql.execution.datasources.parquet.test.avro; @SuppressWarnings("all") @org.apache.avro.specific.AvroGenerated @@ -12,6 +12,6 @@ public interface CompatibilityTest { @SuppressWarnings("all") public interface Callback extends CompatibilityTest { - public static final org.apache.avro.Protocol PROTOCOL = org.apache.spark.sql.parquet.test.avro.CompatibilityTest.PROTOCOL; + public static final org.apache.avro.Protocol PROTOCOL = org.apache.spark.sql.execution.datasources.parquet.test.avro.CompatibilityTest.PROTOCOL; } } \ No newline at end of file diff --git a/sql/core/src/test/gen-java/org/apache/spark/sql/parquet/test/avro/Nested.java b/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/Nested.java similarity index 78% rename from sql/core/src/test/gen-java/org/apache/spark/sql/parquet/test/avro/Nested.java rename to sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/Nested.java index 051f1ee903863..a0a406bcd10c1 100644 --- a/sql/core/src/test/gen-java/org/apache/spark/sql/parquet/test/avro/Nested.java +++ b/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/Nested.java @@ -3,7 +3,7 @@ * * DO NOT EDIT DIRECTLY */ -package org.apache.spark.sql.parquet.test.avro; +package org.apache.spark.sql.execution.datasources.parquet.test.avro; @SuppressWarnings("all") @org.apache.avro.specific.AvroGenerated public class Nested extends org.apache.avro.specific.SpecificRecordBase implements org.apache.avro.specific.SpecificRecord { @@ -77,18 +77,18 @@ public void setNestedStringColumn(java.lang.String value) { } /** Creates a new Nested RecordBuilder */ - public static org.apache.spark.sql.parquet.test.avro.Nested.Builder newBuilder() { - return new org.apache.spark.sql.parquet.test.avro.Nested.Builder(); + public static org.apache.spark.sql.execution.datasources.parquet.test.avro.Nested.Builder newBuilder() { + return new org.apache.spark.sql.execution.datasources.parquet.test.avro.Nested.Builder(); } /** Creates a new Nested RecordBuilder by copying an existing Builder */ - public static org.apache.spark.sql.parquet.test.avro.Nested.Builder newBuilder(org.apache.spark.sql.parquet.test.avro.Nested.Builder other) { - return new org.apache.spark.sql.parquet.test.avro.Nested.Builder(other); + public static org.apache.spark.sql.execution.datasources.parquet.test.avro.Nested.Builder newBuilder(org.apache.spark.sql.execution.datasources.parquet.test.avro.Nested.Builder other) { + return new org.apache.spark.sql.execution.datasources.parquet.test.avro.Nested.Builder(other); } /** Creates a new Nested RecordBuilder by copying an existing Nested instance */ - public static org.apache.spark.sql.parquet.test.avro.Nested.Builder newBuilder(org.apache.spark.sql.parquet.test.avro.Nested other) { - return new org.apache.spark.sql.parquet.test.avro.Nested.Builder(other); + public static org.apache.spark.sql.execution.datasources.parquet.test.avro.Nested.Builder newBuilder(org.apache.spark.sql.execution.datasources.parquet.test.avro.Nested other) { + return new org.apache.spark.sql.execution.datasources.parquet.test.avro.Nested.Builder(other); } /** @@ -102,11 +102,11 @@ public static class Builder extends org.apache.avro.specific.SpecificRecordBuild /** Creates a new Builder */ private Builder() { - super(org.apache.spark.sql.parquet.test.avro.Nested.SCHEMA$); + super(org.apache.spark.sql.execution.datasources.parquet.test.avro.Nested.SCHEMA$); } /** Creates a Builder by copying an existing Builder */ - private Builder(org.apache.spark.sql.parquet.test.avro.Nested.Builder other) { + private Builder(org.apache.spark.sql.execution.datasources.parquet.test.avro.Nested.Builder other) { super(other); if (isValidValue(fields()[0], other.nested_ints_column)) { this.nested_ints_column = data().deepCopy(fields()[0].schema(), other.nested_ints_column); @@ -119,8 +119,8 @@ private Builder(org.apache.spark.sql.parquet.test.avro.Nested.Builder other) { } /** Creates a Builder by copying an existing Nested instance */ - private Builder(org.apache.spark.sql.parquet.test.avro.Nested other) { - super(org.apache.spark.sql.parquet.test.avro.Nested.SCHEMA$); + private Builder(org.apache.spark.sql.execution.datasources.parquet.test.avro.Nested other) { + super(org.apache.spark.sql.execution.datasources.parquet.test.avro.Nested.SCHEMA$); if (isValidValue(fields()[0], other.nested_ints_column)) { this.nested_ints_column = data().deepCopy(fields()[0].schema(), other.nested_ints_column); fieldSetFlags()[0] = true; @@ -137,7 +137,7 @@ public java.util.List getNestedIntsColumn() { } /** Sets the value of the 'nested_ints_column' field */ - public org.apache.spark.sql.parquet.test.avro.Nested.Builder setNestedIntsColumn(java.util.List value) { + public org.apache.spark.sql.execution.datasources.parquet.test.avro.Nested.Builder setNestedIntsColumn(java.util.List value) { validate(fields()[0], value); this.nested_ints_column = value; fieldSetFlags()[0] = true; @@ -150,7 +150,7 @@ public boolean hasNestedIntsColumn() { } /** Clears the value of the 'nested_ints_column' field */ - public org.apache.spark.sql.parquet.test.avro.Nested.Builder clearNestedIntsColumn() { + public org.apache.spark.sql.execution.datasources.parquet.test.avro.Nested.Builder clearNestedIntsColumn() { nested_ints_column = null; fieldSetFlags()[0] = false; return this; @@ -162,7 +162,7 @@ public java.lang.String getNestedStringColumn() { } /** Sets the value of the 'nested_string_column' field */ - public org.apache.spark.sql.parquet.test.avro.Nested.Builder setNestedStringColumn(java.lang.String value) { + public org.apache.spark.sql.execution.datasources.parquet.test.avro.Nested.Builder setNestedStringColumn(java.lang.String value) { validate(fields()[1], value); this.nested_string_column = value; fieldSetFlags()[1] = true; @@ -175,7 +175,7 @@ public boolean hasNestedStringColumn() { } /** Clears the value of the 'nested_string_column' field */ - public org.apache.spark.sql.parquet.test.avro.Nested.Builder clearNestedStringColumn() { + public org.apache.spark.sql.execution.datasources.parquet.test.avro.Nested.Builder clearNestedStringColumn() { nested_string_column = null; fieldSetFlags()[1] = false; return this; diff --git a/sql/core/src/test/gen-java/org/apache/spark/sql/parquet/test/avro/ParquetAvroCompat.java b/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/ParquetAvroCompat.java similarity index 83% rename from sql/core/src/test/gen-java/org/apache/spark/sql/parquet/test/avro/ParquetAvroCompat.java rename to sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/ParquetAvroCompat.java index 354c9d73cca31..6198b00b1e3ca 100644 --- a/sql/core/src/test/gen-java/org/apache/spark/sql/parquet/test/avro/ParquetAvroCompat.java +++ b/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/ParquetAvroCompat.java @@ -3,7 +3,7 @@ * * DO NOT EDIT DIRECTLY */ -package org.apache.spark.sql.parquet.test.avro; +package org.apache.spark.sql.execution.datasources.parquet.test.avro; @SuppressWarnings("all") @org.apache.avro.specific.AvroGenerated public class ParquetAvroCompat extends org.apache.avro.specific.SpecificRecordBase implements org.apache.avro.specific.SpecificRecord { @@ -25,7 +25,7 @@ public class ParquetAvroCompat extends org.apache.avro.specific.SpecificRecordBa @Deprecated public java.lang.String maybe_string_column; @Deprecated public java.util.List strings_column; @Deprecated public java.util.Map string_to_int_column; - @Deprecated public java.util.Map> complex_column; + @Deprecated public java.util.Map> complex_column; /** * Default constructor. Note that this does not initialize fields @@ -37,7 +37,7 @@ public ParquetAvroCompat() {} /** * All-args constructor. */ - public ParquetAvroCompat(java.lang.Boolean bool_column, java.lang.Integer int_column, java.lang.Long long_column, java.lang.Float float_column, java.lang.Double double_column, java.nio.ByteBuffer binary_column, java.lang.String string_column, java.lang.Boolean maybe_bool_column, java.lang.Integer maybe_int_column, java.lang.Long maybe_long_column, java.lang.Float maybe_float_column, java.lang.Double maybe_double_column, java.nio.ByteBuffer maybe_binary_column, java.lang.String maybe_string_column, java.util.List strings_column, java.util.Map string_to_int_column, java.util.Map> complex_column) { + public ParquetAvroCompat(java.lang.Boolean bool_column, java.lang.Integer int_column, java.lang.Long long_column, java.lang.Float float_column, java.lang.Double double_column, java.nio.ByteBuffer binary_column, java.lang.String string_column, java.lang.Boolean maybe_bool_column, java.lang.Integer maybe_int_column, java.lang.Long maybe_long_column, java.lang.Float maybe_float_column, java.lang.Double maybe_double_column, java.nio.ByteBuffer maybe_binary_column, java.lang.String maybe_string_column, java.util.List strings_column, java.util.Map string_to_int_column, java.util.Map> complex_column) { this.bool_column = bool_column; this.int_column = int_column; this.long_column = long_column; @@ -101,7 +101,7 @@ public void put(int field$, java.lang.Object value$) { case 13: maybe_string_column = (java.lang.String)value$; break; case 14: strings_column = (java.util.List)value$; break; case 15: string_to_int_column = (java.util.Map)value$; break; - case 16: complex_column = (java.util.Map>)value$; break; + case 16: complex_column = (java.util.Map>)value$; break; default: throw new org.apache.avro.AvroRuntimeException("Bad index"); } } @@ -349,7 +349,7 @@ public void setStringToIntColumn(java.util.Map> getComplexColumn() { + public java.util.Map> getComplexColumn() { return complex_column; } @@ -357,23 +357,23 @@ public java.util.Map> value) { + public void setComplexColumn(java.util.Map> value) { this.complex_column = value; } /** Creates a new ParquetAvroCompat RecordBuilder */ - public static org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder newBuilder() { - return new org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder(); + public static org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat.Builder newBuilder() { + return new org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat.Builder(); } /** Creates a new ParquetAvroCompat RecordBuilder by copying an existing Builder */ - public static org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder newBuilder(org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder other) { - return new org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder(other); + public static org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat.Builder newBuilder(org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat.Builder other) { + return new org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat.Builder(other); } /** Creates a new ParquetAvroCompat RecordBuilder by copying an existing ParquetAvroCompat instance */ - public static org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder newBuilder(org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat other) { - return new org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder(other); + public static org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat.Builder newBuilder(org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat other) { + return new org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat.Builder(other); } /** @@ -398,15 +398,15 @@ public static class Builder extends org.apache.avro.specific.SpecificRecordBuild private java.lang.String maybe_string_column; private java.util.List strings_column; private java.util.Map string_to_int_column; - private java.util.Map> complex_column; + private java.util.Map> complex_column; /** Creates a new Builder */ private Builder() { - super(org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.SCHEMA$); + super(org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat.SCHEMA$); } /** Creates a Builder by copying an existing Builder */ - private Builder(org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder other) { + private Builder(org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat.Builder other) { super(other); if (isValidValue(fields()[0], other.bool_column)) { this.bool_column = data().deepCopy(fields()[0].schema(), other.bool_column); @@ -479,8 +479,8 @@ private Builder(org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder } /** Creates a Builder by copying an existing ParquetAvroCompat instance */ - private Builder(org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat other) { - super(org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.SCHEMA$); + private Builder(org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat other) { + super(org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat.SCHEMA$); if (isValidValue(fields()[0], other.bool_column)) { this.bool_column = data().deepCopy(fields()[0].schema(), other.bool_column); fieldSetFlags()[0] = true; @@ -557,7 +557,7 @@ public java.lang.Boolean getBoolColumn() { } /** Sets the value of the 'bool_column' field */ - public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder setBoolColumn(boolean value) { + public org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat.Builder setBoolColumn(boolean value) { validate(fields()[0], value); this.bool_column = value; fieldSetFlags()[0] = true; @@ -570,7 +570,7 @@ public boolean hasBoolColumn() { } /** Clears the value of the 'bool_column' field */ - public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder clearBoolColumn() { + public org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat.Builder clearBoolColumn() { fieldSetFlags()[0] = false; return this; } @@ -581,7 +581,7 @@ public java.lang.Integer getIntColumn() { } /** Sets the value of the 'int_column' field */ - public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder setIntColumn(int value) { + public org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat.Builder setIntColumn(int value) { validate(fields()[1], value); this.int_column = value; fieldSetFlags()[1] = true; @@ -594,7 +594,7 @@ public boolean hasIntColumn() { } /** Clears the value of the 'int_column' field */ - public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder clearIntColumn() { + public org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat.Builder clearIntColumn() { fieldSetFlags()[1] = false; return this; } @@ -605,7 +605,7 @@ public java.lang.Long getLongColumn() { } /** Sets the value of the 'long_column' field */ - public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder setLongColumn(long value) { + public org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat.Builder setLongColumn(long value) { validate(fields()[2], value); this.long_column = value; fieldSetFlags()[2] = true; @@ -618,7 +618,7 @@ public boolean hasLongColumn() { } /** Clears the value of the 'long_column' field */ - public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder clearLongColumn() { + public org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat.Builder clearLongColumn() { fieldSetFlags()[2] = false; return this; } @@ -629,7 +629,7 @@ public java.lang.Float getFloatColumn() { } /** Sets the value of the 'float_column' field */ - public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder setFloatColumn(float value) { + public org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat.Builder setFloatColumn(float value) { validate(fields()[3], value); this.float_column = value; fieldSetFlags()[3] = true; @@ -642,7 +642,7 @@ public boolean hasFloatColumn() { } /** Clears the value of the 'float_column' field */ - public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder clearFloatColumn() { + public org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat.Builder clearFloatColumn() { fieldSetFlags()[3] = false; return this; } @@ -653,7 +653,7 @@ public java.lang.Double getDoubleColumn() { } /** Sets the value of the 'double_column' field */ - public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder setDoubleColumn(double value) { + public org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat.Builder setDoubleColumn(double value) { validate(fields()[4], value); this.double_column = value; fieldSetFlags()[4] = true; @@ -666,7 +666,7 @@ public boolean hasDoubleColumn() { } /** Clears the value of the 'double_column' field */ - public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder clearDoubleColumn() { + public org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat.Builder clearDoubleColumn() { fieldSetFlags()[4] = false; return this; } @@ -677,7 +677,7 @@ public java.nio.ByteBuffer getBinaryColumn() { } /** Sets the value of the 'binary_column' field */ - public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder setBinaryColumn(java.nio.ByteBuffer value) { + public org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat.Builder setBinaryColumn(java.nio.ByteBuffer value) { validate(fields()[5], value); this.binary_column = value; fieldSetFlags()[5] = true; @@ -690,7 +690,7 @@ public boolean hasBinaryColumn() { } /** Clears the value of the 'binary_column' field */ - public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder clearBinaryColumn() { + public org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat.Builder clearBinaryColumn() { binary_column = null; fieldSetFlags()[5] = false; return this; @@ -702,7 +702,7 @@ public java.lang.String getStringColumn() { } /** Sets the value of the 'string_column' field */ - public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder setStringColumn(java.lang.String value) { + public org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat.Builder setStringColumn(java.lang.String value) { validate(fields()[6], value); this.string_column = value; fieldSetFlags()[6] = true; @@ -715,7 +715,7 @@ public boolean hasStringColumn() { } /** Clears the value of the 'string_column' field */ - public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder clearStringColumn() { + public org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat.Builder clearStringColumn() { string_column = null; fieldSetFlags()[6] = false; return this; @@ -727,7 +727,7 @@ public java.lang.Boolean getMaybeBoolColumn() { } /** Sets the value of the 'maybe_bool_column' field */ - public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder setMaybeBoolColumn(java.lang.Boolean value) { + public org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat.Builder setMaybeBoolColumn(java.lang.Boolean value) { validate(fields()[7], value); this.maybe_bool_column = value; fieldSetFlags()[7] = true; @@ -740,7 +740,7 @@ public boolean hasMaybeBoolColumn() { } /** Clears the value of the 'maybe_bool_column' field */ - public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder clearMaybeBoolColumn() { + public org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat.Builder clearMaybeBoolColumn() { maybe_bool_column = null; fieldSetFlags()[7] = false; return this; @@ -752,7 +752,7 @@ public java.lang.Integer getMaybeIntColumn() { } /** Sets the value of the 'maybe_int_column' field */ - public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder setMaybeIntColumn(java.lang.Integer value) { + public org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat.Builder setMaybeIntColumn(java.lang.Integer value) { validate(fields()[8], value); this.maybe_int_column = value; fieldSetFlags()[8] = true; @@ -765,7 +765,7 @@ public boolean hasMaybeIntColumn() { } /** Clears the value of the 'maybe_int_column' field */ - public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder clearMaybeIntColumn() { + public org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat.Builder clearMaybeIntColumn() { maybe_int_column = null; fieldSetFlags()[8] = false; return this; @@ -777,7 +777,7 @@ public java.lang.Long getMaybeLongColumn() { } /** Sets the value of the 'maybe_long_column' field */ - public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder setMaybeLongColumn(java.lang.Long value) { + public org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat.Builder setMaybeLongColumn(java.lang.Long value) { validate(fields()[9], value); this.maybe_long_column = value; fieldSetFlags()[9] = true; @@ -790,7 +790,7 @@ public boolean hasMaybeLongColumn() { } /** Clears the value of the 'maybe_long_column' field */ - public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder clearMaybeLongColumn() { + public org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat.Builder clearMaybeLongColumn() { maybe_long_column = null; fieldSetFlags()[9] = false; return this; @@ -802,7 +802,7 @@ public java.lang.Float getMaybeFloatColumn() { } /** Sets the value of the 'maybe_float_column' field */ - public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder setMaybeFloatColumn(java.lang.Float value) { + public org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat.Builder setMaybeFloatColumn(java.lang.Float value) { validate(fields()[10], value); this.maybe_float_column = value; fieldSetFlags()[10] = true; @@ -815,7 +815,7 @@ public boolean hasMaybeFloatColumn() { } /** Clears the value of the 'maybe_float_column' field */ - public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder clearMaybeFloatColumn() { + public org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat.Builder clearMaybeFloatColumn() { maybe_float_column = null; fieldSetFlags()[10] = false; return this; @@ -827,7 +827,7 @@ public java.lang.Double getMaybeDoubleColumn() { } /** Sets the value of the 'maybe_double_column' field */ - public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder setMaybeDoubleColumn(java.lang.Double value) { + public org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat.Builder setMaybeDoubleColumn(java.lang.Double value) { validate(fields()[11], value); this.maybe_double_column = value; fieldSetFlags()[11] = true; @@ -840,7 +840,7 @@ public boolean hasMaybeDoubleColumn() { } /** Clears the value of the 'maybe_double_column' field */ - public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder clearMaybeDoubleColumn() { + public org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat.Builder clearMaybeDoubleColumn() { maybe_double_column = null; fieldSetFlags()[11] = false; return this; @@ -852,7 +852,7 @@ public java.nio.ByteBuffer getMaybeBinaryColumn() { } /** Sets the value of the 'maybe_binary_column' field */ - public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder setMaybeBinaryColumn(java.nio.ByteBuffer value) { + public org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat.Builder setMaybeBinaryColumn(java.nio.ByteBuffer value) { validate(fields()[12], value); this.maybe_binary_column = value; fieldSetFlags()[12] = true; @@ -865,7 +865,7 @@ public boolean hasMaybeBinaryColumn() { } /** Clears the value of the 'maybe_binary_column' field */ - public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder clearMaybeBinaryColumn() { + public org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat.Builder clearMaybeBinaryColumn() { maybe_binary_column = null; fieldSetFlags()[12] = false; return this; @@ -877,7 +877,7 @@ public java.lang.String getMaybeStringColumn() { } /** Sets the value of the 'maybe_string_column' field */ - public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder setMaybeStringColumn(java.lang.String value) { + public org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat.Builder setMaybeStringColumn(java.lang.String value) { validate(fields()[13], value); this.maybe_string_column = value; fieldSetFlags()[13] = true; @@ -890,7 +890,7 @@ public boolean hasMaybeStringColumn() { } /** Clears the value of the 'maybe_string_column' field */ - public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder clearMaybeStringColumn() { + public org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat.Builder clearMaybeStringColumn() { maybe_string_column = null; fieldSetFlags()[13] = false; return this; @@ -902,7 +902,7 @@ public java.util.List getStringsColumn() { } /** Sets the value of the 'strings_column' field */ - public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder setStringsColumn(java.util.List value) { + public org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat.Builder setStringsColumn(java.util.List value) { validate(fields()[14], value); this.strings_column = value; fieldSetFlags()[14] = true; @@ -915,7 +915,7 @@ public boolean hasStringsColumn() { } /** Clears the value of the 'strings_column' field */ - public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder clearStringsColumn() { + public org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat.Builder clearStringsColumn() { strings_column = null; fieldSetFlags()[14] = false; return this; @@ -927,7 +927,7 @@ public java.util.Map getStringToIntColumn() } /** Sets the value of the 'string_to_int_column' field */ - public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder setStringToIntColumn(java.util.Map value) { + public org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat.Builder setStringToIntColumn(java.util.Map value) { validate(fields()[15], value); this.string_to_int_column = value; fieldSetFlags()[15] = true; @@ -940,19 +940,19 @@ public boolean hasStringToIntColumn() { } /** Clears the value of the 'string_to_int_column' field */ - public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder clearStringToIntColumn() { + public org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat.Builder clearStringToIntColumn() { string_to_int_column = null; fieldSetFlags()[15] = false; return this; } /** Gets the value of the 'complex_column' field */ - public java.util.Map> getComplexColumn() { + public java.util.Map> getComplexColumn() { return complex_column; } /** Sets the value of the 'complex_column' field */ - public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder setComplexColumn(java.util.Map> value) { + public org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat.Builder setComplexColumn(java.util.Map> value) { validate(fields()[16], value); this.complex_column = value; fieldSetFlags()[16] = true; @@ -965,7 +965,7 @@ public boolean hasComplexColumn() { } /** Clears the value of the 'complex_column' field */ - public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder clearComplexColumn() { + public org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat.Builder clearComplexColumn() { complex_column = null; fieldSetFlags()[16] = false; return this; @@ -991,7 +991,7 @@ public ParquetAvroCompat build() { record.maybe_string_column = fieldSetFlags()[13] ? this.maybe_string_column : (java.lang.String) defaultValue(fields()[13]); record.strings_column = fieldSetFlags()[14] ? this.strings_column : (java.util.List) defaultValue(fields()[14]); record.string_to_int_column = fieldSetFlags()[15] ? this.string_to_int_column : (java.util.Map) defaultValue(fields()[15]); - record.complex_column = fieldSetFlags()[16] ? this.complex_column : (java.util.Map>) defaultValue(fields()[16]); + record.complex_column = fieldSetFlags()[16] ? this.complex_column : (java.util.Map>) defaultValue(fields()[16]); return record; } catch (Exception e) { throw new org.apache.avro.AvroRuntimeException(e); diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index c49f256be5501..adbd95197d7ca 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -25,8 +25,8 @@ import scala.util.Random import org.apache.spark.sql.catalyst.plans.logical.OneRowRelation import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.functions._ -import org.apache.spark.sql.json.JSONRelation -import org.apache.spark.sql.parquet.ParquetRelation +import org.apache.spark.sql.execution.datasources.json.JSONRelation +import org.apache.spark.sql.execution.datasources.parquet.ParquetRelation import org.apache.spark.sql.types._ import org.apache.spark.sql.test.{ExamplePointUDT, ExamplePoint, SQLTestUtils} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala similarity index 99% rename from sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala index 92022ff23d2c3..73d5621897819 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.json +package org.apache.spark.sql.execution.datasources.json import java.io.{File, StringWriter} import java.sql.{Date, Timestamp} @@ -28,7 +28,7 @@ import org.apache.spark.sql.{SQLContext, QueryTest, Row, SQLConf} import org.apache.spark.sql.TestData._ import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution.datasources.{ResolvedDataSource, LogicalRelation} -import org.apache.spark.sql.json.InferSchema.compatibleType +import org.apache.spark.sql.execution.datasources.json.InferSchema.compatibleType import org.apache.spark.sql.types._ import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.util.Utils diff --git a/sql/core/src/test/scala/org/apache/spark/sql/json/TestJsonData.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/TestJsonData.scala similarity index 99% rename from sql/core/src/test/scala/org/apache/spark/sql/json/TestJsonData.scala rename to sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/TestJsonData.scala index 369df5653060b..6b62c9a003df6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/json/TestJsonData.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/TestJsonData.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.json +package org.apache.spark.sql.execution.datasources.json import org.apache.spark.rdd.RDD import org.apache.spark.sql.SQLContext diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetAvroCompatibilitySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetAvroCompatibilitySuite.scala similarity index 96% rename from sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetAvroCompatibilitySuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetAvroCompatibilitySuite.scala index bfa427349ff6a..4d9c07bb7a570 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetAvroCompatibilitySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetAvroCompatibilitySuite.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.parquet +package org.apache.spark.sql.execution.datasources.parquet import java.nio.ByteBuffer import java.util.{List => JList, Map => JMap} @@ -25,7 +25,7 @@ import scala.collection.JavaConversions._ import org.apache.hadoop.fs.Path import org.apache.parquet.avro.AvroParquetWriter -import org.apache.spark.sql.parquet.test.avro.{Nested, ParquetAvroCompat} +import org.apache.spark.sql.execution.datasources.parquet.test.avro.{Nested, ParquetAvroCompat} import org.apache.spark.sql.test.TestSQLContext import org.apache.spark.sql.{Row, SQLContext} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetCompatibilityTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCompatibilityTest.scala similarity index 97% rename from sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetCompatibilityTest.scala rename to sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCompatibilityTest.scala index 57478931cd509..68f35b1f3aa83 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetCompatibilityTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCompatibilityTest.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.parquet +package org.apache.spark.sql.execution.datasources.parquet import java.io.File import scala.collection.JavaConversions._ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala similarity index 99% rename from sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala index b6a7c4fbddbdc..7dd9680d8cd65 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.parquet +package org.apache.spark.sql.execution.datasources.parquet import org.apache.parquet.filter2.predicate.Operators._ import org.apache.parquet.filter2.predicate.{FilterPredicate, Operators} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala similarity index 99% rename from sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala index b415da5b8c136..ee925afe08508 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.parquet +package org.apache.spark.sql.execution.datasources.parquet import scala.collection.JavaConversions._ import scala.reflect.ClassTag @@ -373,7 +373,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest { // _temporary should be missing if direct output committer works. try { configuration.set("spark.sql.parquet.output.committer.class", - "org.apache.spark.sql.parquet.DirectParquetOutputCommitter") + classOf[DirectParquetOutputCommitter].getCanonicalName) sqlContext.udf.register("div0", (x: Int) => x / 0) withTempPath { dir => intercept[org.apache.spark.SparkException] { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetPartitionDiscoverySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala similarity index 99% rename from sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetPartitionDiscoverySuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala index 2eef10189f11c..73152de244759 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetPartitionDiscoverySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.parquet +package org.apache.spark.sql.execution.datasources.parquet import java.io.File import java.math.BigInteger diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala similarity index 99% rename from sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala index 5c65a8ec57f00..5e6d9c1cd44a8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.parquet +package org.apache.spark.sql.execution.datasources.parquet import java.io.File diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetSchemaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala similarity index 99% rename from sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetSchemaSuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala index 4a0b3b60f419d..8f06de7ce7c4f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetSchemaSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.parquet +package org.apache.spark.sql.execution.datasources.parquet import scala.reflect.ClassTag import scala.reflect.runtime.universe.TypeTag diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTest.scala similarity index 98% rename from sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetTest.scala rename to sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTest.scala index 64e94056f209a..3c6e54db4bca7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTest.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.parquet +package org.apache.spark.sql.execution.datasources.parquet import java.io.File diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetThriftCompatibilitySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetThriftCompatibilitySuite.scala similarity index 98% rename from sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetThriftCompatibilitySuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetThriftCompatibilitySuite.scala index 1c532d78790d2..92b1d822172d5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetThriftCompatibilitySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetThriftCompatibilitySuite.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.parquet +package org.apache.spark.sql.execution.datasources.parquet import org.apache.spark.sql.test.TestSQLContext import org.apache.spark.sql.{Row, SQLContext} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala similarity index 92% rename from sql/core/src/test/scala/org/apache/spark/sql/metric/SQLMetricsSuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala index d22160f5384f4..953284c98b208 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/metric/SQLMetricsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.metric +package org.apache.spark.sql.execution.metric import java.io.{ByteArrayInputStream, ByteArrayOutputStream} @@ -41,16 +41,6 @@ class SQLMetricsSuite extends SparkFunSuite { } } - test("IntSQLMetric should not box Int") { - val l = SQLMetrics.createIntMetric(TestSQLContext.sparkContext, "Int") - val f = () => { l += 1 } - BoxingFinder.getClassReader(f.getClass).foreach { cl => - val boxingFinder = new BoxingFinder() - cl.accept(boxingFinder, 0) - assert(boxingFinder.boxingInvokes.isEmpty, s"Found boxing: ${boxingFinder.boxingInvokes}") - } - } - test("Normal accumulator should do boxing") { // We need this test to make sure BoxingFinder works. val l = TestSQLContext.sparkContext.accumulator(0L) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ui/SQLListenerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala similarity index 99% rename from sql/core/src/test/scala/org/apache/spark/sql/ui/SQLListenerSuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala index 69a561e16aa17..41dd1896c15df 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ui/SQLListenerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala @@ -15,13 +15,13 @@ * limitations under the License. */ -package org.apache.spark.sql.ui +package org.apache.spark.sql.execution.ui import java.util.Properties import org.apache.spark.{SparkException, SparkContext, SparkConf, SparkFunSuite} import org.apache.spark.executor.TaskMetrics -import org.apache.spark.sql.metric.LongSQLMetricValue +import org.apache.spark.sql.execution.metric.LongSQLMetricValue import org.apache.spark.scheduler._ import org.apache.spark.sql.{DataFrame, SQLContext} import org.apache.spark.sql.execution.SQLExecution diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala index 1907e643c85dd..562c279067048 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala @@ -51,7 +51,7 @@ class CreateTableAsSelectSuite extends DataSourceTest with BeforeAndAfterAll { sql( s""" |CREATE TEMPORARY TABLE jsonTable - |USING org.apache.spark.sql.json.DefaultSource + |USING json |OPTIONS ( | path '${path.toString}' |) AS @@ -75,7 +75,7 @@ class CreateTableAsSelectSuite extends DataSourceTest with BeforeAndAfterAll { sql( s""" |CREATE TEMPORARY TABLE jsonTable - |USING org.apache.spark.sql.json.DefaultSource + |USING json |OPTIONS ( | path '${path.toString}' |) AS @@ -92,7 +92,7 @@ class CreateTableAsSelectSuite extends DataSourceTest with BeforeAndAfterAll { sql( s""" |CREATE TEMPORARY TABLE jsonTable - |USING org.apache.spark.sql.json.DefaultSource + |USING json |OPTIONS ( | path '${path.toString}' |) AS @@ -107,7 +107,7 @@ class CreateTableAsSelectSuite extends DataSourceTest with BeforeAndAfterAll { sql( s""" |CREATE TEMPORARY TABLE IF NOT EXISTS jsonTable - |USING org.apache.spark.sql.json.DefaultSource + |USING json |OPTIONS ( | path '${path.toString}' |) AS @@ -122,7 +122,7 @@ class CreateTableAsSelectSuite extends DataSourceTest with BeforeAndAfterAll { sql( s""" |CREATE TEMPORARY TABLE jsonTable - |USING org.apache.spark.sql.json.DefaultSource + |USING json |OPTIONS ( | path '${path.toString}' |) AS @@ -139,7 +139,7 @@ class CreateTableAsSelectSuite extends DataSourceTest with BeforeAndAfterAll { sql( s""" |CREATE TEMPORARY TABLE jsonTable - |USING org.apache.spark.sql.json.DefaultSource + |USING json |OPTIONS ( | path '${path.toString}' |) AS @@ -158,7 +158,7 @@ class CreateTableAsSelectSuite extends DataSourceTest with BeforeAndAfterAll { sql( s""" |CREATE TEMPORARY TABLE IF NOT EXISTS jsonTable - |USING org.apache.spark.sql.json.DefaultSource + |USING json |OPTIONS ( | path '${path.toString}' |) AS @@ -175,7 +175,7 @@ class CreateTableAsSelectSuite extends DataSourceTest with BeforeAndAfterAll { sql( s""" |CREATE TEMPORARY TABLE jsonTable (a int, b string) - |USING org.apache.spark.sql.json.DefaultSource + |USING json |OPTIONS ( | path '${path.toString}' |) AS @@ -188,7 +188,7 @@ class CreateTableAsSelectSuite extends DataSourceTest with BeforeAndAfterAll { sql( s""" |CREATE TEMPORARY TABLE jsonTable - |USING org.apache.spark.sql.json.DefaultSource + |USING json |OPTIONS ( | path '${path.toString}' |) AS @@ -199,7 +199,7 @@ class CreateTableAsSelectSuite extends DataSourceTest with BeforeAndAfterAll { sql( s""" |CREATE TEMPORARY TABLE jsonTable - |USING org.apache.spark.sql.json.DefaultSource + |USING json |OPTIONS ( | path '${path.toString}' |) AS diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLSourceLoadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLSourceLoadSuite.scala index 1a4d41b02ca68..392da0b0826b5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLSourceLoadSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLSourceLoadSuite.scala @@ -20,9 +20,37 @@ package org.apache.spark.sql.sources import org.apache.spark.sql.SQLContext import org.apache.spark.sql.types.{StringType, StructField, StructType} + +// please note that the META-INF/services had to be modified for the test directory for this to work +class DDLSourceLoadSuite extends DataSourceTest { + + test("data sources with the same name") { + intercept[RuntimeException] { + caseInsensitiveContext.read.format("Fluet da Bomb").load() + } + } + + test("load data source from format alias") { + caseInsensitiveContext.read.format("gathering quorum").load().schema == + StructType(Seq(StructField("stringType", StringType, nullable = false))) + } + + test("specify full classname with duplicate formats") { + caseInsensitiveContext.read.format("org.apache.spark.sql.sources.FakeSourceOne") + .load().schema == StructType(Seq(StructField("stringType", StringType, nullable = false))) + } + + test("should fail to load ORC without HiveContext") { + intercept[ClassNotFoundException] { + caseInsensitiveContext.read.format("orc").load() + } + } +} + + class FakeSourceOne extends RelationProvider with DataSourceRegister { - def format(): String = "Fluet da Bomb" + def shortName(): String = "Fluet da Bomb" override def createRelation(cont: SQLContext, param: Map[String, String]): BaseRelation = new BaseRelation { @@ -35,7 +63,7 @@ class FakeSourceOne extends RelationProvider with DataSourceRegister { class FakeSourceTwo extends RelationProvider with DataSourceRegister { - def format(): String = "Fluet da Bomb" + def shortName(): String = "Fluet da Bomb" override def createRelation(cont: SQLContext, param: Map[String, String]): BaseRelation = new BaseRelation { @@ -48,7 +76,7 @@ class FakeSourceTwo extends RelationProvider with DataSourceRegister { class FakeSourceThree extends RelationProvider with DataSourceRegister { - def format(): String = "gathering quorum" + def shortName(): String = "gathering quorum" override def createRelation(cont: SQLContext, param: Map[String, String]): BaseRelation = new BaseRelation { @@ -58,28 +86,3 @@ class FakeSourceThree extends RelationProvider with DataSourceRegister { StructType(Seq(StructField("stringType", StringType, nullable = false))) } } -// please note that the META-INF/services had to be modified for the test directory for this to work -class DDLSourceLoadSuite extends DataSourceTest { - - test("data sources with the same name") { - intercept[RuntimeException] { - caseInsensitiveContext.read.format("Fluet da Bomb").load() - } - } - - test("load data source from format alias") { - caseInsensitiveContext.read.format("gathering quorum").load().schema == - StructType(Seq(StructField("stringType", StringType, nullable = false))) - } - - test("specify full classname with duplicate formats") { - caseInsensitiveContext.read.format("org.apache.spark.sql.sources.FakeSourceOne") - .load().schema == StructType(Seq(StructField("stringType", StringType, nullable = false))) - } - - test("Loading Orc") { - intercept[ClassNotFoundException] { - caseInsensitiveContext.read.format("orc").load() - } - } -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/ResolvedDataSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/ResolvedDataSourceSuite.scala index 3cbf5467b253a..27d1cd92fca1a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/ResolvedDataSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/ResolvedDataSourceSuite.scala @@ -22,14 +22,39 @@ import org.apache.spark.sql.execution.datasources.ResolvedDataSource class ResolvedDataSourceSuite extends SparkFunSuite { - test("builtin sources") { - assert(ResolvedDataSource.lookupDataSource("jdbc") === - classOf[org.apache.spark.sql.jdbc.DefaultSource]) + test("jdbc") { + assert( + ResolvedDataSource.lookupDataSource("jdbc") === + classOf[org.apache.spark.sql.execution.datasources.jdbc.DefaultSource]) + assert( + ResolvedDataSource.lookupDataSource("org.apache.spark.sql.execution.datasources.jdbc") === + classOf[org.apache.spark.sql.execution.datasources.jdbc.DefaultSource]) + assert( + ResolvedDataSource.lookupDataSource("org.apache.spark.sql.jdbc") === + classOf[org.apache.spark.sql.execution.datasources.jdbc.DefaultSource]) + } - assert(ResolvedDataSource.lookupDataSource("json") === - classOf[org.apache.spark.sql.json.DefaultSource]) + test("json") { + assert( + ResolvedDataSource.lookupDataSource("json") === + classOf[org.apache.spark.sql.execution.datasources.json.DefaultSource]) + assert( + ResolvedDataSource.lookupDataSource("org.apache.spark.sql.execution.datasources.json") === + classOf[org.apache.spark.sql.execution.datasources.json.DefaultSource]) + assert( + ResolvedDataSource.lookupDataSource("org.apache.spark.sql.json") === + classOf[org.apache.spark.sql.execution.datasources.json.DefaultSource]) + } - assert(ResolvedDataSource.lookupDataSource("parquet") === - classOf[org.apache.spark.sql.parquet.DefaultSource]) + test("parquet") { + assert( + ResolvedDataSource.lookupDataSource("parquet") === + classOf[org.apache.spark.sql.execution.datasources.parquet.DefaultSource]) + assert( + ResolvedDataSource.lookupDataSource("org.apache.spark.sql.execution.datasources.parquet") === + classOf[org.apache.spark.sql.execution.datasources.parquet.DefaultSource]) + assert( + ResolvedDataSource.lookupDataSource("org.apache.spark.sql.parquet") === + classOf[org.apache.spark.sql.execution.datasources.parquet.DefaultSource]) } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index 7198a32df4a02..ac9aaed19d566 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -41,7 +41,7 @@ import org.apache.spark.sql.catalyst.{InternalRow, SqlParser, TableIdentifier} import org.apache.spark.sql.execution.datasources import org.apache.spark.sql.execution.datasources.{CreateTableUsingAsSelect, LogicalRelation, Partition => ParquetPartition, PartitionSpec, ResolvedDataSource} import org.apache.spark.sql.hive.client._ -import org.apache.spark.sql.parquet.ParquetRelation +import org.apache.spark.sql.execution.datasources.parquet.ParquetRelation import org.apache.spark.sql.sources._ import org.apache.spark.sql.types._ import org.apache.spark.sql.{AnalysisException, SQLContext, SaveMode} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala index 0c344c63fde3f..9f4f8b5789afe 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala @@ -32,7 +32,6 @@ import org.apache.hadoop.mapreduce.lib.input.FileInputFormat import org.apache.hadoop.mapreduce.{Job, TaskAttemptContext} import org.apache.spark.Logging -import org.apache.spark.annotation.DeveloperApi import org.apache.spark.mapred.SparkHadoopMapRedUtil import org.apache.spark.rdd.{HadoopRDD, RDD} import org.apache.spark.sql.catalyst.InternalRow @@ -49,9 +48,9 @@ import scala.collection.JavaConversions._ private[sql] class DefaultSource extends HadoopFsRelationProvider with DataSourceRegister { - def format(): String = "orc" + override def shortName(): String = "orc" - def createRelation( + override def createRelation( sqlContext: SQLContext, paths: Array[String], dataSchema: Option[StructType], @@ -144,7 +143,6 @@ private[orc] class OrcOutputWriter( } } -@DeveloperApi private[sql] class OrcRelation( override val paths: Array[String], maybeDataSchema: Option[StructType], diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveParquetSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveParquetSuite.scala index a45c2d957278f..1fa005d5f9a15 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveParquetSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveParquetSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.hive import org.apache.spark.sql.hive.test.TestHive -import org.apache.spark.sql.parquet.ParquetTest +import org.apache.spark.sql.execution.datasources.parquet.ParquetTest import org.apache.spark.sql.{QueryTest, Row} case class Cases(lower: String, UPPER: String) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala index b73d6665755d0..7f36a483a3965 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala @@ -32,7 +32,7 @@ import org.apache.spark.sql.hive.client.{HiveTable, ManagedTable} import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.hive.test.TestHive._ import org.apache.spark.sql.hive.test.TestHive.implicits._ -import org.apache.spark.sql.parquet.ParquetRelation +import org.apache.spark.sql.execution.datasources.parquet.ParquetRelation import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql.types._ import org.apache.spark.util.Utils diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ParquetHiveCompatibilitySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ParquetHiveCompatibilitySuite.scala index f00d3754c364a..80eb9f122ad90 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ParquetHiveCompatibilitySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ParquetHiveCompatibilitySuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.hive import org.apache.hadoop.hive.conf.HiveConf import org.apache.spark.sql.hive.test.TestHive -import org.apache.spark.sql.parquet.ParquetCompatibilityTest +import org.apache.spark.sql.execution.datasources.parquet.ParquetCompatibilityTest import org.apache.spark.sql.{Row, SQLConf, SQLContext} class ParquetHiveCompatibilitySuite extends ParquetCompatibilityTest { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 2fa7ae3fa2e12..79a136ae6f619 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -30,7 +30,7 @@ import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.hive.test.TestHive._ import org.apache.spark.sql.hive.test.TestHive.implicits._ import org.apache.spark.sql.hive.{HiveContext, HiveQLDialect, MetastoreRelation} -import org.apache.spark.sql.parquet.ParquetRelation +import org.apache.spark.sql.execution.datasources.parquet.ParquetRelation import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.CalendarInterval diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala index c4bc60086f6e1..50f02432dacce 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.hive.execution.HiveTableScan import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.hive.test.TestHive._ import org.apache.spark.sql.hive.test.TestHive.implicits._ -import org.apache.spark.sql.parquet.ParquetRelation +import org.apache.spark.sql.execution.datasources.parquet.ParquetRelation import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql.types._ import org.apache.spark.util.Utils diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/ParquetHadoopFsRelationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/ParquetHadoopFsRelationSuite.scala index d280543a071d9..cb4cedddbfddd 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/ParquetHadoopFsRelationSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/ParquetHadoopFsRelationSuite.scala @@ -23,12 +23,12 @@ import com.google.common.io.Files import org.apache.hadoop.fs.Path import org.apache.spark.deploy.SparkHadoopUtil -import org.apache.spark.sql.{AnalysisException, SaveMode, parquet} +import org.apache.spark.sql.{AnalysisException, SaveMode} import org.apache.spark.sql.types.{IntegerType, StructField, StructType} class ParquetHadoopFsRelationSuite extends HadoopFsRelationTest { - override val dataSourceName: String = classOf[parquet.DefaultSource].getCanonicalName + override val dataSourceName: String = "parquet" import sqlContext._ import sqlContext.implicits._ diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextHadoopFsRelationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextHadoopFsRelationSuite.scala index 1813cc33226d1..48c37a1fa1022 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextHadoopFsRelationSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextHadoopFsRelationSuite.scala @@ -53,7 +53,7 @@ class SimpleTextHadoopFsRelationSuite extends HadoopFsRelationTest { class JsonHadoopFsRelationSuite extends HadoopFsRelationTest { override val dataSourceName: String = - classOf[org.apache.spark.sql.json.DefaultSource].getCanonicalName + classOf[org.apache.spark.sql.execution.datasources.json.DefaultSource].getCanonicalName import sqlContext._ From fe2fb7fb7189d183a4273ad27514af4b6b461f26 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Mon, 10 Aug 2015 13:52:18 -0700 Subject: [PATCH 009/253] [SPARK-9620] [SQL] generated UnsafeProjection should support many columns or large exressions Currently, generated UnsafeProjection can reach 64k byte code limit of Java. This patch will split the generated expressions into multiple functions, to avoid the limitation. After this patch, we can work well with table that have up to 64k columns (hit max number of constants limit in Java), it should be enough in practice. cc rxin Author: Davies Liu Closes #8044 from davies/wider_table and squashes the following commits: 9192e6c [Davies Liu] fix generated safe projection d1ef81a [Davies Liu] fix failed tests 737b3d3 [Davies Liu] Merge branch 'master' of github.com:apache/spark into wider_table ffcd132 [Davies Liu] address comments 1b95be4 [Davies Liu] put the generated class into sql package 77ed72d [Davies Liu] address comments 4518e17 [Davies Liu] Merge branch 'master' of github.com:apache/spark into wider_table 75ccd01 [Davies Liu] Merge branch 'master' of github.com:apache/spark into wider_table 495e932 [Davies Liu] support wider table with more than 1k columns for generated projections --- .../expressions/codegen/CodeGenerator.scala | 48 ++++++- .../codegen/GenerateMutableProjection.scala | 43 +----- .../codegen/GenerateSafeProjection.scala | 52 ++------ .../codegen/GenerateUnsafeProjection.scala | 122 +++++++++--------- .../codegen/GenerateUnsafeRowJoiner.scala | 2 +- .../codegen/GeneratedProjectionSuite.scala | 82 ++++++++++++ 6 files changed, 207 insertions(+), 142 deletions(-) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratedProjectionSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 7b41c9a3f3b8e..c21f4d626a74e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.expressions.codegen import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer import scala.language.existentials import com.google.common.cache.{CacheBuilder, CacheLoader} @@ -265,6 +266,45 @@ class CodeGenContext { def isPrimitiveType(jt: String): Boolean = primitiveTypes.contains(jt) def isPrimitiveType(dt: DataType): Boolean = isPrimitiveType(javaType(dt)) + + /** + * Splits the generated code of expressions into multiple functions, because function has + * 64kb code size limit in JVM + * + * @param row the variable name of row that is used by expressions + */ + def splitExpressions(row: String, expressions: Seq[String]): String = { + val blocks = new ArrayBuffer[String]() + val blockBuilder = new StringBuilder() + for (code <- expressions) { + // We can't know how many byte code will be generated, so use the number of bytes as limit + if (blockBuilder.length > 64 * 1000) { + blocks.append(blockBuilder.toString()) + blockBuilder.clear() + } + blockBuilder.append(code) + } + blocks.append(blockBuilder.toString()) + + if (blocks.length == 1) { + // inline execution if only one block + blocks.head + } else { + val apply = freshName("apply") + val functions = blocks.zipWithIndex.map { case (body, i) => + val name = s"${apply}_$i" + val code = s""" + |private void $name(InternalRow $row) { + | $body + |} + """.stripMargin + addNewFunction(name, code) + name + } + + functions.map(name => s"$name($row);").mkString("\n") + } + } } /** @@ -289,15 +329,15 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin protected def declareMutableStates(ctx: CodeGenContext): String = { ctx.mutableStates.map { case (javaType, variableName, _) => s"private $javaType $variableName;" - }.mkString + }.mkString("\n") } protected def initMutableStates(ctx: CodeGenContext): String = { - ctx.mutableStates.map(_._3).mkString + ctx.mutableStates.map(_._3).mkString("\n") } protected def declareAddedFunctions(ctx: CodeGenContext): String = { - ctx.addedFuntions.map { case (funcName, funcCode) => funcCode }.mkString + ctx.addedFuntions.map { case (funcName, funcCode) => funcCode }.mkString("\n") } /** @@ -328,6 +368,8 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin private[this] def doCompile(code: String): GeneratedClass = { val evaluator = new ClassBodyEvaluator() evaluator.setParentClassLoader(getClass.getClassLoader) + // Cannot be under package codegen, or fail with java.lang.InstantiationException + evaluator.setClassName("org.apache.spark.sql.catalyst.expressions.GeneratedClass") evaluator.setDefaultImports(Array( classOf[PlatformDependent].getName, classOf[InternalRow].getName, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala index ac58423cd884d..b4d4df8934bd4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala @@ -40,7 +40,7 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu protected def create(expressions: Seq[Expression]): (() => MutableProjection) = { val ctx = newCodeGenContext() - val projectionCode = expressions.zipWithIndex.map { + val projectionCodes = expressions.zipWithIndex.map { case (NoOp, _) => "" case (e, i) => val evaluationCode = e.gen(ctx) @@ -65,49 +65,21 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu """ } } - // collect projections into blocks as function has 64kb codesize limit in JVM - val projectionBlocks = new ArrayBuffer[String]() - val blockBuilder = new StringBuilder() - for (projection <- projectionCode) { - if (blockBuilder.length > 16 * 1000) { - projectionBlocks.append(blockBuilder.toString()) - blockBuilder.clear() - } - blockBuilder.append(projection) - } - projectionBlocks.append(blockBuilder.toString()) - - val (projectionFuns, projectionCalls) = { - // inline execution if codesize limit was not broken - if (projectionBlocks.length == 1) { - ("", projectionBlocks.head) - } else { - ( - projectionBlocks.zipWithIndex.map { case (body, i) => - s""" - |private void apply$i(InternalRow i) { - | $body - |} - """.stripMargin - }.mkString, - projectionBlocks.indices.map(i => s"apply$i(i);").mkString("\n") - ) - } - } + val allProjections = ctx.splitExpressions("i", projectionCodes) val code = s""" public Object generate($exprType[] expr) { - return new SpecificProjection(expr); + return new SpecificMutableProjection(expr); } - class SpecificProjection extends ${classOf[BaseMutableProjection].getName} { + class SpecificMutableProjection extends ${classOf[BaseMutableProjection].getName} { private $exprType[] expressions; private $mutableRowType mutableRow; ${declareMutableStates(ctx)} ${declareAddedFunctions(ctx)} - public SpecificProjection($exprType[] expr) { + public SpecificMutableProjection($exprType[] expr) { expressions = expr; mutableRow = new $genericMutableRowType(${expressions.size}); ${initMutableStates(ctx)} @@ -123,12 +95,9 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu return (InternalRow) mutableRow; } - $projectionFuns - public Object apply(Object _i) { InternalRow i = (InternalRow) _i; - $projectionCalls - + $allProjections return mutableRow; } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala index ef08ddf041afc..7ad352d7ce3e9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala @@ -17,8 +17,6 @@ package org.apache.spark.sql.catalyst.expressions.codegen -import scala.collection.mutable.ArrayBuffer - import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.NoOp import org.apache.spark.sql.types._ @@ -43,6 +41,9 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] val tmp = ctx.freshName("tmp") val output = ctx.freshName("safeRow") val values = ctx.freshName("values") + // These expressions could be splitted into multiple functions + ctx.addMutableState("Object[]", values, s"this.$values = null;") + val rowClass = classOf[GenericInternalRow].getName val fieldWriters = schema.map(_.dataType).zipWithIndex.map { case (dt, i) => @@ -53,12 +54,12 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] $values[$i] = ${converter.primitive}; } """ - }.mkString("\n") - + } + val allFields = ctx.splitExpressions(tmp, fieldWriters) val code = s""" final InternalRow $tmp = $input; - final Object[] $values = new Object[${schema.length}]; - $fieldWriters + this.$values = new Object[${schema.length}]; + $allFields final InternalRow $output = new $rowClass($values); """ @@ -128,7 +129,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] protected def create(expressions: Seq[Expression]): Projection = { val ctx = newCodeGenContext() - val projectionCode = expressions.zipWithIndex.map { + val expressionCodes = expressions.zipWithIndex.map { case (NoOp, _) => "" case (e, i) => val evaluationCode = e.gen(ctx) @@ -143,36 +144,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] } """ } - // collect projections into blocks as function has 64kb codesize limit in JVM - val projectionBlocks = new ArrayBuffer[String]() - val blockBuilder = new StringBuilder() - for (projection <- projectionCode) { - if (blockBuilder.length > 16 * 1000) { - projectionBlocks.append(blockBuilder.toString()) - blockBuilder.clear() - } - blockBuilder.append(projection) - } - projectionBlocks.append(blockBuilder.toString()) - - val (projectionFuns, projectionCalls) = { - // inline it if we have only one block - if (projectionBlocks.length == 1) { - ("", projectionBlocks.head) - } else { - ( - projectionBlocks.zipWithIndex.map { case (body, i) => - s""" - |private void apply$i(InternalRow i) { - | $body - |} - """.stripMargin - }.mkString, - projectionBlocks.indices.map(i => s"apply$i(i);").mkString("\n") - ) - } - } - + val allExpressions = ctx.splitExpressions("i", expressionCodes) val code = s""" public Object generate($exprType[] expr) { return new SpecificSafeProjection(expr); @@ -183,6 +155,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] private $exprType[] expressions; private $mutableRowType mutableRow; ${declareMutableStates(ctx)} + ${declareAddedFunctions(ctx)} public SpecificSafeProjection($exprType[] expr) { expressions = expr; @@ -190,12 +163,9 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] ${initMutableStates(ctx)} } - $projectionFuns - public Object apply(Object _i) { InternalRow i = (InternalRow) _i; - $projectionCalls - + $allExpressions return mutableRow; } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala index d8912df694a10..29f6a7b981752 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -19,7 +19,6 @@ package org.apache.spark.sql.catalyst.expressions.codegen import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.PlatformDependent /** * Generates a [[Projection]] that returns an [[UnsafeRow]]. @@ -41,8 +40,6 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro private val ArrayWriter = classOf[UnsafeRowWriters.ArrayWriter].getName private val MapWriter = classOf[UnsafeRowWriters.MapWriter].getName - private val PlatformDependent = classOf[PlatformDependent].getName - /** Returns true iff we support this data type. */ def canSupport(dataType: DataType): Boolean = dataType match { case NullType => true @@ -56,19 +53,19 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro def genAdditionalSize(dt: DataType, ev: GeneratedExpressionCode): String = dt match { case t: DecimalType if t.precision > Decimal.MAX_LONG_DIGITS => - s" + $DecimalWriter.getSize(${ev.primitive})" + s"$DecimalWriter.getSize(${ev.primitive})" case StringType => - s" + (${ev.isNull} ? 0 : $StringWriter.getSize(${ev.primitive}))" + s"${ev.isNull} ? 0 : $StringWriter.getSize(${ev.primitive})" case BinaryType => - s" + (${ev.isNull} ? 0 : $BinaryWriter.getSize(${ev.primitive}))" + s"${ev.isNull} ? 0 : $BinaryWriter.getSize(${ev.primitive})" case CalendarIntervalType => - s" + (${ev.isNull} ? 0 : 16)" + s"${ev.isNull} ? 0 : 16" case _: StructType => - s" + (${ev.isNull} ? 0 : $StructWriter.getSize(${ev.primitive}))" + s"${ev.isNull} ? 0 : $StructWriter.getSize(${ev.primitive})" case _: ArrayType => - s" + (${ev.isNull} ? 0 : $ArrayWriter.getSize(${ev.primitive}))" + s"${ev.isNull} ? 0 : $ArrayWriter.getSize(${ev.primitive})" case _: MapType => - s" + (${ev.isNull} ? 0 : $MapWriter.getSize(${ev.primitive}))" + s"${ev.isNull} ? 0 : $MapWriter.getSize(${ev.primitive})" case _ => "" } @@ -125,64 +122,69 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro */ private def createCodeForStruct( ctx: CodeGenContext, + row: String, inputs: Seq[GeneratedExpressionCode], inputTypes: Seq[DataType]): GeneratedExpressionCode = { + val fixedSize = 8 * inputTypes.length + UnsafeRow.calculateBitSetWidthInBytes(inputTypes.length) + val output = ctx.freshName("convertedStruct") - ctx.addMutableState("UnsafeRow", output, s"$output = new UnsafeRow();") + ctx.addMutableState("UnsafeRow", output, s"this.$output = new UnsafeRow();") val buffer = ctx.freshName("buffer") - ctx.addMutableState("byte[]", buffer, s"$buffer = new byte[64];") - val numBytes = ctx.freshName("numBytes") + ctx.addMutableState("byte[]", buffer, s"this.$buffer = new byte[$fixedSize];") val cursor = ctx.freshName("cursor") + ctx.addMutableState("int", cursor, s"this.$cursor = 0;") + val tmp = ctx.freshName("tmpBuffer") - val convertedFields = inputTypes.zip(inputs).map { case (dt, input) => - createConvertCode(ctx, input, dt) - } - - val fixedSize = 8 * inputTypes.length + UnsafeRow.calculateBitSetWidthInBytes(inputTypes.length) - val additionalSize = inputTypes.zip(convertedFields).map { case (dt, ev) => - genAdditionalSize(dt, ev) - }.mkString("") - - val fieldWriters = inputTypes.zip(convertedFields).zipWithIndex.map { case ((dt, ev), i) => - val update = genFieldWriter(ctx, dt, ev, output, i, cursor) - if (dt.isInstanceOf[DecimalType]) { - // Can't call setNullAt() for DecimalType + val convertedFields = inputTypes.zip(inputs).zipWithIndex.map { case ((dt, input), i) => + val ev = createConvertCode(ctx, input, dt) + val growBuffer = if (!UnsafeRow.isFixedLength(dt)) { + val numBytes = ctx.freshName("numBytes") s""" + int $numBytes = $cursor + (${genAdditionalSize(dt, ev)}); + if ($buffer.length < $numBytes) { + // This will not happen frequently, because the buffer is re-used. + byte[] $tmp = new byte[$numBytes * 2]; + PlatformDependent.copyMemory($buffer, PlatformDependent.BYTE_ARRAY_OFFSET, + $tmp, PlatformDependent.BYTE_ARRAY_OFFSET, $buffer.length); + $buffer = $tmp; + } + $output.pointTo($buffer, PlatformDependent.BYTE_ARRAY_OFFSET, + ${inputTypes.length}, $numBytes); + """ + } else { + "" + } + val update = dt match { + case dt: DecimalType if dt.precision > Decimal.MAX_LONG_DIGITS => + // Can't call setNullAt() for DecimalType + s""" if (${ev.isNull}) { - $cursor += $DecimalWriter.write($output, $i, $cursor, null); + $cursor += $DecimalWriter.write($output, $i, $cursor, null); } else { - $update; + ${genFieldWriter(ctx, dt, ev, output, i, cursor)}; } """ - } else { - s""" + case _ => + s""" if (${ev.isNull}) { $output.setNullAt($i); } else { - $update; + ${genFieldWriter(ctx, dt, ev, output, i, cursor)}; } """ } - }.mkString("\n") + s""" + ${ev.code} + $growBuffer + $update + """ + } val code = s""" - ${convertedFields.map(_.code).mkString("\n")} - - final int $numBytes = $fixedSize $additionalSize; - if ($numBytes > $buffer.length) { - $buffer = new byte[$numBytes]; - } - - $output.pointTo( - $buffer, - $PlatformDependent.BYTE_ARRAY_OFFSET, - ${inputTypes.length}, - $numBytes); - - int $cursor = $fixedSize; - - $fieldWriters + $cursor = $fixedSize; + $output.pointTo($buffer, PlatformDependent.BYTE_ARRAY_OFFSET, ${inputTypes.length}, $cursor); + ${ctx.splitExpressions(row, convertedFields)} """ GeneratedExpressionCode(code, "false", output) } @@ -265,17 +267,17 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro // Should we do word align? val elementSize = elementType.defaultSize s""" - $PlatformDependent.UNSAFE.put${ctx.primitiveTypeName(elementType)}( + PlatformDependent.UNSAFE.put${ctx.primitiveTypeName(elementType)}( $buffer, - $PlatformDependent.BYTE_ARRAY_OFFSET + $cursor, + PlatformDependent.BYTE_ARRAY_OFFSET + $cursor, ${convertedElement.primitive}); $cursor += $elementSize; """ case t: DecimalType if t.precision <= Decimal.MAX_LONG_DIGITS => s""" - $PlatformDependent.UNSAFE.putLong( + PlatformDependent.UNSAFE.putLong( $buffer, - $PlatformDependent.BYTE_ARRAY_OFFSET + $cursor, + PlatformDependent.BYTE_ARRAY_OFFSET + $cursor, ${convertedElement.primitive}.toUnscaledLong()); $cursor += 8; """ @@ -284,7 +286,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro s""" $cursor += $writer.write( $buffer, - $PlatformDependent.BYTE_ARRAY_OFFSET + $cursor, + PlatformDependent.BYTE_ARRAY_OFFSET + $cursor, $elements[$index]); """ } @@ -318,14 +320,14 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro for (int $index = 0; $index < $numElements; $index++) { if ($checkNull) { // If element is null, write the negative value address into offset region. - $PlatformDependent.UNSAFE.putInt( + PlatformDependent.UNSAFE.putInt( $buffer, - $PlatformDependent.BYTE_ARRAY_OFFSET + 4 * $index, + PlatformDependent.BYTE_ARRAY_OFFSET + 4 * $index, -$cursor); } else { - $PlatformDependent.UNSAFE.putInt( + PlatformDependent.UNSAFE.putInt( $buffer, - $PlatformDependent.BYTE_ARRAY_OFFSET + 4 * $index, + PlatformDependent.BYTE_ARRAY_OFFSET + 4 * $index, $cursor); $writeElement @@ -334,7 +336,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro $output.pointTo( $buffer, - $PlatformDependent.BYTE_ARRAY_OFFSET, + PlatformDependent.BYTE_ARRAY_OFFSET, $numElements, $numBytes); } @@ -400,7 +402,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro val fieldIsNull = s"$tmp.isNullAt($i)" GeneratedExpressionCode("", fieldIsNull, getFieldCode) } - val converter = createCodeForStruct(ctx, fieldEvals, fieldTypes) + val converter = createCodeForStruct(ctx, tmp, fieldEvals, fieldTypes) val code = s""" ${input.code} UnsafeRow $output = null; @@ -427,7 +429,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro def createCode(ctx: CodeGenContext, expressions: Seq[Expression]): GeneratedExpressionCode = { val exprEvals = expressions.map(e => e.gen(ctx)) val exprTypes = expressions.map(_.dataType) - createCodeForStruct(ctx, exprEvals, exprTypes) + createCodeForStruct(ctx, "i", exprEvals, exprTypes) } protected def canonicalize(in: Seq[Expression]): Seq[Expression] = diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala index 30b51dd83fa9a..8aaa5b4300044 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala @@ -155,7 +155,7 @@ object GenerateUnsafeRowJoiner extends CodeGenerator[(StructType, StructType), U |$putLong(buf, $cursor, $getLong(buf, $cursor) + ($shift << 32)); """.stripMargin } - }.mkString + }.mkString("\n") // ------------------------ Finally, put everything together --------------------------- // val code = s""" diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratedProjectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratedProjectionSuite.scala new file mode 100644 index 0000000000000..8c7ee8720f7bb --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratedProjectionSuite.scala @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.codegen + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.types.{StringType, IntegerType, StructField, StructType} +import org.apache.spark.unsafe.types.UTF8String + +/** + * A test suite for generated projections + */ +class GeneratedProjectionSuite extends SparkFunSuite { + + test("generated projections on wider table") { + val N = 1000 + val wideRow1 = new GenericInternalRow((1 to N).toArray[Any]) + val schema1 = StructType((1 to N).map(i => StructField("", IntegerType))) + val wideRow2 = new GenericInternalRow( + (1 to N).map(i => UTF8String.fromString(i.toString)).toArray[Any]) + val schema2 = StructType((1 to N).map(i => StructField("", StringType))) + val joined = new JoinedRow(wideRow1, wideRow2) + val joinedSchema = StructType(schema1 ++ schema2) + val nested = new JoinedRow(InternalRow(joined, joined), joined) + val nestedSchema = StructType( + Seq(StructField("", joinedSchema), StructField("", joinedSchema)) ++ joinedSchema) + + // test generated UnsafeProjection + val unsafeProj = UnsafeProjection.create(nestedSchema) + val unsafe: UnsafeRow = unsafeProj(nested) + (0 until N).foreach { i => + val s = UTF8String.fromString((i + 1).toString) + assert(i + 1 === unsafe.getInt(i + 2)) + assert(s === unsafe.getUTF8String(i + 2 + N)) + assert(i + 1 === unsafe.getStruct(0, N * 2).getInt(i)) + assert(s === unsafe.getStruct(0, N * 2).getUTF8String(i + N)) + assert(i + 1 === unsafe.getStruct(1, N * 2).getInt(i)) + assert(s === unsafe.getStruct(1, N * 2).getUTF8String(i + N)) + } + + // test generated SafeProjection + val safeProj = FromUnsafeProjection(nestedSchema) + val result = safeProj(unsafe) + // Can't compare GenericInternalRow with JoinedRow directly + (0 until N).foreach { i => + val r = i + 1 + val s = UTF8String.fromString((i + 1).toString) + assert(r === result.getInt(i + 2)) + assert(s === result.getUTF8String(i + 2 + N)) + assert(r === result.getStruct(0, N * 2).getInt(i)) + assert(s === result.getStruct(0, N * 2).getUTF8String(i + N)) + assert(r === result.getStruct(1, N * 2).getInt(i)) + assert(s === result.getStruct(1, N * 2).getUTF8String(i + N)) + } + + // test generated MutableProjection + val exprs = nestedSchema.fields.zipWithIndex.map { case (f, i) => + BoundReference(i, f.dataType, true) + } + val mutableProj = GenerateMutableProjection.generate(exprs)() + val row1 = mutableProj(result) + assert(result === row1) + val row2 = mutableProj(result) + assert(result === row2) + } +} From c4fd2a242228ee101904770446e3f37d49e39b76 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Mon, 10 Aug 2015 13:55:11 -0700 Subject: [PATCH 010/253] [SPARK-9759] [SQL] improve decimal.times() and cast(int, decimalType) This patch optimize two things: 1. passing MathContext to JavaBigDecimal.multiply/divide/reminder to do right rounding, because java.math.BigDecimal.apply(MathContext) is expensive 2. Cast integer/short/byte to decimal directly (without double) This two optimizations could speed up the end-to-end time of a aggregation (SUM(short * decimal(5, 2)) 75% (from 19s -> 10.8s) Author: Davies Liu Closes #8052 from davies/optimize_decimal and squashes the following commits: 225efad [Davies Liu] improve decimal.times() and cast(int, decimalType) --- .../spark/sql/catalyst/expressions/Cast.scala | 42 +++++++------------ .../org/apache/spark/sql/types/Decimal.scala | 12 +++--- 2 files changed, 22 insertions(+), 32 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 946c5a9c04f14..616b9e0e65b78 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -155,7 +155,7 @@ case class Cast(child: Expression, dataType: DataType) case ByteType => buildCast[Byte](_, _ != 0) case DecimalType() => - buildCast[Decimal](_, _ != Decimal.ZERO) + buildCast[Decimal](_, !_.isZero) case DoubleType => buildCast[Double](_, _ != 0) case FloatType => @@ -315,13 +315,13 @@ case class Cast(child: Expression, dataType: DataType) case TimestampType => // Note that we lose precision here. buildCast[Long](_, t => changePrecision(Decimal(timestampToDouble(t)), target)) - case DecimalType() => + case dt: DecimalType => b => changePrecision(b.asInstanceOf[Decimal].clone(), target) - case LongType => - b => changePrecision(Decimal(b.asInstanceOf[Long]), target) - case x: NumericType => // All other numeric types can be represented precisely as Doubles + case t: IntegralType => + b => changePrecision(Decimal(t.integral.asInstanceOf[Integral[Any]].toLong(b)), target) + case x: FractionalType => b => try { - changePrecision(Decimal(x.numeric.asInstanceOf[Numeric[Any]].toDouble(b)), target) + changePrecision(Decimal(x.fractional.asInstanceOf[Fractional[Any]].toDouble(b)), target) } catch { case _: NumberFormatException => null } @@ -534,10 +534,7 @@ case class Cast(child: Expression, dataType: DataType) (c, evPrim, evNull) => s""" try { - org.apache.spark.sql.types.Decimal tmpDecimal = - new org.apache.spark.sql.types.Decimal().set( - new scala.math.BigDecimal( - new java.math.BigDecimal($c.toString()))); + Decimal tmpDecimal = Decimal.apply(new java.math.BigDecimal($c.toString())); ${changePrecision("tmpDecimal", target, evPrim, evNull)} } catch (java.lang.NumberFormatException e) { $evNull = true; @@ -546,12 +543,7 @@ case class Cast(child: Expression, dataType: DataType) case BooleanType => (c, evPrim, evNull) => s""" - org.apache.spark.sql.types.Decimal tmpDecimal = null; - if ($c) { - tmpDecimal = new org.apache.spark.sql.types.Decimal().set(1); - } else { - tmpDecimal = new org.apache.spark.sql.types.Decimal().set(0); - } + Decimal tmpDecimal = $c ? Decimal.apply(1) : Decimal.apply(0); ${changePrecision("tmpDecimal", target, evPrim, evNull)} """ case DateType => @@ -561,32 +553,28 @@ case class Cast(child: Expression, dataType: DataType) // Note that we lose precision here. (c, evPrim, evNull) => s""" - org.apache.spark.sql.types.Decimal tmpDecimal = - new org.apache.spark.sql.types.Decimal().set( - scala.math.BigDecimal.valueOf(${timestampToDoubleCode(c)})); + Decimal tmpDecimal = Decimal.apply( + scala.math.BigDecimal.valueOf(${timestampToDoubleCode(c)})); ${changePrecision("tmpDecimal", target, evPrim, evNull)} """ case DecimalType() => (c, evPrim, evNull) => s""" - org.apache.spark.sql.types.Decimal tmpDecimal = $c.clone(); + Decimal tmpDecimal = $c.clone(); ${changePrecision("tmpDecimal", target, evPrim, evNull)} """ - case LongType => + case x: IntegralType => (c, evPrim, evNull) => s""" - org.apache.spark.sql.types.Decimal tmpDecimal = - new org.apache.spark.sql.types.Decimal().set($c); + Decimal tmpDecimal = Decimal.apply((long) $c); ${changePrecision("tmpDecimal", target, evPrim, evNull)} """ - case x: NumericType => + case x: FractionalType => // All other numeric types can be represented precisely as Doubles (c, evPrim, evNull) => s""" try { - org.apache.spark.sql.types.Decimal tmpDecimal = - new org.apache.spark.sql.types.Decimal().set( - scala.math.BigDecimal.valueOf((double) $c)); + Decimal tmpDecimal = Decimal.apply(scala.math.BigDecimal.valueOf((double) $c)); ${changePrecision("tmpDecimal", target, evPrim, evNull)} } catch (java.lang.NumberFormatException e) { $evNull = true; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala index 624c3f3d7fa97..d95805c24521c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala @@ -139,9 +139,9 @@ final class Decimal extends Ordered[Decimal] with Serializable { def toBigDecimal: BigDecimal = { if (decimalVal.ne(null)) { - decimalVal(MATH_CONTEXT) + decimalVal } else { - BigDecimal(longVal, _scale)(MATH_CONTEXT) + BigDecimal(longVal, _scale) } } @@ -280,13 +280,15 @@ final class Decimal extends Ordered[Decimal] with Serializable { } // HiveTypeCoercion will take care of the precision, scale of result - def * (that: Decimal): Decimal = Decimal(toBigDecimal * that.toBigDecimal) + def * (that: Decimal): Decimal = + Decimal(toJavaBigDecimal.multiply(that.toJavaBigDecimal, MATH_CONTEXT)) def / (that: Decimal): Decimal = - if (that.isZero) null else Decimal(toBigDecimal / that.toBigDecimal) + if (that.isZero) null else Decimal(toJavaBigDecimal.divide(that.toJavaBigDecimal, MATH_CONTEXT)) def % (that: Decimal): Decimal = - if (that.isZero) null else Decimal(toBigDecimal % that.toBigDecimal) + if (that.isZero) null + else Decimal(toJavaBigDecimal.remainder(that.toJavaBigDecimal, MATH_CONTEXT)) def remainder(that: Decimal): Decimal = this % that From 853809e948e7c5092643587a30738115b6591a59 Mon Sep 17 00:00:00 2001 From: Prabeesh K Date: Mon, 10 Aug 2015 16:33:23 -0700 Subject: [PATCH 011/253] [SPARK-5155] [PYSPARK] [STREAMING] Mqtt streaming support in Python This PR is based on #4229, thanks prabeesh. Closes #4229 Author: Prabeesh K Author: zsxwing Author: prabs Author: Prabeesh K Closes #7833 from zsxwing/pr4229 and squashes the following commits: 9570bec [zsxwing] Fix the variable name and check null in finally 4a9c79e [zsxwing] Fix pom.xml indentation abf5f18 [zsxwing] Merge branch 'master' into pr4229 935615c [zsxwing] Fix the flaky MQTT tests 47278c5 [zsxwing] Include the project class files 478f844 [zsxwing] Add unpack 5f8a1d4 [zsxwing] Make the maven build generate the test jar for Python MQTT tests 734db99 [zsxwing] Merge branch 'master' into pr4229 126608a [Prabeesh K] address the comments b90b709 [Prabeesh K] Merge pull request #1 from zsxwing/pr4229 d07f454 [zsxwing] Register StreamingListerner before starting StreamingContext; Revert unncessary changes; fix the python unit test a6747cb [Prabeesh K] wait for starting the receiver before publishing data 87fc677 [Prabeesh K] address the comments: 97244ec [zsxwing] Make sbt build the assembly test jar for streaming mqtt 80474d1 [Prabeesh K] fix 1f0cfe9 [Prabeesh K] python style fix e1ee016 [Prabeesh K] scala style fix a5a8f9f [Prabeesh K] added Python test 9767d82 [Prabeesh K] implemented Python-friendly class a11968b [Prabeesh K] fixed python style 795ec27 [Prabeesh K] address comments ee387ae [Prabeesh K] Fix assembly jar location of mqtt-assembly 3f4df12 [Prabeesh K] updated version b34c3c1 [prabs] adress comments 3aa7fff [prabs] Added Python streaming mqtt word count example b7d42ff [prabs] Mqtt streaming support in Python --- dev/run-tests.py | 2 + dev/sparktestsupport/modules.py | 2 + docs/streaming-programming-guide.md | 2 +- .../main/python/streaming/mqtt_wordcount.py | 58 +++++++++ external/mqtt-assembly/pom.xml | 102 +++++++++++++++ external/mqtt/pom.xml | 28 +++++ external/mqtt/src/main/assembly/assembly.xml | 44 +++++++ .../spark/streaming/mqtt/MQTTUtils.scala | 16 +++ .../streaming/mqtt/MQTTStreamSuite.scala | 118 +++--------------- .../spark/streaming/mqtt/MQTTTestUtils.scala | 111 ++++++++++++++++ pom.xml | 1 + project/SparkBuild.scala | 12 +- python/pyspark/streaming/mqtt.py | 72 +++++++++++ python/pyspark/streaming/tests.py | 106 +++++++++++++++- 14 files changed, 565 insertions(+), 109 deletions(-) create mode 100644 examples/src/main/python/streaming/mqtt_wordcount.py create mode 100644 external/mqtt-assembly/pom.xml create mode 100644 external/mqtt/src/main/assembly/assembly.xml create mode 100644 external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTTestUtils.scala create mode 100644 python/pyspark/streaming/mqtt.py diff --git a/dev/run-tests.py b/dev/run-tests.py index d1852b95bb292..f689425ee40b6 100755 --- a/dev/run-tests.py +++ b/dev/run-tests.py @@ -303,6 +303,8 @@ def build_spark_sbt(hadoop_version): "assembly/assembly", "streaming-kafka-assembly/assembly", "streaming-flume-assembly/assembly", + "streaming-mqtt-assembly/assembly", + "streaming-mqtt/test:assembly", "streaming-kinesis-asl-assembly/assembly"] profiles_and_goals = build_profiles + sbt_goals diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index a9717ff9569c7..d82c0cca37bc6 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -181,6 +181,7 @@ def contains_file(self, filename): dependencies=[streaming], source_file_regexes=[ "external/mqtt", + "external/mqtt-assembly", ], sbt_test_goals=[ "streaming-mqtt/test", @@ -306,6 +307,7 @@ def contains_file(self, filename): streaming, streaming_kafka, streaming_flume_assembly, + streaming_mqtt, streaming_kinesis_asl ], source_file_regexes=[ diff --git a/docs/streaming-programming-guide.md b/docs/streaming-programming-guide.md index dbfdb619f89e2..c59d936b43c88 100644 --- a/docs/streaming-programming-guide.md +++ b/docs/streaming-programming-guide.md @@ -683,7 +683,7 @@ for Java, and [StreamingContext](api/python/pyspark.streaming.html#pyspark.strea {:.no_toc} Python API As of Spark {{site.SPARK_VERSION_SHORT}}, -out of these sources, *only* Kafka and Flume are available in the Python API. We will add more advanced sources in the Python API in future. +out of these sources, *only* Kafka, Flume and MQTT are available in the Python API. We will add more advanced sources in the Python API in future. This category of sources require interfacing with external non-Spark libraries, some of them with complex dependencies (e.g., Kafka and Flume). Hence, to minimize issues related to version conflicts diff --git a/examples/src/main/python/streaming/mqtt_wordcount.py b/examples/src/main/python/streaming/mqtt_wordcount.py new file mode 100644 index 0000000000000..617ce5ea6775e --- /dev/null +++ b/examples/src/main/python/streaming/mqtt_wordcount.py @@ -0,0 +1,58 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" + A sample wordcount with MqttStream stream + Usage: mqtt_wordcount.py + + To run this in your local machine, you need to setup a MQTT broker and publisher first, + Mosquitto is one of the open source MQTT Brokers, see + http://mosquitto.org/ + Eclipse paho project provides number of clients and utilities for working with MQTT, see + http://www.eclipse.org/paho/#getting-started + + and then run the example + `$ bin/spark-submit --jars external/mqtt-assembly/target/scala-*/\ + spark-streaming-mqtt-assembly-*.jar examples/src/main/python/streaming/mqtt_wordcount.py \ + tcp://localhost:1883 foo` +""" + +import sys + +from pyspark import SparkContext +from pyspark.streaming import StreamingContext +from pyspark.streaming.mqtt import MQTTUtils + +if __name__ == "__main__": + if len(sys.argv) != 3: + print >> sys.stderr, "Usage: mqtt_wordcount.py " + exit(-1) + + sc = SparkContext(appName="PythonStreamingMQTTWordCount") + ssc = StreamingContext(sc, 1) + + brokerUrl = sys.argv[1] + topic = sys.argv[2] + + lines = MQTTUtils.createStream(ssc, brokerUrl, topic) + counts = lines.flatMap(lambda line: line.split(" ")) \ + .map(lambda word: (word, 1)) \ + .reduceByKey(lambda a, b: a+b) + counts.pprint() + + ssc.start() + ssc.awaitTermination() diff --git a/external/mqtt-assembly/pom.xml b/external/mqtt-assembly/pom.xml new file mode 100644 index 0000000000000..9c94473053d96 --- /dev/null +++ b/external/mqtt-assembly/pom.xml @@ -0,0 +1,102 @@ + + + + + 4.0.0 + + org.apache.spark + spark-parent_2.10 + 1.5.0-SNAPSHOT + ../../pom.xml + + + org.apache.spark + spark-streaming-mqtt-assembly_2.10 + jar + Spark Project External MQTT Assembly + http://spark.apache.org/ + + + streaming-mqtt-assembly + + + + + org.apache.spark + spark-streaming-mqtt_${scala.binary.version} + ${project.version} + + + org.apache.spark + spark-streaming_${scala.binary.version} + ${project.version} + provided + + + + + target/scala-${scala.binary.version}/classes + target/scala-${scala.binary.version}/test-classes + + + org.apache.maven.plugins + maven-shade-plugin + + false + ${project.build.directory}/scala-${scala.binary.version}/spark-streaming-mqtt-assembly-${project.version}.jar + + + *:* + + + + + *:* + + META-INF/*.SF + META-INF/*.DSA + META-INF/*.RSA + + + + + + + package + + shade + + + + + + reference.conf + + + log4j.properties + + + + + + + + + + + diff --git a/external/mqtt/pom.xml b/external/mqtt/pom.xml index 0e41e5781784b..69b309876a0db 100644 --- a/external/mqtt/pom.xml +++ b/external/mqtt/pom.xml @@ -78,5 +78,33 @@ target/scala-${scala.binary.version}/classes target/scala-${scala.binary.version}/test-classes + + + + + org.apache.maven.plugins + maven-assembly-plugin + + + test-jar-with-dependencies + package + + single + + + + spark-streaming-mqtt-test-${project.version} + ${project.build.directory}/scala-${scala.binary.version}/ + false + + false + + src/main/assembly/assembly.xml + + + + + + diff --git a/external/mqtt/src/main/assembly/assembly.xml b/external/mqtt/src/main/assembly/assembly.xml new file mode 100644 index 0000000000000..ecab5b360eb3e --- /dev/null +++ b/external/mqtt/src/main/assembly/assembly.xml @@ -0,0 +1,44 @@ + + + test-jar-with-dependencies + + jar + + false + + + + ${project.build.directory}/scala-${scala.binary.version}/test-classes + / + + + + + + true + test + true + + org.apache.hadoop:*:jar + org.apache.zookeeper:*:jar + org.apache.avro:*:jar + + + + + diff --git a/external/mqtt/src/main/scala/org/apache/spark/streaming/mqtt/MQTTUtils.scala b/external/mqtt/src/main/scala/org/apache/spark/streaming/mqtt/MQTTUtils.scala index 1142d0f56ba34..38a1114863d15 100644 --- a/external/mqtt/src/main/scala/org/apache/spark/streaming/mqtt/MQTTUtils.scala +++ b/external/mqtt/src/main/scala/org/apache/spark/streaming/mqtt/MQTTUtils.scala @@ -74,3 +74,19 @@ object MQTTUtils { createStream(jssc.ssc, brokerUrl, topic, storageLevel) } } + +/** + * This is a helper class that wraps the methods in MQTTUtils into more Python-friendly class and + * function so that it can be easily instantiated and called from Python's MQTTUtils. + */ +private class MQTTUtilsPythonHelper { + + def createStream( + jssc: JavaStreamingContext, + brokerUrl: String, + topic: String, + storageLevel: StorageLevel + ): JavaDStream[String] = { + MQTTUtils.createStream(jssc, brokerUrl, topic, storageLevel) + } +} diff --git a/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTStreamSuite.scala b/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTStreamSuite.scala index c4bf5aa7869bb..a6a9249db8ed7 100644 --- a/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTStreamSuite.scala +++ b/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTStreamSuite.scala @@ -17,46 +17,30 @@ package org.apache.spark.streaming.mqtt -import java.net.{URI, ServerSocket} -import java.util.concurrent.CountDownLatch -import java.util.concurrent.TimeUnit - import scala.concurrent.duration._ import scala.language.postfixOps -import org.apache.activemq.broker.{TransportConnector, BrokerService} -import org.apache.commons.lang3.RandomUtils -import org.eclipse.paho.client.mqttv3._ -import org.eclipse.paho.client.mqttv3.persist.MqttDefaultFilePersistence - import org.scalatest.BeforeAndAfter import org.scalatest.concurrent.Eventually -import org.apache.spark.streaming.{Milliseconds, StreamingContext} -import org.apache.spark.storage.StorageLevel -import org.apache.spark.streaming.dstream.ReceiverInputDStream -import org.apache.spark.streaming.scheduler.StreamingListener -import org.apache.spark.streaming.scheduler.StreamingListenerReceiverStarted import org.apache.spark.{SparkConf, SparkFunSuite} -import org.apache.spark.util.Utils +import org.apache.spark.storage.StorageLevel +import org.apache.spark.streaming.{Milliseconds, StreamingContext} class MQTTStreamSuite extends SparkFunSuite with Eventually with BeforeAndAfter { private val batchDuration = Milliseconds(500) private val master = "local[2]" private val framework = this.getClass.getSimpleName - private val freePort = findFreePort() - private val brokerUri = "//localhost:" + freePort private val topic = "def" - private val persistenceDir = Utils.createTempDir() private var ssc: StreamingContext = _ - private var broker: BrokerService = _ - private var connector: TransportConnector = _ + private var mqttTestUtils: MQTTTestUtils = _ before { ssc = new StreamingContext(master, framework, batchDuration) - setupMQTT() + mqttTestUtils = new MQTTTestUtils + mqttTestUtils.setup() } after { @@ -64,14 +48,17 @@ class MQTTStreamSuite extends SparkFunSuite with Eventually with BeforeAndAfter ssc.stop() ssc = null } - Utils.deleteRecursively(persistenceDir) - tearDownMQTT() + if (mqttTestUtils != null) { + mqttTestUtils.teardown() + mqttTestUtils = null + } } test("mqtt input stream") { val sendMessage = "MQTT demo for spark streaming" - val receiveStream = - MQTTUtils.createStream(ssc, "tcp:" + brokerUri, topic, StorageLevel.MEMORY_ONLY) + val receiveStream = MQTTUtils.createStream(ssc, "tcp://" + mqttTestUtils.brokerUri, topic, + StorageLevel.MEMORY_ONLY) + @volatile var receiveMessage: List[String] = List() receiveStream.foreachRDD { rdd => if (rdd.collect.length > 0) { @@ -79,89 +66,14 @@ class MQTTStreamSuite extends SparkFunSuite with Eventually with BeforeAndAfter receiveMessage } } - ssc.start() - // wait for the receiver to start before publishing data, or we risk failing - // the test nondeterministically. See SPARK-4631 - waitForReceiverToStart() + ssc.start() - publishData(sendMessage) + // Retry it because we don't know when the receiver will start. eventually(timeout(10000 milliseconds), interval(100 milliseconds)) { + mqttTestUtils.publishData(topic, sendMessage) assert(sendMessage.equals(receiveMessage(0))) } ssc.stop() } - - private def setupMQTT() { - broker = new BrokerService() - broker.setDataDirectoryFile(Utils.createTempDir()) - connector = new TransportConnector() - connector.setName("mqtt") - connector.setUri(new URI("mqtt:" + brokerUri)) - broker.addConnector(connector) - broker.start() - } - - private def tearDownMQTT() { - if (broker != null) { - broker.stop() - broker = null - } - if (connector != null) { - connector.stop() - connector = null - } - } - - private def findFreePort(): Int = { - val candidatePort = RandomUtils.nextInt(1024, 65536) - Utils.startServiceOnPort(candidatePort, (trialPort: Int) => { - val socket = new ServerSocket(trialPort) - socket.close() - (null, trialPort) - }, new SparkConf())._2 - } - - def publishData(data: String): Unit = { - var client: MqttClient = null - try { - val persistence = new MqttDefaultFilePersistence(persistenceDir.getAbsolutePath) - client = new MqttClient("tcp:" + brokerUri, MqttClient.generateClientId(), persistence) - client.connect() - if (client.isConnected) { - val msgTopic = client.getTopic(topic) - val message = new MqttMessage(data.getBytes("utf-8")) - message.setQos(1) - message.setRetained(true) - - for (i <- 0 to 10) { - try { - msgTopic.publish(message) - } catch { - case e: MqttException if e.getReasonCode == MqttException.REASON_CODE_MAX_INFLIGHT => - // wait for Spark streaming to consume something from the message queue - Thread.sleep(50) - } - } - } - } finally { - client.disconnect() - client.close() - client = null - } - } - - /** - * Block until at least one receiver has started or timeout occurs. - */ - private def waitForReceiverToStart() = { - val latch = new CountDownLatch(1) - ssc.addStreamingListener(new StreamingListener { - override def onReceiverStarted(receiverStarted: StreamingListenerReceiverStarted) { - latch.countDown() - } - }) - - assert(latch.await(10, TimeUnit.SECONDS), "Timeout waiting for receiver to start.") - } } diff --git a/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTTestUtils.scala b/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTTestUtils.scala new file mode 100644 index 0000000000000..1a371b7008824 --- /dev/null +++ b/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTTestUtils.scala @@ -0,0 +1,111 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.mqtt + +import java.net.{ServerSocket, URI} + +import scala.language.postfixOps + +import com.google.common.base.Charsets.UTF_8 +import org.apache.activemq.broker.{BrokerService, TransportConnector} +import org.apache.commons.lang3.RandomUtils +import org.eclipse.paho.client.mqttv3._ +import org.eclipse.paho.client.mqttv3.persist.MqttDefaultFilePersistence + +import org.apache.spark.util.Utils +import org.apache.spark.{Logging, SparkConf} + +/** + * Share codes for Scala and Python unit tests + */ +private class MQTTTestUtils extends Logging { + + private val persistenceDir = Utils.createTempDir() + private val brokerHost = "localhost" + private val brokerPort = findFreePort() + + private var broker: BrokerService = _ + private var connector: TransportConnector = _ + + def brokerUri: String = { + s"$brokerHost:$brokerPort" + } + + def setup(): Unit = { + broker = new BrokerService() + broker.setDataDirectoryFile(Utils.createTempDir()) + connector = new TransportConnector() + connector.setName("mqtt") + connector.setUri(new URI("mqtt://" + brokerUri)) + broker.addConnector(connector) + broker.start() + } + + def teardown(): Unit = { + if (broker != null) { + broker.stop() + broker = null + } + if (connector != null) { + connector.stop() + connector = null + } + Utils.deleteRecursively(persistenceDir) + } + + private def findFreePort(): Int = { + val candidatePort = RandomUtils.nextInt(1024, 65536) + Utils.startServiceOnPort(candidatePort, (trialPort: Int) => { + val socket = new ServerSocket(trialPort) + socket.close() + (null, trialPort) + }, new SparkConf())._2 + } + + def publishData(topic: String, data: String): Unit = { + var client: MqttClient = null + try { + val persistence = new MqttDefaultFilePersistence(persistenceDir.getAbsolutePath) + client = new MqttClient("tcp://" + brokerUri, MqttClient.generateClientId(), persistence) + client.connect() + if (client.isConnected) { + val msgTopic = client.getTopic(topic) + val message = new MqttMessage(data.getBytes(UTF_8)) + message.setQos(1) + message.setRetained(true) + + for (i <- 0 to 10) { + try { + msgTopic.publish(message) + } catch { + case e: MqttException if e.getReasonCode == MqttException.REASON_CODE_MAX_INFLIGHT => + // wait for Spark streaming to consume something from the message queue + Thread.sleep(50) + } + } + } + } finally { + if (client != null) { + client.disconnect() + client.close() + client = null + } + } + } + +} diff --git a/pom.xml b/pom.xml index 2bcc55b040a26..8942836a7da16 100644 --- a/pom.xml +++ b/pom.xml @@ -104,6 +104,7 @@ external/flume-sink external/flume-assembly external/mqtt + external/mqtt-assembly external/zeromq examples repl diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 9a33baa7c6ce1..41a85fa9de778 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -45,8 +45,8 @@ object BuildCommons { sparkKinesisAsl) = Seq("yarn", "yarn-stable", "java8-tests", "ganglia-lgpl", "kinesis-asl").map(ProjectRef(buildLocation, _)) - val assemblyProjects@Seq(assembly, examples, networkYarn, streamingFlumeAssembly, streamingKafkaAssembly, streamingKinesisAslAssembly) = - Seq("assembly", "examples", "network-yarn", "streaming-flume-assembly", "streaming-kafka-assembly", "streaming-kinesis-asl-assembly") + val assemblyProjects@Seq(assembly, examples, networkYarn, streamingFlumeAssembly, streamingKafkaAssembly, streamingMqttAssembly, streamingKinesisAslAssembly) = + Seq("assembly", "examples", "network-yarn", "streaming-flume-assembly", "streaming-kafka-assembly", "streaming-mqtt-assembly", "streaming-kinesis-asl-assembly") .map(ProjectRef(buildLocation, _)) val tools = ProjectRef(buildLocation, "tools") @@ -212,6 +212,9 @@ object SparkBuild extends PomBuild { /* Enable Assembly for all assembly projects */ assemblyProjects.foreach(enable(Assembly.settings)) + /* Enable Assembly for streamingMqtt test */ + enable(inConfig(Test)(Assembly.settings))(streamingMqtt) + /* Package pyspark artifacts in a separate zip file for YARN. */ enable(PySparkAssembly.settings)(assembly) @@ -382,13 +385,16 @@ object Assembly { .getOrElse(SbtPomKeys.effectivePom.value.getProperties.get("hadoop.version").asInstanceOf[String]) }, jarName in assembly <<= (version, moduleName, hadoopVersion) map { (v, mName, hv) => - if (mName.contains("streaming-flume-assembly") || mName.contains("streaming-kafka-assembly") || mName.contains("streaming-kinesis-asl-assembly")) { + if (mName.contains("streaming-flume-assembly") || mName.contains("streaming-kafka-assembly") || mName.contains("streaming-mqtt-assembly") || mName.contains("streaming-kinesis-asl-assembly")) { // This must match the same name used in maven (see external/kafka-assembly/pom.xml) s"${mName}-${v}.jar" } else { s"${mName}-${v}-hadoop${hv}.jar" } }, + jarName in (Test, assembly) <<= (version, moduleName, hadoopVersion) map { (v, mName, hv) => + s"${mName}-test-${v}.jar" + }, mergeStrategy in assembly := { case PathList("org", "datanucleus", xs @ _*) => MergeStrategy.discard case m if m.toLowerCase.endsWith("manifest.mf") => MergeStrategy.discard diff --git a/python/pyspark/streaming/mqtt.py b/python/pyspark/streaming/mqtt.py new file mode 100644 index 0000000000000..f06598971c548 --- /dev/null +++ b/python/pyspark/streaming/mqtt.py @@ -0,0 +1,72 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from py4j.java_gateway import Py4JJavaError + +from pyspark.storagelevel import StorageLevel +from pyspark.serializers import UTF8Deserializer +from pyspark.streaming import DStream + +__all__ = ['MQTTUtils'] + + +class MQTTUtils(object): + + @staticmethod + def createStream(ssc, brokerUrl, topic, + storageLevel=StorageLevel.MEMORY_AND_DISK_SER_2): + """ + Create an input stream that pulls messages from a Mqtt Broker. + :param ssc: StreamingContext object + :param brokerUrl: Url of remote mqtt publisher + :param topic: topic name to subscribe to + :param storageLevel: RDD storage level. + :return: A DStream object + """ + jlevel = ssc._sc._getJavaStorageLevel(storageLevel) + + try: + helperClass = ssc._jvm.java.lang.Thread.currentThread().getContextClassLoader() \ + .loadClass("org.apache.spark.streaming.mqtt.MQTTUtilsPythonHelper") + helper = helperClass.newInstance() + jstream = helper.createStream(ssc._jssc, brokerUrl, topic, jlevel) + except Py4JJavaError as e: + if 'ClassNotFoundException' in str(e.java_exception): + MQTTUtils._printErrorMsg(ssc.sparkContext) + raise e + + return DStream(jstream, ssc, UTF8Deserializer()) + + @staticmethod + def _printErrorMsg(sc): + print(""" +________________________________________________________________________________________________ + + Spark Streaming's MQTT libraries not found in class path. Try one of the following. + + 1. Include the MQTT library and its dependencies with in the + spark-submit command as + + $ bin/spark-submit --packages org.apache.spark:spark-streaming-mqtt:%s ... + + 2. Download the JAR of the artifact from Maven Central http://search.maven.org/, + Group Id = org.apache.spark, Artifact Id = spark-streaming-mqtt-assembly, Version = %s. + Then, include the jar in the spark-submit command as + + $ bin/spark-submit --jars ... +________________________________________________________________________________________________ +""" % (sc.version, sc.version)) diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py index 5cd544b2144ef..66ae3345f468f 100644 --- a/python/pyspark/streaming/tests.py +++ b/python/pyspark/streaming/tests.py @@ -40,6 +40,7 @@ from pyspark.streaming.context import StreamingContext from pyspark.streaming.kafka import Broker, KafkaUtils, OffsetRange, TopicAndPartition from pyspark.streaming.flume import FlumeUtils +from pyspark.streaming.mqtt import MQTTUtils from pyspark.streaming.kinesis import KinesisUtils, InitialPositionInStream @@ -893,6 +894,68 @@ def test_flume_polling_multiple_hosts(self): self._testMultipleTimes(self._testFlumePollingMultipleHosts) +class MQTTStreamTests(PySparkStreamingTestCase): + timeout = 20 # seconds + duration = 1 + + def setUp(self): + super(MQTTStreamTests, self).setUp() + + MQTTTestUtilsClz = self.ssc._jvm.java.lang.Thread.currentThread().getContextClassLoader() \ + .loadClass("org.apache.spark.streaming.mqtt.MQTTTestUtils") + self._MQTTTestUtils = MQTTTestUtilsClz.newInstance() + self._MQTTTestUtils.setup() + + def tearDown(self): + if self._MQTTTestUtils is not None: + self._MQTTTestUtils.teardown() + self._MQTTTestUtils = None + + super(MQTTStreamTests, self).tearDown() + + def _randomTopic(self): + return "topic-%d" % random.randint(0, 10000) + + def _startContext(self, topic): + # Start the StreamingContext and also collect the result + stream = MQTTUtils.createStream(self.ssc, "tcp://" + self._MQTTTestUtils.brokerUri(), topic) + result = [] + + def getOutput(_, rdd): + for data in rdd.collect(): + result.append(data) + + stream.foreachRDD(getOutput) + self.ssc.start() + return result + + def test_mqtt_stream(self): + """Test the Python MQTT stream API.""" + sendData = "MQTT demo for spark streaming" + topic = self._randomTopic() + result = self._startContext(topic) + + def retry(): + self._MQTTTestUtils.publishData(topic, sendData) + # Because "publishData" sends duplicate messages, here we should use > 0 + self.assertTrue(len(result) > 0) + self.assertEqual(sendData, result[0]) + + # Retry it because we don't know when the receiver will start. + self._retry_or_timeout(retry) + + def _retry_or_timeout(self, test_func): + start_time = time.time() + while True: + try: + test_func() + break + except: + if time.time() - start_time > self.timeout: + raise + time.sleep(0.01) + + class KinesisStreamTests(PySparkStreamingTestCase): def test_kinesis_stream_api(self): @@ -985,7 +1048,42 @@ def search_flume_assembly_jar(): "'build/mvn package' before running this test") elif len(jars) > 1: raise Exception(("Found multiple Spark Streaming Flume assembly JARs in %s; please " - "remove all but one") % flume_assembly_dir) + "remove all but one") % flume_assembly_dir) + else: + return jars[0] + + +def search_mqtt_assembly_jar(): + SPARK_HOME = os.environ["SPARK_HOME"] + mqtt_assembly_dir = os.path.join(SPARK_HOME, "external/mqtt-assembly") + jars = glob.glob( + os.path.join(mqtt_assembly_dir, "target/scala-*/spark-streaming-mqtt-assembly-*.jar")) + if not jars: + raise Exception( + ("Failed to find Spark Streaming MQTT assembly jar in %s. " % mqtt_assembly_dir) + + "You need to build Spark with " + "'build/sbt assembly/assembly streaming-mqtt-assembly/assembly' or " + "'build/mvn package' before running this test") + elif len(jars) > 1: + raise Exception(("Found multiple Spark Streaming MQTT assembly JARs in %s; please " + "remove all but one") % mqtt_assembly_dir) + else: + return jars[0] + + +def search_mqtt_test_jar(): + SPARK_HOME = os.environ["SPARK_HOME"] + mqtt_test_dir = os.path.join(SPARK_HOME, "external/mqtt") + jars = glob.glob( + os.path.join(mqtt_test_dir, "target/scala-*/spark-streaming-mqtt-test-*.jar")) + if not jars: + raise Exception( + ("Failed to find Spark Streaming MQTT test jar in %s. " % mqtt_test_dir) + + "You need to build Spark with " + "'build/sbt assembly/assembly streaming-mqtt/test:assembly'") + elif len(jars) > 1: + raise Exception(("Found multiple Spark Streaming MQTT test JARs in %s; please " + "remove all but one") % mqtt_test_dir) else: return jars[0] @@ -1012,8 +1110,12 @@ def search_kinesis_asl_assembly_jar(): if __name__ == "__main__": kafka_assembly_jar = search_kafka_assembly_jar() flume_assembly_jar = search_flume_assembly_jar() + mqtt_assembly_jar = search_mqtt_assembly_jar() + mqtt_test_jar = search_mqtt_test_jar() kinesis_asl_assembly_jar = search_kinesis_asl_assembly_jar() - jars = "%s,%s,%s" % (kafka_assembly_jar, flume_assembly_jar, kinesis_asl_assembly_jar) + + jars = "%s,%s,%s,%s,%s" % (kafka_assembly_jar, flume_assembly_jar, kinesis_asl_assembly_jar, + mqtt_assembly_jar, mqtt_test_jar) os.environ["PYSPARK_SUBMIT_ARGS"] = "--jars %s pyspark-shell" % jars unittest.main() From 3c9802d9400bea802984456683b2736a450ee17e Mon Sep 17 00:00:00 2001 From: Hao Zhu Date: Mon, 10 Aug 2015 17:17:22 -0700 Subject: [PATCH 012/253] [SPARK-9801] [STREAMING] Check if file exists before deleting temporary files. Spark streaming deletes the temp file and backup files without checking if they exist or not Author: Hao Zhu Closes #8082 from viadea/master and squashes the following commits: 242d05f [Hao Zhu] [SPARK-9801][Streaming]No need to check the existence of those files fd143f2 [Hao Zhu] [SPARK-9801][Streaming]Check if backupFile exists before deleting backupFile files. 087daf0 [Hao Zhu] SPARK-9801 --- .../scala/org/apache/spark/streaming/Checkpoint.scala | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala index 2780d5b6adbcf..6f6b449accc3c 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala @@ -192,7 +192,9 @@ class CheckpointWriter( + "'") // Write checkpoint to temp file - fs.delete(tempFile, true) // just in case it exists + if (fs.exists(tempFile)) { + fs.delete(tempFile, true) // just in case it exists + } val fos = fs.create(tempFile) Utils.tryWithSafeFinally { fos.write(bytes) @@ -203,7 +205,9 @@ class CheckpointWriter( // If the checkpoint file exists, back it up // If the backup exists as well, just delete it, otherwise rename will fail if (fs.exists(checkpointFile)) { - fs.delete(backupFile, true) // just in case it exists + if (fs.exists(backupFile)){ + fs.delete(backupFile, true) // just in case it exists + } if (!fs.rename(checkpointFile, backupFile)) { logWarning("Could not rename " + checkpointFile + " to " + backupFile) } From 071bbad5db1096a548c886762b611a8484a52753 Mon Sep 17 00:00:00 2001 From: Damian Guy Date: Tue, 11 Aug 2015 12:46:33 +0800 Subject: [PATCH 013/253] [SPARK-9340] [SQL] Fixes converting unannotated Parquet lists This PR is inspired by #8063 authored by dguy. Especially, testing Parquet files added here are all taken from that PR. **Committer who merges this PR should attribute it to "Damian Guy ".** ---- SPARK-6776 and SPARK-6777 followed `parquet-avro` to implement backwards-compatibility rules defined in `parquet-format` spec. However, both Spark SQL and `parquet-avro` neglected the following statement in `parquet-format`: > This does not affect repeated fields that are not annotated: A repeated field that is neither contained by a `LIST`- or `MAP`-annotated group nor annotated by `LIST` or `MAP` should be interpreted as a required list of required elements where the element type is the type of the field. One of the consequences is that, Parquet files generated by `parquet-protobuf` containing unannotated repeated fields are not correctly converted to Catalyst arrays. This PR fixes this issue by 1. Handling unannotated repeated fields in `CatalystSchemaConverter`. 2. Converting this kind of special repeated fields to Catalyst arrays in `CatalystRowConverter`. Two special converters, `RepeatedPrimitiveConverter` and `RepeatedGroupConverter`, are added. They delegate actual conversion work to a child `elementConverter` and accumulates elements in an `ArrayBuffer`. Two extra methods, `start()` and `end()`, are added to `ParentContainerUpdater`. So that they can be used to initialize new `ArrayBuffer`s for unannotated repeated fields, and propagate converted array values to upstream. Author: Cheng Lian Closes #8070 from liancheng/spark-9340/unannotated-parquet-list and squashes the following commits: ace6df7 [Cheng Lian] Moves ParquetProtobufCompatibilitySuite f1c7bfd [Cheng Lian] Updates .rat-excludes 420ad2b [Cheng Lian] Fixes converting unannotated Parquet lists --- .rat-excludes | 1 + .../parquet/CatalystRowConverter.scala | 151 ++++++++++++++---- .../parquet/CatalystSchemaConverter.scala | 7 +- .../resources/nested-array-struct.parquet | Bin 0 -> 775 bytes .../test/resources/old-repeated-int.parquet | Bin 0 -> 389 bytes .../resources/old-repeated-message.parquet | Bin 0 -> 600 bytes .../src/test/resources/old-repeated.parquet | Bin 0 -> 432 bytes .../parquet-thrift-compat.snappy.parquet | Bin .../resources/proto-repeated-string.parquet | Bin 0 -> 411 bytes .../resources/proto-repeated-struct.parquet | Bin 0 -> 608 bytes .../proto-struct-with-array-many.parquet | Bin 0 -> 802 bytes .../resources/proto-struct-with-array.parquet | Bin 0 -> 1576 bytes .../ParquetProtobufCompatibilitySuite.scala | 91 +++++++++++ .../parquet/ParquetSchemaSuite.scala | 30 ++++ 14 files changed, 247 insertions(+), 33 deletions(-) create mode 100644 sql/core/src/test/resources/nested-array-struct.parquet create mode 100644 sql/core/src/test/resources/old-repeated-int.parquet create mode 100644 sql/core/src/test/resources/old-repeated-message.parquet create mode 100644 sql/core/src/test/resources/old-repeated.parquet mode change 100755 => 100644 sql/core/src/test/resources/parquet-thrift-compat.snappy.parquet create mode 100644 sql/core/src/test/resources/proto-repeated-string.parquet create mode 100644 sql/core/src/test/resources/proto-repeated-struct.parquet create mode 100644 sql/core/src/test/resources/proto-struct-with-array-many.parquet create mode 100644 sql/core/src/test/resources/proto-struct-with-array.parquet create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetProtobufCompatibilitySuite.scala diff --git a/.rat-excludes b/.rat-excludes index 72771465846b8..9165872b9fb27 100644 --- a/.rat-excludes +++ b/.rat-excludes @@ -94,3 +94,4 @@ INDEX gen-java.* .*avpr org.apache.spark.sql.sources.DataSourceRegister +.*parquet diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystRowConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystRowConverter.scala index 3542dfbae1292..ab5a6ddd41cfc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystRowConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystRowConverter.scala @@ -21,11 +21,11 @@ import java.math.{BigDecimal, BigInteger} import java.nio.ByteOrder import scala.collection.JavaConversions._ -import scala.collection.mutable import scala.collection.mutable.ArrayBuffer import org.apache.parquet.column.Dictionary import org.apache.parquet.io.api.{Binary, Converter, GroupConverter, PrimitiveConverter} +import org.apache.parquet.schema.OriginalType.LIST import org.apache.parquet.schema.Type.Repetition import org.apache.parquet.schema.{GroupType, PrimitiveType, Type} @@ -42,6 +42,12 @@ import org.apache.spark.unsafe.types.UTF8String * values to an [[ArrayBuffer]]. */ private[parquet] trait ParentContainerUpdater { + /** Called before a record field is being converted */ + def start(): Unit = () + + /** Called after a record field is being converted */ + def end(): Unit = () + def set(value: Any): Unit = () def setBoolean(value: Boolean): Unit = set(value) def setByte(value: Byte): Unit = set(value) @@ -55,6 +61,32 @@ private[parquet] trait ParentContainerUpdater { /** A no-op updater used for root converter (who doesn't have a parent). */ private[parquet] object NoopUpdater extends ParentContainerUpdater +private[parquet] trait HasParentContainerUpdater { + def updater: ParentContainerUpdater +} + +/** + * A convenient converter class for Parquet group types with an [[HasParentContainerUpdater]]. + */ +private[parquet] abstract class CatalystGroupConverter(val updater: ParentContainerUpdater) + extends GroupConverter with HasParentContainerUpdater + +/** + * Parquet converter for Parquet primitive types. Note that not all Spark SQL atomic types + * are handled by this converter. Parquet primitive types are only a subset of those of Spark + * SQL. For example, BYTE, SHORT, and INT in Spark SQL are all covered by INT32 in Parquet. + */ +private[parquet] class CatalystPrimitiveConverter(val updater: ParentContainerUpdater) + extends PrimitiveConverter with HasParentContainerUpdater { + + override def addBoolean(value: Boolean): Unit = updater.setBoolean(value) + override def addInt(value: Int): Unit = updater.setInt(value) + override def addLong(value: Long): Unit = updater.setLong(value) + override def addFloat(value: Float): Unit = updater.setFloat(value) + override def addDouble(value: Double): Unit = updater.setDouble(value) + override def addBinary(value: Binary): Unit = updater.set(value.getBytes) +} + /** * A [[CatalystRowConverter]] is used to convert Parquet "structs" into Spark SQL [[InternalRow]]s. * Since any Parquet record is also a struct, this converter can also be used as root converter. @@ -70,7 +102,7 @@ private[parquet] class CatalystRowConverter( parquetType: GroupType, catalystType: StructType, updater: ParentContainerUpdater) - extends GroupConverter { + extends CatalystGroupConverter(updater) { /** * Updater used together with field converters within a [[CatalystRowConverter]]. It propagates @@ -89,13 +121,11 @@ private[parquet] class CatalystRowConverter( /** * Represents the converted row object once an entire Parquet record is converted. - * - * @todo Uses [[UnsafeRow]] for better performance. */ val currentRow = new SpecificMutableRow(catalystType.map(_.dataType)) // Converters for each field. - private val fieldConverters: Array[Converter] = { + private val fieldConverters: Array[Converter with HasParentContainerUpdater] = { parquetType.getFields.zip(catalystType).zipWithIndex.map { case ((parquetFieldType, catalystField), ordinal) => // Converted field value should be set to the `ordinal`-th cell of `currentRow` @@ -105,11 +135,19 @@ private[parquet] class CatalystRowConverter( override def getConverter(fieldIndex: Int): Converter = fieldConverters(fieldIndex) - override def end(): Unit = updater.set(currentRow) + override def end(): Unit = { + var i = 0 + while (i < currentRow.numFields) { + fieldConverters(i).updater.end() + i += 1 + } + updater.set(currentRow) + } override def start(): Unit = { var i = 0 while (i < currentRow.numFields) { + fieldConverters(i).updater.start() currentRow.setNullAt(i) i += 1 } @@ -122,20 +160,20 @@ private[parquet] class CatalystRowConverter( private def newConverter( parquetType: Type, catalystType: DataType, - updater: ParentContainerUpdater): Converter = { + updater: ParentContainerUpdater): Converter with HasParentContainerUpdater = { catalystType match { case BooleanType | IntegerType | LongType | FloatType | DoubleType | BinaryType => new CatalystPrimitiveConverter(updater) case ByteType => - new PrimitiveConverter { + new CatalystPrimitiveConverter(updater) { override def addInt(value: Int): Unit = updater.setByte(value.asInstanceOf[ByteType#InternalType]) } case ShortType => - new PrimitiveConverter { + new CatalystPrimitiveConverter(updater) { override def addInt(value: Int): Unit = updater.setShort(value.asInstanceOf[ShortType#InternalType]) } @@ -148,7 +186,7 @@ private[parquet] class CatalystRowConverter( case TimestampType => // TODO Implements `TIMESTAMP_MICROS` once parquet-mr has that. - new PrimitiveConverter { + new CatalystPrimitiveConverter(updater) { // Converts nanosecond timestamps stored as INT96 override def addBinary(value: Binary): Unit = { assert( @@ -164,13 +202,23 @@ private[parquet] class CatalystRowConverter( } case DateType => - new PrimitiveConverter { + new CatalystPrimitiveConverter(updater) { override def addInt(value: Int): Unit = { // DateType is not specialized in `SpecificMutableRow`, have to box it here. updater.set(value.asInstanceOf[DateType#InternalType]) } } + // A repeated field that is neither contained by a `LIST`- or `MAP`-annotated group nor + // annotated by `LIST` or `MAP` should be interpreted as a required list of required + // elements where the element type is the type of the field. + case t: ArrayType if parquetType.getOriginalType != LIST => + if (parquetType.isPrimitive) { + new RepeatedPrimitiveConverter(parquetType, t.elementType, updater) + } else { + new RepeatedGroupConverter(parquetType, t.elementType, updater) + } + case t: ArrayType => new CatalystArrayConverter(parquetType.asGroupType(), t, updater) @@ -195,27 +243,11 @@ private[parquet] class CatalystRowConverter( } } - /** - * Parquet converter for Parquet primitive types. Note that not all Spark SQL atomic types - * are handled by this converter. Parquet primitive types are only a subset of those of Spark - * SQL. For example, BYTE, SHORT, and INT in Spark SQL are all covered by INT32 in Parquet. - */ - private final class CatalystPrimitiveConverter(updater: ParentContainerUpdater) - extends PrimitiveConverter { - - override def addBoolean(value: Boolean): Unit = updater.setBoolean(value) - override def addInt(value: Int): Unit = updater.setInt(value) - override def addLong(value: Long): Unit = updater.setLong(value) - override def addFloat(value: Float): Unit = updater.setFloat(value) - override def addDouble(value: Double): Unit = updater.setDouble(value) - override def addBinary(value: Binary): Unit = updater.set(value.getBytes) - } - /** * Parquet converter for strings. A dictionary is used to minimize string decoding cost. */ private final class CatalystStringConverter(updater: ParentContainerUpdater) - extends PrimitiveConverter { + extends CatalystPrimitiveConverter(updater) { private var expandedDictionary: Array[UTF8String] = null @@ -242,7 +274,7 @@ private[parquet] class CatalystRowConverter( private final class CatalystDecimalConverter( decimalType: DecimalType, updater: ParentContainerUpdater) - extends PrimitiveConverter { + extends CatalystPrimitiveConverter(updater) { // Converts decimals stored as INT32 override def addInt(value: Int): Unit = { @@ -306,7 +338,7 @@ private[parquet] class CatalystRowConverter( parquetSchema: GroupType, catalystSchema: ArrayType, updater: ParentContainerUpdater) - extends GroupConverter { + extends CatalystGroupConverter(updater) { private var currentArray: ArrayBuffer[Any] = _ @@ -383,7 +415,7 @@ private[parquet] class CatalystRowConverter( parquetType: GroupType, catalystType: MapType, updater: ParentContainerUpdater) - extends GroupConverter { + extends CatalystGroupConverter(updater) { private var currentKeys: ArrayBuffer[Any] = _ private var currentValues: ArrayBuffer[Any] = _ @@ -446,4 +478,61 @@ private[parquet] class CatalystRowConverter( } } } + + private trait RepeatedConverter { + private var currentArray: ArrayBuffer[Any] = _ + + protected def newArrayUpdater(updater: ParentContainerUpdater) = new ParentContainerUpdater { + override def start(): Unit = currentArray = ArrayBuffer.empty[Any] + override def end(): Unit = updater.set(new GenericArrayData(currentArray.toArray)) + override def set(value: Any): Unit = currentArray += value + } + } + + /** + * A primitive converter for converting unannotated repeated primitive values to required arrays + * of required primitives values. + */ + private final class RepeatedPrimitiveConverter( + parquetType: Type, + catalystType: DataType, + parentUpdater: ParentContainerUpdater) + extends PrimitiveConverter with RepeatedConverter with HasParentContainerUpdater { + + val updater: ParentContainerUpdater = newArrayUpdater(parentUpdater) + + private val elementConverter: PrimitiveConverter = + newConverter(parquetType, catalystType, updater).asPrimitiveConverter() + + override def addBoolean(value: Boolean): Unit = elementConverter.addBoolean(value) + override def addInt(value: Int): Unit = elementConverter.addInt(value) + override def addLong(value: Long): Unit = elementConverter.addLong(value) + override def addFloat(value: Float): Unit = elementConverter.addFloat(value) + override def addDouble(value: Double): Unit = elementConverter.addDouble(value) + override def addBinary(value: Binary): Unit = elementConverter.addBinary(value) + + override def setDictionary(dict: Dictionary): Unit = elementConverter.setDictionary(dict) + override def hasDictionarySupport: Boolean = elementConverter.hasDictionarySupport + override def addValueFromDictionary(id: Int): Unit = elementConverter.addValueFromDictionary(id) + } + + /** + * A group converter for converting unannotated repeated group values to required arrays of + * required struct values. + */ + private final class RepeatedGroupConverter( + parquetType: Type, + catalystType: DataType, + parentUpdater: ParentContainerUpdater) + extends GroupConverter with HasParentContainerUpdater with RepeatedConverter { + + val updater: ParentContainerUpdater = newArrayUpdater(parentUpdater) + + private val elementConverter: GroupConverter = + newConverter(parquetType, catalystType, updater).asGroupConverter() + + override def getConverter(field: Int): Converter = elementConverter.getConverter(field) + override def end(): Unit = elementConverter.end() + override def start(): Unit = elementConverter.start() + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystSchemaConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystSchemaConverter.scala index a3fc74cf7929b..275646e8181ad 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystSchemaConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystSchemaConverter.scala @@ -100,8 +100,11 @@ private[parquet] class CatalystSchemaConverter( StructField(field.getName, convertField(field), nullable = false) case REPEATED => - throw new AnalysisException( - s"REPEATED not supported outside LIST or MAP. Type: $field") + // A repeated field that is neither contained by a `LIST`- or `MAP`-annotated group nor + // annotated by `LIST` or `MAP` should be interpreted as a required list of required + // elements where the element type is the type of the field. + val arrayType = ArrayType(convertField(field), containsNull = false) + StructField(field.getName, arrayType, nullable = false) } } diff --git a/sql/core/src/test/resources/nested-array-struct.parquet b/sql/core/src/test/resources/nested-array-struct.parquet new file mode 100644 index 0000000000000000000000000000000000000000..41a43fa35d39685e56ba4849a16cba4bb1aa86ae GIT binary patch literal 775 zcmaKr-%G+!6vvO#=8u;i;*JSE3{j~lAyWtuV%0Q3*U(Y)B(q&>u(@?NBZ;2+Px=dc z?6I>s%N6u+cQ4=bIp1@3cBjdsBLbvCDhGte15a`Q8~~)V;d2WY3V@LYX{-@GMj#!6 z`;fvdgDblto20oxMhs#hdKzt*4*3w}iuUE6PW?b*Zs1NAv%1WfvAnT@2NhLn_L#fy zHFWX={q7*H z93@^0*O&w#e53@lJ`hFEV2=wL)V**Pb(8vc%<=-4iJz&t;n22J{%<(t!px$!DZLaV zDaOCYR1UR;Go`F89pTwFrqpgr1NlrDOs+J&f2GO;)PtpmW%OH3neOEccXHoWyO`6Bl5({+ea14&qL7DtETw`(h_426$BxCYApN S1!5siKXe$p;U(Ab7x)6M)yB5~ literal 0 HcmV?d00001 diff --git a/sql/core/src/test/resources/old-repeated-int.parquet b/sql/core/src/test/resources/old-repeated-int.parquet new file mode 100644 index 0000000000000000000000000000000000000000..520922f73ebb75950c4e65dec6689af817cb7a33 GIT binary patch literal 389 zcmZWmO-sW-5FMkeB_3tN1_Ca@_7p=~Z@EPbSg5juTs)Ocvz0>9#NEw7ivQhdDcI1< z%;U}1booNUUY~PehCwzvumZho_zD!@TVtKoQHivd-PzgS{O4oWwfk)ZsDnB!ByvMUB7gt@Rk50{GMw}6DWn-w!#F0CGZwP` zh81XVct9peJU!6=O6yx`ZywS;EGVOwOOIsCr3p)d)y(vgv`4bce)5D7UiKlH0l6Ga0+m-EijTt zM!X)Rd#pSj~EkCao0il-K=62)db~64ZC7r8d1AwIhzFkoG;e&AbpZW#pJ&{5H literal 0 HcmV?d00001 diff --git a/sql/core/src/test/resources/old-repeated.parquet b/sql/core/src/test/resources/old-repeated.parquet new file mode 100644 index 0000000000000000000000000000000000000000..213f1a90291b30a8a3161b51c38f008f3ae9f6e5 GIT binary patch literal 432 zcmZWm!D@p*5ZxNF!5+)X3PMGioUA16Ei?y9g$B|h;-#ms>Ld--Xm{5`DgF13A<&42 znSH!BJNtMWhsm50I-@h68VC$(I7}ZALYRJm-NGUo*2p;a%Z@xEJgH{;FE=Sj6^mNc zS-TAqXn-pyRtNP8Qt};84d*60yAuBru{7JUo$1)Y6%&KlJ(Uv6u!JS1zOP)cw zaM$5ewB9699EEB0jJ*18aDDn7N1N4K`fzXlnuJ~V?c^nwk}Yeo3wXox4+#3Y!pMU2 V+-`?%2{TWZ?kYh(G4~k%>JK8=aDe~- literal 0 HcmV?d00001 diff --git a/sql/core/src/test/resources/parquet-thrift-compat.snappy.parquet b/sql/core/src/test/resources/parquet-thrift-compat.snappy.parquet old mode 100755 new mode 100644 diff --git a/sql/core/src/test/resources/proto-repeated-string.parquet b/sql/core/src/test/resources/proto-repeated-string.parquet new file mode 100644 index 0000000000000000000000000000000000000000..8a7eea601d0164b1177ab90efb7116c6e4c900da GIT binary patch literal 411 zcmY*W!D_-l5S^$E62wc{umKMtR4+{f-b!vM4Q)Y6&|G?w#H^aKansF;gi?C#*Yq1Z zv6e;{_C4Mk<_)t^FrN}2UmBK6hDddy19SkO`+9soFOY8;=b|A8A$itAvJoQdBBnKK zKy>)#|`4$W^3YtqL)WN5pTmWh1ZGv$>{f|s#sCG%1VN%LJ&FyD4siH@<(8PDu@ z!?sWEU$)ao`yyr1x2MQ?k}~ewv*0eAE$3kr261?gx~fYY8oxy0auLs;o*#@41L)=X h7Au}q6}>(e6`sLs-{PvZ8BpWYeN#xd)c_*=mLHLcZu|fM literal 0 HcmV?d00001 diff --git a/sql/core/src/test/resources/proto-repeated-struct.parquet b/sql/core/src/test/resources/proto-repeated-struct.parquet new file mode 100644 index 0000000000000000000000000000000000000000..c29eee35c350e84efa2818febb2528309a8ac3ea GIT binary patch literal 608 zcma))-D<)x6vvP811a8(bSZdI%9Js*tjca=2puciK%r=F1_P}cr%>B2jf^q&4ts<> z%r1PaC9^8^%A4e$li&GFTzg<)z+K#J;DQh(TmnDm&bFDCfsEak0$H6=|yp$CW-$_F@l={DK5j1GEpn8)DX!>A+15G`Fpg} zMZREE-l#~cYPa=r6<4%c3AD?tzx2bP7Sypiu9pGoi(^0p+XD*$Y;s4$HpQOVL*dnD3MsZWzL<}ml5ZY`6p``6p3uzOR6cOQizQ3hI@;TGm z8hy+Rm>Hl!&aiC8oEI4i7`u<#dC`r5%q<5r!9eDg1IFy*c2RU=AalzBO)!wT<$y7e zS0C?=T^uJ)6ePiDIW^oM?BO`}o-pLWHOS&h@*H8x zD56?ZuNu`Fl-0Tj)YHv=x(@J>C=j)+#mj%Z_4kgdp}D_<3b zc(o7;z363$6CCr9}+-E j<-Z;KUL2!lIhl~NDb+dIHUN;6ire!D{E~S&5gg5S?rsql%Icnq5|qgD{N<#x?oCY2!n{ZAEKv64iDOJsH{F!~)4uB-v0( z@BJM;_owufA5^-l4@Z{ekVAVh>zOxi-n`LDMyq>_0q^0x8b74l?^SjJrp%q&06h8-y4iMdSJ@MDH4c~HjX3j($=&sN1 zW|q&!OYxG3d&~^8@dlzhDa$1b0`rz(6tkBD*J153G=T1;gzF$B0g1WSKnPOy6M$~KQ|W5 z5A#DO!$$Zca>TK`;67WBib&?m76{4rqTmNgwH)UCc)*uPhjcg;fc!=Tfl{N?GyS_6 z3+tZPeSOS=k#BjS>(f75Q`2EhwX*hMsK_@Kv&ZT;SydBky3mDR6_J}cL*_TtV}7>H zA+wumr}b9v46coS`}(TY;qmaR$9wg^82X@n)jvIvzps*~J`|Fl Date: Mon, 10 Aug 2015 22:04:41 -0700 Subject: [PATCH 014/253] [SPARK-9729] [SPARK-9363] [SQL] Use sort merge join for left and right outer join This patch adds a new `SortMergeOuterJoin` operator that performs left and right outer joins using sort merge join. It also refactors `SortMergeJoin` in order to improve performance and code clarity. Along the way, I also performed a couple pieces of minor cleanup and optimization: - Rename the `HashJoin` physical planner rule to `EquiJoinSelection`, since it's also used for non-hash joins. - Rewrite the comment at the top of `HashJoin` to better explain the precedence for choosing join operators. - Update `JoinSuite` to use `SqlTestUtils.withConf` for changing SQLConf settings. This patch incorporates several ideas from adrian-wang's patch, #5717. Closes #5717. [Review on Reviewable](https://reviewable.io/reviews/apache/spark/7904) Author: Josh Rosen Author: Daoyuan Wang Closes #7904 from JoshRosen/outer-join-smj and squashes 1 commits. --- .../sql/catalyst/expressions/JoinedRow.scala | 6 +- .../org/apache/spark/sql/SQLContext.scala | 2 +- .../spark/sql/execution/RowIterator.scala | 93 +++++ .../spark/sql/execution/SparkStrategies.scala | 45 ++- .../joins/BroadcastNestedLoopJoin.scala | 5 +- .../sql/execution/joins/SortMergeJoin.scala | 331 +++++++++++++----- .../execution/joins/SortMergeOuterJoin.scala | 251 +++++++++++++ .../org/apache/spark/sql/JoinSuite.scala | 132 ++++--- .../sql/execution/joins/InnerJoinSuite.scala | 180 ++++++++++ .../sql/execution/joins/OuterJoinSuite.scala | 310 +++++++++++----- .../sql/execution/joins/SemiJoinSuite.scala | 125 ++++--- .../apache/spark/sql/test/SQLTestUtils.scala | 2 +- .../apache/spark/sql/hive/HiveContext.scala | 2 +- 13 files changed, 1165 insertions(+), 319 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/RowIterator.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/JoinedRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/JoinedRow.scala index b76757c93523d..d3560df0792eb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/JoinedRow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/JoinedRow.scala @@ -37,20 +37,20 @@ class JoinedRow extends InternalRow { } /** Updates this JoinedRow to used point at two new base rows. Returns itself. */ - def apply(r1: InternalRow, r2: InternalRow): InternalRow = { + def apply(r1: InternalRow, r2: InternalRow): JoinedRow = { row1 = r1 row2 = r2 this } /** Updates this JoinedRow by updating its left base row. Returns itself. */ - def withLeft(newLeft: InternalRow): InternalRow = { + def withLeft(newLeft: InternalRow): JoinedRow = { row1 = newLeft this } /** Updates this JoinedRow by updating its right base row. Returns itself. */ - def withRight(newRight: InternalRow): InternalRow = { + def withRight(newRight: InternalRow): JoinedRow = { row2 = newRight this } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index f73bb0488c984..4bf00b3399e7a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -873,7 +873,7 @@ class SQLContext(@transient val sparkContext: SparkContext) HashAggregation :: Aggregation :: LeftSemiJoin :: - HashJoin :: + EquiJoinSelection :: InMemoryScans :: BasicOperators :: CartesianProduct :: diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/RowIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/RowIterator.scala new file mode 100644 index 0000000000000..7462dbc4eba3a --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/RowIterator.scala @@ -0,0 +1,93 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import java.util.NoSuchElementException + +import org.apache.spark.sql.catalyst.InternalRow + +/** + * An internal iterator interface which presents a more restrictive API than + * [[scala.collection.Iterator]]. + * + * One major departure from the Scala iterator API is the fusing of the `hasNext()` and `next()` + * calls: Scala's iterator allows users to call `hasNext()` without immediately advancing the + * iterator to consume the next row, whereas RowIterator combines these calls into a single + * [[advanceNext()]] method. + */ +private[sql] abstract class RowIterator { + /** + * Advance this iterator by a single row. Returns `false` if this iterator has no more rows + * and `true` otherwise. If this returns `true`, then the new row can be retrieved by calling + * [[getRow]]. + */ + def advanceNext(): Boolean + + /** + * Retrieve the row from this iterator. This method is idempotent. It is illegal to call this + * method after [[advanceNext()]] has returned `false`. + */ + def getRow: InternalRow + + /** + * Convert this RowIterator into a [[scala.collection.Iterator]]. + */ + def toScala: Iterator[InternalRow] = new RowIteratorToScala(this) +} + +object RowIterator { + def fromScala(scalaIter: Iterator[InternalRow]): RowIterator = { + scalaIter match { + case wrappedRowIter: RowIteratorToScala => wrappedRowIter.rowIter + case _ => new RowIteratorFromScala(scalaIter) + } + } +} + +private final class RowIteratorToScala(val rowIter: RowIterator) extends Iterator[InternalRow] { + private [this] var hasNextWasCalled: Boolean = false + private [this] var _hasNext: Boolean = false + override def hasNext: Boolean = { + // Idempotency: + if (!hasNextWasCalled) { + _hasNext = rowIter.advanceNext() + hasNextWasCalled = true + } + _hasNext + } + override def next(): InternalRow = { + if (!hasNext) throw new NoSuchElementException + hasNextWasCalled = false + rowIter.getRow + } +} + +private final class RowIteratorFromScala(scalaIter: Iterator[InternalRow]) extends RowIterator { + private[this] var _next: InternalRow = null + override def advanceNext(): Boolean = { + if (scalaIter.hasNext) { + _next = scalaIter.next() + true + } else { + _next = null + false + } + } + override def getRow: InternalRow = _next + override def toScala: Iterator[InternalRow] = scalaIter +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index c4b9b5acea4de..1fc870d44b578 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -63,19 +63,23 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { } /** - * Uses the ExtractEquiJoinKeys pattern to find joins where at least some of the predicates can be - * evaluated by matching hash keys. + * Uses the [[ExtractEquiJoinKeys]] pattern to find joins where at least some of the predicates + * can be evaluated by matching join keys. * - * This strategy applies a simple optimization based on the estimates of the physical sizes of - * the two join sides. When planning a [[joins.BroadcastHashJoin]], if one side has an - * estimated physical size smaller than the user-settable threshold - * [[org.apache.spark.sql.SQLConf.AUTO_BROADCASTJOIN_THRESHOLD]], the planner would mark it as the - * ''build'' relation and mark the other relation as the ''stream'' side. The build table will be - * ''broadcasted'' to all of the executors involved in the join, as a - * [[org.apache.spark.broadcast.Broadcast]] object. If both estimates exceed the threshold, they - * will instead be used to decide the build side in a [[joins.ShuffledHashJoin]]. + * Join implementations are chosen with the following precedence: + * + * - Broadcast: if one side of the join has an estimated physical size that is smaller than the + * user-configurable [[org.apache.spark.sql.SQLConf.AUTO_BROADCASTJOIN_THRESHOLD]] threshold + * or if that side has an explicit broadcast hint (e.g. the user applied the + * [[org.apache.spark.sql.functions.broadcast()]] function to a DataFrame), then that side + * of the join will be broadcasted and the other side will be streamed, with no shuffling + * performed. If both sides of the join are eligible to be broadcasted then the + * - Sort merge: if the matching join keys are sortable and + * [[org.apache.spark.sql.SQLConf.SORTMERGE_JOIN]] is enabled (default), then sort merge join + * will be used. + * - Hash: will be chosen if neither of the above optimizations apply to this join. */ - object HashJoin extends Strategy with PredicateHelper { + object EquiJoinSelection extends Strategy with PredicateHelper { private[this] def makeBroadcastHashJoin( leftKeys: Seq[Expression], @@ -90,14 +94,15 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { } def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { + + // --- Inner joins -------------------------------------------------------------------------- + case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, CanBroadcast(right)) => makeBroadcastHashJoin(leftKeys, rightKeys, left, right, condition, joins.BuildRight) case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, CanBroadcast(left), right) => makeBroadcastHashJoin(leftKeys, rightKeys, left, right, condition, joins.BuildLeft) - // If the sort merge join option is set, we want to use sort merge join prior to hashjoin - // for now let's support inner join first, then add outer join case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, right) if sqlContext.conf.sortMergeJoinEnabled && RowOrdering.isOrderable(leftKeys) => val mergeJoin = @@ -115,6 +120,8 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { leftKeys, rightKeys, buildSide, planLater(left), planLater(right)) condition.map(Filter(_, hashJoin)).getOrElse(hashJoin) :: Nil + // --- Outer joins -------------------------------------------------------------------------- + case ExtractEquiJoinKeys( LeftOuter, leftKeys, rightKeys, condition, left, CanBroadcast(right)) => joins.BroadcastHashOuterJoin( @@ -125,10 +132,22 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { joins.BroadcastHashOuterJoin( leftKeys, rightKeys, RightOuter, condition, planLater(left), planLater(right)) :: Nil + case ExtractEquiJoinKeys(LeftOuter, leftKeys, rightKeys, condition, left, right) + if sqlContext.conf.sortMergeJoinEnabled && RowOrdering.isOrderable(leftKeys) => + joins.SortMergeOuterJoin( + leftKeys, rightKeys, LeftOuter, condition, planLater(left), planLater(right)) :: Nil + + case ExtractEquiJoinKeys(RightOuter, leftKeys, rightKeys, condition, left, right) + if sqlContext.conf.sortMergeJoinEnabled && RowOrdering.isOrderable(leftKeys) => + joins.SortMergeOuterJoin( + leftKeys, rightKeys, RightOuter, condition, planLater(left), planLater(right)) :: Nil + case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right) => joins.ShuffledHashOuterJoin( leftKeys, rightKeys, joinType, condition, planLater(left), planLater(right)) :: Nil + // --- Cases where this strategy does not apply --------------------------------------------- + case _ => Nil } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala index 23aebf4b068b4..017a44b9ca863 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala @@ -65,8 +65,9 @@ case class BroadcastNestedLoopJoin( left.output.map(_.withNullability(true)) ++ right.output case FullOuter => left.output.map(_.withNullability(true)) ++ right.output.map(_.withNullability(true)) - case _ => - left.output ++ right.output + case x => + throw new IllegalArgumentException( + s"BroadcastNestedLoopJoin should not take $x as the JoinType") } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala index 4ae23c186cf7b..6d656ea2849a9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala @@ -17,15 +17,14 @@ package org.apache.spark.sql.execution.joins -import java.util.NoSuchElementException +import scala.collection.mutable.ArrayBuffer import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical._ -import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} -import org.apache.spark.util.collection.CompactBuffer +import org.apache.spark.sql.execution.{BinaryNode, RowIterator, SparkPlan} /** * :: DeveloperApi :: @@ -38,8 +37,6 @@ case class SortMergeJoin( left: SparkPlan, right: SparkPlan) extends BinaryNode { - override protected[sql] val trackNumOfRowsEnabled = true - override def output: Seq[Attribute] = left.output ++ right.output override def outputPartitioning: Partitioning = @@ -56,117 +53,265 @@ case class SortMergeJoin( @transient protected lazy val leftKeyGenerator = newProjection(leftKeys, left.output) @transient protected lazy val rightKeyGenerator = newProjection(rightKeys, right.output) + protected[this] def isUnsafeMode: Boolean = { + (codegenEnabled && unsafeEnabled + && UnsafeProjection.canSupport(leftKeys) + && UnsafeProjection.canSupport(rightKeys) + && UnsafeProjection.canSupport(schema)) + } + + override def outputsUnsafeRows: Boolean = isUnsafeMode + override def canProcessUnsafeRows: Boolean = isUnsafeMode + override def canProcessSafeRows: Boolean = !isUnsafeMode + private def requiredOrders(keys: Seq[Expression]): Seq[SortOrder] = { // This must be ascending in order to agree with the `keyOrdering` defined in `doExecute()`. keys.map(SortOrder(_, Ascending)) } protected override def doExecute(): RDD[InternalRow] = { - val leftResults = left.execute().map(_.copy()) - val rightResults = right.execute().map(_.copy()) - - leftResults.zipPartitions(rightResults) { (leftIter, rightIter) => - new Iterator[InternalRow] { + left.execute().zipPartitions(right.execute()) { (leftIter, rightIter) => + new RowIterator { // An ordering that can be used to compare keys from both sides. private[this] val keyOrdering = newNaturalAscendingOrdering(leftKeys.map(_.dataType)) - // Mutable per row objects. + private[this] var currentLeftRow: InternalRow = _ + private[this] var currentRightMatches: ArrayBuffer[InternalRow] = _ + private[this] var currentMatchIdx: Int = -1 + private[this] val smjScanner = new SortMergeJoinScanner( + leftKeyGenerator, + rightKeyGenerator, + keyOrdering, + RowIterator.fromScala(leftIter), + RowIterator.fromScala(rightIter) + ) private[this] val joinRow = new JoinedRow - private[this] var leftElement: InternalRow = _ - private[this] var rightElement: InternalRow = _ - private[this] var leftKey: InternalRow = _ - private[this] var rightKey: InternalRow = _ - private[this] var rightMatches: CompactBuffer[InternalRow] = _ - private[this] var rightPosition: Int = -1 - private[this] var stop: Boolean = false - private[this] var matchKey: InternalRow = _ - - // initialize iterator - initialize() - - override final def hasNext: Boolean = nextMatchingPair() - - override final def next(): InternalRow = { - if (hasNext) { - // we are using the buffered right rows and run down left iterator - val joinedRow = joinRow(leftElement, rightMatches(rightPosition)) - rightPosition += 1 - if (rightPosition >= rightMatches.size) { - rightPosition = 0 - fetchLeft() - if (leftElement == null || keyOrdering.compare(leftKey, matchKey) != 0) { - stop = false - rightMatches = null - } - } - joinedRow + private[this] val resultProjection: (InternalRow) => InternalRow = { + if (isUnsafeMode) { + UnsafeProjection.create(schema) } else { - // no more result - throw new NoSuchElementException + identity[InternalRow] } } - private def fetchLeft() = { - if (leftIter.hasNext) { - leftElement = leftIter.next() - leftKey = leftKeyGenerator(leftElement) - } else { - leftElement = null + override def advanceNext(): Boolean = { + if (currentMatchIdx == -1 || currentMatchIdx == currentRightMatches.length) { + if (smjScanner.findNextInnerJoinRows()) { + currentRightMatches = smjScanner.getBufferedMatches + currentLeftRow = smjScanner.getStreamedRow + currentMatchIdx = 0 + } else { + currentRightMatches = null + currentLeftRow = null + currentMatchIdx = -1 + } } - } - - private def fetchRight() = { - if (rightIter.hasNext) { - rightElement = rightIter.next() - rightKey = rightKeyGenerator(rightElement) + if (currentLeftRow != null) { + joinRow(currentLeftRow, currentRightMatches(currentMatchIdx)) + currentMatchIdx += 1 + true } else { - rightElement = null + false } } - private def initialize() = { - fetchLeft() - fetchRight() + override def getRow: InternalRow = resultProjection(joinRow) + }.toScala + } + } +} + +/** + * Helper class that is used to implement [[SortMergeJoin]] and [[SortMergeOuterJoin]]. + * + * To perform an inner (outer) join, users of this class call [[findNextInnerJoinRows()]] + * ([[findNextOuterJoinRows()]]), which returns `true` if a result has been produced and `false` + * otherwise. If a result has been produced, then the caller may call [[getStreamedRow]] to return + * the matching row from the streamed input and may call [[getBufferedMatches]] to return the + * sequence of matching rows from the buffered input (in the case of an outer join, this will return + * an empty sequence if there are no matches from the buffered input). For efficiency, both of these + * methods return mutable objects which are re-used across calls to the `findNext*JoinRows()` + * methods. + * + * @param streamedKeyGenerator a projection that produces join keys from the streamed input. + * @param bufferedKeyGenerator a projection that produces join keys from the buffered input. + * @param keyOrdering an ordering which can be used to compare join keys. + * @param streamedIter an input whose rows will be streamed. + * @param bufferedIter an input whose rows will be buffered to construct sequences of rows that + * have the same join key. + */ +private[joins] class SortMergeJoinScanner( + streamedKeyGenerator: Projection, + bufferedKeyGenerator: Projection, + keyOrdering: Ordering[InternalRow], + streamedIter: RowIterator, + bufferedIter: RowIterator) { + private[this] var streamedRow: InternalRow = _ + private[this] var streamedRowKey: InternalRow = _ + private[this] var bufferedRow: InternalRow = _ + // Note: this is guaranteed to never have any null columns: + private[this] var bufferedRowKey: InternalRow = _ + /** + * The join key for the rows buffered in `bufferedMatches`, or null if `bufferedMatches` is empty + */ + private[this] var matchJoinKey: InternalRow = _ + /** Buffered rows from the buffered side of the join. This is empty if there are no matches. */ + private[this] val bufferedMatches: ArrayBuffer[InternalRow] = new ArrayBuffer[InternalRow] + + // Initialization (note: do _not_ want to advance streamed here). + advancedBufferedToRowWithNullFreeJoinKey() + + // --- Public methods --------------------------------------------------------------------------- + + def getStreamedRow: InternalRow = streamedRow + + def getBufferedMatches: ArrayBuffer[InternalRow] = bufferedMatches + + /** + * Advances both input iterators, stopping when we have found rows with matching join keys. + * @return true if matching rows have been found and false otherwise. If this returns true, then + * [[getStreamedRow]] and [[getBufferedMatches]] can be called to construct the join + * results. + */ + final def findNextInnerJoinRows(): Boolean = { + while (advancedStreamed() && streamedRowKey.anyNull) { + // Advance the streamed side of the join until we find the next row whose join key contains + // no nulls or we hit the end of the streamed iterator. + } + if (streamedRow == null) { + // We have consumed the entire streamed iterator, so there can be no more matches. + matchJoinKey = null + bufferedMatches.clear() + false + } else if (matchJoinKey != null && keyOrdering.compare(streamedRowKey, matchJoinKey) == 0) { + // The new streamed row has the same join key as the previous row, so return the same matches. + true + } else if (bufferedRow == null) { + // The streamed row's join key does not match the current batch of buffered rows and there are + // no more rows to read from the buffered iterator, so there can be no more matches. + matchJoinKey = null + bufferedMatches.clear() + false + } else { + // Advance both the streamed and buffered iterators to find the next pair of matching rows. + var comp = keyOrdering.compare(streamedRowKey, bufferedRowKey) + do { + if (streamedRowKey.anyNull) { + advancedStreamed() + } else { + assert(!bufferedRowKey.anyNull) + comp = keyOrdering.compare(streamedRowKey, bufferedRowKey) + if (comp > 0) advancedBufferedToRowWithNullFreeJoinKey() + else if (comp < 0) advancedStreamed() } + } while (streamedRow != null && bufferedRow != null && comp != 0) + if (streamedRow == null || bufferedRow == null) { + // We have either hit the end of one of the iterators, so there can be no more matches. + matchJoinKey = null + bufferedMatches.clear() + false + } else { + // The streamed row's join key matches the current buffered row's join, so walk through the + // buffered iterator to buffer the rest of the matching rows. + assert(comp == 0) + bufferMatchingRows() + true + } + } + } - /** - * Searches the right iterator for the next rows that have matches in left side, and store - * them in a buffer. - * - * @return true if the search is successful, and false if the right iterator runs out of - * tuples. - */ - private def nextMatchingPair(): Boolean = { - if (!stop && rightElement != null) { - // run both side to get the first match pair - while (!stop && leftElement != null && rightElement != null) { - val comparing = keyOrdering.compare(leftKey, rightKey) - // for inner join, we need to filter those null keys - stop = comparing == 0 && !leftKey.anyNull - if (comparing > 0 || rightKey.anyNull) { - fetchRight() - } else if (comparing < 0 || leftKey.anyNull) { - fetchLeft() - } - } - rightMatches = new CompactBuffer[InternalRow]() - if (stop) { - stop = false - // iterate the right side to buffer all rows that matches - // as the records should be ordered, exit when we meet the first that not match - while (!stop && rightElement != null) { - rightMatches += rightElement - fetchRight() - stop = keyOrdering.compare(leftKey, rightKey) != 0 - } - if (rightMatches.size > 0) { - rightPosition = 0 - matchKey = leftKey - } - } + /** + * Advances the streamed input iterator and buffers all rows from the buffered input that + * have matching keys. + * @return true if the streamed iterator returned a row, false otherwise. If this returns true, + * then [getStreamedRow and [[getBufferedMatches]] can be called to produce the outer + * join results. + */ + final def findNextOuterJoinRows(): Boolean = { + if (!advancedStreamed()) { + // We have consumed the entire streamed iterator, so there can be no more matches. + matchJoinKey = null + bufferedMatches.clear() + false + } else { + if (matchJoinKey != null && keyOrdering.compare(streamedRowKey, matchJoinKey) == 0) { + // Matches the current group, so do nothing. + } else { + // The streamed row does not match the current group. + matchJoinKey = null + bufferedMatches.clear() + if (bufferedRow != null && !streamedRowKey.anyNull) { + // The buffered iterator could still contain matching rows, so we'll need to walk through + // it until we either find matches or pass where they would be found. + var comp = 1 + do { + comp = keyOrdering.compare(streamedRowKey, bufferedRowKey) + } while (comp > 0 && advancedBufferedToRowWithNullFreeJoinKey()) + if (comp == 0) { + // We have found matches, so buffer them (this updates matchJoinKey) + bufferMatchingRows() + } else { + // We have overshot the position where the row would be found, hence no matches. } - rightMatches != null && rightMatches.size > 0 } } + // If there is a streamed input then we always return true + true } } + + // --- Private methods -------------------------------------------------------------------------- + + /** + * Advance the streamed iterator and compute the new row's join key. + * @return true if the streamed iterator returned a row and false otherwise. + */ + private def advancedStreamed(): Boolean = { + if (streamedIter.advanceNext()) { + streamedRow = streamedIter.getRow + streamedRowKey = streamedKeyGenerator(streamedRow) + true + } else { + streamedRow = null + streamedRowKey = null + false + } + } + + /** + * Advance the buffered iterator until we find a row with join key that does not contain nulls. + * @return true if the buffered iterator returned a row and false otherwise. + */ + private def advancedBufferedToRowWithNullFreeJoinKey(): Boolean = { + var foundRow: Boolean = false + while (!foundRow && bufferedIter.advanceNext()) { + bufferedRow = bufferedIter.getRow + bufferedRowKey = bufferedKeyGenerator(bufferedRow) + foundRow = !bufferedRowKey.anyNull + } + if (!foundRow) { + bufferedRow = null + bufferedRowKey = null + false + } else { + true + } + } + + /** + * Called when the streamed and buffered join keys match in order to buffer the matching rows. + */ + private def bufferMatchingRows(): Unit = { + assert(streamedRowKey != null) + assert(!streamedRowKey.anyNull) + assert(bufferedRowKey != null) + assert(!bufferedRowKey.anyNull) + assert(keyOrdering.compare(streamedRowKey, bufferedRowKey) == 0) + // This join key may have been produced by a mutable projection, so we need to make a copy: + matchJoinKey = streamedRowKey.copy() + bufferedMatches.clear() + do { + bufferedMatches += bufferedRow.copy() // need to copy mutable rows before buffering them + advancedBufferedToRowWithNullFreeJoinKey() + } while (bufferedRow != null && keyOrdering.compare(streamedRowKey, bufferedRowKey) == 0) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala new file mode 100644 index 0000000000000..5326966b07a66 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala @@ -0,0 +1,251 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.joins + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.{JoinType, LeftOuter, RightOuter} +import org.apache.spark.sql.catalyst.plans.physical._ +import org.apache.spark.sql.execution.{BinaryNode, RowIterator, SparkPlan} + +/** + * :: DeveloperApi :: + * Performs an sort merge outer join of two child relations. + * + * Note: this does not support full outer join yet; see SPARK-9730 for progress on this. + */ +@DeveloperApi +case class SortMergeOuterJoin( + leftKeys: Seq[Expression], + rightKeys: Seq[Expression], + joinType: JoinType, + condition: Option[Expression], + left: SparkPlan, + right: SparkPlan) extends BinaryNode { + + override def output: Seq[Attribute] = { + joinType match { + case LeftOuter => + left.output ++ right.output.map(_.withNullability(true)) + case RightOuter => + left.output.map(_.withNullability(true)) ++ right.output + case x => + throw new IllegalArgumentException( + s"${getClass.getSimpleName} should not take $x as the JoinType") + } + } + + override def outputPartitioning: Partitioning = joinType match { + // For left and right outer joins, the output is partitioned by the streamed input's join keys. + case LeftOuter => left.outputPartitioning + case RightOuter => right.outputPartitioning + case x => + throw new IllegalArgumentException( + s"${getClass.getSimpleName} should not take $x as the JoinType") + } + + override def outputOrdering: Seq[SortOrder] = joinType match { + // For left and right outer joins, the output is ordered by the streamed input's join keys. + case LeftOuter => requiredOrders(leftKeys) + case RightOuter => requiredOrders(rightKeys) + case x => throw new IllegalArgumentException( + s"SortMergeOuterJoin should not take $x as the JoinType") + } + + override def requiredChildDistribution: Seq[Distribution] = + ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil + + override def requiredChildOrdering: Seq[Seq[SortOrder]] = + requiredOrders(leftKeys) :: requiredOrders(rightKeys) :: Nil + + private def requiredOrders(keys: Seq[Expression]): Seq[SortOrder] = { + // This must be ascending in order to agree with the `keyOrdering` defined in `doExecute()`. + keys.map(SortOrder(_, Ascending)) + } + + private def isUnsafeMode: Boolean = { + (codegenEnabled && unsafeEnabled + && UnsafeProjection.canSupport(leftKeys) + && UnsafeProjection.canSupport(rightKeys) + && UnsafeProjection.canSupport(schema)) + } + + override def outputsUnsafeRows: Boolean = isUnsafeMode + override def canProcessUnsafeRows: Boolean = isUnsafeMode + override def canProcessSafeRows: Boolean = !isUnsafeMode + + private def createLeftKeyGenerator(): Projection = { + if (isUnsafeMode) { + UnsafeProjection.create(leftKeys, left.output) + } else { + newProjection(leftKeys, left.output) + } + } + + private def createRightKeyGenerator(): Projection = { + if (isUnsafeMode) { + UnsafeProjection.create(rightKeys, right.output) + } else { + newProjection(rightKeys, right.output) + } + } + + override def doExecute(): RDD[InternalRow] = { + left.execute().zipPartitions(right.execute()) { (leftIter, rightIter) => + // An ordering that can be used to compare keys from both sides. + val keyOrdering = newNaturalAscendingOrdering(leftKeys.map(_.dataType)) + val boundCondition: (InternalRow) => Boolean = { + condition.map { cond => + newPredicate(cond, left.output ++ right.output) + }.getOrElse { + (r: InternalRow) => true + } + } + val resultProj: InternalRow => InternalRow = { + if (isUnsafeMode) { + UnsafeProjection.create(schema) + } else { + identity[InternalRow] + } + } + + joinType match { + case LeftOuter => + val smjScanner = new SortMergeJoinScanner( + streamedKeyGenerator = createLeftKeyGenerator(), + bufferedKeyGenerator = createRightKeyGenerator(), + keyOrdering, + streamedIter = RowIterator.fromScala(leftIter), + bufferedIter = RowIterator.fromScala(rightIter) + ) + val rightNullRow = new GenericInternalRow(right.output.length) + new LeftOuterIterator(smjScanner, rightNullRow, boundCondition, resultProj).toScala + + case RightOuter => + val smjScanner = new SortMergeJoinScanner( + streamedKeyGenerator = createRightKeyGenerator(), + bufferedKeyGenerator = createLeftKeyGenerator(), + keyOrdering, + streamedIter = RowIterator.fromScala(rightIter), + bufferedIter = RowIterator.fromScala(leftIter) + ) + val leftNullRow = new GenericInternalRow(left.output.length) + new RightOuterIterator(smjScanner, leftNullRow, boundCondition, resultProj).toScala + + case x => + throw new IllegalArgumentException( + s"SortMergeOuterJoin should not take $x as the JoinType") + } + } + } +} + + +private class LeftOuterIterator( + smjScanner: SortMergeJoinScanner, + rightNullRow: InternalRow, + boundCondition: InternalRow => Boolean, + resultProj: InternalRow => InternalRow + ) extends RowIterator { + private[this] val joinedRow: JoinedRow = new JoinedRow() + private[this] var rightIdx: Int = 0 + assert(smjScanner.getBufferedMatches.length == 0) + + private def advanceLeft(): Boolean = { + rightIdx = 0 + if (smjScanner.findNextOuterJoinRows()) { + joinedRow.withLeft(smjScanner.getStreamedRow) + if (smjScanner.getBufferedMatches.isEmpty) { + // There are no matching right rows, so return nulls for the right row + joinedRow.withRight(rightNullRow) + } else { + // Find the next row from the right input that satisfied the bound condition + if (!advanceRightUntilBoundConditionSatisfied()) { + joinedRow.withRight(rightNullRow) + } + } + true + } else { + // Left input has been exhausted + false + } + } + + private def advanceRightUntilBoundConditionSatisfied(): Boolean = { + var foundMatch: Boolean = false + while (!foundMatch && rightIdx < smjScanner.getBufferedMatches.length) { + foundMatch = boundCondition(joinedRow.withRight(smjScanner.getBufferedMatches(rightIdx))) + rightIdx += 1 + } + foundMatch + } + + override def advanceNext(): Boolean = { + advanceRightUntilBoundConditionSatisfied() || advanceLeft() + } + + override def getRow: InternalRow = resultProj(joinedRow) +} + +private class RightOuterIterator( + smjScanner: SortMergeJoinScanner, + leftNullRow: InternalRow, + boundCondition: InternalRow => Boolean, + resultProj: InternalRow => InternalRow + ) extends RowIterator { + private[this] val joinedRow: JoinedRow = new JoinedRow() + private[this] var leftIdx: Int = 0 + assert(smjScanner.getBufferedMatches.length == 0) + + private def advanceRight(): Boolean = { + leftIdx = 0 + if (smjScanner.findNextOuterJoinRows()) { + joinedRow.withRight(smjScanner.getStreamedRow) + if (smjScanner.getBufferedMatches.isEmpty) { + // There are no matching left rows, so return nulls for the left row + joinedRow.withLeft(leftNullRow) + } else { + // Find the next row from the left input that satisfied the bound condition + if (!advanceLeftUntilBoundConditionSatisfied()) { + joinedRow.withLeft(leftNullRow) + } + } + true + } else { + // Right input has been exhausted + false + } + } + + private def advanceLeftUntilBoundConditionSatisfied(): Boolean = { + var foundMatch: Boolean = false + while (!foundMatch && leftIdx < smjScanner.getBufferedMatches.length) { + foundMatch = boundCondition(joinedRow.withLeft(smjScanner.getBufferedMatches(leftIdx))) + leftIdx += 1 + } + foundMatch + } + + override def advanceNext(): Boolean = { + advanceLeftUntilBoundConditionSatisfied() || advanceRight() + } + + override def getRow: InternalRow = resultProj(joinedRow) +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index 5bef1d8966031..ae07eaf91c872 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -22,13 +22,14 @@ import org.scalatest.BeforeAndAfterEach import org.apache.spark.sql.TestData._ import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation import org.apache.spark.sql.execution.joins._ -import org.apache.spark.sql.types.BinaryType +import org.apache.spark.sql.test.SQLTestUtils -class JoinSuite extends QueryTest with BeforeAndAfterEach { +class JoinSuite extends QueryTest with SQLTestUtils with BeforeAndAfterEach { // Ensures tables are loaded. TestData + override def sqlContext: SQLContext = org.apache.spark.sql.test.TestSQLContext lazy val ctx = org.apache.spark.sql.test.TestSQLContext import ctx.implicits._ import ctx.logicalPlanToSparkQuery @@ -37,7 +38,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { val x = testData2.as("x") val y = testData2.as("y") val join = x.join(y, $"x.a" === $"y.a", "inner").queryExecution.optimizedPlan - val planned = ctx.planner.HashJoin(join) + val planned = ctx.planner.EquiJoinSelection(join) assert(planned.size === 1) } @@ -55,6 +56,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { case j: BroadcastNestedLoopJoin => j case j: BroadcastLeftSemiJoinHash => j case j: SortMergeJoin => j + case j: SortMergeOuterJoin => j } assert(operators.size === 1) @@ -66,7 +68,6 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { test("join operator selection") { ctx.cacheManager.clearCache() - val SORTMERGEJOIN_ENABLED: Boolean = ctx.conf.sortMergeJoinEnabled Seq( ("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a", classOf[LeftSemiJoinHash]), ("SELECT * FROM testData LEFT SEMI JOIN testData2", classOf[LeftSemiJoinBNL]), @@ -83,11 +84,11 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { ("SELECT * FROM testData JOIN testData2 ON key = a", classOf[SortMergeJoin]), ("SELECT * FROM testData JOIN testData2 ON key = a and key = 2", classOf[SortMergeJoin]), ("SELECT * FROM testData JOIN testData2 ON key = a where key = 2", classOf[SortMergeJoin]), - ("SELECT * FROM testData LEFT JOIN testData2 ON key = a", classOf[ShuffledHashOuterJoin]), + ("SELECT * FROM testData LEFT JOIN testData2 ON key = a", classOf[SortMergeOuterJoin]), ("SELECT * FROM testData RIGHT JOIN testData2 ON key = a where key = 2", - classOf[ShuffledHashOuterJoin]), + classOf[SortMergeOuterJoin]), ("SELECT * FROM testData right join testData2 ON key = a and key = 2", - classOf[ShuffledHashOuterJoin]), + classOf[SortMergeOuterJoin]), ("SELECT * FROM testData full outer join testData2 ON key = a", classOf[ShuffledHashOuterJoin]), ("SELECT * FROM testData left JOIN testData2 ON (key * a != key + a)", @@ -97,82 +98,75 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { ("SELECT * FROM testData full JOIN testData2 ON (key * a != key + a)", classOf[BroadcastNestedLoopJoin]) ).foreach { case (query, joinClass) => assertJoin(query, joinClass) } - try { - ctx.conf.setConf(SQLConf.SORTMERGE_JOIN, true) + withSQLConf(SQLConf.SORTMERGE_JOIN.key -> "false") { Seq( - ("SELECT * FROM testData JOIN testData2 ON key = a", classOf[SortMergeJoin]), - ("SELECT * FROM testData JOIN testData2 ON key = a and key = 2", classOf[SortMergeJoin]), - ("SELECT * FROM testData JOIN testData2 ON key = a where key = 2", classOf[SortMergeJoin]) + ("SELECT * FROM testData JOIN testData2 ON key = a", classOf[ShuffledHashJoin]), + ("SELECT * FROM testData JOIN testData2 ON key = a and key = 2", + classOf[ShuffledHashJoin]), + ("SELECT * FROM testData JOIN testData2 ON key = a where key = 2", + classOf[ShuffledHashJoin]), + ("SELECT * FROM testData LEFT JOIN testData2 ON key = a", classOf[ShuffledHashOuterJoin]), + ("SELECT * FROM testData RIGHT JOIN testData2 ON key = a where key = 2", + classOf[ShuffledHashOuterJoin]), + ("SELECT * FROM testData right join testData2 ON key = a and key = 2", + classOf[ShuffledHashOuterJoin]), + ("SELECT * FROM testData full outer join testData2 ON key = a", + classOf[ShuffledHashOuterJoin]) ).foreach { case (query, joinClass) => assertJoin(query, joinClass) } - } finally { - ctx.conf.setConf(SQLConf.SORTMERGE_JOIN, SORTMERGEJOIN_ENABLED) } } test("SortMergeJoin shouldn't work on unsortable columns") { - val SORTMERGEJOIN_ENABLED: Boolean = ctx.conf.sortMergeJoinEnabled - try { - ctx.conf.setConf(SQLConf.SORTMERGE_JOIN, true) + withSQLConf(SQLConf.SORTMERGE_JOIN.key -> "true") { Seq( ("SELECT * FROM arrayData JOIN complexData ON data = a", classOf[ShuffledHashJoin]) ).foreach { case (query, joinClass) => assertJoin(query, joinClass) } - } finally { - ctx.conf.setConf(SQLConf.SORTMERGE_JOIN, SORTMERGEJOIN_ENABLED) } } test("broadcasted hash join operator selection") { ctx.cacheManager.clearCache() ctx.sql("CACHE TABLE testData") - - val SORTMERGEJOIN_ENABLED: Boolean = ctx.conf.sortMergeJoinEnabled - Seq( - ("SELECT * FROM testData join testData2 ON key = a", classOf[BroadcastHashJoin]), - ("SELECT * FROM testData join testData2 ON key = a and key = 2", classOf[BroadcastHashJoin]), - ("SELECT * FROM testData join testData2 ON key = a where key = 2", - classOf[BroadcastHashJoin]) - ).foreach { case (query, joinClass) => assertJoin(query, joinClass) } - try { - ctx.conf.setConf(SQLConf.SORTMERGE_JOIN, true) - Seq( - ("SELECT * FROM testData join testData2 ON key = a", classOf[BroadcastHashJoin]), - ("SELECT * FROM testData join testData2 ON key = a and key = 2", - classOf[BroadcastHashJoin]), - ("SELECT * FROM testData join testData2 ON key = a where key = 2", - classOf[BroadcastHashJoin]) - ).foreach { case (query, joinClass) => assertJoin(query, joinClass) } - } finally { - ctx.conf.setConf(SQLConf.SORTMERGE_JOIN, SORTMERGEJOIN_ENABLED) + for (sortMergeJoinEnabled <- Seq(true, false)) { + withClue(s"sortMergeJoinEnabled=$sortMergeJoinEnabled") { + withSQLConf(SQLConf.SORTMERGE_JOIN.key -> s"$sortMergeJoinEnabled") { + Seq( + ("SELECT * FROM testData join testData2 ON key = a", + classOf[BroadcastHashJoin]), + ("SELECT * FROM testData join testData2 ON key = a and key = 2", + classOf[BroadcastHashJoin]), + ("SELECT * FROM testData join testData2 ON key = a where key = 2", + classOf[BroadcastHashJoin]) + ).foreach { case (query, joinClass) => assertJoin(query, joinClass) } + } + } } - ctx.sql("UNCACHE TABLE testData") } test("broadcasted hash outer join operator selection") { ctx.cacheManager.clearCache() ctx.sql("CACHE TABLE testData") - - val SORTMERGEJOIN_ENABLED: Boolean = ctx.conf.sortMergeJoinEnabled - Seq( - ("SELECT * FROM testData LEFT JOIN testData2 ON key = a", classOf[ShuffledHashOuterJoin]), - ("SELECT * FROM testData RIGHT JOIN testData2 ON key = a where key = 2", - classOf[BroadcastHashOuterJoin]), - ("SELECT * FROM testData right join testData2 ON key = a and key = 2", - classOf[BroadcastHashOuterJoin]) - ).foreach { case (query, joinClass) => assertJoin(query, joinClass) } - try { - ctx.conf.setConf(SQLConf.SORTMERGE_JOIN, true) + withSQLConf(SQLConf.SORTMERGE_JOIN.key -> "true") { Seq( - ("SELECT * FROM testData LEFT JOIN testData2 ON key = a", classOf[ShuffledHashOuterJoin]), + ("SELECT * FROM testData LEFT JOIN testData2 ON key = a", + classOf[SortMergeOuterJoin]), + ("SELECT * FROM testData RIGHT JOIN testData2 ON key = a where key = 2", + classOf[BroadcastHashOuterJoin]), + ("SELECT * FROM testData right join testData2 ON key = a and key = 2", + classOf[BroadcastHashOuterJoin]) + ).foreach { case (query, joinClass) => assertJoin(query, joinClass) } + } + withSQLConf(SQLConf.SORTMERGE_JOIN.key -> "false") { + Seq( + ("SELECT * FROM testData LEFT JOIN testData2 ON key = a", + classOf[ShuffledHashOuterJoin]), ("SELECT * FROM testData RIGHT JOIN testData2 ON key = a where key = 2", classOf[BroadcastHashOuterJoin]), ("SELECT * FROM testData right join testData2 ON key = a and key = 2", classOf[BroadcastHashOuterJoin]) ).foreach { case (query, joinClass) => assertJoin(query, joinClass) } - } finally { - ctx.conf.setConf(SQLConf.SORTMERGE_JOIN, SORTMERGEJOIN_ENABLED) } - ctx.sql("UNCACHE TABLE testData") } @@ -180,7 +174,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { val x = testData2.as("x") val y = testData2.as("y") val join = x.join(y, ($"x.a" === $"y.a") && ($"x.b" === $"y.b")).queryExecution.optimizedPlan - val planned = ctx.planner.HashJoin(join) + val planned = ctx.planner.EquiJoinSelection(join) assert(planned.size === 1) } @@ -457,25 +451,24 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { test("broadcasted left semi join operator selection") { ctx.cacheManager.clearCache() ctx.sql("CACHE TABLE testData") - val tmp = ctx.conf.autoBroadcastJoinThreshold - ctx.sql(s"SET ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key}=1000000000") - Seq( - ("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a", - classOf[BroadcastLeftSemiJoinHash]) - ).foreach { - case (query, joinClass) => assertJoin(query, joinClass) + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "1000000000") { + Seq( + ("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a", + classOf[BroadcastLeftSemiJoinHash]) + ).foreach { + case (query, joinClass) => assertJoin(query, joinClass) + } } - ctx.sql(s"SET ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key}=-1") - - Seq( - ("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a", classOf[LeftSemiJoinHash]) - ).foreach { - case (query, joinClass) => assertJoin(query, joinClass) + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { + Seq( + ("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a", classOf[LeftSemiJoinHash]) + ).foreach { + case (query, joinClass) => assertJoin(query, joinClass) + } } - ctx.setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, tmp) ctx.sql("UNCACHE TABLE testData") } @@ -488,6 +481,5 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { Row(2, 2) :: Row(3, 1) :: Row(3, 2) :: Nil) - } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala new file mode 100644 index 0000000000000..ddff7cebcc17d --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala @@ -0,0 +1,180 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.joins + +import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys +import org.apache.spark.sql.catalyst.plans.Inner +import org.apache.spark.sql.catalyst.plans.logical.Join +import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.sql.types.{IntegerType, StringType, StructType} +import org.apache.spark.sql.{SQLConf, execution, Row, DataFrame} +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.execution._ + +class InnerJoinSuite extends SparkPlanTest with SQLTestUtils { + + private def testInnerJoin( + testName: String, + leftRows: DataFrame, + rightRows: DataFrame, + condition: Expression, + expectedAnswer: Seq[Product]): Unit = { + val join = Join(leftRows.logicalPlan, rightRows.logicalPlan, Inner, Some(condition)) + ExtractEquiJoinKeys.unapply(join).foreach { + case (joinType, leftKeys, rightKeys, boundCondition, leftChild, rightChild) => + + def makeBroadcastHashJoin(left: SparkPlan, right: SparkPlan, side: BuildSide) = { + val broadcastHashJoin = + execution.joins.BroadcastHashJoin(leftKeys, rightKeys, side, left, right) + boundCondition.map(Filter(_, broadcastHashJoin)).getOrElse(broadcastHashJoin) + } + + def makeShuffledHashJoin(left: SparkPlan, right: SparkPlan, side: BuildSide) = { + val shuffledHashJoin = + execution.joins.ShuffledHashJoin(leftKeys, rightKeys, side, left, right) + val filteredJoin = + boundCondition.map(Filter(_, shuffledHashJoin)).getOrElse(shuffledHashJoin) + EnsureRequirements(sqlContext).apply(filteredJoin) + } + + def makeSortMergeJoin(left: SparkPlan, right: SparkPlan) = { + val sortMergeJoin = + execution.joins.SortMergeJoin(leftKeys, rightKeys, left, right) + val filteredJoin = boundCondition.map(Filter(_, sortMergeJoin)).getOrElse(sortMergeJoin) + EnsureRequirements(sqlContext).apply(filteredJoin) + } + + test(s"$testName using BroadcastHashJoin (build=left)") { + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => + makeBroadcastHashJoin(left, right, joins.BuildLeft), + expectedAnswer.map(Row.fromTuple), + sortAnswers = true) + } + } + + test(s"$testName using BroadcastHashJoin (build=right)") { + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => + makeBroadcastHashJoin(left, right, joins.BuildRight), + expectedAnswer.map(Row.fromTuple), + sortAnswers = true) + } + } + + test(s"$testName using ShuffledHashJoin (build=left)") { + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => + makeShuffledHashJoin(left, right, joins.BuildLeft), + expectedAnswer.map(Row.fromTuple), + sortAnswers = true) + } + } + + test(s"$testName using ShuffledHashJoin (build=right)") { + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => + makeShuffledHashJoin(left, right, joins.BuildRight), + expectedAnswer.map(Row.fromTuple), + sortAnswers = true) + } + } + + test(s"$testName using SortMergeJoin") { + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => + makeSortMergeJoin(left, right), + expectedAnswer.map(Row.fromTuple), + sortAnswers = true) + } + } + } + } + + { + val upperCaseData = sqlContext.createDataFrame(sqlContext.sparkContext.parallelize(Seq( + Row(1, "A"), + Row(2, "B"), + Row(3, "C"), + Row(4, "D"), + Row(5, "E"), + Row(6, "F"), + Row(null, "G") + )), new StructType().add("N", IntegerType).add("L", StringType)) + + val lowerCaseData = sqlContext.createDataFrame(sqlContext.sparkContext.parallelize(Seq( + Row(1, "a"), + Row(2, "b"), + Row(3, "c"), + Row(4, "d"), + Row(null, "e") + )), new StructType().add("n", IntegerType).add("l", StringType)) + + testInnerJoin( + "inner join, one match per row", + upperCaseData, + lowerCaseData, + (upperCaseData.col("N") === lowerCaseData.col("n")).expr, + Seq( + (1, "A", 1, "a"), + (2, "B", 2, "b"), + (3, "C", 3, "c"), + (4, "D", 4, "d") + ) + ) + } + + private val testData2 = Seq( + (1, 1), + (1, 2), + (2, 1), + (2, 2), + (3, 1), + (3, 2) + ).toDF("a", "b") + + { + val left = testData2.where("a = 1") + val right = testData2.where("a = 1") + testInnerJoin( + "inner join, multiple matches", + left, + right, + (left.col("a") === right.col("a")).expr, + Seq( + (1, 1, 1, 1), + (1, 1, 1, 2), + (1, 2, 1, 1), + (1, 2, 1, 2) + ) + ) + } + + { + val left = testData2.where("a = 1") + val right = testData2.where("a = 2") + testInnerJoin( + "inner join, no matches", + left, + right, + (left.col("a") === right.col("a")).expr, + Seq.empty + ) + } + +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala index 2c27da596bc4f..e16f5e39aa2f4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala @@ -1,89 +1,221 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.joins - -import org.apache.spark.sql.Row -import org.apache.spark.sql.catalyst.dsl.expressions._ -import org.apache.spark.sql.catalyst.expressions.{Expression, LessThan} -import org.apache.spark.sql.catalyst.plans.{FullOuter, LeftOuter, RightOuter} -import org.apache.spark.sql.execution.{SparkPlan, SparkPlanTest} - -class OuterJoinSuite extends SparkPlanTest { - - val left = Seq( - (1, 2.0), - (2, 1.0), - (3, 3.0) - ).toDF("a", "b") - - val right = Seq( - (2, 3.0), - (3, 2.0), - (4, 1.0) - ).toDF("c", "d") - - val leftKeys: List[Expression] = 'a :: Nil - val rightKeys: List[Expression] = 'c :: Nil - val condition = Some(LessThan('b, 'd)) - - test("shuffled hash outer join") { - checkAnswer2(left, right, (left: SparkPlan, right: SparkPlan) => - ShuffledHashOuterJoin(leftKeys, rightKeys, LeftOuter, condition, left, right), - Seq( - (1, 2.0, null, null), - (2, 1.0, 2, 3.0), - (3, 3.0, null, null) - ).map(Row.fromTuple)) - - checkAnswer2(left, right, (left: SparkPlan, right: SparkPlan) => - ShuffledHashOuterJoin(leftKeys, rightKeys, RightOuter, condition, left, right), - Seq( - (2, 1.0, 2, 3.0), - (null, null, 3, 2.0), - (null, null, 4, 1.0) - ).map(Row.fromTuple)) - - checkAnswer2(left, right, (left: SparkPlan, right: SparkPlan) => - ShuffledHashOuterJoin(leftKeys, rightKeys, FullOuter, condition, left, right), - Seq( - (1, 2.0, null, null), - (2, 1.0, 2, 3.0), - (3, 3.0, null, null), - (null, null, 3, 2.0), - (null, null, 4, 1.0) - ).map(Row.fromTuple)) - } - - test("broadcast hash outer join") { - checkAnswer2(left, right, (left: SparkPlan, right: SparkPlan) => - BroadcastHashOuterJoin(leftKeys, rightKeys, LeftOuter, condition, left, right), - Seq( - (1, 2.0, null, null), - (2, 1.0, 2, 3.0), - (3, 3.0, null, null) - ).map(Row.fromTuple)) - - checkAnswer2(left, right, (left: SparkPlan, right: SparkPlan) => - BroadcastHashOuterJoin(leftKeys, rightKeys, RightOuter, condition, left, right), - Seq( - (2, 1.0, 2, 3.0), - (null, null, 3, 2.0), - (null, null, 4, 1.0) - ).map(Row.fromTuple)) - } -} +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.joins + +import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys +import org.apache.spark.sql.catalyst.plans.logical.Join +import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.sql.types.{IntegerType, DoubleType, StructType} +import org.apache.spark.sql.{SQLConf, DataFrame, Row} +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans._ +import org.apache.spark.sql.execution.{EnsureRequirements, joins, SparkPlan, SparkPlanTest} + +class OuterJoinSuite extends SparkPlanTest with SQLTestUtils { + + private def testOuterJoin( + testName: String, + leftRows: DataFrame, + rightRows: DataFrame, + joinType: JoinType, + condition: Expression, + expectedAnswer: Seq[Product]): Unit = { + val join = Join(leftRows.logicalPlan, rightRows.logicalPlan, Inner, Some(condition)) + ExtractEquiJoinKeys.unapply(join).foreach { + case (_, leftKeys, rightKeys, boundCondition, leftChild, rightChild) => + test(s"$testName using ShuffledHashOuterJoin") { + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => + EnsureRequirements(sqlContext).apply( + ShuffledHashOuterJoin(leftKeys, rightKeys, joinType, boundCondition, left, right)), + expectedAnswer.map(Row.fromTuple), + sortAnswers = true) + } + } + + if (joinType != FullOuter) { + test(s"$testName using BroadcastHashOuterJoin") { + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => + BroadcastHashOuterJoin(leftKeys, rightKeys, joinType, boundCondition, left, right), + expectedAnswer.map(Row.fromTuple), + sortAnswers = true) + } + } + + test(s"$testName using SortMergeOuterJoin") { + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => + EnsureRequirements(sqlContext).apply( + SortMergeOuterJoin(leftKeys, rightKeys, joinType, boundCondition, left, right)), + expectedAnswer.map(Row.fromTuple), + sortAnswers = false) + } + } + } + } + + test(s"$testName using BroadcastNestedLoopJoin (build=left)") { + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => + joins.BroadcastNestedLoopJoin(left, right, joins.BuildLeft, joinType, Some(condition)), + expectedAnswer.map(Row.fromTuple), + sortAnswers = true) + } + } + + test(s"$testName using BroadcastNestedLoopJoin (build=right)") { + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => + joins.BroadcastNestedLoopJoin(left, right, joins.BuildRight, joinType, Some(condition)), + expectedAnswer.map(Row.fromTuple), + sortAnswers = true) + } + } + } + + val left = sqlContext.createDataFrame(sqlContext.sparkContext.parallelize(Seq( + Row(1, 2.0), + Row(2, 100.0), + Row(2, 1.0), // This row is duplicated to ensure that we will have multiple buffered matches + Row(2, 1.0), + Row(3, 3.0), + Row(5, 1.0), + Row(6, 6.0), + Row(null, null) + )), new StructType().add("a", IntegerType).add("b", DoubleType)) + + val right = sqlContext.createDataFrame(sqlContext.sparkContext.parallelize(Seq( + Row(0, 0.0), + Row(2, 3.0), // This row is duplicated to ensure that we will have multiple buffered matches + Row(2, -1.0), + Row(2, -1.0), + Row(2, 3.0), + Row(3, 2.0), + Row(4, 1.0), + Row(5, 3.0), + Row(7, 7.0), + Row(null, null) + )), new StructType().add("c", IntegerType).add("d", DoubleType)) + + val condition = { + And( + (left.col("a") === right.col("c")).expr, + LessThan(left.col("b").expr, right.col("d").expr)) + } + + // --- Basic outer joins ------------------------------------------------------------------------ + + testOuterJoin( + "basic left outer join", + left, + right, + LeftOuter, + condition, + Seq( + (null, null, null, null), + (1, 2.0, null, null), + (2, 100.0, null, null), + (2, 1.0, 2, 3.0), + (2, 1.0, 2, 3.0), + (2, 1.0, 2, 3.0), + (2, 1.0, 2, 3.0), + (3, 3.0, null, null), + (5, 1.0, 5, 3.0), + (6, 6.0, null, null) + ) + ) + + testOuterJoin( + "basic right outer join", + left, + right, + RightOuter, + condition, + Seq( + (null, null, null, null), + (null, null, 0, 0.0), + (2, 1.0, 2, 3.0), + (2, 1.0, 2, 3.0), + (null, null, 2, -1.0), + (null, null, 2, -1.0), + (2, 1.0, 2, 3.0), + (2, 1.0, 2, 3.0), + (null, null, 3, 2.0), + (null, null, 4, 1.0), + (5, 1.0, 5, 3.0), + (null, null, 7, 7.0) + ) + ) + + testOuterJoin( + "basic full outer join", + left, + right, + FullOuter, + condition, + Seq( + (1, 2.0, null, null), + (null, null, 2, -1.0), + (null, null, 2, -1.0), + (2, 100.0, null, null), + (2, 1.0, 2, 3.0), + (2, 1.0, 2, 3.0), + (2, 1.0, 2, 3.0), + (2, 1.0, 2, 3.0), + (3, 3.0, null, null), + (5, 1.0, 5, 3.0), + (6, 6.0, null, null), + (null, null, 0, 0.0), + (null, null, 3, 2.0), + (null, null, 4, 1.0), + (null, null, 7, 7.0), + (null, null, null, null), + (null, null, null, null) + ) + ) + + // --- Both inputs empty ------------------------------------------------------------------------ + + testOuterJoin( + "left outer join with both inputs empty", + left.filter("false"), + right.filter("false"), + LeftOuter, + condition, + Seq.empty + ) + + testOuterJoin( + "right outer join with both inputs empty", + left.filter("false"), + right.filter("false"), + RightOuter, + condition, + Seq.empty + ) + + testOuterJoin( + "full outer join with both inputs empty", + left.filter("false"), + right.filter("false"), + FullOuter, + condition, + Seq.empty + ) +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala index 927e85a7db3dc..4503ed251fcb1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala @@ -17,58 +17,91 @@ package org.apache.spark.sql.execution.joins -import org.apache.spark.sql.Row -import org.apache.spark.sql.catalyst.dsl.expressions._ -import org.apache.spark.sql.catalyst.expressions.{LessThan, Expression} -import org.apache.spark.sql.execution.{SparkPlan, SparkPlanTest} +import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys +import org.apache.spark.sql.catalyst.plans.Inner +import org.apache.spark.sql.catalyst.plans.logical.Join +import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.sql.types.{DoubleType, IntegerType, StructType} +import org.apache.spark.sql.{SQLConf, DataFrame, Row} +import org.apache.spark.sql.catalyst.expressions.{And, LessThan, Expression} +import org.apache.spark.sql.execution.{EnsureRequirements, SparkPlan, SparkPlanTest} +class SemiJoinSuite extends SparkPlanTest with SQLTestUtils { -class SemiJoinSuite extends SparkPlanTest{ - val left = Seq( - (1, 2.0), - (1, 2.0), - (2, 1.0), - (2, 1.0), - (3, 3.0) - ).toDF("a", "b") + private def testLeftSemiJoin( + testName: String, + leftRows: DataFrame, + rightRows: DataFrame, + condition: Expression, + expectedAnswer: Seq[Product]): Unit = { + val join = Join(leftRows.logicalPlan, rightRows.logicalPlan, Inner, Some(condition)) + ExtractEquiJoinKeys.unapply(join).foreach { + case (joinType, leftKeys, rightKeys, boundCondition, leftChild, rightChild) => + test(s"$testName using LeftSemiJoinHash") { + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => + EnsureRequirements(left.sqlContext).apply( + LeftSemiJoinHash(leftKeys, rightKeys, left, right, boundCondition)), + expectedAnswer.map(Row.fromTuple), + sortAnswers = true) + } + } - val right = Seq( - (2, 3.0), - (2, 3.0), - (3, 2.0), - (4, 1.0) - ).toDF("c", "d") + test(s"$testName using BroadcastLeftSemiJoinHash") { + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => + BroadcastLeftSemiJoinHash(leftKeys, rightKeys, left, right, boundCondition), + expectedAnswer.map(Row.fromTuple), + sortAnswers = true) + } + } + } - val leftKeys: List[Expression] = 'a :: Nil - val rightKeys: List[Expression] = 'c :: Nil - val condition = Some(LessThan('b, 'd)) - - test("left semi join hash") { - checkAnswer2(left, right, (left: SparkPlan, right: SparkPlan) => - LeftSemiJoinHash(leftKeys, rightKeys, left, right, condition), - Seq( - (2, 1.0), - (2, 1.0) - ).map(Row.fromTuple)) + test(s"$testName using LeftSemiJoinBNL") { + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => + LeftSemiJoinBNL(left, right, Some(condition)), + expectedAnswer.map(Row.fromTuple), + sortAnswers = true) + } + } } - test("left semi join BNL") { - checkAnswer2(left, right, (left: SparkPlan, right: SparkPlan) => - LeftSemiJoinBNL(left, right, condition), - Seq( - (1, 2.0), - (1, 2.0), - (2, 1.0), - (2, 1.0) - ).map(Row.fromTuple)) - } + val left = sqlContext.createDataFrame(sqlContext.sparkContext.parallelize(Seq( + Row(1, 2.0), + Row(1, 2.0), + Row(2, 1.0), + Row(2, 1.0), + Row(3, 3.0), + Row(null, null), + Row(null, 5.0), + Row(6, null) + )), new StructType().add("a", IntegerType).add("b", DoubleType)) - test("broadcast left semi join hash") { - checkAnswer2(left, right, (left: SparkPlan, right: SparkPlan) => - BroadcastLeftSemiJoinHash(leftKeys, rightKeys, left, right, condition), - Seq( - (2, 1.0), - (2, 1.0) - ).map(Row.fromTuple)) + val right = sqlContext.createDataFrame(sqlContext.sparkContext.parallelize(Seq( + Row(2, 3.0), + Row(2, 3.0), + Row(3, 2.0), + Row(4, 1.0), + Row(null, null), + Row(null, 5.0), + Row(6, null) + )), new StructType().add("c", IntegerType).add("d", DoubleType)) + + val condition = { + And( + (left.col("a") === right.col("c")).expr, + LessThan(left.col("b").expr, right.col("d").expr)) } + + testLeftSemiJoin( + "basic test", + left, + right, + condition, + Seq( + (2, 1.0), + (2, 1.0) + ) + ) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala index 4c11acdab9ec0..1066695589778 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql.SQLContext import org.apache.spark.util.Utils trait SQLTestUtils { this: SparkFunSuite => - def sqlContext: SQLContext + protected def sqlContext: SQLContext protected def configuration = sqlContext.sparkContext.hadoopConfiguration diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index 567d7fa12ff14..f17177a771c3b 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -531,7 +531,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) with Logging { HashAggregation, Aggregation, LeftSemiJoin, - HashJoin, + EquiJoinSelection, BasicOperators, CartesianProduct, BroadcastNestedLoopJoin From 0f90d6055e5bea9ceb1d454db84f4aa1d59b284d Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Mon, 10 Aug 2015 23:41:53 -0700 Subject: [PATCH 015/253] [SPARK-9640] [STREAMING] [TEST] Do not run Python Kinesis tests when the Kinesis assembly JAR has not been generated Author: Tathagata Das Closes #7961 from tdas/SPARK-9640 and squashes the following commits: 974ce19 [Tathagata Das] Undo changes related to SPARK-9727 004ae26 [Tathagata Das] style fixes 9bbb97d [Tathagata Das] Minor style fies e6a677e [Tathagata Das] Merge remote-tracking branch 'apache-github/master' into SPARK-9640 ca90719 [Tathagata Das] Removed extra line ba9cfc7 [Tathagata Das] Improved kinesis test selection logic 88d59bd [Tathagata Das] updated test modules 871fcc8 [Tathagata Das] Fixed SparkBuild 94be631 [Tathagata Das] Fixed style b858196 [Tathagata Das] Fixed conditions and few other things based on PR comments. e292e64 [Tathagata Das] Added filters for Kinesis python tests --- python/pyspark/streaming/tests.py | 56 ++++++++++++++++++++++++------- 1 file changed, 44 insertions(+), 12 deletions(-) diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py index 66ae3345f468f..f0ed415f97120 100644 --- a/python/pyspark/streaming/tests.py +++ b/python/pyspark/streaming/tests.py @@ -971,8 +971,10 @@ def test_kinesis_stream_api(self): "awsAccessKey", "awsSecretKey") def test_kinesis_stream(self): - if os.environ.get('ENABLE_KINESIS_TESTS') != '1': - print("Skip test_kinesis_stream") + if not are_kinesis_tests_enabled: + sys.stderr.write( + "Skipped test_kinesis_stream (enable by setting environment variable %s=1" + % kinesis_test_environ_var) return import random @@ -1013,6 +1015,7 @@ def get_output(_, rdd): traceback.print_exc() raise finally: + self.ssc.stop(False) kinesisTestUtils.deleteStream() kinesisTestUtils.deleteDynamoDBTable(kinesisAppName) @@ -1027,7 +1030,7 @@ def search_kafka_assembly_jar(): ("Failed to find Spark Streaming kafka assembly jar in %s. " % kafka_assembly_dir) + "You need to build Spark with " "'build/sbt assembly/assembly streaming-kafka-assembly/assembly' or " - "'build/mvn package' before running this test") + "'build/mvn package' before running this test.") elif len(jars) > 1: raise Exception(("Found multiple Spark Streaming Kafka assembly JARs in %s; please " "remove all but one") % kafka_assembly_dir) @@ -1045,7 +1048,7 @@ def search_flume_assembly_jar(): ("Failed to find Spark Streaming Flume assembly jar in %s. " % flume_assembly_dir) + "You need to build Spark with " "'build/sbt assembly/assembly streaming-flume-assembly/assembly' or " - "'build/mvn package' before running this test") + "'build/mvn package' before running this test.") elif len(jars) > 1: raise Exception(("Found multiple Spark Streaming Flume assembly JARs in %s; please " "remove all but one") % flume_assembly_dir) @@ -1095,11 +1098,7 @@ def search_kinesis_asl_assembly_jar(): os.path.join(kinesis_asl_assembly_dir, "target/scala-*/spark-streaming-kinesis-asl-assembly-*.jar")) if not jars: - raise Exception( - ("Failed to find Spark Streaming Kinesis ASL assembly jar in %s. " % - kinesis_asl_assembly_dir) + "You need to build Spark with " - "'build/sbt -Pkinesis-asl assembly/assembly streaming-kinesis-asl-assembly/assembly' " - "or 'build/mvn -Pkinesis-asl package' before running this test") + return None elif len(jars) > 1: raise Exception(("Found multiple Spark Streaming Kinesis ASL assembly JARs in %s; please " "remove all but one") % kinesis_asl_assembly_dir) @@ -1107,6 +1106,10 @@ def search_kinesis_asl_assembly_jar(): return jars[0] +# Must be same as the variable and condition defined in KinesisTestUtils.scala +kinesis_test_environ_var = "ENABLE_KINESIS_TESTS" +are_kinesis_tests_enabled = os.environ.get(kinesis_test_environ_var) == '1' + if __name__ == "__main__": kafka_assembly_jar = search_kafka_assembly_jar() flume_assembly_jar = search_flume_assembly_jar() @@ -1114,8 +1117,37 @@ def search_kinesis_asl_assembly_jar(): mqtt_test_jar = search_mqtt_test_jar() kinesis_asl_assembly_jar = search_kinesis_asl_assembly_jar() - jars = "%s,%s,%s,%s,%s" % (kafka_assembly_jar, flume_assembly_jar, kinesis_asl_assembly_jar, - mqtt_assembly_jar, mqtt_test_jar) + if kinesis_asl_assembly_jar is None: + kinesis_jar_present = False + jars = "%s,%s,%s,%s" % (kafka_assembly_jar, flume_assembly_jar, mqtt_assembly_jar, + mqtt_test_jar) + else: + kinesis_jar_present = True + jars = "%s,%s,%s,%s,%s" % (kafka_assembly_jar, flume_assembly_jar, mqtt_assembly_jar, + mqtt_test_jar, kinesis_asl_assembly_jar) os.environ["PYSPARK_SUBMIT_ARGS"] = "--jars %s pyspark-shell" % jars - unittest.main() + testcases = [BasicOperationTests, WindowFunctionTests, StreamingContextTests, + CheckpointTests, KafkaStreamTests, FlumeStreamTests, FlumePollingStreamTests] + + if kinesis_jar_present is True: + testcases.append(KinesisStreamTests) + elif are_kinesis_tests_enabled is False: + sys.stderr.write("Skipping all Kinesis Python tests as the optional Kinesis project was " + "not compiled with -Pkinesis-asl profile. To run these tests, " + "you need to build Spark with 'build/sbt -Pkinesis-asl assembly/assembly " + "streaming-kinesis-asl-assembly/assembly' or " + "'build/mvn -Pkinesis-asl package' before running this test.") + else: + raise Exception( + ("Failed to find Spark Streaming Kinesis assembly jar in %s. " + % kinesis_asl_assembly_dir) + + "You need to build Spark with 'build/sbt -Pkinesis-asl " + "assembly/assembly streaming-kinesis-asl-assembly/assembly'" + "or 'build/mvn -Pkinesis-asl package' before running this test.") + + sys.stderr.write("Running tests: %s \n" % (str(testcases))) + for testcase in testcases: + sys.stderr.write("[Running %s]\n" % (testcase)) + tests = unittest.TestLoader().loadTestsFromTestCase(testcase) + unittest.TextTestRunner(verbosity=2).run(tests) From 55752d88321925da815823f968128832de6fdbbb Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Tue, 11 Aug 2015 01:08:30 -0700 Subject: [PATCH 016/253] [SPARK-9810] [BUILD] Remove individual commit messages from the squash commit message For more information, please see the JIRA ticket and the associated dev list discussion. https://issues.apache.org/jira/browse/SPARK-9810 http://apache-spark-developers-list.1001551.n3.nabble.com/discuss-Removing-individual-commit-messages-from-the-squash-commit-message-td13295.html Author: Reynold Xin Closes #8091 from rxin/SPARK-9810. --- dev/merge_spark_pr.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/dev/merge_spark_pr.py b/dev/merge_spark_pr.py index ad4b76695c9ff..b9bdec3d70864 100755 --- a/dev/merge_spark_pr.py +++ b/dev/merge_spark_pr.py @@ -159,11 +159,7 @@ def merge_pr(pr_num, target_ref, title, body, pr_repo_desc): merge_message_flags += ["-m", message] # The string "Closes #%s" string is required for GitHub to correctly close the PR - merge_message_flags += [ - "-m", - "Closes #%s from %s and squashes the following commits:" % (pr_num, pr_repo_desc)] - for c in commits: - merge_message_flags += ["-m", c] + merge_message_flags += ["-m", "Closes #%s from %s." % (pr_num, pr_repo_desc)] run_cmd(['git', 'commit', '--author="%s"' % primary_author] + merge_message_flags) From 600031ebe27473d8fffe6ea436c2149223b82896 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Tue, 11 Aug 2015 02:41:03 -0700 Subject: [PATCH 017/253] [SPARK-9727] [STREAMING] [BUILD] Updated streaming kinesis SBT project name to be more consistent Author: Tathagata Das Closes #8092 from tdas/SPARK-9727 and squashes the following commits: b1b01fd [Tathagata Das] Updated streaming kinesis project name --- dev/sparktestsupport/modules.py | 4 ++-- extras/kinesis-asl/pom.xml | 2 +- project/SparkBuild.scala | 4 ++-- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index d82c0cca37bc6..346452f3174e4 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -134,7 +134,7 @@ def contains_file(self, filename): # files in streaming_kinesis_asl are changed, so that if Kinesis experiences an outage, we don't # fail other PRs. streaming_kinesis_asl = Module( - name="kinesis-asl", + name="streaming-kinesis-asl", dependencies=[], source_file_regexes=[ "extras/kinesis-asl/", @@ -147,7 +147,7 @@ def contains_file(self, filename): "ENABLE_KINESIS_TESTS": "1" }, sbt_test_goals=[ - "kinesis-asl/test", + "streaming-kinesis-asl/test", ] ) diff --git a/extras/kinesis-asl/pom.xml b/extras/kinesis-asl/pom.xml index c242e7a57b9ab..521b53e230c4a 100644 --- a/extras/kinesis-asl/pom.xml +++ b/extras/kinesis-asl/pom.xml @@ -31,7 +31,7 @@ Spark Kinesis Integration - kinesis-asl + streaming-kinesis-asl diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 41a85fa9de778..cad7067ade8c1 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -42,8 +42,8 @@ object BuildCommons { "streaming-zeromq", "launcher", "unsafe").map(ProjectRef(buildLocation, _)) val optionallyEnabledProjects@Seq(yarn, yarnStable, java8Tests, sparkGangliaLgpl, - sparkKinesisAsl) = Seq("yarn", "yarn-stable", "java8-tests", "ganglia-lgpl", - "kinesis-asl").map(ProjectRef(buildLocation, _)) + streamingKinesisAsl) = Seq("yarn", "yarn-stable", "java8-tests", "ganglia-lgpl", + "streaming-kinesis-asl").map(ProjectRef(buildLocation, _)) val assemblyProjects@Seq(assembly, examples, networkYarn, streamingFlumeAssembly, streamingKafkaAssembly, streamingMqttAssembly, streamingKinesisAslAssembly) = Seq("assembly", "examples", "network-yarn", "streaming-flume-assembly", "streaming-kafka-assembly", "streaming-mqtt-assembly", "streaming-kinesis-asl-assembly") From d378396f86f625f006738d87fe5dbc2ff8fd913d Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Tue, 11 Aug 2015 08:41:06 -0700 Subject: [PATCH 018/253] [SPARK-9815] Rename PlatformDependent.UNSAFE -> Platform. PlatformDependent.UNSAFE is way too verbose. Author: Reynold Xin Closes #8094 from rxin/SPARK-9815 and squashes the following commits: 229b603 [Reynold Xin] [SPARK-9815] Rename PlatformDependent.UNSAFE -> Platform. --- .../serializer/DummySerializerInstance.java | 6 +- .../unsafe/UnsafeShuffleExternalSorter.java | 22 +-- .../shuffle/unsafe/UnsafeShuffleWriter.java | 4 +- .../spark/unsafe/map/BytesToBytesMap.java | 20 +-- .../unsafe/sort/PrefixComparators.java | 5 +- .../unsafe/sort/UnsafeExternalSorter.java | 22 +-- .../unsafe/sort/UnsafeInMemorySorter.java | 4 +- .../unsafe/sort/UnsafeSorterSpillReader.java | 4 +- .../unsafe/sort/UnsafeSorterSpillWriter.java | 6 +- .../UnsafeShuffleInMemorySorterSuite.java | 20 +-- .../map/AbstractBytesToBytesMapSuite.java | 94 +++++----- .../sort/UnsafeExternalSorterSuite.java | 20 +-- .../sort/UnsafeInMemorySorterSuite.java | 20 +-- .../catalyst/expressions/UnsafeArrayData.java | 51 ++---- .../catalyst/expressions/UnsafeReaders.java | 8 +- .../sql/catalyst/expressions/UnsafeRow.java | 108 +++++------ .../expressions/UnsafeRowWriters.java | 41 ++--- .../catalyst/expressions/UnsafeWriters.java | 43 ++--- .../execution/UnsafeExternalRowSorter.java | 4 +- .../expressions/codegen/CodeGenerator.scala | 4 +- .../codegen/GenerateUnsafeProjection.scala | 32 ++-- .../codegen/GenerateUnsafeRowJoiner.scala | 16 +- .../expressions/stringOperations.scala | 4 +- .../GenerateUnsafeRowJoinerBitsetSuite.scala | 4 +- .../UnsafeFixedWidthAggregationMap.java | 4 +- .../sql/execution/UnsafeKVExternalSorter.java | 4 +- .../sql/execution/UnsafeRowSerializer.scala | 6 +- .../sql/execution/joins/HashedRelation.scala | 13 +- .../org/apache/spark/sql/UnsafeRowSuite.scala | 4 +- .../{PlatformDependent.java => Platform.java} | 170 ++++++++---------- .../spark/unsafe/array/ByteArrayMethods.java | 14 +- .../apache/spark/unsafe/array/LongArray.java | 6 +- .../spark/unsafe/bitset/BitSetMethods.java | 19 +- .../spark/unsafe/hash/Murmur3_x86_32.java | 4 +- .../spark/unsafe/memory/MemoryBlock.java | 4 +- .../unsafe/memory/UnsafeMemoryAllocator.java | 6 +- .../apache/spark/unsafe/types/ByteArray.java | 10 +- .../apache/spark/unsafe/types/UTF8String.java | 30 ++-- .../unsafe/hash/Murmur3_x86_32Suite.java | 14 +- 39 files changed, 371 insertions(+), 499 deletions(-) rename unsafe/src/main/java/org/apache/spark/unsafe/{PlatformDependent.java => Platform.java} (55%) diff --git a/core/src/main/java/org/apache/spark/serializer/DummySerializerInstance.java b/core/src/main/java/org/apache/spark/serializer/DummySerializerInstance.java index 0399abc63c235..0e58bb4f7101c 100644 --- a/core/src/main/java/org/apache/spark/serializer/DummySerializerInstance.java +++ b/core/src/main/java/org/apache/spark/serializer/DummySerializerInstance.java @@ -25,7 +25,7 @@ import scala.reflect.ClassTag; import org.apache.spark.annotation.Private; -import org.apache.spark.unsafe.PlatformDependent; +import org.apache.spark.unsafe.Platform; /** * Unfortunately, we need a serializer instance in order to construct a DiskBlockObjectWriter. @@ -49,7 +49,7 @@ public void flush() { try { s.flush(); } catch (IOException e) { - PlatformDependent.throwException(e); + Platform.throwException(e); } } @@ -64,7 +64,7 @@ public void close() { try { s.close(); } catch (IOException e) { - PlatformDependent.throwException(e); + Platform.throwException(e); } } }; diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java index 925b60a145886..3d1ef0c48adc5 100644 --- a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java +++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java @@ -37,7 +37,7 @@ import org.apache.spark.storage.BlockManager; import org.apache.spark.storage.DiskBlockObjectWriter; import org.apache.spark.storage.TempShuffleBlockId; -import org.apache.spark.unsafe.PlatformDependent; +import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.array.ByteArrayMethods; import org.apache.spark.unsafe.memory.MemoryBlock; import org.apache.spark.unsafe.memory.TaskMemoryManager; @@ -211,16 +211,12 @@ private void writeSortedFile(boolean isLastFile) throws IOException { final long recordPointer = sortedRecords.packedRecordPointer.getRecordPointer(); final Object recordPage = taskMemoryManager.getPage(recordPointer); final long recordOffsetInPage = taskMemoryManager.getOffsetInPage(recordPointer); - int dataRemaining = PlatformDependent.UNSAFE.getInt(recordPage, recordOffsetInPage); + int dataRemaining = Platform.getInt(recordPage, recordOffsetInPage); long recordReadPosition = recordOffsetInPage + 4; // skip over record length while (dataRemaining > 0) { final int toTransfer = Math.min(DISK_WRITE_BUFFER_SIZE, dataRemaining); - PlatformDependent.copyMemory( - recordPage, - recordReadPosition, - writeBuffer, - PlatformDependent.BYTE_ARRAY_OFFSET, - toTransfer); + Platform.copyMemory( + recordPage, recordReadPosition, writeBuffer, Platform.BYTE_ARRAY_OFFSET, toTransfer); writer.write(writeBuffer, 0, toTransfer); recordReadPosition += toTransfer; dataRemaining -= toTransfer; @@ -447,14 +443,10 @@ public void insertRecord( final long recordAddress = taskMemoryManager.encodePageNumberAndOffset(dataPage, dataPagePosition); - PlatformDependent.UNSAFE.putInt(dataPageBaseObject, dataPagePosition, lengthInBytes); + Platform.putInt(dataPageBaseObject, dataPagePosition, lengthInBytes); dataPagePosition += 4; - PlatformDependent.copyMemory( - recordBaseObject, - recordBaseOffset, - dataPageBaseObject, - dataPagePosition, - lengthInBytes); + Platform.copyMemory( + recordBaseObject, recordBaseOffset, dataPageBaseObject, dataPagePosition, lengthInBytes); assert(inMemSorter != null); inMemSorter.insertRecord(recordAddress, partitionId); } diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java index 02084f9122e00..2389c28b28395 100644 --- a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java @@ -53,7 +53,7 @@ import org.apache.spark.shuffle.ShuffleWriter; import org.apache.spark.storage.BlockManager; import org.apache.spark.storage.TimeTrackingOutputStream; -import org.apache.spark.unsafe.PlatformDependent; +import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.memory.TaskMemoryManager; @Private @@ -244,7 +244,7 @@ void insertRecordIntoSorter(Product2 record) throws IOException { assert (serializedRecordSize > 0); sorter.insertRecord( - serBuffer.getBuf(), PlatformDependent.BYTE_ARRAY_OFFSET, serializedRecordSize, partitionId); + serBuffer.getBuf(), Platform.BYTE_ARRAY_OFFSET, serializedRecordSize, partitionId); } @VisibleForTesting diff --git a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java index 7f79cd13aab43..85b46ec8bfae3 100644 --- a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java +++ b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java @@ -270,10 +270,10 @@ public boolean hasNext() { @Override public Location next() { - int totalLength = PlatformDependent.UNSAFE.getInt(pageBaseObject, offsetInPage); + int totalLength = Platform.getInt(pageBaseObject, offsetInPage); if (totalLength == END_OF_PAGE_MARKER) { advanceToNextPage(); - totalLength = PlatformDependent.UNSAFE.getInt(pageBaseObject, offsetInPage); + totalLength = Platform.getInt(pageBaseObject, offsetInPage); } loc.with(currentPage, offsetInPage); offsetInPage += 4 + totalLength; @@ -402,9 +402,9 @@ private void updateAddressesAndSizes(long fullKeyAddress) { private void updateAddressesAndSizes(final Object page, final long offsetInPage) { long position = offsetInPage; - final int totalLength = PlatformDependent.UNSAFE.getInt(page, position); + final int totalLength = Platform.getInt(page, position); position += 4; - keyLength = PlatformDependent.UNSAFE.getInt(page, position); + keyLength = Platform.getInt(page, position); position += 4; valueLength = totalLength - keyLength - 4; @@ -572,7 +572,7 @@ public boolean putNewKey( // There wasn't enough space in the current page, so write an end-of-page marker: final Object pageBaseObject = currentDataPage.getBaseObject(); final long lengthOffsetInPage = currentDataPage.getBaseOffset() + pageCursor; - PlatformDependent.UNSAFE.putInt(pageBaseObject, lengthOffsetInPage, END_OF_PAGE_MARKER); + Platform.putInt(pageBaseObject, lengthOffsetInPage, END_OF_PAGE_MARKER); } final long memoryGranted = shuffleMemoryManager.tryToAcquire(pageSizeBytes); if (memoryGranted != pageSizeBytes) { @@ -608,21 +608,21 @@ public boolean putNewKey( final long valueDataOffsetInPage = insertCursor; insertCursor += valueLengthBytes; // word used to store the value size - PlatformDependent.UNSAFE.putInt(dataPageBaseObject, recordOffset, + Platform.putInt(dataPageBaseObject, recordOffset, keyLengthBytes + valueLengthBytes + 4); - PlatformDependent.UNSAFE.putInt(dataPageBaseObject, keyLengthOffset, keyLengthBytes); + Platform.putInt(dataPageBaseObject, keyLengthOffset, keyLengthBytes); // Copy the key - PlatformDependent.copyMemory( + Platform.copyMemory( keyBaseObject, keyBaseOffset, dataPageBaseObject, keyDataOffsetInPage, keyLengthBytes); // Copy the value - PlatformDependent.copyMemory(valueBaseObject, valueBaseOffset, dataPageBaseObject, + Platform.copyMemory(valueBaseObject, valueBaseOffset, dataPageBaseObject, valueDataOffsetInPage, valueLengthBytes); // --- Update bookeeping data structures ----------------------------------------------------- if (useOverflowPage) { // Store the end-of-page marker at the end of the data page - PlatformDependent.UNSAFE.putInt(dataPageBaseObject, insertCursor, END_OF_PAGE_MARKER); + Platform.putInt(dataPageBaseObject, insertCursor, END_OF_PAGE_MARKER); } else { pageCursor += requiredSize; } diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java index 5e002ae1b7568..71b76d5ddfaa7 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java @@ -20,10 +20,9 @@ import com.google.common.primitives.UnsignedLongs; import org.apache.spark.annotation.Private; -import org.apache.spark.unsafe.PlatformDependent; +import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.types.UTF8String; import org.apache.spark.util.Utils; -import static org.apache.spark.unsafe.PlatformDependent.BYTE_ARRAY_OFFSET; @Private public class PrefixComparators { @@ -73,7 +72,7 @@ public static long computePrefix(byte[] bytes) { final int minLen = Math.min(bytes.length, 8); long p = 0; for (int i = 0; i < minLen; ++i) { - p |= (128L + PlatformDependent.UNSAFE.getByte(bytes, BYTE_ARRAY_OFFSET + i)) + p |= (128L + Platform.getByte(bytes, Platform.BYTE_ARRAY_OFFSET + i)) << (56 - 8 * i); } return p; diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java index 5ebbf9b068fd6..9601aafe55464 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java @@ -35,7 +35,7 @@ import org.apache.spark.shuffle.ShuffleMemoryManager; import org.apache.spark.storage.BlockManager; import org.apache.spark.unsafe.array.ByteArrayMethods; -import org.apache.spark.unsafe.PlatformDependent; +import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.memory.MemoryBlock; import org.apache.spark.unsafe.memory.TaskMemoryManager; import org.apache.spark.util.Utils; @@ -427,14 +427,10 @@ public void insertRecord( final long recordAddress = taskMemoryManager.encodePageNumberAndOffset(dataPage, dataPagePosition); - PlatformDependent.UNSAFE.putInt(dataPageBaseObject, dataPagePosition, lengthInBytes); + Platform.putInt(dataPageBaseObject, dataPagePosition, lengthInBytes); dataPagePosition += 4; - PlatformDependent.copyMemory( - recordBaseObject, - recordBaseOffset, - dataPageBaseObject, - dataPagePosition, - lengthInBytes); + Platform.copyMemory( + recordBaseObject, recordBaseOffset, dataPageBaseObject, dataPagePosition, lengthInBytes); assert(inMemSorter != null); inMemSorter.insertRecord(recordAddress, prefix); } @@ -493,18 +489,16 @@ public void insertKVRecord( final long recordAddress = taskMemoryManager.encodePageNumberAndOffset(dataPage, dataPagePosition); - PlatformDependent.UNSAFE.putInt(dataPageBaseObject, dataPagePosition, keyLen + valueLen + 4); + Platform.putInt(dataPageBaseObject, dataPagePosition, keyLen + valueLen + 4); dataPagePosition += 4; - PlatformDependent.UNSAFE.putInt(dataPageBaseObject, dataPagePosition, keyLen); + Platform.putInt(dataPageBaseObject, dataPagePosition, keyLen); dataPagePosition += 4; - PlatformDependent.copyMemory( - keyBaseObj, keyOffset, dataPageBaseObject, dataPagePosition, keyLen); + Platform.copyMemory(keyBaseObj, keyOffset, dataPageBaseObject, dataPagePosition, keyLen); dataPagePosition += keyLen; - PlatformDependent.copyMemory( - valueBaseObj, valueOffset, dataPageBaseObject, dataPagePosition, valueLen); + Platform.copyMemory(valueBaseObj, valueOffset, dataPageBaseObject, dataPagePosition, valueLen); assert(inMemSorter != null); inMemSorter.insertRecord(recordAddress, prefix); diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java index 1e4b8a116e11a..f7787e1019c2b 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java @@ -19,7 +19,7 @@ import java.util.Comparator; -import org.apache.spark.unsafe.PlatformDependent; +import org.apache.spark.unsafe.Platform; import org.apache.spark.util.collection.Sorter; import org.apache.spark.unsafe.memory.TaskMemoryManager; @@ -164,7 +164,7 @@ public void loadNext() { final long recordPointer = sortBuffer[position]; baseObject = memoryManager.getPage(recordPointer); baseOffset = memoryManager.getOffsetInPage(recordPointer) + 4; // Skip over record length - recordLength = PlatformDependent.UNSAFE.getInt(baseObject, baseOffset - 4); + recordLength = Platform.getInt(baseObject, baseOffset - 4); keyPrefix = sortBuffer[position + 1]; position += 2; } diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java index ca1ccedc93c8e..4989b05d63e23 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java @@ -23,7 +23,7 @@ import org.apache.spark.storage.BlockId; import org.apache.spark.storage.BlockManager; -import org.apache.spark.unsafe.PlatformDependent; +import org.apache.spark.unsafe.Platform; /** * Reads spill files written by {@link UnsafeSorterSpillWriter} (see that class for a description @@ -42,7 +42,7 @@ final class UnsafeSorterSpillReader extends UnsafeSorterIterator { private byte[] arr = new byte[1024 * 1024]; private Object baseObject = arr; - private final long baseOffset = PlatformDependent.BYTE_ARRAY_OFFSET; + private final long baseOffset = Platform.BYTE_ARRAY_OFFSET; public UnsafeSorterSpillReader( BlockManager blockManager, diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java index 44cf6c756d7c3..e59a84ff8d118 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java @@ -28,7 +28,7 @@ import org.apache.spark.storage.BlockManager; import org.apache.spark.storage.DiskBlockObjectWriter; import org.apache.spark.storage.TempLocalBlockId; -import org.apache.spark.unsafe.PlatformDependent; +import org.apache.spark.unsafe.Platform; /** * Spills a list of sorted records to disk. Spill files have the following format: @@ -117,11 +117,11 @@ public void write( long recordReadPosition = baseOffset; while (dataRemaining > 0) { final int toTransfer = Math.min(freeSpaceInWriteBuffer, dataRemaining); - PlatformDependent.copyMemory( + Platform.copyMemory( baseObject, recordReadPosition, writeBuffer, - PlatformDependent.BYTE_ARRAY_OFFSET + (DISK_WRITE_BUFFER_SIZE - freeSpaceInWriteBuffer), + Platform.BYTE_ARRAY_OFFSET + (DISK_WRITE_BUFFER_SIZE - freeSpaceInWriteBuffer), toTransfer); writer.write(writeBuffer, 0, (DISK_WRITE_BUFFER_SIZE - freeSpaceInWriteBuffer) + toTransfer); recordReadPosition += toTransfer; diff --git a/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleInMemorySorterSuite.java b/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleInMemorySorterSuite.java index 8fa72597db24d..40fefe2c9d140 100644 --- a/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleInMemorySorterSuite.java +++ b/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleInMemorySorterSuite.java @@ -24,7 +24,7 @@ import org.junit.Test; import org.apache.spark.HashPartitioner; -import org.apache.spark.unsafe.PlatformDependent; +import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.memory.ExecutorMemoryManager; import org.apache.spark.unsafe.memory.MemoryAllocator; import org.apache.spark.unsafe.memory.MemoryBlock; @@ -34,11 +34,7 @@ public class UnsafeShuffleInMemorySorterSuite { private static String getStringFromDataPage(Object baseObject, long baseOffset, int strLength) { final byte[] strBytes = new byte[strLength]; - PlatformDependent.copyMemory( - baseObject, - baseOffset, - strBytes, - PlatformDependent.BYTE_ARRAY_OFFSET, strLength); + Platform.copyMemory(baseObject, baseOffset, strBytes, Platform.BYTE_ARRAY_OFFSET, strLength); return new String(strBytes); } @@ -74,14 +70,10 @@ public void testBasicSorting() throws Exception { for (String str : dataToSort) { final long recordAddress = memoryManager.encodePageNumberAndOffset(dataPage, position); final byte[] strBytes = str.getBytes("utf-8"); - PlatformDependent.UNSAFE.putInt(baseObject, position, strBytes.length); + Platform.putInt(baseObject, position, strBytes.length); position += 4; - PlatformDependent.copyMemory( - strBytes, - PlatformDependent.BYTE_ARRAY_OFFSET, - baseObject, - position, - strBytes.length); + Platform.copyMemory( + strBytes, Platform.BYTE_ARRAY_OFFSET, baseObject, position, strBytes.length); position += strBytes.length; sorter.insertRecord(recordAddress, hashPartitioner.getPartition(str)); } @@ -98,7 +90,7 @@ public void testBasicSorting() throws Exception { Assert.assertTrue("Partition id " + partitionId + " should be >= prev id " + prevPartitionId, partitionId >= prevPartitionId); final long recordAddress = iter.packedRecordPointer.getRecordPointer(); - final int recordLength = PlatformDependent.UNSAFE.getInt( + final int recordLength = Platform.getInt( memoryManager.getPage(recordAddress), memoryManager.getOffsetInPage(recordAddress)); final String str = getStringFromDataPage( memoryManager.getPage(recordAddress), diff --git a/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java b/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java index e56a3f0b6d12c..1a79c20c35246 100644 --- a/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java +++ b/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java @@ -32,9 +32,7 @@ import org.apache.spark.shuffle.ShuffleMemoryManager; import org.apache.spark.unsafe.array.ByteArrayMethods; import org.apache.spark.unsafe.memory.*; -import org.apache.spark.unsafe.PlatformDependent; -import static org.apache.spark.unsafe.PlatformDependent.BYTE_ARRAY_OFFSET; -import static org.apache.spark.unsafe.PlatformDependent.LONG_ARRAY_OFFSET; +import org.apache.spark.unsafe.Platform; public abstract class AbstractBytesToBytesMapSuite { @@ -80,13 +78,8 @@ public void tearDown() { private static byte[] getByteArray(MemoryLocation loc, int size) { final byte[] arr = new byte[size]; - PlatformDependent.copyMemory( - loc.getBaseObject(), - loc.getBaseOffset(), - arr, - BYTE_ARRAY_OFFSET, - size - ); + Platform.copyMemory( + loc.getBaseObject(), loc.getBaseOffset(), arr, Platform.BYTE_ARRAY_OFFSET, size); return arr; } @@ -108,7 +101,7 @@ private static boolean arrayEquals( long actualLengthBytes) { return (actualLengthBytes == expected.length) && ByteArrayMethods.arrayEquals( expected, - BYTE_ARRAY_OFFSET, + Platform.BYTE_ARRAY_OFFSET, actualAddr.getBaseObject(), actualAddr.getBaseOffset(), expected.length @@ -124,7 +117,7 @@ public void emptyMap() { final int keyLengthInWords = 10; final int keyLengthInBytes = keyLengthInWords * 8; final byte[] key = getRandomByteArray(keyLengthInWords); - Assert.assertFalse(map.lookup(key, BYTE_ARRAY_OFFSET, keyLengthInBytes).isDefined()); + Assert.assertFalse(map.lookup(key, Platform.BYTE_ARRAY_OFFSET, keyLengthInBytes).isDefined()); Assert.assertFalse(map.iterator().hasNext()); } finally { map.free(); @@ -141,14 +134,14 @@ public void setAndRetrieveAKey() { final byte[] valueData = getRandomByteArray(recordLengthWords); try { final BytesToBytesMap.Location loc = - map.lookup(keyData, BYTE_ARRAY_OFFSET, recordLengthBytes); + map.lookup(keyData, Platform.BYTE_ARRAY_OFFSET, recordLengthBytes); Assert.assertFalse(loc.isDefined()); Assert.assertTrue(loc.putNewKey( keyData, - BYTE_ARRAY_OFFSET, + Platform.BYTE_ARRAY_OFFSET, recordLengthBytes, valueData, - BYTE_ARRAY_OFFSET, + Platform.BYTE_ARRAY_OFFSET, recordLengthBytes )); // After storing the key and value, the other location methods should return results that @@ -159,7 +152,8 @@ public void setAndRetrieveAKey() { Assert.assertArrayEquals(valueData, getByteArray(loc.getValueAddress(), recordLengthBytes)); // After calling lookup() the location should still point to the correct data. - Assert.assertTrue(map.lookup(keyData, BYTE_ARRAY_OFFSET, recordLengthBytes).isDefined()); + Assert.assertTrue( + map.lookup(keyData, Platform.BYTE_ARRAY_OFFSET, recordLengthBytes).isDefined()); Assert.assertEquals(recordLengthBytes, loc.getKeyLength()); Assert.assertEquals(recordLengthBytes, loc.getValueLength()); Assert.assertArrayEquals(keyData, getByteArray(loc.getKeyAddress(), recordLengthBytes)); @@ -168,10 +162,10 @@ public void setAndRetrieveAKey() { try { Assert.assertTrue(loc.putNewKey( keyData, - BYTE_ARRAY_OFFSET, + Platform.BYTE_ARRAY_OFFSET, recordLengthBytes, valueData, - BYTE_ARRAY_OFFSET, + Platform.BYTE_ARRAY_OFFSET, recordLengthBytes )); Assert.fail("Should not be able to set a new value for a key"); @@ -191,25 +185,25 @@ private void iteratorTestBase(boolean destructive) throws Exception { for (long i = 0; i < size; i++) { final long[] value = new long[] { i }; final BytesToBytesMap.Location loc = - map.lookup(value, PlatformDependent.LONG_ARRAY_OFFSET, 8); + map.lookup(value, Platform.LONG_ARRAY_OFFSET, 8); Assert.assertFalse(loc.isDefined()); // Ensure that we store some zero-length keys if (i % 5 == 0) { Assert.assertTrue(loc.putNewKey( null, - PlatformDependent.LONG_ARRAY_OFFSET, + Platform.LONG_ARRAY_OFFSET, 0, value, - PlatformDependent.LONG_ARRAY_OFFSET, + Platform.LONG_ARRAY_OFFSET, 8 )); } else { Assert.assertTrue(loc.putNewKey( value, - PlatformDependent.LONG_ARRAY_OFFSET, + Platform.LONG_ARRAY_OFFSET, 8, value, - PlatformDependent.LONG_ARRAY_OFFSET, + Platform.LONG_ARRAY_OFFSET, 8 )); } @@ -228,14 +222,13 @@ private void iteratorTestBase(boolean destructive) throws Exception { Assert.assertTrue(loc.isDefined()); final MemoryLocation keyAddress = loc.getKeyAddress(); final MemoryLocation valueAddress = loc.getValueAddress(); - final long value = PlatformDependent.UNSAFE.getLong( + final long value = Platform.getLong( valueAddress.getBaseObject(), valueAddress.getBaseOffset()); final long keyLength = loc.getKeyLength(); if (keyLength == 0) { Assert.assertTrue("value " + value + " was not divisible by 5", value % 5 == 0); } else { - final long key = PlatformDependent.UNSAFE.getLong( - keyAddress.getBaseObject(), keyAddress.getBaseOffset()); + final long key = Platform.getLong(keyAddress.getBaseObject(), keyAddress.getBaseOffset()); Assert.assertEquals(value, key); } valuesSeen.set((int) value); @@ -284,16 +277,16 @@ public void iteratingOverDataPagesWithWastedSpace() throws Exception { final long[] value = new long[] { i, i, i, i, i }; // 5 * 8 = 40 bytes final BytesToBytesMap.Location loc = map.lookup( key, - LONG_ARRAY_OFFSET, + Platform.LONG_ARRAY_OFFSET, KEY_LENGTH ); Assert.assertFalse(loc.isDefined()); Assert.assertTrue(loc.putNewKey( key, - LONG_ARRAY_OFFSET, + Platform.LONG_ARRAY_OFFSET, KEY_LENGTH, value, - LONG_ARRAY_OFFSET, + Platform.LONG_ARRAY_OFFSET, VALUE_LENGTH )); } @@ -308,18 +301,18 @@ public void iteratingOverDataPagesWithWastedSpace() throws Exception { Assert.assertTrue(loc.isDefined()); Assert.assertEquals(KEY_LENGTH, loc.getKeyLength()); Assert.assertEquals(VALUE_LENGTH, loc.getValueLength()); - PlatformDependent.copyMemory( + Platform.copyMemory( loc.getKeyAddress().getBaseObject(), loc.getKeyAddress().getBaseOffset(), key, - LONG_ARRAY_OFFSET, + Platform.LONG_ARRAY_OFFSET, KEY_LENGTH ); - PlatformDependent.copyMemory( + Platform.copyMemory( loc.getValueAddress().getBaseObject(), loc.getValueAddress().getBaseOffset(), value, - LONG_ARRAY_OFFSET, + Platform.LONG_ARRAY_OFFSET, VALUE_LENGTH ); for (long j : key) { @@ -354,16 +347,16 @@ public void randomizedStressTest() { expected.put(ByteBuffer.wrap(key), value); final BytesToBytesMap.Location loc = map.lookup( key, - BYTE_ARRAY_OFFSET, + Platform.BYTE_ARRAY_OFFSET, key.length ); Assert.assertFalse(loc.isDefined()); Assert.assertTrue(loc.putNewKey( key, - BYTE_ARRAY_OFFSET, + Platform.BYTE_ARRAY_OFFSET, key.length, value, - BYTE_ARRAY_OFFSET, + Platform.BYTE_ARRAY_OFFSET, value.length )); // After calling putNewKey, the following should be true, even before calling @@ -379,7 +372,8 @@ public void randomizedStressTest() { for (Map.Entry entry : expected.entrySet()) { final byte[] key = entry.getKey().array(); final byte[] value = entry.getValue(); - final BytesToBytesMap.Location loc = map.lookup(key, BYTE_ARRAY_OFFSET, key.length); + final BytesToBytesMap.Location loc = + map.lookup(key, Platform.BYTE_ARRAY_OFFSET, key.length); Assert.assertTrue(loc.isDefined()); Assert.assertTrue(arrayEquals(key, loc.getKeyAddress(), loc.getKeyLength())); Assert.assertTrue(arrayEquals(value, loc.getValueAddress(), loc.getValueLength())); @@ -405,16 +399,16 @@ public void randomizedTestWithRecordsLargerThanPageSize() { expected.put(ByteBuffer.wrap(key), value); final BytesToBytesMap.Location loc = map.lookup( key, - BYTE_ARRAY_OFFSET, + Platform.BYTE_ARRAY_OFFSET, key.length ); Assert.assertFalse(loc.isDefined()); Assert.assertTrue(loc.putNewKey( key, - BYTE_ARRAY_OFFSET, + Platform.BYTE_ARRAY_OFFSET, key.length, value, - BYTE_ARRAY_OFFSET, + Platform.BYTE_ARRAY_OFFSET, value.length )); // After calling putNewKey, the following should be true, even before calling @@ -429,7 +423,8 @@ public void randomizedTestWithRecordsLargerThanPageSize() { for (Map.Entry entry : expected.entrySet()) { final byte[] key = entry.getKey().array(); final byte[] value = entry.getValue(); - final BytesToBytesMap.Location loc = map.lookup(key, BYTE_ARRAY_OFFSET, key.length); + final BytesToBytesMap.Location loc = + map.lookup(key, Platform.BYTE_ARRAY_OFFSET, key.length); Assert.assertTrue(loc.isDefined()); Assert.assertTrue(arrayEquals(key, loc.getKeyAddress(), loc.getKeyLength())); Assert.assertTrue(arrayEquals(value, loc.getValueAddress(), loc.getValueLength())); @@ -447,12 +442,10 @@ public void failureToAllocateFirstPage() { try { final long[] emptyArray = new long[0]; final BytesToBytesMap.Location loc = - map.lookup(emptyArray, PlatformDependent.LONG_ARRAY_OFFSET, 0); + map.lookup(emptyArray, Platform.LONG_ARRAY_OFFSET, 0); Assert.assertFalse(loc.isDefined()); Assert.assertFalse(loc.putNewKey( - emptyArray, LONG_ARRAY_OFFSET, 0, - emptyArray, LONG_ARRAY_OFFSET, 0 - )); + emptyArray, Platform.LONG_ARRAY_OFFSET, 0, emptyArray, Platform.LONG_ARRAY_OFFSET, 0)); } finally { map.free(); } @@ -468,8 +461,9 @@ public void failureToGrow() { int i; for (i = 0; i < 1024; i++) { final long[] arr = new long[]{i}; - final BytesToBytesMap.Location loc = map.lookup(arr, PlatformDependent.LONG_ARRAY_OFFSET, 8); - success = loc.putNewKey(arr, LONG_ARRAY_OFFSET, 8, arr, LONG_ARRAY_OFFSET, 8); + final BytesToBytesMap.Location loc = map.lookup(arr, Platform.LONG_ARRAY_OFFSET, 8); + success = + loc.putNewKey(arr, Platform.LONG_ARRAY_OFFSET, 8, arr, Platform.LONG_ARRAY_OFFSET, 8); if (!success) { break; } @@ -541,12 +535,12 @@ public void testPeakMemoryUsed() { try { for (long i = 0; i < numRecordsPerPage * 10; i++) { final long[] value = new long[]{i}; - map.lookup(value, PlatformDependent.LONG_ARRAY_OFFSET, 8).putNewKey( + map.lookup(value, Platform.LONG_ARRAY_OFFSET, 8).putNewKey( value, - PlatformDependent.LONG_ARRAY_OFFSET, + Platform.LONG_ARRAY_OFFSET, 8, value, - PlatformDependent.LONG_ARRAY_OFFSET, + Platform.LONG_ARRAY_OFFSET, 8); newPeakMemory = map.getPeakMemoryUsedBytes(); if (i % numRecordsPerPage == 0) { diff --git a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java index 83049b8a21fcf..445a37b83e98a 100644 --- a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java +++ b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java @@ -49,7 +49,7 @@ import org.apache.spark.serializer.SerializerInstance; import org.apache.spark.shuffle.ShuffleMemoryManager; import org.apache.spark.storage.*; -import org.apache.spark.unsafe.PlatformDependent; +import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.memory.ExecutorMemoryManager; import org.apache.spark.unsafe.memory.MemoryAllocator; import org.apache.spark.unsafe.memory.TaskMemoryManager; @@ -166,14 +166,14 @@ private void assertSpillFilesWereCleanedUp() { private static void insertNumber(UnsafeExternalSorter sorter, int value) throws Exception { final int[] arr = new int[]{ value }; - sorter.insertRecord(arr, PlatformDependent.INT_ARRAY_OFFSET, 4, value); + sorter.insertRecord(arr, Platform.INT_ARRAY_OFFSET, 4, value); } private static void insertRecord( UnsafeExternalSorter sorter, int[] record, long prefix) throws IOException { - sorter.insertRecord(record, PlatformDependent.INT_ARRAY_OFFSET, record.length * 4, prefix); + sorter.insertRecord(record, Platform.INT_ARRAY_OFFSET, record.length * 4, prefix); } private UnsafeExternalSorter newSorter() throws IOException { @@ -205,7 +205,7 @@ public void testSortingOnlyByPrefix() throws Exception { iter.loadNext(); assertEquals(i, iter.getKeyPrefix()); assertEquals(4, iter.getRecordLength()); - assertEquals(i, PlatformDependent.UNSAFE.getInt(iter.getBaseObject(), iter.getBaseOffset())); + assertEquals(i, Platform.getInt(iter.getBaseObject(), iter.getBaseOffset())); } sorter.cleanupResources(); @@ -253,7 +253,7 @@ public void spillingOccursInResponseToMemoryPressure() throws Exception { iter.loadNext(); assertEquals(i, iter.getKeyPrefix()); assertEquals(4, iter.getRecordLength()); - assertEquals(i, PlatformDependent.UNSAFE.getInt(iter.getBaseObject(), iter.getBaseOffset())); + assertEquals(i, Platform.getInt(iter.getBaseObject(), iter.getBaseOffset())); i++; } sorter.cleanupResources(); @@ -265,7 +265,7 @@ public void testFillingPage() throws Exception { final UnsafeExternalSorter sorter = newSorter(); byte[] record = new byte[16]; while (sorter.getNumberOfAllocatedPages() < 2) { - sorter.insertRecord(record, PlatformDependent.BYTE_ARRAY_OFFSET, record.length, 0); + sorter.insertRecord(record, Platform.BYTE_ARRAY_OFFSET, record.length, 0); } sorter.cleanupResources(); assertSpillFilesWereCleanedUp(); @@ -292,25 +292,25 @@ public void sortingRecordsThatExceedPageSize() throws Exception { iter.loadNext(); assertEquals(123, iter.getKeyPrefix()); assertEquals(smallRecord.length * 4, iter.getRecordLength()); - assertEquals(123, PlatformDependent.UNSAFE.getInt(iter.getBaseObject(), iter.getBaseOffset())); + assertEquals(123, Platform.getInt(iter.getBaseObject(), iter.getBaseOffset())); // Small record assertTrue(iter.hasNext()); iter.loadNext(); assertEquals(123, iter.getKeyPrefix()); assertEquals(smallRecord.length * 4, iter.getRecordLength()); - assertEquals(123, PlatformDependent.UNSAFE.getInt(iter.getBaseObject(), iter.getBaseOffset())); + assertEquals(123, Platform.getInt(iter.getBaseObject(), iter.getBaseOffset())); // Large record assertTrue(iter.hasNext()); iter.loadNext(); assertEquals(456, iter.getKeyPrefix()); assertEquals(largeRecord.length * 4, iter.getRecordLength()); - assertEquals(456, PlatformDependent.UNSAFE.getInt(iter.getBaseObject(), iter.getBaseOffset())); + assertEquals(456, Platform.getInt(iter.getBaseObject(), iter.getBaseOffset())); // Large record assertTrue(iter.hasNext()); iter.loadNext(); assertEquals(456, iter.getKeyPrefix()); assertEquals(largeRecord.length * 4, iter.getRecordLength()); - assertEquals(456, PlatformDependent.UNSAFE.getInt(iter.getBaseObject(), iter.getBaseOffset())); + assertEquals(456, Platform.getInt(iter.getBaseObject(), iter.getBaseOffset())); assertFalse(iter.hasNext()); sorter.cleanupResources(); diff --git a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java index 909500930539c..778e813df6b54 100644 --- a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java +++ b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java @@ -26,7 +26,7 @@ import static org.mockito.Mockito.mock; import org.apache.spark.HashPartitioner; -import org.apache.spark.unsafe.PlatformDependent; +import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.memory.ExecutorMemoryManager; import org.apache.spark.unsafe.memory.MemoryAllocator; import org.apache.spark.unsafe.memory.MemoryBlock; @@ -36,11 +36,7 @@ public class UnsafeInMemorySorterSuite { private static String getStringFromDataPage(Object baseObject, long baseOffset, int length) { final byte[] strBytes = new byte[length]; - PlatformDependent.copyMemory( - baseObject, - baseOffset, - strBytes, - PlatformDependent.BYTE_ARRAY_OFFSET, length); + Platform.copyMemory(baseObject, baseOffset, strBytes, Platform.BYTE_ARRAY_OFFSET, length); return new String(strBytes); } @@ -76,14 +72,10 @@ public void testSortingOnlyByIntegerPrefix() throws Exception { long position = dataPage.getBaseOffset(); for (String str : dataToSort) { final byte[] strBytes = str.getBytes("utf-8"); - PlatformDependent.UNSAFE.putInt(baseObject, position, strBytes.length); + Platform.putInt(baseObject, position, strBytes.length); position += 4; - PlatformDependent.copyMemory( - strBytes, - PlatformDependent.BYTE_ARRAY_OFFSET, - baseObject, - position, - strBytes.length); + Platform.copyMemory( + strBytes, Platform.BYTE_ARRAY_OFFSET, baseObject, position, strBytes.length); position += strBytes.length; } // Since the key fits within the 8-byte prefix, we don't need to do any record comparison, so @@ -113,7 +105,7 @@ public int compare(long prefix1, long prefix2) { position = dataPage.getBaseOffset(); for (int i = 0; i < dataToSort.length; i++) { // position now points to the start of a record (which holds its length). - final int recordLength = PlatformDependent.UNSAFE.getInt(baseObject, position); + final int recordLength = Platform.getInt(baseObject, position); final long address = memoryManager.encodePageNumberAndOffset(dataPage, position); final String str = getStringFromDataPage(baseObject, position + 4, recordLength); final int partitionId = hashPartitioner.getPartition(str); diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java index 0374846d71674..501dff090313c 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java @@ -22,7 +22,7 @@ import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.types.*; -import org.apache.spark.unsafe.PlatformDependent; +import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.array.ByteArrayMethods; import org.apache.spark.unsafe.hash.Murmur3_x86_32; import org.apache.spark.unsafe.types.CalendarInterval; @@ -59,7 +59,7 @@ public class UnsafeArrayData extends ArrayData { private int sizeInBytes; private int getElementOffset(int ordinal) { - return PlatformDependent.UNSAFE.getInt(baseObject, baseOffset + ordinal * 4L); + return Platform.getInt(baseObject, baseOffset + ordinal * 4L); } private int getElementSize(int offset, int ordinal) { @@ -157,7 +157,7 @@ public boolean getBoolean(int ordinal) { assertIndexIsValid(ordinal); final int offset = getElementOffset(ordinal); if (offset < 0) return false; - return PlatformDependent.UNSAFE.getBoolean(baseObject, baseOffset + offset); + return Platform.getBoolean(baseObject, baseOffset + offset); } @Override @@ -165,7 +165,7 @@ public byte getByte(int ordinal) { assertIndexIsValid(ordinal); final int offset = getElementOffset(ordinal); if (offset < 0) return 0; - return PlatformDependent.UNSAFE.getByte(baseObject, baseOffset + offset); + return Platform.getByte(baseObject, baseOffset + offset); } @Override @@ -173,7 +173,7 @@ public short getShort(int ordinal) { assertIndexIsValid(ordinal); final int offset = getElementOffset(ordinal); if (offset < 0) return 0; - return PlatformDependent.UNSAFE.getShort(baseObject, baseOffset + offset); + return Platform.getShort(baseObject, baseOffset + offset); } @Override @@ -181,7 +181,7 @@ public int getInt(int ordinal) { assertIndexIsValid(ordinal); final int offset = getElementOffset(ordinal); if (offset < 0) return 0; - return PlatformDependent.UNSAFE.getInt(baseObject, baseOffset + offset); + return Platform.getInt(baseObject, baseOffset + offset); } @Override @@ -189,7 +189,7 @@ public long getLong(int ordinal) { assertIndexIsValid(ordinal); final int offset = getElementOffset(ordinal); if (offset < 0) return 0; - return PlatformDependent.UNSAFE.getLong(baseObject, baseOffset + offset); + return Platform.getLong(baseObject, baseOffset + offset); } @Override @@ -197,7 +197,7 @@ public float getFloat(int ordinal) { assertIndexIsValid(ordinal); final int offset = getElementOffset(ordinal); if (offset < 0) return 0; - return PlatformDependent.UNSAFE.getFloat(baseObject, baseOffset + offset); + return Platform.getFloat(baseObject, baseOffset + offset); } @Override @@ -205,7 +205,7 @@ public double getDouble(int ordinal) { assertIndexIsValid(ordinal); final int offset = getElementOffset(ordinal); if (offset < 0) return 0; - return PlatformDependent.UNSAFE.getDouble(baseObject, baseOffset + offset); + return Platform.getDouble(baseObject, baseOffset + offset); } @Override @@ -215,7 +215,7 @@ public Decimal getDecimal(int ordinal, int precision, int scale) { if (offset < 0) return null; if (precision <= Decimal.MAX_LONG_DIGITS()) { - final long value = PlatformDependent.UNSAFE.getLong(baseObject, baseOffset + offset); + final long value = Platform.getLong(baseObject, baseOffset + offset); return Decimal.apply(value, precision, scale); } else { final byte[] bytes = getBinary(ordinal); @@ -241,12 +241,7 @@ public byte[] getBinary(int ordinal) { if (offset < 0) return null; final int size = getElementSize(offset, ordinal); final byte[] bytes = new byte[size]; - PlatformDependent.copyMemory( - baseObject, - baseOffset + offset, - bytes, - PlatformDependent.BYTE_ARRAY_OFFSET, - size); + Platform.copyMemory(baseObject, baseOffset + offset, bytes, Platform.BYTE_ARRAY_OFFSET, size); return bytes; } @@ -255,9 +250,8 @@ public CalendarInterval getInterval(int ordinal) { assertIndexIsValid(ordinal); final int offset = getElementOffset(ordinal); if (offset < 0) return null; - final int months = (int) PlatformDependent.UNSAFE.getLong(baseObject, baseOffset + offset); - final long microseconds = - PlatformDependent.UNSAFE.getLong(baseObject, baseOffset + offset + 8); + final int months = (int) Platform.getLong(baseObject, baseOffset + offset); + final long microseconds = Platform.getLong(baseObject, baseOffset + offset + 8); return new CalendarInterval(months, microseconds); } @@ -307,27 +301,16 @@ public boolean equals(Object other) { } public void writeToMemory(Object target, long targetOffset) { - PlatformDependent.copyMemory( - baseObject, - baseOffset, - target, - targetOffset, - sizeInBytes - ); + Platform.copyMemory(baseObject, baseOffset, target, targetOffset, sizeInBytes); } @Override public UnsafeArrayData copy() { UnsafeArrayData arrayCopy = new UnsafeArrayData(); final byte[] arrayDataCopy = new byte[sizeInBytes]; - PlatformDependent.copyMemory( - baseObject, - baseOffset, - arrayDataCopy, - PlatformDependent.BYTE_ARRAY_OFFSET, - sizeInBytes - ); - arrayCopy.pointTo(arrayDataCopy, PlatformDependent.BYTE_ARRAY_OFFSET, numElements, sizeInBytes); + Platform.copyMemory( + baseObject, baseOffset, arrayDataCopy, Platform.BYTE_ARRAY_OFFSET, sizeInBytes); + arrayCopy.pointTo(arrayDataCopy, Platform.BYTE_ARRAY_OFFSET, numElements, sizeInBytes); return arrayCopy; } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeReaders.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeReaders.java index b521b703389d3..7b03185a30e3c 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeReaders.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeReaders.java @@ -17,13 +17,13 @@ package org.apache.spark.sql.catalyst.expressions; -import org.apache.spark.unsafe.PlatformDependent; +import org.apache.spark.unsafe.Platform; public class UnsafeReaders { public static UnsafeArrayData readArray(Object baseObject, long baseOffset, int numBytes) { // Read the number of elements from first 4 bytes. - final int numElements = PlatformDependent.UNSAFE.getInt(baseObject, baseOffset); + final int numElements = Platform.getInt(baseObject, baseOffset); final UnsafeArrayData array = new UnsafeArrayData(); // Skip the first 4 bytes. array.pointTo(baseObject, baseOffset + 4, numElements, numBytes - 4); @@ -32,9 +32,9 @@ public static UnsafeArrayData readArray(Object baseObject, long baseOffset, int public static UnsafeMapData readMap(Object baseObject, long baseOffset, int numBytes) { // Read the number of elements from first 4 bytes. - final int numElements = PlatformDependent.UNSAFE.getInt(baseObject, baseOffset); + final int numElements = Platform.getInt(baseObject, baseOffset); // Read the numBytes of key array in second 4 bytes. - final int keyArraySize = PlatformDependent.UNSAFE.getInt(baseObject, baseOffset + 4); + final int keyArraySize = Platform.getInt(baseObject, baseOffset + 4); final int valueArraySize = numBytes - 8 - keyArraySize; final UnsafeArrayData keyArray = new UnsafeArrayData(); diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java index e829acb6285f1..7fd94772090df 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java @@ -27,7 +27,7 @@ import java.util.Set; import org.apache.spark.sql.types.*; -import org.apache.spark.unsafe.PlatformDependent; +import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.array.ByteArrayMethods; import org.apache.spark.unsafe.bitset.BitSetMethods; import org.apache.spark.unsafe.hash.Murmur3_x86_32; @@ -169,7 +169,7 @@ public void pointTo(Object baseObject, long baseOffset, int numFields, int sizeI * @param sizeInBytes the number of bytes valid in the byte array */ public void pointTo(byte[] buf, int numFields, int sizeInBytes) { - pointTo(buf, PlatformDependent.BYTE_ARRAY_OFFSET, numFields, sizeInBytes); + pointTo(buf, Platform.BYTE_ARRAY_OFFSET, numFields, sizeInBytes); } @Override @@ -179,7 +179,7 @@ public void setNullAt(int i) { // To preserve row equality, zero out the value when setting the column to null. // Since this row does does not currently support updates to variable-length values, we don't // have to worry about zeroing out that data. - PlatformDependent.UNSAFE.putLong(baseObject, getFieldOffset(i), 0); + Platform.putLong(baseObject, getFieldOffset(i), 0); } @Override @@ -191,14 +191,14 @@ public void update(int ordinal, Object value) { public void setInt(int ordinal, int value) { assertIndexIsValid(ordinal); setNotNullAt(ordinal); - PlatformDependent.UNSAFE.putInt(baseObject, getFieldOffset(ordinal), value); + Platform.putInt(baseObject, getFieldOffset(ordinal), value); } @Override public void setLong(int ordinal, long value) { assertIndexIsValid(ordinal); setNotNullAt(ordinal); - PlatformDependent.UNSAFE.putLong(baseObject, getFieldOffset(ordinal), value); + Platform.putLong(baseObject, getFieldOffset(ordinal), value); } @Override @@ -208,28 +208,28 @@ public void setDouble(int ordinal, double value) { if (Double.isNaN(value)) { value = Double.NaN; } - PlatformDependent.UNSAFE.putDouble(baseObject, getFieldOffset(ordinal), value); + Platform.putDouble(baseObject, getFieldOffset(ordinal), value); } @Override public void setBoolean(int ordinal, boolean value) { assertIndexIsValid(ordinal); setNotNullAt(ordinal); - PlatformDependent.UNSAFE.putBoolean(baseObject, getFieldOffset(ordinal), value); + Platform.putBoolean(baseObject, getFieldOffset(ordinal), value); } @Override public void setShort(int ordinal, short value) { assertIndexIsValid(ordinal); setNotNullAt(ordinal); - PlatformDependent.UNSAFE.putShort(baseObject, getFieldOffset(ordinal), value); + Platform.putShort(baseObject, getFieldOffset(ordinal), value); } @Override public void setByte(int ordinal, byte value) { assertIndexIsValid(ordinal); setNotNullAt(ordinal); - PlatformDependent.UNSAFE.putByte(baseObject, getFieldOffset(ordinal), value); + Platform.putByte(baseObject, getFieldOffset(ordinal), value); } @Override @@ -239,7 +239,7 @@ public void setFloat(int ordinal, float value) { if (Float.isNaN(value)) { value = Float.NaN; } - PlatformDependent.UNSAFE.putFloat(baseObject, getFieldOffset(ordinal), value); + Platform.putFloat(baseObject, getFieldOffset(ordinal), value); } /** @@ -263,23 +263,23 @@ public void setDecimal(int ordinal, Decimal value, int precision) { long cursor = getLong(ordinal) >>> 32; assert cursor > 0 : "invalid cursor " + cursor; // zero-out the bytes - PlatformDependent.UNSAFE.putLong(baseObject, baseOffset + cursor, 0L); - PlatformDependent.UNSAFE.putLong(baseObject, baseOffset + cursor + 8, 0L); + Platform.putLong(baseObject, baseOffset + cursor, 0L); + Platform.putLong(baseObject, baseOffset + cursor + 8, 0L); if (value == null) { setNullAt(ordinal); // keep the offset for future update - PlatformDependent.UNSAFE.putLong(baseObject, getFieldOffset(ordinal), cursor << 32); + Platform.putLong(baseObject, getFieldOffset(ordinal), cursor << 32); } else { final BigInteger integer = value.toJavaBigDecimal().unscaledValue(); - final int[] mag = (int[]) PlatformDependent.UNSAFE.getObjectVolatile(integer, - PlatformDependent.BIG_INTEGER_MAG_OFFSET); + final int[] mag = (int[]) Platform.getObjectVolatile(integer, + Platform.BIG_INTEGER_MAG_OFFSET); assert(mag.length <= 4); // Write the bytes to the variable length portion. - PlatformDependent.copyMemory(mag, PlatformDependent.INT_ARRAY_OFFSET, - baseObject, baseOffset + cursor, mag.length * 4); + Platform.copyMemory( + mag, Platform.INT_ARRAY_OFFSET, baseObject, baseOffset + cursor, mag.length * 4); setLong(ordinal, (cursor << 32) | ((long) (((integer.signum() + 1) << 8) + mag.length))); } } @@ -336,43 +336,43 @@ public boolean isNullAt(int ordinal) { @Override public boolean getBoolean(int ordinal) { assertIndexIsValid(ordinal); - return PlatformDependent.UNSAFE.getBoolean(baseObject, getFieldOffset(ordinal)); + return Platform.getBoolean(baseObject, getFieldOffset(ordinal)); } @Override public byte getByte(int ordinal) { assertIndexIsValid(ordinal); - return PlatformDependent.UNSAFE.getByte(baseObject, getFieldOffset(ordinal)); + return Platform.getByte(baseObject, getFieldOffset(ordinal)); } @Override public short getShort(int ordinal) { assertIndexIsValid(ordinal); - return PlatformDependent.UNSAFE.getShort(baseObject, getFieldOffset(ordinal)); + return Platform.getShort(baseObject, getFieldOffset(ordinal)); } @Override public int getInt(int ordinal) { assertIndexIsValid(ordinal); - return PlatformDependent.UNSAFE.getInt(baseObject, getFieldOffset(ordinal)); + return Platform.getInt(baseObject, getFieldOffset(ordinal)); } @Override public long getLong(int ordinal) { assertIndexIsValid(ordinal); - return PlatformDependent.UNSAFE.getLong(baseObject, getFieldOffset(ordinal)); + return Platform.getLong(baseObject, getFieldOffset(ordinal)); } @Override public float getFloat(int ordinal) { assertIndexIsValid(ordinal); - return PlatformDependent.UNSAFE.getFloat(baseObject, getFieldOffset(ordinal)); + return Platform.getFloat(baseObject, getFieldOffset(ordinal)); } @Override public double getDouble(int ordinal) { assertIndexIsValid(ordinal); - return PlatformDependent.UNSAFE.getDouble(baseObject, getFieldOffset(ordinal)); + return Platform.getDouble(baseObject, getFieldOffset(ordinal)); } private static byte[] EMPTY = new byte[0]; @@ -391,13 +391,13 @@ public Decimal getDecimal(int ordinal, int precision, int scale) { assert signum >=0 && signum <= 2 : "invalid signum " + signum; int size = (int) (offsetAndSize & 0xff); int[] mag = new int[size]; - PlatformDependent.copyMemory(baseObject, baseOffset + offset, - mag, PlatformDependent.INT_ARRAY_OFFSET, size * 4); + Platform.copyMemory( + baseObject, baseOffset + offset, mag, Platform.INT_ARRAY_OFFSET, size * 4); // create a BigInteger using signum and mag BigInteger v = new BigInteger(0, EMPTY); // create the initial object - PlatformDependent.UNSAFE.putInt(v, PlatformDependent.BIG_INTEGER_SIGNUM_OFFSET, signum - 1); - PlatformDependent.UNSAFE.putObjectVolatile(v, PlatformDependent.BIG_INTEGER_MAG_OFFSET, mag); + Platform.putInt(v, Platform.BIG_INTEGER_SIGNUM_OFFSET, signum - 1); + Platform.putObjectVolatile(v, Platform.BIG_INTEGER_MAG_OFFSET, mag); return Decimal.apply(new BigDecimal(v, scale), precision, scale); } } @@ -420,11 +420,11 @@ public byte[] getBinary(int ordinal) { final int offset = (int) (offsetAndSize >> 32); final int size = (int) (offsetAndSize & ((1L << 32) - 1)); final byte[] bytes = new byte[size]; - PlatformDependent.copyMemory( + Platform.copyMemory( baseObject, baseOffset + offset, bytes, - PlatformDependent.BYTE_ARRAY_OFFSET, + Platform.BYTE_ARRAY_OFFSET, size ); return bytes; @@ -438,9 +438,8 @@ public CalendarInterval getInterval(int ordinal) { } else { final long offsetAndSize = getLong(ordinal); final int offset = (int) (offsetAndSize >> 32); - final int months = (int) PlatformDependent.UNSAFE.getLong(baseObject, baseOffset + offset); - final long microseconds = - PlatformDependent.UNSAFE.getLong(baseObject, baseOffset + offset + 8); + final int months = (int) Platform.getLong(baseObject, baseOffset + offset); + final long microseconds = Platform.getLong(baseObject, baseOffset + offset + 8); return new CalendarInterval(months, microseconds); } } @@ -491,14 +490,14 @@ public MapData getMap(int ordinal) { public UnsafeRow copy() { UnsafeRow rowCopy = new UnsafeRow(); final byte[] rowDataCopy = new byte[sizeInBytes]; - PlatformDependent.copyMemory( + Platform.copyMemory( baseObject, baseOffset, rowDataCopy, - PlatformDependent.BYTE_ARRAY_OFFSET, + Platform.BYTE_ARRAY_OFFSET, sizeInBytes ); - rowCopy.pointTo(rowDataCopy, PlatformDependent.BYTE_ARRAY_OFFSET, numFields, sizeInBytes); + rowCopy.pointTo(rowDataCopy, Platform.BYTE_ARRAY_OFFSET, numFields, sizeInBytes); return rowCopy; } @@ -518,18 +517,13 @@ public static UnsafeRow createFromByteArray(int numBytes, int numFields) { */ public void copyFrom(UnsafeRow row) { // copyFrom is only available for UnsafeRow created from byte array. - assert (baseObject instanceof byte[]) && baseOffset == PlatformDependent.BYTE_ARRAY_OFFSET; + assert (baseObject instanceof byte[]) && baseOffset == Platform.BYTE_ARRAY_OFFSET; if (row.sizeInBytes > this.sizeInBytes) { // resize the underlying byte[] if it's not large enough. this.baseObject = new byte[row.sizeInBytes]; } - PlatformDependent.copyMemory( - row.baseObject, - row.baseOffset, - this.baseObject, - this.baseOffset, - row.sizeInBytes - ); + Platform.copyMemory( + row.baseObject, row.baseOffset, this.baseObject, this.baseOffset, row.sizeInBytes); // update the sizeInBytes. this.sizeInBytes = row.sizeInBytes; } @@ -544,19 +538,15 @@ public void copyFrom(UnsafeRow row) { */ public void writeToStream(OutputStream out, byte[] writeBuffer) throws IOException { if (baseObject instanceof byte[]) { - int offsetInByteArray = (int) (PlatformDependent.BYTE_ARRAY_OFFSET - baseOffset); + int offsetInByteArray = (int) (Platform.BYTE_ARRAY_OFFSET - baseOffset); out.write((byte[]) baseObject, offsetInByteArray, sizeInBytes); } else { int dataRemaining = sizeInBytes; long rowReadPosition = baseOffset; while (dataRemaining > 0) { int toTransfer = Math.min(writeBuffer.length, dataRemaining); - PlatformDependent.copyMemory( - baseObject, - rowReadPosition, - writeBuffer, - PlatformDependent.BYTE_ARRAY_OFFSET, - toTransfer); + Platform.copyMemory( + baseObject, rowReadPosition, writeBuffer, Platform.BYTE_ARRAY_OFFSET, toTransfer); out.write(writeBuffer, 0, toTransfer); rowReadPosition += toTransfer; dataRemaining -= toTransfer; @@ -584,13 +574,12 @@ public boolean equals(Object other) { * Returns the underlying bytes for this UnsafeRow. */ public byte[] getBytes() { - if (baseObject instanceof byte[] && baseOffset == PlatformDependent.BYTE_ARRAY_OFFSET + if (baseObject instanceof byte[] && baseOffset == Platform.BYTE_ARRAY_OFFSET && (((byte[]) baseObject).length == sizeInBytes)) { return (byte[]) baseObject; } else { byte[] bytes = new byte[sizeInBytes]; - PlatformDependent.copyMemory(baseObject, baseOffset, bytes, - PlatformDependent.BYTE_ARRAY_OFFSET, sizeInBytes); + Platform.copyMemory(baseObject, baseOffset, bytes, Platform.BYTE_ARRAY_OFFSET, sizeInBytes); return bytes; } } @@ -600,8 +589,7 @@ public byte[] getBytes() { public String toString() { StringBuilder build = new StringBuilder("["); for (int i = 0; i < sizeInBytes; i += 8) { - build.append(java.lang.Long.toHexString( - PlatformDependent.UNSAFE.getLong(baseObject, baseOffset + i))); + build.append(java.lang.Long.toHexString(Platform.getLong(baseObject, baseOffset + i))); build.append(','); } build.append(']'); @@ -619,12 +607,6 @@ public boolean anyNull() { * bytes in this string. */ public void writeToMemory(Object target, long targetOffset) { - PlatformDependent.copyMemory( - baseObject, - baseOffset, - target, - targetOffset, - sizeInBytes - ); + Platform.copyMemory(baseObject, baseOffset, target, targetOffset, sizeInBytes); } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRowWriters.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRowWriters.java index 28e7ec0a0f120..005351f0883e5 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRowWriters.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRowWriters.java @@ -21,7 +21,7 @@ import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.types.Decimal; -import org.apache.spark.unsafe.PlatformDependent; +import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.array.ByteArrayMethods; import org.apache.spark.unsafe.types.ByteArray; import org.apache.spark.unsafe.types.CalendarInterval; @@ -58,27 +58,27 @@ public static int write(UnsafeRow target, int ordinal, int cursor, Decimal input final Object base = target.getBaseObject(); final long offset = target.getBaseOffset() + cursor; // zero-out the bytes - PlatformDependent.UNSAFE.putLong(base, offset, 0L); - PlatformDependent.UNSAFE.putLong(base, offset + 8, 0L); + Platform.putLong(base, offset, 0L); + Platform.putLong(base, offset + 8, 0L); if (input == null) { target.setNullAt(ordinal); // keep the offset and length for update int fieldOffset = UnsafeRow.calculateBitSetWidthInBytes(target.numFields()) + ordinal * 8; - PlatformDependent.UNSAFE.putLong(base, target.getBaseOffset() + fieldOffset, + Platform.putLong(base, target.getBaseOffset() + fieldOffset, ((long) cursor) << 32); return SIZE; } final BigInteger integer = input.toJavaBigDecimal().unscaledValue(); int signum = integer.signum() + 1; - final int[] mag = (int[]) PlatformDependent.UNSAFE.getObjectVolatile(integer, - PlatformDependent.BIG_INTEGER_MAG_OFFSET); + final int[] mag = (int[]) Platform.getObjectVolatile( + integer, Platform.BIG_INTEGER_MAG_OFFSET); assert(mag.length <= 4); // Write the bytes to the variable length portion. - PlatformDependent.copyMemory(mag, PlatformDependent.INT_ARRAY_OFFSET, - base, target.getBaseOffset() + cursor, mag.length * 4); + Platform.copyMemory( + mag, Platform.INT_ARRAY_OFFSET, base, target.getBaseOffset() + cursor, mag.length * 4); // Set the fixed length portion. target.setLong(ordinal, (((long) cursor) << 32) | ((long) ((signum << 8) + mag.length))); @@ -99,8 +99,7 @@ public static int write(UnsafeRow target, int ordinal, int cursor, UTF8String in // zero-out the padding bytes if ((numBytes & 0x07) > 0) { - PlatformDependent.UNSAFE.putLong( - target.getBaseObject(), offset + ((numBytes >> 3) << 3), 0L); + Platform.putLong(target.getBaseObject(), offset + ((numBytes >> 3) << 3), 0L); } // Write the bytes to the variable length portion. @@ -125,8 +124,7 @@ public static int write(UnsafeRow target, int ordinal, int cursor, byte[] input) // zero-out the padding bytes if ((numBytes & 0x07) > 0) { - PlatformDependent.UNSAFE.putLong( - target.getBaseObject(), offset + ((numBytes >> 3) << 3), 0L); + Platform.putLong(target.getBaseObject(), offset + ((numBytes >> 3) << 3), 0L); } // Write the bytes to the variable length portion. @@ -167,8 +165,7 @@ public static int write(UnsafeRow target, int ordinal, int cursor, InternalRow i // zero-out the padding bytes if ((numBytes & 0x07) > 0) { - PlatformDependent.UNSAFE.putLong( - target.getBaseObject(), offset + ((numBytes >> 3) << 3), 0L); + Platform.putLong(target.getBaseObject(), offset + ((numBytes >> 3) << 3), 0L); } // Write the bytes to the variable length portion. @@ -191,8 +188,8 @@ public static int write(UnsafeRow target, int ordinal, int cursor, CalendarInter final long offset = target.getBaseOffset() + cursor; // Write the months and microseconds fields of Interval to the variable length portion. - PlatformDependent.UNSAFE.putLong(target.getBaseObject(), offset, input.months); - PlatformDependent.UNSAFE.putLong(target.getBaseObject(), offset + 8, input.microseconds); + Platform.putLong(target.getBaseObject(), offset, input.months); + Platform.putLong(target.getBaseObject(), offset + 8, input.microseconds); // Set the fixed length portion. target.setLong(ordinal, ((long) cursor) << 32); @@ -212,12 +209,11 @@ public static int write(UnsafeRow target, int ordinal, int cursor, UnsafeArrayDa final long offset = target.getBaseOffset() + cursor; // write the number of elements into first 4 bytes. - PlatformDependent.UNSAFE.putInt(target.getBaseObject(), offset, input.numElements()); + Platform.putInt(target.getBaseObject(), offset, input.numElements()); // zero-out the padding bytes if ((numBytes & 0x07) > 0) { - PlatformDependent.UNSAFE.putLong( - target.getBaseObject(), offset + ((numBytes >> 3) << 3), 0L); + Platform.putLong(target.getBaseObject(), offset + ((numBytes >> 3) << 3), 0L); } // Write the bytes to the variable length portion. @@ -247,14 +243,13 @@ public static int write(UnsafeRow target, int ordinal, int cursor, UnsafeMapData final int numBytes = 4 + 4 + keysNumBytes + valuesNumBytes; // write the number of elements into first 4 bytes. - PlatformDependent.UNSAFE.putInt(target.getBaseObject(), offset, input.numElements()); + Platform.putInt(target.getBaseObject(), offset, input.numElements()); // write the numBytes of key array into second 4 bytes. - PlatformDependent.UNSAFE.putInt(target.getBaseObject(), offset + 4, keysNumBytes); + Platform.putInt(target.getBaseObject(), offset + 4, keysNumBytes); // zero-out the padding bytes if ((numBytes & 0x07) > 0) { - PlatformDependent.UNSAFE.putLong( - target.getBaseObject(), offset + ((numBytes >> 3) << 3), 0L); + Platform.putLong(target.getBaseObject(), offset + ((numBytes >> 3) << 3), 0L); } // Write the bytes of key array to the variable length portion. diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeWriters.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeWriters.java index 0e8e405d055de..cd83695fca033 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeWriters.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeWriters.java @@ -18,8 +18,7 @@ package org.apache.spark.sql.catalyst.expressions; import org.apache.spark.sql.types.Decimal; -import org.apache.spark.unsafe.PlatformDependent; -import org.apache.spark.unsafe.array.ByteArrayMethods; +import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.types.CalendarInterval; import org.apache.spark.unsafe.types.UTF8String; @@ -36,17 +35,11 @@ public static void writeToMemory( // zero-out the padding bytes // if ((numBytes & 0x07) > 0) { -// PlatformDependent.UNSAFE.putLong(targetObject, targetOffset + ((numBytes >> 3) << 3), 0L); +// Platform.putLong(targetObject, targetOffset + ((numBytes >> 3) << 3), 0L); // } // Write the UnsafeData to the target memory. - PlatformDependent.copyMemory( - inputObject, - inputOffset, - targetObject, - targetOffset, - numBytes - ); + Platform.copyMemory(inputObject, inputOffset, targetObject, targetOffset, numBytes); } public static int getRoundedSize(int size) { @@ -68,16 +61,11 @@ public static int write(Object targetObject, long targetOffset, Decimal input) { assert(numBytes <= 16); // zero-out the bytes - PlatformDependent.UNSAFE.putLong(targetObject, targetOffset, 0L); - PlatformDependent.UNSAFE.putLong(targetObject, targetOffset + 8, 0L); + Platform.putLong(targetObject, targetOffset, 0L); + Platform.putLong(targetObject, targetOffset + 8, 0L); // Write the bytes to the variable length portion. - PlatformDependent.copyMemory(bytes, - PlatformDependent.BYTE_ARRAY_OFFSET, - targetObject, - targetOffset, - numBytes); - + Platform.copyMemory(bytes, Platform.BYTE_ARRAY_OFFSET, targetObject, targetOffset, numBytes); return 16; } } @@ -111,8 +99,7 @@ public static int write(Object targetObject, long targetOffset, byte[] input) { final int numBytes = input.length; // Write the bytes to the variable length portion. - writeToMemory(input, PlatformDependent.BYTE_ARRAY_OFFSET, - targetObject, targetOffset, numBytes); + writeToMemory(input, Platform.BYTE_ARRAY_OFFSET, targetObject, targetOffset, numBytes); return getRoundedSize(numBytes); } @@ -144,11 +131,9 @@ public static int getSize(UnsafeRow input) { } public static int write(Object targetObject, long targetOffset, CalendarInterval input) { - // Write the months and microseconds fields of Interval to the variable length portion. - PlatformDependent.UNSAFE.putLong(targetObject, targetOffset, input.months); - PlatformDependent.UNSAFE.putLong(targetObject, targetOffset + 8, input.microseconds); - + Platform.putLong(targetObject, targetOffset, input.months); + Platform.putLong(targetObject, targetOffset + 8, input.microseconds); return 16; } } @@ -165,11 +150,11 @@ public static int write(Object targetObject, long targetOffset, UnsafeArrayData final int numBytes = input.getSizeInBytes(); // write the number of elements into first 4 bytes. - PlatformDependent.UNSAFE.putInt(targetObject, targetOffset, input.numElements()); + Platform.putInt(targetObject, targetOffset, input.numElements()); // Write the bytes to the variable length portion. - writeToMemory(input.getBaseObject(), input.getBaseOffset(), - targetObject, targetOffset + 4, numBytes); + writeToMemory( + input.getBaseObject(), input.getBaseOffset(), targetObject, targetOffset + 4, numBytes); return getRoundedSize(numBytes + 4); } @@ -190,9 +175,9 @@ public static int write(Object targetObject, long targetOffset, UnsafeMapData in final int numBytes = 4 + 4 + keysNumBytes + valuesNumBytes; // write the number of elements into first 4 bytes. - PlatformDependent.UNSAFE.putInt(targetObject, targetOffset, input.numElements()); + Platform.putInt(targetObject, targetOffset, input.numElements()); // write the numBytes of key array into second 4 bytes. - PlatformDependent.UNSAFE.putInt(targetObject, targetOffset + 4, keysNumBytes); + Platform.putInt(targetObject, targetOffset + 4, keysNumBytes); // Write the bytes of key array to the variable length portion. writeToMemory(keyArray.getBaseObject(), keyArray.getBaseOffset(), diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java index a5ae2b9736527..1d27182912c8a 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java @@ -31,7 +31,7 @@ import org.apache.spark.sql.catalyst.expressions.UnsafeProjection; import org.apache.spark.sql.catalyst.expressions.UnsafeRow; import org.apache.spark.sql.types.StructType; -import org.apache.spark.unsafe.PlatformDependent; +import org.apache.spark.unsafe.Platform; import org.apache.spark.util.collection.unsafe.sort.PrefixComparator; import org.apache.spark.util.collection.unsafe.sort.RecordComparator; import org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter; @@ -157,7 +157,7 @@ public UnsafeRow next() { cleanupResources(); // Scala iterators don't declare any checked exceptions, so we need to use this hack // to re-throw the exception: - PlatformDependent.throwException(e); + Platform.throwException(e); } throw new RuntimeException("Exception should have been re-thrown in next()"); }; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index c21f4d626a74e..bf96248feaef7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -28,7 +28,7 @@ import org.apache.spark.Logging import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.PlatformDependent +import org.apache.spark.unsafe.Platform import org.apache.spark.unsafe.types._ @@ -371,7 +371,7 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin // Cannot be under package codegen, or fail with java.lang.InstantiationException evaluator.setClassName("org.apache.spark.sql.catalyst.expressions.GeneratedClass") evaluator.setDefaultImports(Array( - classOf[PlatformDependent].getName, + classOf[Platform].getName, classOf[InternalRow].getName, classOf[UnsafeRow].getName, classOf[UTF8String].getName, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala index 29f6a7b981752..b2fb913850794 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -145,12 +145,11 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro if ($buffer.length < $numBytes) { // This will not happen frequently, because the buffer is re-used. byte[] $tmp = new byte[$numBytes * 2]; - PlatformDependent.copyMemory($buffer, PlatformDependent.BYTE_ARRAY_OFFSET, - $tmp, PlatformDependent.BYTE_ARRAY_OFFSET, $buffer.length); + Platform.copyMemory($buffer, Platform.BYTE_ARRAY_OFFSET, + $tmp, Platform.BYTE_ARRAY_OFFSET, $buffer.length); $buffer = $tmp; } - $output.pointTo($buffer, PlatformDependent.BYTE_ARRAY_OFFSET, - ${inputTypes.length}, $numBytes); + $output.pointTo($buffer, Platform.BYTE_ARRAY_OFFSET, ${inputTypes.length}, $numBytes); """ } else { "" @@ -183,7 +182,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro val code = s""" $cursor = $fixedSize; - $output.pointTo($buffer, PlatformDependent.BYTE_ARRAY_OFFSET, ${inputTypes.length}, $cursor); + $output.pointTo($buffer, Platform.BYTE_ARRAY_OFFSET, ${inputTypes.length}, $cursor); ${ctx.splitExpressions(row, convertedFields)} """ GeneratedExpressionCode(code, "false", output) @@ -267,17 +266,17 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro // Should we do word align? val elementSize = elementType.defaultSize s""" - PlatformDependent.UNSAFE.put${ctx.primitiveTypeName(elementType)}( + Platform.put${ctx.primitiveTypeName(elementType)}( $buffer, - PlatformDependent.BYTE_ARRAY_OFFSET + $cursor, + Platform.BYTE_ARRAY_OFFSET + $cursor, ${convertedElement.primitive}); $cursor += $elementSize; """ case t: DecimalType if t.precision <= Decimal.MAX_LONG_DIGITS => s""" - PlatformDependent.UNSAFE.putLong( + Platform.putLong( $buffer, - PlatformDependent.BYTE_ARRAY_OFFSET + $cursor, + Platform.BYTE_ARRAY_OFFSET + $cursor, ${convertedElement.primitive}.toUnscaledLong()); $cursor += 8; """ @@ -286,7 +285,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro s""" $cursor += $writer.write( $buffer, - PlatformDependent.BYTE_ARRAY_OFFSET + $cursor, + Platform.BYTE_ARRAY_OFFSET + $cursor, $elements[$index]); """ } @@ -320,23 +319,16 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro for (int $index = 0; $index < $numElements; $index++) { if ($checkNull) { // If element is null, write the negative value address into offset region. - PlatformDependent.UNSAFE.putInt( - $buffer, - PlatformDependent.BYTE_ARRAY_OFFSET + 4 * $index, - -$cursor); + Platform.putInt($buffer, Platform.BYTE_ARRAY_OFFSET + 4 * $index, -$cursor); } else { - PlatformDependent.UNSAFE.putInt( - $buffer, - PlatformDependent.BYTE_ARRAY_OFFSET + 4 * $index, - $cursor); - + Platform.putInt($buffer, Platform.BYTE_ARRAY_OFFSET + 4 * $index, $cursor); $writeElement } } $output.pointTo( $buffer, - PlatformDependent.BYTE_ARRAY_OFFSET, + Platform.BYTE_ARRAY_OFFSET, $numElements, $numBytes); } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala index 8aaa5b4300044..da91ff29537b3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions.codegen import org.apache.spark.sql.catalyst.expressions.{UnsafeRow, Attribute} import org.apache.spark.sql.types.StructType -import org.apache.spark.unsafe.PlatformDependent +import org.apache.spark.unsafe.Platform abstract class UnsafeRowJoiner { @@ -52,9 +52,9 @@ object GenerateUnsafeRowJoiner extends CodeGenerator[(StructType, StructType), U } def create(schema1: StructType, schema2: StructType): UnsafeRowJoiner = { - val offset = PlatformDependent.BYTE_ARRAY_OFFSET - val getLong = "PlatformDependent.UNSAFE.getLong" - val putLong = "PlatformDependent.UNSAFE.putLong" + val offset = Platform.BYTE_ARRAY_OFFSET + val getLong = "Platform.getLong" + val putLong = "Platform.putLong" val bitset1Words = (schema1.size + 63) / 64 val bitset2Words = (schema2.size + 63) / 64 @@ -96,7 +96,7 @@ object GenerateUnsafeRowJoiner extends CodeGenerator[(StructType, StructType), U var cursor = offset + outputBitsetWords * 8 val copyFixedLengthRow1 = s""" |// Copy fixed length data for row1 - |PlatformDependent.copyMemory( + |Platform.copyMemory( | obj1, offset1 + ${bitset1Words * 8}, | buf, $cursor, | ${schema1.size * 8}); @@ -106,7 +106,7 @@ object GenerateUnsafeRowJoiner extends CodeGenerator[(StructType, StructType), U // --------------------- copy fixed length portion from row 2 ----------------------- // val copyFixedLengthRow2 = s""" |// Copy fixed length data for row2 - |PlatformDependent.copyMemory( + |Platform.copyMemory( | obj2, offset2 + ${bitset2Words * 8}, | buf, $cursor, | ${schema2.size * 8}); @@ -118,7 +118,7 @@ object GenerateUnsafeRowJoiner extends CodeGenerator[(StructType, StructType), U val copyVariableLengthRow1 = s""" |// Copy variable length data for row1 |long numBytesVariableRow1 = row1.getSizeInBytes() - $numBytesBitsetAndFixedRow1; - |PlatformDependent.copyMemory( + |Platform.copyMemory( | obj1, offset1 + ${(bitset1Words + schema1.size) * 8}, | buf, $cursor, | numBytesVariableRow1); @@ -129,7 +129,7 @@ object GenerateUnsafeRowJoiner extends CodeGenerator[(StructType, StructType), U val copyVariableLengthRow2 = s""" |// Copy variable length data for row2 |long numBytesVariableRow2 = row2.getSizeInBytes() - $numBytesBitsetAndFixedRow2; - |PlatformDependent.copyMemory( + |Platform.copyMemory( | obj2, offset2 + ${(bitset2Words + schema2.size) * 8}, | buf, $cursor + numBytesVariableRow1, | numBytesVariableRow2); diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala index 76666bd6b3d27..134f1aa2af9a8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala @@ -1013,7 +1013,7 @@ case class Decode(bin: Expression, charset: Expression) try { ${ev.primitive} = UTF8String.fromString(new String($bytes, $charset.toString())); } catch (java.io.UnsupportedEncodingException e) { - org.apache.spark.unsafe.PlatformDependent.throwException(e); + org.apache.spark.unsafe.Platform.throwException(e); } """) } @@ -1043,7 +1043,7 @@ case class Encode(value: Expression, charset: Expression) try { ${ev.primitive} = $string.toString().getBytes($charset.toString()); } catch (java.io.UnsupportedEncodingException e) { - org.apache.spark.unsafe.PlatformDependent.throwException(e); + org.apache.spark.unsafe.Platform.throwException(e); }""") } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoinerBitsetSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoinerBitsetSuite.scala index aff1bee99faad..796d60032e1a6 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoinerBitsetSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoinerBitsetSuite.scala @@ -22,7 +22,7 @@ import scala.util.Random import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.PlatformDependent +import org.apache.spark.unsafe.Platform /** * A test suite for the bitset portion of the row concatenation. @@ -96,7 +96,7 @@ class GenerateUnsafeRowJoinerBitsetSuite extends SparkFunSuite { // This way we can test the joiner when the input UnsafeRows are not the entire arrays. val offset = numFields * 8 val buf = new Array[Byte](sizeInBytes + offset) - row.pointTo(buf, PlatformDependent.BYTE_ARRAY_OFFSET + offset, numFields, sizeInBytes) + row.pointTo(buf, Platform.BYTE_ARRAY_OFFSET + offset, numFields, sizeInBytes) row } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java index 00218f213054b..5cce41d5a7569 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java @@ -27,7 +27,7 @@ import org.apache.spark.sql.types.StructField; import org.apache.spark.sql.types.StructType; import org.apache.spark.unsafe.KVIterator; -import org.apache.spark.unsafe.PlatformDependent; +import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.map.BytesToBytesMap; import org.apache.spark.unsafe.memory.MemoryLocation; import org.apache.spark.unsafe.memory.TaskMemoryManager; @@ -138,7 +138,7 @@ public UnsafeRow getAggregationBufferFromUnsafeRow(UnsafeRow unsafeGroupingKeyRo unsafeGroupingKeyRow.getBaseOffset(), unsafeGroupingKeyRow.getSizeInBytes(), emptyAggregationBuffer, - PlatformDependent.BYTE_ARRAY_OFFSET, + Platform.BYTE_ARRAY_OFFSET, emptyAggregationBuffer.length ); if (!putSucceeded) { diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java index 69d6784713a24..7db6b7ff50f22 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java @@ -31,7 +31,7 @@ import org.apache.spark.sql.types.StructType; import org.apache.spark.storage.BlockManager; import org.apache.spark.unsafe.KVIterator; -import org.apache.spark.unsafe.PlatformDependent; +import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.map.BytesToBytesMap; import org.apache.spark.unsafe.memory.MemoryBlock; import org.apache.spark.unsafe.memory.TaskMemoryManager; @@ -225,7 +225,7 @@ public boolean next() throws IOException { int recordLen = underlying.getRecordLength(); // Note that recordLen = keyLen + valueLen + 4 bytes (for the keyLen itself) - int keyLen = PlatformDependent.UNSAFE.getInt(baseObj, recordOffset); + int keyLen = Platform.getInt(baseObj, recordOffset); int valueLen = recordLen - keyLen - 4; key.pointTo(baseObj, recordOffset + 4, numKeyFields, keyLen); value.pointTo(baseObj, recordOffset + 4 + keyLen, numValueFields, valueLen); diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala index 6c7e5cacc99e7..3860c4bba9a99 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala @@ -26,7 +26,7 @@ import com.google.common.io.ByteStreams import org.apache.spark.serializer.{SerializationStream, DeserializationStream, SerializerInstance, Serializer} import org.apache.spark.sql.catalyst.expressions.UnsafeRow -import org.apache.spark.unsafe.PlatformDependent +import org.apache.spark.unsafe.Platform /** * Serializer for serializing [[UnsafeRow]]s during shuffle. Since UnsafeRows are already stored as @@ -116,7 +116,7 @@ private class UnsafeRowSerializerInstance(numFields: Int) extends SerializerInst rowBuffer = new Array[Byte](rowSize) } ByteStreams.readFully(dIn, rowBuffer, 0, rowSize) - row.pointTo(rowBuffer, PlatformDependent.BYTE_ARRAY_OFFSET, numFields, rowSize) + row.pointTo(rowBuffer, Platform.BYTE_ARRAY_OFFSET, numFields, rowSize) rowSize = dIn.readInt() // read the next row's size if (rowSize == EOF) { // We are returning the last row in this stream val _rowTuple = rowTuple @@ -150,7 +150,7 @@ private class UnsafeRowSerializerInstance(numFields: Int) extends SerializerInst rowBuffer = new Array[Byte](rowSize) } ByteStreams.readFully(dIn, rowBuffer, 0, rowSize) - row.pointTo(rowBuffer, PlatformDependent.BYTE_ARRAY_OFFSET, numFields, rowSize) + row.pointTo(rowBuffer, Platform.BYTE_ARRAY_OFFSET, numFields, rowSize) row.asInstanceOf[T] } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala index 953abf409f220..63d35d0f02622 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala @@ -25,7 +25,7 @@ import org.apache.spark.shuffle.ShuffleMemoryManager import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.execution.SparkSqlSerializer -import org.apache.spark.unsafe.PlatformDependent +import org.apache.spark.unsafe.Platform import org.apache.spark.unsafe.map.BytesToBytesMap import org.apache.spark.unsafe.memory.{ExecutorMemoryManager, MemoryAllocator, TaskMemoryManager} import org.apache.spark.util.Utils @@ -218,8 +218,8 @@ private[joins] final class UnsafeHashedRelation( var offset = loc.getValueAddress.getBaseOffset val last = loc.getValueAddress.getBaseOffset + loc.getValueLength while (offset < last) { - val numFields = PlatformDependent.UNSAFE.getInt(base, offset) - val sizeInBytes = PlatformDependent.UNSAFE.getInt(base, offset + 4) + val numFields = Platform.getInt(base, offset) + val sizeInBytes = Platform.getInt(base, offset + 4) offset += 8 val row = new UnsafeRow @@ -314,10 +314,11 @@ private[joins] final class UnsafeHashedRelation( in.readFully(valuesBuffer, 0, valuesSize) // put it into binary map - val loc = binaryMap.lookup(keyBuffer, PlatformDependent.BYTE_ARRAY_OFFSET, keySize) + val loc = binaryMap.lookup(keyBuffer, Platform.BYTE_ARRAY_OFFSET, keySize) assert(!loc.isDefined, "Duplicated key found!") - val putSuceeded = loc.putNewKey(keyBuffer, PlatformDependent.BYTE_ARRAY_OFFSET, keySize, - valuesBuffer, PlatformDependent.BYTE_ARRAY_OFFSET, valuesSize) + val putSuceeded = loc.putNewKey( + keyBuffer, Platform.BYTE_ARRAY_OFFSET, keySize, + valuesBuffer, Platform.BYTE_ARRAY_OFFSET, valuesSize) if (!putSuceeded) { throw new IOException("Could not allocate memory to grow BytesToBytesMap") } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala index 89bad1bfdab0a..219435dff5bc8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala @@ -23,7 +23,7 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{UnsafeRow, UnsafeProjection} import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.PlatformDependent +import org.apache.spark.unsafe.Platform import org.apache.spark.unsafe.memory.MemoryAllocator import org.apache.spark.unsafe.types.UTF8String @@ -51,7 +51,7 @@ class UnsafeRowSuite extends SparkFunSuite { val bytesFromOffheapRow: Array[Byte] = { val offheapRowPage = MemoryAllocator.UNSAFE.allocate(arrayBackedUnsafeRow.getSizeInBytes) try { - PlatformDependent.copyMemory( + Platform.copyMemory( arrayBackedUnsafeRow.getBaseObject, arrayBackedUnsafeRow.getBaseOffset, offheapRowPage.getBaseObject, diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/PlatformDependent.java b/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java similarity index 55% rename from unsafe/src/main/java/org/apache/spark/unsafe/PlatformDependent.java rename to unsafe/src/main/java/org/apache/spark/unsafe/Platform.java index b2de2a2590f05..18343efdc3437 100644 --- a/unsafe/src/main/java/org/apache/spark/unsafe/PlatformDependent.java +++ b/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java @@ -22,103 +22,111 @@ import sun.misc.Unsafe; -public final class PlatformDependent { +public final class Platform { - /** - * Facade in front of {@link sun.misc.Unsafe}, used to avoid directly exposing Unsafe outside of - * this package. This also lets us avoid accidental use of deprecated methods. - */ - public static final class UNSAFE { - - private UNSAFE() { } + private static final Unsafe _UNSAFE; - public static int getInt(Object object, long offset) { - return _UNSAFE.getInt(object, offset); - } + public static final int BYTE_ARRAY_OFFSET; - public static void putInt(Object object, long offset, int value) { - _UNSAFE.putInt(object, offset, value); - } + public static final int INT_ARRAY_OFFSET; - public static boolean getBoolean(Object object, long offset) { - return _UNSAFE.getBoolean(object, offset); - } + public static final int LONG_ARRAY_OFFSET; - public static void putBoolean(Object object, long offset, boolean value) { - _UNSAFE.putBoolean(object, offset, value); - } + public static final int DOUBLE_ARRAY_OFFSET; - public static byte getByte(Object object, long offset) { - return _UNSAFE.getByte(object, offset); - } + // Support for resetting final fields while deserializing + public static final long BIG_INTEGER_SIGNUM_OFFSET; + public static final long BIG_INTEGER_MAG_OFFSET; - public static void putByte(Object object, long offset, byte value) { - _UNSAFE.putByte(object, offset, value); - } + public static int getInt(Object object, long offset) { + return _UNSAFE.getInt(object, offset); + } - public static short getShort(Object object, long offset) { - return _UNSAFE.getShort(object, offset); - } + public static void putInt(Object object, long offset, int value) { + _UNSAFE.putInt(object, offset, value); + } - public static void putShort(Object object, long offset, short value) { - _UNSAFE.putShort(object, offset, value); - } + public static boolean getBoolean(Object object, long offset) { + return _UNSAFE.getBoolean(object, offset); + } - public static long getLong(Object object, long offset) { - return _UNSAFE.getLong(object, offset); - } + public static void putBoolean(Object object, long offset, boolean value) { + _UNSAFE.putBoolean(object, offset, value); + } - public static void putLong(Object object, long offset, long value) { - _UNSAFE.putLong(object, offset, value); - } + public static byte getByte(Object object, long offset) { + return _UNSAFE.getByte(object, offset); + } - public static float getFloat(Object object, long offset) { - return _UNSAFE.getFloat(object, offset); - } + public static void putByte(Object object, long offset, byte value) { + _UNSAFE.putByte(object, offset, value); + } - public static void putFloat(Object object, long offset, float value) { - _UNSAFE.putFloat(object, offset, value); - } + public static short getShort(Object object, long offset) { + return _UNSAFE.getShort(object, offset); + } - public static double getDouble(Object object, long offset) { - return _UNSAFE.getDouble(object, offset); - } + public static void putShort(Object object, long offset, short value) { + _UNSAFE.putShort(object, offset, value); + } - public static void putDouble(Object object, long offset, double value) { - _UNSAFE.putDouble(object, offset, value); - } + public static long getLong(Object object, long offset) { + return _UNSAFE.getLong(object, offset); + } - public static Object getObjectVolatile(Object object, long offset) { - return _UNSAFE.getObjectVolatile(object, offset); - } + public static void putLong(Object object, long offset, long value) { + _UNSAFE.putLong(object, offset, value); + } - public static void putObjectVolatile(Object object, long offset, Object value) { - _UNSAFE.putObjectVolatile(object, offset, value); - } + public static float getFloat(Object object, long offset) { + return _UNSAFE.getFloat(object, offset); + } - public static long allocateMemory(long size) { - return _UNSAFE.allocateMemory(size); - } + public static void putFloat(Object object, long offset, float value) { + _UNSAFE.putFloat(object, offset, value); + } - public static void freeMemory(long address) { - _UNSAFE.freeMemory(address); - } + public static double getDouble(Object object, long offset) { + return _UNSAFE.getDouble(object, offset); + } + public static void putDouble(Object object, long offset, double value) { + _UNSAFE.putDouble(object, offset, value); } - private static final Unsafe _UNSAFE; + public static Object getObjectVolatile(Object object, long offset) { + return _UNSAFE.getObjectVolatile(object, offset); + } - public static final int BYTE_ARRAY_OFFSET; + public static void putObjectVolatile(Object object, long offset, Object value) { + _UNSAFE.putObjectVolatile(object, offset, value); + } - public static final int INT_ARRAY_OFFSET; + public static long allocateMemory(long size) { + return _UNSAFE.allocateMemory(size); + } - public static final int LONG_ARRAY_OFFSET; + public static void freeMemory(long address) { + _UNSAFE.freeMemory(address); + } - public static final int DOUBLE_ARRAY_OFFSET; + public static void copyMemory( + Object src, long srcOffset, Object dst, long dstOffset, long length) { + while (length > 0) { + long size = Math.min(length, UNSAFE_COPY_THRESHOLD); + _UNSAFE.copyMemory(src, srcOffset, dst, dstOffset, size); + length -= size; + srcOffset += size; + dstOffset += size; + } + } - // Support for resetting final fields while deserializing - public static final long BIG_INTEGER_SIGNUM_OFFSET; - public static final long BIG_INTEGER_MAG_OFFSET; + /** + * Raises an exception bypassing compiler checks for checked exceptions. + */ + public static void throwException(Throwable t) { + _UNSAFE.throwException(t); + } /** * Limits the number of bytes to copy per {@link Unsafe#copyMemory(long, long, long)} to @@ -162,26 +170,4 @@ public static void freeMemory(long address) { BIG_INTEGER_MAG_OFFSET = 0; } } - - static public void copyMemory( - Object src, - long srcOffset, - Object dst, - long dstOffset, - long length) { - while (length > 0) { - long size = Math.min(length, UNSAFE_COPY_THRESHOLD); - _UNSAFE.copyMemory(src, srcOffset, dst, dstOffset, size); - length -= size; - srcOffset += size; - dstOffset += size; - } - } - - /** - * Raises an exception bypassing compiler checks for checked exceptions. - */ - public static void throwException(Throwable t) { - _UNSAFE.throwException(t); - } } diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/array/ByteArrayMethods.java b/unsafe/src/main/java/org/apache/spark/unsafe/array/ByteArrayMethods.java index 70b81ce015ddc..cf42877bf9fd4 100644 --- a/unsafe/src/main/java/org/apache/spark/unsafe/array/ByteArrayMethods.java +++ b/unsafe/src/main/java/org/apache/spark/unsafe/array/ByteArrayMethods.java @@ -17,7 +17,7 @@ package org.apache.spark.unsafe.array; -import static org.apache.spark.unsafe.PlatformDependent.*; +import org.apache.spark.unsafe.Platform; public class ByteArrayMethods { @@ -45,20 +45,18 @@ public static int roundNumberOfBytesToNearestWord(int numBytes) { * @return true if the arrays are equal, false otherwise */ public static boolean arrayEquals( - Object leftBase, - long leftOffset, - Object rightBase, - long rightOffset, - final long length) { + Object leftBase, long leftOffset, Object rightBase, long rightOffset, final long length) { int i = 0; while (i <= length - 8) { - if (UNSAFE.getLong(leftBase, leftOffset + i) != UNSAFE.getLong(rightBase, rightOffset + i)) { + if (Platform.getLong(leftBase, leftOffset + i) != + Platform.getLong(rightBase, rightOffset + i)) { return false; } i += 8; } while (i < length) { - if (UNSAFE.getByte(leftBase, leftOffset + i) != UNSAFE.getByte(rightBase, rightOffset + i)) { + if (Platform.getByte(leftBase, leftOffset + i) != + Platform.getByte(rightBase, rightOffset + i)) { return false; } i += 1; diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/array/LongArray.java b/unsafe/src/main/java/org/apache/spark/unsafe/array/LongArray.java index 18d1f0d2d7eb2..74105050e4191 100644 --- a/unsafe/src/main/java/org/apache/spark/unsafe/array/LongArray.java +++ b/unsafe/src/main/java/org/apache/spark/unsafe/array/LongArray.java @@ -17,7 +17,7 @@ package org.apache.spark.unsafe.array; -import org.apache.spark.unsafe.PlatformDependent; +import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.memory.MemoryBlock; /** @@ -64,7 +64,7 @@ public long size() { public void set(int index, long value) { assert index >= 0 : "index (" + index + ") should >= 0"; assert index < length : "index (" + index + ") should < length (" + length + ")"; - PlatformDependent.UNSAFE.putLong(baseObj, baseOffset + index * WIDTH, value); + Platform.putLong(baseObj, baseOffset + index * WIDTH, value); } /** @@ -73,6 +73,6 @@ public void set(int index, long value) { public long get(int index) { assert index >= 0 : "index (" + index + ") should >= 0"; assert index < length : "index (" + index + ") should < length (" + length + ")"; - return PlatformDependent.UNSAFE.getLong(baseObj, baseOffset + index * WIDTH); + return Platform.getLong(baseObj, baseOffset + index * WIDTH); } } diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/bitset/BitSetMethods.java b/unsafe/src/main/java/org/apache/spark/unsafe/bitset/BitSetMethods.java index 27462c7fa5e62..7857bf66a72ad 100644 --- a/unsafe/src/main/java/org/apache/spark/unsafe/bitset/BitSetMethods.java +++ b/unsafe/src/main/java/org/apache/spark/unsafe/bitset/BitSetMethods.java @@ -17,7 +17,7 @@ package org.apache.spark.unsafe.bitset; -import org.apache.spark.unsafe.PlatformDependent; +import org.apache.spark.unsafe.Platform; /** * Methods for working with fixed-size uncompressed bitsets. @@ -41,8 +41,8 @@ public static void set(Object baseObject, long baseOffset, int index) { assert index >= 0 : "index (" + index + ") should >= 0"; final long mask = 1L << (index & 0x3f); // mod 64 and shift final long wordOffset = baseOffset + (index >> 6) * WORD_SIZE; - final long word = PlatformDependent.UNSAFE.getLong(baseObject, wordOffset); - PlatformDependent.UNSAFE.putLong(baseObject, wordOffset, word | mask); + final long word = Platform.getLong(baseObject, wordOffset); + Platform.putLong(baseObject, wordOffset, word | mask); } /** @@ -52,8 +52,8 @@ public static void unset(Object baseObject, long baseOffset, int index) { assert index >= 0 : "index (" + index + ") should >= 0"; final long mask = 1L << (index & 0x3f); // mod 64 and shift final long wordOffset = baseOffset + (index >> 6) * WORD_SIZE; - final long word = PlatformDependent.UNSAFE.getLong(baseObject, wordOffset); - PlatformDependent.UNSAFE.putLong(baseObject, wordOffset, word & ~mask); + final long word = Platform.getLong(baseObject, wordOffset); + Platform.putLong(baseObject, wordOffset, word & ~mask); } /** @@ -63,7 +63,7 @@ public static boolean isSet(Object baseObject, long baseOffset, int index) { assert index >= 0 : "index (" + index + ") should >= 0"; final long mask = 1L << (index & 0x3f); // mod 64 and shift final long wordOffset = baseOffset + (index >> 6) * WORD_SIZE; - final long word = PlatformDependent.UNSAFE.getLong(baseObject, wordOffset); + final long word = Platform.getLong(baseObject, wordOffset); return (word & mask) != 0; } @@ -73,7 +73,7 @@ public static boolean isSet(Object baseObject, long baseOffset, int index) { public static boolean anySet(Object baseObject, long baseOffset, long bitSetWidthInWords) { long addr = baseOffset; for (int i = 0; i < bitSetWidthInWords; i++, addr += WORD_SIZE) { - if (PlatformDependent.UNSAFE.getLong(baseObject, addr) != 0) { + if (Platform.getLong(baseObject, addr) != 0) { return true; } } @@ -109,8 +109,7 @@ public static int nextSetBit( // Try to find the next set bit in the current word final int subIndex = fromIndex & 0x3f; - long word = - PlatformDependent.UNSAFE.getLong(baseObject, baseOffset + wi * WORD_SIZE) >> subIndex; + long word = Platform.getLong(baseObject, baseOffset + wi * WORD_SIZE) >> subIndex; if (word != 0) { return (wi << 6) + subIndex + java.lang.Long.numberOfTrailingZeros(word); } @@ -118,7 +117,7 @@ public static int nextSetBit( // Find the next set bit in the rest of the words wi += 1; while (wi < bitsetSizeInWords) { - word = PlatformDependent.UNSAFE.getLong(baseObject, baseOffset + wi * WORD_SIZE); + word = Platform.getLong(baseObject, baseOffset + wi * WORD_SIZE); if (word != 0) { return (wi << 6) + java.lang.Long.numberOfTrailingZeros(word); } diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/hash/Murmur3_x86_32.java b/unsafe/src/main/java/org/apache/spark/unsafe/hash/Murmur3_x86_32.java index 61f483ced3217..4276f25c2165b 100644 --- a/unsafe/src/main/java/org/apache/spark/unsafe/hash/Murmur3_x86_32.java +++ b/unsafe/src/main/java/org/apache/spark/unsafe/hash/Murmur3_x86_32.java @@ -17,7 +17,7 @@ package org.apache.spark.unsafe.hash; -import org.apache.spark.unsafe.PlatformDependent; +import org.apache.spark.unsafe.Platform; /** * 32-bit Murmur3 hasher. This is based on Guava's Murmur3_32HashFunction. @@ -53,7 +53,7 @@ public static int hashUnsafeWords(Object base, long offset, int lengthInBytes, i assert (lengthInBytes % 8 == 0): "lengthInBytes must be a multiple of 8 (word-aligned)"; int h1 = seed; for (int i = 0; i < lengthInBytes; i += 4) { - int halfWord = PlatformDependent.UNSAFE.getInt(base, offset + i); + int halfWord = Platform.getInt(base, offset + i); int k1 = mixK1(halfWord); h1 = mixH1(h1, k1); } diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryBlock.java b/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryBlock.java index 91be46ba21ff8..dd75820834370 100644 --- a/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryBlock.java +++ b/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryBlock.java @@ -19,7 +19,7 @@ import javax.annotation.Nullable; -import org.apache.spark.unsafe.PlatformDependent; +import org.apache.spark.unsafe.Platform; /** * A consecutive block of memory, starting at a {@link MemoryLocation} with a fixed size. @@ -50,6 +50,6 @@ public long size() { * Creates a memory block pointing to the memory used by the long array. */ public static MemoryBlock fromLongArray(final long[] array) { - return new MemoryBlock(array, PlatformDependent.LONG_ARRAY_OFFSET, array.length * 8); + return new MemoryBlock(array, Platform.LONG_ARRAY_OFFSET, array.length * 8); } } diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/memory/UnsafeMemoryAllocator.java b/unsafe/src/main/java/org/apache/spark/unsafe/memory/UnsafeMemoryAllocator.java index 62f4459696c28..cda7826c8c99b 100644 --- a/unsafe/src/main/java/org/apache/spark/unsafe/memory/UnsafeMemoryAllocator.java +++ b/unsafe/src/main/java/org/apache/spark/unsafe/memory/UnsafeMemoryAllocator.java @@ -17,7 +17,7 @@ package org.apache.spark.unsafe.memory; -import org.apache.spark.unsafe.PlatformDependent; +import org.apache.spark.unsafe.Platform; /** * A simple {@link MemoryAllocator} that uses {@code Unsafe} to allocate off-heap memory. @@ -29,7 +29,7 @@ public MemoryBlock allocate(long size) throws OutOfMemoryError { if (size % 8 != 0) { throw new IllegalArgumentException("Size " + size + " was not a multiple of 8"); } - long address = PlatformDependent.UNSAFE.allocateMemory(size); + long address = Platform.allocateMemory(size); return new MemoryBlock(null, address, size); } @@ -37,6 +37,6 @@ public MemoryBlock allocate(long size) throws OutOfMemoryError { public void free(MemoryBlock memory) { assert (memory.obj == null) : "baseObject not null; are you trying to use the off-heap allocator to free on-heap memory?"; - PlatformDependent.UNSAFE.freeMemory(memory.offset); + Platform.freeMemory(memory.offset); } } diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/types/ByteArray.java b/unsafe/src/main/java/org/apache/spark/unsafe/types/ByteArray.java index 69b0e206cef18..c08c9c73d2396 100644 --- a/unsafe/src/main/java/org/apache/spark/unsafe/types/ByteArray.java +++ b/unsafe/src/main/java/org/apache/spark/unsafe/types/ByteArray.java @@ -17,7 +17,7 @@ package org.apache.spark.unsafe.types; -import org.apache.spark.unsafe.PlatformDependent; +import org.apache.spark.unsafe.Platform; public class ByteArray { @@ -27,12 +27,6 @@ public class ByteArray { * hold all the bytes in this string. */ public static void writeToMemory(byte[] src, Object target, long targetOffset) { - PlatformDependent.copyMemory( - src, - PlatformDependent.BYTE_ARRAY_OFFSET, - target, - targetOffset, - src.length - ); + Platform.copyMemory(src, Platform.BYTE_ARRAY_OFFSET, target, targetOffset, src.length); } } diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java index d1014426c0f49..667c00900f2c5 100644 --- a/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java +++ b/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java @@ -24,10 +24,10 @@ import java.util.Arrays; import java.util.Map; -import org.apache.spark.unsafe.PlatformDependent; +import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.array.ByteArrayMethods; -import static org.apache.spark.unsafe.PlatformDependent.*; +import static org.apache.spark.unsafe.Platform.*; /** @@ -133,13 +133,7 @@ protected UTF8String(Object base, long offset, int numBytes) { * bytes in this string. */ public void writeToMemory(Object target, long targetOffset) { - PlatformDependent.copyMemory( - base, - offset, - target, - targetOffset, - numBytes - ); + Platform.copyMemory(base, offset, target, targetOffset, numBytes); } /** @@ -183,12 +177,12 @@ public long getPrefix() { long mask = 0; if (isLittleEndian) { if (numBytes >= 8) { - p = PlatformDependent.UNSAFE.getLong(base, offset); + p = Platform.getLong(base, offset); } else if (numBytes > 4) { - p = PlatformDependent.UNSAFE.getLong(base, offset); + p = Platform.getLong(base, offset); mask = (1L << (8 - numBytes) * 8) - 1; } else if (numBytes > 0) { - p = (long) PlatformDependent.UNSAFE.getInt(base, offset); + p = (long) Platform.getInt(base, offset); mask = (1L << (8 - numBytes) * 8) - 1; } else { p = 0; @@ -197,12 +191,12 @@ public long getPrefix() { } else { // byteOrder == ByteOrder.BIG_ENDIAN if (numBytes >= 8) { - p = PlatformDependent.UNSAFE.getLong(base, offset); + p = Platform.getLong(base, offset); } else if (numBytes > 4) { - p = PlatformDependent.UNSAFE.getLong(base, offset); + p = Platform.getLong(base, offset); mask = (1L << (8 - numBytes) * 8) - 1; } else if (numBytes > 0) { - p = ((long) PlatformDependent.UNSAFE.getInt(base, offset)) << 32; + p = ((long) Platform.getInt(base, offset)) << 32; mask = (1L << (8 - numBytes) * 8) - 1; } else { p = 0; @@ -293,7 +287,7 @@ public boolean contains(final UTF8String substring) { * Returns the byte at position `i`. */ private byte getByte(int i) { - return UNSAFE.getByte(base, offset + i); + return Platform.getByte(base, offset + i); } private boolean matchAt(final UTF8String s, int pos) { @@ -769,7 +763,7 @@ public static UTF8String concatWs(UTF8String separator, UTF8String... inputs) { int len = inputs[i].numBytes; copyMemory( inputs[i].base, inputs[i].offset, - result, PlatformDependent.BYTE_ARRAY_OFFSET + offset, + result, BYTE_ARRAY_OFFSET + offset, len); offset += len; @@ -778,7 +772,7 @@ public static UTF8String concatWs(UTF8String separator, UTF8String... inputs) { if (j < numInputs) { copyMemory( separator.base, separator.offset, - result, PlatformDependent.BYTE_ARRAY_OFFSET + offset, + result, BYTE_ARRAY_OFFSET + offset, separator.numBytes); offset += separator.numBytes; } diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/hash/Murmur3_x86_32Suite.java b/unsafe/src/test/java/org/apache/spark/unsafe/hash/Murmur3_x86_32Suite.java index 3b9175835229c..2f8cb132ac8b4 100644 --- a/unsafe/src/test/java/org/apache/spark/unsafe/hash/Murmur3_x86_32Suite.java +++ b/unsafe/src/test/java/org/apache/spark/unsafe/hash/Murmur3_x86_32Suite.java @@ -22,7 +22,7 @@ import java.util.Set; import junit.framework.Assert; -import org.apache.spark.unsafe.PlatformDependent; +import org.apache.spark.unsafe.Platform; import org.junit.Test; /** @@ -83,11 +83,11 @@ public void randomizedStressTestBytes() { rand.nextBytes(bytes); Assert.assertEquals( - hasher.hashUnsafeWords(bytes, PlatformDependent.BYTE_ARRAY_OFFSET, byteArrSize), - hasher.hashUnsafeWords(bytes, PlatformDependent.BYTE_ARRAY_OFFSET, byteArrSize)); + hasher.hashUnsafeWords(bytes, Platform.BYTE_ARRAY_OFFSET, byteArrSize), + hasher.hashUnsafeWords(bytes, Platform.BYTE_ARRAY_OFFSET, byteArrSize)); hashcodes.add(hasher.hashUnsafeWords( - bytes, PlatformDependent.BYTE_ARRAY_OFFSET, byteArrSize)); + bytes, Platform.BYTE_ARRAY_OFFSET, byteArrSize)); } // A very loose bound. @@ -106,11 +106,11 @@ public void randomizedStressTestPaddedStrings() { System.arraycopy(strBytes, 0, paddedBytes, 0, strBytes.length); Assert.assertEquals( - hasher.hashUnsafeWords(paddedBytes, PlatformDependent.BYTE_ARRAY_OFFSET, byteArrSize), - hasher.hashUnsafeWords(paddedBytes, PlatformDependent.BYTE_ARRAY_OFFSET, byteArrSize)); + hasher.hashUnsafeWords(paddedBytes, Platform.BYTE_ARRAY_OFFSET, byteArrSize), + hasher.hashUnsafeWords(paddedBytes, Platform.BYTE_ARRAY_OFFSET, byteArrSize)); hashcodes.add(hasher.hashUnsafeWords( - paddedBytes, PlatformDependent.BYTE_ARRAY_OFFSET, byteArrSize)); + paddedBytes, Platform.BYTE_ARRAY_OFFSET, byteArrSize)); } // A very loose bound. From dfe347d2cae3eb05d7539aaf72db3d309e711213 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Tue, 11 Aug 2015 08:52:15 -0700 Subject: [PATCH 019/253] [SPARK-9785] [SQL] HashPartitioning compatibility should consider expression ordering HashPartitioning compatibility is currently defined w.r.t the _set_ of expressions, but the ordering of those expressions matters when computing hash codes; this could lead to incorrect answers if we mistakenly avoided a shuffle based on the assumption that HashPartitionings with the same expressions in different orders will produce equivalent row hashcodes. The first commit adds a regression test which illustrates this problem. The fix for this is simple: make `HashPartitioning.compatibleWith` and `HashPartitioning.guarantees` sensitive to the expression ordering (i.e. do not perform set comparison). Author: Josh Rosen Closes #8074 from JoshRosen/hashpartitioning-compatiblewith-fixes and squashes the following commits: b61412f [Josh Rosen] Demonstrate that I haven't cheated in my fix 0b4d7d9 [Josh Rosen] Update so that clusteringSet is only used in satisfies(). dc9c9d7 [Josh Rosen] Add failing regression test for SPARK-9785 --- .../plans/physical/partitioning.scala | 15 ++--- .../sql/catalyst/PartitioningSuite.scala | 55 +++++++++++++++++++ 2 files changed, 60 insertions(+), 10 deletions(-) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/PartitioningSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index 5a89a90b735a6..5ac3f1f5b0cac 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -216,26 +216,23 @@ case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int) override def nullable: Boolean = false override def dataType: DataType = IntegerType - lazy val clusteringSet = expressions.toSet - override def satisfies(required: Distribution): Boolean = required match { case UnspecifiedDistribution => true case ClusteredDistribution(requiredClustering) => - clusteringSet.subsetOf(requiredClustering.toSet) + expressions.toSet.subsetOf(requiredClustering.toSet) case _ => false } override def compatibleWith(other: Partitioning): Boolean = other match { - case o: HashPartitioning => - this.clusteringSet == o.clusteringSet && this.numPartitions == o.numPartitions + case o: HashPartitioning => this == o case _ => false } override def guarantees(other: Partitioning): Boolean = other match { - case o: HashPartitioning => - this.clusteringSet == o.clusteringSet && this.numPartitions == o.numPartitions + case o: HashPartitioning => this == o case _ => false } + } /** @@ -257,15 +254,13 @@ case class RangePartitioning(ordering: Seq[SortOrder], numPartitions: Int) override def nullable: Boolean = false override def dataType: DataType = IntegerType - private[this] lazy val clusteringSet = ordering.map(_.child).toSet - override def satisfies(required: Distribution): Boolean = required match { case UnspecifiedDistribution => true case OrderedDistribution(requiredOrdering) => val minSize = Seq(requiredOrdering.size, ordering.size).min requiredOrdering.take(minSize) == ordering.take(minSize) case ClusteredDistribution(requiredClustering) => - clusteringSet.subsetOf(requiredClustering.toSet) + ordering.map(_.child).toSet.subsetOf(requiredClustering.toSet) case _ => false } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/PartitioningSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/PartitioningSuite.scala new file mode 100644 index 0000000000000..5b802ccc637dd --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/PartitioningSuite.scala @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.expressions.{InterpretedMutableProjection, Literal} +import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, HashPartitioning} + +class PartitioningSuite extends SparkFunSuite { + test("HashPartitioning compatibility should be sensitive to expression ordering (SPARK-9785)") { + val expressions = Seq(Literal(2), Literal(3)) + // Consider two HashPartitionings that have the same _set_ of hash expressions but which are + // created with different orderings of those expressions: + val partitioningA = HashPartitioning(expressions, 100) + val partitioningB = HashPartitioning(expressions.reverse, 100) + // These partitionings are not considered equal: + assert(partitioningA != partitioningB) + // However, they both satisfy the same clustered distribution: + val distribution = ClusteredDistribution(expressions) + assert(partitioningA.satisfies(distribution)) + assert(partitioningB.satisfies(distribution)) + // These partitionings compute different hashcodes for the same input row: + def computeHashCode(partitioning: HashPartitioning): Int = { + val hashExprProj = new InterpretedMutableProjection(partitioning.expressions, Seq.empty) + hashExprProj.apply(InternalRow.empty).hashCode() + } + assert(computeHashCode(partitioningA) != computeHashCode(partitioningB)) + // Thus, these partitionings are incompatible: + assert(!partitioningA.compatibleWith(partitioningB)) + assert(!partitioningB.compatibleWith(partitioningA)) + assert(!partitioningA.guarantees(partitioningB)) + assert(!partitioningB.guarantees(partitioningA)) + + // Just to be sure that we haven't cheated by having these methods always return false, + // check that identical partitionings are still compatible with and guarantee each other: + assert(partitioningA === partitioningA) + assert(partitioningA.guarantees(partitioningA)) + assert(partitioningA.compatibleWith(partitioningA)) + } +} From bce72797f3499f14455722600b0d0898d4fd87c9 Mon Sep 17 00:00:00 2001 From: Jeff Zhang Date: Tue, 11 Aug 2015 10:42:17 -0700 Subject: [PATCH 020/253] Fix comment error API is updated but its doc comment is not updated. Author: Jeff Zhang Closes #8097 from zjffdu/dev. --- core/src/main/scala/org/apache/spark/SparkContext.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 9ced44131b0d9..6aafb4c5644d7 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -866,7 +866,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * }}} * * Do - * `val rdd = sparkContext.dataStreamFiles("hdfs://a-hdfs-path")`, + * `val rdd = sparkContext.binaryFiles("hdfs://a-hdfs-path")`, * * then `rdd` contains * {{{ From 8cad854ef6a2066de5adffcca6b79a205ccfd5f3 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Tue, 11 Aug 2015 11:01:59 -0700 Subject: [PATCH 021/253] [SPARK-8345] [ML] Add an SQL node as a feature transformer Implements the transforms which are defined by SQL statement. Currently we only support SQL syntax like 'SELECT ... FROM __THIS__' where '__THIS__' represents the underlying table of the input dataset. Author: Yanbo Liang Closes #7465 from yanboliang/spark-8345 and squashes the following commits: b403fcb [Yanbo Liang] address comments 0d4bb15 [Yanbo Liang] a better transformSchema() implementation 51eb9e7 [Yanbo Liang] Add an SQL node as a feature transformer --- .../spark/ml/feature/SQLTransformer.scala | 72 +++++++++++++++++++ .../ml/feature/SQLTransformerSuite.scala | 44 ++++++++++++ 2 files changed, 116 insertions(+) create mode 100644 mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala create mode 100644 mllib/src/test/scala/org/apache/spark/ml/feature/SQLTransformerSuite.scala diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala new file mode 100644 index 0000000000000..95e4305638730 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import org.apache.spark.SparkContext +import org.apache.spark.annotation.Experimental +import org.apache.spark.ml.param.{ParamMap, Param} +import org.apache.spark.ml.Transformer +import org.apache.spark.ml.util.Identifiable +import org.apache.spark.sql.{SQLContext, DataFrame, Row} +import org.apache.spark.sql.types.StructType + +/** + * :: Experimental :: + * Implements the transforms which are defined by SQL statement. + * Currently we only support SQL syntax like 'SELECT ... FROM __THIS__' + * where '__THIS__' represents the underlying table of the input dataset. + */ +@Experimental +class SQLTransformer (override val uid: String) extends Transformer { + + def this() = this(Identifiable.randomUID("sql")) + + /** + * SQL statement parameter. The statement is provided in string form. + * @group param + */ + final val statement: Param[String] = new Param[String](this, "statement", "SQL statement") + + /** @group setParam */ + def setStatement(value: String): this.type = set(statement, value) + + /** @group getParam */ + def getStatement: String = $(statement) + + private val tableIdentifier: String = "__THIS__" + + override def transform(dataset: DataFrame): DataFrame = { + val tableName = Identifiable.randomUID(uid) + dataset.registerTempTable(tableName) + val realStatement = $(statement).replace(tableIdentifier, tableName) + val outputDF = dataset.sqlContext.sql(realStatement) + outputDF + } + + override def transformSchema(schema: StructType): StructType = { + val sc = SparkContext.getOrCreate() + val sqlContext = SQLContext.getOrCreate(sc) + val dummyRDD = sc.parallelize(Seq(Row.empty)) + val dummyDF = sqlContext.createDataFrame(dummyRDD, schema) + dummyDF.registerTempTable(tableIdentifier) + val outputSchema = sqlContext.sql($(statement)).schema + outputSchema + } + + override def copy(extra: ParamMap): SQLTransformer = defaultCopy(extra) +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/SQLTransformerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/SQLTransformerSuite.scala new file mode 100644 index 0000000000000..d19052881ae45 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/SQLTransformerSuite.scala @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.mllib.util.MLlibTestSparkContext + +class SQLTransformerSuite extends SparkFunSuite with MLlibTestSparkContext { + + test("params") { + ParamsSuite.checkParams(new SQLTransformer()) + } + + test("transform numeric data") { + val original = sqlContext.createDataFrame( + Seq((0, 1.0, 3.0), (2, 2.0, 5.0))).toDF("id", "v1", "v2") + val sqlTrans = new SQLTransformer().setStatement( + "SELECT *, (v1 + v2) AS v3, (v1 * v2) AS v4 FROM __THIS__") + val result = sqlTrans.transform(original) + val resultSchema = sqlTrans.transformSchema(original.schema) + val expected = sqlContext.createDataFrame( + Seq((0, 1.0, 3.0, 4.0, 3.0), (2, 2.0, 5.0, 7.0, 10.0))) + .toDF("id", "v1", "v2", "v3", "v4") + assert(result.schema.toString == resultSchema.toString) + assert(resultSchema == expected.schema) + assert(result.collect().toSeq == expected.collect().toSeq) + } +} From dbd778d84d094ca142bc08c351478595b280bc2a Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Tue, 11 Aug 2015 11:33:36 -0700 Subject: [PATCH 022/253] [SPARK-8764] [ML] string indexer should take option to handle unseen values As a precursor to adding a public constructor add an option to handle unseen values by skipping rather than throwing an exception (default remains throwing an exception), Author: Holden Karau Closes #7266 from holdenk/SPARK-8764-string-indexer-should-take-option-to-handle-unseen-values and squashes the following commits: 38a4de9 [Holden Karau] fix long line 045bf22 [Holden Karau] Add a second b entry so b gets 0 for sure 81dd312 [Holden Karau] Update the docs for handleInvalid param to be more descriptive 7f37f6e [Holden Karau] remove extra space (scala style) 414e249 [Holden Karau] And switch to using handleInvalid instead of skipInvalid 1e53f9b [Holden Karau] update the param (codegen side) 7a22215 [Holden Karau] fix typo 100a39b [Holden Karau] Merge in master aa5b093 [Holden Karau] Since we filter we should never go down this code path if getSkipInvalid is true 75ffa69 [Holden Karau] Remove extra newline d69ef5e [Holden Karau] Add a test b5734be [Holden Karau] Add support for unseen labels afecd4e [Holden Karau] Add a param to skip invalid entries. --- .../spark/ml/feature/StringIndexer.scala | 26 ++++++++++++--- .../ml/param/shared/SharedParamsCodeGen.scala | 4 +++ .../spark/ml/param/shared/sharedParams.scala | 15 +++++++++ .../spark/ml/feature/StringIndexerSuite.scala | 32 +++++++++++++++++++ 4 files changed, 73 insertions(+), 4 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala index ebfa972532358..e4485eb038409 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala @@ -33,7 +33,8 @@ import org.apache.spark.util.collection.OpenHashMap /** * Base trait for [[StringIndexer]] and [[StringIndexerModel]]. */ -private[feature] trait StringIndexerBase extends Params with HasInputCol with HasOutputCol { +private[feature] trait StringIndexerBase extends Params with HasInputCol with HasOutputCol + with HasHandleInvalid { /** Validates and transforms the input schema. */ protected def validateAndTransformSchema(schema: StructType): StructType = { @@ -65,13 +66,16 @@ class StringIndexer(override val uid: String) extends Estimator[StringIndexerMod def this() = this(Identifiable.randomUID("strIdx")) + /** @group setParam */ + def setHandleInvalid(value: String): this.type = set(handleInvalid, value) + setDefault(handleInvalid, "error") + /** @group setParam */ def setInputCol(value: String): this.type = set(inputCol, value) /** @group setParam */ def setOutputCol(value: String): this.type = set(outputCol, value) - // TODO: handle unseen labels override def fit(dataset: DataFrame): StringIndexerModel = { val counts = dataset.select(col($(inputCol)).cast(StringType)) @@ -111,6 +115,10 @@ class StringIndexerModel private[ml] ( map } + /** @group setParam */ + def setHandleInvalid(value: String): this.type = set(handleInvalid, value) + setDefault(handleInvalid, "error") + /** @group setParam */ def setInputCol(value: String): this.type = set(inputCol, value) @@ -128,14 +136,24 @@ class StringIndexerModel private[ml] ( if (labelToIndex.contains(label)) { labelToIndex(label) } else { - // TODO: handle unseen labels throw new SparkException(s"Unseen label: $label.") } } + val outputColName = $(outputCol) val metadata = NominalAttribute.defaultAttr .withName(outputColName).withValues(labels).toMetadata() - dataset.select(col("*"), + // If we are skipping invalid records, filter them out. + val filteredDataset = (getHandleInvalid) match { + case "skip" => { + val filterer = udf { label: String => + labelToIndex.contains(label) + } + dataset.where(filterer(dataset($(inputCol)))) + } + case _ => dataset + } + filteredDataset.select(col("*"), indexer(dataset($(inputCol)).cast(StringType)).as(outputColName, metadata)) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala index a97c8059b8d45..da4c076830391 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala @@ -59,6 +59,10 @@ private[shared] object SharedParamsCodeGen { ParamDesc[Int]("checkpointInterval", "checkpoint interval (>= 1)", isValid = "ParamValidators.gtEq(1)"), ParamDesc[Boolean]("fitIntercept", "whether to fit an intercept term", Some("true")), + ParamDesc[String]("handleInvalid", "how to handle invalid entries. Options are skip (which " + + "will filter out rows with bad values), or error (which will throw an errror). More " + + "options may be added later.", + isValid = "ParamValidators.inArray(Array(\"skip\", \"error\"))"), ParamDesc[Boolean]("standardization", "whether to standardize the training features" + " before fitting the model.", Some("true")), ParamDesc[Long]("seed", "random seed", Some("this.getClass.getName.hashCode.toLong")), diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala index f332630c32f1b..23e2b6cc43996 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala @@ -247,6 +247,21 @@ private[ml] trait HasFitIntercept extends Params { final def getFitIntercept: Boolean = $(fitIntercept) } +/** + * Trait for shared param handleInvalid. + */ +private[ml] trait HasHandleInvalid extends Params { + + /** + * Param for how to handle invalid entries. Options are skip (which will filter out rows with bad values), or error (which will throw an errror). More options may be added later.. + * @group param + */ + final val handleInvalid: Param[String] = new Param[String](this, "handleInvalid", "how to handle invalid entries. Options are skip (which will filter out rows with bad values), or error (which will throw an errror). More options may be added later.", ParamValidators.inArray(Array("skip", "error"))) + + /** @group getParam */ + final def getHandleInvalid: String = $(handleInvalid) +} + /** * Trait for shared param standardization (default: true). */ diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala index d0295a0fe2fc1..b111036087e6a 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.ml.feature +import org.apache.spark.SparkException import org.apache.spark.SparkFunSuite import org.apache.spark.ml.attribute.{Attribute, NominalAttribute} import org.apache.spark.ml.param.ParamsSuite @@ -62,6 +63,37 @@ class StringIndexerSuite extends SparkFunSuite with MLlibTestSparkContext { reversed2.collect().map(r => (r.getInt(0), r.getString(1))).toSet) } + test("StringIndexerUnseen") { + val data = sc.parallelize(Seq((0, "a"), (1, "b"), (4, "b")), 2) + val data2 = sc.parallelize(Seq((0, "a"), (1, "b"), (2, "c")), 2) + val df = sqlContext.createDataFrame(data).toDF("id", "label") + val df2 = sqlContext.createDataFrame(data2).toDF("id", "label") + val indexer = new StringIndexer() + .setInputCol("label") + .setOutputCol("labelIndex") + .fit(df) + // Verify we throw by default with unseen values + intercept[SparkException] { + indexer.transform(df2).collect() + } + val indexerSkipInvalid = new StringIndexer() + .setInputCol("label") + .setOutputCol("labelIndex") + .setHandleInvalid("skip") + .fit(df) + // Verify that we skip the c record + val transformed = indexerSkipInvalid.transform(df2) + val attr = Attribute.fromStructField(transformed.schema("labelIndex")) + .asInstanceOf[NominalAttribute] + assert(attr.values.get === Array("b", "a")) + val output = transformed.select("id", "labelIndex").map { r => + (r.getInt(0), r.getDouble(1)) + }.collect().toSet + // a -> 1, b -> 0 + val expected = Set((0, 1.0), (1, 0.0)) + assert(output === expected) + } + test("StringIndexer with a numeric input column") { val data = sc.parallelize(Seq((0, 100), (1, 200), (2, 300), (3, 100), (4, 100), (5, 300)), 2) val df = sqlContext.createDataFrame(data).toDF("id", "label") From 5b8bb1b213b8738f563fcd00747604410fbb3087 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Tue, 11 Aug 2015 12:02:28 -0700 Subject: [PATCH 023/253] [SPARK-9572] [STREAMING] [PYSPARK] Added StreamingContext.getActiveOrCreate() in Python Author: Tathagata Das Closes #8080 from tdas/SPARK-9572 and squashes the following commits: 64a231d [Tathagata Das] Fix based on comments 741a0d0 [Tathagata Das] Fixed style f4f094c [Tathagata Das] Tweaked test 9afcdbe [Tathagata Das] Merge remote-tracking branch 'apache-github/master' into SPARK-9572 e21488d [Tathagata Das] Minor update 1a371d9 [Tathagata Das] Addressed comments. 60479da [Tathagata Das] Fixed indent 9c2da9c [Tathagata Das] Fixed bugs b5bd32c [Tathagata Das] Merge remote-tracking branch 'apache-github/master' into SPARK-9572 b55b348 [Tathagata Das] Removed prints 5781728 [Tathagata Das] Fix style issues b711214 [Tathagata Das] Reverted run-tests.py 643b59d [Tathagata Das] Revert unnecessary change 150e58c [Tathagata Das] Added StreamingContext.getActiveOrCreate() in Python --- python/pyspark/streaming/context.py | 57 +++++++++++- python/pyspark/streaming/tests.py | 133 +++++++++++++++++++++++++--- python/run-tests.py | 2 +- 3 files changed, 177 insertions(+), 15 deletions(-) diff --git a/python/pyspark/streaming/context.py b/python/pyspark/streaming/context.py index ac5ba69e8dbbb..e3ba70e4e5e88 100644 --- a/python/pyspark/streaming/context.py +++ b/python/pyspark/streaming/context.py @@ -86,6 +86,9 @@ class StreamingContext(object): """ _transformerSerializer = None + # Reference to a currently active StreamingContext + _activeContext = None + def __init__(self, sparkContext, batchDuration=None, jssc=None): """ Create a new StreamingContext. @@ -142,10 +145,10 @@ def getOrCreate(cls, checkpointPath, setupFunc): Either recreate a StreamingContext from checkpoint data or create a new StreamingContext. If checkpoint data exists in the provided `checkpointPath`, then StreamingContext will be recreated from the checkpoint data. If the data does not exist, then the provided setupFunc - will be used to create a JavaStreamingContext. + will be used to create a new context. - @param checkpointPath: Checkpoint directory used in an earlier JavaStreamingContext program - @param setupFunc: Function to create a new JavaStreamingContext and setup DStreams + @param checkpointPath: Checkpoint directory used in an earlier streaming program + @param setupFunc: Function to create a new context and setup DStreams """ # TODO: support checkpoint in HDFS if not os.path.exists(checkpointPath) or not os.listdir(checkpointPath): @@ -170,6 +173,52 @@ def getOrCreate(cls, checkpointPath, setupFunc): cls._transformerSerializer.ctx = sc return StreamingContext(sc, None, jssc) + @classmethod + def getActive(cls): + """ + Return either the currently active StreamingContext (i.e., if there is a context started + but not stopped) or None. + """ + activePythonContext = cls._activeContext + if activePythonContext is not None: + # Verify that the current running Java StreamingContext is active and is the same one + # backing the supposedly active Python context + activePythonContextJavaId = activePythonContext._jssc.ssc().hashCode() + activeJvmContextOption = activePythonContext._jvm.StreamingContext.getActive() + + if activeJvmContextOption.isEmpty(): + cls._activeContext = None + elif activeJvmContextOption.get().hashCode() != activePythonContextJavaId: + cls._activeContext = None + raise Exception("JVM's active JavaStreamingContext is not the JavaStreamingContext " + "backing the action Python StreamingContext. This is unexpected.") + return cls._activeContext + + @classmethod + def getActiveOrCreate(cls, checkpointPath, setupFunc): + """ + Either return the active StreamingContext (i.e. currently started but not stopped), + or recreate a StreamingContext from checkpoint data or create a new StreamingContext + using the provided setupFunc function. If the checkpointPath is None or does not contain + valid checkpoint data, then setupFunc will be called to create a new context and setup + DStreams. + + @param checkpointPath: Checkpoint directory used in an earlier streaming program. Can be + None if the intention is to always create a new context when there + is no active context. + @param setupFunc: Function to create a new JavaStreamingContext and setup DStreams + """ + + if setupFunc is None: + raise Exception("setupFunc cannot be None") + activeContext = cls.getActive() + if activeContext is not None: + return activeContext + elif checkpointPath is not None: + return cls.getOrCreate(checkpointPath, setupFunc) + else: + return setupFunc() + @property def sparkContext(self): """ @@ -182,6 +231,7 @@ def start(self): Start the execution of the streams. """ self._jssc.start() + StreamingContext._activeContext = self def awaitTermination(self, timeout=None): """ @@ -212,6 +262,7 @@ def stop(self, stopSparkContext=True, stopGraceFully=False): of all received data to be completed """ self._jssc.stop(stopSparkContext, stopGraceFully) + StreamingContext._activeContext = None if stopSparkContext: self._sc.stop() diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py index f0ed415f97120..6108c845c1efe 100644 --- a/python/pyspark/streaming/tests.py +++ b/python/pyspark/streaming/tests.py @@ -24,6 +24,7 @@ import tempfile import random import struct +import shutil from functools import reduce if sys.version_info[:2] <= (2, 6): @@ -59,12 +60,21 @@ def setUpClass(cls): @classmethod def tearDownClass(cls): cls.sc.stop() + # Clean up in the JVM just in case there has been some issues in Python API + jSparkContextOption = SparkContext._jvm.SparkContext.get() + if jSparkContextOption.nonEmpty(): + jSparkContextOption.get().stop() def setUp(self): self.ssc = StreamingContext(self.sc, self.duration) def tearDown(self): - self.ssc.stop(False) + if self.ssc is not None: + self.ssc.stop(False) + # Clean up in the JVM just in case there has been some issues in Python API + jStreamingContextOption = StreamingContext._jvm.SparkContext.getActive() + if jStreamingContextOption.nonEmpty(): + jStreamingContextOption.get().stop(False) def wait_for(self, result, n): start_time = time.time() @@ -442,6 +452,7 @@ def test_reduce_by_invalid_window(self): class StreamingContextTests(PySparkStreamingTestCase): duration = 0.1 + setupCalled = False def _add_input_stream(self): inputs = [range(1, x) for x in range(101)] @@ -515,10 +526,85 @@ def func(rdds): self.assertEqual([2, 3, 1], self._take(dstream, 3)) + def test_get_active(self): + self.assertEqual(StreamingContext.getActive(), None) + + # Verify that getActive() returns the active context + self.ssc.queueStream([[1]]).foreachRDD(lambda rdd: rdd.count()) + self.ssc.start() + self.assertEqual(StreamingContext.getActive(), self.ssc) + + # Verify that getActive() returns None + self.ssc.stop(False) + self.assertEqual(StreamingContext.getActive(), None) + + # Verify that if the Java context is stopped, then getActive() returns None + self.ssc = StreamingContext(self.sc, self.duration) + self.ssc.queueStream([[1]]).foreachRDD(lambda rdd: rdd.count()) + self.ssc.start() + self.assertEqual(StreamingContext.getActive(), self.ssc) + self.ssc._jssc.stop(False) + self.assertEqual(StreamingContext.getActive(), None) + + def test_get_active_or_create(self): + # Test StreamingContext.getActiveOrCreate() without checkpoint data + # See CheckpointTests for tests with checkpoint data + self.ssc = None + self.assertEqual(StreamingContext.getActive(), None) + + def setupFunc(): + ssc = StreamingContext(self.sc, self.duration) + ssc.queueStream([[1]]).foreachRDD(lambda rdd: rdd.count()) + self.setupCalled = True + return ssc + + # Verify that getActiveOrCreate() (w/o checkpoint) calls setupFunc when no context is active + self.setupCalled = False + self.ssc = StreamingContext.getActiveOrCreate(None, setupFunc) + self.assertTrue(self.setupCalled) + + # Verify that getActiveOrCreate() retuns active context and does not call the setupFunc + self.ssc.start() + self.setupCalled = False + self.assertEqual(StreamingContext.getActiveOrCreate(None, setupFunc), self.ssc) + self.assertFalse(self.setupCalled) + + # Verify that getActiveOrCreate() calls setupFunc after active context is stopped + self.ssc.stop(False) + self.setupCalled = False + self.ssc = StreamingContext.getActiveOrCreate(None, setupFunc) + self.assertTrue(self.setupCalled) + + # Verify that if the Java context is stopped, then getActive() returns None + self.ssc = StreamingContext(self.sc, self.duration) + self.ssc.queueStream([[1]]).foreachRDD(lambda rdd: rdd.count()) + self.ssc.start() + self.assertEqual(StreamingContext.getActive(), self.ssc) + self.ssc._jssc.stop(False) + self.setupCalled = False + self.ssc = StreamingContext.getActiveOrCreate(None, setupFunc) + self.assertTrue(self.setupCalled) + class CheckpointTests(unittest.TestCase): - def test_get_or_create(self): + setupCalled = False + + @staticmethod + def tearDownClass(): + # Clean up in the JVM just in case there has been some issues in Python API + jStreamingContextOption = StreamingContext._jvm.SparkContext.getActive() + if jStreamingContextOption.nonEmpty(): + jStreamingContextOption.get().stop() + jSparkContextOption = SparkContext._jvm.SparkContext.get() + if jSparkContextOption.nonEmpty(): + jSparkContextOption.get().stop() + + def tearDown(self): + if self.ssc is not None: + self.ssc.stop(True) + + def test_get_or_create_and_get_active_or_create(self): inputd = tempfile.mkdtemp() outputd = tempfile.mkdtemp() + "/" @@ -533,11 +619,12 @@ def setup(): wc = dstream.updateStateByKey(updater) wc.map(lambda x: "%s,%d" % x).saveAsTextFiles(outputd + "test") wc.checkpoint(.5) + self.setupCalled = True return ssc cpd = tempfile.mkdtemp("test_streaming_cps") - ssc = StreamingContext.getOrCreate(cpd, setup) - ssc.start() + self.ssc = StreamingContext.getOrCreate(cpd, setup) + self.ssc.start() def check_output(n): while not os.listdir(outputd): @@ -552,7 +639,7 @@ def check_output(n): # not finished time.sleep(0.01) continue - ordd = ssc.sparkContext.textFile(p).map(lambda line: line.split(",")) + ordd = self.ssc.sparkContext.textFile(p).map(lambda line: line.split(",")) d = ordd.values().map(int).collect() if not d: time.sleep(0.01) @@ -568,13 +655,37 @@ def check_output(n): check_output(1) check_output(2) - ssc.stop(True, True) + # Verify the getOrCreate() recovers from checkpoint files + self.ssc.stop(True, True) time.sleep(1) - ssc = StreamingContext.getOrCreate(cpd, setup) - ssc.start() + self.setupCalled = False + self.ssc = StreamingContext.getOrCreate(cpd, setup) + self.assertFalse(self.setupCalled) + self.ssc.start() check_output(3) - ssc.stop(True, True) + + # Verify the getActiveOrCreate() recovers from checkpoint files + self.ssc.stop(True, True) + time.sleep(1) + self.setupCalled = False + self.ssc = StreamingContext.getActiveOrCreate(cpd, setup) + self.assertFalse(self.setupCalled) + self.ssc.start() + check_output(4) + + # Verify that getActiveOrCreate() returns active context + self.setupCalled = False + self.assertEquals(StreamingContext.getActiveOrCreate(cpd, setup), self.ssc) + self.assertFalse(self.setupCalled) + + # Verify that getActiveOrCreate() calls setup() in absence of checkpoint files + self.ssc.stop(True, True) + shutil.rmtree(cpd) # delete checkpoint directory + self.setupCalled = False + self.ssc = StreamingContext.getActiveOrCreate(cpd, setup) + self.assertTrue(self.setupCalled) + self.ssc.stop(True, True) class KafkaStreamTests(PySparkStreamingTestCase): @@ -1134,7 +1245,7 @@ def search_kinesis_asl_assembly_jar(): testcases.append(KinesisStreamTests) elif are_kinesis_tests_enabled is False: sys.stderr.write("Skipping all Kinesis Python tests as the optional Kinesis project was " - "not compiled with -Pkinesis-asl profile. To run these tests, " + "not compiled into a JAR. To run these tests, " "you need to build Spark with 'build/sbt -Pkinesis-asl assembly/assembly " "streaming-kinesis-asl-assembly/assembly' or " "'build/mvn -Pkinesis-asl package' before running this test.") @@ -1150,4 +1261,4 @@ def search_kinesis_asl_assembly_jar(): for testcase in testcases: sys.stderr.write("[Running %s]\n" % (testcase)) tests = unittest.TestLoader().loadTestsFromTestCase(testcase) - unittest.TextTestRunner(verbosity=2).run(tests) + unittest.TextTestRunner(verbosity=3).run(tests) diff --git a/python/run-tests.py b/python/run-tests.py index cc560779373b3..fd56c7ab6e0e2 100755 --- a/python/run-tests.py +++ b/python/run-tests.py @@ -158,7 +158,7 @@ def main(): else: log_level = logging.INFO logging.basicConfig(stream=sys.stdout, level=log_level, format="%(message)s") - LOGGER.info("Running PySpark tests. Output is in python/%s", LOG_FILE) + LOGGER.info("Running PySpark tests. Output is in %s", LOG_FILE) if os.path.exists(LOG_FILE): os.remove(LOG_FILE) python_execs = opts.python_executables.split(',') From 5831294a7a8fa2524133c5d718cbc8187d2b0620 Mon Sep 17 00:00:00 2001 From: zsxwing Date: Tue, 11 Aug 2015 12:39:13 -0700 Subject: [PATCH 024/253] [SPARK-9646] [SQL] Add metrics for all join and aggregate operators This PR added metrics for all join and aggregate operators. However, I found the metrics may be confusing in the following two case: 1. The iterator is not totally consumed and the metric values will be less. 2. Recreating the iterators will make metric values look bigger than the size of the input source, such as `CartesianProduct`. Author: zsxwing Closes #8060 from zsxwing/sql-metrics and squashes the following commits: 40f3fc1 [zsxwing] Mark LongSQLMetric private[metric] to avoid using incorrectly and leak memory b1b9071 [zsxwing] Merge branch 'master' into sql-metrics 4bef25a [zsxwing] Add metrics for SortMergeOuterJoin 95ccfc6 [zsxwing] Merge branch 'master' into sql-metrics 67cb4dd [zsxwing] Add metrics for Project and TungstenProject; remove metrics from PhysicalRDD and LocalTableScan 0eb47d4 [zsxwing] Merge branch 'master' into sql-metrics dd9d932 [zsxwing] Avoid creating new Iterators 589ea26 [zsxwing] Add metrics for all join and aggregate operators --- .../spark/sql/execution/Aggregate.scala | 11 + .../spark/sql/execution/ExistingRDD.scala | 2 - .../spark/sql/execution/LocalTableScan.scala | 2 - .../spark/sql/execution/SparkPlan.scala | 25 +- .../aggregate/SortBasedAggregate.scala | 12 +- .../SortBasedAggregationIterator.scala | 18 +- .../aggregate/TungstenAggregate.scala | 12 +- .../TungstenAggregationIterator.scala | 11 +- .../spark/sql/execution/basicOperators.scala | 36 +- .../execution/joins/BroadcastHashJoin.scala | 30 +- .../joins/BroadcastHashOuterJoin.scala | 40 +- .../joins/BroadcastLeftSemiJoinHash.scala | 24 +- .../joins/BroadcastNestedLoopJoin.scala | 27 +- .../execution/joins/CartesianProduct.scala | 25 +- .../spark/sql/execution/joins/HashJoin.scala | 7 +- .../sql/execution/joins/HashOuterJoin.scala | 30 +- .../sql/execution/joins/HashSemiJoin.scala | 23 +- .../sql/execution/joins/HashedRelation.scala | 8 +- .../sql/execution/joins/LeftSemiJoinBNL.scala | 19 +- .../execution/joins/LeftSemiJoinHash.scala | 18 +- .../execution/joins/ShuffledHashJoin.scala | 16 +- .../joins/ShuffledHashOuterJoin.scala | 29 +- .../sql/execution/joins/SortMergeJoin.scala | 21 +- .../execution/joins/SortMergeOuterJoin.scala | 38 +- .../sql/execution/metric/SQLMetrics.scala | 6 + .../execution/joins/HashedRelationSuite.scala | 14 +- .../execution/metric/SQLMetricsSuite.scala | 450 +++++++++++++++++- 27 files changed, 847 insertions(+), 107 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala index e8c6a0f8f801d..f3b6a3a5f4a33 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala @@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.errors._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical._ +import org.apache.spark.sql.execution.metric.SQLMetrics /** * :: DeveloperApi :: @@ -45,6 +46,10 @@ case class Aggregate( child: SparkPlan) extends UnaryNode { + override private[sql] lazy val metrics = Map( + "numInputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of input rows"), + "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows")) + override def requiredChildDistribution: List[Distribution] = { if (partial) { UnspecifiedDistribution :: Nil @@ -121,12 +126,15 @@ case class Aggregate( } protected override def doExecute(): RDD[InternalRow] = attachTree(this, "execute") { + val numInputRows = longMetric("numInputRows") + val numOutputRows = longMetric("numOutputRows") if (groupingExpressions.isEmpty) { child.execute().mapPartitions { iter => val buffer = newAggregateBuffer() var currentRow: InternalRow = null while (iter.hasNext) { currentRow = iter.next() + numInputRows += 1 var i = 0 while (i < buffer.length) { buffer(i).update(currentRow) @@ -142,6 +150,7 @@ case class Aggregate( i += 1 } + numOutputRows += 1 Iterator(resultProjection(aggregateResults)) } } else { @@ -152,6 +161,7 @@ case class Aggregate( var currentRow: InternalRow = null while (iter.hasNext) { currentRow = iter.next() + numInputRows += 1 val currentGroup = groupingProjection(currentRow) var currentBuffer = hashTable.get(currentGroup) if (currentBuffer == null) { @@ -180,6 +190,7 @@ case class Aggregate( val currentEntry = hashTableIter.next() val currentGroup = currentEntry.getKey val currentBuffer = currentEntry.getValue + numOutputRows += 1 var i = 0 while (i < currentBuffer.length) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala index cae7ca5cbdc88..abb60cf12e3a5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala @@ -99,8 +99,6 @@ private[sql] case class PhysicalRDD( rdd: RDD[InternalRow], extraInformation: String) extends LeafNode { - override protected[sql] val trackNumOfRowsEnabled = true - protected override def doExecute(): RDD[InternalRow] = rdd override def simpleString: String = "Scan " + extraInformation + output.mkString("[", ",", "]") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScan.scala index 858dd85fd1fa6..34e926e4582be 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScan.scala @@ -30,8 +30,6 @@ private[sql] case class LocalTableScan( output: Seq[Attribute], rows: Seq[InternalRow]) extends LeafNode { - override protected[sql] val trackNumOfRowsEnabled = true - private lazy val rdd = sqlContext.sparkContext.parallelize(rows) protected override def doExecute(): RDD[InternalRow] = rdd diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index 9ba5cf2d2b39e..72f5450510a10 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -80,23 +80,10 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ super.makeCopy(newArgs) } - /** - * Whether track the number of rows output by this SparkPlan - */ - protected[sql] def trackNumOfRowsEnabled: Boolean = false - - private lazy val defaultMetrics: Map[String, SQLMetric[_, _]] = - if (trackNumOfRowsEnabled) { - Map("numRows" -> SQLMetrics.createLongMetric(sparkContext, "number of rows")) - } - else { - Map.empty - } - /** * Return all metrics containing metrics of this SparkPlan. */ - private[sql] def metrics: Map[String, SQLMetric[_, _]] = defaultMetrics + private[sql] def metrics: Map[String, SQLMetric[_, _]] = Map.empty /** * Return a LongSQLMetric according to the name. @@ -150,15 +137,7 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ } RDDOperationScope.withScope(sparkContext, nodeName, false, true) { prepare() - if (trackNumOfRowsEnabled) { - val numRows = longMetric("numRows") - doExecute().map { row => - numRows += 1 - row - } - } else { - doExecute() - } + doExecute() } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala index ad428ad663f30..ab26f9c58aa2e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala @@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.physical.{UnspecifiedDistribution, ClusteredDistribution, AllTuples, Distribution} import org.apache.spark.sql.execution.{UnsafeFixedWidthAggregationMap, SparkPlan, UnaryNode} +import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.types.StructType case class SortBasedAggregate( @@ -38,6 +39,10 @@ case class SortBasedAggregate( child: SparkPlan) extends UnaryNode { + override private[sql] lazy val metrics = Map( + "numInputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of input rows"), + "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows")) + override def outputsUnsafeRows: Boolean = false override def canProcessUnsafeRows: Boolean = false @@ -63,6 +68,8 @@ case class SortBasedAggregate( } protected override def doExecute(): RDD[InternalRow] = attachTree(this, "execute") { + val numInputRows = longMetric("numInputRows") + val numOutputRows = longMetric("numOutputRows") child.execute().mapPartitions { iter => // Because the constructor of an aggregation iterator will read at least the first row, // we need to get the value of iter.hasNext first. @@ -84,10 +91,13 @@ case class SortBasedAggregate( newProjection _, child.output, iter, - outputsUnsafeRows) + outputsUnsafeRows, + numInputRows, + numOutputRows) if (!hasInput && groupingExpressions.isEmpty) { // There is no input and there is no grouping expressions. // We need to output a single row as the output. + numOutputRows += 1 Iterator[InternalRow](outputIter.outputForEmptyGroupingKeyWithoutInput()) } else { outputIter diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala index 67ebafde25ad3..73d50e07cf0b5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.aggregate import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression2, AggregateFunction2} +import org.apache.spark.sql.execution.metric.LongSQLMetric import org.apache.spark.unsafe.KVIterator /** @@ -37,7 +38,9 @@ class SortBasedAggregationIterator( initialInputBufferOffset: Int, resultExpressions: Seq[NamedExpression], newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection), - outputsUnsafeRows: Boolean) + outputsUnsafeRows: Boolean, + numInputRows: LongSQLMetric, + numOutputRows: LongSQLMetric) extends AggregationIterator( groupingKeyAttributes, valueAttributes, @@ -103,6 +106,7 @@ class SortBasedAggregationIterator( // Get the grouping key. val groupingKey = inputKVIterator.getKey val currentRow = inputKVIterator.getValue + numInputRows += 1 // Check if the current row belongs the current input row. if (currentGroupingKey == groupingKey) { @@ -137,7 +141,7 @@ class SortBasedAggregationIterator( val outputRow = generateOutput(currentGroupingKey, sortBasedAggregationBuffer) // Initialize buffer values for the next group. initializeBuffer(sortBasedAggregationBuffer) - + numOutputRows += 1 outputRow } else { // no more result @@ -151,7 +155,7 @@ class SortBasedAggregationIterator( nextGroupingKey = inputKVIterator.getKey().copy() firstRowInNextGroup = inputKVIterator.getValue().copy() - + numInputRows += 1 sortedInputHasNewGroup = true } else { // This inputIter is empty. @@ -181,7 +185,9 @@ object SortBasedAggregationIterator { newProjection: (Seq[Expression], Seq[Attribute]) => Projection, inputAttributes: Seq[Attribute], inputIter: Iterator[InternalRow], - outputsUnsafeRows: Boolean): SortBasedAggregationIterator = { + outputsUnsafeRows: Boolean, + numInputRows: LongSQLMetric, + numOutputRows: LongSQLMetric): SortBasedAggregationIterator = { val kvIterator = if (UnsafeProjection.canSupport(groupingExprs)) { AggregationIterator.unsafeKVIterator( groupingExprs, @@ -202,7 +208,9 @@ object SortBasedAggregationIterator { initialInputBufferOffset, resultExpressions, newMutableProjection, - outputsUnsafeRows) + outputsUnsafeRows, + numInputRows, + numOutputRows) } // scalastyle:on } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala index 1694794a53d9f..6b5935a7ce296 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala @@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression2 import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical.{UnspecifiedDistribution, ClusteredDistribution, AllTuples, Distribution} import org.apache.spark.sql.execution.{UnaryNode, SparkPlan} +import org.apache.spark.sql.execution.metric.SQLMetrics case class TungstenAggregate( requiredChildDistributionExpressions: Option[Seq[Expression]], @@ -35,6 +36,10 @@ case class TungstenAggregate( child: SparkPlan) extends UnaryNode { + override private[sql] lazy val metrics = Map( + "numInputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of input rows"), + "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows")) + override def outputsUnsafeRows: Boolean = true override def canProcessUnsafeRows: Boolean = true @@ -61,6 +66,8 @@ case class TungstenAggregate( } protected override def doExecute(): RDD[InternalRow] = attachTree(this, "execute") { + val numInputRows = longMetric("numInputRows") + val numOutputRows = longMetric("numOutputRows") child.execute().mapPartitions { iter => val hasInput = iter.hasNext if (!hasInput && groupingExpressions.nonEmpty) { @@ -78,9 +85,12 @@ case class TungstenAggregate( newMutableProjection, child.output, iter, - testFallbackStartsAt) + testFallbackStartsAt, + numInputRows, + numOutputRows) if (!hasInput && groupingExpressions.isEmpty) { + numOutputRows += 1 Iterator.single[UnsafeRow](aggregationIterator.outputForEmptyGroupingKeyWithoutInput()) } else { aggregationIterator diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala index 32160906c3bc8..1f383dd04482f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala @@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeRowJoiner import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.execution.{UnsafeKVExternalSorter, UnsafeFixedWidthAggregationMap} +import org.apache.spark.sql.execution.metric.LongSQLMetric import org.apache.spark.sql.types.StructType /** @@ -83,7 +84,9 @@ class TungstenAggregationIterator( newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection), originalInputAttributes: Seq[Attribute], inputIter: Iterator[InternalRow], - testFallbackStartsAt: Option[Int]) + testFallbackStartsAt: Option[Int], + numInputRows: LongSQLMetric, + numOutputRows: LongSQLMetric) extends Iterator[UnsafeRow] with Logging { /////////////////////////////////////////////////////////////////////////// @@ -352,6 +355,7 @@ class TungstenAggregationIterator( private def processInputs(): Unit = { while (!sortBased && inputIter.hasNext) { val newInput = inputIter.next() + numInputRows += 1 val groupingKey = groupProjection.apply(newInput) val buffer: UnsafeRow = hashMap.getAggregationBufferFromUnsafeRow(groupingKey) if (buffer == null) { @@ -371,6 +375,7 @@ class TungstenAggregationIterator( var i = 0 while (!sortBased && inputIter.hasNext) { val newInput = inputIter.next() + numInputRows += 1 val groupingKey = groupProjection.apply(newInput) val buffer: UnsafeRow = if (i < fallbackStartsAt) { hashMap.getAggregationBufferFromUnsafeRow(groupingKey) @@ -439,6 +444,7 @@ class TungstenAggregationIterator( // Process the rest of input rows. while (inputIter.hasNext) { val newInput = inputIter.next() + numInputRows += 1 val groupingKey = groupProjection.apply(newInput) buffer.copyFrom(initialAggregationBuffer) processRow(buffer, newInput) @@ -462,6 +468,7 @@ class TungstenAggregationIterator( // Insert the rest of input rows. while (inputIter.hasNext) { val newInput = inputIter.next() + numInputRows += 1 val groupingKey = groupProjection.apply(newInput) bufferExtractor(newInput) externalSorter.insertKV(groupingKey, buffer) @@ -657,7 +664,7 @@ class TungstenAggregationIterator( TaskContext.get().internalMetricsToAccumulators( InternalAccumulator.PEAK_EXECUTION_MEMORY).add(peakMemory) } - + numOutputRows += 1 res } else { // no more result diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index bf2de244c8e4a..247c900baae9e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -41,11 +41,20 @@ import org.apache.spark.{HashPartitioner, SparkEnv} case class Project(projectList: Seq[NamedExpression], child: SparkPlan) extends UnaryNode { override def output: Seq[Attribute] = projectList.map(_.toAttribute) + override private[sql] lazy val metrics = Map( + "numRows" -> SQLMetrics.createLongMetric(sparkContext, "number of rows")) + @transient lazy val buildProjection = newMutableProjection(projectList, child.output) - protected override def doExecute(): RDD[InternalRow] = child.execute().mapPartitions { iter => - val reusableProjection = buildProjection() - iter.map(reusableProjection) + protected override def doExecute(): RDD[InternalRow] = { + val numRows = longMetric("numRows") + child.execute().mapPartitions { iter => + val reusableProjection = buildProjection() + iter.map { row => + numRows += 1 + reusableProjection(row) + } + } } override def outputOrdering: Seq[SortOrder] = child.outputOrdering @@ -57,19 +66,28 @@ case class Project(projectList: Seq[NamedExpression], child: SparkPlan) extends */ case class TungstenProject(projectList: Seq[NamedExpression], child: SparkPlan) extends UnaryNode { + override private[sql] lazy val metrics = Map( + "numRows" -> SQLMetrics.createLongMetric(sparkContext, "number of rows")) + override def outputsUnsafeRows: Boolean = true override def canProcessUnsafeRows: Boolean = true override def canProcessSafeRows: Boolean = true override def output: Seq[Attribute] = projectList.map(_.toAttribute) - protected override def doExecute(): RDD[InternalRow] = child.execute().mapPartitions { iter => - this.transformAllExpressions { - case CreateStruct(children) => CreateStructUnsafe(children) - case CreateNamedStruct(children) => CreateNamedStructUnsafe(children) + protected override def doExecute(): RDD[InternalRow] = { + val numRows = longMetric("numRows") + child.execute().mapPartitions { iter => + this.transformAllExpressions { + case CreateStruct(children) => CreateStructUnsafe(children) + case CreateNamedStruct(children) => CreateNamedStructUnsafe(children) + } + val project = UnsafeProjection.create(projectList, child.output) + iter.map { row => + numRows += 1 + project(row) + } } - val project = UnsafeProjection.create(projectList, child.output) - iter.map(project) } override def outputOrdering: Seq[SortOrder] = child.outputOrdering diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala index f7a68e4f5d445..2e108cb814516 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala @@ -26,6 +26,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.plans.physical.{Distribution, Partitioning, UnspecifiedDistribution} import org.apache.spark.sql.execution.{BinaryNode, SQLExecution, SparkPlan} +import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.util.ThreadUtils import org.apache.spark.{InternalAccumulator, TaskContext} @@ -45,7 +46,10 @@ case class BroadcastHashJoin( right: SparkPlan) extends BinaryNode with HashJoin { - override protected[sql] val trackNumOfRowsEnabled = true + override private[sql] lazy val metrics = Map( + "numLeftRows" -> SQLMetrics.createLongMetric(sparkContext, "number of left rows"), + "numRightRows" -> SQLMetrics.createLongMetric(sparkContext, "number of right rows"), + "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows")) val timeout: Duration = { val timeoutValue = sqlContext.conf.broadcastTimeout @@ -65,6 +69,11 @@ case class BroadcastHashJoin( // for the same query. @transient private lazy val broadcastFuture = { + val numBuildRows = buildSide match { + case BuildLeft => longMetric("numLeftRows") + case BuildRight => longMetric("numRightRows") + } + // broadcastFuture is used in "doExecute". Therefore we can get the execution id correctly here. val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY) future { @@ -73,8 +82,15 @@ case class BroadcastHashJoin( SQLExecution.withExecutionId(sparkContext, executionId) { // Note that we use .execute().collect() because we don't want to convert data to Scala // types - val input: Array[InternalRow] = buildPlan.execute().map(_.copy()).collect() - val hashed = HashedRelation(input.iterator, buildSideKeyGenerator, input.size) + val input: Array[InternalRow] = buildPlan.execute().map { row => + numBuildRows += 1 + row.copy() + }.collect() + // The following line doesn't run in a job so we cannot track the metric value. However, we + // have already tracked it in the above lines. So here we can use + // `SQLMetrics.nullLongMetric` to ignore it. + val hashed = HashedRelation( + input.iterator, SQLMetrics.nullLongMetric, buildSideKeyGenerator, input.size) sparkContext.broadcast(hashed) } }(BroadcastHashJoin.broadcastHashJoinExecutionContext) @@ -85,6 +101,12 @@ case class BroadcastHashJoin( } protected override def doExecute(): RDD[InternalRow] = { + val numStreamedRows = buildSide match { + case BuildLeft => longMetric("numRightRows") + case BuildRight => longMetric("numLeftRows") + } + val numOutputRows = longMetric("numOutputRows") + val broadcastRelation = Await.result(broadcastFuture, timeout) streamedPlan.execute().mapPartitions { streamedIter => @@ -95,7 +117,7 @@ case class BroadcastHashJoin( InternalAccumulator.PEAK_EXECUTION_MEMORY).add(unsafe.getUnsafeSize) case _ => } - hashJoin(streamedIter, hashedRelation) + hashJoin(streamedIter, numStreamedRows, hashedRelation, numOutputRows) } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala index a3626de49aeab..69a8b95eaa7ec 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala @@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical.{Distribution, Partitioning, UnspecifiedDistribution} import org.apache.spark.sql.catalyst.plans.{JoinType, LeftOuter, RightOuter} import org.apache.spark.sql.execution.{BinaryNode, SQLExecution, SparkPlan} +import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.{InternalAccumulator, TaskContext} /** @@ -45,6 +46,11 @@ case class BroadcastHashOuterJoin( left: SparkPlan, right: SparkPlan) extends BinaryNode with HashOuterJoin { + override private[sql] lazy val metrics = Map( + "numLeftRows" -> SQLMetrics.createLongMetric(sparkContext, "number of left rows"), + "numRightRows" -> SQLMetrics.createLongMetric(sparkContext, "number of right rows"), + "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows")) + val timeout = { val timeoutValue = sqlContext.conf.broadcastTimeout if (timeoutValue < 0) { @@ -63,6 +69,14 @@ case class BroadcastHashOuterJoin( // for the same query. @transient private lazy val broadcastFuture = { + val numBuildRows = joinType match { + case RightOuter => longMetric("numLeftRows") + case LeftOuter => longMetric("numRightRows") + case x => + throw new IllegalArgumentException( + s"HashOuterJoin should not take $x as the JoinType") + } + // broadcastFuture is used in "doExecute". Therefore we can get the execution id correctly here. val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY) future { @@ -71,8 +85,15 @@ case class BroadcastHashOuterJoin( SQLExecution.withExecutionId(sparkContext, executionId) { // Note that we use .execute().collect() because we don't want to convert data to Scala // types - val input: Array[InternalRow] = buildPlan.execute().map(_.copy()).collect() - val hashed = HashedRelation(input.iterator, buildKeyGenerator, input.size) + val input: Array[InternalRow] = buildPlan.execute().map { row => + numBuildRows += 1 + row.copy() + }.collect() + // The following line doesn't run in a job so we cannot track the metric value. However, we + // have already tracked it in the above lines. So here we can use + // `SQLMetrics.nullLongMetric` to ignore it. + val hashed = HashedRelation( + input.iterator, SQLMetrics.nullLongMetric, buildKeyGenerator, input.size) sparkContext.broadcast(hashed) } }(BroadcastHashJoin.broadcastHashJoinExecutionContext) @@ -83,6 +104,15 @@ case class BroadcastHashOuterJoin( } override def doExecute(): RDD[InternalRow] = { + val numStreamedRows = joinType match { + case RightOuter => longMetric("numRightRows") + case LeftOuter => longMetric("numLeftRows") + case x => + throw new IllegalArgumentException( + s"HashOuterJoin should not take $x as the JoinType") + } + val numOutputRows = longMetric("numOutputRows") + val broadcastRelation = Await.result(broadcastFuture, timeout) streamedPlan.execute().mapPartitions { streamedIter => @@ -101,16 +131,18 @@ case class BroadcastHashOuterJoin( joinType match { case LeftOuter => streamedIter.flatMap(currentRow => { + numStreamedRows += 1 val rowKey = keyGenerator(currentRow) joinedRow.withLeft(currentRow) - leftOuterIterator(rowKey, joinedRow, hashTable.get(rowKey), resultProj) + leftOuterIterator(rowKey, joinedRow, hashTable.get(rowKey), resultProj, numOutputRows) }) case RightOuter => streamedIter.flatMap(currentRow => { + numStreamedRows += 1 val rowKey = keyGenerator(currentRow) joinedRow.withRight(currentRow) - rightOuterIterator(rowKey, hashTable.get(rowKey), joinedRow, resultProj) + rightOuterIterator(rowKey, hashTable.get(rowKey), joinedRow, resultProj, numOutputRows) }) case x => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala index 5bd06fbdca605..78a8c16c62bca 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala @@ -23,6 +23,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} +import org.apache.spark.sql.execution.metric.SQLMetrics /** * :: DeveloperApi :: @@ -37,18 +38,31 @@ case class BroadcastLeftSemiJoinHash( right: SparkPlan, condition: Option[Expression]) extends BinaryNode with HashSemiJoin { + override private[sql] lazy val metrics = Map( + "numLeftRows" -> SQLMetrics.createLongMetric(sparkContext, "number of left rows"), + "numRightRows" -> SQLMetrics.createLongMetric(sparkContext, "number of right rows"), + "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows")) + protected override def doExecute(): RDD[InternalRow] = { - val input = right.execute().map(_.copy()).collect() + val numLeftRows = longMetric("numLeftRows") + val numRightRows = longMetric("numRightRows") + val numOutputRows = longMetric("numOutputRows") + + val input = right.execute().map { row => + numRightRows += 1 + row.copy() + }.collect() if (condition.isEmpty) { - val hashSet = buildKeyHashSet(input.toIterator) + val hashSet = buildKeyHashSet(input.toIterator, SQLMetrics.nullLongMetric) val broadcastedRelation = sparkContext.broadcast(hashSet) left.execute().mapPartitions { streamIter => - hashSemiJoin(streamIter, broadcastedRelation.value) + hashSemiJoin(streamIter, numLeftRows, broadcastedRelation.value, numOutputRows) } } else { - val hashRelation = HashedRelation(input.toIterator, rightKeyGenerator, input.size) + val hashRelation = + HashedRelation(input.toIterator, SQLMetrics.nullLongMetric, rightKeyGenerator, input.size) val broadcastedRelation = sparkContext.broadcast(hashRelation) left.execute().mapPartitions { streamIter => @@ -59,7 +73,7 @@ case class BroadcastLeftSemiJoinHash( InternalAccumulator.PEAK_EXECUTION_MEMORY).add(unsafe.getUnsafeSize) case _ => } - hashSemiJoin(streamIter, hashedRelation) + hashSemiJoin(streamIter, numLeftRows, hashedRelation, numOutputRows) } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala index 017a44b9ca863..28c88b1b03d02 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala @@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.catalyst.plans.{FullOuter, JoinType, LeftOuter, RightOuter} import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} +import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.util.collection.CompactBuffer /** @@ -38,6 +39,11 @@ case class BroadcastNestedLoopJoin( condition: Option[Expression]) extends BinaryNode { // TODO: Override requiredChildDistribution. + override private[sql] lazy val metrics = Map( + "numLeftRows" -> SQLMetrics.createLongMetric(sparkContext, "number of left rows"), + "numRightRows" -> SQLMetrics.createLongMetric(sparkContext, "number of right rows"), + "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows")) + /** BuildRight means the right relation <=> the broadcast relation. */ private val (streamed, broadcast) = buildSide match { case BuildRight => (left, right) @@ -75,9 +81,17 @@ case class BroadcastNestedLoopJoin( newPredicate(condition.getOrElse(Literal(true)), left.output ++ right.output) protected override def doExecute(): RDD[InternalRow] = { + val (numStreamedRows, numBuildRows) = buildSide match { + case BuildRight => (longMetric("numLeftRows"), longMetric("numRightRows")) + case BuildLeft => (longMetric("numRightRows"), longMetric("numLeftRows")) + } + val numOutputRows = longMetric("numOutputRows") + val broadcastedRelation = - sparkContext.broadcast(broadcast.execute().map(_.copy()) - .collect().toIndexedSeq) + sparkContext.broadcast(broadcast.execute().map { row => + numBuildRows += 1 + row.copy() + }.collect().toIndexedSeq) /** All rows that either match both-way, or rows from streamed joined with nulls. */ val matchesOrStreamedRowsWithNulls = streamed.execute().mapPartitions { streamedIter => @@ -94,6 +108,7 @@ case class BroadcastNestedLoopJoin( streamedIter.foreach { streamedRow => var i = 0 var streamRowMatched = false + numStreamedRows += 1 while (i < broadcastedRelation.value.size) { val broadcastedRow = broadcastedRelation.value(i) @@ -162,6 +177,12 @@ case class BroadcastNestedLoopJoin( // TODO: Breaks lineage. sparkContext.union( - matchesOrStreamedRowsWithNulls.flatMap(_._1), sparkContext.makeRDD(broadcastRowsWithNulls)) + matchesOrStreamedRowsWithNulls.flatMap(_._1), + sparkContext.makeRDD(broadcastRowsWithNulls) + ).map { row => + // `broadcastRowsWithNulls` doesn't run in a job so that we have to track numOutputRows here. + numOutputRows += 1 + row + } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala index 261b4724159fb..2115f40702286 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala @@ -22,6 +22,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Attribute, JoinedRow} import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} +import org.apache.spark.sql.execution.metric.SQLMetrics /** * :: DeveloperApi :: @@ -30,13 +31,31 @@ import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} case class CartesianProduct(left: SparkPlan, right: SparkPlan) extends BinaryNode { override def output: Seq[Attribute] = left.output ++ right.output + override private[sql] lazy val metrics = Map( + "numLeftRows" -> SQLMetrics.createLongMetric(sparkContext, "number of left rows"), + "numRightRows" -> SQLMetrics.createLongMetric(sparkContext, "number of right rows"), + "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows")) + protected override def doExecute(): RDD[InternalRow] = { - val leftResults = left.execute().map(_.copy()) - val rightResults = right.execute().map(_.copy()) + val numLeftRows = longMetric("numLeftRows") + val numRightRows = longMetric("numRightRows") + val numOutputRows = longMetric("numOutputRows") + + val leftResults = left.execute().map { row => + numLeftRows += 1 + row.copy() + } + val rightResults = right.execute().map { row => + numRightRows += 1 + row.copy() + } leftResults.cartesian(rightResults).mapPartitions { iter => val joinedRow = new JoinedRow - iter.map(r => joinedRow(r._1, r._2)) + iter.map { r => + numOutputRows += 1 + joinedRow(r._1, r._2) + } } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala index 22d46d1c3e3b7..7ce4a517838cb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.joins import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.metric.LongSQLMetric trait HashJoin { @@ -69,7 +70,9 @@ trait HashJoin { protected def hashJoin( streamIter: Iterator[InternalRow], - hashedRelation: HashedRelation): Iterator[InternalRow] = + numStreamRows: LongSQLMetric, + hashedRelation: HashedRelation, + numOutputRows: LongSQLMetric): Iterator[InternalRow] = { new Iterator[InternalRow] { private[this] var currentStreamedRow: InternalRow = _ @@ -98,6 +101,7 @@ trait HashJoin { case BuildLeft => joinRow(currentHashMatches(currentMatchPosition), currentStreamedRow) } currentMatchPosition += 1 + numOutputRows += 1 resultProjection(ret) } @@ -113,6 +117,7 @@ trait HashJoin { while (currentHashMatches == null && streamIter.hasNext) { currentStreamedRow = streamIter.next() + numStreamRows += 1 val key = joinKeys(currentStreamedRow) if (!key.anyNull) { currentHashMatches = hashedRelation.get(key) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala index 701bd3cd86372..66903347c88c1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala @@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.metric.LongSQLMetric import org.apache.spark.util.collection.CompactBuffer @DeveloperApi @@ -114,22 +115,28 @@ trait HashOuterJoin { key: InternalRow, joinedRow: JoinedRow, rightIter: Iterable[InternalRow], - resultProjection: InternalRow => InternalRow): Iterator[InternalRow] = { + resultProjection: InternalRow => InternalRow, + numOutputRows: LongSQLMetric): Iterator[InternalRow] = { val ret: Iterable[InternalRow] = { if (!key.anyNull) { val temp = if (rightIter != null) { rightIter.collect { - case r if boundCondition(joinedRow.withRight(r)) => resultProjection(joinedRow).copy() + case r if boundCondition(joinedRow.withRight(r)) => { + numOutputRows += 1 + resultProjection(joinedRow).copy() + } } } else { List.empty } if (temp.isEmpty) { + numOutputRows += 1 resultProjection(joinedRow.withRight(rightNullRow)) :: Nil } else { temp } } else { + numOutputRows += 1 resultProjection(joinedRow.withRight(rightNullRow)) :: Nil } } @@ -140,22 +147,28 @@ trait HashOuterJoin { key: InternalRow, leftIter: Iterable[InternalRow], joinedRow: JoinedRow, - resultProjection: InternalRow => InternalRow): Iterator[InternalRow] = { + resultProjection: InternalRow => InternalRow, + numOutputRows: LongSQLMetric): Iterator[InternalRow] = { val ret: Iterable[InternalRow] = { if (!key.anyNull) { val temp = if (leftIter != null) { leftIter.collect { - case l if boundCondition(joinedRow.withLeft(l)) => resultProjection(joinedRow).copy() + case l if boundCondition(joinedRow.withLeft(l)) => { + numOutputRows += 1 + resultProjection(joinedRow).copy() + } } } else { List.empty } if (temp.isEmpty) { + numOutputRows += 1 resultProjection(joinedRow.withLeft(leftNullRow)) :: Nil } else { temp } } else { + numOutputRows += 1 resultProjection(joinedRow.withLeft(leftNullRow)) :: Nil } } @@ -164,7 +177,7 @@ trait HashOuterJoin { protected[this] def fullOuterIterator( key: InternalRow, leftIter: Iterable[InternalRow], rightIter: Iterable[InternalRow], - joinedRow: JoinedRow): Iterator[InternalRow] = { + joinedRow: JoinedRow, numOutputRows: LongSQLMetric): Iterator[InternalRow] = { if (!key.anyNull) { // Store the positions of records in right, if one of its associated row satisfy // the join condition. @@ -177,6 +190,7 @@ trait HashOuterJoin { // append them directly case (r, idx) if boundCondition(joinedRow.withRight(r)) => + numOutputRows += 1 matched = true // if the row satisfy the join condition, add its index into the matched set rightMatchedSet.add(idx) @@ -189,6 +203,7 @@ trait HashOuterJoin { // as we don't know whether we need to append it until finish iterating all // of the records in right side. // If we didn't get any proper row, then append a single row with empty right. + numOutputRows += 1 joinedRow.withRight(rightNullRow).copy() }) } ++ rightIter.zipWithIndex.collect { @@ -197,12 +212,15 @@ trait HashOuterJoin { // Re-visiting the records in right, and append additional row with empty left, if its not // in the matched set. case (r, idx) if !rightMatchedSet.contains(idx) => + numOutputRows += 1 joinedRow(leftNullRow, r).copy() } } else { leftIter.iterator.map[InternalRow] { l => + numOutputRows += 1 joinedRow(l, rightNullRow).copy() } ++ rightIter.iterator.map[InternalRow] { r => + numOutputRows += 1 joinedRow(leftNullRow, r).copy() } } @@ -211,10 +229,12 @@ trait HashOuterJoin { // This is only used by FullOuter protected[this] def buildHashTable( iter: Iterator[InternalRow], + numIterRows: LongSQLMetric, keyGenerator: Projection): JavaHashMap[InternalRow, CompactBuffer[InternalRow]] = { val hashTable = new JavaHashMap[InternalRow, CompactBuffer[InternalRow]]() while (iter.hasNext) { val currentRow = iter.next() + numIterRows += 1 val rowKey = keyGenerator(currentRow) var existingMatchList = hashTable.get(rowKey) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala index 82dd6eb7e7ed0..beb141ade616d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.joins import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.metric.LongSQLMetric trait HashSemiJoin { @@ -61,13 +62,15 @@ trait HashSemiJoin { @transient private lazy val boundCondition = newPredicate(condition.getOrElse(Literal(true)), left.output ++ right.output) - protected def buildKeyHashSet(buildIter: Iterator[InternalRow]): java.util.Set[InternalRow] = { + protected def buildKeyHashSet( + buildIter: Iterator[InternalRow], numBuildRows: LongSQLMetric): java.util.Set[InternalRow] = { val hashSet = new java.util.HashSet[InternalRow]() // Create a Hash set of buildKeys val rightKey = rightKeyGenerator while (buildIter.hasNext) { val currentRow = buildIter.next() + numBuildRows += 1 val rowKey = rightKey(currentRow) if (!rowKey.anyNull) { val keyExists = hashSet.contains(rowKey) @@ -82,25 +85,35 @@ trait HashSemiJoin { protected def hashSemiJoin( streamIter: Iterator[InternalRow], - hashSet: java.util.Set[InternalRow]): Iterator[InternalRow] = { + numStreamRows: LongSQLMetric, + hashSet: java.util.Set[InternalRow], + numOutputRows: LongSQLMetric): Iterator[InternalRow] = { val joinKeys = leftKeyGenerator streamIter.filter(current => { + numStreamRows += 1 val key = joinKeys(current) - !key.anyNull && hashSet.contains(key) + val r = !key.anyNull && hashSet.contains(key) + if (r) numOutputRows += 1 + r }) } protected def hashSemiJoin( streamIter: Iterator[InternalRow], - hashedRelation: HashedRelation): Iterator[InternalRow] = { + numStreamRows: LongSQLMetric, + hashedRelation: HashedRelation, + numOutputRows: LongSQLMetric): Iterator[InternalRow] = { val joinKeys = leftKeyGenerator val joinedRow = new JoinedRow streamIter.filter { current => + numStreamRows += 1 val key = joinKeys(current) lazy val rowBuffer = hashedRelation.get(key) - !key.anyNull && rowBuffer != null && rowBuffer.exists { + val r = !key.anyNull && rowBuffer != null && rowBuffer.exists { (row: InternalRow) => boundCondition(joinedRow(current, row)) } + if (r) numOutputRows += 1 + r } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala index 63d35d0f02622..c1bc7947aa39c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala @@ -25,6 +25,7 @@ import org.apache.spark.shuffle.ShuffleMemoryManager import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.execution.SparkSqlSerializer +import org.apache.spark.sql.execution.metric.LongSQLMetric import org.apache.spark.unsafe.Platform import org.apache.spark.unsafe.map.BytesToBytesMap import org.apache.spark.unsafe.memory.{ExecutorMemoryManager, MemoryAllocator, TaskMemoryManager} @@ -112,11 +113,13 @@ private[joins] object HashedRelation { def apply( input: Iterator[InternalRow], + numInputRows: LongSQLMetric, keyGenerator: Projection, sizeEstimate: Int = 64): HashedRelation = { if (keyGenerator.isInstanceOf[UnsafeProjection]) { - return UnsafeHashedRelation(input, keyGenerator.asInstanceOf[UnsafeProjection], sizeEstimate) + return UnsafeHashedRelation( + input, numInputRows, keyGenerator.asInstanceOf[UnsafeProjection], sizeEstimate) } // TODO: Use Spark's HashMap implementation. @@ -130,6 +133,7 @@ private[joins] object HashedRelation { // Create a mapping of buildKeys -> rows while (input.hasNext) { currentRow = input.next() + numInputRows += 1 val rowKey = keyGenerator(currentRow) if (!rowKey.anyNull) { val existingMatchList = hashTable.get(rowKey) @@ -331,6 +335,7 @@ private[joins] object UnsafeHashedRelation { def apply( input: Iterator[InternalRow], + numInputRows: LongSQLMetric, keyGenerator: UnsafeProjection, sizeEstimate: Int): HashedRelation = { @@ -340,6 +345,7 @@ private[joins] object UnsafeHashedRelation { // Create a mapping of buildKeys -> rows while (input.hasNext) { val unsafeRow = input.next().asInstanceOf[UnsafeRow] + numInputRows += 1 val rowKey = keyGenerator(unsafeRow) if (!rowKey.anyNull) { val existingMatchList = hashTable.get(rowKey) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinBNL.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinBNL.scala index 4443455ef11fe..ad6362542f2ff 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinBNL.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinBNL.scala @@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} +import org.apache.spark.sql.execution.metric.SQLMetrics /** * :: DeveloperApi :: @@ -35,6 +36,11 @@ case class LeftSemiJoinBNL( extends BinaryNode { // TODO: Override requiredChildDistribution. + override private[sql] lazy val metrics = Map( + "numLeftRows" -> SQLMetrics.createLongMetric(sparkContext, "number of left rows"), + "numRightRows" -> SQLMetrics.createLongMetric(sparkContext, "number of right rows"), + "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows")) + override def outputPartitioning: Partitioning = streamed.outputPartitioning override def output: Seq[Attribute] = left.output @@ -52,13 +58,21 @@ case class LeftSemiJoinBNL( newPredicate(condition.getOrElse(Literal(true)), left.output ++ right.output) protected override def doExecute(): RDD[InternalRow] = { + val numLeftRows = longMetric("numLeftRows") + val numRightRows = longMetric("numRightRows") + val numOutputRows = longMetric("numOutputRows") + val broadcastedRelation = - sparkContext.broadcast(broadcast.execute().map(_.copy()).collect().toIndexedSeq) + sparkContext.broadcast(broadcast.execute().map { row => + numRightRows += 1 + row.copy() + }.collect().toIndexedSeq) streamed.execute().mapPartitions { streamedIter => val joinedRow = new JoinedRow streamedIter.filter(streamedRow => { + numLeftRows += 1 var i = 0 var matched = false @@ -69,6 +83,9 @@ case class LeftSemiJoinBNL( } i += 1 } + if (matched) { + numOutputRows += 1 + } matched }) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala index 68ccd34d8ed9b..18808adaac63f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala @@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, Distribution, ClusteredDistribution} import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} +import org.apache.spark.sql.execution.metric.SQLMetrics /** * :: DeveloperApi :: @@ -37,19 +38,28 @@ case class LeftSemiJoinHash( right: SparkPlan, condition: Option[Expression]) extends BinaryNode with HashSemiJoin { + override private[sql] lazy val metrics = Map( + "numLeftRows" -> SQLMetrics.createLongMetric(sparkContext, "number of left rows"), + "numRightRows" -> SQLMetrics.createLongMetric(sparkContext, "number of right rows"), + "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows")) + override def outputPartitioning: Partitioning = left.outputPartitioning override def requiredChildDistribution: Seq[Distribution] = ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil protected override def doExecute(): RDD[InternalRow] = { + val numLeftRows = longMetric("numLeftRows") + val numRightRows = longMetric("numRightRows") + val numOutputRows = longMetric("numOutputRows") + right.execute().zipPartitions(left.execute()) { (buildIter, streamIter) => if (condition.isEmpty) { - val hashSet = buildKeyHashSet(buildIter) - hashSemiJoin(streamIter, hashSet) + val hashSet = buildKeyHashSet(buildIter, numRightRows) + hashSemiJoin(streamIter, numLeftRows, hashSet, numOutputRows) } else { - val hashRelation = HashedRelation(buildIter, rightKeyGenerator) - hashSemiJoin(streamIter, hashRelation) + val hashRelation = HashedRelation(buildIter, numRightRows, rightKeyGenerator) + hashSemiJoin(streamIter, numLeftRows, hashRelation, numOutputRows) } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala index c923dc837c449..fc8c9439a6f07 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala @@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} +import org.apache.spark.sql.execution.metric.SQLMetrics /** * :: DeveloperApi :: @@ -38,7 +39,10 @@ case class ShuffledHashJoin( right: SparkPlan) extends BinaryNode with HashJoin { - override protected[sql] val trackNumOfRowsEnabled = true + override private[sql] lazy val metrics = Map( + "numLeftRows" -> SQLMetrics.createLongMetric(sparkContext, "number of left rows"), + "numRightRows" -> SQLMetrics.createLongMetric(sparkContext, "number of right rows"), + "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows")) override def outputPartitioning: Partitioning = PartitioningCollection(Seq(left.outputPartitioning, right.outputPartitioning)) @@ -47,9 +51,15 @@ case class ShuffledHashJoin( ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil protected override def doExecute(): RDD[InternalRow] = { + val (numBuildRows, numStreamedRows) = buildSide match { + case BuildLeft => (longMetric("numLeftRows"), longMetric("numRightRows")) + case BuildRight => (longMetric("numRightRows"), longMetric("numLeftRows")) + } + val numOutputRows = longMetric("numOutputRows") + buildPlan.execute().zipPartitions(streamedPlan.execute()) { (buildIter, streamIter) => - val hashed = HashedRelation(buildIter, buildSideKeyGenerator) - hashJoin(streamIter, hashed) + val hashed = HashedRelation(buildIter, numBuildRows, buildSideKeyGenerator) + hashJoin(streamIter, numStreamedRows, hashed, numOutputRows) } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala index 6a8c35efca8f4..ed282f98b7d71 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala @@ -26,6 +26,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.catalyst.plans.{FullOuter, JoinType, LeftOuter, RightOuter} import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} +import org.apache.spark.sql.execution.metric.SQLMetrics /** * :: DeveloperApi :: @@ -41,6 +42,11 @@ case class ShuffledHashOuterJoin( left: SparkPlan, right: SparkPlan) extends BinaryNode with HashOuterJoin { + override private[sql] lazy val metrics = Map( + "numLeftRows" -> SQLMetrics.createLongMetric(sparkContext, "number of left rows"), + "numRightRows" -> SQLMetrics.createLongMetric(sparkContext, "number of right rows"), + "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows")) + override def requiredChildDistribution: Seq[Distribution] = ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil @@ -53,39 +59,48 @@ case class ShuffledHashOuterJoin( } protected override def doExecute(): RDD[InternalRow] = { + val numLeftRows = longMetric("numLeftRows") + val numRightRows = longMetric("numRightRows") + val numOutputRows = longMetric("numOutputRows") + val joinedRow = new JoinedRow() left.execute().zipPartitions(right.execute()) { (leftIter, rightIter) => // TODO this probably can be replaced by external sort (sort merged join?) joinType match { case LeftOuter => - val hashed = HashedRelation(rightIter, buildKeyGenerator) + val hashed = HashedRelation(rightIter, numRightRows, buildKeyGenerator) val keyGenerator = streamedKeyGenerator val resultProj = resultProjection leftIter.flatMap( currentRow => { + numLeftRows += 1 val rowKey = keyGenerator(currentRow) joinedRow.withLeft(currentRow) - leftOuterIterator(rowKey, joinedRow, hashed.get(rowKey), resultProj) + leftOuterIterator(rowKey, joinedRow, hashed.get(rowKey), resultProj, numOutputRows) }) case RightOuter => - val hashed = HashedRelation(leftIter, buildKeyGenerator) + val hashed = HashedRelation(leftIter, numLeftRows, buildKeyGenerator) val keyGenerator = streamedKeyGenerator val resultProj = resultProjection rightIter.flatMap ( currentRow => { + numRightRows += 1 val rowKey = keyGenerator(currentRow) joinedRow.withRight(currentRow) - rightOuterIterator(rowKey, hashed.get(rowKey), joinedRow, resultProj) + rightOuterIterator(rowKey, hashed.get(rowKey), joinedRow, resultProj, numOutputRows) }) case FullOuter => // TODO(davies): use UnsafeRow - val leftHashTable = buildHashTable(leftIter, newProjection(leftKeys, left.output)) - val rightHashTable = buildHashTable(rightIter, newProjection(rightKeys, right.output)) + val leftHashTable = + buildHashTable(leftIter, numLeftRows, newProjection(leftKeys, left.output)) + val rightHashTable = + buildHashTable(rightIter, numRightRows, newProjection(rightKeys, right.output)) (leftHashTable.keySet ++ rightHashTable.keySet).iterator.flatMap { key => fullOuterIterator(key, leftHashTable.getOrElse(key, EMPTY_LIST), rightHashTable.getOrElse(key, EMPTY_LIST), - joinedRow) + joinedRow, + numOutputRows) } case x => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala index 6d656ea2849a9..6b7322671d6b4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala @@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.{BinaryNode, RowIterator, SparkPlan} +import org.apache.spark.sql.execution.metric.{LongSQLMetric, SQLMetrics} /** * :: DeveloperApi :: @@ -37,6 +38,11 @@ case class SortMergeJoin( left: SparkPlan, right: SparkPlan) extends BinaryNode { + override private[sql] lazy val metrics = Map( + "numLeftRows" -> SQLMetrics.createLongMetric(sparkContext, "number of left rows"), + "numRightRows" -> SQLMetrics.createLongMetric(sparkContext, "number of right rows"), + "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows")) + override def output: Seq[Attribute] = left.output ++ right.output override def outputPartitioning: Partitioning = @@ -70,6 +76,10 @@ case class SortMergeJoin( } protected override def doExecute(): RDD[InternalRow] = { + val numLeftRows = longMetric("numLeftRows") + val numRightRows = longMetric("numRightRows") + val numOutputRows = longMetric("numOutputRows") + left.execute().zipPartitions(right.execute()) { (leftIter, rightIter) => new RowIterator { // An ordering that can be used to compare keys from both sides. @@ -82,7 +92,9 @@ case class SortMergeJoin( rightKeyGenerator, keyOrdering, RowIterator.fromScala(leftIter), - RowIterator.fromScala(rightIter) + numLeftRows, + RowIterator.fromScala(rightIter), + numRightRows ) private[this] val joinRow = new JoinedRow private[this] val resultProjection: (InternalRow) => InternalRow = { @@ -108,6 +120,7 @@ case class SortMergeJoin( if (currentLeftRow != null) { joinRow(currentLeftRow, currentRightMatches(currentMatchIdx)) currentMatchIdx += 1 + numOutputRows += 1 true } else { false @@ -144,7 +157,9 @@ private[joins] class SortMergeJoinScanner( bufferedKeyGenerator: Projection, keyOrdering: Ordering[InternalRow], streamedIter: RowIterator, - bufferedIter: RowIterator) { + numStreamedRows: LongSQLMetric, + bufferedIter: RowIterator, + numBufferedRows: LongSQLMetric) { private[this] var streamedRow: InternalRow = _ private[this] var streamedRowKey: InternalRow = _ private[this] var bufferedRow: InternalRow = _ @@ -269,6 +284,7 @@ private[joins] class SortMergeJoinScanner( if (streamedIter.advanceNext()) { streamedRow = streamedIter.getRow streamedRowKey = streamedKeyGenerator(streamedRow) + numStreamedRows += 1 true } else { streamedRow = null @@ -286,6 +302,7 @@ private[joins] class SortMergeJoinScanner( while (!foundRow && bufferedIter.advanceNext()) { bufferedRow = bufferedIter.getRow bufferedRowKey = bufferedKeyGenerator(bufferedRow) + numBufferedRows += 1 foundRow = !bufferedRowKey.anyNull } if (!foundRow) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala index 5326966b07a66..dea9e5e580a1e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala @@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.{JoinType, LeftOuter, RightOuter} import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.{BinaryNode, RowIterator, SparkPlan} +import org.apache.spark.sql.execution.metric.{LongSQLMetric, SQLMetrics} /** * :: DeveloperApi :: @@ -40,6 +41,11 @@ case class SortMergeOuterJoin( left: SparkPlan, right: SparkPlan) extends BinaryNode { + override private[sql] lazy val metrics = Map( + "numLeftRows" -> SQLMetrics.createLongMetric(sparkContext, "number of left rows"), + "numRightRows" -> SQLMetrics.createLongMetric(sparkContext, "number of right rows"), + "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows")) + override def output: Seq[Attribute] = { joinType match { case LeftOuter => @@ -108,6 +114,10 @@ case class SortMergeOuterJoin( } override def doExecute(): RDD[InternalRow] = { + val numLeftRows = longMetric("numLeftRows") + val numRightRows = longMetric("numRightRows") + val numOutputRows = longMetric("numOutputRows") + left.execute().zipPartitions(right.execute()) { (leftIter, rightIter) => // An ordering that can be used to compare keys from both sides. val keyOrdering = newNaturalAscendingOrdering(leftKeys.map(_.dataType)) @@ -133,10 +143,13 @@ case class SortMergeOuterJoin( bufferedKeyGenerator = createRightKeyGenerator(), keyOrdering, streamedIter = RowIterator.fromScala(leftIter), - bufferedIter = RowIterator.fromScala(rightIter) + numLeftRows, + bufferedIter = RowIterator.fromScala(rightIter), + numRightRows ) val rightNullRow = new GenericInternalRow(right.output.length) - new LeftOuterIterator(smjScanner, rightNullRow, boundCondition, resultProj).toScala + new LeftOuterIterator( + smjScanner, rightNullRow, boundCondition, resultProj, numOutputRows).toScala case RightOuter => val smjScanner = new SortMergeJoinScanner( @@ -144,10 +157,13 @@ case class SortMergeOuterJoin( bufferedKeyGenerator = createLeftKeyGenerator(), keyOrdering, streamedIter = RowIterator.fromScala(rightIter), - bufferedIter = RowIterator.fromScala(leftIter) + numRightRows, + bufferedIter = RowIterator.fromScala(leftIter), + numLeftRows ) val leftNullRow = new GenericInternalRow(left.output.length) - new RightOuterIterator(smjScanner, leftNullRow, boundCondition, resultProj).toScala + new RightOuterIterator( + smjScanner, leftNullRow, boundCondition, resultProj, numOutputRows).toScala case x => throw new IllegalArgumentException( @@ -162,7 +178,8 @@ private class LeftOuterIterator( smjScanner: SortMergeJoinScanner, rightNullRow: InternalRow, boundCondition: InternalRow => Boolean, - resultProj: InternalRow => InternalRow + resultProj: InternalRow => InternalRow, + numRows: LongSQLMetric ) extends RowIterator { private[this] val joinedRow: JoinedRow = new JoinedRow() private[this] var rightIdx: Int = 0 @@ -198,7 +215,9 @@ private class LeftOuterIterator( } override def advanceNext(): Boolean = { - advanceRightUntilBoundConditionSatisfied() || advanceLeft() + val r = advanceRightUntilBoundConditionSatisfied() || advanceLeft() + if (r) numRows += 1 + r } override def getRow: InternalRow = resultProj(joinedRow) @@ -208,7 +227,8 @@ private class RightOuterIterator( smjScanner: SortMergeJoinScanner, leftNullRow: InternalRow, boundCondition: InternalRow => Boolean, - resultProj: InternalRow => InternalRow + resultProj: InternalRow => InternalRow, + numRows: LongSQLMetric ) extends RowIterator { private[this] val joinedRow: JoinedRow = new JoinedRow() private[this] var leftIdx: Int = 0 @@ -244,7 +264,9 @@ private class RightOuterIterator( } override def advanceNext(): Boolean = { - advanceLeftUntilBoundConditionSatisfied() || advanceRight() + val r = advanceLeftUntilBoundConditionSatisfied() || advanceRight() + if (r) numRows += 1 + r } override def getRow: InternalRow = resultProj(joinedRow) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala index 1b51a5e5c8a8e..7a2a98ec18cb8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala @@ -112,4 +112,10 @@ private[sql] object SQLMetrics { sc.cleaner.foreach(_.registerAccumulatorForCleanup(acc)) acc } + + /** + * A metric that its value will be ignored. Use this one when we need a metric parameter but don't + * care about the value. + */ + val nullLongMetric = new LongSQLMetric("null") } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala index 8b1a9b21a96b9..a1fa2c3864bdb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala @@ -22,6 +22,8 @@ import java.io.{ByteArrayInputStream, ByteArrayOutputStream, ObjectInputStream, import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.execution.metric.SQLMetrics +import org.apache.spark.sql.test.TestSQLContext import org.apache.spark.sql.types.{IntegerType, StructField, StructType} import org.apache.spark.util.collection.CompactBuffer @@ -35,7 +37,8 @@ class HashedRelationSuite extends SparkFunSuite { test("GeneralHashedRelation") { val data = Array(InternalRow(0), InternalRow(1), InternalRow(2), InternalRow(2)) - val hashed = HashedRelation(data.iterator, keyProjection) + val numDataRows = SQLMetrics.createLongMetric(TestSQLContext.sparkContext, "data") + val hashed = HashedRelation(data.iterator, numDataRows, keyProjection) assert(hashed.isInstanceOf[GeneralHashedRelation]) assert(hashed.get(data(0)) === CompactBuffer[InternalRow](data(0))) @@ -45,11 +48,13 @@ class HashedRelationSuite extends SparkFunSuite { val data2 = CompactBuffer[InternalRow](data(2)) data2 += data(2) assert(hashed.get(data(2)) === data2) + assert(numDataRows.value.value === data.length) } test("UniqueKeyHashedRelation") { val data = Array(InternalRow(0), InternalRow(1), InternalRow(2)) - val hashed = HashedRelation(data.iterator, keyProjection) + val numDataRows = SQLMetrics.createLongMetric(TestSQLContext.sparkContext, "data") + val hashed = HashedRelation(data.iterator, numDataRows, keyProjection) assert(hashed.isInstanceOf[UniqueKeyHashedRelation]) assert(hashed.get(data(0)) === CompactBuffer[InternalRow](data(0))) @@ -62,17 +67,19 @@ class HashedRelationSuite extends SparkFunSuite { assert(uniqHashed.getValue(data(1)) === data(1)) assert(uniqHashed.getValue(data(2)) === data(2)) assert(uniqHashed.getValue(InternalRow(10)) === null) + assert(numDataRows.value.value === data.length) } test("UnsafeHashedRelation") { val schema = StructType(StructField("a", IntegerType, true) :: Nil) val data = Array(InternalRow(0), InternalRow(1), InternalRow(2), InternalRow(2)) + val numDataRows = SQLMetrics.createLongMetric(TestSQLContext.sparkContext, "data") val toUnsafe = UnsafeProjection.create(schema) val unsafeData = data.map(toUnsafe(_).copy()).toArray val buildKey = Seq(BoundReference(0, IntegerType, false)) val keyGenerator = UnsafeProjection.create(buildKey) - val hashed = UnsafeHashedRelation(unsafeData.iterator, keyGenerator, 1) + val hashed = UnsafeHashedRelation(unsafeData.iterator, numDataRows, keyGenerator, 1) assert(hashed.isInstanceOf[UnsafeHashedRelation]) assert(hashed.get(unsafeData(0)) === CompactBuffer[InternalRow](unsafeData(0))) @@ -94,5 +101,6 @@ class HashedRelationSuite extends SparkFunSuite { assert(hashed2.get(unsafeData(1)) === CompactBuffer[InternalRow](unsafeData(1))) assert(hashed2.get(toUnsafe(InternalRow(10))) === null) assert(hashed2.get(unsafeData(2)) === data2) + assert(numDataRows.value.value === data.length) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala index 953284c98b208..7383d3f8fe024 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala @@ -25,15 +25,24 @@ import com.esotericsoftware.reflectasm.shaded.org.objectweb.asm._ import com.esotericsoftware.reflectasm.shaded.org.objectweb.asm.Opcodes._ import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.test.TestSQLContext +import org.apache.spark.sql._ +import org.apache.spark.sql.execution.ui.SparkPlanGraph +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.test.{SQLTestUtils, TestSQLContext} import org.apache.spark.util.Utils +class SQLMetricsSuite extends SparkFunSuite with SQLTestUtils { -class SQLMetricsSuite extends SparkFunSuite { + override val sqlContext = TestSQLContext + + import sqlContext.implicits._ test("LongSQLMetric should not box Long") { val l = SQLMetrics.createLongMetric(TestSQLContext.sparkContext, "long") - val f = () => { l += 1L } + val f = () => { + l += 1L + l.add(1L) + } BoxingFinder.getClassReader(f.getClass).foreach { cl => val boxingFinder = new BoxingFinder() cl.accept(boxingFinder, 0) @@ -51,6 +60,441 @@ class SQLMetricsSuite extends SparkFunSuite { assert(boxingFinder.boxingInvokes.nonEmpty, "Found find boxing in this test") } } + + /** + * Call `df.collect()` and verify if the collected metrics are same as "expectedMetrics". + * + * @param df `DataFrame` to run + * @param expectedNumOfJobs number of jobs that will run + * @param expectedMetrics the expected metrics. The format is + * `nodeId -> (operatorName, metric name -> metric value)`. + */ + private def testSparkPlanMetrics( + df: DataFrame, + expectedNumOfJobs: Int, + expectedMetrics: Map[Long, (String, Map[String, Any])]): Unit = { + val previousExecutionIds = TestSQLContext.listener.executionIdToData.keySet + df.collect() + TestSQLContext.sparkContext.listenerBus.waitUntilEmpty(10000) + val executionIds = TestSQLContext.listener.executionIdToData.keySet.diff(previousExecutionIds) + assert(executionIds.size === 1) + val executionId = executionIds.head + val jobs = TestSQLContext.listener.getExecution(executionId).get.jobs + // Use "<=" because there is a race condition that we may miss some jobs + // TODO Change it to "=" once we fix the race condition that missing the JobStarted event. + assert(jobs.size <= expectedNumOfJobs) + if (jobs.size == expectedNumOfJobs) { + // If we can track all jobs, check the metric values + val metricValues = TestSQLContext.listener.getExecutionMetrics(executionId) + val actualMetrics = SparkPlanGraph(df.queryExecution.executedPlan).nodes.filter { node => + expectedMetrics.contains(node.id) + }.map { node => + val nodeMetrics = node.metrics.map { metric => + val metricValue = metricValues(metric.accumulatorId) + (metric.name, metricValue) + }.toMap + (node.id, node.name -> nodeMetrics) + }.toMap + assert(expectedMetrics === actualMetrics) + } else { + // TODO Remove this "else" once we fix the race condition that missing the JobStarted event. + // Since we cannot track all jobs, the metric values could be wrong and we should not check + // them. + logWarning("Due to a race condition, we miss some jobs and cannot verify the metric values") + } + } + + test("Project metrics") { + withSQLConf( + SQLConf.UNSAFE_ENABLED.key -> "false", + SQLConf.CODEGEN_ENABLED.key -> "false", + SQLConf.TUNGSTEN_ENABLED.key -> "false") { + // Assume the execution plan is + // PhysicalRDD(nodeId = 1) -> Project(nodeId = 0) + val df = TestData.person.select('name) + testSparkPlanMetrics(df, 1, Map( + 0L ->("Project", Map( + "number of rows" -> 2L))) + ) + } + } + + test("TungstenProject metrics") { + withSQLConf( + SQLConf.UNSAFE_ENABLED.key -> "true", + SQLConf.CODEGEN_ENABLED.key -> "true", + SQLConf.TUNGSTEN_ENABLED.key -> "true") { + // Assume the execution plan is + // PhysicalRDD(nodeId = 1) -> TungstenProject(nodeId = 0) + val df = TestData.person.select('name) + testSparkPlanMetrics(df, 1, Map( + 0L ->("TungstenProject", Map( + "number of rows" -> 2L))) + ) + } + } + + test("Filter metrics") { + // Assume the execution plan is + // PhysicalRDD(nodeId = 1) -> Filter(nodeId = 0) + val df = TestData.person.filter('age < 25) + testSparkPlanMetrics(df, 1, Map( + 0L -> ("Filter", Map( + "number of input rows" -> 2L, + "number of output rows" -> 1L))) + ) + } + + test("Aggregate metrics") { + withSQLConf( + SQLConf.UNSAFE_ENABLED.key -> "false", + SQLConf.CODEGEN_ENABLED.key -> "false", + SQLConf.TUNGSTEN_ENABLED.key -> "false") { + // Assume the execution plan is + // ... -> Aggregate(nodeId = 2) -> TungstenExchange(nodeId = 1) -> Aggregate(nodeId = 0) + val df = TestData.testData2.groupBy().count() // 2 partitions + testSparkPlanMetrics(df, 1, Map( + 2L -> ("Aggregate", Map( + "number of input rows" -> 6L, + "number of output rows" -> 2L)), + 0L -> ("Aggregate", Map( + "number of input rows" -> 2L, + "number of output rows" -> 1L))) + ) + + // 2 partitions and each partition contains 2 keys + val df2 = TestData.testData2.groupBy('a).count() + testSparkPlanMetrics(df2, 1, Map( + 2L -> ("Aggregate", Map( + "number of input rows" -> 6L, + "number of output rows" -> 4L)), + 0L -> ("Aggregate", Map( + "number of input rows" -> 4L, + "number of output rows" -> 3L))) + ) + } + } + + test("SortBasedAggregate metrics") { + // Because SortBasedAggregate may skip different rows if the number of partitions is different, + // this test should use the deterministic number of partitions. + withSQLConf( + SQLConf.UNSAFE_ENABLED.key -> "false", + SQLConf.CODEGEN_ENABLED.key -> "true", + SQLConf.TUNGSTEN_ENABLED.key -> "true") { + // Assume the execution plan is + // ... -> SortBasedAggregate(nodeId = 2) -> TungstenExchange(nodeId = 1) -> + // SortBasedAggregate(nodeId = 0) + val df = TestData.testData2.groupBy().count() // 2 partitions + testSparkPlanMetrics(df, 1, Map( + 2L -> ("SortBasedAggregate", Map( + "number of input rows" -> 6L, + "number of output rows" -> 2L)), + 0L -> ("SortBasedAggregate", Map( + "number of input rows" -> 2L, + "number of output rows" -> 1L))) + ) + + // Assume the execution plan is + // ... -> SortBasedAggregate(nodeId = 3) -> TungstenExchange(nodeId = 2) + // -> ExternalSort(nodeId = 1)-> SortBasedAggregate(nodeId = 0) + // 2 partitions and each partition contains 2 keys + val df2 = TestData.testData2.groupBy('a).count() + testSparkPlanMetrics(df2, 1, Map( + 3L -> ("SortBasedAggregate", Map( + "number of input rows" -> 6L, + "number of output rows" -> 4L)), + 0L -> ("SortBasedAggregate", Map( + "number of input rows" -> 4L, + "number of output rows" -> 3L))) + ) + } + } + + test("TungstenAggregate metrics") { + withSQLConf( + SQLConf.UNSAFE_ENABLED.key -> "true", + SQLConf.CODEGEN_ENABLED.key -> "true", + SQLConf.TUNGSTEN_ENABLED.key -> "true") { + // Assume the execution plan is + // ... -> TungstenAggregate(nodeId = 2) -> Exchange(nodeId = 1) + // -> TungstenAggregate(nodeId = 0) + val df = TestData.testData2.groupBy().count() // 2 partitions + testSparkPlanMetrics(df, 1, Map( + 2L -> ("TungstenAggregate", Map( + "number of input rows" -> 6L, + "number of output rows" -> 2L)), + 0L -> ("TungstenAggregate", Map( + "number of input rows" -> 2L, + "number of output rows" -> 1L))) + ) + + // 2 partitions and each partition contains 2 keys + val df2 = TestData.testData2.groupBy('a).count() + testSparkPlanMetrics(df2, 1, Map( + 2L -> ("TungstenAggregate", Map( + "number of input rows" -> 6L, + "number of output rows" -> 4L)), + 0L -> ("TungstenAggregate", Map( + "number of input rows" -> 4L, + "number of output rows" -> 3L))) + ) + } + } + + test("SortMergeJoin metrics") { + // Because SortMergeJoin may skip different rows if the number of partitions is different, this + // test should use the deterministic number of partitions. + withSQLConf(SQLConf.SORTMERGE_JOIN.key -> "true") { + val testDataForJoin = TestData.testData2.filter('a < 2) // TestData2(1, 1) :: TestData2(1, 2) + testDataForJoin.registerTempTable("testDataForJoin") + withTempTable("testDataForJoin") { + // Assume the execution plan is + // ... -> SortMergeJoin(nodeId = 1) -> TungstenProject(nodeId = 0) + val df = sqlContext.sql( + "SELECT * FROM testData2 JOIN testDataForJoin ON testData2.a = testDataForJoin.a") + testSparkPlanMetrics(df, 1, Map( + 1L -> ("SortMergeJoin", Map( + // It's 4 because we only read 3 rows in the first partition and 1 row in the second one + "number of left rows" -> 4L, + "number of right rows" -> 2L, + "number of output rows" -> 4L))) + ) + } + } + } + + test("SortMergeOuterJoin metrics") { + // Because SortMergeOuterJoin may skip different rows if the number of partitions is different, + // this test should use the deterministic number of partitions. + withSQLConf(SQLConf.SORTMERGE_JOIN.key -> "true") { + val testDataForJoin = TestData.testData2.filter('a < 2) // TestData2(1, 1) :: TestData2(1, 2) + testDataForJoin.registerTempTable("testDataForJoin") + withTempTable("testDataForJoin") { + // Assume the execution plan is + // ... -> SortMergeOuterJoin(nodeId = 1) -> TungstenProject(nodeId = 0) + val df = sqlContext.sql( + "SELECT * FROM testData2 left JOIN testDataForJoin ON testData2.a = testDataForJoin.a") + testSparkPlanMetrics(df, 1, Map( + 1L -> ("SortMergeOuterJoin", Map( + // It's 4 because we only read 3 rows in the first partition and 1 row in the second one + "number of left rows" -> 6L, + "number of right rows" -> 2L, + "number of output rows" -> 8L))) + ) + + val df2 = sqlContext.sql( + "SELECT * FROM testDataForJoin right JOIN testData2 ON testData2.a = testDataForJoin.a") + testSparkPlanMetrics(df2, 1, Map( + 1L -> ("SortMergeOuterJoin", Map( + // It's 4 because we only read 3 rows in the first partition and 1 row in the second one + "number of left rows" -> 2L, + "number of right rows" -> 6L, + "number of output rows" -> 8L))) + ) + } + } + } + + test("BroadcastHashJoin metrics") { + withSQLConf(SQLConf.SORTMERGE_JOIN.key -> "false") { + val df1 = Seq((1, "1"), (2, "2")).toDF("key", "value") + val df2 = Seq((1, "1"), (2, "2"), (3, "3"), (4, "4")).toDF("key", "value") + // Assume the execution plan is + // ... -> BroadcastHashJoin(nodeId = 1) -> TungstenProject(nodeId = 0) + val df = df1.join(broadcast(df2), "key") + testSparkPlanMetrics(df, 2, Map( + 1L -> ("BroadcastHashJoin", Map( + "number of left rows" -> 2L, + "number of right rows" -> 4L, + "number of output rows" -> 2L))) + ) + } + } + + test("ShuffledHashJoin metrics") { + withSQLConf(SQLConf.SORTMERGE_JOIN.key -> "false") { + val testDataForJoin = TestData.testData2.filter('a < 2) // TestData2(1, 1) :: TestData2(1, 2) + testDataForJoin.registerTempTable("testDataForJoin") + withTempTable("testDataForJoin") { + // Assume the execution plan is + // ... -> ShuffledHashJoin(nodeId = 1) -> TungstenProject(nodeId = 0) + val df = sqlContext.sql( + "SELECT * FROM testData2 JOIN testDataForJoin ON testData2.a = testDataForJoin.a") + testSparkPlanMetrics(df, 1, Map( + 1L -> ("ShuffledHashJoin", Map( + "number of left rows" -> 6L, + "number of right rows" -> 2L, + "number of output rows" -> 4L))) + ) + } + } + } + + test("ShuffledHashOuterJoin metrics") { + withSQLConf(SQLConf.SORTMERGE_JOIN.key -> "false", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "0") { + val df1 = Seq((1, "a"), (1, "b"), (4, "c")).toDF("key", "value") + val df2 = Seq((1, "a"), (1, "b"), (2, "c"), (3, "d")).toDF("key2", "value") + // Assume the execution plan is + // ... -> ShuffledHashOuterJoin(nodeId = 0) + val df = df1.join(df2, $"key" === $"key2", "left_outer") + testSparkPlanMetrics(df, 1, Map( + 0L -> ("ShuffledHashOuterJoin", Map( + "number of left rows" -> 3L, + "number of right rows" -> 4L, + "number of output rows" -> 5L))) + ) + + val df3 = df1.join(df2, $"key" === $"key2", "right_outer") + testSparkPlanMetrics(df3, 1, Map( + 0L -> ("ShuffledHashOuterJoin", Map( + "number of left rows" -> 3L, + "number of right rows" -> 4L, + "number of output rows" -> 6L))) + ) + + val df4 = df1.join(df2, $"key" === $"key2", "outer") + testSparkPlanMetrics(df4, 1, Map( + 0L -> ("ShuffledHashOuterJoin", Map( + "number of left rows" -> 3L, + "number of right rows" -> 4L, + "number of output rows" -> 7L))) + ) + } + } + + test("BroadcastHashOuterJoin metrics") { + withSQLConf(SQLConf.SORTMERGE_JOIN.key -> "false") { + val df1 = Seq((1, "a"), (1, "b"), (4, "c")).toDF("key", "value") + val df2 = Seq((1, "a"), (1, "b"), (2, "c"), (3, "d")).toDF("key2", "value") + // Assume the execution plan is + // ... -> BroadcastHashOuterJoin(nodeId = 0) + val df = df1.join(broadcast(df2), $"key" === $"key2", "left_outer") + testSparkPlanMetrics(df, 2, Map( + 0L -> ("BroadcastHashOuterJoin", Map( + "number of left rows" -> 3L, + "number of right rows" -> 4L, + "number of output rows" -> 5L))) + ) + + val df3 = df1.join(broadcast(df2), $"key" === $"key2", "right_outer") + testSparkPlanMetrics(df3, 2, Map( + 0L -> ("BroadcastHashOuterJoin", Map( + "number of left rows" -> 3L, + "number of right rows" -> 4L, + "number of output rows" -> 6L))) + ) + } + } + + test("BroadcastNestedLoopJoin metrics") { + withSQLConf(SQLConf.SORTMERGE_JOIN.key -> "true") { + val testDataForJoin = TestData.testData2.filter('a < 2) // TestData2(1, 1) :: TestData2(1, 2) + testDataForJoin.registerTempTable("testDataForJoin") + withTempTable("testDataForJoin") { + // Assume the execution plan is + // ... -> BroadcastNestedLoopJoin(nodeId = 1) -> TungstenProject(nodeId = 0) + val df = sqlContext.sql( + "SELECT * FROM testData2 left JOIN testDataForJoin ON " + + "testData2.a * testDataForJoin.a != testData2.a + testDataForJoin.a") + testSparkPlanMetrics(df, 3, Map( + 1L -> ("BroadcastNestedLoopJoin", Map( + "number of left rows" -> 12L, // left needs to be scanned twice + "number of right rows" -> 2L, + "number of output rows" -> 12L))) + ) + } + } + } + + test("BroadcastLeftSemiJoinHash metrics") { + withSQLConf(SQLConf.SORTMERGE_JOIN.key -> "false") { + val df1 = Seq((1, "1"), (2, "2")).toDF("key", "value") + val df2 = Seq((1, "1"), (2, "2"), (3, "3"), (4, "4")).toDF("key2", "value") + // Assume the execution plan is + // ... -> BroadcastLeftSemiJoinHash(nodeId = 0) + val df = df1.join(broadcast(df2), $"key" === $"key2", "leftsemi") + testSparkPlanMetrics(df, 2, Map( + 0L -> ("BroadcastLeftSemiJoinHash", Map( + "number of left rows" -> 2L, + "number of right rows" -> 4L, + "number of output rows" -> 2L))) + ) + } + } + + test("LeftSemiJoinHash metrics") { + withSQLConf(SQLConf.SORTMERGE_JOIN.key -> "true", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "0") { + val df1 = Seq((1, "1"), (2, "2")).toDF("key", "value") + val df2 = Seq((1, "1"), (2, "2"), (3, "3"), (4, "4")).toDF("key2", "value") + // Assume the execution plan is + // ... -> LeftSemiJoinHash(nodeId = 0) + val df = df1.join(df2, $"key" === $"key2", "leftsemi") + testSparkPlanMetrics(df, 1, Map( + 0L -> ("LeftSemiJoinHash", Map( + "number of left rows" -> 2L, + "number of right rows" -> 4L, + "number of output rows" -> 2L))) + ) + } + } + + test("LeftSemiJoinBNL metrics") { + withSQLConf(SQLConf.SORTMERGE_JOIN.key -> "false") { + val df1 = Seq((1, "1"), (2, "2")).toDF("key", "value") + val df2 = Seq((1, "1"), (2, "2"), (3, "3"), (4, "4")).toDF("key2", "value") + // Assume the execution plan is + // ... -> LeftSemiJoinBNL(nodeId = 0) + val df = df1.join(df2, $"key" < $"key2", "leftsemi") + testSparkPlanMetrics(df, 2, Map( + 0L -> ("LeftSemiJoinBNL", Map( + "number of left rows" -> 2L, + "number of right rows" -> 4L, + "number of output rows" -> 2L))) + ) + } + } + + test("CartesianProduct metrics") { + val testDataForJoin = TestData.testData2.filter('a < 2) // TestData2(1, 1) :: TestData2(1, 2) + testDataForJoin.registerTempTable("testDataForJoin") + withTempTable("testDataForJoin") { + // Assume the execution plan is + // ... -> CartesianProduct(nodeId = 1) -> TungstenProject(nodeId = 0) + val df = sqlContext.sql( + "SELECT * FROM testData2 JOIN testDataForJoin") + testSparkPlanMetrics(df, 1, Map( + 1L -> ("CartesianProduct", Map( + "number of left rows" -> 12L, // left needs to be scanned twice + "number of right rows" -> 12L, // right is read 6 times + "number of output rows" -> 12L))) + ) + } + } + + test("save metrics") { + withTempPath { file => + val previousExecutionIds = TestSQLContext.listener.executionIdToData.keySet + // Assume the execution plan is + // PhysicalRDD(nodeId = 0) + TestData.person.select('name).write.format("json").save(file.getAbsolutePath) + TestSQLContext.sparkContext.listenerBus.waitUntilEmpty(10000) + val executionIds = TestSQLContext.listener.executionIdToData.keySet.diff(previousExecutionIds) + assert(executionIds.size === 1) + val executionId = executionIds.head + val jobs = TestSQLContext.listener.getExecution(executionId).get.jobs + // Use "<=" because there is a race condition that we may miss some jobs + // TODO Change "<=" to "=" once we fix the race condition that missing the JobStarted event. + assert(jobs.size <= 1) + val metricValues = TestSQLContext.listener.getExecutionMetrics(executionId) + // Because "save" will create a new DataFrame internally, we cannot get the real metric id. + // However, we still can check the value. + assert(metricValues.values.toSeq === Seq(2L)) + } + } + } private case class MethodIdentifier[T](cls: Class[T], name: String, desc: String) From 520ad44b17f72e6465bf990f64b4e289f8a83447 Mon Sep 17 00:00:00 2001 From: Feynman Liang Date: Tue, 11 Aug 2015 12:49:47 -0700 Subject: [PATCH 025/253] [SPARK-9750] [MLLIB] Improve equals on SparseMatrix and DenseMatrix Adds unit test for `equals` on `mllib.linalg.Matrix` class and `equals` to both `SparseMatrix` and `DenseMatrix`. Supports equality testing between `SparseMatrix` and `DenseMatrix`. mengxr Author: Feynman Liang Closes #8042 from feynmanliang/SPARK-9750 and squashes the following commits: bb70d5e [Feynman Liang] Breeze compare for dense matrices as well, in case other is sparse ab6f3c8 [Feynman Liang] Sparse matrix compare for equals 22782df [Feynman Liang] Add equality based on matrix semantics, not representation 78f9426 [Feynman Liang] Add casts 43d28fa [Feynman Liang] Fix failing test 6416fa0 [Feynman Liang] Add failing sparse matrix equals tests --- .../apache/spark/mllib/linalg/Matrices.scala | 8 ++++++-- .../spark/mllib/linalg/MatricesSuite.scala | 18 ++++++++++++++++++ 2 files changed, 24 insertions(+), 2 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala index 1c858348bf20e..1139ce36d50b8 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala @@ -257,8 +257,7 @@ class DenseMatrix( this(numRows, numCols, values, false) override def equals(o: Any): Boolean = o match { - case m: DenseMatrix => - m.numRows == numRows && m.numCols == numCols && Arrays.equals(toArray, m.toArray) + case m: Matrix => toBreeze == m.toBreeze case _ => false } @@ -519,6 +518,11 @@ class SparseMatrix( rowIndices: Array[Int], values: Array[Double]) = this(numRows, numCols, colPtrs, rowIndices, values, false) + override def equals(o: Any): Boolean = o match { + case m: Matrix => toBreeze == m.toBreeze + case _ => false + } + private[mllib] def toBreeze: BM[Double] = { if (!isTransposed) { new BSM[Double](values, numRows, numCols, colPtrs, rowIndices) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala index a270ba2562db9..bfd6d5495f5e0 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala @@ -74,6 +74,24 @@ class MatricesSuite extends SparkFunSuite { } } + test("equals") { + val dm1 = Matrices.dense(2, 2, Array(0.0, 1.0, 2.0, 3.0)) + assert(dm1 === dm1) + assert(dm1 !== dm1.transpose) + + val dm2 = Matrices.dense(2, 2, Array(0.0, 2.0, 1.0, 3.0)) + assert(dm1 === dm2.transpose) + + val sm1 = dm1.asInstanceOf[DenseMatrix].toSparse + assert(sm1 === sm1) + assert(sm1 === dm1) + assert(sm1 !== sm1.transpose) + + val sm2 = dm2.asInstanceOf[DenseMatrix].toSparse + assert(sm1 === sm2.transpose) + assert(sm1 === dm2.transpose) + } + test("matrix copies are deep copies") { val m = 3 val n = 2 From 2a3be4ddf9d9527353f07ea0ab204ce17dbcba9a Mon Sep 17 00:00:00 2001 From: Patrick Wendell Date: Tue, 11 Aug 2015 14:02:23 -0700 Subject: [PATCH 026/253] [SPARK-7726] Add import so Scaladoc doesn't fail. This is another import needed so Scala 2.11 doc generation doesn't fail. See SPARK-7726 for more detail. I tested this locally and the 2.11 install goes from failing to succeeding with this patch. Author: Patrick Wendell Closes #8095 from pwendell/scaladoc. --- .../spark/network/shuffle/protocol/mesos/RegisterDriver.java | 3 +++ 1 file changed, 3 insertions(+) diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/mesos/RegisterDriver.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/mesos/RegisterDriver.java index 1c28fc1dff246..94a61d6caadc4 100644 --- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/mesos/RegisterDriver.java +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/mesos/RegisterDriver.java @@ -23,6 +23,9 @@ import org.apache.spark.network.protocol.Encoders; import org.apache.spark.network.shuffle.protocol.BlockTransferMessage; +// Needed by ScalaDoc. See SPARK-7726 +import static org.apache.spark.network.shuffle.protocol.BlockTransferMessage.Type; + /** * A message sent from the driver to register with the MesosExternalShuffleService. */ From 00c02728a6c6c4282c389ca90641dd78dd5e3d32 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Tue, 11 Aug 2015 14:04:09 -0700 Subject: [PATCH 027/253] [SPARK-9814] [SQL] EqualNotNull not passing to data sources MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Author: hyukjinkwon Author: 권혁진 Closes #8096 from HyukjinKwon/master. --- .../sql/execution/datasources/DataSourceStrategy.scala | 5 +++++ .../scala/org/apache/spark/sql/sources/filters.scala | 9 +++++++++ .../org/apache/spark/sql/sources/FilteredScanSuite.scala | 1 + 3 files changed, 15 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index 78a4acdf4b1bf..2a4c40db8bb66 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -349,6 +349,11 @@ private[sql] object DataSourceStrategy extends Strategy with Logging { case expressions.EqualTo(Literal(v, _), a: Attribute) => Some(sources.EqualTo(a.name, v)) + case expressions.EqualNullSafe(a: Attribute, Literal(v, _)) => + Some(sources.EqualNullSafe(a.name, v)) + case expressions.EqualNullSafe(Literal(v, _), a: Attribute) => + Some(sources.EqualNullSafe(a.name, v)) + case expressions.GreaterThan(a: Attribute, Literal(v, _)) => Some(sources.GreaterThan(a.name, v)) case expressions.GreaterThan(Literal(v, _), a: Attribute) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/filters.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/filters.scala index 4d942e4f9287a..3780cbbcc9631 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/filters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/filters.scala @@ -36,6 +36,15 @@ abstract class Filter */ case class EqualTo(attribute: String, value: Any) extends Filter +/** + * Performs equality comparison, similar to [[EqualTo]]. However, this differs from [[EqualTo]] + * in that it returns `true` (rather than NULL) if both inputs are NULL, and `false` + * (rather than NULL) if one of the input is NULL and the other is not NULL. + * + * @since 1.5.0 + */ +case class EqualNullSafe(attribute: String, value: Any) extends Filter + /** * A filter that evaluates to `true` iff the attribute evaluates to a value * greater than `value`. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala index 81b3a0f0c5b3a..5ef365797eace 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala @@ -56,6 +56,7 @@ case class SimpleFilteredScan(from: Int, to: Int)(@transient val sqlContext: SQL // Predicate test on integer column def translateFilterOnA(filter: Filter): Int => Boolean = filter match { case EqualTo("a", v) => (a: Int) => a == v + case EqualNullSafe("a", v) => (a: Int) => a == v case LessThan("a", v: Int) => (a: Int) => a < v case LessThanOrEqual("a", v: Int) => (a: Int) => a <= v case GreaterThan("a", v: Int) => (a: Int) => a > v From f16bc68dfb25c7b746ae031a57840ace9bafa87f Mon Sep 17 00:00:00 2001 From: zsxwing Date: Tue, 11 Aug 2015 14:06:23 -0700 Subject: [PATCH 028/253] [SPARK-9824] [CORE] Fix the issue that InternalAccumulator leaks WeakReference `InternalAccumulator.create` doesn't call `registerAccumulatorForCleanup` to register itself with ContextCleaner, so `WeakReference`s for these accumulators in `Accumulators.originals` won't be removed. This PR added `registerAccumulatorForCleanup` for internal accumulators to avoid the memory leak. Author: zsxwing Closes #8108 from zsxwing/internal-accumulators-leak. --- .../scala/org/apache/spark/Accumulators.scala | 22 +++++++++++-------- .../org/apache/spark/scheduler/Stage.scala | 2 +- .../org/apache/spark/AccumulatorSuite.scala | 3 ++- 3 files changed, 16 insertions(+), 11 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/Accumulators.scala b/core/src/main/scala/org/apache/spark/Accumulators.scala index 064246dfa7fc3..c39c8667d013e 100644 --- a/core/src/main/scala/org/apache/spark/Accumulators.scala +++ b/core/src/main/scala/org/apache/spark/Accumulators.scala @@ -382,14 +382,18 @@ private[spark] object InternalAccumulator { * add to the same set of accumulators. We do this to report the distribution of accumulator * values across all tasks within each stage. */ - def create(): Seq[Accumulator[Long]] = { - Seq( - // Execution memory refers to the memory used by internal data structures created - // during shuffles, aggregations and joins. The value of this accumulator should be - // approximately the sum of the peak sizes across all such data structures created - // in this task. For SQL jobs, this only tracks all unsafe operators and ExternalSort. - new Accumulator( - 0L, AccumulatorParam.LongAccumulatorParam, Some(PEAK_EXECUTION_MEMORY), internal = true) - ) ++ maybeTestAccumulator.toSeq + def create(sc: SparkContext): Seq[Accumulator[Long]] = { + val internalAccumulators = Seq( + // Execution memory refers to the memory used by internal data structures created + // during shuffles, aggregations and joins. The value of this accumulator should be + // approximately the sum of the peak sizes across all such data structures created + // in this task. For SQL jobs, this only tracks all unsafe operators and ExternalSort. + new Accumulator( + 0L, AccumulatorParam.LongAccumulatorParam, Some(PEAK_EXECUTION_MEMORY), internal = true) + ) ++ maybeTestAccumulator.toSeq + internalAccumulators.foreach { accumulator => + sc.cleaner.foreach(_.registerAccumulatorForCleanup(accumulator)) + } + internalAccumulators } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala index de05ee256dbfc..1cf06856ffbc2 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala @@ -81,7 +81,7 @@ private[spark] abstract class Stage( * accumulators here again will override partial values from the finished tasks. */ def resetInternalAccumulators(): Unit = { - _internalAccumulators = InternalAccumulator.create() + _internalAccumulators = InternalAccumulator.create(rdd.sparkContext) } /** diff --git a/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala b/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala index 48f549575f4d1..0eb2293a9d063 100644 --- a/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala +++ b/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala @@ -160,7 +160,8 @@ class AccumulatorSuite extends SparkFunSuite with Matchers with LocalSparkContex } test("internal accumulators in TaskContext") { - val accums = InternalAccumulator.create() + sc = new SparkContext("local", "test") + val accums = InternalAccumulator.create(sc) val taskContext = new TaskContextImpl(0, 0, 0, 0, null, null, accums) val internalMetricsToAccums = taskContext.internalMetricsToAccumulators val collectedInternalAccums = taskContext.collectInternalAccumulators() From 423cdfd83d7fd02a4f8cf3e714db913fd3f9ca09 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Tue, 11 Aug 2015 14:08:09 -0700 Subject: [PATCH 029/253] Closes #1290 Closes #4934 From be3e27164133025db860781bd5cdd3ca233edd21 Mon Sep 17 00:00:00 2001 From: Feynman Liang Date: Tue, 11 Aug 2015 14:21:53 -0700 Subject: [PATCH 030/253] [SPARK-9788] [MLLIB] Fix LDA Binary Compatibility MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 1. Add “asymmetricDocConcentration” and revert docConcentration changes. If the (internal) doc concentration vector is a single value, “getDocConcentration" returns it. If it is a constant vector, getDocConcentration returns the first item, and fails otherwise. 2. Give `LDAModel.gammaShape` a default value in `LDAModel` concrete class constructors. jkbradley Author: Feynman Liang Closes #8077 from feynmanliang/SPARK-9788 and squashes the following commits: 6b07bc8 [Feynman Liang] Code review changes 9d6a71e [Feynman Liang] Add asymmetricAlpha alias bf4e685 [Feynman Liang] Asymmetric docConcentration 4cab972 [Feynman Liang] Default gammaShape --- .../apache/spark/mllib/clustering/LDA.scala | 27 ++++++++++++++++-- .../spark/mllib/clustering/LDAModel.scala | 11 ++++---- .../spark/mllib/clustering/LDAOptimizer.scala | 28 +++++++++---------- .../spark/mllib/clustering/LDASuite.scala | 4 +-- 4 files changed, 46 insertions(+), 24 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala index ab124e6d77c5e..0fc9b1ac4d716 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala @@ -79,7 +79,24 @@ class LDA private ( * * This is the parameter to a Dirichlet distribution. */ - def getDocConcentration: Vector = this.docConcentration + def getAsymmetricDocConcentration: Vector = this.docConcentration + + /** + * Concentration parameter (commonly named "alpha") for the prior placed on documents' + * distributions over topics ("theta"). + * + * This method assumes the Dirichlet distribution is symmetric and can be described by a single + * [[Double]] parameter. It should fail if docConcentration is asymmetric. + */ + def getDocConcentration: Double = { + val parameter = docConcentration(0) + if (docConcentration.size == 1) { + parameter + } else { + require(docConcentration.toArray.forall(_ == parameter)) + parameter + } + } /** * Concentration parameter (commonly named "alpha") for the prior placed on documents' @@ -106,18 +123,22 @@ class LDA private ( * [[https://github.com/Blei-Lab/onlineldavb]]. */ def setDocConcentration(docConcentration: Vector): this.type = { + require(docConcentration.size > 0, "docConcentration must have > 0 elements") this.docConcentration = docConcentration this } - /** Replicates Double to create a symmetric prior */ + /** Replicates a [[Double]] docConcentration to create a symmetric prior. */ def setDocConcentration(docConcentration: Double): this.type = { this.docConcentration = Vectors.dense(docConcentration) this } + /** Alias for [[getAsymmetricDocConcentration]] */ + def getAsymmetricAlpha: Vector = getAsymmetricDocConcentration + /** Alias for [[getDocConcentration]] */ - def getAlpha: Vector = getDocConcentration + def getAlpha: Double = getDocConcentration /** Alias for [[setDocConcentration()]] */ def setAlpha(alpha: Vector): this.type = setDocConcentration(alpha) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala index 33babda69bbb9..5dc637ebdc133 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala @@ -27,7 +27,6 @@ import org.json4s.jackson.JsonMethods._ import org.apache.spark.SparkContext import org.apache.spark.annotation.Experimental import org.apache.spark.api.java.JavaPairRDD -import org.apache.spark.broadcast.Broadcast import org.apache.spark.graphx.{Edge, EdgeContext, Graph, VertexId} import org.apache.spark.mllib.linalg.{Matrices, Matrix, Vector, Vectors} import org.apache.spark.mllib.util.{Loader, Saveable} @@ -190,7 +189,8 @@ class LocalLDAModel private[clustering] ( val topics: Matrix, override val docConcentration: Vector, override val topicConcentration: Double, - override protected[clustering] val gammaShape: Double) extends LDAModel with Serializable { + override protected[clustering] val gammaShape: Double = 100) + extends LDAModel with Serializable { override def k: Int = topics.numCols @@ -455,8 +455,9 @@ class DistributedLDAModel private[clustering] ( val vocabSize: Int, override val docConcentration: Vector, override val topicConcentration: Double, - override protected[clustering] val gammaShape: Double, - private[spark] val iterationTimes: Array[Double]) extends LDAModel { + private[spark] val iterationTimes: Array[Double], + override protected[clustering] val gammaShape: Double = 100) + extends LDAModel { import LDA._ @@ -756,7 +757,7 @@ object DistributedLDAModel extends Loader[DistributedLDAModel] { val graph: Graph[LDA.TopicCounts, LDA.TokenCount] = Graph(vertices, edges) new DistributedLDAModel(graph, globalTopicTotals, globalTopicTotals.length, vocabSize, - docConcentration, topicConcentration, gammaShape, iterationTimes) + docConcentration, topicConcentration, iterationTimes, gammaShape) } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala index afba2866c7040..a0008f9c99ad7 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala @@ -95,10 +95,8 @@ final class EMLDAOptimizer extends LDAOptimizer { * Compute bipartite term/doc graph. */ override private[clustering] def initialize(docs: RDD[(Long, Vector)], lda: LDA): LDAOptimizer = { - val docConcentration = lda.getDocConcentration(0) - require({ - lda.getDocConcentration.toArray.forall(_ == docConcentration) - }, "EMLDAOptimizer currently only supports symmetric document-topic priors") + // EMLDAOptimizer currently only supports symmetric document-topic priors + val docConcentration = lda.getDocConcentration val topicConcentration = lda.getTopicConcentration val k = lda.getK @@ -209,11 +207,11 @@ final class EMLDAOptimizer extends LDAOptimizer { override private[clustering] def getLDAModel(iterationTimes: Array[Double]): LDAModel = { require(graph != null, "graph is null, EMLDAOptimizer not initialized.") this.graphCheckpointer.deleteAllCheckpoints() - // This assumes gammaShape = 100 in OnlineLDAOptimizer to ensure equivalence in LDAModel.toLocal - // conversion + // The constructor's default arguments assume gammaShape = 100 to ensure equivalence in + // LDAModel.toLocal conversion new DistributedLDAModel(this.graph, this.globalTopicTotals, this.k, this.vocabSize, Vectors.dense(Array.fill(this.k)(this.docConcentration)), this.topicConcentration, - 100, iterationTimes) + iterationTimes) } } @@ -378,18 +376,20 @@ final class OnlineLDAOptimizer extends LDAOptimizer { this.k = lda.getK this.corpusSize = docs.count() this.vocabSize = docs.first()._2.size - this.alpha = if (lda.getDocConcentration.size == 1) { - if (lda.getDocConcentration(0) == -1) Vectors.dense(Array.fill(k)(1.0 / k)) + this.alpha = if (lda.getAsymmetricDocConcentration.size == 1) { + if (lda.getAsymmetricDocConcentration(0) == -1) Vectors.dense(Array.fill(k)(1.0 / k)) else { - require(lda.getDocConcentration(0) >= 0, s"all entries in alpha must be >=0, got: $alpha") - Vectors.dense(Array.fill(k)(lda.getDocConcentration(0))) + require(lda.getAsymmetricDocConcentration(0) >= 0, + s"all entries in alpha must be >=0, got: $alpha") + Vectors.dense(Array.fill(k)(lda.getAsymmetricDocConcentration(0))) } } else { - require(lda.getDocConcentration.size == k, s"alpha must have length k, got: $alpha") - lda.getDocConcentration.foreachActive { case (_, x) => + require(lda.getAsymmetricDocConcentration.size == k, + s"alpha must have length k, got: $alpha") + lda.getAsymmetricDocConcentration.foreachActive { case (_, x) => require(x >= 0, s"all entries in alpha must be >= 0, got: $alpha") } - lda.getDocConcentration + lda.getAsymmetricDocConcentration } this.eta = if (lda.getTopicConcentration == -1) 1.0 / k else lda.getTopicConcentration this.randomGenerator = new Random(lda.getSeed) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala index fdc2554ab853e..ce6a8eb8e8c46 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala @@ -160,8 +160,8 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext { test("setter alias") { val lda = new LDA().setAlpha(2.0).setBeta(3.0) - assert(lda.getAlpha.toArray.forall(_ === 2.0)) - assert(lda.getDocConcentration.toArray.forall(_ === 2.0)) + assert(lda.getAsymmetricAlpha.toArray.forall(_ === 2.0)) + assert(lda.getAsymmetricDocConcentration.toArray.forall(_ === 2.0)) assert(lda.getBeta === 3.0) assert(lda.getTopicConcentration === 3.0) } From 017b5de07ef6cff249e984a2ab781c520249ac76 Mon Sep 17 00:00:00 2001 From: Sudhakar Thota Date: Tue, 11 Aug 2015 14:31:51 -0700 Subject: [PATCH 031/253] [SPARK-8925] [MLLIB] Add @since tags to mllib.util Went thru the history of changes the file MLUtils.scala and picked up the version that the change went in. Author: Sudhakar Thota Author: Sudhakar Thota Closes #7436 from sthota2014/SPARK-8925_thotas. --- .../org/apache/spark/mllib/util/MLUtils.scala | 22 ++++++++++++++++++- 1 file changed, 21 insertions(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala index 7c5cfa7bd84ce..26eb84a8dc0b0 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala @@ -64,6 +64,7 @@ object MLUtils { * feature dimensions. * @param minPartitions min number of partitions * @return labeled data stored as an RDD[LabeledPoint] + * @since 1.0.0 */ def loadLibSVMFile( sc: SparkContext, @@ -113,7 +114,10 @@ object MLUtils { } // Convenient methods for `loadLibSVMFile`. - + + /** + * @since 1.0.0 + */ @deprecated("use method without multiclass argument, which no longer has effect", "1.1.0") def loadLibSVMFile( sc: SparkContext, @@ -126,6 +130,7 @@ object MLUtils { /** * Loads labeled data in the LIBSVM format into an RDD[LabeledPoint], with the default number of * partitions. + * @since 1.0.0 */ def loadLibSVMFile( sc: SparkContext, @@ -133,6 +138,9 @@ object MLUtils { numFeatures: Int): RDD[LabeledPoint] = loadLibSVMFile(sc, path, numFeatures, sc.defaultMinPartitions) + /** + * @since 1.0.0 + */ @deprecated("use method without multiclass argument, which no longer has effect", "1.1.0") def loadLibSVMFile( sc: SparkContext, @@ -141,6 +149,9 @@ object MLUtils { numFeatures: Int): RDD[LabeledPoint] = loadLibSVMFile(sc, path, numFeatures) + /** + * @since 1.0.0 + */ @deprecated("use method without multiclass argument, which no longer has effect", "1.1.0") def loadLibSVMFile( sc: SparkContext, @@ -151,6 +162,7 @@ object MLUtils { /** * Loads binary labeled data in the LIBSVM format into an RDD[LabeledPoint], with number of * features determined automatically and the default number of partitions. + * @since 1.0.0 */ def loadLibSVMFile(sc: SparkContext, path: String): RDD[LabeledPoint] = loadLibSVMFile(sc, path, -1) @@ -181,12 +193,14 @@ object MLUtils { * @param path file or directory path in any Hadoop-supported file system URI * @param minPartitions min number of partitions * @return vectors stored as an RDD[Vector] + * @since 1.1.0 */ def loadVectors(sc: SparkContext, path: String, minPartitions: Int): RDD[Vector] = sc.textFile(path, minPartitions).map(Vectors.parse) /** * Loads vectors saved using `RDD[Vector].saveAsTextFile` with the default number of partitions. + * @since 1.1.0 */ def loadVectors(sc: SparkContext, path: String): RDD[Vector] = sc.textFile(path, sc.defaultMinPartitions).map(Vectors.parse) @@ -197,6 +211,7 @@ object MLUtils { * @param path file or directory path in any Hadoop-supported file system URI * @param minPartitions min number of partitions * @return labeled points stored as an RDD[LabeledPoint] + * @since 1.1.0 */ def loadLabeledPoints(sc: SparkContext, path: String, minPartitions: Int): RDD[LabeledPoint] = sc.textFile(path, minPartitions).map(LabeledPoint.parse) @@ -204,6 +219,7 @@ object MLUtils { /** * Loads labeled points saved using `RDD[LabeledPoint].saveAsTextFile` with the default number of * partitions. + * @since 1.1.0 */ def loadLabeledPoints(sc: SparkContext, dir: String): RDD[LabeledPoint] = loadLabeledPoints(sc, dir, sc.defaultMinPartitions) @@ -220,6 +236,7 @@ object MLUtils { * * @deprecated Should use [[org.apache.spark.rdd.RDD#saveAsTextFile]] for saving and * [[org.apache.spark.mllib.util.MLUtils#loadLabeledPoints]] for loading. + * @since 1.0.0 */ @deprecated("Should use MLUtils.loadLabeledPoints instead.", "1.0.1") def loadLabeledData(sc: SparkContext, dir: String): RDD[LabeledPoint] = { @@ -241,6 +258,7 @@ object MLUtils { * * @deprecated Should use [[org.apache.spark.rdd.RDD#saveAsTextFile]] for saving and * [[org.apache.spark.mllib.util.MLUtils#loadLabeledPoints]] for loading. + * @since 1.0.0 */ @deprecated("Should use RDD[LabeledPoint].saveAsTextFile instead.", "1.0.1") def saveLabeledData(data: RDD[LabeledPoint], dir: String) { @@ -253,6 +271,7 @@ object MLUtils { * Return a k element array of pairs of RDDs with the first element of each pair * containing the training data, a complement of the validation data and the second * element, the validation data, containing a unique 1/kth of the data. Where k=numFolds. + * @since 1.0.0 */ @Experimental def kFold[T: ClassTag](rdd: RDD[T], numFolds: Int, seed: Int): Array[(RDD[T], RDD[T])] = { @@ -268,6 +287,7 @@ object MLUtils { /** * Returns a new vector with `1.0` (bias) appended to the input vector. + * @since 1.0.0 */ def appendBias(vector: Vector): Vector = { vector match { From 736af95bd0c41723d455246b634a0fb68b38a7c7 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Tue, 11 Aug 2015 14:52:52 -0700 Subject: [PATCH 032/253] [HOTFIX] Fix style error caused by 017b5de --- mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala index 26eb84a8dc0b0..11ed23176fc12 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala @@ -114,7 +114,7 @@ object MLUtils { } // Convenient methods for `loadLibSVMFile`. - + /** * @since 1.0.0 */ From 5a5bbc29961630d649d4bd4acd5d19eb537b5fd0 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Tue, 11 Aug 2015 16:33:08 -0700 Subject: [PATCH 033/253] [SPARK-9074] [LAUNCHER] Allow arbitrary Spark args to be set. This change allows any Spark argument to be added to the app to be started using SparkLauncher. Known arguments are properly validated, while unknown arguments are allowed so that the library can launch newer Spark versions (in case SPARK_HOME points at one). Author: Marcelo Vanzin Closes #7975 from vanzin/SPARK-9074 and squashes the following commits: b5e451a [Marcelo Vanzin] [SPARK-9074] [launcher] Allow arbitrary Spark args to be set. --- .../apache/spark/launcher/SparkLauncher.java | 101 +++++++++++++++++- .../launcher/SparkSubmitCommandBuilder.java | 2 +- .../spark/launcher/SparkLauncherSuite.java | 50 +++++++++ 3 files changed, 150 insertions(+), 3 deletions(-) diff --git a/launcher/src/main/java/org/apache/spark/launcher/SparkLauncher.java b/launcher/src/main/java/org/apache/spark/launcher/SparkLauncher.java index c0f89c9230692..03c9358bc865d 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/SparkLauncher.java +++ b/launcher/src/main/java/org/apache/spark/launcher/SparkLauncher.java @@ -20,12 +20,13 @@ import java.io.File; import java.io.IOException; import java.util.ArrayList; +import java.util.Arrays; import java.util.List; import java.util.Map; import static org.apache.spark.launcher.CommandBuilderUtils.*; -/** +/** * Launcher for Spark applications. *

* Use this class to start Spark applications programmatically. The class uses a builder pattern @@ -57,7 +58,8 @@ public class SparkLauncher { /** Configuration key for the number of executor CPU cores. */ public static final String EXECUTOR_CORES = "spark.executor.cores"; - private final SparkSubmitCommandBuilder builder; + // Visible for testing. + final SparkSubmitCommandBuilder builder; public SparkLauncher() { this(null); @@ -187,6 +189,73 @@ public SparkLauncher setMainClass(String mainClass) { return this; } + /** + * Adds a no-value argument to the Spark invocation. If the argument is known, this method + * validates whether the argument is indeed a no-value argument, and throws an exception + * otherwise. + *

+ * Use this method with caution. It is possible to create an invalid Spark command by passing + * unknown arguments to this method, since those are allowed for forward compatibility. + * + * @param arg Argument to add. + * @return This launcher. + */ + public SparkLauncher addSparkArg(String arg) { + SparkSubmitOptionParser validator = new ArgumentValidator(false); + validator.parse(Arrays.asList(arg)); + builder.sparkArgs.add(arg); + return this; + } + + /** + * Adds an argument with a value to the Spark invocation. If the argument name corresponds to + * a known argument, the code validates that the argument actually expects a value, and throws + * an exception otherwise. + *

+ * It is safe to add arguments modified by other methods in this class (such as + * {@link #setMaster(String)} - the last invocation will be the one to take effect. + *

+ * Use this method with caution. It is possible to create an invalid Spark command by passing + * unknown arguments to this method, since those are allowed for forward compatibility. + * + * @param name Name of argument to add. + * @param value Value of the argument. + * @return This launcher. + */ + public SparkLauncher addSparkArg(String name, String value) { + SparkSubmitOptionParser validator = new ArgumentValidator(true); + if (validator.MASTER.equals(name)) { + setMaster(value); + } else if (validator.PROPERTIES_FILE.equals(name)) { + setPropertiesFile(value); + } else if (validator.CONF.equals(name)) { + String[] vals = value.split("=", 2); + setConf(vals[0], vals[1]); + } else if (validator.CLASS.equals(name)) { + setMainClass(value); + } else if (validator.JARS.equals(name)) { + builder.jars.clear(); + for (String jar : value.split(",")) { + addJar(jar); + } + } else if (validator.FILES.equals(name)) { + builder.files.clear(); + for (String file : value.split(",")) { + addFile(file); + } + } else if (validator.PY_FILES.equals(name)) { + builder.pyFiles.clear(); + for (String file : value.split(",")) { + addPyFile(file); + } + } else { + validator.parse(Arrays.asList(name, value)); + builder.sparkArgs.add(name); + builder.sparkArgs.add(value); + } + return this; + } + /** * Adds command line arguments for the application. * @@ -277,4 +346,32 @@ public Process launch() throws IOException { return pb.start(); } + private static class ArgumentValidator extends SparkSubmitOptionParser { + + private final boolean hasValue; + + ArgumentValidator(boolean hasValue) { + this.hasValue = hasValue; + } + + @Override + protected boolean handle(String opt, String value) { + if (value == null && hasValue) { + throw new IllegalArgumentException(String.format("'%s' does not expect a value.", opt)); + } + return true; + } + + @Override + protected boolean handleUnknown(String opt) { + // Do not fail on unknown arguments, to support future arguments added to SparkSubmit. + return true; + } + + protected void handleExtraArgs(List extra) { + // No op. + } + + }; + } diff --git a/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java b/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java index 87c43aa9980e1..4f354cedee66f 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java +++ b/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java @@ -76,7 +76,7 @@ class SparkSubmitCommandBuilder extends AbstractCommandBuilder { "spark-internal"); } - private final List sparkArgs; + final List sparkArgs; private final boolean printHelp; /** diff --git a/launcher/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java b/launcher/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java index 252d5abae1ca3..d0c26dd05679b 100644 --- a/launcher/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java +++ b/launcher/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java @@ -20,6 +20,7 @@ import java.io.BufferedReader; import java.io.InputStream; import java.io.InputStreamReader; +import java.util.Arrays; import java.util.HashMap; import java.util.Map; @@ -35,8 +36,54 @@ public class SparkLauncherSuite { private static final Logger LOG = LoggerFactory.getLogger(SparkLauncherSuite.class); + @Test + public void testSparkArgumentHandling() throws Exception { + SparkLauncher launcher = new SparkLauncher() + .setSparkHome(System.getProperty("spark.test.home")); + SparkSubmitOptionParser opts = new SparkSubmitOptionParser(); + + launcher.addSparkArg(opts.HELP); + try { + launcher.addSparkArg(opts.PROXY_USER); + fail("Expected IllegalArgumentException."); + } catch (IllegalArgumentException e) { + // Expected. + } + + launcher.addSparkArg(opts.PROXY_USER, "someUser"); + try { + launcher.addSparkArg(opts.HELP, "someValue"); + fail("Expected IllegalArgumentException."); + } catch (IllegalArgumentException e) { + // Expected. + } + + launcher.addSparkArg("--future-argument"); + launcher.addSparkArg("--future-argument", "someValue"); + + launcher.addSparkArg(opts.MASTER, "myMaster"); + assertEquals("myMaster", launcher.builder.master); + + launcher.addJar("foo"); + launcher.addSparkArg(opts.JARS, "bar"); + assertEquals(Arrays.asList("bar"), launcher.builder.jars); + + launcher.addFile("foo"); + launcher.addSparkArg(opts.FILES, "bar"); + assertEquals(Arrays.asList("bar"), launcher.builder.files); + + launcher.addPyFile("foo"); + launcher.addSparkArg(opts.PY_FILES, "bar"); + assertEquals(Arrays.asList("bar"), launcher.builder.pyFiles); + + launcher.setConf("spark.foo", "foo"); + launcher.addSparkArg(opts.CONF, "spark.foo=bar"); + assertEquals("bar", launcher.builder.conf.get("spark.foo")); + } + @Test public void testChildProcLauncher() throws Exception { + SparkSubmitOptionParser opts = new SparkSubmitOptionParser(); Map env = new HashMap(); env.put("SPARK_PRINT_LAUNCH_COMMAND", "1"); @@ -44,9 +91,12 @@ public void testChildProcLauncher() throws Exception { .setSparkHome(System.getProperty("spark.test.home")) .setMaster("local") .setAppResource("spark-internal") + .addSparkArg(opts.CONF, + String.format("%s=-Dfoo=ShouldBeOverriddenBelow", SparkLauncher.DRIVER_EXTRA_JAVA_OPTIONS)) .setConf(SparkLauncher.DRIVER_EXTRA_JAVA_OPTIONS, "-Dfoo=bar -Dtest.name=-testChildProcLauncher") .setConf(SparkLauncher.DRIVER_EXTRA_CLASSPATH, System.getProperty("java.class.path")) + .addSparkArg(opts.CLASS, "ShouldBeOverriddenBelow") .setMainClass(SparkLauncherTestApp.class.getName()) .addAppArgs("proc"); final Process app = launcher.launch(); From afa757c98c537965007cad4c61c436887f3ac6a6 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Tue, 11 Aug 2015 18:08:49 -0700 Subject: [PATCH 034/253] [SPARK-9849] [SQL] DirectParquetOutputCommitter qualified name should be backward compatible DirectParquetOutputCommitter was moved in SPARK-9763. However, users can explicitly set the class as a config option, so we must be able to resolve the old committer qualified name. Author: Reynold Xin Closes #8114 from rxin/SPARK-9849. --- .../datasources/parquet/ParquetRelation.scala | 7 +++++ .../datasources/parquet/ParquetIOSuite.scala | 27 ++++++++++++++++++- 2 files changed, 33 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala index 4086a139bed72..c71c69b6e80b1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala @@ -209,6 +209,13 @@ private[sql] class ParquetRelation( override def prepareJobForWrite(job: Job): OutputWriterFactory = { val conf = ContextUtil.getConfiguration(job) + // SPARK-9849 DirectParquetOutputCommitter qualified name should be backward compatible + val committerClassname = conf.get(SQLConf.PARQUET_OUTPUT_COMMITTER_CLASS.key) + if (committerClassname == "org.apache.spark.sql.parquet.DirectParquetOutputCommitter") { + conf.set(SQLConf.PARQUET_OUTPUT_COMMITTER_CLASS.key, + classOf[DirectParquetOutputCommitter].getCanonicalName) + } + val committerClass = conf.getClass( SQLConf.PARQUET_OUTPUT_COMMITTER_CLASS.key, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala index ee925afe08508..cb166349fdb26 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala @@ -390,7 +390,32 @@ class ParquetIOSuite extends QueryTest with ParquetTest { } } - test("SPARK-8121: spark.sql.parquet.output.committer.class shouldn't be overriden") { + test("SPARK-9849 DirectParquetOutputCommitter qualified name should be backward compatible") { + val clonedConf = new Configuration(configuration) + + // Write to a parquet file and let it fail. + // _temporary should be missing if direct output committer works. + try { + configuration.set("spark.sql.parquet.output.committer.class", + "org.apache.spark.sql.parquet.DirectParquetOutputCommitter") + sqlContext.udf.register("div0", (x: Int) => x / 0) + withTempPath { dir => + intercept[org.apache.spark.SparkException] { + sqlContext.sql("select div0(1)").write.parquet(dir.getCanonicalPath) + } + val path = new Path(dir.getCanonicalPath, "_temporary") + val fs = path.getFileSystem(configuration) + assert(!fs.exists(path)) + } + } finally { + // Hadoop 1 doesn't have `Configuration.unset` + configuration.clear() + clonedConf.foreach(entry => configuration.set(entry.getKey, entry.getValue)) + } + } + + + test("SPARK-8121: spark.sql.parquet.output.committer.class shouldn't be overridden") { withTempPath { dir => val clonedConf = new Configuration(configuration) From b8fdd5cd1b11dd7954d1f05bb71b1a2ae740d065 Mon Sep 17 00:00:00 2001 From: Neelesh Srinivas Salian Date: Tue, 11 Aug 2015 18:43:10 -0700 Subject: [PATCH 035/253] Added section for preferred yarn and kept the one with deploy-mode for generic submission to help clear up confusion --- docs/submitting-applications.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/submitting-applications.md b/docs/submitting-applications.md index 4a564ed0a765b..ae21a16d5ec56 100644 --- a/docs/submitting-applications.md +++ b/docs/submitting-applications.md @@ -48,7 +48,7 @@ Some of the commonly used options are: * `application-jar`: Path to a bundled jar including your application and all dependencies. The URL must be globally visible inside of your cluster, for instance, an `hdfs://` path or a `file://` path that is present on all nodes. * `application-arguments`: Arguments passed to the main method of your main class, if any -Alternatively, for submitting on yarn, +For submitting application to YARN, the preferred options are: {% highlight bash %} ./bin/spark-submit \ From 8c65676a6b7a692d07face111d8e998f36ca0151 Mon Sep 17 00:00:00 2001 From: Neelesh Srinivas Salian Date: Tue, 11 Aug 2015 18:44:36 -0700 Subject: [PATCH 036/253] Moved the Standalone examples together --- docs/submitting-applications.md | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/docs/submitting-applications.md b/docs/submitting-applications.md index ae21a16d5ec56..8778f719e2501 100644 --- a/docs/submitting-applications.md +++ b/docs/submitting-applications.md @@ -125,6 +125,12 @@ run it with `--help`. Here are a few examples of common options: --total-executor-cores 100 \ /path/to/examples.jar \ 1000 + +# Run a Python application on a Spark Standalone cluster +./bin/spark-submit \ + --master spark://207.184.161.138:7077 \ + examples/src/main/python/pi.py \ + 1000 # Run on a Spark Standalone cluster in cluster deploy mode with supervise ./bin/spark-submit \ @@ -146,12 +152,6 @@ export HADOOP_CONF_DIR=XXX --num-executors 50 \ /path/to/examples.jar \ 1000 - -# Run a Python application on a Spark Standalone cluster -./bin/spark-submit \ - --master spark://207.184.161.138:7077 \ - examples/src/main/python/pi.py \ - 1000 {% endhighlight %} # Loading Configuration from a File From ca8f70e9d473d2c81866f3c330cc6545c33bdac7 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Tue, 11 Aug 2015 20:46:58 -0700 Subject: [PATCH 037/253] [SPARK-9649] Fix flaky test MasterSuite again - disable REST The REST server is not actually used in most tests and so we can disable it. It is a source of flakiness because it tries to bind to a specific port in vain. There was also some code that avoided the shuffle service in tests. This is actually not necessary because the shuffle service is already off by default. Author: Andrew Or Closes #8084 from andrewor14/fix-master-suite-again. --- pom.xml | 1 + project/SparkBuild.scala | 1 + 2 files changed, 2 insertions(+) diff --git a/pom.xml b/pom.xml index 8942836a7da16..cfd7d32563f2a 100644 --- a/pom.xml +++ b/pom.xml @@ -1895,6 +1895,7 @@ ${project.build.directory}/tmp ${spark.test.home} 1 + false false false true diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index cad7067ade8c1..74f815f941d5b 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -546,6 +546,7 @@ object TestSettings { javaOptions in Test += "-Dspark.test.home=" + sparkHome, javaOptions in Test += "-Dspark.testing=1", javaOptions in Test += "-Dspark.port.maxRetries=100", + javaOptions in Test += "-Dspark.master.rest.enabled=false", javaOptions in Test += "-Dspark.ui.enabled=false", javaOptions in Test += "-Dspark.ui.showConsoleProgress=false", javaOptions in Test += "-Dspark.driver.allowMultipleContexts=true", From 3ef0f32928fc383ad3edd5ad167212aeb9eba6e1 Mon Sep 17 00:00:00 2001 From: Patrick Wendell Date: Tue, 11 Aug 2015 21:16:48 -0700 Subject: [PATCH 038/253] [SPARK-1517] Refactor release scripts to facilitate nightly publishing This update contains some code changes to the release scripts that allow easier nightly publishing. I've been using these new scripts on Jenkins for cutting and publishing nightly snapshots for the last month or so, and it has been going well. I'd like to get them merged back upstream so this can be maintained by the community. The main changes are: 1. Separates the release tagging from various build possibilities for an already tagged release (`release-tag.sh` and `release-build.sh`). 2. Allow for injecting credentials through the environment, including GPG keys. This is then paired with secure key injection in Jenkins. 3. Support for copying build results to a remote directory, and also "rotating" results, e.g. the ability to keep the last N copies of binary or doc builds. I'm happy if anyone wants to take a look at this - it's not user facing but an internal utility used for generating releases. Author: Patrick Wendell Closes #7411 from pwendell/release-script-updates and squashes the following commits: 74f9beb [Patrick Wendell] Moving maven build command to a variable 233ce85 [Patrick Wendell] [SPARK-1517] Refactor release scripts to facilitate nightly publishing --- dev/create-release/create-release.sh | 267 ---------------------- dev/create-release/release-build.sh | 321 +++++++++++++++++++++++++++ dev/create-release/release-tag.sh | 79 +++++++ 3 files changed, 400 insertions(+), 267 deletions(-) delete mode 100755 dev/create-release/create-release.sh create mode 100755 dev/create-release/release-build.sh create mode 100755 dev/create-release/release-tag.sh diff --git a/dev/create-release/create-release.sh b/dev/create-release/create-release.sh deleted file mode 100755 index 4311c8c9e4ca6..0000000000000 --- a/dev/create-release/create-release.sh +++ /dev/null @@ -1,267 +0,0 @@ -#!/usr/bin/env bash - -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -# Quick-and-dirty automation of making maven and binary releases. Not robust at all. -# Publishes releases to Maven and packages/copies binary release artifacts. -# Expects to be run in a totally empty directory. -# -# Options: -# --skip-create-release Assume the desired release tag already exists -# --skip-publish Do not publish to Maven central -# --skip-package Do not package and upload binary artifacts -# Would be nice to add: -# - Send output to stderr and have useful logging in stdout - -# Note: The following variables must be set before use! -ASF_USERNAME=${ASF_USERNAME:-pwendell} -ASF_PASSWORD=${ASF_PASSWORD:-XXX} -GPG_PASSPHRASE=${GPG_PASSPHRASE:-XXX} -GIT_BRANCH=${GIT_BRANCH:-branch-1.0} -RELEASE_VERSION=${RELEASE_VERSION:-1.2.0} -# Allows publishing under a different version identifier than -# was present in the actual release sources (e.g. rc-X) -PUBLISH_VERSION=${PUBLISH_VERSION:-$RELEASE_VERSION} -NEXT_VERSION=${NEXT_VERSION:-1.2.1} -RC_NAME=${RC_NAME:-rc2} - -M2_REPO=~/.m2/repository -SPARK_REPO=$M2_REPO/org/apache/spark -NEXUS_ROOT=https://repository.apache.org/service/local/staging -NEXUS_PROFILE=d63f592e7eac0 # Profile for Spark staging uploads - -if [ -z "$JAVA_HOME" ]; then - echo "Error: JAVA_HOME is not set, cannot proceed." - exit -1 -fi -JAVA_7_HOME=${JAVA_7_HOME:-$JAVA_HOME} - -set -e - -GIT_TAG=v$RELEASE_VERSION-$RC_NAME - -if [[ ! "$@" =~ --skip-create-release ]]; then - echo "Creating release commit and publishing to Apache repository" - # Artifact publishing - git clone https://$ASF_USERNAME:$ASF_PASSWORD@git-wip-us.apache.org/repos/asf/spark.git \ - -b $GIT_BRANCH - pushd spark - export MAVEN_OPTS="-Xmx3g -XX:MaxPermSize=1g -XX:ReservedCodeCacheSize=1g" - - # Create release commits and push them to github - # NOTE: This is done "eagerly" i.e. we don't check if we can succesfully build - # or before we coin the release commit. This helps avoid races where - # other people add commits to this branch while we are in the middle of building. - cur_ver="${RELEASE_VERSION}-SNAPSHOT" - rel_ver="${RELEASE_VERSION}" - next_ver="${NEXT_VERSION}-SNAPSHOT" - - old="^\( \{2,4\}\)${cur_ver}<\/version>$" - new="\1${rel_ver}<\/version>" - find . -name pom.xml | grep -v dev | xargs -I {} sed -i \ - -e "s/${old}/${new}/" {} - find . -name package.scala | grep -v dev | xargs -I {} sed -i \ - -e "s/${old}/${new}/" {} - - git commit -a -m "Preparing Spark release $GIT_TAG" - echo "Creating tag $GIT_TAG at the head of $GIT_BRANCH" - git tag $GIT_TAG - - old="^\( \{2,4\}\)${rel_ver}<\/version>$" - new="\1${next_ver}<\/version>" - find . -name pom.xml | grep -v dev | xargs -I {} sed -i \ - -e "s/$old/$new/" {} - find . -name package.scala | grep -v dev | xargs -I {} sed -i \ - -e "s/${old}/${new}/" {} - git commit -a -m "Preparing development version $next_ver" - git push origin $GIT_TAG - git push origin HEAD:$GIT_BRANCH - popd - rm -rf spark -fi - -if [[ ! "$@" =~ --skip-publish ]]; then - git clone https://$ASF_USERNAME:$ASF_PASSWORD@git-wip-us.apache.org/repos/asf/spark.git - pushd spark - git checkout --force $GIT_TAG - - # Substitute in case published version is different than released - old="^\( \{2,4\}\)${RELEASE_VERSION}<\/version>$" - new="\1${PUBLISH_VERSION}<\/version>" - find . -name pom.xml | grep -v dev | xargs -I {} sed -i \ - -e "s/${old}/${new}/" {} - - # Using Nexus API documented here: - # https://support.sonatype.com/entries/39720203-Uploading-to-a-Staging-Repository-via-REST-API - echo "Creating Nexus staging repository" - repo_request="Apache Spark $GIT_TAG (published as $PUBLISH_VERSION)" - out=$(curl -X POST -d "$repo_request" -u $ASF_USERNAME:$ASF_PASSWORD \ - -H "Content-Type:application/xml" -v \ - $NEXUS_ROOT/profiles/$NEXUS_PROFILE/start) - staged_repo_id=$(echo $out | sed -e "s/.*\(orgapachespark-[0-9]\{4\}\).*/\1/") - echo "Created Nexus staging repository: $staged_repo_id" - - rm -rf $SPARK_REPO - - build/mvn -DskipTests -Pyarn -Phive \ - -Phive-thriftserver -Phadoop-2.2 -Pspark-ganglia-lgpl -Pkinesis-asl \ - clean install - - ./dev/change-scala-version.sh 2.11 - - build/mvn -DskipTests -Pyarn -Phive \ - -Dscala-2.11 -Phadoop-2.2 -Pspark-ganglia-lgpl -Pkinesis-asl \ - clean install - - ./dev/change-scala-version.sh 2.10 - - pushd $SPARK_REPO - - # Remove any extra files generated during install - find . -type f |grep -v \.jar |grep -v \.pom | xargs rm - - echo "Creating hash and signature files" - for file in $(find . -type f) - do - echo $GPG_PASSPHRASE | gpg --passphrase-fd 0 --output $file.asc --detach-sig --armour $file; - if [ $(command -v md5) ]; then - # Available on OS X; -q to keep only hash - md5 -q $file > $file.md5 - else - # Available on Linux; cut to keep only hash - md5sum $file | cut -f1 -d' ' > $file.md5 - fi - shasum -a 1 $file | cut -f1 -d' ' > $file.sha1 - done - - nexus_upload=$NEXUS_ROOT/deployByRepositoryId/$staged_repo_id - echo "Uplading files to $nexus_upload" - for file in $(find . -type f) - do - # strip leading ./ - file_short=$(echo $file | sed -e "s/\.\///") - dest_url="$nexus_upload/org/apache/spark/$file_short" - echo " Uploading $file_short" - curl -u $ASF_USERNAME:$ASF_PASSWORD --upload-file $file_short $dest_url - done - - echo "Closing nexus staging repository" - repo_request="$staged_repo_idApache Spark $GIT_TAG (published as $PUBLISH_VERSION)" - out=$(curl -X POST -d "$repo_request" -u $ASF_USERNAME:$ASF_PASSWORD \ - -H "Content-Type:application/xml" -v \ - $NEXUS_ROOT/profiles/$NEXUS_PROFILE/finish) - echo "Closed Nexus staging repository: $staged_repo_id" - - popd - popd - rm -rf spark -fi - -if [[ ! "$@" =~ --skip-package ]]; then - # Source and binary tarballs - echo "Packaging release tarballs" - git clone https://git-wip-us.apache.org/repos/asf/spark.git - cd spark - git checkout --force $GIT_TAG - release_hash=`git rev-parse HEAD` - - rm .gitignore - rm -rf .git - cd .. - - cp -r spark spark-$RELEASE_VERSION - tar cvzf spark-$RELEASE_VERSION.tgz spark-$RELEASE_VERSION - echo $GPG_PASSPHRASE | gpg --passphrase-fd 0 --armour --output spark-$RELEASE_VERSION.tgz.asc \ - --detach-sig spark-$RELEASE_VERSION.tgz - echo $GPG_PASSPHRASE | gpg --passphrase-fd 0 --print-md MD5 spark-$RELEASE_VERSION.tgz > \ - spark-$RELEASE_VERSION.tgz.md5 - echo $GPG_PASSPHRASE | gpg --passphrase-fd 0 --print-md SHA512 spark-$RELEASE_VERSION.tgz > \ - spark-$RELEASE_VERSION.tgz.sha - rm -rf spark-$RELEASE_VERSION - - # Updated for each binary build - make_binary_release() { - NAME=$1 - FLAGS=$2 - ZINC_PORT=$3 - cp -r spark spark-$RELEASE_VERSION-bin-$NAME - - cd spark-$RELEASE_VERSION-bin-$NAME - - # TODO There should probably be a flag to make-distribution to allow 2.11 support - if [[ $FLAGS == *scala-2.11* ]]; then - ./dev/change-scala-version.sh 2.11 - fi - - export ZINC_PORT=$ZINC_PORT - echo "Creating distribution: $NAME ($FLAGS)" - ./make-distribution.sh --name $NAME --tgz $FLAGS -DzincPort=$ZINC_PORT 2>&1 > \ - ../binary-release-$NAME.log - cd .. - cp spark-$RELEASE_VERSION-bin-$NAME/spark-$RELEASE_VERSION-bin-$NAME.tgz . - - echo $GPG_PASSPHRASE | gpg --passphrase-fd 0 --armour \ - --output spark-$RELEASE_VERSION-bin-$NAME.tgz.asc \ - --detach-sig spark-$RELEASE_VERSION-bin-$NAME.tgz - echo $GPG_PASSPHRASE | gpg --passphrase-fd 0 --print-md \ - MD5 spark-$RELEASE_VERSION-bin-$NAME.tgz > \ - spark-$RELEASE_VERSION-bin-$NAME.tgz.md5 - echo $GPG_PASSPHRASE | gpg --passphrase-fd 0 --print-md \ - SHA512 spark-$RELEASE_VERSION-bin-$NAME.tgz > \ - spark-$RELEASE_VERSION-bin-$NAME.tgz.sha - } - - # We increment the Zinc port each time to avoid OOM's and other craziness if multiple builds - # share the same Zinc server. - make_binary_release "hadoop1" "-Psparkr -Phadoop-1 -Phive -Phive-thriftserver" "3030" & - make_binary_release "hadoop1-scala2.11" "-Psparkr -Phadoop-1 -Phive -Dscala-2.11" "3031" & - make_binary_release "cdh4" "-Psparkr -Phadoop-1 -Phive -Phive-thriftserver -Dhadoop.version=2.0.0-mr1-cdh4.2.0" "3032" & - make_binary_release "hadoop2.3" "-Psparkr -Phadoop-2.3 -Phive -Phive-thriftserver -Pyarn" "3033" & - make_binary_release "hadoop2.4" "-Psparkr -Phadoop-2.4 -Phive -Phive-thriftserver -Pyarn" "3034" & - make_binary_release "mapr3" "-Pmapr3 -Psparkr -Phive -Phive-thriftserver" "3035" & - make_binary_release "mapr4" "-Pmapr4 -Psparkr -Pyarn -Phive -Phive-thriftserver" "3036" & - make_binary_release "hadoop2.4-without-hive" "-Psparkr -Phadoop-2.4 -Pyarn" "3037" & - wait - rm -rf spark-$RELEASE_VERSION-bin-*/ - - # Copy data - echo "Copying release tarballs" - rc_folder=spark-$RELEASE_VERSION-$RC_NAME - ssh $ASF_USERNAME@people.apache.org \ - mkdir /home/$ASF_USERNAME/public_html/$rc_folder - scp spark-* \ - $ASF_USERNAME@people.apache.org:/home/$ASF_USERNAME/public_html/$rc_folder/ - - # Docs - cd spark - sbt/sbt clean - cd docs - # Compile docs with Java 7 to use nicer format - JAVA_HOME="$JAVA_7_HOME" PRODUCTION=1 RELEASE_VERSION="$RELEASE_VERSION" jekyll build - echo "Copying release documentation" - rc_docs_folder=${rc_folder}-docs - ssh $ASF_USERNAME@people.apache.org \ - mkdir /home/$ASF_USERNAME/public_html/$rc_docs_folder - rsync -r _site/* $ASF_USERNAME@people.apache.org:/home/$ASF_USERNAME/public_html/$rc_docs_folder - - echo "Release $RELEASE_VERSION completed:" - echo "Git tag:\t $GIT_TAG" - echo "Release commit:\t $release_hash" - echo "Binary location:\t http://people.apache.org/~$ASF_USERNAME/$rc_folder" - echo "Doc location:\t http://people.apache.org/~$ASF_USERNAME/$rc_docs_folder" -fi diff --git a/dev/create-release/release-build.sh b/dev/create-release/release-build.sh new file mode 100755 index 0000000000000..399c73e7bf6bc --- /dev/null +++ b/dev/create-release/release-build.sh @@ -0,0 +1,321 @@ +#!/usr/bin/env bash + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +function exit_with_usage { + cat << EOF +usage: release-build.sh +Creates build deliverables from a Spark commit. + +Top level targets are + package: Create binary packages and copy them to people.apache + docs: Build docs and copy them to people.apache + publish-snapshot: Publish snapshot release to Apache snapshots + publish-release: Publish a release to Apache release repo + +All other inputs are environment variables + +GIT_REF - Release tag or commit to build from +SPARK_VERSION - Release identifier used when publishing +SPARK_PACKAGE_VERSION - Release identifier in top level package directory +REMOTE_PARENT_DIR - Parent in which to create doc or release builds. +REMOTE_PARENT_MAX_LENGTH - If set, parent directory will be cleaned to only + have this number of subdirectories (by deleting old ones). WARNING: This deletes data. + +ASF_USERNAME - Username of ASF committer account +ASF_PASSWORD - Password of ASF committer account +ASF_RSA_KEY - RSA private key file for ASF committer account + +GPG_KEY - GPG key used to sign release artifacts +GPG_PASSPHRASE - Passphrase for GPG key +EOF + exit 1 +} + +set -e + +if [ $# -eq 0 ]; then + exit_with_usage +fi + +if [[ $@ == *"help"* ]]; then + exit_with_usage +fi + +for env in ASF_USERNAME ASF_RSA_KEY GPG_PASSPHRASE GPG_KEY; do + if [ -z "${!env}" ]; then + echo "ERROR: $env must be set to run this script" + exit_with_usage + fi +done + +# Commit ref to checkout when building +GIT_REF=${GIT_REF:-master} + +# Destination directory parent on remote server +REMOTE_PARENT_DIR=${REMOTE_PARENT_DIR:-/home/$ASF_USERNAME/public_html} + +SSH="ssh -o StrictHostKeyChecking=no -i $ASF_RSA_KEY" +GPG="gpg --no-tty --batch" +NEXUS_ROOT=https://repository.apache.org/service/local/staging +NEXUS_PROFILE=d63f592e7eac0 # Profile for Spark staging uploads +BASE_DIR=$(pwd) + +MVN="build/mvn --force" +PUBLISH_PROFILES="-Pyarn -Phive -Phadoop-2.2" +PUBLISH_PROFILES="$PUBLISH_PROFILES -Pspark-ganglia-lgpl -Pkinesis-asl" + +rm -rf spark +git clone https://git-wip-us.apache.org/repos/asf/spark.git +cd spark +git checkout $GIT_REF +git_hash=`git rev-parse --short HEAD` +echo "Checked out Spark git hash $git_hash" + +if [ -z "$SPARK_VERSION" ]; then + SPARK_VERSION=$($MVN help:evaluate -Dexpression=project.version \ + | grep -v INFO | grep -v WARNING | grep -v Download) +fi + +if [ -z "$SPARK_PACKAGE_VERSION" ]; then + SPARK_PACKAGE_VERSION="${SPARK_VERSION}-$(date +%Y_%m_%d_%H_%M)-${git_hash}" +fi + +DEST_DIR_NAME="spark-$SPARK_PACKAGE_VERSION" +USER_HOST="$ASF_USERNAME@people.apache.org" + +rm .gitignore +rm -rf .git +cd .. + +if [ -n "$REMOTE_PARENT_MAX_LENGTH" ]; then + old_dirs=$($SSH $USER_HOST ls -t $REMOTE_PARENT_DIR | tail -n +$REMOTE_PARENT_MAX_LENGTH) + for old_dir in $old_dirs; do + echo "Removing directory: $old_dir" + $SSH $USER_HOST rm -r $REMOTE_PARENT_DIR/$old_dir + done +fi + +if [[ "$1" == "package" ]]; then + # Source and binary tarballs + echo "Packaging release tarballs" + cp -r spark spark-$SPARK_VERSION + tar cvzf spark-$SPARK_VERSION.tgz spark-$SPARK_VERSION + echo $GPG_PASSPHRASE | $GPG --passphrase-fd 0 --armour --output spark-$SPARK_VERSION.tgz.asc \ + --detach-sig spark-$SPARK_VERSION.tgz + echo $GPG_PASSPHRASE | $GPG --passphrase-fd 0 --print-md MD5 spark-$SPARK_VERSION.tgz > \ + spark-$SPARK_VERSION.tgz.md5 + echo $GPG_PASSPHRASE | $GPG --passphrase-fd 0 --print-md \ + SHA512 spark-$SPARK_VERSION.tgz > spark-$SPARK_VERSION.tgz.sha + rm -rf spark-$SPARK_VERSION + + # Updated for each binary build + make_binary_release() { + NAME=$1 + FLAGS=$2 + ZINC_PORT=$3 + cp -r spark spark-$SPARK_VERSION-bin-$NAME + + cd spark-$SPARK_VERSION-bin-$NAME + + # TODO There should probably be a flag to make-distribution to allow 2.11 support + if [[ $FLAGS == *scala-2.11* ]]; then + ./dev/change-scala-version.sh 2.11 + fi + + export ZINC_PORT=$ZINC_PORT + echo "Creating distribution: $NAME ($FLAGS)" + ./make-distribution.sh --name $NAME --tgz $FLAGS -DzincPort=$ZINC_PORT 2>&1 > \ + ../binary-release-$NAME.log + cd .. + cp spark-$SPARK_VERSION-bin-$NAME/spark-$SPARK_VERSION-bin-$NAME.tgz . + + echo $GPG_PASSPHRASE | $GPG --passphrase-fd 0 --armour \ + --output spark-$SPARK_VERSION-bin-$NAME.tgz.asc \ + --detach-sig spark-$SPARK_VERSION-bin-$NAME.tgz + echo $GPG_PASSPHRASE | $GPG --passphrase-fd 0 --print-md \ + MD5 spark-$SPARK_VERSION-bin-$NAME.tgz > \ + spark-$SPARK_VERSION-bin-$NAME.tgz.md5 + echo $GPG_PASSPHRASE | $GPG --passphrase-fd 0 --print-md \ + SHA512 spark-$SPARK_VERSION-bin-$NAME.tgz > \ + spark-$SPARK_VERSION-bin-$NAME.tgz.sha + } + + # TODO: Check exit codes of children here: + # http://stackoverflow.com/questions/1570262/shell-get-exit-code-of-background-process + + # We increment the Zinc port each time to avoid OOM's and other craziness if multiple builds + # share the same Zinc server. + make_binary_release "hadoop1" "-Psparkr -Phadoop-1 -Phive -Phive-thriftserver" "3030" & + make_binary_release "hadoop1-scala2.11" "-Psparkr -Phadoop-1 -Phive -Dscala-2.11" "3031" & + make_binary_release "cdh4" "-Psparkr -Phadoop-1 -Phive -Phive-thriftserver -Dhadoop.version=2.0.0-mr1-cdh4.2.0" "3032" & + make_binary_release "hadoop2.3" "-Psparkr -Phadoop-2.3 -Phive -Phive-thriftserver -Pyarn" "3033" & + make_binary_release "hadoop2.4" "-Psparkr -Phadoop-2.4 -Phive -Phive-thriftserver -Pyarn" "3034" & + make_binary_release "hadoop2.6" "-Psparkr -Phadoop-2.6 -Phive -Phive-thriftserver -Pyarn" "3034" & + make_binary_release "hadoop2.4-without-hive" "-Psparkr -Phadoop-2.4 -Pyarn" "3037" & + make_binary_release "without-hadoop" "-Psparkr -Phadoop-provided -Pyarn" "3038" & + wait + rm -rf spark-$SPARK_VERSION-bin-*/ + + # Copy data + dest_dir="$REMOTE_PARENT_DIR/${DEST_DIR_NAME}-bin" + echo "Copying release tarballs to $dest_dir" + $SSH $USER_HOST mkdir $dest_dir + rsync -e "$SSH" spark-* $USER_HOST:$dest_dir + echo "Linking /latest to $dest_dir" + $SSH $USER_HOST rm -f "$REMOTE_PARENT_DIR/latest" + $SSH $USER_HOST ln -s $dest_dir "$REMOTE_PARENT_DIR/latest" + exit 0 +fi + +if [[ "$1" == "docs" ]]; then + # Documentation + cd spark + echo "Building Spark docs" + dest_dir="$REMOTE_PARENT_DIR/${DEST_DIR_NAME}-docs" + cd docs + # Compile docs with Java 7 to use nicer format + # TODO: Make configurable to add this: PRODUCTION=1 + PRODUCTION=1 RELEASE_VERSION="$SPARK_VERSION" jekyll build + echo "Copying release documentation to $dest_dir" + $SSH $USER_HOST mkdir $dest_dir + echo "Linking /latest to $dest_dir" + $SSH $USER_HOST rm -f "$REMOTE_PARENT_DIR/latest" + $SSH $USER_HOST ln -s $dest_dir "$REMOTE_PARENT_DIR/latest" + rsync -e "$SSH" -r _site/* $USER_HOST:$dest_dir + cd .. + exit 0 +fi + +if [[ "$1" == "publish-snapshot" ]]; then + cd spark + # Publish Spark to Maven release repo + echo "Deploying Spark SNAPSHOT at '$GIT_REF' ($git_hash)" + echo "Publish version is $SPARK_VERSION" + if [[ ! $SPARK_VERSION == *"SNAPSHOT"* ]]; then + echo "ERROR: Snapshots must have a version containing SNAPSHOT" + echo "ERROR: You gave version '$SPARK_VERSION'" + exit 1 + fi + # Coerce the requested version + $MVN versions:set -DnewVersion=$SPARK_VERSION + tmp_settings="tmp-settings.xml" + echo "" > $tmp_settings + echo "apache.snapshots.https$ASF_USERNAME" >> $tmp_settings + echo "$ASF_PASSWORD" >> $tmp_settings + echo "" >> $tmp_settings + + # Generate random point for Zinc + export ZINC_PORT=$(python -S -c "import random; print random.randrange(3030,4030)") + + $MVN -DzincPort=$ZINC_PORT --settings $tmp_settings -DskipTests $PUBLISH_PROFILES \ + -Phive-thriftserver deploy + ./dev/change-scala-version.sh 2.10 + $MVN -DzincPort=$ZINC_PORT -Dscala-2.11 --settings $tmp_settings \ + -DskipTests $PUBLISH_PROFILES deploy + + # Clean-up Zinc nailgun process + /usr/sbin/lsof -P |grep $ZINC_PORT | grep LISTEN | awk '{ print $2; }' | xargs kill + + rm $tmp_settings + cd .. + exit 0 +fi + +if [[ "$1" == "publish-release" ]]; then + cd spark + # Publish Spark to Maven release repo + echo "Publishing Spark checkout at '$GIT_REF' ($git_hash)" + echo "Publish version is $SPARK_VERSION" + # Coerce the requested version + $MVN versions:set -DnewVersion=$SPARK_VERSION + + # Using Nexus API documented here: + # https://support.sonatype.com/entries/39720203-Uploading-to-a-Staging-Repository-via-REST-API + echo "Creating Nexus staging repository" + repo_request="Apache Spark $SPARK_VERSION (commit $git_hash)" + out=$(curl -X POST -d "$repo_request" -u $ASF_USERNAME:$ASF_PASSWORD \ + -H "Content-Type:application/xml" -v \ + $NEXUS_ROOT/profiles/$NEXUS_PROFILE/start) + staged_repo_id=$(echo $out | sed -e "s/.*\(orgapachespark-[0-9]\{4\}\).*/\1/") + echo "Created Nexus staging repository: $staged_repo_id" + + tmp_repo=$(mktemp -d spark-repo-XXXXX) + + # Generate random point for Zinc + export ZINC_PORT=$(python -S -c "import random; print random.randrange(3030,4030)") + + $MVN -DzincPort=$ZINC_PORT -Dmaven.repo.local=$tmp_repo -DskipTests $PUBLISH_PROFILES \ + -Phive-thriftserver clean install + + ./dev/change-scala-version.sh 2.11 + + $MVN -DzincPort=$ZINC_PORT -Dmaven.repo.local=$tmp_repo -Dscala-2.11 \ + -DskipTests $PUBLISH_PROFILES clean install + + # Clean-up Zinc nailgun process + /usr/sbin/lsof -P |grep $ZINC_PORT | grep LISTEN | awk '{ print $2; }' | xargs kill + + ./dev/change-version-to-2.10.sh + + pushd $tmp_repo/org/apache/spark + + # Remove any extra files generated during install + find . -type f |grep -v \.jar |grep -v \.pom | xargs rm + + echo "Creating hash and signature files" + for file in $(find . -type f) + do + echo $GPG_PASSPHRASE | $GPG --passphrase-fd 0 --output $file.asc \ + --detach-sig --armour $file; + if [ $(command -v md5) ]; then + # Available on OS X; -q to keep only hash + md5 -q $file > $file.md5 + else + # Available on Linux; cut to keep only hash + md5sum $file | cut -f1 -d' ' > $file.md5 + fi + sha1sum $file | cut -f1 -d' ' > $file.sha1 + done + + nexus_upload=$NEXUS_ROOT/deployByRepositoryId/$staged_repo_id + echo "Uplading files to $nexus_upload" + for file in $(find . -type f) + do + # strip leading ./ + file_short=$(echo $file | sed -e "s/\.\///") + dest_url="$nexus_upload/org/apache/spark/$file_short" + echo " Uploading $file_short" + curl -u $ASF_USERNAME:$ASF_PASSWORD --upload-file $file_short $dest_url + done + + echo "Closing nexus staging repository" + repo_request="$staged_repo_idApache Spark $SPARK_VERSION (commit $git_hash)" + out=$(curl -X POST -d "$repo_request" -u $ASF_USERNAME:$ASF_PASSWORD \ + -H "Content-Type:application/xml" -v \ + $NEXUS_ROOT/profiles/$NEXUS_PROFILE/finish) + echo "Closed Nexus staging repository: $staged_repo_id" + popd + rm -rf $tmp_repo + cd .. + exit 0 +fi + +cd .. +rm -rf spark +echo "ERROR: expects to be called with 'package', 'docs', 'publish-release' or 'publish-snapshot'" diff --git a/dev/create-release/release-tag.sh b/dev/create-release/release-tag.sh new file mode 100755 index 0000000000000..b0a3374becc6a --- /dev/null +++ b/dev/create-release/release-tag.sh @@ -0,0 +1,79 @@ +#!/usr/bin/env bash + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +function exit_with_usage { + cat << EOF +usage: tag-release.sh +Tags a Spark release on a particular branch. + +Inputs are specified with the following environment variables: +ASF_USERNAME - Apache Username +ASF_PASSWORD - Apache Password +GIT_NAME - Name to use with git +GIT_EMAIL - E-mail address to use with git +GIT_BRANCH - Git branch on which to make release +RELEASE_VERSION - Version used in pom files for release +RELEASE_TAG - Name of release tag +NEXT_VERSION - Development version after release +EOF + exit 1 +} + +set -e + +if [[ $@ == *"help"* ]]; then + exit_with_usage +fi + +for env in ASF_USERNAME ASF_PASSWORD RELEASE_VERSION RELEASE_TAG NEXT_VERSION GIT_EMAIL GIT_NAME GIT_BRANCH; do + if [ -z "${!env}" ]; then + echo "$env must be set to run this script" + exit 1 + fi +done + +ASF_SPARK_REPO="git-wip-us.apache.org/repos/asf/spark.git" +MVN="build/mvn --force" + +rm -rf spark +git clone https://$ASF_USERNAME:$ASF_PASSWORD@$ASF_SPARK_REPO -b $GIT_BRANCH +cd spark + +git config user.name "$GIT_NAME" +git config user.email $GIT_EMAIL + +# Create release version +$MVN versions:set -DnewVersion=$RELEASE_VERSION | grep -v "no value" # silence logs +git commit -a -m "Preparing Spark release $RELEASE_TAG" +echo "Creating tag $RELEASE_TAG at the head of $GIT_BRANCH" +git tag $RELEASE_TAG + +# TODO: It would be nice to do some verifications here +# i.e. check whether ec2 scripts have the new version + +# Create next version +$MVN versions:set -DnewVersion=$NEXT_VERSION | grep -v "no value" # silence logs +git commit -a -m "Preparing development version $NEXT_VERSION" + +# Push changes +git push origin $RELEASE_TAG +git push origin HEAD:$GIT_BRANCH + +cd .. +rm -rf spark From 74a293f4537c6982345166f8883538f81d850872 Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Tue, 11 Aug 2015 21:26:03 -0700 Subject: [PATCH 039/253] [SPARK-9713] [ML] Document SparkR MLlib glm() integration in Spark 1.5 This documents the use of R model formulae in the SparkR guide. Also fixes some bugs in the R api doc. mengxr Author: Eric Liang Closes #8085 from ericl/docs. --- R/pkg/R/generics.R | 4 ++-- R/pkg/R/mllib.R | 8 ++++---- docs/sparkr.md | 37 ++++++++++++++++++++++++++++++++++++- 3 files changed, 42 insertions(+), 7 deletions(-) diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index c43b947129e87..379a78b1d833e 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -535,8 +535,8 @@ setGeneric("showDF", function(x,...) { standardGeneric("showDF") }) #' @export setGeneric("summarize", function(x,...) { standardGeneric("summarize") }) -##' rdname summary -##' @export +#' @rdname summary +#' @export setGeneric("summary", function(x, ...) { standardGeneric("summary") }) # @rdname tojson diff --git a/R/pkg/R/mllib.R b/R/pkg/R/mllib.R index b524d1fd87496..cea3d760d05fe 100644 --- a/R/pkg/R/mllib.R +++ b/R/pkg/R/mllib.R @@ -56,10 +56,10 @@ setMethod("glm", signature(formula = "formula", family = "ANY", data = "DataFram #' #' Makes predictions from a model produced by glm(), similarly to R's predict(). #' -#' @param model A fitted MLlib model +#' @param object A fitted MLlib model #' @param newData DataFrame for testing #' @return DataFrame containing predicted values -#' @rdname glm +#' @rdname predict #' @export #' @examples #'\dontrun{ @@ -76,10 +76,10 @@ setMethod("predict", signature(object = "PipelineModel"), #' #' Returns the summary of a model produced by glm(), similarly to R's summary(). #' -#' @param model A fitted MLlib model +#' @param x A fitted MLlib model #' @return a list with a 'coefficient' component, which is the matrix of coefficients. See #' summary.glm for more information. -#' @rdname glm +#' @rdname summary #' @export #' @examples #'\dontrun{ diff --git a/docs/sparkr.md b/docs/sparkr.md index 4385a4eeacd5c..7139d16b4a068 100644 --- a/docs/sparkr.md +++ b/docs/sparkr.md @@ -11,7 +11,8 @@ title: SparkR (R on Spark) SparkR is an R package that provides a light-weight frontend to use Apache Spark from R. In Spark {{site.SPARK_VERSION}}, SparkR provides a distributed data frame implementation that supports operations like selection, filtering, aggregation etc. (similar to R data frames, -[dplyr](https://github.com/hadley/dplyr)) but on large datasets. +[dplyr](https://github.com/hadley/dplyr)) but on large datasets. SparkR also supports distributed +machine learning using MLlib. # SparkR DataFrames @@ -230,3 +231,37 @@ head(teenagers) {% endhighlight %}

+ +# Machine Learning + +SparkR allows the fitting of generalized linear models over DataFrames using the [glm()](api/R/glm.html) function. Under the hood, SparkR uses MLlib to train a model of the specified family. Currently the gaussian and binomial families are supported. We support a subset of the available R formula operators for model fitting, including '~', '.', '+', and '-'. The example below shows the use of building a gaussian GLM model using SparkR. + +
+{% highlight r %} +# Create the DataFrame +df <- createDataFrame(sqlContext, iris) + +# Fit a linear model over the dataset. +model <- glm(Sepal_Length ~ Sepal_Width + Species, data = df, family = "gaussian") + +# Model coefficients are returned in a similar format to R's native glm(). +summary(model) +##$coefficients +## Estimate +##(Intercept) 2.2513930 +##Sepal_Width 0.8035609 +##Species_versicolor 1.4587432 +##Species_virginica 1.9468169 + +# Make predictions based on the model. +predictions <- predict(model, newData = df) +head(select(predictions, "Sepal_Length", "prediction")) +## Sepal_Length prediction +##1 5.1 5.063856 +##2 4.9 4.662076 +##3 4.7 4.822788 +##4 4.6 4.742432 +##5 5.0 5.144212 +##6 5.4 5.385281 +{% endhighlight %} +
From c3e9a120e33159fb45cd99f3a55fc5cf16cd7c6c Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Tue, 11 Aug 2015 22:45:18 -0700 Subject: [PATCH 040/253] [SPARK-9831] [SQL] fix serialization with empty broadcast Author: Davies Liu Closes #8117 from davies/fix_serialization and squashes the following commits: d21ac71 [Davies Liu] fix serialization with empty broadcast --- .../sql/execution/joins/HashedRelation.scala | 2 +- .../execution/joins/HashedRelationSuite.scala | 17 +++++++++++++++++ 2 files changed, 18 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala index c1bc7947aa39c..076afe6e4e960 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala @@ -299,7 +299,7 @@ private[joins] final class UnsafeHashedRelation( binaryMap = new BytesToBytesMap( taskMemoryManager, shuffleMemoryManager, - nKeys * 2, // reduce hash collision + (nKeys * 1.5 + 1).toInt, // reduce hash collision pageSizeBytes) var i = 0 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala index a1fa2c3864bdb..c635b2d51f464 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala @@ -103,4 +103,21 @@ class HashedRelationSuite extends SparkFunSuite { assert(hashed2.get(unsafeData(2)) === data2) assert(numDataRows.value.value === data.length) } + + test("test serialization empty hash map") { + val os = new ByteArrayOutputStream() + val out = new ObjectOutputStream(os) + val hashed = new UnsafeHashedRelation( + new java.util.HashMap[UnsafeRow, CompactBuffer[UnsafeRow]]) + hashed.writeExternal(out) + out.flush() + val in = new ObjectInputStream(new ByteArrayInputStream(os.toByteArray)) + val hashed2 = new UnsafeHashedRelation() + hashed2.readExternal(in) + + val schema = StructType(StructField("a", IntegerType, true) :: Nil) + val toUnsafe = UnsafeProjection.create(schema) + val row = toUnsafe(InternalRow(0)) + assert(hashed2.get(row) === null) + } } From b1581ac28840a4d2209ef8bb5c9f8700b4c1b286 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Tue, 11 Aug 2015 22:46:59 -0700 Subject: [PATCH 041/253] [SPARK-9854] [SQL] RuleExecutor.timeMap should be thread-safe `RuleExecutor.timeMap` is currently a non-thread-safe mutable HashMap; this can lead to infinite loops if multiple threads are concurrently modifying the map. I believe that this is responsible for some hangs that I've observed in HiveQuerySuite. This patch addresses this by using a Guava `AtomicLongMap`. Author: Josh Rosen Closes #8120 from JoshRosen/rule-executor-time-map-fix. --- .../spark/sql/catalyst/rules/RuleExecutor.scala | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala index 8b824511a79da..f80d2a93241d1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala @@ -17,22 +17,25 @@ package org.apache.spark.sql.catalyst.rules +import scala.collection.JavaConverters._ + +import com.google.common.util.concurrent.AtomicLongMap + import org.apache.spark.Logging import org.apache.spark.sql.catalyst.trees.TreeNode import org.apache.spark.sql.catalyst.util.sideBySide -import scala.collection.mutable - object RuleExecutor { - protected val timeMap = new mutable.HashMap[String, Long].withDefault(_ => 0) + protected val timeMap = AtomicLongMap.create[String]() /** Resets statistics about time spent running specific rules */ def resetTime(): Unit = timeMap.clear() /** Dump statistics about time spent running specific rules. */ def dumpTimeSpent(): String = { - val maxSize = timeMap.keys.map(_.toString.length).max - timeMap.toSeq.sortBy(_._2).reverseMap { case (k, v) => + val map = timeMap.asMap().asScala + val maxSize = map.keys.map(_.toString.length).max + map.toSeq.sortBy(_._2).reverseMap { case (k, v) => s"${k.padTo(maxSize, " ").mkString} $v" }.mkString("\n") } @@ -79,7 +82,7 @@ abstract class RuleExecutor[TreeType <: TreeNode[_]] extends Logging { val startTime = System.nanoTime() val result = rule(plan) val runTime = System.nanoTime() - startTime - RuleExecutor.timeMap(rule.ruleName) = RuleExecutor.timeMap(rule.ruleName) + runTime + RuleExecutor.timeMap.addAndGet(rule.ruleName, runTime) if (!result.fastEquals(plan)) { logTrace( From b85f9a242a12e8096e331fa77d5ebd16e93c844d Mon Sep 17 00:00:00 2001 From: xutingjun Date: Tue, 11 Aug 2015 23:19:35 -0700 Subject: [PATCH 042/253] [SPARK-8366] maxNumExecutorsNeeded should properly handle failed tasks Author: xutingjun Author: meiyoula <1039320815@qq.com> Closes #6817 from XuTingjun/SPARK-8366. --- .../spark/ExecutorAllocationManager.scala | 22 ++++++++++++------- .../ExecutorAllocationManagerSuite.scala | 22 +++++++++++++++++-- 2 files changed, 34 insertions(+), 10 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala index 1877aaf2cac55..b93536e6536e2 100644 --- a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala +++ b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala @@ -599,14 +599,8 @@ private[spark] class ExecutorAllocationManager( // If this is the last pending task, mark the scheduler queue as empty stageIdToTaskIndices.getOrElseUpdate(stageId, new mutable.HashSet[Int]) += taskIndex - val numTasksScheduled = stageIdToTaskIndices(stageId).size - val numTasksTotal = stageIdToNumTasks.getOrElse(stageId, -1) - if (numTasksScheduled == numTasksTotal) { - // No more pending tasks for this stage - stageIdToNumTasks -= stageId - if (stageIdToNumTasks.isEmpty) { - allocationManager.onSchedulerQueueEmpty() - } + if (totalPendingTasks() == 0) { + allocationManager.onSchedulerQueueEmpty() } // Mark the executor on which this task is scheduled as busy @@ -618,6 +612,8 @@ private[spark] class ExecutorAllocationManager( override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = { val executorId = taskEnd.taskInfo.executorId val taskId = taskEnd.taskInfo.taskId + val taskIndex = taskEnd.taskInfo.index + val stageId = taskEnd.stageId allocationManager.synchronized { numRunningTasks -= 1 // If the executor is no longer running any scheduled tasks, mark it as idle @@ -628,6 +624,16 @@ private[spark] class ExecutorAllocationManager( allocationManager.onExecutorIdle(executorId) } } + + // If the task failed, we expect it to be resubmitted later. To ensure we have + // enough resources to run the resubmitted task, we need to mark the scheduler + // as backlogged again if it's not already marked as such (SPARK-8366) + if (taskEnd.reason != Success) { + if (totalPendingTasks() == 0) { + allocationManager.onSchedulerBacklogged() + } + stageIdToTaskIndices.get(stageId).foreach { _.remove(taskIndex) } + } } } diff --git a/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala b/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala index 34caca892891c..f374f97f87448 100644 --- a/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala @@ -206,8 +206,8 @@ class ExecutorAllocationManagerSuite val task2Info = createTaskInfo(1, 0, "executor-1") sc.listenerBus.postToAll(SparkListenerTaskStart(2, 0, task2Info)) - sc.listenerBus.postToAll(SparkListenerTaskEnd(2, 0, null, null, task1Info, null)) - sc.listenerBus.postToAll(SparkListenerTaskEnd(2, 0, null, null, task2Info, null)) + sc.listenerBus.postToAll(SparkListenerTaskEnd(2, 0, null, Success, task1Info, null)) + sc.listenerBus.postToAll(SparkListenerTaskEnd(2, 0, null, Success, task2Info, null)) assert(adjustRequestedExecutors(manager) === -1) } @@ -787,6 +787,24 @@ class ExecutorAllocationManagerSuite Map("host2" -> 1, "host3" -> 2, "host4" -> 1, "host5" -> 2)) } + test("SPARK-8366: maxNumExecutorsNeeded should properly handle failed tasks") { + sc = createSparkContext() + val manager = sc.executorAllocationManager.get + assert(maxNumExecutorsNeeded(manager) === 0) + + sc.listenerBus.postToAll(SparkListenerStageSubmitted(createStageInfo(0, 1))) + assert(maxNumExecutorsNeeded(manager) === 1) + + val taskInfo = createTaskInfo(1, 1, "executor-1") + sc.listenerBus.postToAll(SparkListenerTaskStart(0, 0, taskInfo)) + assert(maxNumExecutorsNeeded(manager) === 1) + + // If the task is failed, we expect it to be resubmitted later. + val taskEndReason = ExceptionFailure(null, null, null, null, null) + sc.listenerBus.postToAll(SparkListenerTaskEnd(0, 0, null, taskEndReason, taskInfo, null)) + assert(maxNumExecutorsNeeded(manager) === 1) + } + private def createSparkContext( minExecutors: Int = 1, maxExecutors: Int = 5, From a807fcbe50b2ce18751d80d39e9d21842f7da32a Mon Sep 17 00:00:00 2001 From: Rohit Agarwal Date: Tue, 11 Aug 2015 23:20:39 -0700 Subject: [PATCH 043/253] [SPARK-9806] [WEB UI] Don't share ReplayListenerBus between multiple applications Author: Rohit Agarwal Closes #8088 from mindprince/SPARK-9806. --- .../org/apache/spark/deploy/history/FsHistoryProvider.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala index e3060ac3fa1a9..53c18ca3ff50c 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala @@ -272,9 +272,9 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) * Replay the log files in the list and merge the list of old applications with new ones */ private def mergeApplicationListing(logs: Seq[FileStatus]): Unit = { - val bus = new ReplayListenerBus() val newAttempts = logs.flatMap { fileStatus => try { + val bus = new ReplayListenerBus() val res = replay(fileStatus, bus) res match { case Some(r) => logDebug(s"Application log ${r.logPath} loaded successfully.") From 4e3f4b934f74e8c7c06f4940d6381343f9fd4918 Mon Sep 17 00:00:00 2001 From: zsxwing Date: Tue, 11 Aug 2015 23:23:17 -0700 Subject: [PATCH 044/253] [SPARK-9829] [WEBUI] Display the update value for peak execution memory The peak execution memory is not correct because it shows the sum of finished tasks' values when a task finishes. This PR fixes it by using the update value rather than the accumulator value. Author: zsxwing Closes #8121 from zsxwing/SPARK-9829. --- core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala index 0c94204df6530..fb4556b836859 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala @@ -860,7 +860,7 @@ private[ui] class TaskDataSource( } val peakExecutionMemoryUsed = taskInternalAccumulables .find { acc => acc.name == InternalAccumulator.PEAK_EXECUTION_MEMORY } - .map { acc => acc.value.toLong } + .map { acc => acc.update.getOrElse("0").toLong } .getOrElse(0L) val maybeInput = metrics.flatMap(_.inputMetrics) From bab89232854de7554e88f29cab76f1a1c349edc1 Mon Sep 17 00:00:00 2001 From: Carson Wang Date: Tue, 11 Aug 2015 23:25:02 -0700 Subject: [PATCH 045/253] [SPARK-9426] [WEBUI] Job page DAG visualization is not shown To reproduce the issue, go to the stage page and click DAG Visualization once, then go to the job page to show the job DAG visualization. You will only see the first stage of the job. Root cause: the java script use local storage to remember your selection. Once you click the stage DAG visualization, the local storage set `expand-dag-viz-arrow-stage` to true. When you go to the job page, the js checks `expand-dag-viz-arrow-stage` in the local storage first and will try to show stage DAG visualization on the job page. To fix this, I set an id to the DAG span to differ job page and stage page. In the js code, we check the id and local storage together to make sure we show the correct DAG visualization. Author: Carson Wang Closes #8104 from carsonwang/SPARK-9426. --- .../resources/org/apache/spark/ui/static/spark-dag-viz.js | 8 ++++---- core/src/main/scala/org/apache/spark/ui/UIUtils.scala | 3 ++- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/core/src/main/resources/org/apache/spark/ui/static/spark-dag-viz.js b/core/src/main/resources/org/apache/spark/ui/static/spark-dag-viz.js index 4a893bc0189aa..83dbea40b63f3 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/spark-dag-viz.js +++ b/core/src/main/resources/org/apache/spark/ui/static/spark-dag-viz.js @@ -109,13 +109,13 @@ function toggleDagViz(forJob) { } $(function (){ - if (window.localStorage.getItem(expandDagVizArrowKey(false)) == "true") { + if ($("#stage-dag-viz").length && + window.localStorage.getItem(expandDagVizArrowKey(false)) == "true") { // Set it to false so that the click function can revert it window.localStorage.setItem(expandDagVizArrowKey(false), "false"); toggleDagViz(false); - } - - if (window.localStorage.getItem(expandDagVizArrowKey(true)) == "true") { + } else if ($("#job-dag-viz").length && + window.localStorage.getItem(expandDagVizArrowKey(true)) == "true") { // Set it to false so that the click function can revert it window.localStorage.setItem(expandDagVizArrowKey(true), "false"); toggleDagViz(true); diff --git a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala index 718aea7e1dc22..f2da417724104 100644 --- a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala @@ -352,7 +352,8 @@ private[spark] object UIUtils extends Logging { */ private def showDagViz(graphs: Seq[RDDOperationGraph], forJob: Boolean): Seq[Node] = {
- + From 5c99d8bf98cbf7f568345d02a814fc318cbfca75 Mon Sep 17 00:00:00 2001 From: Timothy Chen Date: Tue, 11 Aug 2015 23:26:33 -0700 Subject: [PATCH 046/253] [SPARK-8798] [MESOS] Allow additional uris to be fetched with mesos Some users like to download additional files in their sandbox that they can refer to from their spark program, or even later mount these files to another directory. Author: Timothy Chen Closes #7195 from tnachen/mesos_files. --- .../cluster/mesos/CoarseMesosSchedulerBackend.scala | 5 +++++ .../scheduler/cluster/mesos/MesosClusterScheduler.scala | 3 +++ .../scheduler/cluster/mesos/MesosSchedulerBackend.scala | 5 +++++ .../scheduler/cluster/mesos/MesosSchedulerUtils.scala | 6 ++++++ docs/running-on-mesos.md | 8 ++++++++ 5 files changed, 27 insertions(+) diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala index 15a0915708c7c..d6e1e9e5bebc2 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala @@ -194,6 +194,11 @@ private[spark] class CoarseMesosSchedulerBackend( s" --app-id $appId") command.addUris(CommandInfo.URI.newBuilder().setValue(uri.get)) } + + conf.getOption("spark.mesos.uris").map { uris => + setupUris(uris, command) + } + command.build() } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala index f078547e71352..64ec2b8e3db15 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala @@ -403,6 +403,9 @@ private[spark] class MesosClusterScheduler( } builder.setValue(s"$executable $cmdOptions $jar $appArguments") builder.setEnvironment(envBuilder.build()) + conf.getOption("spark.mesos.uris").map { uris => + setupUris(uris, builder) + } builder.build() } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala index 3f63ec1c5832f..5c20606d58715 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala @@ -133,6 +133,11 @@ private[spark] class MesosSchedulerBackend( builder.addAllResources(usedCpuResources) builder.addAllResources(usedMemResources) + + sc.conf.getOption("spark.mesos.uris").map { uris => + setupUris(uris, command) + } + val executorInfo = builder .setExecutorId(ExecutorID.newBuilder().setValue(execId).build()) .setCommand(command) diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala index c04920e4f5873..5b854aa5c2754 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala @@ -331,4 +331,10 @@ private[mesos] trait MesosSchedulerUtils extends Logging { sc.executorMemory } + def setupUris(uris: String, builder: CommandInfo.Builder): Unit = { + uris.split(",").foreach { uri => + builder.addUris(CommandInfo.URI.newBuilder().setValue(uri.trim())) + } + } + } diff --git a/docs/running-on-mesos.md b/docs/running-on-mesos.md index debdd2adf22d6..55e6d4e83a725 100644 --- a/docs/running-on-mesos.md +++ b/docs/running-on-mesos.md @@ -306,6 +306,14 @@ See the [configuration page](configuration.html) for information on Spark config the final overhead will be this value. + + spark.mesos.uris + (none) + + A list of URIs to be downloaded to the sandbox when driver or executor is launched by Mesos. + This applies to both coarse-grain and fine-grain mode. + + spark.mesos.principal Framework principal to authenticate to Mesos From 741a29f98945538a475579ccc974cd42c1613be4 Mon Sep 17 00:00:00 2001 From: Timothy Chen Date: Tue, 11 Aug 2015 23:33:22 -0700 Subject: [PATCH 047/253] [SPARK-9575] [MESOS] Add docuemntation around Mesos shuffle service. andrewor14 Author: Timothy Chen Closes #7907 from tnachen/mesos_shuffle. --- docs/running-on-mesos.md | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/docs/running-on-mesos.md b/docs/running-on-mesos.md index 55e6d4e83a725..cfd219ab02e26 100644 --- a/docs/running-on-mesos.md +++ b/docs/running-on-mesos.md @@ -216,6 +216,20 @@ node. Please refer to [Hadoop on Mesos](https://github.com/mesos/hadoop). In either case, HDFS runs separately from Hadoop MapReduce, without being scheduled through Mesos. +# Dynamic Resource Allocation with Mesos + +Mesos supports dynamic allocation only with coarse grain mode, which can resize the number of executors based on statistics +of the application. While dynamic allocation supports both scaling up and scaling down the number of executors, the coarse grain scheduler only supports scaling down +since it is already designed to run one executor per slave with the configured amount of resources. However, after scaling down the number of executors the coarse grain scheduler +can scale back up to the same amount of executors when Spark signals more executors are needed. + +Users that like to utilize this feature should launch the Mesos Shuffle Service that +provides shuffle data cleanup functionality on top of the Shuffle Service since Mesos doesn't yet support notifying another framework's +termination. To launch/stop the Mesos Shuffle Service please use the provided sbin/start-mesos-shuffle-service.sh and sbin/stop-mesos-shuffle-service.sh +scripts accordingly. + +The Shuffle Service is expected to be running on each slave node that will run Spark executors. One way to easily achieve this with Mesos +is to launch the Shuffle Service with Marathon with a unique host constraint. # Configuration From 9d0822455ddc8d765440d58c463367a4d67ef456 Mon Sep 17 00:00:00 2001 From: Yijie Shen Date: Wed, 12 Aug 2015 19:54:00 +0800 Subject: [PATCH 048/253] [SPARK-9182] [SQL] Filters are not passed through to jdbc source This PR fixes unable to push filter down to JDBC source caused by `Cast` during pattern matching. While we are comparing columns of different type, there's a big chance we need a cast on the column, therefore not match the pattern directly on Attribute and would fail to push down. Author: Yijie Shen Closes #8049 from yjshen/jdbc_pushdown. --- .../datasources/DataSourceStrategy.scala | 30 ++++++++++++++-- .../execution/datasources/jdbc/JDBCRDD.scala | 2 +- .../org/apache/spark/sql/jdbc/JDBCSuite.scala | 34 +++++++++++++++++++ 3 files changed, 63 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index 2a4c40db8bb66..9eea2b0382535 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -20,13 +20,13 @@ package org.apache.spark.sql.execution.datasources import org.apache.spark.{Logging, TaskContext} import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.rdd.{MapPartitionsRDD, RDD, UnionRDD} -import org.apache.spark.sql.catalyst.{InternalRow, expressions} +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow, expressions} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.sources._ -import org.apache.spark.sql.types.{StringType, StructType} +import org.apache.spark.sql.types.{TimestampType, DateType, StringType, StructType} import org.apache.spark.sql.{SaveMode, Strategy, execution, sources, _} import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.{SerializableConfiguration, Utils} @@ -343,11 +343,17 @@ private[sql] object DataSourceStrategy extends Strategy with Logging { * and convert them. */ protected[sql] def selectFilters(filters: Seq[Expression]) = { + import CatalystTypeConverters._ + def translate(predicate: Expression): Option[Filter] = predicate match { case expressions.EqualTo(a: Attribute, Literal(v, _)) => Some(sources.EqualTo(a.name, v)) case expressions.EqualTo(Literal(v, _), a: Attribute) => Some(sources.EqualTo(a.name, v)) + case expressions.EqualTo(Cast(a: Attribute, _), l: Literal) => + Some(sources.EqualTo(a.name, convertToScala(Cast(l, a.dataType).eval(), a.dataType))) + case expressions.EqualTo(l: Literal, Cast(a: Attribute, _)) => + Some(sources.EqualTo(a.name, convertToScala(Cast(l, a.dataType).eval(), a.dataType))) case expressions.EqualNullSafe(a: Attribute, Literal(v, _)) => Some(sources.EqualNullSafe(a.name, v)) @@ -358,21 +364,41 @@ private[sql] object DataSourceStrategy extends Strategy with Logging { Some(sources.GreaterThan(a.name, v)) case expressions.GreaterThan(Literal(v, _), a: Attribute) => Some(sources.LessThan(a.name, v)) + case expressions.GreaterThan(Cast(a: Attribute, _), l: Literal) => + Some(sources.GreaterThan(a.name, convertToScala(Cast(l, a.dataType).eval(), a.dataType))) + case expressions.GreaterThan(l: Literal, Cast(a: Attribute, _)) => + Some(sources.LessThan(a.name, convertToScala(Cast(l, a.dataType).eval(), a.dataType))) case expressions.LessThan(a: Attribute, Literal(v, _)) => Some(sources.LessThan(a.name, v)) case expressions.LessThan(Literal(v, _), a: Attribute) => Some(sources.GreaterThan(a.name, v)) + case expressions.LessThan(Cast(a: Attribute, _), l: Literal) => + Some(sources.LessThan(a.name, convertToScala(Cast(l, a.dataType).eval(), a.dataType))) + case expressions.LessThan(l: Literal, Cast(a: Attribute, _)) => + Some(sources.GreaterThan(a.name, convertToScala(Cast(l, a.dataType).eval(), a.dataType))) case expressions.GreaterThanOrEqual(a: Attribute, Literal(v, _)) => Some(sources.GreaterThanOrEqual(a.name, v)) case expressions.GreaterThanOrEqual(Literal(v, _), a: Attribute) => Some(sources.LessThanOrEqual(a.name, v)) + case expressions.GreaterThanOrEqual(Cast(a: Attribute, _), l: Literal) => + Some(sources.GreaterThanOrEqual(a.name, + convertToScala(Cast(l, a.dataType).eval(), a.dataType))) + case expressions.GreaterThanOrEqual(l: Literal, Cast(a: Attribute, _)) => + Some(sources.LessThanOrEqual(a.name, + convertToScala(Cast(l, a.dataType).eval(), a.dataType))) case expressions.LessThanOrEqual(a: Attribute, Literal(v, _)) => Some(sources.LessThanOrEqual(a.name, v)) case expressions.LessThanOrEqual(Literal(v, _), a: Attribute) => Some(sources.GreaterThanOrEqual(a.name, v)) + case expressions.LessThanOrEqual(Cast(a: Attribute, _), l: Literal) => + Some(sources.LessThanOrEqual(a.name, + convertToScala(Cast(l, a.dataType).eval(), a.dataType))) + case expressions.LessThanOrEqual(l: Literal, Cast(a: Attribute, _)) => + Some(sources.GreaterThanOrEqual(a.name, + convertToScala(Cast(l, a.dataType).eval(), a.dataType))) case expressions.InSet(a: Attribute, set) => Some(sources.In(a.name, set.toArray)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala index 8eab6a0adccc4..281943e23fcff 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala @@ -284,7 +284,7 @@ private[sql] class JDBCRDD( /** * `filters`, but as a WHERE clause suitable for injection into a SQL query. */ - private val filterWhereClause: String = { + val filterWhereClause: String = { val filterStrings = filters map compileFilter filter (_ != null) if (filterStrings.size > 0) { val sb = new StringBuilder("WHERE ") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index 42f2449afb0f9..b9cfae51e809c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -25,6 +25,8 @@ import org.h2.jdbc.JdbcSQLException import org.scalatest.BeforeAndAfter import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.execution.datasources.jdbc.JDBCRDD +import org.apache.spark.sql.execution.PhysicalRDD import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -148,6 +150,18 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter { |OPTIONS (url '$url', dbtable 'TEST.FLTTYPES', user 'testUser', password 'testPass') """.stripMargin.replaceAll("\n", " ")) + conn.prepareStatement("create table test.decimals (a DECIMAL(7, 2), b DECIMAL(4, 0))"). + executeUpdate() + conn.prepareStatement("insert into test.decimals values (12345.67, 1234)").executeUpdate() + conn.prepareStatement("insert into test.decimals values (34567.89, 1428)").executeUpdate() + conn.commit() + sql( + s""" + |CREATE TEMPORARY TABLE decimals + |USING org.apache.spark.sql.jdbc + |OPTIONS (url '$url', dbtable 'TEST.DECIMALS', user 'testUser', password 'testPass') + """.stripMargin.replaceAll("\n", " ")) + conn.prepareStatement( s""" |create table test.nulltypes (a INT, b BOOLEAN, c TINYINT, d BINARY(20), e VARCHAR(20), @@ -445,4 +459,24 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter { assert(agg.getCatalystType(1, "", 1, null) === Some(StringType)) } + test("SPARK-9182: filters are not passed through to jdbc source") { + def checkPushedFilter(query: String, filterStr: String): Unit = { + val rddOpt = sql(query).queryExecution.executedPlan.collectFirst { + case PhysicalRDD(_, rdd: JDBCRDD, _) => rdd + } + assert(rddOpt.isDefined) + val pushedFilterStr = rddOpt.get.filterWhereClause + assert(pushedFilterStr.contains(filterStr), + s"Expected to push [$filterStr], actually we pushed [$pushedFilterStr]") + } + + checkPushedFilter("select * from foobar where NAME = 'fred'", "NAME = 'fred'") + checkPushedFilter("select * from inttypes where A > '15'", "A > 15") + checkPushedFilter("select * from inttypes where C <= 20", "C <= 20") + checkPushedFilter("select * from decimals where A > 1000", "A > 1000.00") + checkPushedFilter("select * from decimals where A > 1000 AND A < 2000", + "A > 1000.00 AND A < 2000.00") + checkPushedFilter("select * from decimals where A = 2000 AND B > 20", "A = 2000.00 AND B > 20") + checkPushedFilter("select * from timetypes where B > '1998-09-10'", "B > 1998-09-10") + } } From 3ecb3794302dc12d0989f8d725483b2cc37762cf Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Wed, 12 Aug 2015 20:01:34 +0800 Subject: [PATCH 049/253] [SPARK-9407] [SQL] Relaxes Parquet ValidTypeMap to allow ENUM predicates to be pushed down This PR adds a hacky workaround for PARQUET-201, and should be removed once we upgrade to parquet-mr 1.8.1 or higher versions. In Parquet, not all types of columns can be used for filter push-down optimization. The set of valid column types is controlled by `ValidTypeMap`. Unfortunately, in parquet-mr 1.7.0 and prior versions, this limitation is too strict, and doesn't allow `BINARY (ENUM)` columns to be pushed down. On the other hand, `BINARY (ENUM)` is commonly seen in Parquet files written by libraries like `parquet-avro`. This restriction is problematic for Spark SQL, because Spark SQL doesn't have a type that maps to Parquet `BINARY (ENUM)` directly, and always converts `BINARY (ENUM)` to Catalyst `StringType`. Thus, a predicate involving a `BINARY (ENUM)` is recognized as one involving a string field instead and can be pushed down by the query optimizer. Such predicates are actually perfectly legal except that it fails the `ValidTypeMap` check. The workaround added here is relaxing `ValidTypeMap` to include `BINARY (ENUM)`. I also took the chance to simplify `ParquetCompatibilityTest` a little bit when adding regression test. Author: Cheng Lian Closes #8107 from liancheng/spark-9407/parquet-enum-filter-push-down. --- .../datasources/parquet/ParquetFilters.scala | 38 ++++- .../datasources/parquet/ParquetRelation.scala | 2 +- sql/core/src/test/README.md | 16 +- sql/core/src/test/avro/parquet-compat.avdl | 13 +- sql/core/src/test/avro/parquet-compat.avpr | 13 +- .../parquet/test/avro/CompatibilityTest.java | 2 +- .../datasources/parquet/test/avro/Nested.java | 4 +- .../parquet/test/avro/ParquetAvroCompat.java | 4 +- .../parquet/test/avro/ParquetEnum.java | 142 ++++++++++++++++++ .../datasources/parquet/test/avro/Suit.java | 13 ++ .../ParquetAvroCompatibilitySuite.scala | 105 +++++++------ .../parquet/ParquetCompatibilityTest.scala | 33 +--- .../test/scripts/{gen-code.sh => gen-avro.sh} | 13 +- sql/core/src/test/scripts/gen-thrift.sh | 27 ++++ .../src/test/thrift/parquet-compat.thrift | 2 +- .../hive/ParquetHiveCompatibilitySuite.scala | 83 +++++----- 16 files changed, 374 insertions(+), 136 deletions(-) create mode 100644 sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/ParquetEnum.java create mode 100644 sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/Suit.java rename sql/core/src/test/scripts/{gen-code.sh => gen-avro.sh} (76%) create mode 100755 sql/core/src/test/scripts/gen-thrift.sh diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala index 9e2e232f50167..63915e0a28655 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala @@ -25,9 +25,10 @@ import org.apache.hadoop.conf.Configuration import org.apache.parquet.filter2.compat.FilterCompat import org.apache.parquet.filter2.compat.FilterCompat._ import org.apache.parquet.filter2.predicate.FilterApi._ -import org.apache.parquet.filter2.predicate.{FilterApi, FilterPredicate, Statistics} -import org.apache.parquet.filter2.predicate.UserDefinedPredicate +import org.apache.parquet.filter2.predicate._ import org.apache.parquet.io.api.Binary +import org.apache.parquet.schema.OriginalType +import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName import org.apache.spark.SparkEnv import org.apache.spark.sql.catalyst.expressions._ @@ -197,6 +198,8 @@ private[sql] object ParquetFilters { def createFilter(schema: StructType, predicate: sources.Filter): Option[FilterPredicate] = { val dataTypeOf = schema.map(f => f.name -> f.dataType).toMap + relaxParquetValidTypeMap + // NOTE: // // For any comparison operator `cmp`, both `a cmp NULL` and `NULL cmp a` evaluate to `NULL`, @@ -239,6 +242,37 @@ private[sql] object ParquetFilters { } } + // !! HACK ALERT !! + // + // This lazy val is a workaround for PARQUET-201, and should be removed once we upgrade to + // parquet-mr 1.8.1 or higher versions. + // + // In Parquet, not all types of columns can be used for filter push-down optimization. The set + // of valid column types is controlled by `ValidTypeMap`. Unfortunately, in parquet-mr 1.7.0 and + // prior versions, the limitation is too strict, and doesn't allow `BINARY (ENUM)` columns to be + // pushed down. + // + // This restriction is problematic for Spark SQL, because Spark SQL doesn't have a type that maps + // to Parquet original type `ENUM` directly, and always converts `ENUM` to `StringType`. Thus, + // a predicate involving a `ENUM` field can be pushed-down as a string column, which is perfectly + // legal except that it fails the `ValidTypeMap` check. + // + // Here we add `BINARY (ENUM)` into `ValidTypeMap` lazily via reflection to workaround this issue. + private lazy val relaxParquetValidTypeMap: Unit = { + val constructor = Class + .forName(classOf[ValidTypeMap].getCanonicalName + "$FullTypeDescriptor") + .getDeclaredConstructor(classOf[PrimitiveTypeName], classOf[OriginalType]) + + constructor.setAccessible(true) + val enumTypeDescriptor = constructor + .newInstance(PrimitiveTypeName.BINARY, OriginalType.ENUM) + .asInstanceOf[AnyRef] + + val addMethod = classOf[ValidTypeMap].getDeclaredMethods.find(_.getName == "add").get + addMethod.setAccessible(true) + addMethod.invoke(null, classOf[Binary], enumTypeDescriptor) + } + /** * Converts Catalyst predicate expressions to Parquet filter predicates. * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala index c71c69b6e80b1..52fac18ba187a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala @@ -678,7 +678,7 @@ private[sql] object ParquetRelation extends Logging { val followParquetFormatSpec = sqlContext.conf.followParquetFormatSpec val serializedConf = new SerializableConfiguration(sqlContext.sparkContext.hadoopConfiguration) - // HACK ALERT: + // !! HACK ALERT !! // // Parquet requires `FileStatus`es to read footers. Here we try to send cached `FileStatus`es // to executor side to avoid fetching them again. However, `FileStatus` is not `Serializable` diff --git a/sql/core/src/test/README.md b/sql/core/src/test/README.md index 3dd9861b4896d..421c2ea4f7aed 100644 --- a/sql/core/src/test/README.md +++ b/sql/core/src/test/README.md @@ -6,23 +6,19 @@ The following directories and files are used for Parquet compatibility tests: . ├── README.md # This file ├── avro -│   ├── parquet-compat.avdl # Testing Avro IDL -│   └── parquet-compat.avpr # !! NO TOUCH !! Protocol file generated from parquet-compat.avdl +│   ├── *.avdl # Testing Avro IDL(s) +│   └── *.avpr # !! NO TOUCH !! Protocol files generated from Avro IDL(s) ├── gen-java # !! NO TOUCH !! Generated Java code ├── scripts -│   └── gen-code.sh # Script used to generate Java code for Thrift and Avro +│   ├── gen-avro.sh # Script used to generate Java code for Avro +│   └── gen-thrift.sh # Script used to generate Java code for Thrift └── thrift - └── parquet-compat.thrift # Testing Thrift schema + └── *.thrift # Testing Thrift schema(s) ``` -Generated Java code are used in the following test suites: - -- `org.apache.spark.sql.parquet.ParquetAvroCompatibilitySuite` -- `org.apache.spark.sql.parquet.ParquetThriftCompatibilitySuite` - To avoid code generation during build time, Java code generated from testing Thrift schema and Avro IDL are also checked in. -When updating the testing Thrift schema and Avro IDL, please run `gen-code.sh` to update all the generated Java code. +When updating the testing Thrift schema and Avro IDL, please run `gen-avro.sh` and `gen-thrift.sh` accordingly to update generated Java code. ## Prerequisites diff --git a/sql/core/src/test/avro/parquet-compat.avdl b/sql/core/src/test/avro/parquet-compat.avdl index 24729f6143e6c..8070d0a9170a3 100644 --- a/sql/core/src/test/avro/parquet-compat.avdl +++ b/sql/core/src/test/avro/parquet-compat.avdl @@ -16,8 +16,19 @@ */ // This is a test protocol for testing parquet-avro compatibility. -@namespace("org.apache.spark.sql.parquet.test.avro") +@namespace("org.apache.spark.sql.execution.datasources.parquet.test.avro") protocol CompatibilityTest { + enum Suit { + SPADES, + HEARTS, + DIAMONDS, + CLUBS + } + + record ParquetEnum { + Suit suit; + } + record Nested { array nested_ints_column; string nested_string_column; diff --git a/sql/core/src/test/avro/parquet-compat.avpr b/sql/core/src/test/avro/parquet-compat.avpr index a83b7c990dd2e..060391765034b 100644 --- a/sql/core/src/test/avro/parquet-compat.avpr +++ b/sql/core/src/test/avro/parquet-compat.avpr @@ -1,7 +1,18 @@ { "protocol" : "CompatibilityTest", - "namespace" : "org.apache.spark.sql.parquet.test.avro", + "namespace" : "org.apache.spark.sql.execution.datasources.parquet.test.avro", "types" : [ { + "type" : "enum", + "name" : "Suit", + "symbols" : [ "SPADES", "HEARTS", "DIAMONDS", "CLUBS" ] + }, { + "type" : "record", + "name" : "ParquetEnum", + "fields" : [ { + "name" : "suit", + "type" : "Suit" + } ] + }, { "type" : "record", "name" : "Nested", "fields" : [ { diff --git a/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/CompatibilityTest.java b/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/CompatibilityTest.java index 70dec1a9d3c92..2368323cb36b9 100644 --- a/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/CompatibilityTest.java +++ b/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/CompatibilityTest.java @@ -8,7 +8,7 @@ @SuppressWarnings("all") @org.apache.avro.specific.AvroGenerated public interface CompatibilityTest { - public static final org.apache.avro.Protocol PROTOCOL = org.apache.avro.Protocol.parse("{\"protocol\":\"CompatibilityTest\",\"namespace\":\"org.apache.spark.sql.parquet.test.avro\",\"types\":[{\"type\":\"record\",\"name\":\"Nested\",\"fields\":[{\"name\":\"nested_ints_column\",\"type\":{\"type\":\"array\",\"items\":\"int\"}},{\"name\":\"nested_string_column\",\"type\":{\"type\":\"string\",\"avro.java.string\":\"String\"}}]},{\"type\":\"record\",\"name\":\"ParquetAvroCompat\",\"fields\":[{\"name\":\"bool_column\",\"type\":\"boolean\"},{\"name\":\"int_column\",\"type\":\"int\"},{\"name\":\"long_column\",\"type\":\"long\"},{\"name\":\"float_column\",\"type\":\"float\"},{\"name\":\"double_column\",\"type\":\"double\"},{\"name\":\"binary_column\",\"type\":\"bytes\"},{\"name\":\"string_column\",\"type\":{\"type\":\"string\",\"avro.java.string\":\"String\"}},{\"name\":\"maybe_bool_column\",\"type\":[\"null\",\"boolean\"]},{\"name\":\"maybe_int_column\",\"type\":[\"null\",\"int\"]},{\"name\":\"maybe_long_column\",\"type\":[\"null\",\"long\"]},{\"name\":\"maybe_float_column\",\"type\":[\"null\",\"float\"]},{\"name\":\"maybe_double_column\",\"type\":[\"null\",\"double\"]},{\"name\":\"maybe_binary_column\",\"type\":[\"null\",\"bytes\"]},{\"name\":\"maybe_string_column\",\"type\":[\"null\",{\"type\":\"string\",\"avro.java.string\":\"String\"}]},{\"name\":\"strings_column\",\"type\":{\"type\":\"array\",\"items\":{\"type\":\"string\",\"avro.java.string\":\"String\"}}},{\"name\":\"string_to_int_column\",\"type\":{\"type\":\"map\",\"values\":\"int\",\"avro.java.string\":\"String\"}},{\"name\":\"complex_column\",\"type\":{\"type\":\"map\",\"values\":{\"type\":\"array\",\"items\":\"Nested\"},\"avro.java.string\":\"String\"}}]}],\"messages\":{}}"); + public static final org.apache.avro.Protocol PROTOCOL = org.apache.avro.Protocol.parse("{\"protocol\":\"CompatibilityTest\",\"namespace\":\"org.apache.spark.sql.execution.datasources.parquet.test.avro\",\"types\":[{\"type\":\"enum\",\"name\":\"Suit\",\"symbols\":[\"SPADES\",\"HEARTS\",\"DIAMONDS\",\"CLUBS\"]},{\"type\":\"record\",\"name\":\"ParquetEnum\",\"fields\":[{\"name\":\"suit\",\"type\":\"Suit\"}]},{\"type\":\"record\",\"name\":\"Nested\",\"fields\":[{\"name\":\"nested_ints_column\",\"type\":{\"type\":\"array\",\"items\":\"int\"}},{\"name\":\"nested_string_column\",\"type\":{\"type\":\"string\",\"avro.java.string\":\"String\"}}]},{\"type\":\"record\",\"name\":\"ParquetAvroCompat\",\"fields\":[{\"name\":\"bool_column\",\"type\":\"boolean\"},{\"name\":\"int_column\",\"type\":\"int\"},{\"name\":\"long_column\",\"type\":\"long\"},{\"name\":\"float_column\",\"type\":\"float\"},{\"name\":\"double_column\",\"type\":\"double\"},{\"name\":\"binary_column\",\"type\":\"bytes\"},{\"name\":\"string_column\",\"type\":{\"type\":\"string\",\"avro.java.string\":\"String\"}},{\"name\":\"maybe_bool_column\",\"type\":[\"null\",\"boolean\"]},{\"name\":\"maybe_int_column\",\"type\":[\"null\",\"int\"]},{\"name\":\"maybe_long_column\",\"type\":[\"null\",\"long\"]},{\"name\":\"maybe_float_column\",\"type\":[\"null\",\"float\"]},{\"name\":\"maybe_double_column\",\"type\":[\"null\",\"double\"]},{\"name\":\"maybe_binary_column\",\"type\":[\"null\",\"bytes\"]},{\"name\":\"maybe_string_column\",\"type\":[\"null\",{\"type\":\"string\",\"avro.java.string\":\"String\"}]},{\"name\":\"strings_column\",\"type\":{\"type\":\"array\",\"items\":{\"type\":\"string\",\"avro.java.string\":\"String\"}}},{\"name\":\"string_to_int_column\",\"type\":{\"type\":\"map\",\"values\":\"int\",\"avro.java.string\":\"String\"}},{\"name\":\"complex_column\",\"type\":{\"type\":\"map\",\"values\":{\"type\":\"array\",\"items\":\"Nested\"},\"avro.java.string\":\"String\"}}]}],\"messages\":{}}"); @SuppressWarnings("all") public interface Callback extends CompatibilityTest { diff --git a/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/Nested.java b/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/Nested.java index a0a406bcd10c1..a7bf4841919c5 100644 --- a/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/Nested.java +++ b/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/Nested.java @@ -3,11 +3,11 @@ * * DO NOT EDIT DIRECTLY */ -package org.apache.spark.sql.execution.datasources.parquet.test.avro; +package org.apache.spark.sql.execution.datasources.parquet.test.avro; @SuppressWarnings("all") @org.apache.avro.specific.AvroGenerated public class Nested extends org.apache.avro.specific.SpecificRecordBase implements org.apache.avro.specific.SpecificRecord { - public static final org.apache.avro.Schema SCHEMA$ = new org.apache.avro.Schema.Parser().parse("{\"type\":\"record\",\"name\":\"Nested\",\"namespace\":\"org.apache.spark.sql.parquet.test.avro\",\"fields\":[{\"name\":\"nested_ints_column\",\"type\":{\"type\":\"array\",\"items\":\"int\"}},{\"name\":\"nested_string_column\",\"type\":{\"type\":\"string\",\"avro.java.string\":\"String\"}}]}"); + public static final org.apache.avro.Schema SCHEMA$ = new org.apache.avro.Schema.Parser().parse("{\"type\":\"record\",\"name\":\"Nested\",\"namespace\":\"org.apache.spark.sql.execution.datasources.parquet.test.avro\",\"fields\":[{\"name\":\"nested_ints_column\",\"type\":{\"type\":\"array\",\"items\":\"int\"}},{\"name\":\"nested_string_column\",\"type\":{\"type\":\"string\",\"avro.java.string\":\"String\"}}]}"); public static org.apache.avro.Schema getClassSchema() { return SCHEMA$; } @Deprecated public java.util.List nested_ints_column; @Deprecated public java.lang.String nested_string_column; diff --git a/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/ParquetAvroCompat.java b/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/ParquetAvroCompat.java index 6198b00b1e3ca..681cacbd12c7c 100644 --- a/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/ParquetAvroCompat.java +++ b/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/ParquetAvroCompat.java @@ -3,11 +3,11 @@ * * DO NOT EDIT DIRECTLY */ -package org.apache.spark.sql.execution.datasources.parquet.test.avro; +package org.apache.spark.sql.execution.datasources.parquet.test.avro; @SuppressWarnings("all") @org.apache.avro.specific.AvroGenerated public class ParquetAvroCompat extends org.apache.avro.specific.SpecificRecordBase implements org.apache.avro.specific.SpecificRecord { - public static final org.apache.avro.Schema SCHEMA$ = new org.apache.avro.Schema.Parser().parse("{\"type\":\"record\",\"name\":\"ParquetAvroCompat\",\"namespace\":\"org.apache.spark.sql.parquet.test.avro\",\"fields\":[{\"name\":\"bool_column\",\"type\":\"boolean\"},{\"name\":\"int_column\",\"type\":\"int\"},{\"name\":\"long_column\",\"type\":\"long\"},{\"name\":\"float_column\",\"type\":\"float\"},{\"name\":\"double_column\",\"type\":\"double\"},{\"name\":\"binary_column\",\"type\":\"bytes\"},{\"name\":\"string_column\",\"type\":{\"type\":\"string\",\"avro.java.string\":\"String\"}},{\"name\":\"maybe_bool_column\",\"type\":[\"null\",\"boolean\"]},{\"name\":\"maybe_int_column\",\"type\":[\"null\",\"int\"]},{\"name\":\"maybe_long_column\",\"type\":[\"null\",\"long\"]},{\"name\":\"maybe_float_column\",\"type\":[\"null\",\"float\"]},{\"name\":\"maybe_double_column\",\"type\":[\"null\",\"double\"]},{\"name\":\"maybe_binary_column\",\"type\":[\"null\",\"bytes\"]},{\"name\":\"maybe_string_column\",\"type\":[\"null\",{\"type\":\"string\",\"avro.java.string\":\"String\"}]},{\"name\":\"strings_column\",\"type\":{\"type\":\"array\",\"items\":{\"type\":\"string\",\"avro.java.string\":\"String\"}}},{\"name\":\"string_to_int_column\",\"type\":{\"type\":\"map\",\"values\":\"int\",\"avro.java.string\":\"String\"}},{\"name\":\"complex_column\",\"type\":{\"type\":\"map\",\"values\":{\"type\":\"array\",\"items\":{\"type\":\"record\",\"name\":\"Nested\",\"fields\":[{\"name\":\"nested_ints_column\",\"type\":{\"type\":\"array\",\"items\":\"int\"}},{\"name\":\"nested_string_column\",\"type\":{\"type\":\"string\",\"avro.java.string\":\"String\"}}]}},\"avro.java.string\":\"String\"}}]}"); + public static final org.apache.avro.Schema SCHEMA$ = new org.apache.avro.Schema.Parser().parse("{\"type\":\"record\",\"name\":\"ParquetAvroCompat\",\"namespace\":\"org.apache.spark.sql.execution.datasources.parquet.test.avro\",\"fields\":[{\"name\":\"bool_column\",\"type\":\"boolean\"},{\"name\":\"int_column\",\"type\":\"int\"},{\"name\":\"long_column\",\"type\":\"long\"},{\"name\":\"float_column\",\"type\":\"float\"},{\"name\":\"double_column\",\"type\":\"double\"},{\"name\":\"binary_column\",\"type\":\"bytes\"},{\"name\":\"string_column\",\"type\":{\"type\":\"string\",\"avro.java.string\":\"String\"}},{\"name\":\"maybe_bool_column\",\"type\":[\"null\",\"boolean\"]},{\"name\":\"maybe_int_column\",\"type\":[\"null\",\"int\"]},{\"name\":\"maybe_long_column\",\"type\":[\"null\",\"long\"]},{\"name\":\"maybe_float_column\",\"type\":[\"null\",\"float\"]},{\"name\":\"maybe_double_column\",\"type\":[\"null\",\"double\"]},{\"name\":\"maybe_binary_column\",\"type\":[\"null\",\"bytes\"]},{\"name\":\"maybe_string_column\",\"type\":[\"null\",{\"type\":\"string\",\"avro.java.string\":\"String\"}]},{\"name\":\"strings_column\",\"type\":{\"type\":\"array\",\"items\":{\"type\":\"string\",\"avro.java.string\":\"String\"}}},{\"name\":\"string_to_int_column\",\"type\":{\"type\":\"map\",\"values\":\"int\",\"avro.java.string\":\"String\"}},{\"name\":\"complex_column\",\"type\":{\"type\":\"map\",\"values\":{\"type\":\"array\",\"items\":{\"type\":\"record\",\"name\":\"Nested\",\"fields\":[{\"name\":\"nested_ints_column\",\"type\":{\"type\":\"array\",\"items\":\"int\"}},{\"name\":\"nested_string_column\",\"type\":{\"type\":\"string\",\"avro.java.string\":\"String\"}}]}},\"avro.java.string\":\"String\"}}]}"); public static org.apache.avro.Schema getClassSchema() { return SCHEMA$; } @Deprecated public boolean bool_column; @Deprecated public int int_column; diff --git a/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/ParquetEnum.java b/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/ParquetEnum.java new file mode 100644 index 0000000000000..05fefe4cee754 --- /dev/null +++ b/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/ParquetEnum.java @@ -0,0 +1,142 @@ +/** + * Autogenerated by Avro + * + * DO NOT EDIT DIRECTLY + */ +package org.apache.spark.sql.execution.datasources.parquet.test.avro; +@SuppressWarnings("all") +@org.apache.avro.specific.AvroGenerated +public class ParquetEnum extends org.apache.avro.specific.SpecificRecordBase implements org.apache.avro.specific.SpecificRecord { + public static final org.apache.avro.Schema SCHEMA$ = new org.apache.avro.Schema.Parser().parse("{\"type\":\"record\",\"name\":\"ParquetEnum\",\"namespace\":\"org.apache.spark.sql.execution.datasources.parquet.test.avro\",\"fields\":[{\"name\":\"suit\",\"type\":{\"type\":\"enum\",\"name\":\"Suit\",\"symbols\":[\"SPADES\",\"HEARTS\",\"DIAMONDS\",\"CLUBS\"]}}]}"); + public static org.apache.avro.Schema getClassSchema() { return SCHEMA$; } + @Deprecated public org.apache.spark.sql.execution.datasources.parquet.test.avro.Suit suit; + + /** + * Default constructor. Note that this does not initialize fields + * to their default values from the schema. If that is desired then + * one should use newBuilder(). + */ + public ParquetEnum() {} + + /** + * All-args constructor. + */ + public ParquetEnum(org.apache.spark.sql.execution.datasources.parquet.test.avro.Suit suit) { + this.suit = suit; + } + + public org.apache.avro.Schema getSchema() { return SCHEMA$; } + // Used by DatumWriter. Applications should not call. + public java.lang.Object get(int field$) { + switch (field$) { + case 0: return suit; + default: throw new org.apache.avro.AvroRuntimeException("Bad index"); + } + } + // Used by DatumReader. Applications should not call. + @SuppressWarnings(value="unchecked") + public void put(int field$, java.lang.Object value$) { + switch (field$) { + case 0: suit = (org.apache.spark.sql.execution.datasources.parquet.test.avro.Suit)value$; break; + default: throw new org.apache.avro.AvroRuntimeException("Bad index"); + } + } + + /** + * Gets the value of the 'suit' field. + */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.Suit getSuit() { + return suit; + } + + /** + * Sets the value of the 'suit' field. + * @param value the value to set. + */ + public void setSuit(org.apache.spark.sql.execution.datasources.parquet.test.avro.Suit value) { + this.suit = value; + } + + /** Creates a new ParquetEnum RecordBuilder */ + public static org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetEnum.Builder newBuilder() { + return new org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetEnum.Builder(); + } + + /** Creates a new ParquetEnum RecordBuilder by copying an existing Builder */ + public static org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetEnum.Builder newBuilder(org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetEnum.Builder other) { + return new org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetEnum.Builder(other); + } + + /** Creates a new ParquetEnum RecordBuilder by copying an existing ParquetEnum instance */ + public static org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetEnum.Builder newBuilder(org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetEnum other) { + return new org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetEnum.Builder(other); + } + + /** + * RecordBuilder for ParquetEnum instances. + */ + public static class Builder extends org.apache.avro.specific.SpecificRecordBuilderBase + implements org.apache.avro.data.RecordBuilder { + + private org.apache.spark.sql.execution.datasources.parquet.test.avro.Suit suit; + + /** Creates a new Builder */ + private Builder() { + super(org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetEnum.SCHEMA$); + } + + /** Creates a Builder by copying an existing Builder */ + private Builder(org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetEnum.Builder other) { + super(other); + if (isValidValue(fields()[0], other.suit)) { + this.suit = data().deepCopy(fields()[0].schema(), other.suit); + fieldSetFlags()[0] = true; + } + } + + /** Creates a Builder by copying an existing ParquetEnum instance */ + private Builder(org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetEnum other) { + super(org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetEnum.SCHEMA$); + if (isValidValue(fields()[0], other.suit)) { + this.suit = data().deepCopy(fields()[0].schema(), other.suit); + fieldSetFlags()[0] = true; + } + } + + /** Gets the value of the 'suit' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.Suit getSuit() { + return suit; + } + + /** Sets the value of the 'suit' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetEnum.Builder setSuit(org.apache.spark.sql.execution.datasources.parquet.test.avro.Suit value) { + validate(fields()[0], value); + this.suit = value; + fieldSetFlags()[0] = true; + return this; + } + + /** Checks whether the 'suit' field has been set */ + public boolean hasSuit() { + return fieldSetFlags()[0]; + } + + /** Clears the value of the 'suit' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetEnum.Builder clearSuit() { + suit = null; + fieldSetFlags()[0] = false; + return this; + } + + @Override + public ParquetEnum build() { + try { + ParquetEnum record = new ParquetEnum(); + record.suit = fieldSetFlags()[0] ? this.suit : (org.apache.spark.sql.execution.datasources.parquet.test.avro.Suit) defaultValue(fields()[0]); + return record; + } catch (Exception e) { + throw new org.apache.avro.AvroRuntimeException(e); + } + } + } +} diff --git a/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/Suit.java b/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/Suit.java new file mode 100644 index 0000000000000..00711a0c2a267 --- /dev/null +++ b/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/Suit.java @@ -0,0 +1,13 @@ +/** + * Autogenerated by Avro + * + * DO NOT EDIT DIRECTLY + */ +package org.apache.spark.sql.execution.datasources.parquet.test.avro; +@SuppressWarnings("all") +@org.apache.avro.specific.AvroGenerated +public enum Suit { + SPADES, HEARTS, DIAMONDS, CLUBS ; + public static final org.apache.avro.Schema SCHEMA$ = new org.apache.avro.Schema.Parser().parse("{\"type\":\"enum\",\"name\":\"Suit\",\"namespace\":\"org.apache.spark.sql.execution.datasources.parquet.test.avro\",\"symbols\":[\"SPADES\",\"HEARTS\",\"DIAMONDS\",\"CLUBS\"]}"); + public static org.apache.avro.Schema getClassSchema() { return SCHEMA$; } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetAvroCompatibilitySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetAvroCompatibilitySuite.scala index 4d9c07bb7a570..866a975ad5404 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetAvroCompatibilitySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetAvroCompatibilitySuite.scala @@ -22,10 +22,12 @@ import java.util.{List => JList, Map => JMap} import scala.collection.JavaConversions._ +import org.apache.avro.Schema +import org.apache.avro.generic.IndexedRecord import org.apache.hadoop.fs.Path import org.apache.parquet.avro.AvroParquetWriter -import org.apache.spark.sql.execution.datasources.parquet.test.avro.{Nested, ParquetAvroCompat} +import org.apache.spark.sql.execution.datasources.parquet.test.avro.{Nested, ParquetAvroCompat, ParquetEnum, Suit} import org.apache.spark.sql.test.TestSQLContext import org.apache.spark.sql.{Row, SQLContext} @@ -34,52 +36,55 @@ class ParquetAvroCompatibilitySuite extends ParquetCompatibilityTest { override val sqlContext: SQLContext = TestSQLContext - override protected def beforeAll(): Unit = { - super.beforeAll() - - val writer = - new AvroParquetWriter[ParquetAvroCompat]( - new Path(parquetStore.getCanonicalPath), - ParquetAvroCompat.getClassSchema) - - (0 until 10).foreach(i => writer.write(makeParquetAvroCompat(i))) - writer.close() + private def withWriter[T <: IndexedRecord] + (path: String, schema: Schema) + (f: AvroParquetWriter[T] => Unit) = { + val writer = new AvroParquetWriter[T](new Path(path), schema) + try f(writer) finally writer.close() } test("Read Parquet file generated by parquet-avro") { - logInfo( - s"""Schema of the Parquet file written by parquet-avro: - |${readParquetSchema(parquetStore.getCanonicalPath)} + withTempPath { dir => + val path = dir.getCanonicalPath + + withWriter[ParquetAvroCompat](path, ParquetAvroCompat.getClassSchema) { writer => + (0 until 10).foreach(i => writer.write(makeParquetAvroCompat(i))) + } + + logInfo( + s"""Schema of the Parquet file written by parquet-avro: + |${readParquetSchema(path)} """.stripMargin) - checkAnswer(sqlContext.read.parquet(parquetStore.getCanonicalPath), (0 until 10).map { i => - def nullable[T <: AnyRef]: ( => T) => T = makeNullable[T](i) - - Row( - i % 2 == 0, - i, - i.toLong * 10, - i.toFloat + 0.1f, - i.toDouble + 0.2d, - s"val_$i".getBytes, - s"val_$i", - - nullable(i % 2 == 0: java.lang.Boolean), - nullable(i: Integer), - nullable(i.toLong: java.lang.Long), - nullable(i.toFloat + 0.1f: java.lang.Float), - nullable(i.toDouble + 0.2d: java.lang.Double), - nullable(s"val_$i".getBytes), - nullable(s"val_$i"), - - Seq.tabulate(3)(n => s"arr_${i + n}"), - Seq.tabulate(3)(n => n.toString -> (i + n: Integer)).toMap, - Seq.tabulate(3) { n => - (i + n).toString -> Seq.tabulate(3) { m => - Row(Seq.tabulate(3)(j => i + j + m), s"val_${i + m}") - } - }.toMap) - }) + checkAnswer(sqlContext.read.parquet(path), (0 until 10).map { i => + def nullable[T <: AnyRef]: ( => T) => T = makeNullable[T](i) + + Row( + i % 2 == 0, + i, + i.toLong * 10, + i.toFloat + 0.1f, + i.toDouble + 0.2d, + s"val_$i".getBytes, + s"val_$i", + + nullable(i % 2 == 0: java.lang.Boolean), + nullable(i: Integer), + nullable(i.toLong: java.lang.Long), + nullable(i.toFloat + 0.1f: java.lang.Float), + nullable(i.toDouble + 0.2d: java.lang.Double), + nullable(s"val_$i".getBytes), + nullable(s"val_$i"), + + Seq.tabulate(3)(n => s"arr_${i + n}"), + Seq.tabulate(3)(n => n.toString -> (i + n: Integer)).toMap, + Seq.tabulate(3) { n => + (i + n).toString -> Seq.tabulate(3) { m => + Row(Seq.tabulate(3)(j => i + j + m), s"val_${i + m}") + } + }.toMap) + }) + } } def makeParquetAvroCompat(i: Int): ParquetAvroCompat = { @@ -122,4 +127,20 @@ class ParquetAvroCompatibilitySuite extends ParquetCompatibilityTest { .build() } + + test("SPARK-9407 Don't push down predicates involving Parquet ENUM columns") { + import sqlContext.implicits._ + + withTempPath { dir => + val path = dir.getCanonicalPath + + withWriter[ParquetEnum](path, ParquetEnum.getClassSchema) { writer => + (0 until 4).foreach { i => + writer.write(ParquetEnum.newBuilder().setSuit(Suit.values.apply(i)).build()) + } + } + + checkAnswer(sqlContext.read.parquet(path).filter('suit === "SPADES"), Row("SPADES")) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCompatibilityTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCompatibilityTest.scala index 68f35b1f3aa83..0ea64aa2a509b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCompatibilityTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCompatibilityTest.scala @@ -16,45 +16,28 @@ */ package org.apache.spark.sql.execution.datasources.parquet -import java.io.File import scala.collection.JavaConversions._ -import org.apache.hadoop.fs.Path +import org.apache.hadoop.fs.{Path, PathFilter} import org.apache.parquet.hadoop.ParquetFileReader import org.apache.parquet.schema.MessageType import org.scalatest.BeforeAndAfterAll import org.apache.spark.sql.QueryTest -import org.apache.spark.util.Utils abstract class ParquetCompatibilityTest extends QueryTest with ParquetTest with BeforeAndAfterAll { - protected var parquetStore: File = _ - - /** - * Optional path to a staging subdirectory which may be created during query processing - * (Hive does this). - * Parquet files under this directory will be ignored in [[readParquetSchema()]] - * @return an optional staging directory to ignore when scanning for parquet files. - */ - protected def stagingDir: Option[String] = None - - override protected def beforeAll(): Unit = { - parquetStore = Utils.createTempDir(namePrefix = "parquet-compat_") - parquetStore.delete() - } - - override protected def afterAll(): Unit = { - Utils.deleteRecursively(parquetStore) + def readParquetSchema(path: String): MessageType = { + readParquetSchema(path, { path => !path.getName.startsWith("_") }) } - def readParquetSchema(path: String): MessageType = { + def readParquetSchema(path: String, pathFilter: Path => Boolean): MessageType = { val fsPath = new Path(path) val fs = fsPath.getFileSystem(configuration) - val parquetFiles = fs.listStatus(fsPath).toSeq.filterNot { status => - status.getPath.getName.startsWith("_") || - stagingDir.map(status.getPath.getName.startsWith).getOrElse(false) - } + val parquetFiles = fs.listStatus(fsPath, new PathFilter { + override def accept(path: Path): Boolean = pathFilter(path) + }).toSeq + val footers = ParquetFileReader.readAllFootersInParallel(configuration, parquetFiles, true) footers.head.getParquetMetadata.getFileMetaData.getSchema } diff --git a/sql/core/src/test/scripts/gen-code.sh b/sql/core/src/test/scripts/gen-avro.sh similarity index 76% rename from sql/core/src/test/scripts/gen-code.sh rename to sql/core/src/test/scripts/gen-avro.sh index 5d8d8ad08555c..48174b287fd7c 100755 --- a/sql/core/src/test/scripts/gen-code.sh +++ b/sql/core/src/test/scripts/gen-avro.sh @@ -22,10 +22,9 @@ cd - rm -rf $BASEDIR/gen-java mkdir -p $BASEDIR/gen-java -thrift\ - --gen java\ - -out $BASEDIR/gen-java\ - $BASEDIR/thrift/parquet-compat.thrift - -avro-tools idl $BASEDIR/avro/parquet-compat.avdl > $BASEDIR/avro/parquet-compat.avpr -avro-tools compile -string protocol $BASEDIR/avro/parquet-compat.avpr $BASEDIR/gen-java +for input in `ls $BASEDIR/avro/*.avdl`; do + filename=$(basename "$input") + filename="${filename%.*}" + avro-tools idl $input> $BASEDIR/avro/${filename}.avpr + avro-tools compile -string protocol $BASEDIR/avro/${filename}.avpr $BASEDIR/gen-java +done diff --git a/sql/core/src/test/scripts/gen-thrift.sh b/sql/core/src/test/scripts/gen-thrift.sh new file mode 100755 index 0000000000000..ada432c68ab95 --- /dev/null +++ b/sql/core/src/test/scripts/gen-thrift.sh @@ -0,0 +1,27 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +cd $(dirname $0)/.. +BASEDIR=`pwd` +cd - + +rm -rf $BASEDIR/gen-java +mkdir -p $BASEDIR/gen-java + +for input in `ls $BASEDIR/thrift/*.thrift`; do + thrift --gen java -out $BASEDIR/gen-java $input +done diff --git a/sql/core/src/test/thrift/parquet-compat.thrift b/sql/core/src/test/thrift/parquet-compat.thrift index fa5ed8c62306a..98bf778aec5d6 100644 --- a/sql/core/src/test/thrift/parquet-compat.thrift +++ b/sql/core/src/test/thrift/parquet-compat.thrift @@ -15,7 +15,7 @@ * limitations under the License. */ -namespace java org.apache.spark.sql.parquet.test.thrift +namespace java org.apache.spark.sql.execution.datasources.parquet.test.thrift enum Suit { SPADES, diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ParquetHiveCompatibilitySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ParquetHiveCompatibilitySuite.scala index 80eb9f122ad90..251e0324bfa5f 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ParquetHiveCompatibilitySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ParquetHiveCompatibilitySuite.scala @@ -32,53 +32,54 @@ class ParquetHiveCompatibilitySuite extends ParquetCompatibilityTest { * Set the staging directory (and hence path to ignore Parquet files under) * to that set by [[HiveConf.ConfVars.STAGINGDIR]]. */ - override val stagingDir: Option[String] = - Some(new HiveConf().getVar(HiveConf.ConfVars.STAGINGDIR)) + private val stagingDir = new HiveConf().getVar(HiveConf.ConfVars.STAGINGDIR) - override protected def beforeAll(): Unit = { - super.beforeAll() + test("Read Parquet file generated by parquet-hive") { + withTable("parquet_compat") { + withTempPath { dir => + val path = dir.getCanonicalPath - withSQLConf(HiveContext.CONVERT_METASTORE_PARQUET.key -> "false") { - withTempTable("data") { - sqlContext.sql( - s"""CREATE TABLE parquet_compat( - | bool_column BOOLEAN, - | byte_column TINYINT, - | short_column SMALLINT, - | int_column INT, - | long_column BIGINT, - | float_column FLOAT, - | double_column DOUBLE, - | - | strings_column ARRAY, - | int_to_string_column MAP - |) - |STORED AS PARQUET - |LOCATION '${parquetStore.getCanonicalPath}' - """.stripMargin) + withSQLConf(HiveContext.CONVERT_METASTORE_PARQUET.key -> "false") { + withTempTable("data") { + sqlContext.sql( + s"""CREATE TABLE parquet_compat( + | bool_column BOOLEAN, + | byte_column TINYINT, + | short_column SMALLINT, + | int_column INT, + | long_column BIGINT, + | float_column FLOAT, + | double_column DOUBLE, + | + | strings_column ARRAY, + | int_to_string_column MAP + |) + |STORED AS PARQUET + |LOCATION '$path' + """.stripMargin) - val schema = sqlContext.table("parquet_compat").schema - val rowRDD = sqlContext.sparkContext.parallelize(makeRows).coalesce(1) - sqlContext.createDataFrame(rowRDD, schema).registerTempTable("data") - sqlContext.sql("INSERT INTO TABLE parquet_compat SELECT * FROM data") - } - } - } + val schema = sqlContext.table("parquet_compat").schema + val rowRDD = sqlContext.sparkContext.parallelize(makeRows).coalesce(1) + sqlContext.createDataFrame(rowRDD, schema).registerTempTable("data") + sqlContext.sql("INSERT INTO TABLE parquet_compat SELECT * FROM data") + } + } - override protected def afterAll(): Unit = { - sqlContext.sql("DROP TABLE parquet_compat") - } + val schema = readParquetSchema(path, { path => + !path.getName.startsWith("_") && !path.getName.startsWith(stagingDir) + }) - test("Read Parquet file generated by parquet-hive") { - logInfo( - s"""Schema of the Parquet file written by parquet-hive: - |${readParquetSchema(parquetStore.getCanonicalPath)} - """.stripMargin) + logInfo( + s"""Schema of the Parquet file written by parquet-hive: + |$schema + """.stripMargin) - // Unfortunately parquet-hive doesn't add `UTF8` annotation to BINARY when writing strings. - // Have to assume all BINARY values are strings here. - withSQLConf(SQLConf.PARQUET_BINARY_AS_STRING.key -> "true") { - checkAnswer(sqlContext.read.parquet(parquetStore.getCanonicalPath), makeRows) + // Unfortunately parquet-hive doesn't add `UTF8` annotation to BINARY when writing strings. + // Have to assume all BINARY values are strings here. + withSQLConf(SQLConf.PARQUET_BINARY_AS_STRING.key -> "true") { + checkAnswer(sqlContext.read.parquet(path), makeRows) + } + } } } From 2e680668f7b6fc158aa068aedd19c1878ecf759e Mon Sep 17 00:00:00 2001 From: Tom White Date: Wed, 12 Aug 2015 10:06:27 -0500 Subject: [PATCH 050/253] [SPARK-8625] [CORE] Propagate user exceptions in tasks back to driver This allows clients to retrieve the original exception from the cause field of the SparkException that is thrown by the driver. If the original exception is not in fact Serializable then it will not be returned, but the message and stacktrace will be. (All Java Throwables implement the Serializable interface, but this is no guarantee that a particular implementation can actually be serialized.) Author: Tom White Closes #7014 from tomwhite/propagate-user-exceptions. --- .../org/apache/spark/TaskEndReason.scala | 44 ++++++++++++- .../org/apache/spark/executor/Executor.scala | 14 +++- .../apache/spark/scheduler/DAGScheduler.scala | 44 ++++++++----- .../spark/scheduler/DAGSchedulerEvent.scala | 3 +- .../spark/scheduler/TaskSetManager.scala | 12 ++-- .../org/apache/spark/util/JsonProtocol.scala | 2 +- .../ExecutorAllocationManagerSuite.scala | 2 +- .../scala/org/apache/spark/FailureSuite.scala | 66 ++++++++++++++++++- .../spark/scheduler/DAGSchedulerSuite.scala | 2 +- .../spark/scheduler/TaskSetManagerSuite.scala | 5 +- .../ui/jobs/JobProgressListenerSuite.scala | 2 +- .../apache/spark/util/JsonProtocolSuite.scala | 3 +- 12 files changed, 165 insertions(+), 34 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/TaskEndReason.scala b/core/src/main/scala/org/apache/spark/TaskEndReason.scala index 48fd3e7e23d52..934d00dc708b9 100644 --- a/core/src/main/scala/org/apache/spark/TaskEndReason.scala +++ b/core/src/main/scala/org/apache/spark/TaskEndReason.scala @@ -17,6 +17,8 @@ package org.apache.spark +import java.io.{IOException, ObjectInputStream, ObjectOutputStream} + import org.apache.spark.annotation.DeveloperApi import org.apache.spark.executor.TaskMetrics import org.apache.spark.storage.BlockManagerId @@ -90,6 +92,10 @@ case class FetchFailed( * * `fullStackTrace` is a better representation of the stack trace because it contains the whole * stack trace including the exception and its causes + * + * `exception` is the actual exception that caused the task to fail. It may be `None` in + * the case that the exception is not in fact serializable. If a task fails more than + * once (due to retries), `exception` is that one that caused the last failure. */ @DeveloperApi case class ExceptionFailure( @@ -97,11 +103,26 @@ case class ExceptionFailure( description: String, stackTrace: Array[StackTraceElement], fullStackTrace: String, - metrics: Option[TaskMetrics]) + metrics: Option[TaskMetrics], + private val exceptionWrapper: Option[ThrowableSerializationWrapper]) extends TaskFailedReason { + /** + * `preserveCause` is used to keep the exception itself so it is available to the + * driver. This may be set to `false` in the event that the exception is not in fact + * serializable. + */ + private[spark] def this(e: Throwable, metrics: Option[TaskMetrics], preserveCause: Boolean) { + this(e.getClass.getName, e.getMessage, e.getStackTrace, Utils.exceptionString(e), metrics, + if (preserveCause) Some(new ThrowableSerializationWrapper(e)) else None) + } + private[spark] def this(e: Throwable, metrics: Option[TaskMetrics]) { - this(e.getClass.getName, e.getMessage, e.getStackTrace, Utils.exceptionString(e), metrics) + this(e, metrics, preserveCause = true) + } + + def exception: Option[Throwable] = exceptionWrapper.flatMap { + (w: ThrowableSerializationWrapper) => Option(w.exception) } override def toErrorString: String = @@ -127,6 +148,25 @@ case class ExceptionFailure( } } +/** + * A class for recovering from exceptions when deserializing a Throwable that was + * thrown in user task code. If the Throwable cannot be deserialized it will be null, + * but the stacktrace and message will be preserved correctly in SparkException. + */ +private[spark] class ThrowableSerializationWrapper(var exception: Throwable) extends + Serializable with Logging { + private def writeObject(out: ObjectOutputStream): Unit = { + out.writeObject(exception) + } + private def readObject(in: ObjectInputStream): Unit = { + try { + exception = in.readObject().asInstanceOf[Throwable] + } catch { + case e : Exception => log.warn("Task exception could not be deserialized", e) + } + } +} + /** * :: DeveloperApi :: * The task finished successfully, but the result was lost from the executor's block manager before diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index 5d78a9dc8885e..42a85e42ea2b6 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -17,7 +17,7 @@ package org.apache.spark.executor -import java.io.File +import java.io.{File, NotSerializableException} import java.lang.management.ManagementFactory import java.net.URL import java.nio.ByteBuffer @@ -305,8 +305,16 @@ private[spark] class Executor( m } } - val taskEndReason = new ExceptionFailure(t, metrics) - execBackend.statusUpdate(taskId, TaskState.FAILED, ser.serialize(taskEndReason)) + val serializedTaskEndReason = { + try { + ser.serialize(new ExceptionFailure(t, metrics)) + } catch { + case _: NotSerializableException => + // t is not serializable so just send the stacktrace + ser.serialize(new ExceptionFailure(t, metrics, false)) + } + } + execBackend.statusUpdate(taskId, TaskState.FAILED, serializedTaskEndReason) // Don't forcibly exit unless the exception was inherently fatal, to avoid // stopping other tasks unnecessarily. diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index bb489c6b6e98f..7ab5ccf50adb7 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -200,8 +200,8 @@ class DAGScheduler( // Called by TaskScheduler to cancel an entire TaskSet due to either repeated failures or // cancellation of the job itself. - def taskSetFailed(taskSet: TaskSet, reason: String): Unit = { - eventProcessLoop.post(TaskSetFailed(taskSet, reason)) + def taskSetFailed(taskSet: TaskSet, reason: String, exception: Option[Throwable]): Unit = { + eventProcessLoop.post(TaskSetFailed(taskSet, reason, exception)) } private[scheduler] @@ -677,8 +677,11 @@ class DAGScheduler( submitWaitingStages() } - private[scheduler] def handleTaskSetFailed(taskSet: TaskSet, reason: String) { - stageIdToStage.get(taskSet.stageId).foreach {abortStage(_, reason) } + private[scheduler] def handleTaskSetFailed( + taskSet: TaskSet, + reason: String, + exception: Option[Throwable]): Unit = { + stageIdToStage.get(taskSet.stageId).foreach { abortStage(_, reason, exception) } submitWaitingStages() } @@ -762,7 +765,7 @@ class DAGScheduler( } } } else { - abortStage(stage, "No active job for stage " + stage.id) + abortStage(stage, "No active job for stage " + stage.id, None) } } @@ -816,7 +819,7 @@ class DAGScheduler( case NonFatal(e) => stage.makeNewStageAttempt(partitionsToCompute.size) listenerBus.post(SparkListenerStageSubmitted(stage.latestInfo, properties)) - abortStage(stage, s"Task creation failed: $e\n${e.getStackTraceString}") + abortStage(stage, s"Task creation failed: $e\n${e.getStackTraceString}", Some(e)) runningStages -= stage return } @@ -845,13 +848,13 @@ class DAGScheduler( } catch { // In the case of a failure during serialization, abort the stage. case e: NotSerializableException => - abortStage(stage, "Task not serializable: " + e.toString) + abortStage(stage, "Task not serializable: " + e.toString, Some(e)) runningStages -= stage // Abort execution return case NonFatal(e) => - abortStage(stage, s"Task serialization failed: $e\n${e.getStackTraceString}") + abortStage(stage, s"Task serialization failed: $e\n${e.getStackTraceString}", Some(e)) runningStages -= stage return } @@ -878,7 +881,7 @@ class DAGScheduler( } } catch { case NonFatal(e) => - abortStage(stage, s"Task creation failed: $e\n${e.getStackTraceString}") + abortStage(stage, s"Task creation failed: $e\n${e.getStackTraceString}", Some(e)) runningStages -= stage return } @@ -1098,7 +1101,8 @@ class DAGScheduler( } if (disallowStageRetryForTest) { - abortStage(failedStage, "Fetch failure will not retry stage due to testing config") + abortStage(failedStage, "Fetch failure will not retry stage due to testing config", + None) } else if (failedStages.isEmpty) { // Don't schedule an event to resubmit failed stages if failed isn't empty, because // in that case the event will already have been scheduled. @@ -1126,7 +1130,7 @@ class DAGScheduler( case commitDenied: TaskCommitDenied => // Do nothing here, left up to the TaskScheduler to decide how to handle denied commits - case ExceptionFailure(className, description, stackTrace, fullStackTrace, metrics) => + case exceptionFailure: ExceptionFailure => // Do nothing here, left up to the TaskScheduler to decide how to handle user failures case TaskResultLost => @@ -1235,7 +1239,10 @@ class DAGScheduler( * Aborts all jobs depending on a particular Stage. This is called in response to a task set * being canceled by the TaskScheduler. Use taskSetFailed() to inject this event from outside. */ - private[scheduler] def abortStage(failedStage: Stage, reason: String) { + private[scheduler] def abortStage( + failedStage: Stage, + reason: String, + exception: Option[Throwable]): Unit = { if (!stageIdToStage.contains(failedStage.id)) { // Skip all the actions if the stage has been removed. return @@ -1244,7 +1251,7 @@ class DAGScheduler( activeJobs.filter(job => stageDependsOn(job.finalStage, failedStage)).toSeq failedStage.latestInfo.completionTime = Some(clock.getTimeMillis()) for (job <- dependentJobs) { - failJobAndIndependentStages(job, s"Job aborted due to stage failure: $reason") + failJobAndIndependentStages(job, s"Job aborted due to stage failure: $reason", exception) } if (dependentJobs.isEmpty) { logInfo("Ignoring failure of " + failedStage + " because all jobs depending on it are done") @@ -1252,8 +1259,11 @@ class DAGScheduler( } /** Fails a job and all stages that are only used by that job, and cleans up relevant state. */ - private def failJobAndIndependentStages(job: ActiveJob, failureReason: String) { - val error = new SparkException(failureReason) + private def failJobAndIndependentStages( + job: ActiveJob, + failureReason: String, + exception: Option[Throwable] = None): Unit = { + val error = new SparkException(failureReason, exception.getOrElse(null)) var ableToCancelStages = true val shouldInterruptThread = @@ -1462,8 +1472,8 @@ private[scheduler] class DAGSchedulerEventProcessLoop(dagScheduler: DAGScheduler case completion @ CompletionEvent(task, reason, _, _, taskInfo, taskMetrics) => dagScheduler.handleTaskCompletion(completion) - case TaskSetFailed(taskSet, reason) => - dagScheduler.handleTaskSetFailed(taskSet, reason) + case TaskSetFailed(taskSet, reason, exception) => + dagScheduler.handleTaskSetFailed(taskSet, reason, exception) case ResubmitFailedStages => dagScheduler.resubmitFailedStages() diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala index a213d419cf033..f72a52e85dc15 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala @@ -73,6 +73,7 @@ private[scheduler] case class ExecutorAdded(execId: String, host: String) extend private[scheduler] case class ExecutorLost(execId: String) extends DAGSchedulerEvent private[scheduler] -case class TaskSetFailed(taskSet: TaskSet, reason: String) extends DAGSchedulerEvent +case class TaskSetFailed(taskSet: TaskSet, reason: String, exception: Option[Throwable]) + extends DAGSchedulerEvent private[scheduler] case object ResubmitFailedStages extends DAGSchedulerEvent diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index 82455b0426a5d..818b95d67f6be 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -662,7 +662,7 @@ private[spark] class TaskSetManager( val failureReason = s"Lost task ${info.id} in stage ${taskSet.id} (TID $tid, ${info.host}): " + reason.asInstanceOf[TaskFailedReason].toErrorString - reason match { + val failureException: Option[Throwable] = reason match { case fetchFailed: FetchFailed => logWarning(failureReason) if (!successful(index)) { @@ -671,6 +671,7 @@ private[spark] class TaskSetManager( } // Not adding to failed executors for FetchFailed. isZombie = true + None case ef: ExceptionFailure => taskMetrics = ef.metrics.orNull @@ -706,12 +707,15 @@ private[spark] class TaskSetManager( s"Lost task ${info.id} in stage ${taskSet.id} (TID $tid) on executor ${info.host}: " + s"${ef.className} (${ef.description}) [duplicate $dupCount]") } + ef.exception case e: TaskFailedReason => // TaskResultLost, TaskKilled, and others logWarning(failureReason) + None case e: TaskEndReason => logError("Unknown TaskEndReason: " + e) + None } // always add to failed executors failedExecutors.getOrElseUpdate(index, new HashMap[String, Long]()). @@ -728,16 +732,16 @@ private[spark] class TaskSetManager( logError("Task %d in stage %s failed %d times; aborting job".format( index, taskSet.id, maxTaskFailures)) abort("Task %d in stage %s failed %d times, most recent failure: %s\nDriver stacktrace:" - .format(index, taskSet.id, maxTaskFailures, failureReason)) + .format(index, taskSet.id, maxTaskFailures, failureReason), failureException) return } } maybeFinishTaskSet() } - def abort(message: String): Unit = sched.synchronized { + def abort(message: String, exception: Option[Throwable] = None): Unit = sched.synchronized { // TODO: Kill running tasks if we were not terminated due to a Mesos error - sched.dagScheduler.taskSetFailed(taskSet, message) + sched.dagScheduler.taskSetFailed(taskSet, message, exception) isZombie = true maybeFinishTaskSet() } diff --git a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala index c600319d9ddb4..cbc94fd6d54d9 100644 --- a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala +++ b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala @@ -790,7 +790,7 @@ private[spark] object JsonProtocol { val fullStackTrace = Utils.jsonOption(json \ "Full Stack Trace"). map(_.extract[String]).orNull val metrics = Utils.jsonOption(json \ "Metrics").map(taskMetricsFromJson) - ExceptionFailure(className, description, stackTrace, fullStackTrace, metrics) + ExceptionFailure(className, description, stackTrace, fullStackTrace, metrics, None) case `taskResultLost` => TaskResultLost case `taskKilled` => TaskKilled case `executorLostFailure` => diff --git a/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala b/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala index f374f97f87448..116f027a0f987 100644 --- a/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala @@ -800,7 +800,7 @@ class ExecutorAllocationManagerSuite assert(maxNumExecutorsNeeded(manager) === 1) // If the task is failed, we expect it to be resubmitted later. - val taskEndReason = ExceptionFailure(null, null, null, null, null) + val taskEndReason = ExceptionFailure(null, null, null, null, null, None) sc.listenerBus.postToAll(SparkListenerTaskEnd(0, 0, null, taskEndReason, taskInfo, null)) assert(maxNumExecutorsNeeded(manager) === 1) } diff --git a/core/src/test/scala/org/apache/spark/FailureSuite.scala b/core/src/test/scala/org/apache/spark/FailureSuite.scala index 69cb4b44cf7ef..aa50a49c50232 100644 --- a/core/src/test/scala/org/apache/spark/FailureSuite.scala +++ b/core/src/test/scala/org/apache/spark/FailureSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark import org.apache.spark.util.NonSerializable -import java.io.NotSerializableException +import java.io.{IOException, NotSerializableException, ObjectInputStream} // Common state shared by FailureSuite-launched tasks. We use a global object // for this because any local variables used in the task closures will rightfully @@ -166,5 +166,69 @@ class FailureSuite extends SparkFunSuite with LocalSparkContext { assert(thrownDueToMemoryLeak.getMessage.contains("memory leak")) } + // Run a 3-task map job in which task 1 always fails with a exception message that + // depends on the failure number, and check that we get the last failure. + test("last failure cause is sent back to driver") { + sc = new SparkContext("local[1,2]", "test") + val data = sc.makeRDD(1 to 3, 3).map { x => + FailureSuiteState.synchronized { + FailureSuiteState.tasksRun += 1 + if (x == 3) { + FailureSuiteState.tasksFailed += 1 + throw new UserException("oops", + new IllegalArgumentException("failed=" + FailureSuiteState.tasksFailed)) + } + } + x * x + } + val thrown = intercept[SparkException] { + data.collect() + } + FailureSuiteState.synchronized { + assert(FailureSuiteState.tasksRun === 4) + } + assert(thrown.getClass === classOf[SparkException]) + assert(thrown.getCause.getClass === classOf[UserException]) + assert(thrown.getCause.getMessage === "oops") + assert(thrown.getCause.getCause.getClass === classOf[IllegalArgumentException]) + assert(thrown.getCause.getCause.getMessage === "failed=2") + FailureSuiteState.clear() + } + + test("failure cause stacktrace is sent back to driver if exception is not serializable") { + sc = new SparkContext("local", "test") + val thrown = intercept[SparkException] { + sc.makeRDD(1 to 3).foreach { _ => throw new NonSerializableUserException } + } + assert(thrown.getClass === classOf[SparkException]) + assert(thrown.getCause === null) + assert(thrown.getMessage.contains("NonSerializableUserException")) + FailureSuiteState.clear() + } + + test("failure cause stacktrace is sent back to driver if exception is not deserializable") { + sc = new SparkContext("local", "test") + val thrown = intercept[SparkException] { + sc.makeRDD(1 to 3).foreach { _ => throw new NonDeserializableUserException } + } + assert(thrown.getClass === classOf[SparkException]) + assert(thrown.getCause === null) + assert(thrown.getMessage.contains("NonDeserializableUserException")) + FailureSuiteState.clear() + } + // TODO: Need to add tests with shuffle fetch failures. } + +class UserException(message: String, cause: Throwable) + extends RuntimeException(message, cause) + +class NonSerializableUserException extends RuntimeException { + val nonSerializableInstanceVariable = new NonSerializable +} + +class NonDeserializableUserException extends RuntimeException { + private def readObject(in: ObjectInputStream): Unit = { + throw new IOException("Intentional exception during deserialization.") + } +} diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index 86dff8fb577d5..b0ca49cbea4f7 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -242,7 +242,7 @@ class DAGSchedulerSuite /** Sends TaskSetFailed to the scheduler. */ private def failed(taskSet: TaskSet, message: String) { - runEvent(TaskSetFailed(taskSet, message)) + runEvent(TaskSetFailed(taskSet, message, None)) } /** Sends JobCancelled to the DAG scheduler. */ diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala index f7cc4bb61d574..edbdb485c5ea4 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala @@ -48,7 +48,10 @@ class FakeDAGScheduler(sc: SparkContext, taskScheduler: FakeTaskScheduler) override def executorLost(execId: String) {} - override def taskSetFailed(taskSet: TaskSet, reason: String) { + override def taskSetFailed( + taskSet: TaskSet, + reason: String, + exception: Option[Throwable]): Unit = { taskScheduler.taskSetsFailed += taskSet.id } } diff --git a/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala index 56f7b9cf1f358..b140387d309f3 100644 --- a/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala @@ -240,7 +240,7 @@ class JobProgressListenerSuite extends SparkFunSuite with LocalSparkContext with val taskFailedReasons = Seq( Resubmitted, new FetchFailed(null, 0, 0, 0, "ignored"), - ExceptionFailure("Exception", "description", null, null, None), + ExceptionFailure("Exception", "description", null, null, None, None), TaskResultLost, TaskKilled, ExecutorLostFailure("0"), diff --git a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala index dde95f3778434..343a4139b0ca8 100644 --- a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala @@ -163,7 +163,8 @@ class JsonProtocolSuite extends SparkFunSuite { } test("ExceptionFailure backward compatibility") { - val exceptionFailure = ExceptionFailure("To be", "or not to be", stackTrace, null, None) + val exceptionFailure = ExceptionFailure("To be", "or not to be", stackTrace, null, + None, None) val oldEvent = JsonProtocol.taskEndReasonToJson(exceptionFailure) .removeField({ _._1 == "Full Stack Trace" }) assertEquals(exceptionFailure, JsonProtocol.taskEndReasonFromJson(oldEvent)) From be5d1912076c2ffd21ec88611e53d3b3c59b7ecc Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Wed, 12 Aug 2015 09:24:50 -0700 Subject: [PATCH 051/253] [SPARK-9795] Dynamic allocation: avoid double counting when killing same executor twice This is based on KaiXinXiaoLei's changes in #7716. The issue is that when someone calls `sc.killExecutor("1")` on the same executor twice quickly, then the executor target will be adjusted downwards by 2 instead of 1 even though we're only actually killing one executor. In certain cases where we don't adjust the target back upwards quickly, we'll end up with jobs hanging. This is a common danger because there are many places where this is called: - `HeartbeatReceiver` kills an executor that has not been sending heartbeats - `ExecutorAllocationManager` kills an executor that has been idle - The user code might call this, which may interfere with the previous callers While it's not clear whether this fixes SPARK-9745, fixing this potential race condition seems like a strict improvement. I've added a regression test to illustrate the issue. Author: Andrew Or Closes #8078 from andrewor14/da-double-kill. --- .../CoarseGrainedSchedulerBackend.scala | 11 ++++++---- .../StandaloneDynamicAllocationSuite.scala | 20 +++++++++++++++++++ 2 files changed, 27 insertions(+), 4 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index 6acf8a9a5e9b4..5730a87f960a0 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -422,16 +422,19 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp logWarning(s"Executor to kill $id does not exist!") } + // If an executor is already pending to be removed, do not kill it again (SPARK-9795) + val executorsToKill = knownExecutors.filter { id => !executorsPendingToRemove.contains(id) } + executorsPendingToRemove ++= executorsToKill + // If we do not wish to replace the executors we kill, sync the target number of executors // with the cluster manager to avoid allocating new ones. When computing the new target, // take into account executors that are pending to be added or removed. if (!replace) { - doRequestTotalExecutors(numExistingExecutors + numPendingExecutors - - executorsPendingToRemove.size - knownExecutors.size) + doRequestTotalExecutors( + numExistingExecutors + numPendingExecutors - executorsPendingToRemove.size) } - executorsPendingToRemove ++= knownExecutors - doKillExecutors(knownExecutors) + doKillExecutors(executorsToKill) } /** diff --git a/core/src/test/scala/org/apache/spark/deploy/StandaloneDynamicAllocationSuite.scala b/core/src/test/scala/org/apache/spark/deploy/StandaloneDynamicAllocationSuite.scala index 08c41a897a861..1f2a0f0d309ce 100644 --- a/core/src/test/scala/org/apache/spark/deploy/StandaloneDynamicAllocationSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/StandaloneDynamicAllocationSuite.scala @@ -283,6 +283,26 @@ class StandaloneDynamicAllocationSuite assert(master.apps.head.getExecutorLimit === 1000) } + test("kill the same executor twice (SPARK-9795)") { + sc = new SparkContext(appConf) + val appId = sc.applicationId + assert(master.apps.size === 1) + assert(master.apps.head.id === appId) + assert(master.apps.head.executors.size === 2) + assert(master.apps.head.getExecutorLimit === Int.MaxValue) + // sync executors between the Master and the driver, needed because + // the driver refuses to kill executors it does not know about + syncExecutors(sc) + // kill the same executor twice + val executors = getExecutorIds(sc) + assert(executors.size === 2) + assert(sc.killExecutor(executors.head)) + assert(sc.killExecutor(executors.head)) + assert(master.apps.head.executors.size === 1) + // The limit should not be lowered twice + assert(master.apps.head.getExecutorLimit === 1) + } + // =============================== // | Utility methods for testing | // =============================== From 66d87c1d76bea2b81993156ac1fa7dad6c312ebf Mon Sep 17 00:00:00 2001 From: Yuhao Yang Date: Wed, 12 Aug 2015 09:35:32 -0700 Subject: [PATCH 052/253] [SPARK-7583] [MLLIB] User guide update for RegexTokenizer jira: https://issues.apache.org/jira/browse/SPARK-7583 User guide update for RegexTokenizer Author: Yuhao Yang Closes #7828 from hhbyyh/regexTokenizerDoc. --- docs/ml-features.md | 41 ++++++++++++++++++++++++++++++----------- 1 file changed, 30 insertions(+), 11 deletions(-) diff --git a/docs/ml-features.md b/docs/ml-features.md index fa0ad1f00ab12..cec2cbe673407 100644 --- a/docs/ml-features.md +++ b/docs/ml-features.md @@ -217,21 +217,32 @@ for feature in result.select("result").take(3): [Tokenization](http://en.wikipedia.org/wiki/Lexical_analysis#Tokenization) is the process of taking text (such as a sentence) and breaking it into individual terms (usually words). A simple [Tokenizer](api/scala/index.html#org.apache.spark.ml.feature.Tokenizer) class provides this functionality. The example below shows how to split sentences into sequences of words. -Note: A more advanced tokenizer is provided via [RegexTokenizer](api/scala/index.html#org.apache.spark.ml.feature.RegexTokenizer). +[RegexTokenizer](api/scala/index.html#org.apache.spark.ml.feature.RegexTokenizer) allows more + advanced tokenization based on regular expression (regex) matching. + By default, the parameter "pattern" (regex, default: \\s+) is used as delimiters to split the input text. + Alternatively, users can set parameter "gaps" to false indicating the regex "pattern" denotes + "tokens" rather than splitting gaps, and find all matching occurrences as the tokenization result.
{% highlight scala %} -import org.apache.spark.ml.feature.Tokenizer +import org.apache.spark.ml.feature.{Tokenizer, RegexTokenizer} val sentenceDataFrame = sqlContext.createDataFrame(Seq( (0, "Hi I heard about Spark"), - (0, "I wish Java could use case classes"), - (1, "Logistic regression models are neat") + (1, "I wish Java could use case classes"), + (2, "Logistic,regression,models,are,neat") )).toDF("label", "sentence") val tokenizer = new Tokenizer().setInputCol("sentence").setOutputCol("words") -val wordsDataFrame = tokenizer.transform(sentenceDataFrame) -wordsDataFrame.select("words", "label").take(3).foreach(println) +val regexTokenizer = new RegexTokenizer() + .setInputCol("sentence") + .setOutputCol("words") + .setPattern("\\W") // alternatively .setPattern("\\w+").setGaps(false) + +val tokenized = tokenizer.transform(sentenceDataFrame) +tokenized.select("words", "label").take(3).foreach(println) +val regexTokenized = regexTokenizer.transform(sentenceDataFrame) +regexTokenized.select("words", "label").take(3).foreach(println) {% endhighlight %}
@@ -240,6 +251,7 @@ wordsDataFrame.select("words", "label").take(3).foreach(println) import com.google.common.collect.Lists; import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.ml.feature.RegexTokenizer; import org.apache.spark.ml.feature.Tokenizer; import org.apache.spark.mllib.linalg.Vector; import org.apache.spark.sql.DataFrame; @@ -252,8 +264,8 @@ import org.apache.spark.sql.types.StructType; JavaRDD jrdd = jsc.parallelize(Lists.newArrayList( RowFactory.create(0, "Hi I heard about Spark"), - RowFactory.create(0, "I wish Java could use case classes"), - RowFactory.create(1, "Logistic regression models are neat") + RowFactory.create(1, "I wish Java could use case classes"), + RowFactory.create(2, "Logistic,regression,models,are,neat") )); StructType schema = new StructType(new StructField[]{ new StructField("label", DataTypes.DoubleType, false, Metadata.empty()), @@ -267,22 +279,29 @@ for (Row r : wordsDataFrame.select("words", "label").take(3)) { for (String word : words) System.out.print(word + " "); System.out.println(); } + +RegexTokenizer regexTokenizer = new RegexTokenizer() + .setInputCol("sentence") + .setOutputCol("words") + .setPattern("\\W"); // alternatively .setPattern("\\w+").setGaps(false); {% endhighlight %}
{% highlight python %} -from pyspark.ml.feature import Tokenizer +from pyspark.ml.feature import Tokenizer, RegexTokenizer sentenceDataFrame = sqlContext.createDataFrame([ (0, "Hi I heard about Spark"), - (0, "I wish Java could use case classes"), - (1, "Logistic regression models are neat") + (1, "I wish Java could use case classes"), + (2, "Logistic,regression,models,are,neat") ], ["label", "sentence"]) tokenizer = Tokenizer(inputCol="sentence", outputCol="words") wordsDataFrame = tokenizer.transform(sentenceDataFrame) for words_label in wordsDataFrame.select("words", "label").take(3): print(words_label) +regexTokenizer = RegexTokenizer(inputCol="sentence", outputCol="words", pattern="\\W") +# alternatively, pattern="\\w+", gaps(False) {% endhighlight %}
From e0110792ef71ebfd3727b970346a2e13695990a4 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Wed, 12 Aug 2015 10:08:35 -0700 Subject: [PATCH 053/253] [SPARK-9747] [SQL] Avoid starving an unsafe operator in aggregation This is the sister patch to #8011, but for aggregation. In a nutshell: create the `TungstenAggregationIterator` before computing the parent partition. Internally this creates a `BytesToBytesMap` which acquires a page in the constructor as of this patch. This ensures that the aggregation operator is not starved since we reserve at least 1 page in advance. rxin yhuai Author: Andrew Or Closes #8038 from andrewor14/unsafe-starve-memory-agg. --- .../spark/unsafe/map/BytesToBytesMap.java | 34 +++++-- .../unsafe/sort/UnsafeExternalSorter.java | 9 +- .../map/AbstractBytesToBytesMapSuite.java | 11 ++- .../UnsafeFixedWidthAggregationMap.java | 7 ++ .../aggregate/TungstenAggregate.scala | 72 +++++++++------ .../TungstenAggregationIterator.scala | 88 +++++++++++-------- .../TungstenAggregationIteratorSuite.scala | 56 ++++++++++++ 7 files changed, 201 insertions(+), 76 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIteratorSuite.scala diff --git a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java index 85b46ec8bfae3..87ed47e88c4ef 100644 --- a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java +++ b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java @@ -193,6 +193,11 @@ public BytesToBytesMap( TaskMemoryManager.MAXIMUM_PAGE_SIZE_BYTES); } allocate(initialCapacity); + + // Acquire a new page as soon as we construct the map to ensure that we have at least + // one page to work with. Otherwise, other operators in the same task may starve this + // map (SPARK-9747). + acquireNewPage(); } public BytesToBytesMap( @@ -574,16 +579,9 @@ public boolean putNewKey( final long lengthOffsetInPage = currentDataPage.getBaseOffset() + pageCursor; Platform.putInt(pageBaseObject, lengthOffsetInPage, END_OF_PAGE_MARKER); } - final long memoryGranted = shuffleMemoryManager.tryToAcquire(pageSizeBytes); - if (memoryGranted != pageSizeBytes) { - shuffleMemoryManager.release(memoryGranted); - logger.debug("Failed to acquire {} bytes of memory", pageSizeBytes); + if (!acquireNewPage()) { return false; } - MemoryBlock newPage = taskMemoryManager.allocatePage(pageSizeBytes); - dataPages.add(newPage); - pageCursor = 0; - currentDataPage = newPage; dataPage = currentDataPage; dataPageBaseObject = currentDataPage.getBaseObject(); dataPageInsertOffset = currentDataPage.getBaseOffset(); @@ -642,6 +640,24 @@ public boolean putNewKey( } } + /** + * Acquire a new page from the {@link ShuffleMemoryManager}. + * @return whether there is enough space to allocate the new page. + */ + private boolean acquireNewPage() { + final long memoryGranted = shuffleMemoryManager.tryToAcquire(pageSizeBytes); + if (memoryGranted != pageSizeBytes) { + shuffleMemoryManager.release(memoryGranted); + logger.debug("Failed to acquire {} bytes of memory", pageSizeBytes); + return false; + } + MemoryBlock newPage = taskMemoryManager.allocatePage(pageSizeBytes); + dataPages.add(newPage); + pageCursor = 0; + currentDataPage = newPage; + return true; + } + /** * Allocate new data structures for this map. When calling this outside of the constructor, * make sure to keep references to the old data structures so that you can free them. @@ -748,7 +764,7 @@ public long getNumHashCollisions() { } @VisibleForTesting - int getNumDataPages() { + public int getNumDataPages() { return dataPages.size(); } diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java index 9601aafe55464..fc364e0a895b1 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java @@ -132,16 +132,15 @@ private UnsafeExternalSorter( if (existingInMemorySorter == null) { initializeForWriting(); + // Acquire a new page as soon as we construct the sorter to ensure that we have at + // least one page to work with. Otherwise, other operators in the same task may starve + // this sorter (SPARK-9709). We don't need to do this if we already have an existing sorter. + acquireNewPage(); } else { this.isInMemSorterExternal = true; this.inMemSorter = existingInMemorySorter; } - // Acquire a new page as soon as we construct the sorter to ensure that we have at - // least one page to work with. Otherwise, other operators in the same task may starve - // this sorter (SPARK-9709). - acquireNewPage(); - // Register a cleanup task with TaskContext to ensure that memory is guaranteed to be freed at // the end of the task. This is necessary to avoid memory leaks in when the downstream operator // does not fully consume the sorter's output (e.g. sort followed by limit). diff --git a/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java b/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java index 1a79c20c35246..ab480b60adaed 100644 --- a/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java +++ b/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java @@ -543,7 +543,7 @@ public void testPeakMemoryUsed() { Platform.LONG_ARRAY_OFFSET, 8); newPeakMemory = map.getPeakMemoryUsedBytes(); - if (i % numRecordsPerPage == 0) { + if (i % numRecordsPerPage == 0 && i > 0) { // We allocated a new page for this record, so peak memory should change assertEquals(previousPeakMemory + pageSizeBytes, newPeakMemory); } else { @@ -561,4 +561,13 @@ public void testPeakMemoryUsed() { map.free(); } } + + @Test + public void testAcquirePageInConstructor() { + final BytesToBytesMap map = new BytesToBytesMap( + taskMemoryManager, shuffleMemoryManager, 1, PAGE_SIZE_BYTES); + assertEquals(1, map.getNumDataPages()); + map.free(); + } + } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java index 5cce41d5a7569..09511ff35f785 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java @@ -19,6 +19,8 @@ import java.io.IOException; +import com.google.common.annotations.VisibleForTesting; + import org.apache.spark.SparkEnv; import org.apache.spark.shuffle.ShuffleMemoryManager; import org.apache.spark.sql.catalyst.InternalRow; @@ -220,6 +222,11 @@ public long getPeakMemoryUsedBytes() { return map.getPeakMemoryUsedBytes(); } + @VisibleForTesting + public int getNumDataPages() { + return map.getNumDataPages(); + } + /** * Free the memory associated with this map. This is idempotent and can be called multiple times. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala index 6b5935a7ce296..c40ca973796a6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala @@ -17,12 +17,13 @@ package org.apache.spark.sql.execution.aggregate -import org.apache.spark.rdd.RDD +import org.apache.spark.TaskContext +import org.apache.spark.rdd.{MapPartitionsWithPreparationRDD, RDD} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.errors._ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression2 import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.physical.{UnspecifiedDistribution, ClusteredDistribution, AllTuples, Distribution} +import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.{UnaryNode, SparkPlan} import org.apache.spark.sql.execution.metric.SQLMetrics @@ -68,35 +69,56 @@ case class TungstenAggregate( protected override def doExecute(): RDD[InternalRow] = attachTree(this, "execute") { val numInputRows = longMetric("numInputRows") val numOutputRows = longMetric("numOutputRows") - child.execute().mapPartitions { iter => - val hasInput = iter.hasNext - if (!hasInput && groupingExpressions.nonEmpty) { - // This is a grouped aggregate and the input iterator is empty, - // so return an empty iterator. - Iterator.empty.asInstanceOf[Iterator[UnsafeRow]] - } else { - val aggregationIterator = - new TungstenAggregationIterator( - groupingExpressions, - nonCompleteAggregateExpressions, - completeAggregateExpressions, - initialInputBufferOffset, - resultExpressions, - newMutableProjection, - child.output, - iter, - testFallbackStartsAt, - numInputRows, - numOutputRows) - - if (!hasInput && groupingExpressions.isEmpty) { + + /** + * Set up the underlying unsafe data structures used before computing the parent partition. + * This makes sure our iterator is not starved by other operators in the same task. + */ + def preparePartition(): TungstenAggregationIterator = { + new TungstenAggregationIterator( + groupingExpressions, + nonCompleteAggregateExpressions, + completeAggregateExpressions, + initialInputBufferOffset, + resultExpressions, + newMutableProjection, + child.output, + testFallbackStartsAt, + numInputRows, + numOutputRows) + } + + /** Compute a partition using the iterator already set up previously. */ + def executePartition( + context: TaskContext, + partitionIndex: Int, + aggregationIterator: TungstenAggregationIterator, + parentIterator: Iterator[InternalRow]): Iterator[UnsafeRow] = { + val hasInput = parentIterator.hasNext + if (!hasInput) { + // We're not using the underlying map, so we just can free it here + aggregationIterator.free() + if (groupingExpressions.isEmpty) { numOutputRows += 1 Iterator.single[UnsafeRow](aggregationIterator.outputForEmptyGroupingKeyWithoutInput()) } else { - aggregationIterator + // This is a grouped aggregate and the input iterator is empty, + // so return an empty iterator. + Iterator[UnsafeRow]() } + } else { + aggregationIterator.start(parentIterator) + aggregationIterator } } + + // Note: we need to set up the iterator in each partition before computing the + // parent partition, so we cannot simply use `mapPartitions` here (SPARK-9747). + val resultRdd = { + new MapPartitionsWithPreparationRDD[UnsafeRow, InternalRow, TungstenAggregationIterator]( + child.execute(), preparePartition, executePartition, preservesPartitioning = true) + } + resultRdd.asInstanceOf[RDD[InternalRow]] } override def simpleString: String = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala index 1f383dd04482f..af7e0fcedbe4e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala @@ -72,8 +72,6 @@ import org.apache.spark.sql.types.StructType * the function used to create mutable projections. * @param originalInputAttributes * attributes of representing input rows from `inputIter`. - * @param inputIter - * the iterator containing input [[UnsafeRow]]s. */ class TungstenAggregationIterator( groupingExpressions: Seq[NamedExpression], @@ -83,12 +81,14 @@ class TungstenAggregationIterator( resultExpressions: Seq[NamedExpression], newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection), originalInputAttributes: Seq[Attribute], - inputIter: Iterator[InternalRow], testFallbackStartsAt: Option[Int], numInputRows: LongSQLMetric, numOutputRows: LongSQLMetric) extends Iterator[UnsafeRow] with Logging { + // The parent partition iterator, to be initialized later in `start` + private[this] var inputIter: Iterator[InternalRow] = null + /////////////////////////////////////////////////////////////////////////// // Part 1: Initializing aggregate functions. /////////////////////////////////////////////////////////////////////////// @@ -348,11 +348,15 @@ class TungstenAggregationIterator( false // disable tracking of performance metrics ) + // Exposed for testing + private[aggregate] def getHashMap: UnsafeFixedWidthAggregationMap = hashMap + // The function used to read and process input rows. When processing input rows, // it first uses hash-based aggregation by putting groups and their buffers in // hashMap. If we could not allocate more memory for the map, we switch to // sort-based aggregation (by calling switchToSortBasedAggregation). private def processInputs(): Unit = { + assert(inputIter != null, "attempted to process input when iterator was null") while (!sortBased && inputIter.hasNext) { val newInput = inputIter.next() numInputRows += 1 @@ -372,6 +376,7 @@ class TungstenAggregationIterator( // that it switch to sort-based aggregation after `fallbackStartsAt` input rows have // been processed. private def processInputsWithControlledFallback(fallbackStartsAt: Int): Unit = { + assert(inputIter != null, "attempted to process input when iterator was null") var i = 0 while (!sortBased && inputIter.hasNext) { val newInput = inputIter.next() @@ -412,6 +417,7 @@ class TungstenAggregationIterator( * Switch to sort-based aggregation when the hash-based approach is unable to acquire memory. */ private def switchToSortBasedAggregation(firstKey: UnsafeRow, firstInput: InternalRow): Unit = { + assert(inputIter != null, "attempted to process input when iterator was null") logInfo("falling back to sort based aggregation.") // Step 1: Get the ExternalSorter containing sorted entries of the map. externalSorter = hashMap.destructAndCreateExternalSorter() @@ -431,6 +437,11 @@ class TungstenAggregationIterator( case _ => false } + // Note: Since we spill the sorter's contents immediately after creating it, we must insert + // something into the sorter here to ensure that we acquire at least a page of memory. + // This is done through `externalSorter.insertKV`, which will trigger the page allocation. + // Otherwise, children operators may steal the window of opportunity and starve our sorter. + if (needsProcess) { // First, we create a buffer. val buffer = createNewAggregationBuffer() @@ -588,27 +599,33 @@ class TungstenAggregationIterator( // have not switched to sort-based aggregation. /////////////////////////////////////////////////////////////////////////// - // Starts to process input rows. - testFallbackStartsAt match { - case None => - processInputs() - case Some(fallbackStartsAt) => - // This is the testing path. processInputsWithControlledFallback is same as processInputs - // except that it switches to sort-based aggregation after `fallbackStartsAt` input rows - // have been processed. - processInputsWithControlledFallback(fallbackStartsAt) - } + /** + * Start processing input rows. + * Only after this method is called will this iterator be non-empty. + */ + def start(parentIter: Iterator[InternalRow]): Unit = { + inputIter = parentIter + testFallbackStartsAt match { + case None => + processInputs() + case Some(fallbackStartsAt) => + // This is the testing path. processInputsWithControlledFallback is same as processInputs + // except that it switches to sort-based aggregation after `fallbackStartsAt` input rows + // have been processed. + processInputsWithControlledFallback(fallbackStartsAt) + } - // If we did not switch to sort-based aggregation in processInputs, - // we pre-load the first key-value pair from the map (to make hasNext idempotent). - if (!sortBased) { - // First, set aggregationBufferMapIterator. - aggregationBufferMapIterator = hashMap.iterator() - // Pre-load the first key-value pair from the aggregationBufferMapIterator. - mapIteratorHasNext = aggregationBufferMapIterator.next() - // If the map is empty, we just free it. - if (!mapIteratorHasNext) { - hashMap.free() + // If we did not switch to sort-based aggregation in processInputs, + // we pre-load the first key-value pair from the map (to make hasNext idempotent). + if (!sortBased) { + // First, set aggregationBufferMapIterator. + aggregationBufferMapIterator = hashMap.iterator() + // Pre-load the first key-value pair from the aggregationBufferMapIterator. + mapIteratorHasNext = aggregationBufferMapIterator.next() + // If the map is empty, we just free it. + if (!mapIteratorHasNext) { + hashMap.free() + } } } @@ -673,21 +690,20 @@ class TungstenAggregationIterator( } /////////////////////////////////////////////////////////////////////////// - // Part 8: A utility function used to generate a output row when there is no - // input and there is no grouping expression. + // Part 8: Utility functions /////////////////////////////////////////////////////////////////////////// + /** + * Generate a output row when there is no input and there is no grouping expression. + */ def outputForEmptyGroupingKeyWithoutInput(): UnsafeRow = { - if (groupingExpressions.isEmpty) { - sortBasedAggregationBuffer.copyFrom(initialAggregationBuffer) - // We create a output row and copy it. So, we can free the map. - val resultCopy = - generateOutput(UnsafeRow.createFromByteArray(0, 0), sortBasedAggregationBuffer).copy() - hashMap.free() - resultCopy - } else { - throw new IllegalStateException( - "This method should not be called when groupingExpressions is not empty.") - } + assert(groupingExpressions.isEmpty) + assert(inputIter == null) + generateOutput(UnsafeRow.createFromByteArray(0, 0), initialAggregationBuffer) + } + + /** Free memory used in the underlying map. */ + def free(): Unit = { + hashMap.free() } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIteratorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIteratorSuite.scala new file mode 100644 index 0000000000000..ac22c2f3c0a58 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIteratorSuite.scala @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.aggregate + +import org.apache.spark._ +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.InterpretedMutableProjection +import org.apache.spark.sql.execution.metric.SQLMetrics +import org.apache.spark.sql.test.TestSQLContext +import org.apache.spark.unsafe.memory.TaskMemoryManager + +class TungstenAggregationIteratorSuite extends SparkFunSuite { + + test("memory acquired on construction") { + // set up environment + val ctx = TestSQLContext + + val taskMemoryManager = new TaskMemoryManager(SparkEnv.get.executorMemoryManager) + val taskContext = new TaskContextImpl(0, 0, 0, 0, taskMemoryManager, null, Seq.empty) + TaskContext.setTaskContext(taskContext) + + // Assert that a page is allocated before processing starts + var iter: TungstenAggregationIterator = null + try { + val newMutableProjection = (expr: Seq[Expression], schema: Seq[Attribute]) => { + () => new InterpretedMutableProjection(expr, schema) + } + val dummyAccum = SQLMetrics.createLongMetric(ctx.sparkContext, "dummy") + iter = new TungstenAggregationIterator(Seq.empty, Seq.empty, Seq.empty, 0, + Seq.empty, newMutableProjection, Seq.empty, None, dummyAccum, dummyAccum) + val numPages = iter.getHashMap.getNumDataPages + assert(numPages === 1) + } finally { + // Clean up + if (iter != null) { + iter.free() + } + TaskContext.unset() + } + } +} From 57ec27dd7784ce15a2ece8a6c8ac7bd5fd25aea2 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Wed, 12 Aug 2015 10:38:30 -0700 Subject: [PATCH 054/253] [SPARK-9804] [HIVE] Use correct value for isSrcLocal parameter. If the correct parameter is not provided, Hive will run into an error because it calls methods that are specific to the local filesystem to copy the data. Author: Marcelo Vanzin Closes #8086 from vanzin/SPARK-9804. --- .../org/apache/spark/sql/hive/client/HiveShim.scala | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala index 6e826ce552204..8fc8935b1dc3c 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala @@ -25,7 +25,7 @@ import java.util.concurrent.TimeUnit import scala.collection.JavaConversions._ -import org.apache.hadoop.fs.Path +import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.hadoop.hive.conf.HiveConf import org.apache.hadoop.hive.ql.Driver import org.apache.hadoop.hive.ql.metadata.{Hive, Partition, Table} @@ -429,7 +429,7 @@ private[client] class Shim_v0_14 extends Shim_v0_13 { isSkewedStoreAsSubdir: Boolean): Unit = { loadPartitionMethod.invoke(hive, loadPath, tableName, partSpec, replace: JBoolean, holdDDLTime: JBoolean, inheritTableSpecs: JBoolean, isSkewedStoreAsSubdir: JBoolean, - JBoolean.TRUE, JBoolean.FALSE) + isSrcLocal(loadPath, hive.getConf()): JBoolean, JBoolean.FALSE) } override def loadTable( @@ -439,7 +439,7 @@ private[client] class Shim_v0_14 extends Shim_v0_13 { replace: Boolean, holdDDLTime: Boolean): Unit = { loadTableMethod.invoke(hive, loadPath, tableName, replace: JBoolean, holdDDLTime: JBoolean, - JBoolean.TRUE, JBoolean.FALSE, JBoolean.FALSE) + isSrcLocal(loadPath, hive.getConf()): JBoolean, JBoolean.FALSE, JBoolean.FALSE) } override def loadDynamicPartitions( @@ -461,6 +461,13 @@ private[client] class Shim_v0_14 extends Shim_v0_13 { HiveConf.ConfVars.METASTORE_CLIENT_CONNECT_RETRY_DELAY, TimeUnit.MILLISECONDS).asInstanceOf[Long] } + + protected def isSrcLocal(path: Path, conf: HiveConf): Boolean = { + val localFs = FileSystem.getLocal(conf) + val pathFs = FileSystem.get(path.toUri(), conf) + localFs.getUri() == pathFs.getUri() + } + } private[client] class Shim_v1_0 extends Shim_v0_14 { From 70fe558867ccb4bcff6ec673438b03608bb02252 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Wed, 12 Aug 2015 10:48:52 -0700 Subject: [PATCH 055/253] [SPARK-9847] [ML] Modified copyValues to distinguish between default, explicit param values From JIRA: Currently, Params.copyValues copies default parameter values to the paramMap of the target instance, rather than the defaultParamMap. It should copy to the defaultParamMap because explicitly setting a parameter can change the semantics. This issue arose in SPARK-9789, where 2 params "threshold" and "thresholds" for LogisticRegression can have mutually exclusive values. If thresholds is set, then fit() will copy the default value of threshold as well, easily resulting in inconsistent settings for the 2 params. CC: mengxr Author: Joseph K. Bradley Closes #8115 from jkbradley/copyvalues-fix. --- .../org/apache/spark/ml/param/params.scala | 19 ++++++++++++++++--- .../apache/spark/ml/param/ParamsSuite.scala | 8 ++++++++ 2 files changed, 24 insertions(+), 3 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala index d68f5ff0053c9..91c0a5631319d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala @@ -559,13 +559,26 @@ trait Params extends Identifiable with Serializable { /** * Copies param values from this instance to another instance for params shared by them. - * @param to the target instance - * @param extra extra params to be copied + * + * This handles default Params and explicitly set Params separately. + * Default Params are copied from and to [[defaultParamMap]], and explicitly set Params are + * copied from and to [[paramMap]]. + * Warning: This implicitly assumes that this [[Params]] instance and the target instance + * share the same set of default Params. + * + * @param to the target instance, which should work with the same set of default Params as this + * source instance + * @param extra extra params to be copied to the target's [[paramMap]] * @return the target instance with param values copied */ protected def copyValues[T <: Params](to: T, extra: ParamMap = ParamMap.empty): T = { - val map = extractParamMap(extra) + val map = paramMap ++ extra params.foreach { param => + // copy default Params + if (defaultParamMap.contains(param) && to.hasParam(param.name)) { + to.defaultParamMap.put(to.getParam(param.name), defaultParamMap(param)) + } + // copy explicitly set Params if (map.contains(param) && to.hasParam(param.name)) { to.set(param.name, map(param)) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala index 050d4170ea017..be95638d81686 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala @@ -200,6 +200,14 @@ class ParamsSuite extends SparkFunSuite { val inArray = ParamValidators.inArray[Int](Array(1, 2)) assert(inArray(1) && inArray(2) && !inArray(0)) } + + test("Params.copyValues") { + val t = new TestParams() + val t2 = t.copy(ParamMap.empty) + assert(!t2.isSet(t2.maxIter)) + val t3 = t.copy(ParamMap(t.maxIter -> 20)) + assert(t3.isSet(t3.maxIter)) + } } object ParamsSuite extends SparkFunSuite { From 60103ecd3d9c92709a5878be7ebd57012813ab48 Mon Sep 17 00:00:00 2001 From: Brennan Ashton Date: Wed, 12 Aug 2015 11:57:30 -0700 Subject: [PATCH 056/253] [SPARK-9726] [PYTHON] PySpark DF join no longer accepts on=None rxin First pull request for Spark so let me know if I am missing anything The contribution is my original work and I license the work to the project under the project's open source license. Author: Brennan Ashton Closes #8016 from btashton/patch-1. --- python/pyspark/sql/dataframe.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 47d5a6a43a84d..09647ff6d0749 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -566,8 +566,7 @@ def join(self, other, on=None, how=None): if on is None or len(on) == 0: jdf = self._jdf.join(other._jdf) - - if isinstance(on[0], basestring): + elif isinstance(on[0], basestring): jdf = self._jdf.join(other._jdf, self._jseq(on)) else: assert isinstance(on[0], Column), "on should be Column or list of Column" From 762bacc16ac5e74c8b05a7c1e3e367d1d1633cef Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Wed, 12 Aug 2015 13:24:18 -0700 Subject: [PATCH 057/253] [SPARK-9766] [ML] [PySpark] check and add miss docs for PySpark ML Check and add miss docs for PySpark ML (this issue only check miss docs for o.a.s.ml not o.a.s.mllib). Author: Yanbo Liang Closes #8059 from yanboliang/SPARK-9766. --- python/pyspark/ml/classification.py | 12 ++++++++++-- python/pyspark/ml/clustering.py | 4 +++- python/pyspark/ml/evaluation.py | 3 ++- python/pyspark/ml/feature.py | 9 +++++---- 4 files changed, 20 insertions(+), 8 deletions(-) diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index 5978d8f4d3a01..6702dce5545a9 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -34,6 +34,7 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti HasRegParam, HasTol, HasProbabilityCol, HasRawPredictionCol): """ Logistic regression. + Currently, this class only supports binary classification. >>> from pyspark.sql import Row >>> from pyspark.mllib.linalg import Vectors @@ -96,8 +97,8 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred # is an L2 penalty. For alpha = 1, it is an L1 penalty. self.elasticNetParam = \ Param(self, "elasticNetParam", - "the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, the penalty " + - "is an L2 penalty. For alpha = 1, it is an L1 penalty.") + "the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, " + + "the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty.") #: param for whether to fit an intercept term. self.fitIntercept = Param(self, "fitIntercept", "whether to fit an intercept term.") #: param for threshold in binary classification prediction, in range [0, 1]. @@ -656,6 +657,13 @@ class NaiveBayes(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, H HasRawPredictionCol): """ Naive Bayes Classifiers. + It supports both Multinomial and Bernoulli NB. Multinomial NB + (`http://nlp.stanford.edu/IR-book/html/htmledition/naive-bayes-text-classification-1.html`) + can handle finitely supported discrete data. For example, by converting documents into + TF-IDF vectors, it can be used for document classification. By making every vector a + binary (0/1) data, it can also be used as Bernoulli NB + (`http://nlp.stanford.edu/IR-book/html/htmledition/the-bernoulli-model-1.html`). + The input feature values must be nonnegative. >>> from pyspark.sql import Row >>> from pyspark.mllib.linalg import Vectors diff --git a/python/pyspark/ml/clustering.py b/python/pyspark/ml/clustering.py index b5e9b6549d9f1..48338713a29ea 100644 --- a/python/pyspark/ml/clustering.py +++ b/python/pyspark/ml/clustering.py @@ -37,7 +37,9 @@ def clusterCenters(self): @inherit_doc class KMeans(JavaEstimator, HasFeaturesCol, HasMaxIter, HasSeed): """ - K-means Clustering + K-means clustering with support for multiple parallel runs and a k-means++ like initialization + mode (the k-means|| algorithm by Bahmani et al). When multiple concurrent runs are requested, + they are executed together with joint passes over the data for efficiency. >>> from pyspark.mllib.linalg import Vectors >>> data = [(Vectors.dense([0.0, 0.0]),), (Vectors.dense([1.0, 1.0]),), diff --git a/python/pyspark/ml/evaluation.py b/python/pyspark/ml/evaluation.py index 06e809352225b..2734092575ad9 100644 --- a/python/pyspark/ml/evaluation.py +++ b/python/pyspark/ml/evaluation.py @@ -23,7 +23,8 @@ from pyspark.ml.util import keyword_only from pyspark.mllib.common import inherit_doc -__all__ = ['Evaluator', 'BinaryClassificationEvaluator', 'RegressionEvaluator'] +__all__ = ['Evaluator', 'BinaryClassificationEvaluator', 'RegressionEvaluator', + 'MulticlassClassificationEvaluator'] @inherit_doc diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index cb4dfa21298ce..535d55326646c 100644 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -26,10 +26,11 @@ from pyspark.mllib.common import inherit_doc from pyspark.mllib.linalg import _convert_to_vector -__all__ = ['Binarizer', 'HashingTF', 'IDF', 'IDFModel', 'NGram', 'Normalizer', 'OneHotEncoder', - 'PolynomialExpansion', 'RegexTokenizer', 'StandardScaler', 'StandardScalerModel', - 'StringIndexer', 'StringIndexerModel', 'Tokenizer', 'VectorAssembler', 'VectorIndexer', - 'Word2Vec', 'Word2VecModel', 'PCA', 'PCAModel', 'RFormula', 'RFormulaModel'] +__all__ = ['Binarizer', 'Bucketizer', 'HashingTF', 'IDF', 'IDFModel', 'NGram', 'Normalizer', + 'OneHotEncoder', 'PolynomialExpansion', 'RegexTokenizer', 'StandardScaler', + 'StandardScalerModel', 'StringIndexer', 'StringIndexerModel', 'Tokenizer', + 'VectorAssembler', 'VectorIndexer', 'Word2Vec', 'Word2VecModel', 'PCA', + 'PCAModel', 'RFormula', 'RFormulaModel'] @inherit_doc From 8a331d0444f58d3c14c1c12c4f087f1a02d5b8d1 Mon Sep 17 00:00:00 2001 From: Neelesh Srinivas Salian Date: Wed, 12 Aug 2015 14:19:58 -0700 Subject: [PATCH 058/253] Moved Master URLs --- docs/submitting-applications.md | 48 ++++++++++++++++----------------- 1 file changed, 24 insertions(+), 24 deletions(-) diff --git a/docs/submitting-applications.md b/docs/submitting-applications.md index 8778f719e2501..d864bc9f59ffb 100644 --- a/docs/submitting-applications.md +++ b/docs/submitting-applications.md @@ -62,30 +62,6 @@ For submitting application to YARN, the preferred options are: * `--master`: The --master parameter is either `yarn-client` or `yarn-cluster`. Defaults to `yarn-client` -# Master URLs - -The master URL passed to Spark can be in one of the following formats: - - - - - - - - - - -
Master URLMeaning
local Run Spark locally with one worker thread (i.e. no parallelism at all).
local[K] Run Spark locally with K worker threads (ideally, set this to the number of cores on your machine).
local[*] Run Spark locally with as many worker threads as logical cores on your machine.
spark://HOST:PORT Connect to the given Spark standalone - cluster master. The port must be whichever one your master is configured to use, which is 7077 by default. -
mesos://HOST:PORT Connect to the given Mesos cluster. - The port must be whichever one your is configured to use, which is 5050 by default. - Or, for a Mesos cluster using ZooKeeper, use mesos://zk://.... -
yarn-client Connect to a YARN cluster in -client mode. The cluster location will be found based on the HADOOP_CONF_DIR or YARN_CONF_DIR variable. -
yarn-cluster Connect to a YARN cluster in -cluster mode. The cluster location will be found based on the HADOOP_CONF_DIR or YARN_CONF_DIR variable. -
- A common deployment strategy is to submit your application from a gateway machine that is physically co-located with your worker machines (e.g. Master node in a standalone EC2 cluster). @@ -154,6 +130,30 @@ export HADOOP_CONF_DIR=XXX 1000 {% endhighlight %} +# Master URLs + +The master URL passed to Spark can be in one of the following formats: + + + + + + + + + + +
Master URLMeaning
local Run Spark locally with one worker thread (i.e. no parallelism at all).
local[K] Run Spark locally with K worker threads (ideally, set this to the number of cores on your machine).
local[*] Run Spark locally with as many worker threads as logical cores on your machine.
spark://HOST:PORT Connect to the given Spark standalone + cluster master. The port must be whichever one your master is configured to use, which is 7077 by default. +
mesos://HOST:PORT Connect to the given Mesos cluster. + The port must be whichever one your is configured to use, which is 5050 by default. + Or, for a Mesos cluster using ZooKeeper, use mesos://zk://.... +
yarn-client Connect to a YARN cluster in +client mode. The cluster location will be found based on the HADOOP_CONF_DIR or YARN_CONF_DIR variable. +
yarn-cluster Connect to a YARN cluster in +cluster mode. The cluster location will be found based on the HADOOP_CONF_DIR or YARN_CONF_DIR variable. +
+ # Loading Configuration from a File The `spark-submit` script can load default [Spark configuration values](configuration.html) from a From 551def5d6972440365bd7436d484a67138d9a8f3 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Wed, 12 Aug 2015 14:27:13 -0700 Subject: [PATCH 059/253] [SPARK-9789] [ML] Added logreg threshold param back Reinstated LogisticRegression.threshold Param for binary compatibility. Param thresholds overrides threshold, if set. CC: mengxr dbtsai feynmanliang Author: Joseph K. Bradley Closes #8079 from jkbradley/logreg-reinstate-threshold. --- .../classification/LogisticRegression.scala | 127 ++++++++++++++---- .../ml/param/shared/SharedParamsCodeGen.scala | 4 +- .../spark/ml/param/shared/sharedParams.scala | 6 +- .../JavaLogisticRegressionSuite.java | 7 +- .../LogisticRegressionSuite.scala | 33 +++-- python/pyspark/ml/classification.py | 98 +++++++++----- 6 files changed, 199 insertions(+), 76 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index f55134d258857..5bcd7117b668c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -34,8 +34,7 @@ import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer import org.apache.spark.mllib.util.MLUtils import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{DataFrame, Row, SQLContext} -import org.apache.spark.sql.functions.{col, udf} +import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.storage.StorageLevel /** @@ -43,44 +42,115 @@ import org.apache.spark.storage.StorageLevel */ private[classification] trait LogisticRegressionParams extends ProbabilisticClassifierParams with HasRegParam with HasElasticNetParam with HasMaxIter with HasFitIntercept with HasTol - with HasStandardization { + with HasStandardization with HasThreshold { /** - * Version of setThresholds() for binary classification, available for backwards - * compatibility. + * Set threshold in binary classification, in range [0, 1]. * - * Calling this with threshold p will effectively call `setThresholds(Array(1-p, p))`. + * If the estimated probability of class label 1 is > threshold, then predict 1, else 0. + * A high threshold encourages the model to predict 0 more often; + * a low threshold encourages the model to predict 1 more often. + * + * Note: Calling this with threshold p is equivalent to calling `setThresholds(Array(1-p, p))`. + * When [[setThreshold()]] is called, any user-set value for [[thresholds]] will be cleared. + * If both [[threshold]] and [[thresholds]] are set in a ParamMap, then they must be + * equivalent. + * + * Default is 0.5. + * @group setParam + */ + def setThreshold(value: Double): this.type = { + if (isSet(thresholds)) clear(thresholds) + set(threshold, value) + } + + /** + * Get threshold for binary classification. + * + * If [[threshold]] is set, returns that value. + * Otherwise, if [[thresholds]] is set with length 2 (i.e., binary classification), + * this returns the equivalent threshold: {{{1 / (1 + thresholds(0) / thresholds(1))}}}. + * Otherwise, returns [[threshold]] default value. + * + * @group getParam + * @throws IllegalArgumentException if [[thresholds]] is set to an array of length other than 2. + */ + override def getThreshold: Double = { + checkThresholdConsistency() + if (isSet(thresholds)) { + val ts = $(thresholds) + require(ts.length == 2, "Logistic Regression getThreshold only applies to" + + " binary classification, but thresholds has length != 2. thresholds: " + ts.mkString(",")) + 1.0 / (1.0 + ts(0) / ts(1)) + } else { + $(threshold) + } + } + + /** + * Set thresholds in multiclass (or binary) classification to adjust the probability of + * predicting each class. Array must have length equal to the number of classes, with values >= 0. + * The class with largest value p/t is predicted, where p is the original probability of that + * class and t is the class' threshold. + * + * Note: When [[setThresholds()]] is called, any user-set value for [[threshold]] will be cleared. + * If both [[threshold]] and [[thresholds]] are set in a ParamMap, then they must be + * equivalent. * - * Default is effectively 0.5. * @group setParam */ - def setThreshold(value: Double): this.type = set(thresholds, Array(1.0 - value, value)) + def setThresholds(value: Array[Double]): this.type = { + if (isSet(threshold)) clear(threshold) + set(thresholds, value) + } /** - * Version of [[getThresholds()]] for binary classification, available for backwards - * compatibility. + * Get thresholds for binary or multiclass classification. + * + * If [[thresholds]] is set, return its value. + * Otherwise, if [[threshold]] is set, return the equivalent thresholds for binary + * classification: (1-threshold, threshold). + * If neither are set, throw an exception. * - * Param thresholds must have length 2 (or not be specified). - * This returns {{{1 / (1 + thresholds(0) / thresholds(1))}}}. * @group getParam */ - def getThreshold: Double = { - if (isDefined(thresholds)) { - val thresholdValues = $(thresholds) - assert(thresholdValues.length == 2, "Logistic Regression getThreshold only applies to" + - " binary classification, but thresholds has length != 2." + - s" thresholds: ${thresholdValues.mkString(",")}") - 1.0 / (1.0 + thresholdValues(0) / thresholdValues(1)) + override def getThresholds: Array[Double] = { + checkThresholdConsistency() + if (!isSet(thresholds) && isSet(threshold)) { + val t = $(threshold) + Array(1-t, t) } else { - 0.5 + $(thresholds) + } + } + + /** + * If [[threshold]] and [[thresholds]] are both set, ensures they are consistent. + * @throws IllegalArgumentException if [[threshold]] and [[thresholds]] are not equivalent + */ + protected def checkThresholdConsistency(): Unit = { + if (isSet(threshold) && isSet(thresholds)) { + val ts = $(thresholds) + require(ts.length == 2, "Logistic Regression found inconsistent values for threshold and" + + s" thresholds. Param threshold is set (${$(threshold)}), indicating binary" + + s" classification, but Param thresholds is set with length ${ts.length}." + + " Clear one Param value to fix this problem.") + val t = 1.0 / (1.0 + ts(0) / ts(1)) + require(math.abs($(threshold) - t) < 1E-5, "Logistic Regression getThreshold found" + + s" inconsistent values for threshold (${$(threshold)}) and thresholds (equivalent to $t)") } } + + override def validateParams(): Unit = { + checkThresholdConsistency() + } } /** * :: Experimental :: * Logistic regression. - * Currently, this class only supports binary classification. + * Currently, this class only supports binary classification. It will support multiclass + * in the future. */ @Experimental class LogisticRegression(override val uid: String) @@ -128,7 +198,7 @@ class LogisticRegression(override val uid: String) * Whether to fit an intercept term. * Default is true. * @group setParam - * */ + */ def setFitIntercept(value: Boolean): this.type = set(fitIntercept, value) setDefault(fitIntercept -> true) @@ -140,7 +210,7 @@ class LogisticRegression(override val uid: String) * is applied. In R's GLMNET package, the default behavior is true as well. * Default is true. * @group setParam - * */ + */ def setStandardization(value: Boolean): this.type = set(standardization, value) setDefault(standardization -> true) @@ -148,6 +218,10 @@ class LogisticRegression(override val uid: String) override def getThreshold: Double = super.getThreshold + override def setThresholds(value: Array[Double]): this.type = super.setThresholds(value) + + override def getThresholds: Array[Double] = super.getThresholds + override protected def train(dataset: DataFrame): LogisticRegressionModel = { // Extract columns from data. If dataset is persisted, do not persist oldDataset. val instances = extractLabeledPoints(dataset).map { @@ -314,6 +388,10 @@ class LogisticRegressionModel private[ml] ( override def getThreshold: Double = super.getThreshold + override def setThresholds(value: Array[Double]): this.type = super.setThresholds(value) + + override def getThresholds: Array[Double] = super.getThresholds + /** Margin (rawPrediction) for class label 1. For binary classification only. */ private val margin: Vector => Double = (features) => { BLAS.dot(features, weights) + intercept @@ -364,6 +442,7 @@ class LogisticRegressionModel private[ml] ( * The behavior of this can be adjusted using [[thresholds]]. */ override protected def predict(features: Vector): Double = { + // Note: We should use getThreshold instead of $(threshold) since getThreshold is overridden. if (score(features) > getThreshold) 1 else 0 } @@ -393,6 +472,7 @@ class LogisticRegressionModel private[ml] ( } override protected def raw2prediction(rawPrediction: Vector): Double = { + // Note: We should use getThreshold instead of $(threshold) since getThreshold is overridden. val t = getThreshold val rawThreshold = if (t == 0.0) { Double.NegativeInfinity @@ -405,6 +485,7 @@ class LogisticRegressionModel private[ml] ( } override protected def probability2prediction(probability: Vector): Double = { + // Note: We should use getThreshold instead of $(threshold) since getThreshold is overridden. if (probability(1) > getThreshold) 1 else 0 } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala index da4c076830391..9e12f1856a940 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala @@ -45,14 +45,14 @@ private[shared] object SharedParamsCodeGen { " These probabilities should be treated as confidences, not precise probabilities.", Some("\"probability\"")), ParamDesc[Double]("threshold", - "threshold in binary classification prediction, in range [0, 1]", + "threshold in binary classification prediction, in range [0, 1]", Some("0.5"), isValid = "ParamValidators.inRange(0, 1)", finalMethods = false), ParamDesc[Array[Double]]("thresholds", "Thresholds in multi-class classification" + " to adjust the probability of predicting each class." + " Array must have length equal to the number of classes, with values >= 0." + " The class with largest value p/t is predicted, where p is the original probability" + " of that class and t is the class' threshold.", - isValid = "(t: Array[Double]) => t.forall(_ >= 0)"), + isValid = "(t: Array[Double]) => t.forall(_ >= 0)", finalMethods = false), ParamDesc[String]("inputCol", "input column name"), ParamDesc[Array[String]]("inputCols", "input column names"), ParamDesc[String]("outputCol", "output column name", Some("uid + \"__output\"")), diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala index 23e2b6cc43996..a17d4ea960a90 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala @@ -139,7 +139,7 @@ private[ml] trait HasProbabilityCol extends Params { } /** - * Trait for shared param threshold. + * Trait for shared param threshold (default: 0.5). */ private[ml] trait HasThreshold extends Params { @@ -149,6 +149,8 @@ private[ml] trait HasThreshold extends Params { */ final val threshold: DoubleParam = new DoubleParam(this, "threshold", "threshold in binary classification prediction, in range [0, 1]", ParamValidators.inRange(0, 1)) + setDefault(threshold, 0.5) + /** @group getParam */ def getThreshold: Double = $(threshold) } @@ -165,7 +167,7 @@ private[ml] trait HasThresholds extends Params { final val thresholds: DoubleArrayParam = new DoubleArrayParam(this, "thresholds", "Thresholds in multi-class classification to adjust the probability of predicting each class. Array must have length equal to the number of classes, with values >= 0. The class with largest value p/t is predicted, where p is the original probability of that class and t is the class' threshold.", (t: Array[Double]) => t.forall(_ >= 0)) /** @group getParam */ - final def getThresholds: Array[Double] = $(thresholds) + def getThresholds: Array[Double] = $(thresholds) } /** diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java index 7e9aa383728f0..618b95b9bd126 100644 --- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java @@ -100,9 +100,7 @@ public void logisticRegressionWithSetters() { assert(r.getDouble(0) == 0.0); } // Call transform with params, and check that the params worked. - double[] thresholds = {1.0, 0.0}; - model.transform( - dataset, model.thresholds().w(thresholds), model.probabilityCol().w("myProb")) + model.transform(dataset, model.threshold().w(0.0), model.probabilityCol().w("myProb")) .registerTempTable("predNotAllZero"); DataFrame predNotAllZero = jsql.sql("SELECT prediction, myProb FROM predNotAllZero"); boolean foundNonZero = false; @@ -112,9 +110,8 @@ public void logisticRegressionWithSetters() { assert(foundNonZero); // Call fit() with new params, and check as many params as we can. - double[] thresholds2 = {0.6, 0.4}; LogisticRegressionModel model2 = lr.fit(dataset, lr.maxIter().w(5), lr.regParam().w(0.1), - lr.thresholds().w(thresholds2), lr.probabilityCol().w("theProb")); + lr.threshold().w(0.4), lr.probabilityCol().w("theProb")); LogisticRegression parent2 = (LogisticRegression) model2.parent(); assert(parent2.getMaxIter() == 5); assert(parent2.getRegParam() == 0.1); diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala index 8c3d4590f5ae9..e354e161c6dee 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala @@ -94,12 +94,13 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { test("setThreshold, getThreshold") { val lr = new LogisticRegression // default - withClue("LogisticRegression should not have thresholds set by default") { - intercept[java.util.NoSuchElementException] { + assert(lr.getThreshold === 0.5, "LogisticRegression.threshold should default to 0.5") + withClue("LogisticRegression should not have thresholds set by default.") { + intercept[java.util.NoSuchElementException] { // Note: The exception type may change in future lr.getThresholds } } - // Set via thresholds. + // Set via threshold. // Intuition: Large threshold or large thresholds(1) makes class 0 more likely. lr.setThreshold(1.0) assert(lr.getThresholds === Array(0.0, 1.0)) @@ -107,10 +108,26 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { assert(lr.getThresholds === Array(1.0, 0.0)) lr.setThreshold(0.5) assert(lr.getThresholds === Array(0.5, 0.5)) - // Test getThreshold - lr.setThresholds(Array(0.3, 0.7)) + // Set via thresholds + val lr2 = new LogisticRegression + lr2.setThresholds(Array(0.3, 0.7)) val expectedThreshold = 1.0 / (1.0 + 0.3 / 0.7) - assert(lr.getThreshold ~== expectedThreshold relTol 1E-7) + assert(lr2.getThreshold ~== expectedThreshold relTol 1E-7) + // thresholds and threshold must be consistent + lr2.setThresholds(Array(0.1, 0.2, 0.3)) + withClue("getThreshold should throw error if thresholds has length != 2.") { + intercept[IllegalArgumentException] { + lr2.getThreshold + } + } + // thresholds and threshold must be consistent: values + withClue("fit with ParamMap should throw error if threshold, thresholds do not match.") { + intercept[IllegalArgumentException] { + val lr2model = lr2.fit(dataset, + lr2.thresholds -> Array(0.3, 0.7), lr2.threshold -> (expectedThreshold / 2.0)) + lr2model.getThreshold + } + } } test("logistic regression doesn't fit intercept when fitIntercept is off") { @@ -145,7 +162,7 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { s" ${predAllZero.count(_ === 0)} of ${dataset.count()} were 0.") // Call transform with params, and check that the params worked. val predNotAllZero = - model.transform(dataset, model.thresholds -> Array(1.0, 0.0), + model.transform(dataset, model.threshold -> 0.0, model.probabilityCol -> "myProb") .select("prediction", "myProb") .collect() @@ -153,8 +170,8 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { assert(predNotAllZero.exists(_ !== 0.0)) // Call fit() with new params, and check as many params as we can. + lr.setThresholds(Array(0.6, 0.4)) val model2 = lr.fit(dataset, lr.maxIter -> 5, lr.regParam -> 0.1, - lr.thresholds -> Array(0.6, 0.4), lr.probabilityCol -> "theProb") val parent2 = model2.parent.asInstanceOf[LogisticRegression] assert(parent2.getMaxIter === 5) diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index 6702dce5545a9..83f808efc3bf0 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -76,19 +76,21 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti " Array must have length equal to the number of classes, with values >= 0." + " The class with largest value p/t is predicted, where p is the original" + " probability of that class and t is the class' threshold.") + threshold = Param(Params._dummy(), "threshold", + "Threshold in binary classification prediction, in range [0, 1]." + + " If threshold and thresholds are both set, they must match.") @keyword_only def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", maxIter=100, regParam=0.1, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, - threshold=None, thresholds=None, + threshold=0.5, thresholds=None, probabilityCol="probability", rawPredictionCol="rawPrediction"): """ __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ maxIter=100, regParam=0.1, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, \ - threshold=None, thresholds=None, \ + threshold=0.5, thresholds=None, \ probabilityCol="probability", rawPredictionCol="rawPrediction") - Param thresholds overrides Param threshold; threshold is provided - for backwards compatibility and only applies to binary classification. + If the threshold and thresholds Params are both set, they must be equivalent. """ super(LogisticRegression, self).__init__() self._java_obj = self._new_java_obj( @@ -101,7 +103,11 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred "the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty.") #: param for whether to fit an intercept term. self.fitIntercept = Param(self, "fitIntercept", "whether to fit an intercept term.") - #: param for threshold in binary classification prediction, in range [0, 1]. + #: param for threshold in binary classification, in range [0, 1]. + self.threshold = Param(self, "threshold", + "Threshold in binary classification prediction, in range [0, 1]." + + " If threshold and thresholds are both set, they must match.") + #: param for thresholds or cutoffs in binary or multiclass classification self.thresholds = \ Param(self, "thresholds", "Thresholds in multi-class classification" + @@ -110,29 +116,28 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred " The class with largest value p/t is predicted, where p is the original" + " probability of that class and t is the class' threshold.") self._setDefault(maxIter=100, regParam=0.1, elasticNetParam=0.0, tol=1E-6, - fitIntercept=True) + fitIntercept=True, threshold=0.5) kwargs = self.__init__._input_kwargs self.setParams(**kwargs) + self._checkThresholdConsistency() @keyword_only def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", maxIter=100, regParam=0.1, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, - threshold=None, thresholds=None, + threshold=0.5, thresholds=None, probabilityCol="probability", rawPredictionCol="rawPrediction"): """ setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ maxIter=100, regParam=0.1, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, \ - threshold=None, thresholds=None, \ + threshold=0.5, thresholds=None, \ probabilityCol="probability", rawPredictionCol="rawPrediction") Sets params for logistic regression. - Param thresholds overrides Param threshold; threshold is provided - for backwards compatibility and only applies to binary classification. + If the threshold and thresholds Params are both set, they must be equivalent. """ - # Under the hood we use thresholds so translate threshold to thresholds if applicable - if thresholds is None and threshold is not None: - kwargs[thresholds] = [1-threshold, threshold] kwargs = self.setParams._input_kwargs - return self._set(**kwargs) + self._set(**kwargs) + self._checkThresholdConsistency() + return self def _create_model(self, java_model): return LogisticRegressionModel(java_model) @@ -165,44 +170,65 @@ def getFitIntercept(self): def setThreshold(self, value): """ - Sets the value of :py:attr:`thresholds` using [1-value, value]. + Sets the value of :py:attr:`threshold`. + Clears value of :py:attr:`thresholds` if it has been set. + """ + self._paramMap[self.threshold] = value + if self.isSet(self.thresholds): + del self._paramMap[self.thresholds] + return self - >>> lr = LogisticRegression() - >>> lr.getThreshold() - 0.5 - >>> lr.setThreshold(0.6) - LogisticRegression_... - >>> abs(lr.getThreshold() - 0.6) < 1e-5 - True + def getThreshold(self): + """ + Gets the value of threshold or its default value. """ - return self.setThresholds([1-value, value]) + self._checkThresholdConsistency() + if self.isSet(self.thresholds): + ts = self.getOrDefault(self.thresholds) + if len(ts) != 2: + raise ValueError("Logistic Regression getThreshold only applies to" + + " binary classification, but thresholds has length != 2." + + " thresholds: " + ",".join(ts)) + return 1.0/(1.0 + ts[0]/ts[1]) + else: + return self.getOrDefault(self.threshold) def setThresholds(self, value): """ Sets the value of :py:attr:`thresholds`. + Clears value of :py:attr:`threshold` if it has been set. """ self._paramMap[self.thresholds] = value + if self.isSet(self.threshold): + del self._paramMap[self.threshold] return self def getThresholds(self): """ - Gets the value of thresholds or its default value. + If :py:attr:`thresholds` is set, return its value. + Otherwise, if :py:attr:`threshold` is set, return the equivalent thresholds for binary + classification: (1-threshold, threshold). + If neither are set, throw an error. """ - return self.getOrDefault(self.thresholds) + self._checkThresholdConsistency() + if not self.isSet(self.thresholds) and self.isSet(self.threshold): + t = self.getOrDefault(self.threshold) + return [1.0-t, t] + else: + return self.getOrDefault(self.thresholds) - def getThreshold(self): - """ - Gets the value of threshold or its default value. - """ - if self.isDefined(self.thresholds): - thresholds = self.getOrDefault(self.thresholds) - if len(thresholds) != 2: + def _checkThresholdConsistency(self): + if self.isSet(self.threshold) and self.isSet(self.thresholds): + ts = self.getParam(self.thresholds) + if len(ts) != 2: raise ValueError("Logistic Regression getThreshold only applies to" + " binary classification, but thresholds has length != 2." + - " thresholds: " + ",".join(thresholds)) - return 1.0/(1.0+thresholds[0]/thresholds[1]) - else: - return 0.5 + " thresholds: " + ",".join(ts)) + t = 1.0/(1.0 + ts[0]/ts[1]) + t2 = self.getParam(self.threshold) + if abs(t2 - t) >= 1E-5: + raise ValueError("Logistic Regression getThreshold found inconsistent values for" + + " threshold (%g) and thresholds (equivalent to %g)" % (t2, t)) class LogisticRegressionModel(JavaModel): From 6f60298b1d7aa97268a42eca1e3b4851a7e88cb5 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Wed, 12 Aug 2015 14:28:23 -0700 Subject: [PATCH 060/253] [SPARK-8967] [DOC] add Since annotation Add `Since` as a Scala annotation. The benefit is that we can use it without having explicit JavaDoc. This is useful for inherited methods. The limitation is that is doesn't show up in the generated Java API documentation. This might be fixed by modifying genjavadoc. I think we could leave it as a TODO. This is how the generated Scala doc looks: `since` JavaDoc tag: ![screen shot 2015-08-11 at 10 00 37 pm](https://cloud.githubusercontent.com/assets/829644/9230761/fa72865c-40d8-11e5-807e-0f3c815c5acd.png) `Since` annotation: ![screen shot 2015-08-11 at 10 00 28 pm](https://cloud.githubusercontent.com/assets/829644/9230764/0041d7f4-40d9-11e5-8124-c3f3e5d5b31f.png) rxin Author: Xiangrui Meng Closes #8131 from mengxr/SPARK-8967. --- .../org/apache/spark/annotation/Since.scala | 28 +++++++++++++++++++ 1 file changed, 28 insertions(+) create mode 100644 core/src/main/scala/org/apache/spark/annotation/Since.scala diff --git a/core/src/main/scala/org/apache/spark/annotation/Since.scala b/core/src/main/scala/org/apache/spark/annotation/Since.scala new file mode 100644 index 0000000000000..fa59393c22476 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/annotation/Since.scala @@ -0,0 +1,28 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.annotation + +import scala.annotation.StaticAnnotation + +/** + * A Scala annotation that specifies the Spark version when a definition was added. + * Different from the `@since` tag in JavaDoc, this annotation does not require explicit JavaDoc and + * hence works for overridden methods that inherit API documentation directly from parents. + * The limitation is that it does not show up in the generated Java API documentation. + */ +private[spark] class Since(version: String) extends StaticAnnotation From a17384fa343628cec44437da5b80b9403ecd5838 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 12 Aug 2015 15:27:52 -0700 Subject: [PATCH 061/253] [SPARK-9907] [SQL] Python crc32 is mistakenly calling md5 Author: Reynold Xin Closes #8138 from rxin/SPARK-9907. --- python/pyspark/sql/functions.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 95f46044d324a..e98979533f901 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -885,10 +885,10 @@ def crc32(col): returns the value as a bigint. >>> sqlContext.createDataFrame([('ABC',)], ['a']).select(crc32('a').alias('crc32')).collect() - [Row(crc32=u'902fbdd2b1df0c4f70b4a5d23525e932')] + [Row(crc32=2743272264)] """ sc = SparkContext._active_spark_context - return Column(sc._jvm.functions.md5(_to_java_column(col))) + return Column(sc._jvm.functions.crc32(_to_java_column(col))) @ignore_unicode_prefix From 738f353988dbf02704bd63f5e35d94402c59ed79 Mon Sep 17 00:00:00 2001 From: Niranjan Padmanabhan Date: Wed, 12 Aug 2015 16:10:21 -0700 Subject: [PATCH 062/253] [SPARK-9092] Fixed incompatibility when both num-executors and dynamic... MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit … allocation are set. Now, dynamic allocation is set to false when num-executors is explicitly specified as an argument. Consequently, executorAllocationManager in not initialized in the SparkContext. Author: Niranjan Padmanabhan Closes #7657 from neurons/SPARK-9092. --- .../scala/org/apache/spark/SparkConf.scala | 19 +++++++++++++++++++ .../scala/org/apache/spark/SparkContext.scala | 6 +++++- .../org/apache/spark/deploy/SparkSubmit.scala | 4 ++-- .../scala/org/apache/spark/util/Utils.scala | 11 +++++++++++ .../org/apache/spark/SparkContextSuite.scala | 8 ++++++++ .../spark/deploy/SparkSubmitSuite.scala | 1 - docs/running-on-yarn.md | 2 +- .../spark/deploy/yarn/ApplicationMaster.scala | 4 ++-- .../yarn/ApplicationMasterArguments.scala | 5 ----- .../org/apache/spark/deploy/yarn/Client.scala | 5 ++++- .../spark/deploy/yarn/ClientArguments.scala | 8 +------- .../spark/deploy/yarn/YarnAllocator.scala | 9 ++++++++- .../cluster/YarnClientSchedulerBackend.scala | 3 --- .../deploy/yarn/YarnAllocatorSuite.scala | 5 +++-- 14 files changed, 64 insertions(+), 26 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SparkConf.scala b/core/src/main/scala/org/apache/spark/SparkConf.scala index 8ff154fb5e334..b344b5e173d67 100644 --- a/core/src/main/scala/org/apache/spark/SparkConf.scala +++ b/core/src/main/scala/org/apache/spark/SparkConf.scala @@ -389,6 +389,7 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging { val driverOptsKey = "spark.driver.extraJavaOptions" val driverClassPathKey = "spark.driver.extraClassPath" val driverLibraryPathKey = "spark.driver.extraLibraryPath" + val sparkExecutorInstances = "spark.executor.instances" // Used by Yarn in 1.1 and before sys.props.get("spark.driver.libraryPath").foreach { value => @@ -476,6 +477,24 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging { } } } + + if (!contains(sparkExecutorInstances)) { + sys.env.get("SPARK_WORKER_INSTANCES").foreach { value => + val warning = + s""" + |SPARK_WORKER_INSTANCES was detected (set to '$value'). + |This is deprecated in Spark 1.0+. + | + |Please instead use: + | - ./spark-submit with --num-executors to specify the number of executors + | - Or set SPARK_EXECUTOR_INSTANCES + | - spark.executor.instances to configure the number of instances in the spark config. + """.stripMargin + logWarning(warning) + + set("spark.executor.instances", value) + } + } } /** diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 6aafb4c5644d7..207a0c1bffeb3 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -528,7 +528,11 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli } // Optionally scale number of executors dynamically based on workload. Exposed for testing. - val dynamicAllocationEnabled = _conf.getBoolean("spark.dynamicAllocation.enabled", false) + val dynamicAllocationEnabled = Utils.isDynamicAllocationEnabled(_conf) + if (!dynamicAllocationEnabled && _conf.getBoolean("spark.dynamicAllocation.enabled", false)) { + logInfo("Dynamic Allocation and num executors both set, thus dynamic allocation disabled.") + } + _executorAllocationManager = if (dynamicAllocationEnabled) { Some(new ExecutorAllocationManager(this, listenerBus, _conf)) diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index 7ac6cbce4cd1d..02fa3088eded0 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -422,7 +422,8 @@ object SparkSubmit { // Yarn client only OptionAssigner(args.queue, YARN, CLIENT, sysProp = "spark.yarn.queue"), - OptionAssigner(args.numExecutors, YARN, CLIENT, sysProp = "spark.executor.instances"), + OptionAssigner(args.numExecutors, YARN, ALL_DEPLOY_MODES, + sysProp = "spark.executor.instances"), OptionAssigner(args.files, YARN, CLIENT, sysProp = "spark.yarn.dist.files"), OptionAssigner(args.archives, YARN, CLIENT, sysProp = "spark.yarn.dist.archives"), OptionAssigner(args.principal, YARN, CLIENT, sysProp = "spark.yarn.principal"), @@ -433,7 +434,6 @@ object SparkSubmit { OptionAssigner(args.driverMemory, YARN, CLUSTER, clOption = "--driver-memory"), OptionAssigner(args.driverCores, YARN, CLUSTER, clOption = "--driver-cores"), OptionAssigner(args.queue, YARN, CLUSTER, clOption = "--queue"), - OptionAssigner(args.numExecutors, YARN, CLUSTER, clOption = "--num-executors"), OptionAssigner(args.executorMemory, YARN, CLUSTER, clOption = "--executor-memory"), OptionAssigner(args.executorCores, YARN, CLUSTER, clOption = "--executor-cores"), OptionAssigner(args.files, YARN, CLUSTER, clOption = "--files"), diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index c4012d0e83f7d..a90d8541366f4 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -2286,6 +2286,17 @@ private[spark] object Utils extends Logging { isInDirectory(parent, child.getParentFile) } + /** + * Return whether dynamic allocation is enabled in the given conf + * Dynamic allocation and explicitly setting the number of executors are inherently + * incompatible. In environments where dynamic allocation is turned on by default, + * the latter should override the former (SPARK-9092). + */ + def isDynamicAllocationEnabled(conf: SparkConf): Boolean = { + conf.contains("spark.dynamicAllocation.enabled") && + conf.getInt("spark.executor.instances", 0) == 0 + } + } private [util] class SparkShutdownHookManager { diff --git a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala index 5c57940fa5f77..d4f2ea87650a9 100644 --- a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala @@ -285,4 +285,12 @@ class SparkContextSuite extends SparkFunSuite with LocalSparkContext { } } + test("No exception when both num-executors and dynamic allocation set.") { + noException should be thrownBy { + sc = new SparkContext(new SparkConf().setAppName("test").setMaster("local") + .set("spark.dynamicAllocation.enabled", "true").set("spark.executor.instances", "6")) + assert(sc.executorAllocationManager.isEmpty) + assert(sc.getConf.getInt("spark.executor.instances", 0) === 6) + } + } } diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala index 757e0ce3d278b..2456c5d0d49b0 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala @@ -159,7 +159,6 @@ class SparkSubmitSuite childArgsStr should include ("--executor-cores 5") childArgsStr should include ("--arg arg1 --arg arg2") childArgsStr should include ("--queue thequeue") - childArgsStr should include ("--num-executors 6") childArgsStr should include regex ("--jar .*thejar.jar") childArgsStr should include regex ("--addJars .*one.jar,.*two.jar,.*three.jar") childArgsStr should include regex ("--files .*file1.txt,.*file2.txt") diff --git a/docs/running-on-yarn.md b/docs/running-on-yarn.md index cac08a91b97d9..ec32c419b7c51 100644 --- a/docs/running-on-yarn.md +++ b/docs/running-on-yarn.md @@ -199,7 +199,7 @@ If you need a reference to the proper location to put log files in the YARN so t spark.executor.instances 2 - The number of executors. Note that this property is incompatible with spark.dynamicAllocation.enabled. + The number of executors. Note that this property is incompatible with spark.dynamicAllocation.enabled. If both spark.dynamicAllocation.enabled and spark.executor.instances are specified, dynamic allocation is turned off and the specified number of spark.executor.instances is used. diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala index 1d67b3ebb51b7..e19940d8d6642 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -64,7 +64,8 @@ private[spark] class ApplicationMaster( // Default to numExecutors * 2, with minimum of 3 private val maxNumExecutorFailures = sparkConf.getInt("spark.yarn.max.executor.failures", - sparkConf.getInt("spark.yarn.max.worker.failures", math.max(args.numExecutors * 2, 3))) + sparkConf.getInt("spark.yarn.max.worker.failures", + math.max(sparkConf.getInt("spark.executor.instances", 0) * 2, 3))) @volatile private var exitCode = 0 @volatile private var unregistered = false @@ -493,7 +494,6 @@ private[spark] class ApplicationMaster( */ private def startUserApplication(): Thread = { logInfo("Starting the user application in a separate Thread") - System.setProperty("spark.executor.instances", args.numExecutors.toString) val classpath = Client.getUserClasspath(sparkConf) val urls = classpath.map { entry => diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMasterArguments.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMasterArguments.scala index 37f793763367e..b08412414aa1c 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMasterArguments.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMasterArguments.scala @@ -29,7 +29,6 @@ class ApplicationMasterArguments(val args: Array[String]) { var userArgs: Seq[String] = Nil var executorMemory = 1024 var executorCores = 1 - var numExecutors = DEFAULT_NUMBER_EXECUTORS var propertiesFile: String = null parseArgs(args.toList) @@ -63,10 +62,6 @@ class ApplicationMasterArguments(val args: Array[String]) { userArgsBuffer += value args = tail - case ("--num-workers" | "--num-executors") :: IntParam(value) :: tail => - numExecutors = value - args = tail - case ("--worker-memory" | "--executor-memory") :: MemoryParam(value) :: tail => executorMemory = value args = tail diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index b4ba3f0221600..6d63ddaf15852 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -751,7 +751,6 @@ private[spark] class Client( userArgs ++ Seq( "--executor-memory", args.executorMemory.toString + "m", "--executor-cores", args.executorCores.toString, - "--num-executors ", args.numExecutors.toString, "--properties-file", buildPath(YarnSparkHadoopUtil.expandEnvironment(Environment.PWD), LOCALIZED_CONF_DIR, SPARK_CONF_FILE)) @@ -960,6 +959,10 @@ object Client extends Logging { val sparkConf = new SparkConf val args = new ClientArguments(argStrings, sparkConf) + // to maintain backwards-compatibility + if (!Utils.isDynamicAllocationEnabled(sparkConf)) { + sparkConf.setIfMissing("spark.executor.instances", args.numExecutors.toString) + } new Client(args, sparkConf).run() } diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala index 20d63d40cf605..4f42ffefa77f9 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala @@ -53,8 +53,7 @@ private[spark] class ClientArguments(args: Array[String], sparkConf: SparkConf) private val amMemOverheadKey = "spark.yarn.am.memoryOverhead" private val driverCoresKey = "spark.driver.cores" private val amCoresKey = "spark.yarn.am.cores" - private val isDynamicAllocationEnabled = - sparkConf.getBoolean("spark.dynamicAllocation.enabled", false) + private val isDynamicAllocationEnabled = Utils.isDynamicAllocationEnabled(sparkConf) parseArgs(args.toList) loadEnvironmentArgs() @@ -196,11 +195,6 @@ private[spark] class ClientArguments(args: Array[String], sparkConf: SparkConf) if (args(0) == "--num-workers") { println("--num-workers is deprecated. Use --num-executors instead.") } - // Dynamic allocation is not compatible with this option - if (isDynamicAllocationEnabled) { - throw new IllegalArgumentException("Explicitly setting the number " + - "of executors is not compatible with spark.dynamicAllocation.enabled!") - } numExecutors = value args = tail diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala index 59caa787b6e20..ccf753e69f4b6 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala @@ -21,6 +21,8 @@ import java.util.Collections import java.util.concurrent._ import java.util.regex.Pattern +import org.apache.spark.util.Utils + import scala.collection.JavaConversions._ import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet} @@ -86,7 +88,12 @@ private[yarn] class YarnAllocator( private var executorIdCounter = 0 @volatile private var numExecutorsFailed = 0 - @volatile private var targetNumExecutors = args.numExecutors + @volatile private var targetNumExecutors = + if (Utils.isDynamicAllocationEnabled(sparkConf)) { + sparkConf.getInt("spark.dynamicAllocation.initialExecutors", 0) + } else { + sparkConf.getInt("spark.executor.instances", YarnSparkHadoopUtil.DEFAULT_NUMBER_EXECUTORS) + } // Keep track of which container is running which executor to remove the executors later // Visible for testing. diff --git a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala index d225061fcd1b4..d06d95140438c 100644 --- a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala +++ b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala @@ -81,8 +81,6 @@ private[spark] class YarnClientSchedulerBackend( // List of (target Client argument, environment variable, Spark property) val optionTuples = List( - ("--num-executors", "SPARK_WORKER_INSTANCES", "spark.executor.instances"), - ("--num-executors", "SPARK_EXECUTOR_INSTANCES", "spark.executor.instances"), ("--executor-memory", "SPARK_WORKER_MEMORY", "spark.executor.memory"), ("--executor-memory", "SPARK_EXECUTOR_MEMORY", "spark.executor.memory"), ("--executor-cores", "SPARK_WORKER_CORES", "spark.executor.cores"), @@ -92,7 +90,6 @@ private[spark] class YarnClientSchedulerBackend( ) // Warn against the following deprecated environment variables: env var -> suggestion val deprecatedEnvVars = Map( - "SPARK_WORKER_INSTANCES" -> "SPARK_WORKER_INSTANCES or --num-executors through spark-submit", "SPARK_WORKER_MEMORY" -> "SPARK_EXECUTOR_MEMORY or --executor-memory through spark-submit", "SPARK_WORKER_CORES" -> "SPARK_EXECUTOR_CORES or --executor-cores through spark-submit") optionTuples.foreach { case (optionName, envVar, sparkProp) => diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala index 58318bf9bcc08..5d05f514adde3 100644 --- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala +++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala @@ -87,16 +87,17 @@ class YarnAllocatorSuite extends SparkFunSuite with Matchers with BeforeAndAfter def createAllocator(maxExecutors: Int = 5): YarnAllocator = { val args = Array( - "--num-executors", s"$maxExecutors", "--executor-cores", "5", "--executor-memory", "2048", "--jar", "somejar.jar", "--class", "SomeClass") + val sparkConfClone = sparkConf.clone() + sparkConfClone.set("spark.executor.instances", maxExecutors.toString) new YarnAllocator( "not used", mock(classOf[RpcEndpointRef]), conf, - sparkConf, + sparkConfClone, rmClient, appAttemptId, new ApplicationMasterArguments(args), From ab7e721cfec63155641e81e72b4ad43cf6a7d4c7 Mon Sep 17 00:00:00 2001 From: Michel Lemay Date: Wed, 12 Aug 2015 16:17:58 -0700 Subject: [PATCH 063/253] [SPARK-9826] [CORE] Fix cannot use custom classes in log4j.properties Refactor Utils class and create ShutdownHookManager. NOTE: Wasn't able to run /dev/run-tests on windows machine. Manual tests were conducted locally using custom log4j.properties file with Redis appender and logstash formatter (bundled in the fat-jar submitted to spark) ex: log4j.rootCategory=WARN,console,redis log4j.appender.console=org.apache.log4j.ConsoleAppender log4j.appender.console.target=System.err log4j.appender.console.layout=org.apache.log4j.PatternLayout log4j.appender.console.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss} %p %c{1}: %m%n log4j.logger.org.eclipse.jetty=WARN log4j.logger.org.eclipse.jetty.util.component.AbstractLifeCycle=ERROR log4j.logger.org.apache.spark.repl.SparkIMain$exprTyper=INFO log4j.logger.org.apache.spark.repl.SparkILoop$SparkILoopInterpreter=INFO log4j.logger.org.apache.spark.graphx.Pregel=INFO log4j.appender.redis=com.ryantenney.log4j.FailoverRedisAppender log4j.appender.redis.endpoints=hostname:port log4j.appender.redis.key=mykey log4j.appender.redis.alwaysBatch=false log4j.appender.redis.layout=net.logstash.log4j.JSONEventLayoutV1 Author: michellemay Closes #8109 from michellemay/SPARK-9826. --- .../scala/org/apache/spark/SparkContext.scala | 5 +- .../spark/deploy/history/HistoryServer.scala | 4 +- .../spark/deploy/worker/ExecutorRunner.scala | 7 +- .../org/apache/spark/rdd/HadoopRDD.scala | 4 +- .../org/apache/spark/rdd/NewHadoopRDD.scala | 4 +- .../apache/spark/rdd/SqlNewHadoopRDD.scala | 4 +- .../spark/storage/DiskBlockManager.scala | 10 +- .../spark/storage/TachyonBlockManager.scala | 6 +- .../spark/util/ShutdownHookManager.scala | 266 ++++++++++++++++++ .../util/SparkUncaughtExceptionHandler.scala | 2 +- .../scala/org/apache/spark/util/Utils.scala | 222 +-------------- .../hive/thriftserver/HiveThriftServer2.scala | 4 +- .../hive/thriftserver/SparkSQLCLIDriver.scala | 4 +- .../apache/spark/sql/hive/test/TestHive.scala | 4 +- .../spark/streaming/StreamingContext.scala | 8 +- .../spark/deploy/yarn/ApplicationMaster.scala | 5 +- 16 files changed, 307 insertions(+), 252 deletions(-) create mode 100644 core/src/main/scala/org/apache/spark/util/ShutdownHookManager.scala diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 207a0c1bffeb3..2e01a9a18c784 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -563,7 +563,8 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli // Make sure the context is stopped if the user forgets about it. This avoids leaving // unfinished event logs around after the JVM exits cleanly. It doesn't help if the JVM // is killed, though. - _shutdownHookRef = Utils.addShutdownHook(Utils.SPARK_CONTEXT_SHUTDOWN_PRIORITY) { () => + _shutdownHookRef = ShutdownHookManager.addShutdownHook( + ShutdownHookManager.SPARK_CONTEXT_SHUTDOWN_PRIORITY) { () => logInfo("Invoking stop() from shutdown hook") stop() } @@ -1671,7 +1672,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli return } if (_shutdownHookRef != null) { - Utils.removeShutdownHook(_shutdownHookRef) + ShutdownHookManager.removeShutdownHook(_shutdownHookRef) } Utils.tryLogNonFatalError { diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala index a076a9c3f984d..d4f327cc588fe 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala @@ -30,7 +30,7 @@ import org.apache.spark.status.api.v1.{ApiRootResource, ApplicationInfo, Applica UIRoot} import org.apache.spark.ui.{SparkUI, UIUtils, WebUI} import org.apache.spark.ui.JettyUtils._ -import org.apache.spark.util.{SignalLogger, Utils} +import org.apache.spark.util.{ShutdownHookManager, SignalLogger, Utils} /** * A web server that renders SparkUIs of completed applications. @@ -238,7 +238,7 @@ object HistoryServer extends Logging { val server = new HistoryServer(conf, provider, securityManager, port) server.bind() - Utils.addShutdownHook { () => server.stop() } + ShutdownHookManager.addShutdownHook { () => server.stop() } // Wait until the end of the world... or if the HistoryServer process is manually stopped while(true) { Thread.sleep(Int.MaxValue) } diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala index 29a5042285578..ab3fea475c2a5 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala @@ -28,7 +28,7 @@ import org.apache.spark.rpc.RpcEndpointRef import org.apache.spark.{SecurityManager, SparkConf, Logging} import org.apache.spark.deploy.{ApplicationDescription, ExecutorState} import org.apache.spark.deploy.DeployMessages.ExecutorStateChanged -import org.apache.spark.util.Utils +import org.apache.spark.util.{ShutdownHookManager, Utils} import org.apache.spark.util.logging.FileAppender /** @@ -70,7 +70,8 @@ private[deploy] class ExecutorRunner( } workerThread.start() // Shutdown hook that kills actors on shutdown. - shutdownHook = Utils.addShutdownHook { () => killProcess(Some("Worker shutting down")) } + shutdownHook = ShutdownHookManager.addShutdownHook { () => + killProcess(Some("Worker shutting down")) } } /** @@ -102,7 +103,7 @@ private[deploy] class ExecutorRunner( workerThread = null state = ExecutorState.KILLED try { - Utils.removeShutdownHook(shutdownHook) + ShutdownHookManager.removeShutdownHook(shutdownHook) } catch { case e: IllegalStateException => None } diff --git a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala index f1c17369cb48c..e1f8719eead02 100644 --- a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala @@ -44,7 +44,7 @@ import org.apache.spark.broadcast.Broadcast import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.executor.DataReadMethod import org.apache.spark.rdd.HadoopRDD.HadoopMapPartitionsWithSplitRDD -import org.apache.spark.util.{SerializableConfiguration, NextIterator, Utils} +import org.apache.spark.util.{SerializableConfiguration, ShutdownHookManager, NextIterator, Utils} import org.apache.spark.scheduler.{HostTaskLocation, HDFSCacheTaskLocation} import org.apache.spark.storage.StorageLevel @@ -274,7 +274,7 @@ class HadoopRDD[K, V]( } } catch { case e: Exception => { - if (!Utils.inShutdown()) { + if (!ShutdownHookManager.inShutdown()) { logWarning("Exception in RecordReader.close()", e) } } diff --git a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala index f83a051f5da11..6a9c004d65cff 100644 --- a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala @@ -33,7 +33,7 @@ import org.apache.spark._ import org.apache.spark.executor.DataReadMethod import org.apache.spark.mapreduce.SparkHadoopMapReduceUtil import org.apache.spark.rdd.NewHadoopRDD.NewHadoopMapPartitionsWithSplitRDD -import org.apache.spark.util.{SerializableConfiguration, Utils} +import org.apache.spark.util.{SerializableConfiguration, ShutdownHookManager, Utils} import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.storage.StorageLevel @@ -186,7 +186,7 @@ class NewHadoopRDD[K, V]( } } catch { case e: Exception => { - if (!Utils.inShutdown()) { + if (!ShutdownHookManager.inShutdown()) { logWarning("Exception in RecordReader.close()", e) } } diff --git a/core/src/main/scala/org/apache/spark/rdd/SqlNewHadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/SqlNewHadoopRDD.scala index 6a95e44c57fec..fa3fecc80cb63 100644 --- a/core/src/main/scala/org/apache/spark/rdd/SqlNewHadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/SqlNewHadoopRDD.scala @@ -33,7 +33,7 @@ import org.apache.spark.mapreduce.SparkHadoopMapReduceUtil import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.{Partition => SparkPartition, _} import org.apache.spark.storage.StorageLevel -import org.apache.spark.util.{SerializableConfiguration, Utils} +import org.apache.spark.util.{SerializableConfiguration, ShutdownHookManager, Utils} private[spark] class SqlNewHadoopPartition( @@ -212,7 +212,7 @@ private[spark] class SqlNewHadoopRDD[V: ClassTag]( } } catch { case e: Exception => - if (!Utils.inShutdown()) { + if (!ShutdownHookManager.inShutdown()) { logWarning("Exception in RecordReader.close()", e) } } diff --git a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala index 56a33d5ca7d60..3f8d26e1d4cab 100644 --- a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala @@ -22,7 +22,7 @@ import java.io.{IOException, File} import org.apache.spark.{SparkConf, Logging} import org.apache.spark.executor.ExecutorExitCode -import org.apache.spark.util.Utils +import org.apache.spark.util.{ShutdownHookManager, Utils} /** * Creates and maintains the logical mapping between logical blocks and physical on-disk @@ -144,7 +144,7 @@ private[spark] class DiskBlockManager(blockManager: BlockManager, conf: SparkCon } private def addShutdownHook(): AnyRef = { - Utils.addShutdownHook(Utils.TEMP_DIR_SHUTDOWN_PRIORITY + 1) { () => + ShutdownHookManager.addShutdownHook(ShutdownHookManager.TEMP_DIR_SHUTDOWN_PRIORITY + 1) { () => logInfo("Shutdown hook called") DiskBlockManager.this.doStop() } @@ -154,7 +154,7 @@ private[spark] class DiskBlockManager(blockManager: BlockManager, conf: SparkCon private[spark] def stop() { // Remove the shutdown hook. It causes memory leaks if we leave it around. try { - Utils.removeShutdownHook(shutdownHook) + ShutdownHookManager.removeShutdownHook(shutdownHook) } catch { case e: Exception => logError(s"Exception while removing shutdown hook.", e) @@ -168,7 +168,9 @@ private[spark] class DiskBlockManager(blockManager: BlockManager, conf: SparkCon localDirs.foreach { localDir => if (localDir.isDirectory() && localDir.exists()) { try { - if (!Utils.hasRootAsShutdownDeleteDir(localDir)) Utils.deleteRecursively(localDir) + if (!ShutdownHookManager.hasRootAsShutdownDeleteDir(localDir)) { + Utils.deleteRecursively(localDir) + } } catch { case e: Exception => logError(s"Exception while deleting local spark dir: $localDir", e) diff --git a/core/src/main/scala/org/apache/spark/storage/TachyonBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/TachyonBlockManager.scala index ebad5bc5ab28d..22878783fca67 100644 --- a/core/src/main/scala/org/apache/spark/storage/TachyonBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/TachyonBlockManager.scala @@ -32,7 +32,7 @@ import tachyon.TachyonURI import org.apache.spark.Logging import org.apache.spark.executor.ExecutorExitCode -import org.apache.spark.util.Utils +import org.apache.spark.util.{ShutdownHookManager, Utils} /** @@ -80,7 +80,7 @@ private[spark] class TachyonBlockManager() extends ExternalBlockManager with Log // in order to avoid having really large inodes at the top level in Tachyon. tachyonDirs = createTachyonDirs() subDirs = Array.fill(tachyonDirs.length)(new Array[TachyonFile](subDirsPerTachyonDir)) - tachyonDirs.foreach(tachyonDir => Utils.registerShutdownDeleteDir(tachyonDir)) + tachyonDirs.foreach(tachyonDir => ShutdownHookManager.registerShutdownDeleteDir(tachyonDir)) } override def toString: String = {"ExternalBlockStore-Tachyon"} @@ -240,7 +240,7 @@ private[spark] class TachyonBlockManager() extends ExternalBlockManager with Log logDebug("Shutdown hook called") tachyonDirs.foreach { tachyonDir => try { - if (!Utils.hasRootAsShutdownDeleteDir(tachyonDir)) { + if (!ShutdownHookManager.hasRootAsShutdownDeleteDir(tachyonDir)) { Utils.deleteRecursively(tachyonDir, client) } } catch { diff --git a/core/src/main/scala/org/apache/spark/util/ShutdownHookManager.scala b/core/src/main/scala/org/apache/spark/util/ShutdownHookManager.scala new file mode 100644 index 0000000000000..61ff9b89ec1c1 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/ShutdownHookManager.scala @@ -0,0 +1,266 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util + +import java.io.File +import java.util.PriorityQueue + +import scala.util.{Failure, Success, Try} +import tachyon.client.TachyonFile + +import org.apache.hadoop.fs.FileSystem +import org.apache.spark.Logging + +/** + * Various utility methods used by Spark. + */ +private[spark] object ShutdownHookManager extends Logging { + val DEFAULT_SHUTDOWN_PRIORITY = 100 + + /** + * The shutdown priority of the SparkContext instance. This is lower than the default + * priority, so that by default hooks are run before the context is shut down. + */ + val SPARK_CONTEXT_SHUTDOWN_PRIORITY = 50 + + /** + * The shutdown priority of temp directory must be lower than the SparkContext shutdown + * priority. Otherwise cleaning the temp directories while Spark jobs are running can + * throw undesirable errors at the time of shutdown. + */ + val TEMP_DIR_SHUTDOWN_PRIORITY = 25 + + private lazy val shutdownHooks = { + val manager = new SparkShutdownHookManager() + manager.install() + manager + } + + private val shutdownDeletePaths = new scala.collection.mutable.HashSet[String]() + private val shutdownDeleteTachyonPaths = new scala.collection.mutable.HashSet[String]() + + // Add a shutdown hook to delete the temp dirs when the JVM exits + addShutdownHook(TEMP_DIR_SHUTDOWN_PRIORITY) { () => + logInfo("Shutdown hook called") + shutdownDeletePaths.foreach { dirPath => + try { + logInfo("Deleting directory " + dirPath) + Utils.deleteRecursively(new File(dirPath)) + } catch { + case e: Exception => logError(s"Exception while deleting Spark temp dir: $dirPath", e) + } + } + } + + // Register the path to be deleted via shutdown hook + def registerShutdownDeleteDir(file: File) { + val absolutePath = file.getAbsolutePath() + shutdownDeletePaths.synchronized { + shutdownDeletePaths += absolutePath + } + } + + // Register the tachyon path to be deleted via shutdown hook + def registerShutdownDeleteDir(tachyonfile: TachyonFile) { + val absolutePath = tachyonfile.getPath() + shutdownDeleteTachyonPaths.synchronized { + shutdownDeleteTachyonPaths += absolutePath + } + } + + // Remove the path to be deleted via shutdown hook + def removeShutdownDeleteDir(file: File) { + val absolutePath = file.getAbsolutePath() + shutdownDeletePaths.synchronized { + shutdownDeletePaths.remove(absolutePath) + } + } + + // Remove the tachyon path to be deleted via shutdown hook + def removeShutdownDeleteDir(tachyonfile: TachyonFile) { + val absolutePath = tachyonfile.getPath() + shutdownDeleteTachyonPaths.synchronized { + shutdownDeleteTachyonPaths.remove(absolutePath) + } + } + + // Is the path already registered to be deleted via a shutdown hook ? + def hasShutdownDeleteDir(file: File): Boolean = { + val absolutePath = file.getAbsolutePath() + shutdownDeletePaths.synchronized { + shutdownDeletePaths.contains(absolutePath) + } + } + + // Is the path already registered to be deleted via a shutdown hook ? + def hasShutdownDeleteTachyonDir(file: TachyonFile): Boolean = { + val absolutePath = file.getPath() + shutdownDeleteTachyonPaths.synchronized { + shutdownDeleteTachyonPaths.contains(absolutePath) + } + } + + // Note: if file is child of some registered path, while not equal to it, then return true; + // else false. This is to ensure that two shutdown hooks do not try to delete each others + // paths - resulting in IOException and incomplete cleanup. + def hasRootAsShutdownDeleteDir(file: File): Boolean = { + val absolutePath = file.getAbsolutePath() + val retval = shutdownDeletePaths.synchronized { + shutdownDeletePaths.exists { path => + !absolutePath.equals(path) && absolutePath.startsWith(path) + } + } + if (retval) { + logInfo("path = " + file + ", already present as root for deletion.") + } + retval + } + + // Note: if file is child of some registered path, while not equal to it, then return true; + // else false. This is to ensure that two shutdown hooks do not try to delete each others + // paths - resulting in Exception and incomplete cleanup. + def hasRootAsShutdownDeleteDir(file: TachyonFile): Boolean = { + val absolutePath = file.getPath() + val retval = shutdownDeleteTachyonPaths.synchronized { + shutdownDeleteTachyonPaths.exists { path => + !absolutePath.equals(path) && absolutePath.startsWith(path) + } + } + if (retval) { + logInfo("path = " + file + ", already present as root for deletion.") + } + retval + } + + /** + * Detect whether this thread might be executing a shutdown hook. Will always return true if + * the current thread is a running a shutdown hook but may spuriously return true otherwise (e.g. + * if System.exit was just called by a concurrent thread). + * + * Currently, this detects whether the JVM is shutting down by Runtime#addShutdownHook throwing + * an IllegalStateException. + */ + def inShutdown(): Boolean = { + try { + val hook = new Thread { + override def run() {} + } + Runtime.getRuntime.addShutdownHook(hook) + Runtime.getRuntime.removeShutdownHook(hook) + } catch { + case ise: IllegalStateException => return true + } + false + } + + /** + * Adds a shutdown hook with default priority. + * + * @param hook The code to run during shutdown. + * @return A handle that can be used to unregister the shutdown hook. + */ + def addShutdownHook(hook: () => Unit): AnyRef = { + addShutdownHook(DEFAULT_SHUTDOWN_PRIORITY)(hook) + } + + /** + * Adds a shutdown hook with the given priority. Hooks with lower priority values run + * first. + * + * @param hook The code to run during shutdown. + * @return A handle that can be used to unregister the shutdown hook. + */ + def addShutdownHook(priority: Int)(hook: () => Unit): AnyRef = { + shutdownHooks.add(priority, hook) + } + + /** + * Remove a previously installed shutdown hook. + * + * @param ref A handle returned by `addShutdownHook`. + * @return Whether the hook was removed. + */ + def removeShutdownHook(ref: AnyRef): Boolean = { + shutdownHooks.remove(ref) + } + +} + +private [util] class SparkShutdownHookManager { + + private val hooks = new PriorityQueue[SparkShutdownHook]() + private var shuttingDown = false + + /** + * Install a hook to run at shutdown and run all registered hooks in order. Hadoop 1.x does not + * have `ShutdownHookManager`, so in that case we just use the JVM's `Runtime` object and hope for + * the best. + */ + def install(): Unit = { + val hookTask = new Runnable() { + override def run(): Unit = runAll() + } + Try(Utils.classForName("org.apache.hadoop.util.ShutdownHookManager")) match { + case Success(shmClass) => + val fsPriority = classOf[FileSystem].getField("SHUTDOWN_HOOK_PRIORITY").get() + .asInstanceOf[Int] + val shm = shmClass.getMethod("get").invoke(null) + shm.getClass().getMethod("addShutdownHook", classOf[Runnable], classOf[Int]) + .invoke(shm, hookTask, Integer.valueOf(fsPriority + 30)) + + case Failure(_) => + Runtime.getRuntime.addShutdownHook(new Thread(hookTask, "Spark Shutdown Hook")); + } + } + + def runAll(): Unit = synchronized { + shuttingDown = true + while (!hooks.isEmpty()) { + Try(Utils.logUncaughtExceptions(hooks.poll().run())) + } + } + + def add(priority: Int, hook: () => Unit): AnyRef = synchronized { + checkState() + val hookRef = new SparkShutdownHook(priority, hook) + hooks.add(hookRef) + hookRef + } + + def remove(ref: AnyRef): Boolean = synchronized { + hooks.remove(ref) + } + + private def checkState(): Unit = { + if (shuttingDown) { + throw new IllegalStateException("Shutdown hooks cannot be modified during shutdown.") + } + } + +} + +private class SparkShutdownHook(private val priority: Int, hook: () => Unit) + extends Comparable[SparkShutdownHook] { + + override def compareTo(other: SparkShutdownHook): Int = { + other.priority - priority + } + + def run(): Unit = hook() + +} diff --git a/core/src/main/scala/org/apache/spark/util/SparkUncaughtExceptionHandler.scala b/core/src/main/scala/org/apache/spark/util/SparkUncaughtExceptionHandler.scala index ad3db1fbb57ed..7248187247330 100644 --- a/core/src/main/scala/org/apache/spark/util/SparkUncaughtExceptionHandler.scala +++ b/core/src/main/scala/org/apache/spark/util/SparkUncaughtExceptionHandler.scala @@ -33,7 +33,7 @@ private[spark] object SparkUncaughtExceptionHandler // We may have been called from a shutdown hook. If so, we must not call System.exit(). // (If we do, we will deadlock.) - if (!Utils.inShutdown()) { + if (!ShutdownHookManager.inShutdown()) { if (exception.isInstanceOf[OutOfMemoryError]) { System.exit(SparkExitCode.OOM) } else { diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index a90d8541366f4..f2abf227dc129 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -21,7 +21,7 @@ import java.io._ import java.lang.management.ManagementFactory import java.net._ import java.nio.ByteBuffer -import java.util.{PriorityQueue, Properties, Locale, Random, UUID} +import java.util.{Properties, Locale, Random, UUID} import java.util.concurrent._ import javax.net.ssl.HttpsURLConnection @@ -65,21 +65,6 @@ private[spark] object CallSite { private[spark] object Utils extends Logging { val random = new Random() - val DEFAULT_SHUTDOWN_PRIORITY = 100 - - /** - * The shutdown priority of the SparkContext instance. This is lower than the default - * priority, so that by default hooks are run before the context is shut down. - */ - val SPARK_CONTEXT_SHUTDOWN_PRIORITY = 50 - - /** - * The shutdown priority of temp directory must be lower than the SparkContext shutdown - * priority. Otherwise cleaning the temp directories while Spark jobs are running can - * throw undesirable errors at the time of shutdown. - */ - val TEMP_DIR_SHUTDOWN_PRIORITY = 25 - /** * Define a default value for driver memory here since this value is referenced across the code * base and nearly all files already use Utils.scala @@ -90,9 +75,6 @@ private[spark] object Utils extends Logging { @volatile private var localRootDirs: Array[String] = null - private val shutdownHooks = new SparkShutdownHookManager() - shutdownHooks.install() - /** Serialize an object using Java serialization */ def serialize[T](o: T): Array[Byte] = { val bos = new ByteArrayOutputStream() @@ -205,86 +187,6 @@ private[spark] object Utils extends Logging { } } - private val shutdownDeletePaths = new scala.collection.mutable.HashSet[String]() - private val shutdownDeleteTachyonPaths = new scala.collection.mutable.HashSet[String]() - - // Add a shutdown hook to delete the temp dirs when the JVM exits - addShutdownHook(TEMP_DIR_SHUTDOWN_PRIORITY) { () => - logInfo("Shutdown hook called") - shutdownDeletePaths.foreach { dirPath => - try { - logInfo("Deleting directory " + dirPath) - Utils.deleteRecursively(new File(dirPath)) - } catch { - case e: Exception => logError(s"Exception while deleting Spark temp dir: $dirPath", e) - } - } - } - - // Register the path to be deleted via shutdown hook - def registerShutdownDeleteDir(file: File) { - val absolutePath = file.getAbsolutePath() - shutdownDeletePaths.synchronized { - shutdownDeletePaths += absolutePath - } - } - - // Register the tachyon path to be deleted via shutdown hook - def registerShutdownDeleteDir(tachyonfile: TachyonFile) { - val absolutePath = tachyonfile.getPath() - shutdownDeleteTachyonPaths.synchronized { - shutdownDeleteTachyonPaths += absolutePath - } - } - - // Is the path already registered to be deleted via a shutdown hook ? - def hasShutdownDeleteDir(file: File): Boolean = { - val absolutePath = file.getAbsolutePath() - shutdownDeletePaths.synchronized { - shutdownDeletePaths.contains(absolutePath) - } - } - - // Is the path already registered to be deleted via a shutdown hook ? - def hasShutdownDeleteTachyonDir(file: TachyonFile): Boolean = { - val absolutePath = file.getPath() - shutdownDeleteTachyonPaths.synchronized { - shutdownDeleteTachyonPaths.contains(absolutePath) - } - } - - // Note: if file is child of some registered path, while not equal to it, then return true; - // else false. This is to ensure that two shutdown hooks do not try to delete each others - // paths - resulting in IOException and incomplete cleanup. - def hasRootAsShutdownDeleteDir(file: File): Boolean = { - val absolutePath = file.getAbsolutePath() - val retval = shutdownDeletePaths.synchronized { - shutdownDeletePaths.exists { path => - !absolutePath.equals(path) && absolutePath.startsWith(path) - } - } - if (retval) { - logInfo("path = " + file + ", already present as root for deletion.") - } - retval - } - - // Note: if file is child of some registered path, while not equal to it, then return true; - // else false. This is to ensure that two shutdown hooks do not try to delete each others - // paths - resulting in Exception and incomplete cleanup. - def hasRootAsShutdownDeleteDir(file: TachyonFile): Boolean = { - val absolutePath = file.getPath() - val retval = shutdownDeleteTachyonPaths.synchronized { - shutdownDeleteTachyonPaths.exists { path => - !absolutePath.equals(path) && absolutePath.startsWith(path) - } - } - if (retval) { - logInfo("path = " + file + ", already present as root for deletion.") - } - retval - } - /** * JDK equivalent of `chmod 700 file`. * @@ -333,7 +235,7 @@ private[spark] object Utils extends Logging { root: String = System.getProperty("java.io.tmpdir"), namePrefix: String = "spark"): File = { val dir = createDirectory(root, namePrefix) - registerShutdownDeleteDir(dir) + ShutdownHookManager.registerShutdownDeleteDir(dir) dir } @@ -973,9 +875,7 @@ private[spark] object Utils extends Logging { if (savedIOException != null) { throw savedIOException } - shutdownDeletePaths.synchronized { - shutdownDeletePaths.remove(file.getAbsolutePath) - } + ShutdownHookManager.removeShutdownDeleteDir(file) } } finally { if (!file.delete()) { @@ -1478,27 +1378,6 @@ private[spark] object Utils extends Logging { serializer.deserialize[T](serializer.serialize(value)) } - /** - * Detect whether this thread might be executing a shutdown hook. Will always return true if - * the current thread is a running a shutdown hook but may spuriously return true otherwise (e.g. - * if System.exit was just called by a concurrent thread). - * - * Currently, this detects whether the JVM is shutting down by Runtime#addShutdownHook throwing - * an IllegalStateException. - */ - def inShutdown(): Boolean = { - try { - val hook = new Thread { - override def run() {} - } - Runtime.getRuntime.addShutdownHook(hook) - Runtime.getRuntime.removeShutdownHook(hook) - } catch { - case ise: IllegalStateException => return true - } - false - } - private def isSpace(c: Char): Boolean = { " \t\r\n".indexOf(c) != -1 } @@ -2221,37 +2100,6 @@ private[spark] object Utils extends Logging { msg.startsWith(BACKUP_STANDALONE_MASTER_PREFIX) } - /** - * Adds a shutdown hook with default priority. - * - * @param hook The code to run during shutdown. - * @return A handle that can be used to unregister the shutdown hook. - */ - def addShutdownHook(hook: () => Unit): AnyRef = { - addShutdownHook(DEFAULT_SHUTDOWN_PRIORITY)(hook) - } - - /** - * Adds a shutdown hook with the given priority. Hooks with lower priority values run - * first. - * - * @param hook The code to run during shutdown. - * @return A handle that can be used to unregister the shutdown hook. - */ - def addShutdownHook(priority: Int)(hook: () => Unit): AnyRef = { - shutdownHooks.add(priority, hook) - } - - /** - * Remove a previously installed shutdown hook. - * - * @param ref A handle returned by `addShutdownHook`. - * @return Whether the hook was removed. - */ - def removeShutdownHook(ref: AnyRef): Boolean = { - shutdownHooks.remove(ref) - } - /** * To avoid calling `Utils.getCallSite` for every single RDD we create in the body, * set a dummy call site that RDDs use instead. This is for performance optimization. @@ -2299,70 +2147,6 @@ private[spark] object Utils extends Logging { } -private [util] class SparkShutdownHookManager { - - private val hooks = new PriorityQueue[SparkShutdownHook]() - private var shuttingDown = false - - /** - * Install a hook to run at shutdown and run all registered hooks in order. Hadoop 1.x does not - * have `ShutdownHookManager`, so in that case we just use the JVM's `Runtime` object and hope for - * the best. - */ - def install(): Unit = { - val hookTask = new Runnable() { - override def run(): Unit = runAll() - } - Try(Utils.classForName("org.apache.hadoop.util.ShutdownHookManager")) match { - case Success(shmClass) => - val fsPriority = classOf[FileSystem].getField("SHUTDOWN_HOOK_PRIORITY").get() - .asInstanceOf[Int] - val shm = shmClass.getMethod("get").invoke(null) - shm.getClass().getMethod("addShutdownHook", classOf[Runnable], classOf[Int]) - .invoke(shm, hookTask, Integer.valueOf(fsPriority + 30)) - - case Failure(_) => - Runtime.getRuntime.addShutdownHook(new Thread(hookTask, "Spark Shutdown Hook")); - } - } - - def runAll(): Unit = synchronized { - shuttingDown = true - while (!hooks.isEmpty()) { - Try(Utils.logUncaughtExceptions(hooks.poll().run())) - } - } - - def add(priority: Int, hook: () => Unit): AnyRef = synchronized { - checkState() - val hookRef = new SparkShutdownHook(priority, hook) - hooks.add(hookRef) - hookRef - } - - def remove(ref: AnyRef): Boolean = synchronized { - hooks.remove(ref) - } - - private def checkState(): Unit = { - if (shuttingDown) { - throw new IllegalStateException("Shutdown hooks cannot be modified during shutdown.") - } - } - -} - -private class SparkShutdownHook(private val priority: Int, hook: () => Unit) - extends Comparable[SparkShutdownHook] { - - override def compareTo(other: SparkShutdownHook): Int = { - other.priority - priority - } - - def run(): Unit = hook() - -} - /** * A utility class to redirect the child process's stdout or stderr. */ diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala index 9c047347cb58d..2c9fa595b2dad 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala @@ -35,7 +35,7 @@ import org.apache.spark.sql.SQLConf import org.apache.spark.sql.hive.HiveContext import org.apache.spark.sql.hive.thriftserver.ReflectionUtils._ import org.apache.spark.sql.hive.thriftserver.ui.ThriftServerTab -import org.apache.spark.util.Utils +import org.apache.spark.util.{ShutdownHookManager, Utils} import org.apache.spark.{Logging, SparkContext} @@ -76,7 +76,7 @@ object HiveThriftServer2 extends Logging { logInfo("Starting SparkContext") SparkSQLEnv.init() - Utils.addShutdownHook { () => + ShutdownHookManager.addShutdownHook { () => SparkSQLEnv.stop() uiTab.foreach(_.detach()) } diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala index d3886142b388d..7799704c819d9 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala @@ -39,7 +39,7 @@ import org.apache.thrift.transport.TSocket import org.apache.spark.Logging import org.apache.spark.sql.hive.HiveContext -import org.apache.spark.util.Utils +import org.apache.spark.util.{ShutdownHookManager, Utils} /** * This code doesn't support remote connections in Hive 1.2+, as the underlying CliDriver @@ -114,7 +114,7 @@ private[hive] object SparkSQLCLIDriver extends Logging { SessionState.start(sessionState) // Clean up after we exit - Utils.addShutdownHook { () => SparkSQLEnv.stop() } + ShutdownHookManager.addShutdownHook { () => SparkSQLEnv.stop() } val remoteMode = isRemoteMode(sessionState) // "-h" option has been passed, so connect to Hive thrift server. diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala index 296cc5c5e0b04..4eae699ac3b51 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala @@ -36,7 +36,7 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.CacheTableCommand import org.apache.spark.sql.hive._ import org.apache.spark.sql.hive.execution.HiveNativeCommand -import org.apache.spark.util.Utils +import org.apache.spark.util.{ShutdownHookManager, Utils} import org.apache.spark.{SparkConf, SparkContext} /* Implicit conversions */ @@ -154,7 +154,7 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { val hiveFilesTemp = File.createTempFile("catalystHiveFiles", "") hiveFilesTemp.delete() hiveFilesTemp.mkdir() - Utils.registerShutdownDeleteDir(hiveFilesTemp) + ShutdownHookManager.registerShutdownDeleteDir(hiveFilesTemp) val inRepoTests = if (System.getProperty("user.dir").endsWith("sql" + File.separator + "hive")) { new File("src" + File.separator + "test" + File.separator + "resources" + File.separator) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala index 177e710ace54b..b496d1f341a0b 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala @@ -44,7 +44,7 @@ import org.apache.spark.streaming.dstream._ import org.apache.spark.streaming.receiver.{ActorReceiver, ActorSupervisorStrategy, Receiver} import org.apache.spark.streaming.scheduler.{JobScheduler, StreamingListener} import org.apache.spark.streaming.ui.{StreamingJobProgressListener, StreamingTab} -import org.apache.spark.util.{CallSite, Utils} +import org.apache.spark.util.{CallSite, ShutdownHookManager, Utils} /** * Main entry point for Spark Streaming functionality. It provides methods used to create @@ -604,7 +604,7 @@ class StreamingContext private[streaming] ( } StreamingContext.setActiveContext(this) } - shutdownHookRef = Utils.addShutdownHook( + shutdownHookRef = ShutdownHookManager.addShutdownHook( StreamingContext.SHUTDOWN_HOOK_PRIORITY)(stopOnShutdown) // Registering Streaming Metrics at the start of the StreamingContext assert(env.metricsSystem != null) @@ -691,7 +691,7 @@ class StreamingContext private[streaming] ( StreamingContext.setActiveContext(null) waiter.notifyStop() if (shutdownHookRef != null) { - Utils.removeShutdownHook(shutdownHookRef) + ShutdownHookManager.removeShutdownHook(shutdownHookRef) } logInfo("StreamingContext stopped successfully") } @@ -725,7 +725,7 @@ object StreamingContext extends Logging { */ private val ACTIVATION_LOCK = new Object() - private val SHUTDOWN_HOOK_PRIORITY = Utils.SPARK_CONTEXT_SHUTDOWN_PRIORITY + 1 + private val SHUTDOWN_HOOK_PRIORITY = ShutdownHookManager.SPARK_CONTEXT_SHUTDOWN_PRIORITY + 1 private val activeContext = new AtomicReference[StreamingContext](null) diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala index e19940d8d6642..6a8ddb37b29e8 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -112,7 +112,8 @@ private[spark] class ApplicationMaster( val fs = FileSystem.get(yarnConf) // This shutdown hook should run *after* the SparkContext is shut down. - Utils.addShutdownHook(Utils.SPARK_CONTEXT_SHUTDOWN_PRIORITY - 1) { () => + val priority = ShutdownHookManager.SPARK_CONTEXT_SHUTDOWN_PRIORITY - 1 + ShutdownHookManager.addShutdownHook(priority) { () => val maxAppAttempts = client.getMaxRegAttempts(sparkConf, yarnConf) val isLastAttempt = client.getAttemptId().getAttemptId() >= maxAppAttempts @@ -199,7 +200,7 @@ private[spark] class ApplicationMaster( final def finish(status: FinalApplicationStatus, code: Int, msg: String = null): Unit = { synchronized { if (!finished) { - val inShutdown = Utils.inShutdown() + val inShutdown = ShutdownHookManager.inShutdown() logInfo(s"Final app status: $status, exitCode: $code" + Option(msg).map(msg => s", (reason: $msg)").getOrElse("")) exitCode = code From 7035d880a0cf06910c19b4afd49645124c620f14 Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Wed, 12 Aug 2015 16:45:15 -0700 Subject: [PATCH 064/253] [SPARK-9894] [SQL] Json writer should handle MapData. https://issues.apache.org/jira/browse/SPARK-9894 Author: Yin Huai Closes #8137 from yhuai/jsonMapData. --- .../datasources/json/JacksonGenerator.scala | 10 +-- .../sources/JsonHadoopFsRelationSuite.scala | 78 +++++++++++++++++++ .../SimpleTextHadoopFsRelationSuite.scala | 30 ------- 3 files changed, 83 insertions(+), 35 deletions(-) create mode 100644 sql/hive/src/test/scala/org/apache/spark/sql/sources/JsonHadoopFsRelationSuite.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonGenerator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonGenerator.scala index 37c2b5a296c15..99ac7730bd1c9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonGenerator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonGenerator.scala @@ -107,12 +107,12 @@ private[sql] object JacksonGenerator { v.foreach(ty, (_, value) => valWriter(ty, value)) gen.writeEndArray() - case (MapType(kv, vv, _), v: Map[_, _]) => + case (MapType(kt, vt, _), v: MapData) => gen.writeStartObject() - v.foreach { p => - gen.writeFieldName(p._1.toString) - valWriter(vv, p._2) - } + v.foreach(kt, vt, { (k, v) => + gen.writeFieldName(k.toString) + valWriter(vt, v) + }) gen.writeEndObject() case (StructType(ty), v: InternalRow) => diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/JsonHadoopFsRelationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/JsonHadoopFsRelationSuite.scala new file mode 100644 index 0000000000000..ed6d512ab36fe --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/JsonHadoopFsRelationSuite.scala @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources + +import org.apache.hadoop.fs.Path + +import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.sql.Row +import org.apache.spark.sql.types._ + +class JsonHadoopFsRelationSuite extends HadoopFsRelationTest { + override val dataSourceName: String = "json" + + import sqlContext._ + + test("save()/load() - partitioned table - simple queries - partition columns in data") { + withTempDir { file => + val basePath = new Path(file.getCanonicalPath) + val fs = basePath.getFileSystem(SparkHadoopUtil.get.conf) + val qualifiedBasePath = fs.makeQualified(basePath) + + for (p1 <- 1 to 2; p2 <- Seq("foo", "bar")) { + val partitionDir = new Path(qualifiedBasePath, s"p1=$p1/p2=$p2") + sparkContext + .parallelize(for (i <- 1 to 3) yield s"""{"a":$i,"b":"val_$i"}""") + .saveAsTextFile(partitionDir.toString) + } + + val dataSchemaWithPartition = + StructType(dataSchema.fields :+ StructField("p1", IntegerType, nullable = true)) + + checkQueries( + read.format(dataSourceName) + .option("dataSchema", dataSchemaWithPartition.json) + .load(file.getCanonicalPath)) + } + } + + test("SPARK-9894: save complex types to JSON") { + withTempDir { file => + file.delete() + + val schema = + new StructType() + .add("array", ArrayType(LongType)) + .add("map", MapType(StringType, new StructType().add("innerField", LongType))) + + val data = + Row(Seq(1L, 2L, 3L), Map("m1" -> Row(4L))) :: + Row(Seq(5L, 6L, 7L), Map("m2" -> Row(10L))) :: Nil + val df = createDataFrame(sparkContext.parallelize(data), schema) + + // Write the data out. + df.write.format(dataSourceName).save(file.getCanonicalPath) + + // Read it back and check the result. + checkAnswer( + read.format(dataSourceName).schema(schema).load(file.getCanonicalPath), + df + ) + } + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextHadoopFsRelationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextHadoopFsRelationSuite.scala index 48c37a1fa1022..e8975e5f5cd08 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextHadoopFsRelationSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextHadoopFsRelationSuite.scala @@ -50,33 +50,3 @@ class SimpleTextHadoopFsRelationSuite extends HadoopFsRelationTest { } } } - -class JsonHadoopFsRelationSuite extends HadoopFsRelationTest { - override val dataSourceName: String = - classOf[org.apache.spark.sql.execution.datasources.json.DefaultSource].getCanonicalName - - import sqlContext._ - - test("save()/load() - partitioned table - simple queries - partition columns in data") { - withTempDir { file => - val basePath = new Path(file.getCanonicalPath) - val fs = basePath.getFileSystem(SparkHadoopUtil.get.conf) - val qualifiedBasePath = fs.makeQualified(basePath) - - for (p1 <- 1 to 2; p2 <- Seq("foo", "bar")) { - val partitionDir = new Path(qualifiedBasePath, s"p1=$p1/p2=$p2") - sparkContext - .parallelize(for (i <- 1 to 3) yield s"""{"a":$i,"b":"val_$i"}""") - .saveAsTextFile(partitionDir.toString) - } - - val dataSchemaWithPartition = - StructType(dataSchema.fields :+ StructField("p1", IntegerType, nullable = true)) - - checkQueries( - read.format(dataSourceName) - .option("dataSchema", dataSchemaWithPartition.json) - .load(file.getCanonicalPath)) - } - } -} From caa14d9dc9e2eb1102052b22445b63b0e004e3c7 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Wed, 12 Aug 2015 16:53:47 -0700 Subject: [PATCH 065/253] [SPARK-9913] [MLLIB] LDAUtils should be private feynmanliang Author: Xiangrui Meng Closes #8142 from mengxr/SPARK-9913. --- .../main/scala/org/apache/spark/mllib/clustering/LDAUtils.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAUtils.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAUtils.scala index f7e5ce1665fe6..a9ba7b60bad08 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAUtils.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAUtils.scala @@ -22,7 +22,7 @@ import breeze.numerics._ /** * Utility methods for LDA. */ -object LDAUtils { +private[clustering] object LDAUtils { /** * Log Sum Exp with overflow protection using the identity: * For any a: \log \sum_{n=1}^N \exp\{x_n\} = a + \log \sum_{n=1}^N \exp\{x_n - a\} From 6e409bc1357f49de2efdfc4226d074b943fb1153 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Wed, 12 Aug 2015 16:54:45 -0700 Subject: [PATCH 066/253] [SPARK-9909] [ML] [TRIVIAL] move weightCol to shared params As per the TODO move weightCol to Shared Params. Author: Holden Karau Closes #8144 from holdenk/SPARK-9909-move-weightCol-toSharedParams. --- .../ml/param/shared/SharedParamsCodeGen.scala | 4 +++- .../spark/ml/param/shared/sharedParams.scala | 15 +++++++++++++++ .../spark/ml/regression/IsotonicRegression.scala | 16 ++-------------- 3 files changed, 20 insertions(+), 15 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala index 9e12f1856a940..8c16c6149b40d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala @@ -70,7 +70,9 @@ private[shared] object SharedParamsCodeGen { " For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty.", isValid = "ParamValidators.inRange(0, 1)"), ParamDesc[Double]("tol", "the convergence tolerance for iterative algorithms"), - ParamDesc[Double]("stepSize", "Step size to be used for each iteration of optimization.")) + ParamDesc[Double]("stepSize", "Step size to be used for each iteration of optimization."), + ParamDesc[String]("weightCol", "weight column name. If this is not set or empty, we treat " + + "all instance weights as 1.0.")) val code = genSharedParams(params) val file = "src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala" diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala index a17d4ea960a90..c26768953e3db 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala @@ -342,4 +342,19 @@ private[ml] trait HasStepSize extends Params { /** @group getParam */ final def getStepSize: Double = $(stepSize) } + +/** + * Trait for shared param weightCol. + */ +private[ml] trait HasWeightCol extends Params { + + /** + * Param for weight column name. If this is not set or empty, we treat all instance weights as 1.0.. + * @group param + */ + final val weightCol: Param[String] = new Param[String](this, "weightCol", "weight column name. If this is not set or empty, we treat all instance weights as 1.0.") + + /** @group getParam */ + final def getWeightCol: String = $(weightCol) +} // scalastyle:on diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala index f570590960a62..0f33bae30e622 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala @@ -21,7 +21,7 @@ import org.apache.spark.Logging import org.apache.spark.annotation.Experimental import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.param._ -import org.apache.spark.ml.param.shared.{HasFeaturesCol, HasLabelCol, HasPredictionCol} +import org.apache.spark.ml.param.shared.{HasFeaturesCol, HasLabelCol, HasPredictionCol, HasWeightCol} import org.apache.spark.ml.util.{Identifiable, SchemaUtils} import org.apache.spark.mllib.linalg.{Vector, VectorUDT, Vectors} import org.apache.spark.mllib.regression.{IsotonicRegression => MLlibIsotonicRegression, IsotonicRegressionModel => MLlibIsotonicRegressionModel} @@ -35,19 +35,7 @@ import org.apache.spark.storage.StorageLevel * Params for isotonic regression. */ private[regression] trait IsotonicRegressionBase extends Params with HasFeaturesCol - with HasLabelCol with HasPredictionCol with Logging { - - /** - * Param for weight column name (default: none). - * @group param - */ - // TODO: Move weightCol to sharedParams. - final val weightCol: Param[String] = - new Param[String](this, "weightCol", - "weight column name. If this is not set or empty, we treat all instance weights as 1.0.") - - /** @group getParam */ - final def getWeightCol: String = $(weightCol) + with HasLabelCol with HasPredictionCol with HasWeightCol with Logging { /** * Param for whether the output sequence should be isotonic/increasing (true) or From e6aef55766d0e2a48e0f9cb6eda0e31a71b962f3 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Wed, 12 Aug 2015 17:04:31 -0700 Subject: [PATCH 067/253] [SPARK-9912] [MLLIB] QRDecomposition should use QType and RType for type names instead of UType and VType hhbyyh Author: Xiangrui Meng Closes #8140 from mengxr/SPARK-9912. --- .../apache/spark/mllib/linalg/SingularValueDecomposition.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/SingularValueDecomposition.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/SingularValueDecomposition.scala index b416d50a5631e..cff5dbeee3e57 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/SingularValueDecomposition.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/SingularValueDecomposition.scala @@ -31,5 +31,5 @@ case class SingularValueDecomposition[UType, VType](U: UType, s: Vector, V: VTyp * Represents QR factors. */ @Experimental -case class QRDecomposition[UType, VType](Q: UType, R: VType) +case class QRDecomposition[QType, RType](Q: QType, R: RType) From fc1c7fd66e64ccea53b31cd2fbb98bc6d307329c Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Wed, 12 Aug 2015 17:06:12 -0700 Subject: [PATCH 068/253] [SPARK-9915] [ML] stopWords should use StringArrayParam hhbyyh Author: Xiangrui Meng Closes #8141 from mengxr/SPARK-9915. --- .../org/apache/spark/ml/feature/StopWordsRemover.scala | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala index 3cc41424460f2..5d77ea08db657 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala @@ -19,12 +19,12 @@ package org.apache.spark.ml.feature import org.apache.spark.annotation.Experimental import org.apache.spark.ml.Transformer +import org.apache.spark.ml.param.{BooleanParam, ParamMap, StringArrayParam} import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol} -import org.apache.spark.ml.param.{ParamMap, BooleanParam, Param} import org.apache.spark.ml.util.Identifiable import org.apache.spark.sql.DataFrame -import org.apache.spark.sql.types.{StringType, StructField, ArrayType, StructType} import org.apache.spark.sql.functions.{col, udf} +import org.apache.spark.sql.types.{ArrayType, StringType, StructField, StructType} /** * stop words list @@ -100,7 +100,7 @@ class StopWordsRemover(override val uid: String) * the stop words set to be filtered out * @group param */ - val stopWords: Param[Array[String]] = new Param(this, "stopWords", "stop words") + val stopWords: StringArrayParam = new StringArrayParam(this, "stopWords", "stop words") /** @group setParam */ def setStopWords(value: Array[String]): this.type = set(stopWords, value) From 660e6dcff8125b83cc73dbe00c90cbe58744bc66 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Wed, 12 Aug 2015 17:07:29 -0700 Subject: [PATCH 069/253] [SPARK-9449] [SQL] Include MetastoreRelation's inputFiles Author: Michael Armbrust Closes #8119 from marmbrus/metastoreInputFiles. --- .../org/apache/spark/sql/DataFrame.scala | 10 ++++--- .../spark/sql/execution/FileRelation.scala | 28 +++++++++++++++++++ .../apache/spark/sql/sources/interfaces.scala | 6 ++-- .../org/apache/spark/sql/DataFrameSuite.scala | 26 +++++++++-------- .../spark/sql/hive/HiveMetastoreCatalog.scala | 16 +++++++++-- 5 files changed, 66 insertions(+), 20 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/FileRelation.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 27b994f1f0caf..c466d9e6cb349 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -34,10 +34,10 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.logical.{Filter, _} +import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.{Inner, JoinType} import org.apache.spark.sql.catalyst.{CatalystTypeConverters, ScalaReflection, SqlParser} -import org.apache.spark.sql.execution.{EvaluatePython, ExplainCommand, LogicalRDD, SQLExecution} +import org.apache.spark.sql.execution.{EvaluatePython, ExplainCommand, FileRelation, LogicalRDD, SQLExecution} import org.apache.spark.sql.execution.datasources.{CreateTableUsingAsSelect, LogicalRelation} import org.apache.spark.sql.execution.datasources.json.JacksonGenerator import org.apache.spark.sql.sources.HadoopFsRelation @@ -1560,8 +1560,10 @@ class DataFrame private[sql]( */ def inputFiles: Array[String] = { val files: Seq[String] = logicalPlan.collect { - case LogicalRelation(fsBasedRelation: HadoopFsRelation) => - fsBasedRelation.paths.toSeq + case LogicalRelation(fsBasedRelation: FileRelation) => + fsBasedRelation.inputFiles + case fr: FileRelation => + fr.inputFiles }.flatten files.toSet.toArray } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/FileRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/FileRelation.scala new file mode 100644 index 0000000000000..7a2a9eed5807d --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/FileRelation.scala @@ -0,0 +1,28 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +/** + * An interface for relations that are backed by files. When a class implements this interface, + * the list of paths that it returns will be returned to a user who calls `inputPaths` on any + * DataFrame that queries this relation. + */ +private[sql] trait FileRelation { + /** Returns the list of files that will be read when scanning this relation. */ + def inputFiles: Array[String] +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala index 2f8417a48d32e..b3b326fe612c7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala @@ -31,7 +31,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.GenerateMutableProjection -import org.apache.spark.sql.execution.RDDConversions +import org.apache.spark.sql.execution.{FileRelation, RDDConversions} import org.apache.spark.sql.execution.datasources.{PartitioningUtils, PartitionSpec, Partition} import org.apache.spark.sql.types.StructType import org.apache.spark.sql._ @@ -406,7 +406,7 @@ abstract class OutputWriter { */ @Experimental abstract class HadoopFsRelation private[sql](maybePartitionSpec: Option[PartitionSpec]) - extends BaseRelation with Logging { + extends BaseRelation with FileRelation with Logging { override def toString: String = getClass.getSimpleName + paths.mkString("[", ",", "]") @@ -516,6 +516,8 @@ abstract class HadoopFsRelation private[sql](maybePartitionSpec: Option[Partitio */ def paths: Array[String] + override def inputFiles: Array[String] = cachedLeafStatuses().map(_.getPath.toString).toArray + /** * Partition columns. Can be either defined by [[userDefinedPartitionColumns]] or automatically * discovered. Note that they should always be nullable. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index adbd95197d7ca..2feec29955bc8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -485,21 +485,23 @@ class DataFrameSuite extends QueryTest with SQLTestUtils { } test("inputFiles") { - val fakeRelation1 = new ParquetRelation(Array("/my/path", "/my/other/path"), - Some(testData.schema), None, Map.empty)(sqlContext) - val df1 = DataFrame(sqlContext, LogicalRelation(fakeRelation1)) - assert(df1.inputFiles.toSet == fakeRelation1.paths.toSet) + withTempDir { dir => + val df = Seq((1, 22)).toDF("a", "b") - val fakeRelation2 = new JSONRelation( - None, 1, Some(testData.schema), None, None, Array("/json/path"))(sqlContext) - val df2 = DataFrame(sqlContext, LogicalRelation(fakeRelation2)) - assert(df2.inputFiles.toSet == fakeRelation2.paths.toSet) + val parquetDir = new File(dir, "parquet").getCanonicalPath + df.write.parquet(parquetDir) + val parquetDF = sqlContext.read.parquet(parquetDir) + assert(parquetDF.inputFiles.nonEmpty) - val unionDF = df1.unionAll(df2) - assert(unionDF.inputFiles.toSet == fakeRelation1.paths.toSet ++ fakeRelation2.paths) + val jsonDir = new File(dir, "json").getCanonicalPath + df.write.json(jsonDir) + val jsonDF = sqlContext.read.json(jsonDir) + assert(parquetDF.inputFiles.nonEmpty) - val filtered = df1.filter("false").unionAll(df2.intersect(df2)) - assert(filtered.inputFiles.toSet == fakeRelation1.paths.toSet ++ fakeRelation2.paths) + val unioned = jsonDF.unionAll(parquetDF).inputFiles.sorted + val allFiles = (jsonDF.inputFiles ++ parquetDF.inputFiles).toSet.toArray.sorted + assert(unioned === allFiles) + } } ignore("show") { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index ac9aaed19d566..5e5497837a393 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -38,7 +38,7 @@ import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ import org.apache.spark.sql.catalyst.{InternalRow, SqlParser, TableIdentifier} -import org.apache.spark.sql.execution.datasources +import org.apache.spark.sql.execution.{FileRelation, datasources} import org.apache.spark.sql.execution.datasources.{CreateTableUsingAsSelect, LogicalRelation, Partition => ParquetPartition, PartitionSpec, ResolvedDataSource} import org.apache.spark.sql.hive.client._ import org.apache.spark.sql.execution.datasources.parquet.ParquetRelation @@ -739,7 +739,7 @@ private[hive] case class MetastoreRelation (databaseName: String, tableName: String, alias: Option[String]) (val table: HiveTable) (@transient sqlContext: SQLContext) - extends LeafNode with MultiInstanceRelation { + extends LeafNode with MultiInstanceRelation with FileRelation { override def equals(other: Any): Boolean = other match { case relation: MetastoreRelation => @@ -888,6 +888,18 @@ private[hive] case class MetastoreRelation /** An attribute map for determining the ordinal for non-partition columns. */ val columnOrdinals = AttributeMap(attributes.zipWithIndex) + override def inputFiles: Array[String] = { + val partLocations = table.getPartitions(Nil).map(_.storage.location).toArray + if (partLocations.nonEmpty) { + partLocations + } else { + Array( + table.location.getOrElse( + sys.error(s"Could not get the location of ${table.qualifiedName}."))) + } + } + + override def newInstance(): MetastoreRelation = { MetastoreRelation(databaseName, tableName, alias)(table)(sqlContext) } From 8ce60963cb0928058ef7b6e29ff94eb69d1143af Mon Sep 17 00:00:00 2001 From: cody koeninger Date: Wed, 12 Aug 2015 17:44:16 -0700 Subject: [PATCH 070/253] =?UTF-8?q?[SPARK-9780]=20[STREAMING]=20[KAFKA]=20?= =?UTF-8?q?prevent=20NPE=20if=20KafkaRDD=20instantiation=20=E2=80=A6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit …fails Author: cody koeninger Closes #8133 from koeninger/SPARK-9780 and squashes the following commits: 406259d [cody koeninger] [SPARK-9780][Streaming][Kafka] prevent NPE if KafkaRDD instantiation fails --- .../scala/org/apache/spark/streaming/kafka/KafkaRDD.scala | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDD.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDD.scala index 1a9d78c0d4f59..ea5f842c6cafe 100644 --- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDD.scala +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDD.scala @@ -197,7 +197,11 @@ class KafkaRDD[ .dropWhile(_.offset < requestOffset) } - override def close(): Unit = consumer.close() + override def close(): Unit = { + if (consumer != null) { + consumer.close() + } + } override def getNext(): R = { if (iter == null || !iter.hasNext) { From 0d1d146c220f0d47d0e62b368d5b94d3bd9dd197 Mon Sep 17 00:00:00 2001 From: Rohit Agarwal Date: Wed, 12 Aug 2015 17:48:43 -0700 Subject: [PATCH 071/253] [SPARK-9724] [WEB UI] Avoid unnecessary redirects in the Spark Web UI. Author: Rohit Agarwal Closes #8014 from mindprince/SPARK-9724 and squashes the following commits: a7af5ff [Rohit Agarwal] [SPARK-9724] [WEB UI] Inline attachPrefix and attachPrefixForRedirect. Fix logic of attachPrefix 8a977cd [Rohit Agarwal] [SPARK-9724] [WEB UI] Address review comments: Remove unneeded code, update scaladoc. b257844 [Rohit Agarwal] [SPARK-9724] [WEB UI] Avoid unnecessary redirects in the Spark Web UI. --- .../main/scala/org/apache/spark/ui/JettyUtils.scala | 13 ++++++------- .../main/scala/org/apache/spark/ui/SparkUI.scala | 4 ++-- 2 files changed, 8 insertions(+), 9 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala index c8356467fab87..779c0ba083596 100644 --- a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala @@ -106,7 +106,11 @@ private[spark] object JettyUtils extends Logging { path: String, servlet: HttpServlet, basePath: String): ServletContextHandler = { - val prefixedPath = attachPrefix(basePath, path) + val prefixedPath = if (basePath == "" && path == "/") { + path + } else { + (basePath + path).stripSuffix("/") + } val contextHandler = new ServletContextHandler val holder = new ServletHolder(servlet) contextHandler.setContextPath(prefixedPath) @@ -121,7 +125,7 @@ private[spark] object JettyUtils extends Logging { beforeRedirect: HttpServletRequest => Unit = x => (), basePath: String = "", httpMethods: Set[String] = Set("GET")): ServletContextHandler = { - val prefixedDestPath = attachPrefix(basePath, destPath) + val prefixedDestPath = basePath + destPath val servlet = new HttpServlet { override def doGet(request: HttpServletRequest, response: HttpServletResponse): Unit = { if (httpMethods.contains("GET")) { @@ -246,11 +250,6 @@ private[spark] object JettyUtils extends Logging { val (server, boundPort) = Utils.startServiceOnPort[Server](port, connect, conf, serverName) ServerInfo(server, boundPort, collection) } - - /** Attach a prefix to the given path, but avoid returning an empty path */ - private def attachPrefix(basePath: String, relativePath: String): String = { - if (basePath == "") relativePath else (basePath + relativePath).stripSuffix("/") - } } private[spark] case class ServerInfo( diff --git a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala index 3788916cf39bb..d8b90568b7b9a 100644 --- a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala +++ b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala @@ -64,11 +64,11 @@ private[spark] class SparkUI private ( attachTab(new EnvironmentTab(this)) attachTab(new ExecutorsTab(this)) attachHandler(createStaticHandler(SparkUI.STATIC_RESOURCE_DIR, "/static")) - attachHandler(createRedirectHandler("/", "/jobs", basePath = basePath)) + attachHandler(createRedirectHandler("/", "/jobs/", basePath = basePath)) attachHandler(ApiRootResource.getServletHandler(this)) // This should be POST only, but, the YARN AM proxy won't proxy POSTs attachHandler(createRedirectHandler( - "/stages/stage/kill", "/stages", stagesTab.handleKillRequest, + "/stages/stage/kill", "/stages/", stagesTab.handleKillRequest, httpMethods = Set("GET", "POST"))) } initialize() From f4bc01f1f33a93e6affe5c8a3e33ffbd92d03f38 Mon Sep 17 00:00:00 2001 From: Yu ISHIKAWA Date: Wed, 12 Aug 2015 18:33:27 -0700 Subject: [PATCH 072/253] [SPARK-9855] [SPARKR] Add expression functions into SparkR whose params are simple I added lots of expression functions for SparkR. This PR includes only functions whose params are only `(Column)` or `(Column, Column)`. And I think we need to improve how to test those functions. However, it would be better to work on another issue. ## Diff Summary - Add lots of functions in `functions.R` and their generic in `generic.R` - Add aliases for `ceiling` and `sign` - Move expression functions from `column.R` to `functions.R` - Modify `rdname` from `column` to `functions` I haven't supported `not` function, because the name has a collesion with `testthat` package. I didn't think of the way to define it. ## New Supported Functions ``` approxCountDistinct ascii base64 bin bitwiseNOT ceil (alias: ceiling) crc32 dayofmonth dayofyear explode factorial hex hour initcap isNaN last_day length log2 ltrim md5 minute month negate quarter reverse round rtrim second sha1 signum (alias: sign) size soundex to_date trim unbase64 unhex weekofyear year datediff levenshtein months_between nanvl pmod ``` ## JIRA [[SPARK-9855] Add expression functions into SparkR whose params are simple - ASF JIRA](https://issues.apache.org/jira/browse/SPARK-9855) Author: Yu ISHIKAWA Closes #8123 from yu-iskw/SPARK-9855. --- R/pkg/DESCRIPTION | 1 + R/pkg/R/column.R | 81 -------------- R/pkg/R/functions.R | 123 ++++++++++++++++++++ R/pkg/R/generics.R | 185 +++++++++++++++++++++++++++++-- R/pkg/inst/tests/test_sparkSQL.R | 21 ++-- 5 files changed, 309 insertions(+), 102 deletions(-) create mode 100644 R/pkg/R/functions.R diff --git a/R/pkg/DESCRIPTION b/R/pkg/DESCRIPTION index 4949d86d20c91..83e64897216b1 100644 --- a/R/pkg/DESCRIPTION +++ b/R/pkg/DESCRIPTION @@ -29,6 +29,7 @@ Collate: 'client.R' 'context.R' 'deserialize.R' + 'functions.R' 'mllib.R' 'serialize.R' 'sparkR.R' diff --git a/R/pkg/R/column.R b/R/pkg/R/column.R index eeaf9f193b728..328f595d0805f 100644 --- a/R/pkg/R/column.R +++ b/R/pkg/R/column.R @@ -60,12 +60,6 @@ operators <- list( ) column_functions1 <- c("asc", "desc", "isNull", "isNotNull") column_functions2 <- c("like", "rlike", "startsWith", "endsWith", "getField", "getItem", "contains") -functions <- c("min", "max", "sum", "avg", "mean", "count", "abs", "sqrt", - "first", "last", "lower", "upper", "sumDistinct", - "acos", "asin", "atan", "cbrt", "ceiling", "cos", "cosh", "exp", - "expm1", "floor", "log", "log10", "log1p", "rint", "sign", - "sin", "sinh", "tan", "tanh", "toDegrees", "toRadians") -binary_mathfunctions <- c("atan2", "hypot") createOperator <- function(op) { setMethod(op, @@ -111,33 +105,6 @@ createColumnFunction2 <- function(name) { }) } -createStaticFunction <- function(name) { - setMethod(name, - signature(x = "Column"), - function(x) { - if (name == "ceiling") { - name <- "ceil" - } - if (name == "sign") { - name <- "signum" - } - jc <- callJStatic("org.apache.spark.sql.functions", name, x@jc) - column(jc) - }) -} - -createBinaryMathfunctions <- function(name) { - setMethod(name, - signature(y = "Column"), - function(y, x) { - if (class(x) == "Column") { - x <- x@jc - } - jc <- callJStatic("org.apache.spark.sql.functions", name, y@jc, x) - column(jc) - }) -} - createMethods <- function() { for (op in names(operators)) { createOperator(op) @@ -148,12 +115,6 @@ createMethods <- function() { for (name in column_functions2) { createColumnFunction2(name) } - for (x in functions) { - createStaticFunction(x) - } - for (name in binary_mathfunctions) { - createBinaryMathfunctions(name) - } } createMethods() @@ -242,45 +203,3 @@ setMethod("%in%", jc <- callJMethod(x@jc, "in", table) return(column(jc)) }) - -#' Approx Count Distinct -#' -#' @rdname column -#' @return the approximate number of distinct items in a group. -setMethod("approxCountDistinct", - signature(x = "Column"), - function(x, rsd = 0.95) { - jc <- callJStatic("org.apache.spark.sql.functions", "approxCountDistinct", x@jc, rsd) - column(jc) - }) - -#' Count Distinct -#' -#' @rdname column -#' @return the number of distinct items in a group. -setMethod("countDistinct", - signature(x = "Column"), - function(x, ...) { - jcol <- lapply(list(...), function (x) { - x@jc - }) - jc <- callJStatic("org.apache.spark.sql.functions", "countDistinct", x@jc, - listToSeq(jcol)) - column(jc) - }) - -#' @rdname column -#' @aliases countDistinct -setMethod("n_distinct", - signature(x = "Column"), - function(x, ...) { - countDistinct(x, ...) - }) - -#' @rdname column -#' @aliases count -setMethod("n", - signature(x = "Column"), - function(x) { - count(x) - }) diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R new file mode 100644 index 0000000000000..a15d2d5da534e --- /dev/null +++ b/R/pkg/R/functions.R @@ -0,0 +1,123 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +#' @include generics.R column.R +NULL + +#' @title S4 expression functions for DataFrame column(s) +#' @description These are expression functions on DataFrame columns + +functions1 <- c( + "abs", "acos", "approxCountDistinct", "ascii", "asin", "atan", + "avg", "base64", "bin", "bitwiseNOT", "cbrt", "ceil", "cos", "cosh", "count", + "crc32", "dayofmonth", "dayofyear", "exp", "explode", "expm1", "factorial", + "first", "floor", "hex", "hour", "initcap", "isNaN", "last", "last_day", + "length", "log", "log10", "log1p", "log2", "lower", "ltrim", "max", "md5", + "mean", "min", "minute", "month", "negate", "quarter", "reverse", + "rint", "round", "rtrim", "second", "sha1", "signum", "sin", "sinh", "size", + "soundex", "sqrt", "sum", "sumDistinct", "tan", "tanh", "toDegrees", + "toRadians", "to_date", "trim", "unbase64", "unhex", "upper", "weekofyear", + "year") +functions2 <- c( + "atan2", "datediff", "hypot", "levenshtein", "months_between", "nanvl", "pmod") + +createFunction1 <- function(name) { + setMethod(name, + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", name, x@jc) + column(jc) + }) +} + +createFunction2 <- function(name) { + setMethod(name, + signature(y = "Column"), + function(y, x) { + if (class(x) == "Column") { + x <- x@jc + } + jc <- callJStatic("org.apache.spark.sql.functions", name, y@jc, x) + column(jc) + }) +} + +createFunctions <- function() { + for (name in functions1) { + createFunction1(name) + } + for (name in functions2) { + createFunction2(name) + } +} + +createFunctions() + +#' Approx Count Distinct +#' +#' @rdname functions +#' @return the approximate number of distinct items in a group. +setMethod("approxCountDistinct", + signature(x = "Column"), + function(x, rsd = 0.95) { + jc <- callJStatic("org.apache.spark.sql.functions", "approxCountDistinct", x@jc, rsd) + column(jc) + }) + +#' Count Distinct +#' +#' @rdname functions +#' @return the number of distinct items in a group. +setMethod("countDistinct", + signature(x = "Column"), + function(x, ...) { + jcol <- lapply(list(...), function (x) { + x@jc + }) + jc <- callJStatic("org.apache.spark.sql.functions", "countDistinct", x@jc, + listToSeq(jcol)) + column(jc) + }) + +#' @rdname functions +#' @aliases ceil +setMethod("ceiling", + signature(x = "Column"), + function(x) { + ceil(x) + }) + +#' @rdname functions +#' @aliases signum +setMethod("sign", signature(x = "Column"), + function(x) { + signum(x) + }) + +#' @rdname functions +#' @aliases countDistinct +setMethod("n_distinct", signature(x = "Column"), + function(x, ...) { + countDistinct(x, ...) + }) + +#' @rdname functions +#' @aliases count +setMethod("n", signature(x = "Column"), + function(x) { + count(x) + }) diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 379a78b1d833e..f11e7fcb6a07c 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -575,10 +575,6 @@ setGeneric("approxCountDistinct", function(x, ...) { standardGeneric("approxCoun #' @export setGeneric("asc", function(x) { standardGeneric("asc") }) -#' @rdname column -#' @export -setGeneric("avg", function(x, ...) { standardGeneric("avg") }) - #' @rdname column #' @export setGeneric("between", function(x, bounds) { standardGeneric("between") }) @@ -587,13 +583,10 @@ setGeneric("between", function(x, bounds) { standardGeneric("between") }) #' @export setGeneric("cast", function(x, dataType) { standardGeneric("cast") }) -#' @rdname column -#' @export -setGeneric("cbrt", function(x) { standardGeneric("cbrt") }) - #' @rdname column #' @export setGeneric("contains", function(x, ...) { standardGeneric("contains") }) + #' @rdname column #' @export setGeneric("countDistinct", function(x, ...) { standardGeneric("countDistinct") }) @@ -658,22 +651,190 @@ setGeneric("rlike", function(x, ...) { standardGeneric("rlike") }) #' @export setGeneric("startsWith", function(x, ...) { standardGeneric("startsWith") }) -#' @rdname column + +###################### Expression Function Methods ########################## + +#' @rdname functions +#' @export +setGeneric("ascii", function(x) { standardGeneric("ascii") }) + +#' @rdname functions +#' @export +setGeneric("avg", function(x, ...) { standardGeneric("avg") }) + +#' @rdname functions +#' @export +setGeneric("base64", function(x) { standardGeneric("base64") }) + +#' @rdname functions +#' @export +setGeneric("bin", function(x) { standardGeneric("bin") }) + +#' @rdname functions +#' @export +setGeneric("bitwiseNOT", function(x) { standardGeneric("bitwiseNOT") }) + +#' @rdname functions +#' @export +setGeneric("cbrt", function(x) { standardGeneric("cbrt") }) + +#' @rdname functions +#' @export +setGeneric("ceil", function(x) { standardGeneric("ceil") }) + +#' @rdname functions +#' @export +setGeneric("crc32", function(x) { standardGeneric("crc32") }) + +#' @rdname functions +#' @export +setGeneric("datediff", function(y, x) { standardGeneric("datediff") }) + +#' @rdname functions +#' @export +setGeneric("dayofmonth", function(x) { standardGeneric("dayofmonth") }) + +#' @rdname functions +#' @export +setGeneric("dayofyear", function(x) { standardGeneric("dayofyear") }) + +#' @rdname functions +#' @export +setGeneric("explode", function(x) { standardGeneric("explode") }) + +#' @rdname functions +#' @export +setGeneric("hex", function(x) { standardGeneric("hex") }) + +#' @rdname functions +#' @export +setGeneric("hour", function(x) { standardGeneric("hour") }) + +#' @rdname functions +#' @export +setGeneric("initcap", function(x) { standardGeneric("initcap") }) + +#' @rdname functions +#' @export +setGeneric("isNaN", function(x) { standardGeneric("isNaN") }) + +#' @rdname functions +#' @export +setGeneric("last_day", function(x) { standardGeneric("last_day") }) + +#' @rdname functions +#' @export +setGeneric("levenshtein", function(y, x) { standardGeneric("levenshtein") }) + +#' @rdname functions +#' @export +setGeneric("lower", function(x) { standardGeneric("lower") }) + +#' @rdname functions +#' @export +setGeneric("ltrim", function(x) { standardGeneric("ltrim") }) + +#' @rdname functions +#' @export +setGeneric("md5", function(x) { standardGeneric("md5") }) + +#' @rdname functions +#' @export +setGeneric("minute", function(x) { standardGeneric("minute") }) + +#' @rdname functions +#' @export +setGeneric("month", function(x) { standardGeneric("month") }) + +#' @rdname functions +#' @export +setGeneric("months_between", function(y, x) { standardGeneric("months_between") }) + +#' @rdname functions +#' @export +setGeneric("nanvl", function(y, x) { standardGeneric("nanvl") }) + +#' @rdname functions +#' @export +setGeneric("negate", function(x) { standardGeneric("negate") }) + +#' @rdname functions +#' @export +setGeneric("pmod", function(y, x) { standardGeneric("pmod") }) + +#' @rdname functions +#' @export +setGeneric("quarter", function(x) { standardGeneric("quarter") }) + +#' @rdname functions +#' @export +setGeneric("reverse", function(x) { standardGeneric("reverse") }) + +#' @rdname functions +#' @export +setGeneric("rtrim", function(x) { standardGeneric("rtrim") }) + +#' @rdname functions +#' @export +setGeneric("second", function(x) { standardGeneric("second") }) + +#' @rdname functions +#' @export +setGeneric("sha1", function(x) { standardGeneric("sha1") }) + +#' @rdname functions +#' @export +setGeneric("signum", function(x) { standardGeneric("signum") }) + +#' @rdname functions +#' @export +setGeneric("size", function(x) { standardGeneric("size") }) + +#' @rdname functions +#' @export +setGeneric("soundex", function(x) { standardGeneric("soundex") }) + +#' @rdname functions #' @export setGeneric("sumDistinct", function(x) { standardGeneric("sumDistinct") }) -#' @rdname column +#' @rdname functions #' @export setGeneric("toDegrees", function(x) { standardGeneric("toDegrees") }) -#' @rdname column +#' @rdname functions #' @export setGeneric("toRadians", function(x) { standardGeneric("toRadians") }) -#' @rdname column +#' @rdname functions +#' @export +setGeneric("to_date", function(x) { standardGeneric("to_date") }) + +#' @rdname functions +#' @export +setGeneric("trim", function(x) { standardGeneric("trim") }) + +#' @rdname functions +#' @export +setGeneric("unbase64", function(x) { standardGeneric("unbase64") }) + +#' @rdname functions +#' @export +setGeneric("unhex", function(x) { standardGeneric("unhex") }) + +#' @rdname functions #' @export setGeneric("upper", function(x) { standardGeneric("upper") }) +#' @rdname functions +#' @export +setGeneric("weekofyear", function(x) { standardGeneric("weekofyear") }) + +#' @rdname functions +#' @export +setGeneric("year", function(x) { standardGeneric("year") }) + + #' @rdname glm #' @export setGeneric("glm") diff --git a/R/pkg/inst/tests/test_sparkSQL.R b/R/pkg/inst/tests/test_sparkSQL.R index 7377fc8f1ca9c..e6d3b21ff825b 100644 --- a/R/pkg/inst/tests/test_sparkSQL.R +++ b/R/pkg/inst/tests/test_sparkSQL.R @@ -640,15 +640,18 @@ test_that("column operators", { test_that("column functions", { c <- SparkR:::col("a") - c2 <- min(c) + max(c) + sum(c) + avg(c) + count(c) + abs(c) + sqrt(c) - c3 <- lower(c) + upper(c) + first(c) + last(c) - c4 <- approxCountDistinct(c) + countDistinct(c) + cast(c, "string") - c5 <- n(c) + n_distinct(c) - c5 <- acos(c) + asin(c) + atan(c) + cbrt(c) - c6 <- ceiling(c) + cos(c) + cosh(c) + exp(c) + expm1(c) - c7 <- floor(c) + log(c) + log10(c) + log1p(c) + rint(c) - c8 <- sign(c) + sin(c) + sinh(c) + tan(c) + tanh(c) - c9 <- toDegrees(c) + toRadians(c) + c1 <- abs(c) + acos(c) + approxCountDistinct(c) + ascii(c) + asin(c) + atan(c) + c2 <- avg(c) + base64(c) + bin(c) + bitwiseNOT(c) + cbrt(c) + ceil(c) + cos(c) + c3 <- cosh(c) + count(c) + crc32(c) + dayofmonth(c) + dayofyear(c) + exp(c) + c4 <- explode(c) + expm1(c) + factorial(c) + first(c) + floor(c) + hex(c) + c5 <- hour(c) + initcap(c) + isNaN(c) + last(c) + last_day(c) + length(c) + c6 <- log(c) + (c) + log1p(c) + log2(c) + lower(c) + ltrim(c) + max(c) + md5(c) + c7 <- mean(c) + min(c) + minute(c) + month(c) + negate(c) + quarter(c) + c8 <- reverse(c) + rint(c) + round(c) + rtrim(c) + second(c) + sha1(c) + c9 <- signum(c) + sin(c) + sinh(c) + size(c) + soundex(c) + sqrt(c) + sum(c) + c10 <- sumDistinct(c) + tan(c) + tanh(c) + toDegrees(c) + toRadians(c) + c11 <- to_date(c) + trim(c) + unbase64(c) + unhex(c) + upper(c) + weekofyear(c) + c12 <- year(c) df <- jsonFile(sqlContext, jsonPath) df2 <- select(df, between(df$age, c(20, 30)), between(df$age, c(10, 20))) From 7b13ed27c1296cf76d0946e400f3449c335c8471 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Wed, 12 Aug 2015 18:52:11 -0700 Subject: [PATCH 073/253] [SPARK-9870] Disable driver UI and Master REST server in SparkSubmitSuite I think that we should pass additional configuration flags to disable the driver UI and Master REST server in SparkSubmitSuite and HiveSparkSubmitSuite. This might cut down on port-contention-related flakiness in Jenkins. Author: Josh Rosen Closes #8124 from JoshRosen/disable-ui-in-sparksubmitsuite. --- .../org/apache/spark/deploy/SparkSubmitSuite.scala | 7 +++++++ .../apache/spark/sql/hive/HiveSparkSubmitSuite.scala | 10 +++++++++- 2 files changed, 16 insertions(+), 1 deletion(-) diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala index 2456c5d0d49b0..1110ca6051a40 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala @@ -324,6 +324,8 @@ class SparkSubmitSuite "--class", SimpleApplicationTest.getClass.getName.stripSuffix("$"), "--name", "testApp", "--master", "local", + "--conf", "spark.ui.enabled=false", + "--conf", "spark.master.rest.enabled=false", unusedJar.toString) runSparkSubmit(args) } @@ -337,6 +339,8 @@ class SparkSubmitSuite "--class", JarCreationTest.getClass.getName.stripSuffix("$"), "--name", "testApp", "--master", "local-cluster[2,1,1024]", + "--conf", "spark.ui.enabled=false", + "--conf", "spark.master.rest.enabled=false", "--jars", jarsString, unusedJar.toString, "SparkSubmitClassA", "SparkSubmitClassB") runSparkSubmit(args) @@ -355,6 +359,7 @@ class SparkSubmitSuite "--packages", Seq(main, dep).mkString(","), "--repositories", repo, "--conf", "spark.ui.enabled=false", + "--conf", "spark.master.rest.enabled=false", unusedJar.toString, "my.great.lib.MyLib", "my.great.dep.MyLib") runSparkSubmit(args) @@ -500,6 +505,8 @@ class SparkSubmitSuite "--master", "local", "--conf", "spark.driver.extraClassPath=" + systemJar, "--conf", "spark.driver.userClassPathFirst=true", + "--conf", "spark.ui.enabled=false", + "--conf", "spark.master.rest.enabled=false", userJar.toString) runSparkSubmit(args) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala index b8d41065d3f02..1e1972d1ac353 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala @@ -57,6 +57,8 @@ class HiveSparkSubmitSuite "--class", SparkSubmitClassLoaderTest.getClass.getName.stripSuffix("$"), "--name", "SparkSubmitClassLoaderTest", "--master", "local-cluster[2,1,1024]", + "--conf", "spark.ui.enabled=false", + "--conf", "spark.master.rest.enabled=false", "--jars", jarsString, unusedJar.toString, "SparkSubmitClassA", "SparkSubmitClassB") runSparkSubmit(args) @@ -68,6 +70,8 @@ class HiveSparkSubmitSuite "--class", SparkSQLConfTest.getClass.getName.stripSuffix("$"), "--name", "SparkSQLConfTest", "--master", "local-cluster[2,1,1024]", + "--conf", "spark.ui.enabled=false", + "--conf", "spark.master.rest.enabled=false", unusedJar.toString) runSparkSubmit(args) } @@ -79,7 +83,11 @@ class HiveSparkSubmitSuite // the HiveContext code mistakenly overrides the class loader that contains user classes. // For more detail, see sql/hive/src/test/resources/regression-test-SPARK-8489/*scala. val testJar = "sql/hive/src/test/resources/regression-test-SPARK-8489/test.jar" - val args = Seq("--class", "Main", testJar) + val args = Seq( + "--conf", "spark.ui.enabled=false", + "--conf", "spark.master.rest.enabled=false", + "--class", "Main", + testJar) runSparkSubmit(args) } From 7c35746c916cf0019367850e75a080d7e739dba0 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Wed, 12 Aug 2015 20:02:55 -0700 Subject: [PATCH 074/253] [SPARK-9827] [SQL] fix fd leak in UnsafeRowSerializer Currently, UnsafeRowSerializer does not close the InputStream, will cause fd leak if the InputStream has an open fd in it. TODO: the fd could still be leaked, if any items in the stream is not consumed. Currently it replies on GC to close the fd in this case. cc JoshRosen Author: Davies Liu Closes #8116 from davies/fd_leak. --- .../sql/execution/UnsafeRowSerializer.scala | 2 ++ .../execution/UnsafeRowSerializerSuite.scala | 31 +++++++++++++++++-- 2 files changed, 30 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala index 3860c4bba9a99..5c18558f9bde7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala @@ -108,6 +108,7 @@ private class UnsafeRowSerializerInstance(numFields: Int) extends SerializerInst override def asKeyValueIterator: Iterator[(Int, UnsafeRow)] = { new Iterator[(Int, UnsafeRow)] { private[this] var rowSize: Int = dIn.readInt() + if (rowSize == EOF) dIn.close() override def hasNext: Boolean = rowSize != EOF @@ -119,6 +120,7 @@ private class UnsafeRowSerializerInstance(numFields: Int) extends SerializerInst row.pointTo(rowBuffer, Platform.BYTE_ARRAY_OFFSET, numFields, rowSize) rowSize = dIn.readInt() // read the next row's size if (rowSize == EOF) { // We are returning the last row in this stream + dIn.close() val _rowTuple = rowTuple // Null these out so that the byte array can be garbage collected once the entire // iterator has been consumed diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala index 40b47ae18d648..bd02c73a26ace 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.execution -import java.io.{ByteArrayInputStream, ByteArrayOutputStream} +import java.io.{DataOutputStream, ByteArrayInputStream, ByteArrayOutputStream} import org.apache.spark.SparkFunSuite import org.apache.spark.sql.Row @@ -25,6 +25,18 @@ import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.expressions.{UnsafeProjection, UnsafeRow} import org.apache.spark.sql.types._ + +/** + * used to test close InputStream in UnsafeRowSerializer + */ +class ClosableByteArrayInputStream(buf: Array[Byte]) extends ByteArrayInputStream(buf) { + var closed: Boolean = false + override def close(): Unit = { + closed = true + super.close() + } +} + class UnsafeRowSerializerSuite extends SparkFunSuite { private def toUnsafeRow(row: Row, schema: Array[DataType]): UnsafeRow = { @@ -52,8 +64,8 @@ class UnsafeRowSerializerSuite extends SparkFunSuite { serializerStream.writeValue(unsafeRow) } serializerStream.close() - val deserializerIter = serializer.deserializeStream( - new ByteArrayInputStream(baos.toByteArray)).asKeyValueIterator + val input = new ClosableByteArrayInputStream(baos.toByteArray) + val deserializerIter = serializer.deserializeStream(input).asKeyValueIterator for (expectedRow <- unsafeRows) { val actualRow = deserializerIter.next().asInstanceOf[(Integer, UnsafeRow)]._2 assert(expectedRow.getSizeInBytes === actualRow.getSizeInBytes) @@ -61,5 +73,18 @@ class UnsafeRowSerializerSuite extends SparkFunSuite { assert(expectedRow.getInt(1) === actualRow.getInt(1)) } assert(!deserializerIter.hasNext) + assert(input.closed) + } + + test("close empty input stream") { + val baos = new ByteArrayOutputStream() + val dout = new DataOutputStream(baos) + dout.writeInt(-1) // EOF + dout.flush() + val input = new ClosableByteArrayInputStream(baos.toByteArray) + val serializer = new UnsafeRowSerializer(numFields = 2).newInstance() + val deserializerIter = serializer.deserializeStream(input).asKeyValueIterator + assert(!deserializerIter.hasNext) + assert(input.closed) } } From 4413d0855aaba5cb00f737dc6934a0b92d9bc05d Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Wed, 12 Aug 2015 20:03:55 -0700 Subject: [PATCH 075/253] [SPARK-9908] [SQL] When spark.sql.tungsten.enabled is false, broadcast join does not work https://issues.apache.org/jira/browse/SPARK-9908 Author: Yin Huai Closes #8149 from yhuai/SPARK-9908. --- .../apache/spark/sql/execution/joins/HashedRelation.scala | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala index 076afe6e4e960..bb333b4d5ed18 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala @@ -66,7 +66,8 @@ private[joins] final class GeneralHashedRelation( private var hashTable: JavaHashMap[InternalRow, CompactBuffer[InternalRow]]) extends HashedRelation with Externalizable { - private def this() = this(null) // Needed for serialization + // Needed for serialization (it is public to make Java serialization work) + def this() = this(null) override def get(key: InternalRow): Seq[InternalRow] = hashTable.get(key) @@ -88,7 +89,8 @@ private[joins] final class UniqueKeyHashedRelation(private var hashTable: JavaHashMap[InternalRow, InternalRow]) extends HashedRelation with Externalizable { - private def this() = this(null) // Needed for serialization + // Needed for serialization (it is public to make Java serialization work) + def this() = this(null) override def get(key: InternalRow): Seq[InternalRow] = { val v = hashTable.get(key) From d2d5e7fe2df582e1c866334b3014d7cb351f5b70 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Wed, 12 Aug 2015 20:43:36 -0700 Subject: [PATCH 076/253] [SPARK-9704] [ML] Made ProbabilisticClassifier, Identifiable, VectorUDT public APIs Made ProbabilisticClassifier, Identifiable, VectorUDT public. All are annotated as DeveloperApi. CC: mengxr EronWright Author: Joseph K. Bradley Closes #8004 from jkbradley/ml-api-public-items and squashes the following commits: 7ebefda [Joseph K. Bradley] update per code review 7ff0768 [Joseph K. Bradley] attepting to add mima fix 756d84c [Joseph K. Bradley] VectorUDT annotated as AlphaComponent ae7767d [Joseph K. Bradley] added another warning 94fd553 [Joseph K. Bradley] Made ProbabilisticClassifier, Identifiable, VectorUDT public APIs --- .../classification/ProbabilisticClassifier.scala | 4 ++-- .../org/apache/spark/ml/util/Identifiable.scala | 16 ++++++++++++++-- .../org/apache/spark/mllib/linalg/Vectors.scala | 10 ++++------ project/MimaExcludes.scala | 4 ++++ 4 files changed, 24 insertions(+), 10 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala index 1e50a895a9a05..fdd1851ae5508 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala @@ -50,7 +50,7 @@ private[classification] trait ProbabilisticClassifierParams * @tparam M Concrete Model type */ @DeveloperApi -private[spark] abstract class ProbabilisticClassifier[ +abstract class ProbabilisticClassifier[ FeaturesType, E <: ProbabilisticClassifier[FeaturesType, E, M], M <: ProbabilisticClassificationModel[FeaturesType, M]] @@ -74,7 +74,7 @@ private[spark] abstract class ProbabilisticClassifier[ * @tparam M Concrete Model type */ @DeveloperApi -private[spark] abstract class ProbabilisticClassificationModel[ +abstract class ProbabilisticClassificationModel[ FeaturesType, M <: ProbabilisticClassificationModel[FeaturesType, M]] extends ClassificationModel[FeaturesType, M] with ProbabilisticClassifierParams { diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/Identifiable.scala b/mllib/src/main/scala/org/apache/spark/ml/util/Identifiable.scala index ddd34a54503a6..bd213e7362e94 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/Identifiable.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/Identifiable.scala @@ -19,11 +19,19 @@ package org.apache.spark.ml.util import java.util.UUID +import org.apache.spark.annotation.DeveloperApi + /** + * :: DeveloperApi :: + * * Trait for an object with an immutable unique ID that identifies itself and its derivatives. + * + * WARNING: There have not yet been final discussions on this API, so it may be broken in future + * releases. */ -private[spark] trait Identifiable { +@DeveloperApi +trait Identifiable { /** * An immutable unique ID for the object and its derivatives. @@ -33,7 +41,11 @@ private[spark] trait Identifiable { override def toString: String = uid } -private[spark] object Identifiable { +/** + * :: DeveloperApi :: + */ +@DeveloperApi +object Identifiable { /** * Returns a random UID that concatenates the given prefix, "_", and 12 random hex chars. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala index 86c461fa91633..df15d985c814c 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala @@ -26,7 +26,7 @@ import scala.collection.JavaConverters._ import breeze.linalg.{DenseVector => BDV, SparseVector => BSV, Vector => BV} import org.apache.spark.SparkException -import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.annotation.AlphaComponent import org.apache.spark.mllib.util.NumericParser import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.GenericMutableRow @@ -159,15 +159,13 @@ sealed trait Vector extends Serializable { } /** - * :: DeveloperApi :: + * :: AlphaComponent :: * * User-defined type for [[Vector]] which allows easy interaction with SQL * via [[org.apache.spark.sql.DataFrame]]. - * - * NOTE: This is currently private[spark] but will be made public later once it is stabilized. */ -@DeveloperApi -private[spark] class VectorUDT extends UserDefinedType[Vector] { +@AlphaComponent +class VectorUDT extends UserDefinedType[Vector] { override def sqlType: StructType = { // type: 0 = sparse, 1 = dense diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 90261ca3d61aa..784f83c10e023 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -178,6 +178,10 @@ object MimaExcludes { // SPARK-4751 Dynamic allocation for standalone mode ProblemFilters.exclude[MissingMethodProblem]( "org.apache.spark.SparkContext.supportDynamicAllocation") + ) ++ Seq( + // SPARK-9704 Made ProbabilisticClassifier, Identifiable, VectorUDT public APIs + ProblemFilters.exclude[IncompatibleResultTypeProblem]( + "org.apache.spark.mllib.linalg.VectorUDT.serialize") ) case v if v.startsWith("1.4") => From d7053bea985679c514b3add029631ea23e1730ce Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Wed, 12 Aug 2015 20:44:40 -0700 Subject: [PATCH 077/253] [SPARK-9903] [MLLIB] skip local processing in PrefixSpan if there are no small prefixes There exists a chance that the prefixes keep growing to the maximum pattern length. Then the final local processing step becomes unnecessary. feynmanliang Author: Xiangrui Meng Closes #8136 from mengxr/SPARK-9903. --- .../apache/spark/mllib/fpm/PrefixSpan.scala | 37 +++++++++++-------- 1 file changed, 21 insertions(+), 16 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala index ad6715b52f337..dc4ae1d0b69ed 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala @@ -282,25 +282,30 @@ object PrefixSpan extends Logging { largePrefixes = newLargePrefixes } - // Switch to local processing. - val bcSmallPrefixes = sc.broadcast(smallPrefixes) - val distributedFreqPattern = postfixes.flatMap { postfix => - bcSmallPrefixes.value.values.map { prefix => - (prefix.id, postfix.project(prefix).compressed) - }.filter(_._2.nonEmpty) - }.groupByKey().flatMap { case (id, projPostfixes) => - val prefix = bcSmallPrefixes.value(id) - val localPrefixSpan = new LocalPrefixSpan(minCount, maxPatternLength - prefix.length) - // TODO: We collect projected postfixes into memory. We should also compare the performance - // TODO: of keeping them on shuffle files. - localPrefixSpan.run(projPostfixes.toArray).map { case (pattern, count) => - (prefix.items ++ pattern, count) + var freqPatterns = sc.parallelize(localFreqPatterns, 1) + + val numSmallPrefixes = smallPrefixes.size + logInfo(s"number of small prefixes for local processing: $numSmallPrefixes") + if (numSmallPrefixes > 0) { + // Switch to local processing. + val bcSmallPrefixes = sc.broadcast(smallPrefixes) + val distributedFreqPattern = postfixes.flatMap { postfix => + bcSmallPrefixes.value.values.map { prefix => + (prefix.id, postfix.project(prefix).compressed) + }.filter(_._2.nonEmpty) + }.groupByKey().flatMap { case (id, projPostfixes) => + val prefix = bcSmallPrefixes.value(id) + val localPrefixSpan = new LocalPrefixSpan(minCount, maxPatternLength - prefix.length) + // TODO: We collect projected postfixes into memory. We should also compare the performance + // TODO: of keeping them on shuffle files. + localPrefixSpan.run(projPostfixes.toArray).map { case (pattern, count) => + (prefix.items ++ pattern, count) + } } + // Union local frequent patterns and distributed ones. + freqPatterns = freqPatterns ++ distributedFreqPattern } - // Union local frequent patterns and distributed ones. - val freqPatterns = (sc.parallelize(localFreqPatterns, 1) ++ distributedFreqPattern) - .persist(StorageLevel.MEMORY_AND_DISK) freqPatterns } From 2fb4901b71cee65d40a43e61e3f4411c30cdefc3 Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Wed, 12 Aug 2015 20:59:38 -0700 Subject: [PATCH 078/253] [SPARK-9916] [BUILD] [SPARKR] removed left-over sparkr.zip copy/create commands from codebase sparkr.zip is now built by SparkSubmit on a need-to-build basis. cc shivaram Author: Burak Yavuz Closes #8147 from brkyvz/make-dist-fix. --- R/install-dev.bat | 5 ----- make-distribution.sh | 1 - 2 files changed, 6 deletions(-) diff --git a/R/install-dev.bat b/R/install-dev.bat index f32670b67de96..008a5c668bc45 100644 --- a/R/install-dev.bat +++ b/R/install-dev.bat @@ -25,8 +25,3 @@ set SPARK_HOME=%~dp0.. MKDIR %SPARK_HOME%\R\lib R.exe CMD INSTALL --library="%SPARK_HOME%\R\lib" %SPARK_HOME%\R\pkg\ - -rem Zip the SparkR package so that it can be distributed to worker nodes on YARN -pushd %SPARK_HOME%\R\lib -%JAVA_HOME%\bin\jar.exe cfM "%SPARK_HOME%\R\lib\sparkr.zip" SparkR -popd diff --git a/make-distribution.sh b/make-distribution.sh index 4789b0e09cc8a..247a81341e4a4 100755 --- a/make-distribution.sh +++ b/make-distribution.sh @@ -219,7 +219,6 @@ cp -r "$SPARK_HOME/ec2" "$DISTDIR" if [ -d "$SPARK_HOME"/R/lib/SparkR ]; then mkdir -p "$DISTDIR"/R/lib cp -r "$SPARK_HOME/R/lib/SparkR" "$DISTDIR"/R/lib - cp "$SPARK_HOME/R/lib/sparkr.zip" "$DISTDIR"/R/lib fi # Download and copy in tachyon, if requested From 2278219054314f1d31ffc358a59aa5067f9f5de9 Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Wed, 12 Aug 2015 21:24:15 -0700 Subject: [PATCH 079/253] [SPARK-9920] [SQL] The simpleString of TungstenAggregate does not show its output https://issues.apache.org/jira/browse/SPARK-9920 Taking `sqlContext.sql("select i, sum(j1) as sum from testAgg group by i").explain()` as an example, the output of our current master is ``` == Physical Plan == TungstenAggregate(key=[i#0], value=[(sum(cast(j1#1 as bigint)),mode=Final,isDistinct=false)] TungstenExchange hashpartitioning(i#0) TungstenAggregate(key=[i#0], value=[(sum(cast(j1#1 as bigint)),mode=Partial,isDistinct=false)] Scan ParquetRelation[file:/user/hive/warehouse/testagg][i#0,j1#1] ``` With this PR, the output will be ``` == Physical Plan == TungstenAggregate(key=[i#0], functions=[(sum(cast(j1#1 as bigint)),mode=Final,isDistinct=false)], output=[i#0,sum#18L]) TungstenExchange hashpartitioning(i#0) TungstenAggregate(key=[i#0], functions=[(sum(cast(j1#1 as bigint)),mode=Partial,isDistinct=false)], output=[i#0,currentSum#22L]) Scan ParquetRelation[file:/user/hive/warehouse/testagg][i#0,j1#1] ``` Author: Yin Huai Closes #8150 from yhuai/SPARK-9920. --- .../spark/sql/execution/aggregate/SortBasedAggregate.scala | 6 +++++- .../spark/sql/execution/aggregate/TungstenAggregate.scala | 7 ++++--- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala index ab26f9c58aa2e..f4c14a9b3556f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala @@ -108,6 +108,10 @@ case class SortBasedAggregate( override def simpleString: String = { val allAggregateExpressions = nonCompleteAggregateExpressions ++ completeAggregateExpressions - s"""SortBasedAggregate ${groupingExpressions} ${allAggregateExpressions}""" + + val keyString = groupingExpressions.mkString("[", ",", "]") + val functionString = allAggregateExpressions.mkString("[", ",", "]") + val outputString = output.mkString("[", ",", "]") + s"SortBasedAggregate(key=$keyString, functions=$functionString, output=$outputString)" } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala index c40ca973796a6..99f51ba5b6935 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala @@ -127,11 +127,12 @@ case class TungstenAggregate( testFallbackStartsAt match { case None => val keyString = groupingExpressions.mkString("[", ",", "]") - val valueString = allAggregateExpressions.mkString("[", ",", "]") - s"TungstenAggregate(key=$keyString, value=$valueString" + val functionString = allAggregateExpressions.mkString("[", ",", "]") + val outputString = output.mkString("[", ",", "]") + s"TungstenAggregate(key=$keyString, functions=$functionString, output=$outputString)" case Some(fallbackStartsAt) => s"TungstenAggregateWithControlledFallback $groupingExpressions " + - s"$allAggregateExpressions fallbackStartsAt=$fallbackStartsAt" + s"$allAggregateExpressions $resultExpressions fallbackStartsAt=$fallbackStartsAt" } } } From a8ab2634c1eee143a4deaf309204df8add727f9e Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Wed, 12 Aug 2015 21:26:00 -0700 Subject: [PATCH 080/253] [SPARK-9832] [SQL] add a thread-safe lookup for BytesToBytseMap This patch add a thread-safe lookup for BytesToBytseMap, and use that in broadcasted HashedRelation. Author: Davies Liu Closes #8151 from davies/safeLookup. --- .../spark/unsafe/map/BytesToBytesMap.java | 30 ++++++++++++++----- .../sql/execution/joins/HashedRelation.scala | 6 ++-- 2 files changed, 26 insertions(+), 10 deletions(-) diff --git a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java index 87ed47e88c4ef..5f3a4fcf4d585 100644 --- a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java +++ b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java @@ -17,25 +17,24 @@ package org.apache.spark.unsafe.map; -import java.lang.Override; -import java.lang.UnsupportedOperationException; +import javax.annotation.Nullable; import java.util.Iterator; import java.util.LinkedList; import java.util.List; -import javax.annotation.Nullable; - import com.google.common.annotations.VisibleForTesting; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.apache.spark.shuffle.ShuffleMemoryManager; -import org.apache.spark.unsafe.*; +import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.array.ByteArrayMethods; import org.apache.spark.unsafe.array.LongArray; import org.apache.spark.unsafe.bitset.BitSet; import org.apache.spark.unsafe.hash.Murmur3_x86_32; -import org.apache.spark.unsafe.memory.*; +import org.apache.spark.unsafe.memory.MemoryBlock; +import org.apache.spark.unsafe.memory.MemoryLocation; +import org.apache.spark.unsafe.memory.TaskMemoryManager; /** * An append-only hash map where keys and values are contiguous regions of bytes. @@ -328,6 +327,20 @@ public Location lookup( Object keyBaseObject, long keyBaseOffset, int keyRowLengthBytes) { + safeLookup(keyBaseObject, keyBaseOffset, keyRowLengthBytes, loc); + return loc; + } + + /** + * Looks up a key, and saves the result in provided `loc`. + * + * This is a thread-safe version of `lookup`, could be used by multiple threads. + */ + public void safeLookup( + Object keyBaseObject, + long keyBaseOffset, + int keyRowLengthBytes, + Location loc) { assert(bitset != null); assert(longArray != null); @@ -343,7 +356,8 @@ public Location lookup( } if (!bitset.isSet(pos)) { // This is a new key. - return loc.with(pos, hashcode, false); + loc.with(pos, hashcode, false); + return; } else { long stored = longArray.get(pos * 2 + 1); if ((int) (stored) == hashcode) { @@ -361,7 +375,7 @@ public Location lookup( keyRowLengthBytes ); if (areEqual) { - return loc; + return; } else { if (enablePerfMetrics) { numHashCollisions++; diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala index bb333b4d5ed18..ea02076b41a6f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala @@ -215,8 +215,10 @@ private[joins] final class UnsafeHashedRelation( if (binaryMap != null) { // Used in Broadcast join - val loc = binaryMap.lookup(unsafeKey.getBaseObject, unsafeKey.getBaseOffset, - unsafeKey.getSizeInBytes) + val map = binaryMap // avoid the compiler error + val loc = new map.Location // this could be allocated in stack + binaryMap.safeLookup(unsafeKey.getBaseObject, unsafeKey.getBaseOffset, + unsafeKey.getSizeInBytes, loc) if (loc.isDefined) { val buffer = CompactBuffer[UnsafeRow]() From 5fc058a1fc5d83ad53feec936475484aef3800b3 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Wed, 12 Aug 2015 21:33:38 -0700 Subject: [PATCH 081/253] [SPARK-9917] [ML] add getMin/getMax and doc for originalMin/origianlMax in MinMaxScaler hhbyyh Author: Xiangrui Meng Closes #8145 from mengxr/SPARK-9917. --- .../org/apache/spark/ml/feature/MinMaxScaler.scala | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala index b30adf3df48d2..9a473dd23772d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala @@ -41,6 +41,9 @@ private[feature] trait MinMaxScalerParams extends Params with HasInputCol with H val min: DoubleParam = new DoubleParam(this, "min", "lower bound of the output feature range") + /** @group getParam */ + def getMin: Double = $(min) + /** * upper bound after transformation, shared by all features * Default: 1.0 @@ -49,6 +52,9 @@ private[feature] trait MinMaxScalerParams extends Params with HasInputCol with H val max: DoubleParam = new DoubleParam(this, "max", "upper bound of the output feature range") + /** @group getParam */ + def getMax: Double = $(max) + /** Validates and transforms the input schema. */ protected def validateAndTransformSchema(schema: StructType): StructType = { val inputType = schema($(inputCol)).dataType @@ -115,6 +121,9 @@ class MinMaxScaler(override val uid: String) * :: Experimental :: * Model fitted by [[MinMaxScaler]]. * + * @param originalMin min value for each original column during fitting + * @param originalMax max value for each original column during fitting + * * TODO: The transformer does not yet set the metadata in the output column (SPARK-8529). */ @Experimental @@ -136,7 +145,6 @@ class MinMaxScalerModel private[ml] ( /** @group setParam */ def setMax(value: Double): this.type = set(max, value) - override def transform(dataset: DataFrame): DataFrame = { val originalRange = (originalMax.toBreeze - originalMin.toBreeze).toArray val minArray = originalMin.toArray From df543892122342b97e5137b266959ba97589b3ef Mon Sep 17 00:00:00 2001 From: "shikai.tang" Date: Wed, 12 Aug 2015 21:53:15 -0700 Subject: [PATCH 082/253] [SPARK-8922] [DOCUMENTATION, MLLIB] Add @since tags to mllib.evaluation Author: shikai.tang Closes #7429 from mosessky/master. --- .../BinaryClassificationMetrics.scala | 32 ++++++++++++++++--- .../mllib/evaluation/MulticlassMetrics.scala | 9 ++++++ .../mllib/evaluation/MultilabelMetrics.scala | 4 +++ .../mllib/evaluation/RankingMetrics.scala | 4 +++ .../mllib/evaluation/RegressionMetrics.scala | 6 ++++ 5 files changed, 50 insertions(+), 5 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetrics.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetrics.scala index c1d1a224817e8..486741edd6f5a 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetrics.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetrics.scala @@ -41,6 +41,7 @@ import org.apache.spark.sql.DataFrame * of bins may not exactly equal numBins. The last bin in each partition may * be smaller as a result, meaning there may be an extra sample at * partition boundaries. + * @since 1.3.0 */ @Experimental class BinaryClassificationMetrics( @@ -51,6 +52,7 @@ class BinaryClassificationMetrics( /** * Defaults `numBins` to 0. + * @since 1.0.0 */ def this(scoreAndLabels: RDD[(Double, Double)]) = this(scoreAndLabels, 0) @@ -61,12 +63,18 @@ class BinaryClassificationMetrics( private[mllib] def this(scoreAndLabels: DataFrame) = this(scoreAndLabels.map(r => (r.getDouble(0), r.getDouble(1)))) - /** Unpersist intermediate RDDs used in the computation. */ + /** + * Unpersist intermediate RDDs used in the computation. + * @since 1.0.0 + */ def unpersist() { cumulativeCounts.unpersist() } - /** Returns thresholds in descending order. */ + /** + * Returns thresholds in descending order. + * @since 1.0.0 + */ def thresholds(): RDD[Double] = cumulativeCounts.map(_._1) /** @@ -74,6 +82,7 @@ class BinaryClassificationMetrics( * which is an RDD of (false positive rate, true positive rate) * with (0.0, 0.0) prepended and (1.0, 1.0) appended to it. * @see http://en.wikipedia.org/wiki/Receiver_operating_characteristic + * @since 1.0.0 */ def roc(): RDD[(Double, Double)] = { val rocCurve = createCurve(FalsePositiveRate, Recall) @@ -85,6 +94,7 @@ class BinaryClassificationMetrics( /** * Computes the area under the receiver operating characteristic (ROC) curve. + * @since 1.0.0 */ def areaUnderROC(): Double = AreaUnderCurve.of(roc()) @@ -92,6 +102,7 @@ class BinaryClassificationMetrics( * Returns the precision-recall curve, which is an RDD of (recall, precision), * NOT (precision, recall), with (0.0, 1.0) prepended to it. * @see http://en.wikipedia.org/wiki/Precision_and_recall + * @since 1.0.0 */ def pr(): RDD[(Double, Double)] = { val prCurve = createCurve(Recall, Precision) @@ -102,6 +113,7 @@ class BinaryClassificationMetrics( /** * Computes the area under the precision-recall curve. + * @since 1.0.0 */ def areaUnderPR(): Double = AreaUnderCurve.of(pr()) @@ -110,16 +122,26 @@ class BinaryClassificationMetrics( * @param beta the beta factor in F-Measure computation. * @return an RDD of (threshold, F-Measure) pairs. * @see http://en.wikipedia.org/wiki/F1_score + * @since 1.0.0 */ def fMeasureByThreshold(beta: Double): RDD[(Double, Double)] = createCurve(FMeasure(beta)) - /** Returns the (threshold, F-Measure) curve with beta = 1.0. */ + /** + * Returns the (threshold, F-Measure) curve with beta = 1.0. + * @since 1.0.0 + */ def fMeasureByThreshold(): RDD[(Double, Double)] = fMeasureByThreshold(1.0) - /** Returns the (threshold, precision) curve. */ + /** + * Returns the (threshold, precision) curve. + * @since 1.0.0 + */ def precisionByThreshold(): RDD[(Double, Double)] = createCurve(Precision) - /** Returns the (threshold, recall) curve. */ + /** + * Returns the (threshold, recall) curve. + * @since 1.0.0 + */ def recallByThreshold(): RDD[(Double, Double)] = createCurve(Recall) private lazy val ( diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala index 4628dc5690913..dddfa3ea5b800 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala @@ -30,6 +30,7 @@ import org.apache.spark.sql.DataFrame * Evaluator for multiclass classification. * * @param predictionAndLabels an RDD of (prediction, label) pairs. + * @since 1.1.0 */ @Experimental class MulticlassMetrics(predictionAndLabels: RDD[(Double, Double)]) { @@ -64,6 +65,7 @@ class MulticlassMetrics(predictionAndLabels: RDD[(Double, Double)]) { * predicted classes are in columns, * they are ordered by class label ascending, * as in "labels" + * @since 1.1.0 */ def confusionMatrix: Matrix = { val n = labels.size @@ -83,12 +85,14 @@ class MulticlassMetrics(predictionAndLabels: RDD[(Double, Double)]) { /** * Returns true positive rate for a given label (category) * @param label the label. + * @since 1.1.0 */ def truePositiveRate(label: Double): Double = recall(label) /** * Returns false positive rate for a given label (category) * @param label the label. + * @since 1.1.0 */ def falsePositiveRate(label: Double): Double = { val fp = fpByClass.getOrElse(label, 0) @@ -98,6 +102,7 @@ class MulticlassMetrics(predictionAndLabels: RDD[(Double, Double)]) { /** * Returns precision for a given label (category) * @param label the label. + * @since 1.1.0 */ def precision(label: Double): Double = { val tp = tpByClass(label) @@ -108,6 +113,7 @@ class MulticlassMetrics(predictionAndLabels: RDD[(Double, Double)]) { /** * Returns recall for a given label (category) * @param label the label. + * @since 1.1.0 */ def recall(label: Double): Double = tpByClass(label).toDouble / labelCountByClass(label) @@ -115,6 +121,7 @@ class MulticlassMetrics(predictionAndLabels: RDD[(Double, Double)]) { * Returns f-measure for a given label (category) * @param label the label. * @param beta the beta parameter. + * @since 1.1.0 */ def fMeasure(label: Double, beta: Double): Double = { val p = precision(label) @@ -126,6 +133,7 @@ class MulticlassMetrics(predictionAndLabels: RDD[(Double, Double)]) { /** * Returns f1-measure for a given label (category) * @param label the label. + * @since 1.1.0 */ def fMeasure(label: Double): Double = fMeasure(label, 1.0) @@ -179,6 +187,7 @@ class MulticlassMetrics(predictionAndLabels: RDD[(Double, Double)]) { /** * Returns weighted averaged f-measure * @param beta the beta parameter. + * @since 1.1.0 */ def weightedFMeasure(beta: Double): Double = labelCountByClass.map { case (category, count) => fMeasure(category, beta) * count.toDouble / labelCount diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MultilabelMetrics.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MultilabelMetrics.scala index bf6eb1d5bd2ab..77cb1e09bdbb5 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MultilabelMetrics.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MultilabelMetrics.scala @@ -25,6 +25,7 @@ import org.apache.spark.sql.DataFrame * Evaluator for multilabel classification. * @param predictionAndLabels an RDD of (predictions, labels) pairs, * both are non-null Arrays, each with unique elements. + * @since 1.2.0 */ class MultilabelMetrics(predictionAndLabels: RDD[(Array[Double], Array[Double])]) { @@ -103,6 +104,7 @@ class MultilabelMetrics(predictionAndLabels: RDD[(Array[Double], Array[Double])] /** * Returns precision for a given label (category) * @param label the label. + * @since 1.2.0 */ def precision(label: Double): Double = { val tp = tpPerClass(label) @@ -113,6 +115,7 @@ class MultilabelMetrics(predictionAndLabels: RDD[(Array[Double], Array[Double])] /** * Returns recall for a given label (category) * @param label the label. + * @since 1.2.0 */ def recall(label: Double): Double = { val tp = tpPerClass(label) @@ -123,6 +126,7 @@ class MultilabelMetrics(predictionAndLabels: RDD[(Array[Double], Array[Double])] /** * Returns f1-measure for a given label (category) * @param label the label. + * @since 1.2.0 */ def f1Measure(label: Double): Double = { val p = precision(label) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RankingMetrics.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RankingMetrics.scala index 5b5a2a1450f7f..063fbed8cdeea 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RankingMetrics.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RankingMetrics.scala @@ -34,6 +34,7 @@ import org.apache.spark.rdd.RDD * Java users should use [[RankingMetrics$.of]] to create a [[RankingMetrics]] instance. * * @param predictionAndLabels an RDD of (predicted ranking, ground truth set) pairs. + * @since 1.2.0 */ @Experimental class RankingMetrics[T: ClassTag](predictionAndLabels: RDD[(Array[T], Array[T])]) @@ -55,6 +56,7 @@ class RankingMetrics[T: ClassTag](predictionAndLabels: RDD[(Array[T], Array[T])] * * @param k the position to compute the truncated precision, must be positive * @return the average precision at the first k ranking positions + * @since 1.2.0 */ def precisionAt(k: Int): Double = { require(k > 0, "ranking position k should be positive") @@ -124,6 +126,7 @@ class RankingMetrics[T: ClassTag](predictionAndLabels: RDD[(Array[T], Array[T])] * * @param k the position to compute the truncated ndcg, must be positive * @return the average ndcg at the first k ranking positions + * @since 1.2.0 */ def ndcgAt(k: Int): Double = { require(k > 0, "ranking position k should be positive") @@ -162,6 +165,7 @@ object RankingMetrics { /** * Creates a [[RankingMetrics]] instance (for Java users). * @param predictionAndLabels a JavaRDD of (predicted ranking, ground truth set) pairs + * @since 1.4.0 */ def of[E, T <: jl.Iterable[E]](predictionAndLabels: JavaRDD[(T, T)]): RankingMetrics[E] = { implicit val tag = JavaSparkContext.fakeClassTag[E] diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RegressionMetrics.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RegressionMetrics.scala index 408847afa800d..54dfd8c099494 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RegressionMetrics.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RegressionMetrics.scala @@ -29,6 +29,7 @@ import org.apache.spark.sql.DataFrame * Evaluator for regression. * * @param predictionAndObservations an RDD of (prediction, observation) pairs. + * @since 1.2.0 */ @Experimental class RegressionMetrics(predictionAndObservations: RDD[(Double, Double)]) extends Logging { @@ -66,6 +67,7 @@ class RegressionMetrics(predictionAndObservations: RDD[(Double, Double)]) extend * Returns the variance explained by regression. * explainedVariance = \sum_i (\hat{y_i} - \bar{y})^2 / n * @see [[https://en.wikipedia.org/wiki/Fraction_of_variance_unexplained]] + * @since 1.2.0 */ def explainedVariance: Double = { SSreg / summary.count @@ -74,6 +76,7 @@ class RegressionMetrics(predictionAndObservations: RDD[(Double, Double)]) extend /** * Returns the mean absolute error, which is a risk function corresponding to the * expected value of the absolute error loss or l1-norm loss. + * @since 1.2.0 */ def meanAbsoluteError: Double = { summary.normL1(1) / summary.count @@ -82,6 +85,7 @@ class RegressionMetrics(predictionAndObservations: RDD[(Double, Double)]) extend /** * Returns the mean squared error, which is a risk function corresponding to the * expected value of the squared error loss or quadratic loss. + * @since 1.2.0 */ def meanSquaredError: Double = { SSerr / summary.count @@ -90,6 +94,7 @@ class RegressionMetrics(predictionAndObservations: RDD[(Double, Double)]) extend /** * Returns the root mean squared error, which is defined as the square root of * the mean squared error. + * @since 1.2.0 */ def rootMeanSquaredError: Double = { math.sqrt(this.meanSquaredError) @@ -98,6 +103,7 @@ class RegressionMetrics(predictionAndObservations: RDD[(Double, Double)]) extend /** * Returns R^2^, the unadjusted coefficient of determination. * @see [[http://en.wikipedia.org/wiki/Coefficient_of_determination]] + * @since 1.2.0 */ def r2: Double = { 1 - SSerr / SStot From d7eb371eb6369a34e58a09179efe058c4101de9e Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Wed, 12 Aug 2015 22:30:33 -0700 Subject: [PATCH 083/253] [SPARK-9914] [ML] define setters explicitly for Java and use setParam group in RFormula The problem with defining setters in the base class is that it doesn't return the correct type in Java. ericl Author: Xiangrui Meng Closes #8143 from mengxr/SPARK-9914 and squashes the following commits: d36c887 [Xiangrui Meng] remove setters from model a49021b [Xiangrui Meng] define setters explicitly for Java and use setParam group --- .../scala/org/apache/spark/ml/feature/RFormula.scala | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala index d5360c9217ea9..a752dacd72d95 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala @@ -33,11 +33,6 @@ import org.apache.spark.sql.types._ * Base trait for [[RFormula]] and [[RFormulaModel]]. */ private[feature] trait RFormulaBase extends HasFeaturesCol with HasLabelCol { - /** @group getParam */ - def setFeaturesCol(value: String): this.type = set(featuresCol, value) - - /** @group getParam */ - def setLabelCol(value: String): this.type = set(labelCol, value) protected def hasLabelCol(schema: StructType): Boolean = { schema.map(_.name).contains($(labelCol)) @@ -71,6 +66,12 @@ class RFormula(override val uid: String) extends Estimator[RFormulaModel] with R /** @group getParam */ def getFormula: String = $(formula) + /** @group setParam */ + def setFeaturesCol(value: String): this.type = set(featuresCol, value) + + /** @group setParam */ + def setLabelCol(value: String): this.type = set(labelCol, value) + /** Whether the formula specifies fitting an intercept. */ private[ml] def hasIntercept: Boolean = { require(isDefined(formula), "Formula must be defined first.") From d0b18919d16e6a2f19159516bd2767b60b595279 Mon Sep 17 00:00:00 2001 From: Yijie Shen Date: Thu, 13 Aug 2015 13:33:39 +0800 Subject: [PATCH 084/253] [SPARK-9927] [SQL] Revert 8049 since it's pushing wrong filter down I made a mistake in #8049 by casting literal value to attribute's data type, which would cause simply truncate the literal value and push a wrong filter down. JIRA: https://issues.apache.org/jira/browse/SPARK-9927 Author: Yijie Shen Closes #8157 from yjshen/rever8049. --- .../datasources/DataSourceStrategy.scala | 30 ++-------------- .../execution/datasources/jdbc/JDBCRDD.scala | 2 +- .../org/apache/spark/sql/jdbc/JDBCSuite.scala | 35 ------------------- 3 files changed, 3 insertions(+), 64 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index 9eea2b0382535..2a4c40db8bb66 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -20,13 +20,13 @@ package org.apache.spark.sql.execution.datasources import org.apache.spark.{Logging, TaskContext} import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.rdd.{MapPartitionsRDD, RDD, UnionRDD} -import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow, expressions} +import org.apache.spark.sql.catalyst.{InternalRow, expressions} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.sources._ -import org.apache.spark.sql.types.{TimestampType, DateType, StringType, StructType} +import org.apache.spark.sql.types.{StringType, StructType} import org.apache.spark.sql.{SaveMode, Strategy, execution, sources, _} import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.{SerializableConfiguration, Utils} @@ -343,17 +343,11 @@ private[sql] object DataSourceStrategy extends Strategy with Logging { * and convert them. */ protected[sql] def selectFilters(filters: Seq[Expression]) = { - import CatalystTypeConverters._ - def translate(predicate: Expression): Option[Filter] = predicate match { case expressions.EqualTo(a: Attribute, Literal(v, _)) => Some(sources.EqualTo(a.name, v)) case expressions.EqualTo(Literal(v, _), a: Attribute) => Some(sources.EqualTo(a.name, v)) - case expressions.EqualTo(Cast(a: Attribute, _), l: Literal) => - Some(sources.EqualTo(a.name, convertToScala(Cast(l, a.dataType).eval(), a.dataType))) - case expressions.EqualTo(l: Literal, Cast(a: Attribute, _)) => - Some(sources.EqualTo(a.name, convertToScala(Cast(l, a.dataType).eval(), a.dataType))) case expressions.EqualNullSafe(a: Attribute, Literal(v, _)) => Some(sources.EqualNullSafe(a.name, v)) @@ -364,41 +358,21 @@ private[sql] object DataSourceStrategy extends Strategy with Logging { Some(sources.GreaterThan(a.name, v)) case expressions.GreaterThan(Literal(v, _), a: Attribute) => Some(sources.LessThan(a.name, v)) - case expressions.GreaterThan(Cast(a: Attribute, _), l: Literal) => - Some(sources.GreaterThan(a.name, convertToScala(Cast(l, a.dataType).eval(), a.dataType))) - case expressions.GreaterThan(l: Literal, Cast(a: Attribute, _)) => - Some(sources.LessThan(a.name, convertToScala(Cast(l, a.dataType).eval(), a.dataType))) case expressions.LessThan(a: Attribute, Literal(v, _)) => Some(sources.LessThan(a.name, v)) case expressions.LessThan(Literal(v, _), a: Attribute) => Some(sources.GreaterThan(a.name, v)) - case expressions.LessThan(Cast(a: Attribute, _), l: Literal) => - Some(sources.LessThan(a.name, convertToScala(Cast(l, a.dataType).eval(), a.dataType))) - case expressions.LessThan(l: Literal, Cast(a: Attribute, _)) => - Some(sources.GreaterThan(a.name, convertToScala(Cast(l, a.dataType).eval(), a.dataType))) case expressions.GreaterThanOrEqual(a: Attribute, Literal(v, _)) => Some(sources.GreaterThanOrEqual(a.name, v)) case expressions.GreaterThanOrEqual(Literal(v, _), a: Attribute) => Some(sources.LessThanOrEqual(a.name, v)) - case expressions.GreaterThanOrEqual(Cast(a: Attribute, _), l: Literal) => - Some(sources.GreaterThanOrEqual(a.name, - convertToScala(Cast(l, a.dataType).eval(), a.dataType))) - case expressions.GreaterThanOrEqual(l: Literal, Cast(a: Attribute, _)) => - Some(sources.LessThanOrEqual(a.name, - convertToScala(Cast(l, a.dataType).eval(), a.dataType))) case expressions.LessThanOrEqual(a: Attribute, Literal(v, _)) => Some(sources.LessThanOrEqual(a.name, v)) case expressions.LessThanOrEqual(Literal(v, _), a: Attribute) => Some(sources.GreaterThanOrEqual(a.name, v)) - case expressions.LessThanOrEqual(Cast(a: Attribute, _), l: Literal) => - Some(sources.LessThanOrEqual(a.name, - convertToScala(Cast(l, a.dataType).eval(), a.dataType))) - case expressions.LessThanOrEqual(l: Literal, Cast(a: Attribute, _)) => - Some(sources.GreaterThanOrEqual(a.name, - convertToScala(Cast(l, a.dataType).eval(), a.dataType))) case expressions.InSet(a: Attribute, set) => Some(sources.In(a.name, set.toArray)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala index 281943e23fcff..8eab6a0adccc4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala @@ -284,7 +284,7 @@ private[sql] class JDBCRDD( /** * `filters`, but as a WHERE clause suitable for injection into a SQL query. */ - val filterWhereClause: String = { + private val filterWhereClause: String = { val filterStrings = filters map compileFilter filter (_ != null) if (filterStrings.size > 0) { val sb = new StringBuilder("WHERE ") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index b9cfae51e809c..e4dcf4c75d208 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -25,8 +25,6 @@ import org.h2.jdbc.JdbcSQLException import org.scalatest.BeforeAndAfter import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.execution.datasources.jdbc.JDBCRDD -import org.apache.spark.sql.execution.PhysicalRDD import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -150,18 +148,6 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter { |OPTIONS (url '$url', dbtable 'TEST.FLTTYPES', user 'testUser', password 'testPass') """.stripMargin.replaceAll("\n", " ")) - conn.prepareStatement("create table test.decimals (a DECIMAL(7, 2), b DECIMAL(4, 0))"). - executeUpdate() - conn.prepareStatement("insert into test.decimals values (12345.67, 1234)").executeUpdate() - conn.prepareStatement("insert into test.decimals values (34567.89, 1428)").executeUpdate() - conn.commit() - sql( - s""" - |CREATE TEMPORARY TABLE decimals - |USING org.apache.spark.sql.jdbc - |OPTIONS (url '$url', dbtable 'TEST.DECIMALS', user 'testUser', password 'testPass') - """.stripMargin.replaceAll("\n", " ")) - conn.prepareStatement( s""" |create table test.nulltypes (a INT, b BOOLEAN, c TINYINT, d BINARY(20), e VARCHAR(20), @@ -458,25 +444,4 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter { assert(agg.getCatalystType(0, "", 1, null) === Some(LongType)) assert(agg.getCatalystType(1, "", 1, null) === Some(StringType)) } - - test("SPARK-9182: filters are not passed through to jdbc source") { - def checkPushedFilter(query: String, filterStr: String): Unit = { - val rddOpt = sql(query).queryExecution.executedPlan.collectFirst { - case PhysicalRDD(_, rdd: JDBCRDD, _) => rdd - } - assert(rddOpt.isDefined) - val pushedFilterStr = rddOpt.get.filterWhereClause - assert(pushedFilterStr.contains(filterStr), - s"Expected to push [$filterStr], actually we pushed [$pushedFilterStr]") - } - - checkPushedFilter("select * from foobar where NAME = 'fred'", "NAME = 'fred'") - checkPushedFilter("select * from inttypes where A > '15'", "A > 15") - checkPushedFilter("select * from inttypes where C <= 20", "C <= 20") - checkPushedFilter("select * from decimals where A > 1000", "A > 1000.00") - checkPushedFilter("select * from decimals where A > 1000 AND A < 2000", - "A > 1000.00 AND A < 2000.00") - checkPushedFilter("select * from decimals where A = 2000 AND B > 20", "A = 2000.00 AND B > 20") - checkPushedFilter("select * from timetypes where B > '1998-09-10'", "B > 1998-09-10") - } } From 68f99571492f67596b3656e9f076deeb96616f4a Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Wed, 12 Aug 2015 23:04:59 -0700 Subject: [PATCH 085/253] [SPARK-9918] [MLLIB] remove runs from k-means and rename epsilon to tol This requires some discussion. I'm not sure whether `runs` is a useful parameter. It certainly complicates the implementation. We might want to optimize the k-means implementation with block matrix operations. In this case, having `runs` may not be worth the trade-off. Also it increases the communication cost in a single job, which might cause other issues. This PR also renames `epsilon` to `tol` to have consistent naming among algorithms. The Python constructor is updated to include all parameters. jkbradley yu-iskw Author: Xiangrui Meng Closes #8148 from mengxr/SPARK-9918 and squashes the following commits: 149b9e5 [Xiangrui Meng] fix constructor in Python and rename epsilon to tol 3cc15b3 [Xiangrui Meng] fix test and change initStep to initSteps in python a0a0274 [Xiangrui Meng] remove runs from k-means in the pipeline API --- .../apache/spark/ml/clustering/KMeans.scala | 51 +++------------ .../spark/ml/clustering/KMeansSuite.scala | 12 +--- python/pyspark/ml/clustering.py | 63 ++++--------------- 3 files changed, 26 insertions(+), 100 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala index dc192add6ca13..47a18cdb31b53 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala @@ -18,8 +18,8 @@ package org.apache.spark.ml.clustering import org.apache.spark.annotation.Experimental -import org.apache.spark.ml.param.{Param, Params, IntParam, DoubleParam, ParamMap} -import org.apache.spark.ml.param.shared.{HasFeaturesCol, HasMaxIter, HasPredictionCol, HasSeed} +import org.apache.spark.ml.param.{Param, Params, IntParam, ParamMap} +import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util.{Identifiable, SchemaUtils} import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.mllib.clustering.{KMeans => MLlibKMeans, KMeansModel => MLlibKMeansModel} @@ -27,14 +27,13 @@ import org.apache.spark.mllib.linalg.{Vector, VectorUDT} import org.apache.spark.sql.functions.{col, udf} import org.apache.spark.sql.types.{IntegerType, StructType} import org.apache.spark.sql.{DataFrame, Row} -import org.apache.spark.util.Utils /** * Common params for KMeans and KMeansModel */ -private[clustering] trait KMeansParams - extends Params with HasMaxIter with HasFeaturesCol with HasSeed with HasPredictionCol { +private[clustering] trait KMeansParams extends Params with HasMaxIter with HasFeaturesCol + with HasSeed with HasPredictionCol with HasTol { /** * Set the number of clusters to create (k). Must be > 1. Default: 2. @@ -45,31 +44,6 @@ private[clustering] trait KMeansParams /** @group getParam */ def getK: Int = $(k) - /** - * Param the number of runs of the algorithm to execute in parallel. We initialize the algorithm - * this many times with random starting conditions (configured by the initialization mode), then - * return the best clustering found over any run. Must be >= 1. Default: 1. - * @group param - */ - final val runs = new IntParam(this, "runs", - "number of runs of the algorithm to execute in parallel", (value: Int) => value >= 1) - - /** @group getParam */ - def getRuns: Int = $(runs) - - /** - * Param the distance threshold within which we've consider centers to have converged. - * If all centers move less than this Euclidean distance, we stop iterating one run. - * Must be >= 0.0. Default: 1e-4 - * @group param - */ - final val epsilon = new DoubleParam(this, "epsilon", - "distance threshold within which we've consider centers to have converge", - (value: Double) => value >= 0.0) - - /** @group getParam */ - def getEpsilon: Double = $(epsilon) - /** * Param for the initialization algorithm. This can be either "random" to choose random points as * initial cluster centers, or "k-means||" to use a parallel variant of k-means++ @@ -136,9 +110,9 @@ class KMeansModel private[ml] ( /** * :: Experimental :: - * K-means clustering with support for multiple parallel runs and a k-means++ like initialization - * mode (the k-means|| algorithm by Bahmani et al). When multiple concurrent runs are requested, - * they are executed together with joint passes over the data for efficiency. + * K-means clustering with support for k-means|| initialization proposed by Bahmani et al. + * + * @see [[http://dx.doi.org/10.14778/2180912.2180915 Bahmani et al., Scalable k-means++.]] */ @Experimental class KMeans(override val uid: String) extends Estimator[KMeansModel] with KMeansParams { @@ -146,10 +120,9 @@ class KMeans(override val uid: String) extends Estimator[KMeansModel] with KMean setDefault( k -> 2, maxIter -> 20, - runs -> 1, initMode -> MLlibKMeans.K_MEANS_PARALLEL, initSteps -> 5, - epsilon -> 1e-4) + tol -> 1e-4) override def copy(extra: ParamMap): KMeans = defaultCopy(extra) @@ -174,10 +147,7 @@ class KMeans(override val uid: String) extends Estimator[KMeansModel] with KMean def setMaxIter(value: Int): this.type = set(maxIter, value) /** @group setParam */ - def setRuns(value: Int): this.type = set(runs, value) - - /** @group setParam */ - def setEpsilon(value: Double): this.type = set(epsilon, value) + def setTol(value: Double): this.type = set(tol, value) /** @group setParam */ def setSeed(value: Long): this.type = set(seed, value) @@ -191,8 +161,7 @@ class KMeans(override val uid: String) extends Estimator[KMeansModel] with KMean .setInitializationSteps($(initSteps)) .setMaxIterations($(maxIter)) .setSeed($(seed)) - .setEpsilon($(epsilon)) - .setRuns($(runs)) + .setEpsilon($(tol)) val parentModel = algo.run(rdd) val model = new KMeansModel(uid, parentModel) copyValues(model) diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala index 1f15ac02f4008..688b0e31f91dc 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala @@ -52,10 +52,9 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext { assert(kmeans.getFeaturesCol === "features") assert(kmeans.getPredictionCol === "prediction") assert(kmeans.getMaxIter === 20) - assert(kmeans.getRuns === 1) assert(kmeans.getInitMode === MLlibKMeans.K_MEANS_PARALLEL) assert(kmeans.getInitSteps === 5) - assert(kmeans.getEpsilon === 1e-4) + assert(kmeans.getTol === 1e-4) } test("set parameters") { @@ -64,21 +63,19 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext { .setFeaturesCol("test_feature") .setPredictionCol("test_prediction") .setMaxIter(33) - .setRuns(7) .setInitMode(MLlibKMeans.RANDOM) .setInitSteps(3) .setSeed(123) - .setEpsilon(1e-3) + .setTol(1e-3) assert(kmeans.getK === 9) assert(kmeans.getFeaturesCol === "test_feature") assert(kmeans.getPredictionCol === "test_prediction") assert(kmeans.getMaxIter === 33) - assert(kmeans.getRuns === 7) assert(kmeans.getInitMode === MLlibKMeans.RANDOM) assert(kmeans.getInitSteps === 3) assert(kmeans.getSeed === 123) - assert(kmeans.getEpsilon === 1e-3) + assert(kmeans.getTol === 1e-3) } test("parameters validation") { @@ -91,9 +88,6 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext { intercept[IllegalArgumentException] { new KMeans().setInitSteps(0) } - intercept[IllegalArgumentException] { - new KMeans().setRuns(0) - } } test("fit & transform") { diff --git a/python/pyspark/ml/clustering.py b/python/pyspark/ml/clustering.py index 48338713a29ea..cb4c16e25a7a3 100644 --- a/python/pyspark/ml/clustering.py +++ b/python/pyspark/ml/clustering.py @@ -19,7 +19,6 @@ from pyspark.ml.wrapper import JavaEstimator, JavaModel from pyspark.ml.param.shared import * from pyspark.mllib.common import inherit_doc -from pyspark.mllib.linalg import _convert_to_vector __all__ = ['KMeans', 'KMeansModel'] @@ -35,7 +34,7 @@ def clusterCenters(self): @inherit_doc -class KMeans(JavaEstimator, HasFeaturesCol, HasMaxIter, HasSeed): +class KMeans(JavaEstimator, HasFeaturesCol, HasPredictionCol, HasMaxIter, HasTol, HasSeed): """ K-means clustering with support for multiple parallel runs and a k-means++ like initialization mode (the k-means|| algorithm by Bahmani et al). When multiple concurrent runs are requested, @@ -45,7 +44,7 @@ class KMeans(JavaEstimator, HasFeaturesCol, HasMaxIter, HasSeed): >>> data = [(Vectors.dense([0.0, 0.0]),), (Vectors.dense([1.0, 1.0]),), ... (Vectors.dense([9.0, 8.0]),), (Vectors.dense([8.0, 9.0]),)] >>> df = sqlContext.createDataFrame(data, ["features"]) - >>> kmeans = KMeans().setK(2).setSeed(1).setFeaturesCol("features") + >>> kmeans = KMeans(k=2, seed=1) >>> model = kmeans.fit(df) >>> centers = model.clusterCenters() >>> len(centers) @@ -60,10 +59,6 @@ class KMeans(JavaEstimator, HasFeaturesCol, HasMaxIter, HasSeed): # a placeholder to make it appear in the generated doc k = Param(Params._dummy(), "k", "number of clusters to create") - epsilon = Param(Params._dummy(), "epsilon", - "distance threshold within which " + - "we've consider centers to have converged") - runs = Param(Params._dummy(), "runs", "number of runs of the algorithm to execute in parallel") initMode = Param(Params._dummy(), "initMode", "the initialization algorithm. This can be either \"random\" to " + "choose random points as initial cluster centers, or \"k-means||\" " + @@ -71,21 +66,21 @@ class KMeans(JavaEstimator, HasFeaturesCol, HasMaxIter, HasSeed): initSteps = Param(Params._dummy(), "initSteps", "steps for k-means initialization mode") @keyword_only - def __init__(self, k=2, maxIter=20, runs=1, epsilon=1e-4, initMode="k-means||", initStep=5): + def __init__(self, featuresCol="features", predictionCol="prediction", k=2, + initMode="k-means||", initSteps=5, tol=1e-4, maxIter=20, seed=None): + """ + __init__(self, featuresCol="features", predictionCol="prediction", k=2, \ + initMode="k-means||", initSteps=5, tol=1e-4, maxIter=20, seed=None) + """ super(KMeans, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.clustering.KMeans", self.uid) self.k = Param(self, "k", "number of clusters to create") - self.epsilon = Param(self, "epsilon", - "distance threshold within which " + - "we've consider centers to have converged") - self.runs = Param(self, "runs", "number of runs of the algorithm to execute in parallel") - self.seed = Param(self, "seed", "random seed") self.initMode = Param(self, "initMode", "the initialization algorithm. This can be either \"random\" to " + "choose random points as initial cluster centers, or \"k-means||\" " + "to use a parallel variant of k-means++") self.initSteps = Param(self, "initSteps", "steps for k-means initialization mode") - self._setDefault(k=2, maxIter=20, runs=1, epsilon=1e-4, initMode="k-means||", initSteps=5) + self._setDefault(k=2, initMode="k-means||", initSteps=5, tol=1e-4, maxIter=20) kwargs = self.__init__._input_kwargs self.setParams(**kwargs) @@ -93,9 +88,11 @@ def _create_model(self, java_model): return KMeansModel(java_model) @keyword_only - def setParams(self, k=2, maxIter=20, runs=1, epsilon=1e-4, initMode="k-means||", initSteps=5): + def setParams(self, featuresCol="features", predictionCol="prediction", k=2, + initMode="k-means||", initSteps=5, tol=1e-4, maxIter=20, seed=None): """ - setParams(self, k=2, maxIter=20, runs=1, epsilon=1e-4, initMode="k-means||", initSteps=5): + setParams(self, featuresCol="features", predictionCol="prediction", k=2, \ + initMode="k-means||", initSteps=5, tol=1e-4, maxIter=20, seed=None) Sets params for KMeans. """ @@ -119,40 +116,6 @@ def getK(self): """ return self.getOrDefault(self.k) - def setEpsilon(self, value): - """ - Sets the value of :py:attr:`epsilon`. - - >>> algo = KMeans().setEpsilon(1e-5) - >>> abs(algo.getEpsilon() - 1e-5) < 1e-5 - True - """ - self._paramMap[self.epsilon] = value - return self - - def getEpsilon(self): - """ - Gets the value of `epsilon` - """ - return self.getOrDefault(self.epsilon) - - def setRuns(self, value): - """ - Sets the value of :py:attr:`runs`. - - >>> algo = KMeans().setRuns(10) - >>> algo.getRuns() - 10 - """ - self._paramMap[self.runs] = value - return self - - def getRuns(self): - """ - Gets the value of `runs` - """ - return self.getOrDefault(self.runs) - def setInitMode(self, value): """ Sets the value of :py:attr:`initMode`. From 84a27916a62980c8fcb0977c3a7fdb73c0bd5812 Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Thu, 13 Aug 2015 15:08:57 +0800 Subject: [PATCH 086/253] [SPARK-9885] [SQL] Also pass barrierPrefixes and sharedPrefixes to IsolatedClientLoader when hiveMetastoreJars is set to maven. https://issues.apache.org/jira/browse/SPARK-9885 cc marmbrus liancheng Author: Yin Huai Closes #8158 from yhuai/classloaderMaven. --- .../scala/org/apache/spark/sql/hive/HiveContext.scala | 6 +++++- .../spark/sql/hive/client/IsolatedClientLoader.scala | 11 +++++++++-- 2 files changed, 14 insertions(+), 3 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index f17177a771c3b..17762649fd70d 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -231,7 +231,11 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) with Logging { // TODO: Support for loading the jars from an already downloaded location. logInfo( s"Initializing HiveMetastoreConnection version $hiveMetastoreVersion using maven.") - IsolatedClientLoader.forVersion(hiveMetastoreVersion, allConfig) + IsolatedClientLoader.forVersion( + version = hiveMetastoreVersion, + config = allConfig, + barrierPrefixes = hiveMetastoreBarrierPrefixes, + sharedPrefixes = hiveMetastoreSharedPrefixes) } else { // Convert to files and expand any directories. val jars = diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala index a7d5a991948d9..7856037508412 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala @@ -42,11 +42,18 @@ private[hive] object IsolatedClientLoader { def forVersion( version: String, config: Map[String, String] = Map.empty, - ivyPath: Option[String] = None): IsolatedClientLoader = synchronized { + ivyPath: Option[String] = None, + sharedPrefixes: Seq[String] = Seq.empty, + barrierPrefixes: Seq[String] = Seq.empty): IsolatedClientLoader = synchronized { val resolvedVersion = hiveVersion(version) val files = resolvedVersions.getOrElseUpdate(resolvedVersion, downloadVersion(resolvedVersion, ivyPath)) - new IsolatedClientLoader(hiveVersion(version), files, config) + new IsolatedClientLoader( + version = hiveVersion(version), + execJars = files, + config = config, + sharedPrefixes = sharedPrefixes, + barrierPrefixes = barrierPrefixes) } def hiveVersion(version: String): HiveVersion = version match { From 69930310115501f0de094fe6f5c6c60dade342bd Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Thu, 13 Aug 2015 16:16:50 +0800 Subject: [PATCH 087/253] [SPARK-9757] [SQL] Fixes persistence of Parquet relation with decimal column PR #7967 enables us to save data source relations to metastore in Hive compatible format when possible. But it fails to persist Parquet relations with decimal column(s) to Hive metastore of versions lower than 1.2.0. This is because `ParquetHiveSerDe` in Hive versions prior to 1.2.0 doesn't support decimal. This PR checks for this case and falls back to Spark SQL specific metastore table format. Author: Yin Huai Author: Cheng Lian Closes #8130 from liancheng/spark-9757/old-hive-parquet-decimal. --- .../apache/spark/sql/types/ArrayType.scala | 6 +- .../org/apache/spark/sql/types/DataType.scala | 5 ++ .../org/apache/spark/sql/types/MapType.scala | 6 +- .../apache/spark/sql/types/StructType.scala | 8 ++- .../spark/sql/types/DataTypeSuite.scala | 24 +++++++ .../spark/sql/hive/HiveMetastoreCatalog.scala | 39 ++++++++--- .../sql/hive/client/ClientInterface.scala | 3 + .../spark/sql/hive/client/ClientWrapper.scala | 2 +- .../spark/sql/hive/client/package.scala | 2 +- .../sql/hive/HiveMetastoreCatalogSuite.scala | 17 +++-- .../spark/sql/hive/HiveSparkSubmitSuite.scala | 68 +++++++++++++++++-- 11 files changed, 150 insertions(+), 30 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala index 5094058164b2f..5770f59b53077 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala @@ -75,6 +75,10 @@ case class ArrayType(elementType: DataType, containsNull: Boolean) extends DataT override def simpleString: String = s"array<${elementType.simpleString}>" - private[spark] override def asNullable: ArrayType = + override private[spark] def asNullable: ArrayType = ArrayType(elementType.asNullable, containsNull = true) + + override private[spark] def existsRecursively(f: (DataType) => Boolean): Boolean = { + f(this) || elementType.existsRecursively(f) + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala index f4428c2e8b202..7bcd623b3f33e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala @@ -77,6 +77,11 @@ abstract class DataType extends AbstractDataType { */ private[spark] def asNullable: DataType + /** + * Returns true if any `DataType` of this DataType tree satisfies the given function `f`. + */ + private[spark] def existsRecursively(f: (DataType) => Boolean): Boolean = f(this) + override private[sql] def defaultConcreteType: DataType = this override private[sql] def acceptsType(other: DataType): Boolean = sameType(other) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala index ac34b642827ca..00461e529ca0a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala @@ -62,8 +62,12 @@ case class MapType( override def simpleString: String = s"map<${keyType.simpleString},${valueType.simpleString}>" - private[spark] override def asNullable: MapType = + override private[spark] def asNullable: MapType = MapType(keyType.asNullable, valueType.asNullable, valueContainsNull = true) + + override private[spark] def existsRecursively(f: (DataType) => Boolean): Boolean = { + f(this) || keyType.existsRecursively(f) || valueType.existsRecursively(f) + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala index 9cbc207538d4f..d8968ef806390 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala @@ -24,7 +24,7 @@ import org.json4s.JsonDSL._ import org.apache.spark.SparkException import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.sql.catalyst.expressions.{InterpretedOrdering, AttributeReference, Attribute, InterpretedOrdering$} +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, InterpretedOrdering} /** @@ -292,7 +292,7 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru private[sql] def merge(that: StructType): StructType = StructType.merge(this, that).asInstanceOf[StructType] - private[spark] override def asNullable: StructType = { + override private[spark] def asNullable: StructType = { val newFields = fields.map { case StructField(name, dataType, nullable, metadata) => StructField(name, dataType.asNullable, nullable = true, metadata) @@ -301,6 +301,10 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru StructType(newFields) } + override private[spark] def existsRecursively(f: (DataType) => Boolean): Boolean = { + f(this) || fields.exists(field => field.dataType.existsRecursively(f)) + } + private[sql] val interpretedOrdering = InterpretedOrdering.forSchema(this.fields.map(_.dataType)) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala index 88b221cd81d74..706ecd29d1355 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala @@ -170,6 +170,30 @@ class DataTypeSuite extends SparkFunSuite { } } + test("existsRecursively") { + val struct = StructType( + StructField("a", LongType) :: + StructField("b", FloatType) :: Nil) + assert(struct.existsRecursively(_.isInstanceOf[LongType])) + assert(struct.existsRecursively(_.isInstanceOf[StructType])) + assert(!struct.existsRecursively(_.isInstanceOf[IntegerType])) + + val mapType = MapType(struct, StringType) + assert(mapType.existsRecursively(_.isInstanceOf[LongType])) + assert(mapType.existsRecursively(_.isInstanceOf[StructType])) + assert(mapType.existsRecursively(_.isInstanceOf[StringType])) + assert(mapType.existsRecursively(_.isInstanceOf[MapType])) + assert(!mapType.existsRecursively(_.isInstanceOf[IntegerType])) + + val arrayType = ArrayType(mapType) + assert(arrayType.existsRecursively(_.isInstanceOf[LongType])) + assert(arrayType.existsRecursively(_.isInstanceOf[StructType])) + assert(arrayType.existsRecursively(_.isInstanceOf[StringType])) + assert(arrayType.existsRecursively(_.isInstanceOf[MapType])) + assert(arrayType.existsRecursively(_.isInstanceOf[ArrayType])) + assert(!arrayType.existsRecursively(_.isInstanceOf[IntegerType])) + } + def checkDataTypeJsonRepr(dataType: DataType): Unit = { test(s"JSON - $dataType") { assert(DataType.fromJson(dataType.json) === dataType) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index 5e5497837a393..6770462bb0ad3 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -33,15 +33,14 @@ import org.apache.hadoop.hive.ql.plan.TableDesc import org.apache.spark.Logging import org.apache.spark.sql.catalyst.analysis.{Catalog, MultiInstanceRelation, OverrideCatalog} import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ import org.apache.spark.sql.catalyst.{InternalRow, SqlParser, TableIdentifier} -import org.apache.spark.sql.execution.{FileRelation, datasources} +import org.apache.spark.sql.execution.datasources.parquet.ParquetRelation import org.apache.spark.sql.execution.datasources.{CreateTableUsingAsSelect, LogicalRelation, Partition => ParquetPartition, PartitionSpec, ResolvedDataSource} +import org.apache.spark.sql.execution.{FileRelation, datasources} import org.apache.spark.sql.hive.client._ -import org.apache.spark.sql.execution.datasources.parquet.ParquetRelation import org.apache.spark.sql.sources._ import org.apache.spark.sql.types._ import org.apache.spark.sql.{AnalysisException, SQLContext, SaveMode} @@ -86,9 +85,9 @@ private[hive] object HiveSerDe { serde = Option("org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe"))) val key = source.toLowerCase match { - case _ if source.startsWith("org.apache.spark.sql.parquet") => "parquet" - case _ if source.startsWith("org.apache.spark.sql.orc") => "orc" - case _ => source.toLowerCase + case s if s.startsWith("org.apache.spark.sql.parquet") => "parquet" + case s if s.startsWith("org.apache.spark.sql.orc") => "orc" + case s => s } serdeMap.get(key) @@ -309,11 +308,31 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive val hiveTable = (maybeSerDe, dataSource.relation) match { case (Some(serde), relation: HadoopFsRelation) if relation.paths.length == 1 && relation.partitionColumns.isEmpty => - logInfo { - "Persisting data source relation with a single input path into Hive metastore in Hive " + - s"compatible format. Input path: ${relation.paths.head}" + // Hive ParquetSerDe doesn't support decimal type until 1.2.0. + val isParquetSerDe = serde.inputFormat.exists(_.toLowerCase.contains("parquet")) + val hasDecimalFields = relation.schema.existsRecursively(_.isInstanceOf[DecimalType]) + + val hiveParquetSupportsDecimal = client.version match { + case org.apache.spark.sql.hive.client.hive.v1_2 => true + case _ => false + } + + if (isParquetSerDe && !hiveParquetSupportsDecimal && hasDecimalFields) { + // If Hive version is below 1.2.0, we cannot save Hive compatible schema to + // metastore when the file format is Parquet and the schema has DecimalType. + logWarning { + "Persisting Parquet relation with decimal field(s) into Hive metastore in Spark SQL " + + "specific format, which is NOT compatible with Hive. Because ParquetHiveSerDe in " + + s"Hive ${client.version.fullVersion} doesn't support decimal type. See HIVE-6384." + } + newSparkSQLSpecificMetastoreTable() + } else { + logInfo { + "Persisting data source relation with a single input path into Hive metastore in " + + s"Hive compatible format. Input path: ${relation.paths.head}" + } + newHiveCompatibleMetastoreTable(relation, serde) } - newHiveCompatibleMetastoreTable(relation, serde) case (Some(serde), relation: HadoopFsRelation) if relation.partitionColumns.nonEmpty => logWarning { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientInterface.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientInterface.scala index a82e152dcda2c..3811c152a7ae6 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientInterface.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientInterface.scala @@ -88,6 +88,9 @@ private[hive] case class HiveTable( */ private[hive] trait ClientInterface { + /** Returns the Hive Version of this client. */ + def version: HiveVersion + /** Returns the configuration for the given key in the current session. */ def getConf(key: String, defaultValue: String): String diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala index 3d05b583cf9e0..f49c97de8ff4e 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala @@ -58,7 +58,7 @@ import org.apache.spark.util.{CircularBuffer, Utils} * this ClientWrapper. */ private[hive] class ClientWrapper( - version: HiveVersion, + override val version: HiveVersion, config: Map[String, String], initClassLoader: ClassLoader) extends ClientInterface diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/package.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/package.scala index 0503691a44249..b1b8439efa011 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/package.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/package.scala @@ -25,7 +25,7 @@ package object client { val exclusions: Seq[String] = Nil) // scalastyle:off - private[client] object hive { + private[hive] object hive { case object v12 extends HiveVersion("0.12.0") case object v13 extends HiveVersion("0.13.1") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala index 332c3ec0c28b8..59e65ff97b8e0 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala @@ -19,13 +19,13 @@ package org.apache.spark.sql.hive import java.io.File -import org.apache.spark.sql.hive.client.{ExternalTable, HiveColumn, ManagedTable} +import org.apache.spark.sql.hive.client.{ExternalTable, ManagedTable} import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.hive.test.TestHive._ import org.apache.spark.sql.hive.test.TestHive.implicits._ import org.apache.spark.sql.sources.DataSourceTest import org.apache.spark.sql.test.{ExamplePointUDT, SQLTestUtils} -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.types.{DecimalType, StringType, StructType} import org.apache.spark.sql.{Row, SaveMode} import org.apache.spark.{Logging, SparkFunSuite} @@ -55,7 +55,10 @@ class HiveMetastoreCatalogSuite extends SparkFunSuite with Logging { class DataSourceWithHiveMetastoreCatalogSuite extends DataSourceTest with SQLTestUtils { override val sqlContext = TestHive - private val testDF = (1 to 2).map(i => (i, s"val_$i")).toDF("d1", "d2").coalesce(1) + private val testDF = range(1, 3).select( + ('id + 0.1) cast DecimalType(10, 3) as 'd1, + 'id cast StringType as 'd2 + ).coalesce(1) Seq( "parquet" -> ( @@ -88,10 +91,10 @@ class DataSourceWithHiveMetastoreCatalogSuite extends DataSourceTest with SQLTes val columns = hiveTable.schema assert(columns.map(_.name) === Seq("d1", "d2")) - assert(columns.map(_.hiveType) === Seq("int", "string")) + assert(columns.map(_.hiveType) === Seq("decimal(10,3)", "string")) checkAnswer(table("t"), testDF) - assert(runSqlHive("SELECT * FROM t") === Seq("1\tval_1", "2\tval_2")) + assert(runSqlHive("SELECT * FROM t") === Seq("1.1\t1", "2.1\t2")) } } @@ -117,10 +120,10 @@ class DataSourceWithHiveMetastoreCatalogSuite extends DataSourceTest with SQLTes val columns = hiveTable.schema assert(columns.map(_.name) === Seq("d1", "d2")) - assert(columns.map(_.hiveType) === Seq("int", "string")) + assert(columns.map(_.hiveType) === Seq("decimal(10,3)", "string")) checkAnswer(table("t"), testDF) - assert(runSqlHive("SELECT * FROM t") === Seq("1\tval_1", "2\tval_2")) + assert(runSqlHive("SELECT * FROM t") === Seq("1.1\t1", "2.1\t2")) } } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala index 1e1972d1ac353..0c29646114465 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala @@ -20,16 +20,18 @@ package org.apache.spark.sql.hive import java.io.File import scala.collection.mutable.ArrayBuffer -import scala.sys.process.{ProcessLogger, Process} +import scala.sys.process.{Process, ProcessLogger} +import org.scalatest.Matchers +import org.scalatest.concurrent.Timeouts import org.scalatest.exceptions.TestFailedDueToTimeoutException +import org.scalatest.time.SpanSugar._ import org.apache.spark._ +import org.apache.spark.sql.QueryTest import org.apache.spark.sql.hive.test.{TestHive, TestHiveContext} +import org.apache.spark.sql.types.DecimalType import org.apache.spark.util.{ResetSystemProperties, Utils} -import org.scalatest.Matchers -import org.scalatest.concurrent.Timeouts -import org.scalatest.time.SpanSugar._ /** * This suite tests spark-submit with applications using HiveContext. @@ -50,8 +52,8 @@ class HiveSparkSubmitSuite val unusedJar = TestUtils.createJarWithClasses(Seq.empty) val jar1 = TestUtils.createJarWithClasses(Seq("SparkSubmitClassA")) val jar2 = TestUtils.createJarWithClasses(Seq("SparkSubmitClassB")) - val jar3 = TestHive.getHiveFile("hive-contrib-0.13.1.jar").getCanonicalPath() - val jar4 = TestHive.getHiveFile("hive-hcatalog-core-0.13.1.jar").getCanonicalPath() + val jar3 = TestHive.getHiveFile("hive-contrib-0.13.1.jar").getCanonicalPath + val jar4 = TestHive.getHiveFile("hive-hcatalog-core-0.13.1.jar").getCanonicalPath val jarsString = Seq(jar1, jar2, jar3, jar4).map(j => j.toString).mkString(",") val args = Seq( "--class", SparkSubmitClassLoaderTest.getClass.getName.stripSuffix("$"), @@ -91,6 +93,16 @@ class HiveSparkSubmitSuite runSparkSubmit(args) } + test("SPARK-9757 Persist Parquet relation with decimal column") { + val unusedJar = TestUtils.createJarWithClasses(Seq.empty) + val args = Seq( + "--class", SPARK_9757.getClass.getName.stripSuffix("$"), + "--name", "SparkSQLConfTest", + "--master", "local-cluster[2,1,1024]", + unusedJar.toString) + runSparkSubmit(args) + } + // NOTE: This is an expensive operation in terms of time (10 seconds+). Use sparingly. // This is copied from org.apache.spark.deploy.SparkSubmitSuite private def runSparkSubmit(args: Seq[String]): Unit = { @@ -213,7 +225,7 @@ object SparkSQLConfTest extends Logging { // before spark.sql.hive.metastore.jars get set, we will see the following exception: // Exception in thread "main" java.lang.IllegalArgumentException: Builtin jars can only // be used when hive execution version == hive metastore version. - // Execution: 0.13.1 != Metastore: 0.12. Specify a vaild path to the correct hive jars + // Execution: 0.13.1 != Metastore: 0.12. Specify a valid path to the correct hive jars // using $HIVE_METASTORE_JARS or change spark.sql.hive.metastore.version to 0.13.1. val conf = new SparkConf() { override def getAll: Array[(String, String)] = { @@ -239,3 +251,45 @@ object SparkSQLConfTest extends Logging { sc.stop() } } + +object SPARK_9757 extends QueryTest with Logging { + def main(args: Array[String]): Unit = { + Utils.configTestLog4j("INFO") + + val sparkContext = new SparkContext( + new SparkConf() + .set("spark.sql.hive.metastore.version", "0.13.1") + .set("spark.sql.hive.metastore.jars", "maven")) + + val hiveContext = new TestHiveContext(sparkContext) + import hiveContext.implicits._ + import org.apache.spark.sql.functions._ + + val dir = Utils.createTempDir() + dir.delete() + + try { + { + val df = + hiveContext + .range(10) + .select(('id + 0.1) cast DecimalType(10, 3) as 'dec) + df.write.option("path", dir.getCanonicalPath).mode("overwrite").saveAsTable("t") + checkAnswer(hiveContext.table("t"), df) + } + + { + val df = + hiveContext + .range(10) + .select(callUDF("struct", ('id + 0.2) cast DecimalType(10, 3)) as 'dec_struct) + df.write.option("path", dir.getCanonicalPath).mode("overwrite").saveAsTable("t") + checkAnswer(hiveContext.table("t"), df) + } + } finally { + dir.delete() + hiveContext.sql("DROP TABLE t") + sparkContext.stop() + } + } +} From 2932e25da4532de9e86b01d08bce0cb680874e70 Mon Sep 17 00:00:00 2001 From: lewuathe Date: Thu, 13 Aug 2015 09:17:19 -0700 Subject: [PATCH 088/253] [SPARK-9073] [ML] spark.ml Models copy() should call setParent when there is a parent Copied ML models must have the same parent of original ones Author: lewuathe Author: Lewuathe Closes #7447 from Lewuathe/SPARK-9073. --- .../examples/ml/JavaDeveloperApiExample.java | 3 +- .../examples/ml/DeveloperApiExample.scala | 2 +- .../scala/org/apache/spark/ml/Pipeline.scala | 2 +- .../DecisionTreeClassifier.scala | 1 + .../ml/classification/GBTClassifier.scala | 2 +- .../classification/LogisticRegression.scala | 2 +- .../spark/ml/classification/OneVsRest.scala | 2 +- .../RandomForestClassifier.scala | 1 + .../apache/spark/ml/feature/Bucketizer.scala | 4 ++- .../org/apache/spark/ml/feature/IDF.scala | 2 +- .../spark/ml/feature/MinMaxScaler.scala | 2 +- .../org/apache/spark/ml/feature/PCA.scala | 2 +- .../spark/ml/feature/StandardScaler.scala | 2 +- .../spark/ml/feature/StringIndexer.scala | 2 +- .../spark/ml/feature/VectorIndexer.scala | 2 +- .../apache/spark/ml/feature/Word2Vec.scala | 2 +- .../apache/spark/ml/recommendation/ALS.scala | 2 +- .../ml/regression/DecisionTreeRegressor.scala | 2 +- .../spark/ml/regression/GBTRegressor.scala | 2 +- .../ml/regression/LinearRegression.scala | 2 +- .../ml/regression/RandomForestRegressor.scala | 2 +- .../spark/ml/tuning/CrossValidator.scala | 2 +- .../org/apache/spark/ml/PipelineSuite.scala | 3 ++ .../DecisionTreeClassifierSuite.scala | 4 +++ .../classification/GBTClassifierSuite.scala | 4 +++ .../LogisticRegressionSuite.scala | 4 +++ .../ml/classification/OneVsRestSuite.scala | 6 +++- .../RandomForestClassifierSuite.scala | 4 +++ .../spark/ml/feature/BucketizerSuite.scala | 1 + .../spark/ml/feature/MinMaxScalerSuite.scala | 4 +++ .../apache/spark/ml/feature/PCASuite.scala | 4 +++ .../spark/ml/feature/StringIndexerSuite.scala | 5 ++++ .../spark/ml/feature/VectorIndexerSuite.scala | 5 ++++ .../spark/ml/feature/Word2VecSuite.scala | 4 +++ .../spark/ml/recommendation/ALSSuite.scala | 4 +++ .../DecisionTreeRegressorSuite.scala | 11 +++++++ .../ml/regression/GBTRegressorSuite.scala | 5 ++++ .../ml/regression/LinearRegressionSuite.scala | 5 ++++ .../RandomForestRegressorSuite.scala | 7 ++++- .../spark/ml/tuning/CrossValidatorSuite.scala | 5 ++++ .../apache/spark/ml/util/MLTestingUtils.scala | 30 +++++++++++++++++++ 41 files changed, 138 insertions(+), 22 deletions(-) create mode 100644 mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java index 9df26ffca5775..3f1fe900b0008 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java @@ -230,6 +230,7 @@ public Vector predictRaw(Vector features) { */ @Override public MyJavaLogisticRegressionModel copy(ParamMap extra) { - return copyValues(new MyJavaLogisticRegressionModel(uid(), weights_), extra); + return copyValues(new MyJavaLogisticRegressionModel(uid(), weights_), extra) + .setParent(parent()); } } diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala index 78f31b4ffe56a..340c3559b15ef 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala @@ -179,7 +179,7 @@ private class MyLogisticRegressionModel( * This is used for the default implementation of [[transform()]]. */ override def copy(extra: ParamMap): MyLogisticRegressionModel = { - copyValues(new MyLogisticRegressionModel(uid, weights), extra) + copyValues(new MyLogisticRegressionModel(uid, weights), extra).setParent(parent) } } // scalastyle:on println diff --git a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala index aef2c019d2871..a3e59401c5cfb 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala @@ -198,6 +198,6 @@ class PipelineModel private[ml] ( } override def copy(extra: ParamMap): PipelineModel = { - new PipelineModel(uid, stages.map(_.copy(extra))) + new PipelineModel(uid, stages.map(_.copy(extra))).setParent(parent) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala index 29598f3f05c2d..6f70b96b17ec6 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala @@ -141,6 +141,7 @@ final class DecisionTreeClassificationModel private[ml] ( override def copy(extra: ParamMap): DecisionTreeClassificationModel = { copyValues(new DecisionTreeClassificationModel(uid, rootNode, numClasses), extra) + .setParent(parent) } override def toString: String = { diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala index c3891a9599262..3073a2a61ce83 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala @@ -196,7 +196,7 @@ final class GBTClassificationModel( } override def copy(extra: ParamMap): GBTClassificationModel = { - copyValues(new GBTClassificationModel(uid, _trees, _treeWeights), extra) + copyValues(new GBTClassificationModel(uid, _trees, _treeWeights), extra).setParent(parent) } override def toString: String = { diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index 5bcd7117b668c..21fbe38ca8233 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -468,7 +468,7 @@ class LogisticRegressionModel private[ml] ( } override def copy(extra: ParamMap): LogisticRegressionModel = { - copyValues(new LogisticRegressionModel(uid, weights, intercept), extra) + copyValues(new LogisticRegressionModel(uid, weights, intercept), extra).setParent(parent) } override protected def raw2prediction(rawPrediction: Vector): Double = { diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala index 1741f19dc911c..1132d8046df67 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala @@ -138,7 +138,7 @@ final class OneVsRestModel private[ml] ( override def copy(extra: ParamMap): OneVsRestModel = { val copied = new OneVsRestModel( uid, labelMetadata, models.map(_.copy(extra).asInstanceOf[ClassificationModel[_, _]])) - copyValues(copied, extra) + copyValues(copied, extra).setParent(parent) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala index 156050aaf7a45..11a6d72468333 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala @@ -189,6 +189,7 @@ final class RandomForestClassificationModel private[ml] ( override def copy(extra: ParamMap): RandomForestClassificationModel = { copyValues(new RandomForestClassificationModel(uid, _trees, numFeatures, numClasses), extra) + .setParent(parent) } override def toString: String = { diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala index 67e4785bc3553..cfca494dcf468 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala @@ -90,7 +90,9 @@ final class Bucketizer(override val uid: String) SchemaUtils.appendColumn(schema, prepOutputField(schema)) } - override def copy(extra: ParamMap): Bucketizer = defaultCopy(extra) + override def copy(extra: ParamMap): Bucketizer = { + defaultCopy[Bucketizer](extra).setParent(parent) + } } private[feature] object Bucketizer { diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala index ecde80810580c..938447447a0a2 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala @@ -114,6 +114,6 @@ class IDFModel private[ml] ( override def copy(extra: ParamMap): IDFModel = { val copied = new IDFModel(uid, idfModel) - copyValues(copied, extra) + copyValues(copied, extra).setParent(parent) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala index 9a473dd23772d..1b494ec8b1727 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala @@ -173,6 +173,6 @@ class MinMaxScalerModel private[ml] ( override def copy(extra: ParamMap): MinMaxScalerModel = { val copied = new MinMaxScalerModel(uid, originalMin, originalMax) - copyValues(copied, extra) + copyValues(copied, extra).setParent(parent) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala index 2d3bb680cf309..539084704b653 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala @@ -125,6 +125,6 @@ class PCAModel private[ml] ( override def copy(extra: ParamMap): PCAModel = { val copied = new PCAModel(uid, pcaModel) - copyValues(copied, extra) + copyValues(copied, extra).setParent(parent) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala index 72b545e5db3e4..f6d0b0c0e9e75 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala @@ -136,6 +136,6 @@ class StandardScalerModel private[ml] ( override def copy(extra: ParamMap): StandardScalerModel = { val copied = new StandardScalerModel(uid, scaler) - copyValues(copied, extra) + copyValues(copied, extra).setParent(parent) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala index e4485eb038409..9e4b0f0add612 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala @@ -168,7 +168,7 @@ class StringIndexerModel private[ml] ( override def copy(extra: ParamMap): StringIndexerModel = { val copied = new StringIndexerModel(uid, labels) - copyValues(copied, extra) + copyValues(copied, extra).setParent(parent) } /** diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala index c73bdccdef5fa..6875aefe065bb 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala @@ -405,6 +405,6 @@ class VectorIndexerModel private[ml] ( override def copy(extra: ParamMap): VectorIndexerModel = { val copied = new VectorIndexerModel(uid, numFeatures, categoryMaps) - copyValues(copied, extra) + copyValues(copied, extra).setParent(parent) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala index 29acc3eb5865f..5af775a4159ad 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala @@ -221,6 +221,6 @@ class Word2VecModel private[ml] ( override def copy(extra: ParamMap): Word2VecModel = { val copied = new Word2VecModel(uid, wordVectors) - copyValues(copied, extra) + copyValues(copied, extra).setParent(parent) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala index 2e44cd4cc6a22..7db8ad8d27918 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala @@ -219,7 +219,7 @@ class ALSModel private[ml] ( override def copy(extra: ParamMap): ALSModel = { val copied = new ALSModel(uid, rank, userFactors, itemFactors) - copyValues(copied, extra) + copyValues(copied, extra).setParent(parent) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala index dc94a14014542..a2bcd67401d08 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala @@ -114,7 +114,7 @@ final class DecisionTreeRegressionModel private[ml] ( } override def copy(extra: ParamMap): DecisionTreeRegressionModel = { - copyValues(new DecisionTreeRegressionModel(uid, rootNode), extra) + copyValues(new DecisionTreeRegressionModel(uid, rootNode), extra).setParent(parent) } override def toString: String = { diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala index 5633bc320273a..b66e61f37dd5e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala @@ -185,7 +185,7 @@ final class GBTRegressionModel( } override def copy(extra: ParamMap): GBTRegressionModel = { - copyValues(new GBTRegressionModel(uid, _trees, _treeWeights), extra) + copyValues(new GBTRegressionModel(uid, _trees, _treeWeights), extra).setParent(parent) } override def toString: String = { diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index 92d819bad8654..884003eb38524 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -312,7 +312,7 @@ class LinearRegressionModel private[ml] ( override def copy(extra: ParamMap): LinearRegressionModel = { val newModel = copyValues(new LinearRegressionModel(uid, weights, intercept)) if (trainingSummary.isDefined) newModel.setSummary(trainingSummary.get) - newModel + newModel.setParent(parent) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala index db75c0d26392f..2f36da371f577 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala @@ -151,7 +151,7 @@ final class RandomForestRegressionModel private[ml] ( } override def copy(extra: ParamMap): RandomForestRegressionModel = { - copyValues(new RandomForestRegressionModel(uid, _trees, numFeatures), extra) + copyValues(new RandomForestRegressionModel(uid, _trees, numFeatures), extra).setParent(parent) } override def toString: String = { diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala index f979319cc4b58..4792eb0f0a288 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala @@ -160,6 +160,6 @@ class CrossValidatorModel private[ml] ( uid, bestModel.copy(extra).asInstanceOf[Model[_]], avgMetrics.clone()) - copyValues(copied, extra) + copyValues(copied, extra).setParent(parent) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala index 63d2fa31c7499..1f2c9b75b617b 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala @@ -26,6 +26,7 @@ import org.scalatest.mock.MockitoSugar.mock import org.apache.spark.SparkFunSuite import org.apache.spark.ml.feature.HashingTF import org.apache.spark.ml.param.ParamMap +import org.apache.spark.ml.util.MLTestingUtils import org.apache.spark.sql.DataFrame class PipelineSuite extends SparkFunSuite { @@ -65,6 +66,8 @@ class PipelineSuite extends SparkFunSuite { .setStages(Array(estimator0, transformer1, estimator2, transformer3)) val pipelineModel = pipeline.fit(dataset0) + MLTestingUtils.checkCopy(pipelineModel) + assert(pipelineModel.stages.length === 4) assert(pipelineModel.stages(0).eq(model0)) assert(pipelineModel.stages(1).eq(transformer1)) diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala index c7bbf1ce07a23..4b7c5d3f23d2c 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala @@ -21,6 +21,7 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.ml.impl.TreeTests import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.tree.LeafNode +import org.apache.spark.ml.util.MLTestingUtils import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.{DecisionTree => OldDecisionTree, DecisionTreeSuite => OldDecisionTreeSuite} @@ -244,6 +245,9 @@ class DecisionTreeClassifierSuite extends SparkFunSuite with MLlibTestSparkConte val newData: DataFrame = TreeTests.setMetadata(rdd, categoricalFeatures, numClasses) val newTree = dt.fit(newData) + // copied model must have the same parent. + MLTestingUtils.checkCopy(newTree) + val predictions = newTree.transform(newData) .select(newTree.getPredictionCol, newTree.getRawPredictionCol, newTree.getProbabilityCol) .collect() diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala index d4b5896c12c06..e3909bccaa5ca 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala @@ -22,6 +22,7 @@ import org.apache.spark.ml.impl.TreeTests import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.regression.DecisionTreeRegressionModel import org.apache.spark.ml.tree.LeafNode +import org.apache.spark.ml.util.MLTestingUtils import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.{EnsembleTestHelper, GradientBoostedTrees => OldGBT} import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} @@ -92,6 +93,9 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext { .setCheckpointInterval(2) val model = gbt.fit(df) + // copied model must have the same parent. + MLTestingUtils.checkCopy(model) + sc.checkpointDir = None Utils.deleteRecursively(tempDir) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala index e354e161c6dee..cce39f382f738 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.ml.classification import org.apache.spark.SparkFunSuite import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.ml.util.MLTestingUtils import org.apache.spark.mllib.classification.LogisticRegressionSuite._ import org.apache.spark.mllib.linalg.{Vectors, Vector} import org.apache.spark.mllib.util.MLlibTestSparkContext @@ -135,6 +136,9 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { lr.setFitIntercept(false) val model = lr.fit(dataset) assert(model.intercept === 0.0) + + // copied model must have the same parent. + MLTestingUtils.checkCopy(model) } test("logistic regression with setters") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala index bd8e819f6926c..977f0e0b70c1a 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala @@ -21,7 +21,7 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.ml.attribute.NominalAttribute import org.apache.spark.ml.feature.StringIndexer import org.apache.spark.ml.param.{ParamMap, ParamsSuite} -import org.apache.spark.ml.util.MetadataUtils +import org.apache.spark.ml.util.{MLTestingUtils, MetadataUtils} import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS import org.apache.spark.mllib.classification.LogisticRegressionSuite._ import org.apache.spark.mllib.evaluation.MulticlassMetrics @@ -70,6 +70,10 @@ class OneVsRestSuite extends SparkFunSuite with MLlibTestSparkContext { assert(ova.getLabelCol === "label") assert(ova.getPredictionCol === "prediction") val ovaModel = ova.fit(dataset) + + // copied model must have the same parent. + MLTestingUtils.checkCopy(ovaModel) + assert(ovaModel.models.size === numClasses) val transformedDataset = ovaModel.transform(dataset) diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala index 6ca4b5aa5fde8..b4403ec30049a 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala @@ -21,6 +21,7 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.ml.impl.TreeTests import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.tree.LeafNode +import org.apache.spark.ml.util.MLTestingUtils import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.{EnsembleTestHelper, RandomForest => OldRandomForest} @@ -135,6 +136,9 @@ class RandomForestClassifierSuite extends SparkFunSuite with MLlibTestSparkConte val df: DataFrame = TreeTests.setMetadata(rdd, categoricalFeatures, numClasses) val model = rf.fit(df) + // copied model must have the same parent. + MLTestingUtils.checkCopy(model) + val predictions = model.transform(df) .select(rf.getPredictionCol, rf.getRawPredictionCol, rf.getProbabilityCol) .collect() diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala index ec85e0d151e07..0eba34fda6228 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala @@ -21,6 +21,7 @@ import scala.util.Random import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.ml.util.MLTestingUtils import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/MinMaxScalerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/MinMaxScalerSuite.scala index c452054bec92f..c04dda41eea34 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/MinMaxScalerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/MinMaxScalerSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.ml.feature import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.util.MLTestingUtils import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.{Row, SQLContext} @@ -51,6 +52,9 @@ class MinMaxScalerSuite extends SparkFunSuite with MLlibTestSparkContext { .foreach { case Row(vector1: Vector, vector2: Vector) => assert(vector1.equals(vector2), "Transformed vector is different with expected.") } + + // copied model must have the same parent. + MLTestingUtils.checkCopy(model) } test("MinMaxScaler arguments max must be larger than min") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/PCASuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/PCASuite.scala index d0ae36b28c7a9..30c500f87a769 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/PCASuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/PCASuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.ml.feature import org.apache.spark.SparkFunSuite import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.ml.util.MLTestingUtils import org.apache.spark.mllib.linalg.distributed.RowMatrix import org.apache.spark.mllib.linalg.{Vector, Vectors, DenseMatrix, Matrices} import org.apache.spark.mllib.util.MLlibTestSparkContext @@ -56,6 +57,9 @@ class PCASuite extends SparkFunSuite with MLlibTestSparkContext { .setK(3) .fit(df) + // copied model must have the same parent. + MLTestingUtils.checkCopy(pca) + pca.transform(df).select("pca_features", "expected").collect().foreach { case Row(x: Vector, y: Vector) => assert(x ~== y absTol 1e-5, "Transformed vector is different with expected vector.") diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala index b111036087e6a..2d24914cb91f6 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala @@ -21,6 +21,7 @@ import org.apache.spark.SparkException import org.apache.spark.SparkFunSuite import org.apache.spark.ml.attribute.{Attribute, NominalAttribute} import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.ml.util.MLTestingUtils import org.apache.spark.mllib.util.MLlibTestSparkContext class StringIndexerSuite extends SparkFunSuite with MLlibTestSparkContext { @@ -38,6 +39,10 @@ class StringIndexerSuite extends SparkFunSuite with MLlibTestSparkContext { .setInputCol("label") .setOutputCol("labelIndex") .fit(df) + + // copied model must have the same parent. + MLTestingUtils.checkCopy(indexer) + val transformed = indexer.transform(df) val attr = Attribute.fromStructField(transformed.schema("labelIndex")) .asInstanceOf[NominalAttribute] diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala index 03120c828ca96..8cb0a2cf14d37 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala @@ -22,6 +22,7 @@ import scala.beans.{BeanInfo, BeanProperty} import org.apache.spark.{Logging, SparkException, SparkFunSuite} import org.apache.spark.ml.attribute._ import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.ml.util.MLTestingUtils import org.apache.spark.mllib.linalg.{SparseVector, Vector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.rdd.RDD @@ -109,6 +110,10 @@ class VectorIndexerSuite extends SparkFunSuite with MLlibTestSparkContext with L test("Throws error when given RDDs with different size vectors") { val vectorIndexer = getIndexer val model = vectorIndexer.fit(densePoints1) // vectors of length 3 + + // copied model must have the same parent. + MLTestingUtils.checkCopy(model) + model.transform(densePoints1) // should work model.transform(sparsePoints1) // should work intercept[SparkException] { diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala index adcda0e623b25..a2e46f2029956 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.ml.feature import org.apache.spark.SparkFunSuite import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.ml.util.MLTestingUtils import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ @@ -62,6 +63,9 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext { .setSeed(42L) .fit(docDF) + // copied model must have the same parent. + MLTestingUtils.checkCopy(model) + model.transform(docDF).select("result", "expected").collect().foreach { case Row(vector1: Vector, vector2: Vector) => assert(vector1 ~== vector2 absTol 1E-5, "Transformed vector is different with expected.") diff --git a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala index 2e5cfe7027eb6..eadc80e0e62b1 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala @@ -28,6 +28,7 @@ import com.github.fommil.netlib.BLAS.{getInstance => blas} import org.apache.spark.{Logging, SparkException, SparkFunSuite} import org.apache.spark.ml.recommendation.ALS._ +import org.apache.spark.ml.util.MLTestingUtils import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ @@ -374,6 +375,9 @@ class ALSSuite extends SparkFunSuite with MLlibTestSparkContext with Logging { } logInfo(s"Test RMSE is $rmse.") assert(rmse < targetRMSE) + + // copied model must have the same parent. + MLTestingUtils.checkCopy(model) } test("exact rank-1 matrix") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala index 33aa9d0d62343..b092bcd6a7e86 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.ml.regression import org.apache.spark.SparkFunSuite import org.apache.spark.ml.impl.TreeTests +import org.apache.spark.ml.util.MLTestingUtils import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.{DecisionTree => OldDecisionTree, DecisionTreeSuite => OldDecisionTreeSuite} @@ -61,6 +62,16 @@ class DecisionTreeRegressorSuite extends SparkFunSuite with MLlibTestSparkContex compareAPIs(categoricalDataPointsRDD, dt, categoricalFeatures) } + test("copied model must have the same parent") { + val categoricalFeatures = Map(0 -> 2, 1-> 2) + val df = TreeTests.setMetadata(categoricalDataPointsRDD, categoricalFeatures, numClasses = 0) + val model = new DecisionTreeRegressor() + .setImpurity("variance") + .setMaxDepth(2) + .setMaxBins(8).fit(df) + MLTestingUtils.checkCopy(model) + } + ///////////////////////////////////////////////////////////////////////////// // Tests of model save/load ///////////////////////////////////////////////////////////////////////////// diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala index dbdce0c9dea54..a68197b59193d 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.ml.regression import org.apache.spark.SparkFunSuite import org.apache.spark.ml.impl.TreeTests +import org.apache.spark.ml.util.MLTestingUtils import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.{EnsembleTestHelper, GradientBoostedTrees => OldGBT} @@ -82,6 +83,9 @@ class GBTRegressorSuite extends SparkFunSuite with MLlibTestSparkContext { .setMaxDepth(2) .setMaxIter(2) val model = gbt.fit(df) + + // copied model must have the same parent. + MLTestingUtils.checkCopy(model) val preds = model.transform(df) val predictions = preds.select("prediction").map(_.getDouble(0)) // Checks based on SPARK-8736 (to ensure it is not doing classification) @@ -104,6 +108,7 @@ class GBTRegressorSuite extends SparkFunSuite with MLlibTestSparkContext { sc.checkpointDir = None Utils.deleteRecursively(tempDir) + } // TODO: Reinstate test once runWithValidation is implemented SPARK-7132 diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala index 21ad8225bd9f7..2aaee71ecc734 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.ml.regression import org.apache.spark.SparkFunSuite import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.ml.util.MLTestingUtils import org.apache.spark.mllib.linalg.{DenseVector, Vectors} import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext} import org.apache.spark.mllib.util.TestingUtils._ @@ -72,6 +73,10 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { assert(lir.getFitIntercept) assert(lir.getStandardization) val model = lir.fit(dataset) + + // copied model must have the same parent. + MLTestingUtils.checkCopy(model) + model.transform(dataset) .select("label", "prediction") .collect() diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala index 992ce9562434e..7b1b3f11481de 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.ml.regression import org.apache.spark.SparkFunSuite import org.apache.spark.ml.impl.TreeTests +import org.apache.spark.ml.util.MLTestingUtils import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.{EnsembleTestHelper, RandomForest => OldRandomForest} @@ -91,7 +92,11 @@ class RandomForestRegressorSuite extends SparkFunSuite with MLlibTestSparkContex val categoricalFeatures = Map.empty[Int, Int] val df: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, 0) - val importances = rf.fit(df).featureImportances + val model = rf.fit(df) + + // copied model must have the same parent. + MLTestingUtils.checkCopy(model) + val importances = model.featureImportances val mostImportantFeature = importances.argmax assert(mostImportantFeature === 1) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala index db64511a76055..aaca08bb61a45 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.ml.tuning import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.util.MLTestingUtils import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.classification.LogisticRegression import org.apache.spark.ml.evaluation.{BinaryClassificationEvaluator, Evaluator, RegressionEvaluator} @@ -53,6 +54,10 @@ class CrossValidatorSuite extends SparkFunSuite with MLlibTestSparkContext { .setEvaluator(eval) .setNumFolds(3) val cvModel = cv.fit(dataset) + + // copied model must have the same paren. + MLTestingUtils.checkCopy(cvModel) + val parent = cvModel.bestModel.parent.asInstanceOf[LogisticRegression] assert(parent.getRegParam === 0.001) assert(parent.getMaxIter === 10) diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala b/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala new file mode 100644 index 0000000000000..d290cc9b06e73 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.util + +import org.apache.spark.ml.Model +import org.apache.spark.ml.param.ParamMap + +object MLTestingUtils { + def checkCopy(model: Model[_]): Unit = { + val copied = model.copy(ParamMap.empty) + .asInstanceOf[Model[_]] + assert(copied.parent.uid == model.parent.uid) + assert(copied.parent == model.parent) + } +} From 7a539ef3b1792764f866fa88c84c78ad59903f21 Mon Sep 17 00:00:00 2001 From: Rosstin Date: Thu, 13 Aug 2015 09:18:39 -0700 Subject: [PATCH 089/253] [SPARK-8965] [DOCS] Add ml-guide Python Example: Estimator, Transformer, and Param Added ml-guide Python Example: Estimator, Transformer, and Param /docs/_site/ml-guide.html Author: Rosstin Closes #8081 from Rosstin/SPARK-8965. --- docs/ml-guide.md | 68 ++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 68 insertions(+) diff --git a/docs/ml-guide.md b/docs/ml-guide.md index b6ca50e98db02..a03ab4356a413 100644 --- a/docs/ml-guide.md +++ b/docs/ml-guide.md @@ -355,6 +355,74 @@ jsc.stop(); {% endhighlight %} +
+{% highlight python %} +from pyspark import SparkContext +from pyspark.mllib.regression import LabeledPoint +from pyspark.ml.classification import LogisticRegression +from pyspark.ml.param import Param, Params +from pyspark.sql import Row, SQLContext + +sc = SparkContext(appName="SimpleParamsExample") +sqlContext = SQLContext(sc) + +# Prepare training data. +# We use LabeledPoint. +# Spark SQL can convert RDDs of LabeledPoints into DataFrames. +training = sc.parallelize([LabeledPoint(1.0, [0.0, 1.1, 0.1]), + LabeledPoint(0.0, [2.0, 1.0, -1.0]), + LabeledPoint(0.0, [2.0, 1.3, 1.0]), + LabeledPoint(1.0, [0.0, 1.2, -0.5])]) + +# Create a LogisticRegression instance. This instance is an Estimator. +lr = LogisticRegression(maxIter=10, regParam=0.01) +# Print out the parameters, documentation, and any default values. +print "LogisticRegression parameters:\n" + lr.explainParams() + "\n" + +# Learn a LogisticRegression model. This uses the parameters stored in lr. +model1 = lr.fit(training.toDF()) + +# Since model1 is a Model (i.e., a transformer produced by an Estimator), +# we can view the parameters it used during fit(). +# This prints the parameter (name: value) pairs, where names are unique IDs for this +# LogisticRegression instance. +print "Model 1 was fit using parameters: " +print model1.extractParamMap() + +# We may alternatively specify parameters using a Python dictionary as a paramMap +paramMap = {lr.maxIter: 20} +paramMap[lr.maxIter] = 30 # Specify 1 Param, overwriting the original maxIter. +paramMap.update({lr.regParam: 0.1, lr.threshold: 0.55}) # Specify multiple Params. + +# You can combine paramMaps, which are python dictionaries. +paramMap2 = {lr.probabilityCol: "myProbability"} # Change output column name +paramMapCombined = paramMap.copy() +paramMapCombined.update(paramMap2) + +# Now learn a new model using the paramMapCombined parameters. +# paramMapCombined overrides all parameters set earlier via lr.set* methods. +model2 = lr.fit(training.toDF(), paramMapCombined) +print "Model 2 was fit using parameters: " +print model2.extractParamMap() + +# Prepare test data +test = sc.parallelize([LabeledPoint(1.0, [-1.0, 1.5, 1.3]), + LabeledPoint(0.0, [ 3.0, 2.0, -0.1]), + LabeledPoint(1.0, [ 0.0, 2.2, -1.5])]) + +# Make predictions on test data using the Transformer.transform() method. +# LogisticRegression.transform will only use the 'features' column. +# Note that model2.transform() outputs a "myProbability" column instead of the usual +# 'probability' column since we renamed the lr.probabilityCol parameter previously. +prediction = model2.transform(test.toDF()) +selected = prediction.select("features", "label", "myProbability", "prediction") +for row in selected.collect(): + print row + +sc.stop() +{% endhighlight %} +
+ ## Example: Pipeline From 4b70798c96b0a784b85fda461426ec60f609be12 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Thu, 13 Aug 2015 09:31:14 -0700 Subject: [PATCH 090/253] [MINOR] [ML] change MultilayerPerceptronClassifierModel to MultilayerPerceptronClassificationModel To follow the naming rule of ML, change `MultilayerPerceptronClassifierModel` to `MultilayerPerceptronClassificationModel` like `DecisionTreeClassificationModel`, `GBTClassificationModel` and so on. Author: Yanbo Liang Closes #8164 from yanboliang/mlp-name. --- .../MultilayerPerceptronClassifier.scala | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala index 8cd2103d7d5e6..c154561886585 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala @@ -131,7 +131,7 @@ private object LabelConverter { */ @Experimental class MultilayerPerceptronClassifier(override val uid: String) - extends Predictor[Vector, MultilayerPerceptronClassifier, MultilayerPerceptronClassifierModel] + extends Predictor[Vector, MultilayerPerceptronClassifier, MultilayerPerceptronClassificationModel] with MultilayerPerceptronParams { def this() = this(Identifiable.randomUID("mlpc")) @@ -146,7 +146,7 @@ class MultilayerPerceptronClassifier(override val uid: String) * @param dataset Training dataset * @return Fitted model */ - override protected def train(dataset: DataFrame): MultilayerPerceptronClassifierModel = { + override protected def train(dataset: DataFrame): MultilayerPerceptronClassificationModel = { val myLayers = $(layers) val labels = myLayers.last val lpData = extractLabeledPoints(dataset) @@ -156,13 +156,13 @@ class MultilayerPerceptronClassifier(override val uid: String) FeedForwardTrainer.LBFGSOptimizer.setConvergenceTol($(tol)).setNumIterations($(maxIter)) FeedForwardTrainer.setStackSize($(blockSize)) val mlpModel = FeedForwardTrainer.train(data) - new MultilayerPerceptronClassifierModel(uid, myLayers, mlpModel.weights()) + new MultilayerPerceptronClassificationModel(uid, myLayers, mlpModel.weights()) } } /** * :: Experimental :: - * Classifier model based on the Multilayer Perceptron. + * Classification model based on the Multilayer Perceptron. * Each layer has sigmoid activation function, output layer has softmax. * @param uid uid * @param layers array of layer sizes including input and output layers @@ -170,11 +170,11 @@ class MultilayerPerceptronClassifier(override val uid: String) * @return prediction model */ @Experimental -class MultilayerPerceptronClassifierModel private[ml] ( +class MultilayerPerceptronClassificationModel private[ml] ( override val uid: String, layers: Array[Int], weights: Vector) - extends PredictionModel[Vector, MultilayerPerceptronClassifierModel] + extends PredictionModel[Vector, MultilayerPerceptronClassificationModel] with Serializable { private val mlpModel = FeedForwardTopology.multiLayerPerceptron(layers, true).getInstance(weights) @@ -187,7 +187,7 @@ class MultilayerPerceptronClassifierModel private[ml] ( LabelConverter.decodeLabel(mlpModel.predict(features)) } - override def copy(extra: ParamMap): MultilayerPerceptronClassifierModel = { - copyValues(new MultilayerPerceptronClassifierModel(uid, layers, weights), extra) + override def copy(extra: ParamMap): MultilayerPerceptronClassificationModel = { + copyValues(new MultilayerPerceptronClassificationModel(uid, layers, weights), extra) } } From 65fec798ce52ca6b8b0fe14b78a16712778ad04c Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Thu, 13 Aug 2015 10:16:40 -0700 Subject: [PATCH 091/253] [MINOR] [DOC] fix mllib pydoc warnings Switch to correct Sphinx syntax. MechCoder Author: Xiangrui Meng Closes #8169 from mengxr/mllib-pydoc-fix. --- python/pyspark/mllib/regression.py | 14 ++++++++++---- python/pyspark/mllib/util.py | 1 + 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/python/pyspark/mllib/regression.py b/python/pyspark/mllib/regression.py index 5b7afc15ddfba..41946e3674fbe 100644 --- a/python/pyspark/mllib/regression.py +++ b/python/pyspark/mllib/regression.py @@ -207,8 +207,10 @@ def train(cls, data, iterations=100, step=1.0, miniBatchFraction=1.0, Train a linear regression model using Stochastic Gradient Descent (SGD). This solves the least squares regression formulation - f(weights) = 1/n ||A weights-y||^2^ - (which is the mean squared error). + + f(weights) = 1/(2n) ||A weights - y||^2, + + which is the mean squared error. Here the data matrix has n rows, and the input RDD holds the set of rows of A, each with its corresponding right hand side label y. See also the documentation for the precise formulation. @@ -334,7 +336,9 @@ def train(cls, data, iterations=100, step=1.0, regParam=0.01, Stochastic Gradient Descent. This solves the l1-regularized least squares regression formulation - f(weights) = 1/2n ||A weights-y||^2^ + regParam ||weights||_1 + + f(weights) = 1/(2n) ||A weights - y||^2 + regParam ||weights||_1. + Here the data matrix has n rows, and the input RDD holds the set of rows of A, each with its corresponding right hand side label y. See also the documentation for the precise formulation. @@ -451,7 +455,9 @@ def train(cls, data, iterations=100, step=1.0, regParam=0.01, Stochastic Gradient Descent. This solves the l2-regularized least squares regression formulation - f(weights) = 1/2n ||A weights-y||^2^ + regParam/2 ||weights||^2^ + + f(weights) = 1/(2n) ||A weights - y||^2 + regParam/2 ||weights||^2. + Here the data matrix has n rows, and the input RDD holds the set of rows of A, each with its corresponding right hand side label y. See also the documentation for the precise formulation. diff --git a/python/pyspark/mllib/util.py b/python/pyspark/mllib/util.py index 916de2d6fcdbd..10a1e4b3eb0fc 100644 --- a/python/pyspark/mllib/util.py +++ b/python/pyspark/mllib/util.py @@ -300,6 +300,7 @@ def generateLinearInput(intercept, weights, xMean, xVariance, :param: seed Random Seed :param: eps Used to scale the noise. If eps is set high, the amount of gaussian noise added is more. + Returns a list of LabeledPoints of length nPoints """ weights = [float(weight) for weight in weights] From 8815ba2f674dbb18eb499216df9942b411e10daa Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Thu, 13 Aug 2015 11:31:10 -0700 Subject: [PATCH 092/253] [SPARK-9649] Fix MasterSuite, third time's a charm This particular test did not load the default configurations so it continued to start the REST server, which causes port bind exceptions. --- .../test/scala/org/apache/spark/deploy/master/MasterSuite.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala b/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala index 20d0201a364ab..242bf4b5566eb 100644 --- a/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala @@ -40,6 +40,7 @@ class MasterSuite extends SparkFunSuite with Matchers with Eventually with Priva conf.set("spark.deploy.recoveryMode", "CUSTOM") conf.set("spark.deploy.recoveryMode.factory", classOf[CustomRecoveryModeFactory].getCanonicalName) + conf.set("spark.master.rest.enabled", "false") val instantiationAttempts = CustomRecoveryModeFactory.instantiationAttempts From 864de8eaf4b6ad5c9099f6f29e251c56b029f631 Mon Sep 17 00:00:00 2001 From: MechCoder Date: Thu, 13 Aug 2015 13:42:35 -0700 Subject: [PATCH 093/253] [SPARK-9661] [MLLIB] [ML] Java compatibility I skimmed through the docs for various instance of Object and replaced them with Java compaible versions of the same. 1. Some methods in LDAModel. 2. runMiniBatchSGD 3. kolmogorovSmirnovTest Author: MechCoder Closes #8126 from MechCoder/java_incop. --- .../spark/mllib/clustering/LDAModel.scala | 27 +++++++++++++++++-- .../apache/spark/mllib/stat/Statistics.scala | 16 ++++++++++- .../spark/mllib/clustering/JavaLDASuite.java | 24 +++++++++++++++++ .../spark/mllib/stat/JavaStatisticsSuite.java | 22 +++++++++++++++ .../spark/mllib/clustering/LDASuite.scala | 13 +++++++++ 5 files changed, 99 insertions(+), 3 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala index 5dc637ebdc133..f31949f13a4cf 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala @@ -26,7 +26,7 @@ import org.json4s.jackson.JsonMethods._ import org.apache.spark.SparkContext import org.apache.spark.annotation.Experimental -import org.apache.spark.api.java.JavaPairRDD +import org.apache.spark.api.java.{JavaPairRDD, JavaRDD} import org.apache.spark.graphx.{Edge, EdgeContext, Graph, VertexId} import org.apache.spark.mllib.linalg.{Matrices, Matrix, Vector, Vectors} import org.apache.spark.mllib.util.{Loader, Saveable} @@ -228,6 +228,11 @@ class LocalLDAModel private[clustering] ( docConcentration, topicConcentration, topicsMatrix.toBreeze.toDenseMatrix, gammaShape, k, vocabSize) + /** Java-friendly version of [[logLikelihood]] */ + def logLikelihood(documents: JavaPairRDD[java.lang.Long, Vector]): Double = { + logLikelihood(documents.rdd.asInstanceOf[RDD[(Long, Vector)]]) + } + /** * Calculate an upper bound bound on perplexity. (Lower is better.) * See Equation (16) in original Online LDA paper. @@ -242,6 +247,11 @@ class LocalLDAModel private[clustering] ( -logLikelihood(documents) / corpusTokenCount } + /** Java-friendly version of [[logPerplexity]] */ + def logPerplexity(documents: JavaPairRDD[java.lang.Long, Vector]): Double = { + logPerplexity(documents.rdd.asInstanceOf[RDD[(Long, Vector)]]) + } + /** * Estimate the variational likelihood bound of from `documents`: * log p(documents) >= E_q[log p(documents)] - E_q[log q(documents)] @@ -341,8 +351,14 @@ class LocalLDAModel private[clustering] ( } } -} + /** Java-friendly version of [[topicDistributions]] */ + def topicDistributions( + documents: JavaPairRDD[java.lang.Long, Vector]): JavaPairRDD[java.lang.Long, Vector] = { + val distributions = topicDistributions(documents.rdd.asInstanceOf[RDD[(Long, Vector)]]) + JavaPairRDD.fromRDD(distributions.asInstanceOf[RDD[(java.lang.Long, Vector)]]) + } +} @Experimental object LocalLDAModel extends Loader[LocalLDAModel] { @@ -657,6 +673,13 @@ class DistributedLDAModel private[clustering] ( } } + /** Java-friendly version of [[topTopicsPerDocument]] */ + def javaTopTopicsPerDocument( + k: Int): JavaRDD[(java.lang.Long, Array[Int], Array[java.lang.Double])] = { + val topics = topTopicsPerDocument(k) + topics.asInstanceOf[RDD[(java.lang.Long, Array[Int], Array[java.lang.Double])]].toJavaRDD() + } + // TODO: // override def topicDistributions(documents: RDD[(Long, Vector)]): RDD[(Long, Vector)] = ??? diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala index f84502919e381..24fe48cb8f71f 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala @@ -20,7 +20,7 @@ package org.apache.spark.mllib.stat import scala.annotation.varargs import org.apache.spark.annotation.Experimental -import org.apache.spark.api.java.JavaRDD +import org.apache.spark.api.java.{JavaRDD, JavaDoubleRDD} import org.apache.spark.mllib.linalg.distributed.RowMatrix import org.apache.spark.mllib.linalg.{Matrix, Vector} import org.apache.spark.mllib.regression.LabeledPoint @@ -178,6 +178,9 @@ object Statistics { ChiSqTest.chiSquaredFeatures(data) } + /** Java-friendly version of [[chiSqTest()]] */ + def chiSqTest(data: JavaRDD[LabeledPoint]): Array[ChiSqTestResult] = chiSqTest(data.rdd) + /** * Conduct the two-sided Kolmogorov-Smirnov (KS) test for data sampled from a * continuous distribution. By comparing the largest difference between the empirical cumulative @@ -212,4 +215,15 @@ object Statistics { : KolmogorovSmirnovTestResult = { KolmogorovSmirnovTest.testOneSample(data, distName, params: _*) } + + /** Java-friendly version of [[kolmogorovSmirnovTest()]] */ + @varargs + def kolmogorovSmirnovTest( + data: JavaDoubleRDD, + distName: String, + params: java.lang.Double*): KolmogorovSmirnovTestResult = { + val javaParams = params.asInstanceOf[Seq[Double]] + KolmogorovSmirnovTest.testOneSample(data.rdd.asInstanceOf[RDD[Double]], + distName, javaParams: _*) + } } diff --git a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java index d272a42c8576f..427be9430d820 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java @@ -124,6 +124,10 @@ public Boolean call(Tuple2 tuple2) { } }); assertEquals(topicDistributions.count(), nonEmptyCorpus.count()); + + // Check: javaTopTopicsPerDocuments + JavaRDD> topTopics = + model.javaTopTopicsPerDocument(3); } @Test @@ -160,11 +164,31 @@ public void OnlineOptimizerCompatibility() { assertEquals(roundedLocalTopicSummary.length, k); } + @Test + public void localLdaMethods() { + JavaRDD> docs = sc.parallelize(toyData, 2); + JavaPairRDD pairedDocs = JavaPairRDD.fromJavaRDD(docs); + + // check: topicDistributions + assertEquals(toyModel.topicDistributions(pairedDocs).count(), pairedDocs.count()); + + // check: logPerplexity + double logPerplexity = toyModel.logPerplexity(pairedDocs); + + // check: logLikelihood. + ArrayList> docsSingleWord = new ArrayList>(); + docsSingleWord.add(new Tuple2(Long.valueOf(0), Vectors.dense(1.0, 0.0, 0.0))); + JavaPairRDD single = JavaPairRDD.fromJavaRDD(sc.parallelize(docsSingleWord)); + double logLikelihood = toyModel.logLikelihood(single); + } + private static int tinyK = LDASuite$.MODULE$.tinyK(); private static int tinyVocabSize = LDASuite$.MODULE$.tinyVocabSize(); private static Matrix tinyTopics = LDASuite$.MODULE$.tinyTopics(); private static Tuple2[] tinyTopicDescription = LDASuite$.MODULE$.tinyTopicDescription(); private JavaPairRDD corpus; + private LocalLDAModel toyModel = LDASuite$.MODULE$.toyModel(); + private ArrayList> toyData = LDASuite$.MODULE$.javaToyData(); } diff --git a/mllib/src/test/java/org/apache/spark/mllib/stat/JavaStatisticsSuite.java b/mllib/src/test/java/org/apache/spark/mllib/stat/JavaStatisticsSuite.java index 62f7f26b7c98f..eb4e3698624bc 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/stat/JavaStatisticsSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/stat/JavaStatisticsSuite.java @@ -27,7 +27,12 @@ import static org.junit.Assert.assertEquals; import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaDoubleRDD; import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.mllib.stat.test.ChiSqTestResult; +import org.apache.spark.mllib.stat.test.KolmogorovSmirnovTestResult; public class JavaStatisticsSuite implements Serializable { private transient JavaSparkContext sc; @@ -53,4 +58,21 @@ public void testCorr() { // Check default method assertEquals(corr1, corr2); } + + @Test + public void kolmogorovSmirnovTest() { + JavaDoubleRDD data = sc.parallelizeDoubles(Lists.newArrayList(0.2, 1.0, -1.0, 2.0)); + KolmogorovSmirnovTestResult testResult1 = Statistics.kolmogorovSmirnovTest(data, "norm"); + KolmogorovSmirnovTestResult testResult2 = Statistics.kolmogorovSmirnovTest( + data, "norm", 0.0, 1.0); + } + + @Test + public void chiSqTest() { + JavaRDD data = sc.parallelize(Lists.newArrayList( + new LabeledPoint(0.0, Vectors.dense(0.1, 2.3)), + new LabeledPoint(1.0, Vectors.dense(1.5, 5.1)), + new LabeledPoint(0.0, Vectors.dense(2.4, 8.1)))); + ChiSqTestResult[] testResults = Statistics.chiSqTest(data); + } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala index ce6a8eb8e8c46..926185e90bcf9 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.mllib.clustering +import java.util.{ArrayList => JArrayList} + import breeze.linalg.{DenseMatrix => BDM, argtopk, max, argmax} import org.apache.spark.SparkFunSuite @@ -575,6 +577,17 @@ private[clustering] object LDASuite { Vectors.sparse(6, Array(4, 5), Array(1, 1)) ).zipWithIndex.map { case (wordCounts, docId) => (docId.toLong, wordCounts) } + /** Used in the Java Test Suite */ + def javaToyData: JArrayList[(java.lang.Long, Vector)] = { + val javaData = new JArrayList[(java.lang.Long, Vector)] + var i = 0 + while (i < toyData.size) { + javaData.add((toyData(i)._1, toyData(i)._2)) + i += 1 + } + javaData + } + def toyModel: LocalLDAModel = { val k = 2 val vocabSize = 6 From a8d2f4c5f92210a09c846711bd7cc89a43e2edd2 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Thu, 13 Aug 2015 14:03:55 -0700 Subject: [PATCH 094/253] [SPARK-9942] [PYSPARK] [SQL] ignore exceptions while try to import pandas If pandas is broken (can't be imported, raise other exceptions other than ImportError), pyspark can't be imported, we should ignore all the exceptions. Author: Davies Liu Closes #8173 from davies/fix_pandas. --- python/pyspark/sql/context.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py index 917de24f3536b..0ef46c44644ab 100644 --- a/python/pyspark/sql/context.py +++ b/python/pyspark/sql/context.py @@ -39,7 +39,7 @@ try: import pandas has_pandas = True -except ImportError: +except Exception: has_pandas = False __all__ = ["SQLContext", "HiveContext", "UDFRegistration"] From c2520f501a200cf794bbe5dc9c385100f518d020 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Thu, 13 Aug 2015 16:07:03 -0700 Subject: [PATCH 095/253] [SPARK-9935] [SQL] EqualNotNull not processed in ORC https://issues.apache.org/jira/browse/SPARK-9935 Author: hyukjinkwon Closes #8163 from HyukjinKwon/master. --- .../scala/org/apache/spark/sql/hive/orc/OrcFilters.scala | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFilters.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFilters.scala index 86142e5d66f37..b3d9f7f71a27d 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFilters.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFilters.scala @@ -107,6 +107,11 @@ private[orc] object OrcFilters extends Logging { .filter(isSearchableLiteral) .map(builder.equals(attribute, _)) + case EqualNullSafe(attribute, value) => + Option(value) + .filter(isSearchableLiteral) + .map(builder.nullSafeEquals(attribute, _)) + case LessThan(attribute, value) => Option(value) .filter(isSearchableLiteral) From 0fed23b8dc525f62197d1cd332260a0752d7d35c Mon Sep 17 00:00:00 2001 From: Neelesh Srinivas Salian Date: Thu, 13 Aug 2015 16:12:06 -0700 Subject: [PATCH 096/253] Added deploy-mode section to YARN submission --- docs/running-on-yarn.md | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/docs/running-on-yarn.md b/docs/running-on-yarn.md index cac08a91b97d9..0f77fca1a56ce 100644 --- a/docs/running-on-yarn.md +++ b/docs/running-on-yarn.md @@ -40,6 +40,22 @@ The above starts a YARN client program which starts the default Application Mast To launch a Spark application in `yarn-client` mode, do the same, but replace `yarn-cluster` with `yarn-client`. To run spark-shell: $ ./bin/spark-shell --master yarn-client + +The alternative to launching a Spark application on YARN is to explicitly set the deployment mode for the YARN master + +For example: + + $ ./bin/spark-submit --class org.apache.spark.examples.SparkPi \ + --master yarn \ + --deploy-mode client or cluster + --num-executors 3 \ + --driver-memory 4g \ + --executor-memory 2g \ + --executor-cores 1 \ + --queue thequeue \ + lib/spark-examples*.jar \ + +`--deploy-mode` can be either client or cluster. ## Adding Other JARs From 6c5858bc65c8a8602422b46bfa9cf0a1fb296b88 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Thu, 13 Aug 2015 16:52:17 -0700 Subject: [PATCH 097/253] [SPARK-9922] [ML] rename StringIndexerReverse to IndexToString What `StringIndexerInverse` does is not strictly associated with `StringIndexer`, and the name is not clearly describing the transformation. Renaming to `IndexToString` might be better. ~~I also changed `invert` to `inverse` without arguments. `inputCol` and `outputCol` could be set after.~~ I also removed `invert`. jkbradley holdenk Author: Xiangrui Meng Closes #8152 from mengxr/SPARK-9922. --- .../spark/ml/feature/StringIndexer.scala | 34 +++++-------- .../spark/ml/feature/StringIndexerSuite.scala | 50 +++++++++++++------ 2 files changed, 48 insertions(+), 36 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala index 9e4b0f0add612..9f6e7b6b6b274 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala @@ -24,7 +24,7 @@ import org.apache.spark.ml.attribute.{Attribute, NominalAttribute} import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.Transformer -import org.apache.spark.ml.util.{Identifiable, MetadataUtils} +import org.apache.spark.ml.util.Identifiable import org.apache.spark.sql.DataFrame import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.{DoubleType, NumericType, StringType, StructType} @@ -59,6 +59,8 @@ private[feature] trait StringIndexerBase extends Params with HasInputCol with Ha * If the input column is numeric, we cast it to string and index the string values. * The indices are in [0, numLabels), ordered by label frequencies. * So the most frequent label gets index 0. + * + * @see [[IndexToString]] for the inverse transformation */ @Experimental class StringIndexer(override val uid: String) extends Estimator[StringIndexerModel] @@ -170,34 +172,24 @@ class StringIndexerModel private[ml] ( val copied = new StringIndexerModel(uid, labels) copyValues(copied, extra).setParent(parent) } - - /** - * Return a model to perform the inverse transformation. - * Note: By default we keep the original columns during this transformation, so the inverse - * should only be used on new columns such as predicted labels. - */ - def invert(inputCol: String, outputCol: String): StringIndexerInverse = { - new StringIndexerInverse() - .setInputCol(inputCol) - .setOutputCol(outputCol) - .setLabels(labels) - } } /** * :: Experimental :: - * Transform a provided column back to the original input types using either the metadata - * on the input column, or if provided using the labels supplied by the user. - * Note: By default we keep the original columns during this transformation, - * so the inverse should only be used on new columns such as predicted labels. + * A [[Transformer]] that maps a column of string indices back to a new column of corresponding + * string values using either the ML attributes of the input column, or if provided using the labels + * supplied by the user. + * All original columns are kept during transformation. + * + * @see [[StringIndexer]] for converting strings into indices */ @Experimental -class StringIndexerInverse private[ml] ( +class IndexToString private[ml] ( override val uid: String) extends Transformer with HasInputCol with HasOutputCol { def this() = - this(Identifiable.randomUID("strIdxInv")) + this(Identifiable.randomUID("idxToStr")) /** @group setParam */ def setInputCol(value: String): this.type = set(inputCol, value) @@ -257,7 +249,7 @@ class StringIndexerInverse private[ml] ( } val indexer = udf { index: Double => val idx = index.toInt - if (0 <= idx && idx < values.size) { + if (0 <= idx && idx < values.length) { values(idx) } else { throw new SparkException(s"Unseen index: $index ??") @@ -268,7 +260,7 @@ class StringIndexerInverse private[ml] ( indexer(dataset($(inputCol)).cast(DoubleType)).as(outputColName)) } - override def copy(extra: ParamMap): StringIndexerInverse = { + override def copy(extra: ParamMap): IndexToString = { defaultCopy(extra) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala index 2d24914cb91f6..fa918ce64877c 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala @@ -17,12 +17,13 @@ package org.apache.spark.ml.feature -import org.apache.spark.SparkException -import org.apache.spark.SparkFunSuite +import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.ml.attribute.{Attribute, NominalAttribute} import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.util.MLTestingUtils import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.sql.Row +import org.apache.spark.sql.functions.col class StringIndexerSuite extends SparkFunSuite with MLlibTestSparkContext { @@ -53,19 +54,6 @@ class StringIndexerSuite extends SparkFunSuite with MLlibTestSparkContext { // a -> 0, b -> 2, c -> 1 val expected = Set((0, 0.0), (1, 2.0), (2, 1.0), (3, 0.0), (4, 0.0), (5, 1.0)) assert(output === expected) - // convert reverse our transform - val reversed = indexer.invert("labelIndex", "label2") - .transform(transformed) - .select("id", "label2") - assert(df.collect().map(r => (r.getInt(0), r.getString(1))).toSet === - reversed.collect().map(r => (r.getInt(0), r.getString(1))).toSet) - // Check invert using only metadata - val inverse2 = new StringIndexerInverse() - .setInputCol("labelIndex") - .setOutputCol("label2") - val reversed2 = inverse2.transform(transformed).select("id", "label2") - assert(df.collect().map(r => (r.getInt(0), r.getString(1))).toSet === - reversed2.collect().map(r => (r.getInt(0), r.getString(1))).toSet) } test("StringIndexerUnseen") { @@ -125,4 +113,36 @@ class StringIndexerSuite extends SparkFunSuite with MLlibTestSparkContext { val df = sqlContext.range(0L, 10L) assert(indexerModel.transform(df).eq(df)) } + + test("IndexToString params") { + val idxToStr = new IndexToString() + ParamsSuite.checkParams(idxToStr) + } + + test("IndexToString.transform") { + val labels = Array("a", "b", "c") + val df0 = sqlContext.createDataFrame(Seq( + (0, "a"), (1, "b"), (2, "c"), (0, "a") + )).toDF("index", "expected") + + val idxToStr0 = new IndexToString() + .setInputCol("index") + .setOutputCol("actual") + .setLabels(labels) + idxToStr0.transform(df0).select("actual", "expected").collect().foreach { + case Row(actual, expected) => + assert(actual === expected) + } + + val attr = NominalAttribute.defaultAttr.withValues(labels) + val df1 = df0.select(col("index").as("indexWithAttr", attr.toMetadata()), col("expected")) + + val idxToStr1 = new IndexToString() + .setInputCol("indexWithAttr") + .setOutputCol("actual") + idxToStr1.transform(df1).select("actual", "expected").collect().foreach { + case Row(actual, expected) => + assert(actual === expected) + } + } } From 693949ba4096c01a0b41da2542ff316823464a16 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Thu, 13 Aug 2015 17:33:37 -0700 Subject: [PATCH 098/253] [SPARK-8976] [PYSPARK] fix open mode in python3 This bug only happen on Python 3 and Windows. I tested this manually with python 3 and disable python daemon, no unit test yet. Author: Davies Liu Closes #8181 from davies/open_mode. --- python/pyspark/worker.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 93df9002be377..42c2f8b75933e 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -146,5 +146,5 @@ def process(): java_port = int(sys.stdin.readline()) sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock.connect(("127.0.0.1", java_port)) - sock_file = sock.makefile("a+", 65536) + sock_file = sock.makefile("rwb", 65536) main(sock_file, sock_file) From c50f97dafd2d5bf5a8351efcc1c8d3e2b87efc72 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Thu, 13 Aug 2015 17:35:11 -0700 Subject: [PATCH 099/253] [SPARK-9943] [SQL] deserialized UnsafeHashedRelation should be serializable When the free memory in executor goes low, the cached broadcast objects need to serialized into disk, but currently the deserialized UnsafeHashedRelation can't be serialized , fail with NPE. This PR fixes that. cc rxin Author: Davies Liu Closes #8174 from davies/serialize_hashed. --- .../sql/execution/joins/HashedRelation.scala | 93 ++++++++++++------- .../execution/joins/HashedRelationSuite.scala | 14 +++ 2 files changed, 74 insertions(+), 33 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala index ea02076b41a6f..6c0196c21a0d1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.execution.SparkSqlSerializer import org.apache.spark.sql.execution.metric.LongSQLMetric import org.apache.spark.unsafe.Platform import org.apache.spark.unsafe.map.BytesToBytesMap -import org.apache.spark.unsafe.memory.{ExecutorMemoryManager, MemoryAllocator, TaskMemoryManager} +import org.apache.spark.unsafe.memory.{MemoryLocation, ExecutorMemoryManager, MemoryAllocator, TaskMemoryManager} import org.apache.spark.util.Utils import org.apache.spark.util.collection.CompactBuffer import org.apache.spark.{SparkConf, SparkEnv} @@ -247,40 +247,67 @@ private[joins] final class UnsafeHashedRelation( } override def writeExternal(out: ObjectOutput): Unit = Utils.tryOrIOException { - out.writeInt(hashTable.size()) - - val iter = hashTable.entrySet().iterator() - while (iter.hasNext) { - val entry = iter.next() - val key = entry.getKey - val values = entry.getValue - - // write all the values as single byte array - var totalSize = 0L - var i = 0 - while (i < values.length) { - totalSize += values(i).getSizeInBytes + 4 + 4 - i += 1 + if (binaryMap != null) { + // This could happen when a cached broadcast object need to be dumped into disk to free memory + out.writeInt(binaryMap.numElements()) + + var buffer = new Array[Byte](64) + def write(addr: MemoryLocation, length: Int): Unit = { + if (buffer.length < length) { + buffer = new Array[Byte](length) + } + Platform.copyMemory(addr.getBaseObject, addr.getBaseOffset, + buffer, Platform.BYTE_ARRAY_OFFSET, length) + out.write(buffer, 0, length) } - assert(totalSize < Integer.MAX_VALUE, "values are too big") - - // [key size] [values size] [key bytes] [values bytes] - out.writeInt(key.getSizeInBytes) - out.writeInt(totalSize.toInt) - out.write(key.getBytes) - i = 0 - while (i < values.length) { - // [num of fields] [num of bytes] [row bytes] - // write the integer in native order, so they can be read by UNSAFE.getInt() - if (ByteOrder.nativeOrder() == ByteOrder.BIG_ENDIAN) { - out.writeInt(values(i).numFields()) - out.writeInt(values(i).getSizeInBytes) - } else { - out.writeInt(Integer.reverseBytes(values(i).numFields())) - out.writeInt(Integer.reverseBytes(values(i).getSizeInBytes)) + + val iter = binaryMap.iterator() + while (iter.hasNext) { + val loc = iter.next() + // [key size] [values size] [key bytes] [values bytes] + out.writeInt(loc.getKeyLength) + out.writeInt(loc.getValueLength) + write(loc.getKeyAddress, loc.getKeyLength) + write(loc.getValueAddress, loc.getValueLength) + } + + } else { + assert(hashTable != null) + out.writeInt(hashTable.size()) + + val iter = hashTable.entrySet().iterator() + while (iter.hasNext) { + val entry = iter.next() + val key = entry.getKey + val values = entry.getValue + + // write all the values as single byte array + var totalSize = 0L + var i = 0 + while (i < values.length) { + totalSize += values(i).getSizeInBytes + 4 + 4 + i += 1 + } + assert(totalSize < Integer.MAX_VALUE, "values are too big") + + // [key size] [values size] [key bytes] [values bytes] + out.writeInt(key.getSizeInBytes) + out.writeInt(totalSize.toInt) + out.write(key.getBytes) + i = 0 + while (i < values.length) { + // [num of fields] [num of bytes] [row bytes] + // write the integer in native order, so they can be read by UNSAFE.getInt() + if (ByteOrder.nativeOrder() == ByteOrder.BIG_ENDIAN) { + out.writeInt(values(i).numFields()) + out.writeInt(values(i).getSizeInBytes) + } else { + out.writeInt(Integer.reverseBytes(values(i).numFields())) + out.writeInt(Integer.reverseBytes(values(i).getSizeInBytes)) + } + out.write(values(i).getBytes) + i += 1 } - out.write(values(i).getBytes) - i += 1 } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala index c635b2d51f464..d33a967093ca5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala @@ -102,6 +102,14 @@ class HashedRelationSuite extends SparkFunSuite { assert(hashed2.get(toUnsafe(InternalRow(10))) === null) assert(hashed2.get(unsafeData(2)) === data2) assert(numDataRows.value.value === data.length) + + val os2 = new ByteArrayOutputStream() + val out2 = new ObjectOutputStream(os2) + hashed2.asInstanceOf[UnsafeHashedRelation].writeExternal(out2) + out2.flush() + // This depends on that the order of items in BytesToBytesMap.iterator() is exactly the same + // as they are inserted + assert(java.util.Arrays.equals(os2.toByteArray, os.toByteArray)) } test("test serialization empty hash map") { @@ -119,5 +127,11 @@ class HashedRelationSuite extends SparkFunSuite { val toUnsafe = UnsafeProjection.create(schema) val row = toUnsafe(InternalRow(0)) assert(hashed2.get(row) === null) + + val os2 = new ByteArrayOutputStream() + val out2 = new ObjectOutputStream(os2) + hashed2.writeExternal(out2) + out2.flush() + assert(java.util.Arrays.equals(os2.toByteArray, os.toByteArray)) } } From 8187b3ae477e2b2987ae9acc5368d57b1d5653b2 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Thu, 13 Aug 2015 17:42:01 -0700 Subject: [PATCH 100/253] [SPARK-9580] [SQL] Replace singletons in SQL tests A fundamental limitation of the existing SQL tests is that *there is simply no way to create your own `SparkContext`*. This is a serious limitation because the user may wish to use a different master or config. As a case in point, `BroadcastJoinSuite` is entirely commented out because there is no way to make it pass with the existing infrastructure. This patch removes the singletons `TestSQLContext` and `TestData`, and instead introduces a `SharedSQLContext` that starts a context per suite. Unfortunately the singletons were so ingrained in the SQL tests that this patch necessarily needed to touch *all* the SQL test files. [Review on Reviewable](https://reviewable.io/reviews/apache/spark/8111) Author: Andrew Or Closes #8111 from andrewor14/sql-tests-refactor. --- project/MimaExcludes.scala | 10 + project/SparkBuild.scala | 16 +- .../analysis/AnalysisErrorSuite.scala | 6 +- .../org/apache/spark/sql/SQLContext.scala | 97 +----- .../org/apache/spark/sql/SQLImplicits.scala | 123 ++++++++ .../spark/sql/JavaApplySchemaSuite.java | 10 +- .../apache/spark/sql/JavaDataFrameSuite.java | 39 +-- .../org/apache/spark/sql/JavaUDFSuite.java | 10 +- .../spark/sql/sources/JavaSaveLoadSuite.java | 15 +- .../apache/spark/sql/CachedTableSuite.scala | 14 +- .../spark/sql/ColumnExpressionSuite.scala | 37 +-- .../spark/sql/DataFrameAggregateSuite.scala | 10 +- .../spark/sql/DataFrameFunctionsSuite.scala | 18 +- .../spark/sql/DataFrameImplicitsSuite.scala | 6 +- .../apache/spark/sql/DataFrameJoinSuite.scala | 10 +- .../spark/sql/DataFrameNaFunctionsSuite.scala | 6 +- .../apache/spark/sql/DataFrameStatSuite.scala | 19 +- .../org/apache/spark/sql/DataFrameSuite.scala | 12 +- .../spark/sql/DataFrameTungstenSuite.scala | 8 +- .../apache/spark/sql/DateFunctionsSuite.scala | 13 +- .../org/apache/spark/sql/JoinSuite.scala | 47 ++- .../apache/spark/sql/JsonFunctionsSuite.scala | 6 +- .../apache/spark/sql/ListTablesSuite.scala | 15 +- .../spark/sql/MathExpressionsSuite.scala | 28 +- .../org/apache/spark/sql/QueryTest.scala | 6 - .../scala/org/apache/spark/sql/RowSuite.scala | 7 +- .../org/apache/spark/sql/SQLConfSuite.scala | 19 +- .../apache/spark/sql/SQLContextSuite.scala | 13 +- .../org/apache/spark/sql/SQLQuerySuite.scala | 29 +- .../sql/ScalaReflectionRelationSuite.scala | 17 +- .../apache/spark/sql/SerializationSuite.scala | 9 +- .../spark/sql/StringFunctionsSuite.scala | 7 +- .../scala/org/apache/spark/sql/TestData.scala | 197 ------------ .../scala/org/apache/spark/sql/UDFSuite.scala | 39 ++- .../spark/sql/UserDefinedTypeSuite.scala | 9 +- .../columnar/InMemoryColumnarQuerySuite.scala | 14 +- .../columnar/PartitionBatchPruningSuite.scala | 31 +- .../spark/sql/execution/ExchangeSuite.scala | 3 +- .../spark/sql/execution/PlannerSuite.scala | 39 +-- .../execution/RowFormatConvertersSuite.scala | 22 +- .../spark/sql/execution/SortSuite.scala | 3 +- .../spark/sql/execution/SparkPlanTest.scala | 36 +-- .../sql/execution/TungstenSortSuite.scala | 21 +- .../UnsafeFixedWidthAggregationMapSuite.scala | 14 +- .../UnsafeKVExternalSorterSuite.scala | 7 +- .../TungstenAggregationIteratorSuite.scala | 7 +- .../datasources/json/JsonSuite.scala | 42 ++- .../datasources/json/TestJsonData.scala | 40 ++- .../ParquetAvroCompatibilitySuite.scala | 14 +- .../parquet/ParquetCompatibilityTest.scala | 10 +- .../parquet/ParquetFilterSuite.scala | 8 +- .../datasources/parquet/ParquetIOSuite.scala | 6 +- .../ParquetPartitionDiscoverySuite.scala | 12 +- .../ParquetProtobufCompatibilitySuite.scala | 7 +- .../parquet/ParquetQuerySuite.scala | 43 ++- .../parquet/ParquetSchemaSuite.scala | 6 +- .../datasources/parquet/ParquetTest.scala | 15 +- .../ParquetThriftCompatibilitySuite.scala | 8 +- .../sql/execution/debug/DebuggingSuite.scala | 6 +- .../execution/joins/HashedRelationSuite.scala | 10 +- .../sql/execution/joins/InnerJoinSuite.scala | 248 ++++++++------- .../sql/execution/joins/OuterJoinSuite.scala | 125 ++++---- .../sql/execution/joins/SemiJoinSuite.scala | 113 +++---- .../execution/metric/SQLMetricsSuite.scala | 62 ++-- .../sql/execution/ui/SQLListenerSuite.scala | 14 +- .../org/apache/spark/sql/jdbc/JDBCSuite.scala | 9 +- .../spark/sql/jdbc/JDBCWriteSuite.scala | 20 +- .../sources/CreateTableAsSelectSuite.scala | 18 +- .../sql/sources/DDLSourceLoadSuite.scala | 3 +- .../spark/sql/sources/DDLTestSuite.scala | 11 +- .../spark/sql/sources/DataSourceTest.scala | 17 +- .../spark/sql/sources/FilteredScanSuite.scala | 9 +- .../spark/sql/sources/InsertSuite.scala | 30 +- .../sql/sources/PartitionedWriteSuite.scala | 14 +- .../spark/sql/sources/PrunedScanSuite.scala | 11 +- .../spark/sql/sources/SaveLoadSuite.scala | 26 +- .../spark/sql/sources/TableScanSuite.scala | 15 +- .../apache/spark/sql/test/SQLTestData.scala | 290 ++++++++++++++++++ .../apache/spark/sql/test/SQLTestUtils.scala | 92 +++++- .../spark/sql/test/SharedSQLContext.scala | 68 ++++ .../spark/sql/test/TestSQLContext.scala | 46 ++- .../hive/thriftserver/UISeleniumSuite.scala | 2 - .../sql/hive/HiveMetastoreCatalogSuite.scala | 5 +- .../spark/sql/hive/HiveParquetSuite.scala | 11 +- .../sql/hive/MetastoreDataSourcesSuite.scala | 4 +- .../spark/sql/hive/MultiDatabaseSuite.scala | 5 +- .../hive/ParquetHiveCompatibilitySuite.scala | 3 +- .../org/apache/spark/sql/hive/UDFSuite.scala | 4 +- .../execution/AggregationQuerySuite.scala | 9 +- .../sql/hive/execution/HiveExplainSuite.scala | 4 +- .../sql/hive/execution/SQLQuerySuite.scala | 3 +- .../execution/ScriptTransformationSuite.scala | 3 +- .../apache/spark/sql/hive/orc/OrcTest.scala | 4 +- .../apache/spark/sql/hive/parquetSuites.scala | 3 +- .../CommitFailureTestRelationSuite.scala | 6 +- .../sql/sources/hadoopFsRelationSuites.scala | 5 +- 96 files changed, 1460 insertions(+), 1203 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala delete mode 100644 sql/core/src/test/scala/org/apache/spark/sql/TestData.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala rename sql/core/src/{main => test}/scala/org/apache/spark/sql/test/TestSQLContext.scala (54%) diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 784f83c10e023..88745dc086a04 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -178,6 +178,16 @@ object MimaExcludes { // SPARK-4751 Dynamic allocation for standalone mode ProblemFilters.exclude[MissingMethodProblem]( "org.apache.spark.SparkContext.supportDynamicAllocation") + ) ++ Seq( + // SPARK-9580: Remove SQL test singletons + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.sql.test.LocalSQLContext$SQLSession"), + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.sql.test.LocalSQLContext"), + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.sql.test.TestSQLContext"), + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.sql.test.TestSQLContext$") ) ++ Seq( // SPARK-9704 Made ProbabilisticClassifier, Identifiable, VectorUDT public APIs ProblemFilters.exclude[IncompatibleResultTypeProblem]( diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 74f815f941d5b..04e0d49b178cf 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -319,6 +319,8 @@ object SQL { lazy val settings = Seq( initialCommands in console := """ + |import org.apache.spark.SparkContext + |import org.apache.spark.sql.SQLContext |import org.apache.spark.sql.catalyst.analysis._ |import org.apache.spark.sql.catalyst.dsl._ |import org.apache.spark.sql.catalyst.errors._ @@ -328,9 +330,14 @@ object SQL { |import org.apache.spark.sql.catalyst.util._ |import org.apache.spark.sql.execution |import org.apache.spark.sql.functions._ - |import org.apache.spark.sql.test.TestSQLContext._ - |import org.apache.spark.sql.types._""".stripMargin, - cleanupCommands in console := "sparkContext.stop()" + |import org.apache.spark.sql.types._ + | + |val sc = new SparkContext("local[*]", "dev-shell") + |val sqlContext = new SQLContext(sc) + |import sqlContext.implicits._ + |import sqlContext._ + """.stripMargin, + cleanupCommands in console := "sc.stop()" ) } @@ -340,8 +347,6 @@ object Hive { javaOptions += "-XX:MaxPermSize=256m", // Specially disable assertions since some Hive tests fail them javaOptions in Test := (javaOptions in Test).value.filterNot(_ == "-ea"), - // Multiple queries rely on the TestHive singleton. See comments there for more details. - parallelExecution in Test := false, // Supporting all SerDes requires us to depend on deprecated APIs, so we turn off the warnings // only for this subproject. scalacOptions <<= scalacOptions map { currentOpts: Seq[String] => @@ -349,6 +354,7 @@ object Hive { }, initialCommands in console := """ + |import org.apache.spark.SparkContext |import org.apache.spark.sql.catalyst.analysis._ |import org.apache.spark.sql.catalyst.dsl._ |import org.apache.spark.sql.catalyst.errors._ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala index 63b475b6366c2..f60d11c988ef8 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala @@ -17,14 +17,10 @@ package org.apache.spark.sql.catalyst.analysis -import org.scalatest.BeforeAndAfter - -import org.apache.spark.SparkFunSuite import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.Inner -import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.types._ @@ -42,7 +38,7 @@ case class UnresolvedTestPlan() extends LeafNode { override def output: Seq[Attribute] = Nil } -class AnalysisErrorSuite extends AnalysisTest with BeforeAndAfter { +class AnalysisErrorSuite extends AnalysisTest { import TestRelations._ def errorTest( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 4bf00b3399e7a..53de10d5fa9da 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -23,7 +23,6 @@ import java.util.concurrent.atomic.AtomicReference import scala.collection.JavaConversions._ import scala.collection.immutable -import scala.language.implicitConversions import scala.reflect.runtime.universe.TypeTag import scala.util.control.NonFatal @@ -41,10 +40,9 @@ import org.apache.spark.sql.catalyst.rules.RuleExecutor import org.apache.spark.sql.catalyst.{InternalRow, ParserDialect, _} import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.datasources._ +import org.apache.spark.sql.execution.ui.{SQLListener, SQLTab} import org.apache.spark.sql.sources.BaseRelation import org.apache.spark.sql.types._ -import org.apache.spark.sql.execution.ui.{SQLListener, SQLTab} -import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.Utils /** @@ -334,97 +332,10 @@ class SQLContext(@transient val sparkContext: SparkContext) * @since 1.3.0 */ @Experimental - object implicits extends Serializable { - // scalastyle:on - - /** - * Converts $"col name" into an [[Column]]. - * @since 1.3.0 - */ - implicit class StringToColumn(val sc: StringContext) { - def $(args: Any*): ColumnName = { - new ColumnName(sc.s(args: _*)) - } - } - - /** - * An implicit conversion that turns a Scala `Symbol` into a [[Column]]. - * @since 1.3.0 - */ - implicit def symbolToColumn(s: Symbol): ColumnName = new ColumnName(s.name) - - /** - * Creates a DataFrame from an RDD of case classes or tuples. - * @since 1.3.0 - */ - implicit def rddToDataFrameHolder[A <: Product : TypeTag](rdd: RDD[A]): DataFrameHolder = { - DataFrameHolder(self.createDataFrame(rdd)) - } - - /** - * Creates a DataFrame from a local Seq of Product. - * @since 1.3.0 - */ - implicit def localSeqToDataFrameHolder[A <: Product : TypeTag](data: Seq[A]): DataFrameHolder = - { - DataFrameHolder(self.createDataFrame(data)) - } - - // Do NOT add more implicit conversions. They are likely to break source compatibility by - // making existing implicit conversions ambiguous. In particular, RDD[Double] is dangerous - // because of [[DoubleRDDFunctions]]. - - /** - * Creates a single column DataFrame from an RDD[Int]. - * @since 1.3.0 - */ - implicit def intRddToDataFrameHolder(data: RDD[Int]): DataFrameHolder = { - val dataType = IntegerType - val rows = data.mapPartitions { iter => - val row = new SpecificMutableRow(dataType :: Nil) - iter.map { v => - row.setInt(0, v) - row: InternalRow - } - } - DataFrameHolder( - self.internalCreateDataFrame(rows, StructType(StructField("_1", dataType) :: Nil))) - } - - /** - * Creates a single column DataFrame from an RDD[Long]. - * @since 1.3.0 - */ - implicit def longRddToDataFrameHolder(data: RDD[Long]): DataFrameHolder = { - val dataType = LongType - val rows = data.mapPartitions { iter => - val row = new SpecificMutableRow(dataType :: Nil) - iter.map { v => - row.setLong(0, v) - row: InternalRow - } - } - DataFrameHolder( - self.internalCreateDataFrame(rows, StructType(StructField("_1", dataType) :: Nil))) - } - - /** - * Creates a single column DataFrame from an RDD[String]. - * @since 1.3.0 - */ - implicit def stringRddToDataFrameHolder(data: RDD[String]): DataFrameHolder = { - val dataType = StringType - val rows = data.mapPartitions { iter => - val row = new SpecificMutableRow(dataType :: Nil) - iter.map { v => - row.update(0, UTF8String.fromString(v)) - row: InternalRow - } - } - DataFrameHolder( - self.internalCreateDataFrame(rows, StructType(StructField("_1", dataType) :: Nil))) - } + object implicits extends SQLImplicits with Serializable { + protected override def _sqlContext: SQLContext = self } + // scalastyle:on /** * :: Experimental :: diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala new file mode 100644 index 0000000000000..5f82372700f2c --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala @@ -0,0 +1,123 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import scala.language.implicitConversions +import scala.reflect.runtime.universe.TypeTag + +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.types._ +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.SpecificMutableRow +import org.apache.spark.sql.types.StructField +import org.apache.spark.unsafe.types.UTF8String + +/** + * A collection of implicit methods for converting common Scala objects into [[DataFrame]]s. + */ +private[sql] abstract class SQLImplicits { + protected def _sqlContext: SQLContext + + /** + * Converts $"col name" into an [[Column]]. + * @since 1.3.0 + */ + implicit class StringToColumn(val sc: StringContext) { + def $(args: Any*): ColumnName = { + new ColumnName(sc.s(args: _*)) + } + } + + /** + * An implicit conversion that turns a Scala `Symbol` into a [[Column]]. + * @since 1.3.0 + */ + implicit def symbolToColumn(s: Symbol): ColumnName = new ColumnName(s.name) + + /** + * Creates a DataFrame from an RDD of case classes or tuples. + * @since 1.3.0 + */ + implicit def rddToDataFrameHolder[A <: Product : TypeTag](rdd: RDD[A]): DataFrameHolder = { + DataFrameHolder(_sqlContext.createDataFrame(rdd)) + } + + /** + * Creates a DataFrame from a local Seq of Product. + * @since 1.3.0 + */ + implicit def localSeqToDataFrameHolder[A <: Product : TypeTag](data: Seq[A]): DataFrameHolder = + { + DataFrameHolder(_sqlContext.createDataFrame(data)) + } + + // Do NOT add more implicit conversions. They are likely to break source compatibility by + // making existing implicit conversions ambiguous. In particular, RDD[Double] is dangerous + // because of [[DoubleRDDFunctions]]. + + /** + * Creates a single column DataFrame from an RDD[Int]. + * @since 1.3.0 + */ + implicit def intRddToDataFrameHolder(data: RDD[Int]): DataFrameHolder = { + val dataType = IntegerType + val rows = data.mapPartitions { iter => + val row = new SpecificMutableRow(dataType :: Nil) + iter.map { v => + row.setInt(0, v) + row: InternalRow + } + } + DataFrameHolder( + _sqlContext.internalCreateDataFrame(rows, StructType(StructField("_1", dataType) :: Nil))) + } + + /** + * Creates a single column DataFrame from an RDD[Long]. + * @since 1.3.0 + */ + implicit def longRddToDataFrameHolder(data: RDD[Long]): DataFrameHolder = { + val dataType = LongType + val rows = data.mapPartitions { iter => + val row = new SpecificMutableRow(dataType :: Nil) + iter.map { v => + row.setLong(0, v) + row: InternalRow + } + } + DataFrameHolder( + _sqlContext.internalCreateDataFrame(rows, StructType(StructField("_1", dataType) :: Nil))) + } + + /** + * Creates a single column DataFrame from an RDD[String]. + * @since 1.3.0 + */ + implicit def stringRddToDataFrameHolder(data: RDD[String]): DataFrameHolder = { + val dataType = StringType + val rows = data.mapPartitions { iter => + val row = new SpecificMutableRow(dataType :: Nil) + iter.map { v => + row.update(0, UTF8String.fromString(v)) + row: InternalRow + } + } + DataFrameHolder( + _sqlContext.internalCreateDataFrame(rows, StructType(StructField("_1", dataType) :: Nil))) + } +} diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java index e912eb835d169..bf693c7c393f6 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java @@ -27,6 +27,7 @@ import org.junit.Before; import org.junit.Test; +import org.apache.spark.SparkContext; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.api.java.function.Function; @@ -34,7 +35,6 @@ import org.apache.spark.sql.Row; import org.apache.spark.sql.RowFactory; import org.apache.spark.sql.SQLContext; -import org.apache.spark.sql.test.TestSQLContext$; import org.apache.spark.sql.types.DataTypes; import org.apache.spark.sql.types.StructField; import org.apache.spark.sql.types.StructType; @@ -48,14 +48,16 @@ public class JavaApplySchemaSuite implements Serializable { @Before public void setUp() { - sqlContext = TestSQLContext$.MODULE$; - javaCtx = new JavaSparkContext(sqlContext.sparkContext()); + SparkContext context = new SparkContext("local[*]", "testing"); + javaCtx = new JavaSparkContext(context); + sqlContext = new SQLContext(context); } @After public void tearDown() { - javaCtx = null; + sqlContext.sparkContext().stop(); sqlContext = null; + javaCtx = null; } public static class Person implements Serializable { diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java index 7302361ab9fdb..7abdd3db80341 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java @@ -17,44 +17,45 @@ package test.org.apache.spark.sql; +import java.io.Serializable; +import java.util.Arrays; +import java.util.Comparator; +import java.util.List; +import java.util.Map; + +import scala.collection.JavaConversions; +import scala.collection.Seq; + import com.google.common.collect.ImmutableMap; import com.google.common.primitives.Ints; +import org.junit.*; +import org.apache.spark.SparkContext; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.sql.*; +import static org.apache.spark.sql.functions.*; import org.apache.spark.sql.test.TestSQLContext; -import org.apache.spark.sql.test.TestSQLContext$; import org.apache.spark.sql.types.*; -import org.junit.*; - -import scala.collection.JavaConversions; -import scala.collection.Seq; - -import java.io.Serializable; -import java.util.Arrays; -import java.util.Comparator; -import java.util.List; -import java.util.Map; - -import static org.apache.spark.sql.functions.*; public class JavaDataFrameSuite { private transient JavaSparkContext jsc; - private transient SQLContext context; + private transient TestSQLContext context; @Before public void setUp() { // Trigger static initializer of TestData - TestData$.MODULE$.testData(); - jsc = new JavaSparkContext(TestSQLContext.sparkContext()); - context = TestSQLContext$.MODULE$; + SparkContext sc = new SparkContext("local[*]", "testing"); + jsc = new JavaSparkContext(sc); + context = new TestSQLContext(sc); + context.loadTestData(); } @After public void tearDown() { - jsc = null; + context.sparkContext().stop(); context = null; + jsc = null; } @Test @@ -230,7 +231,7 @@ public void testCovariance() { @Test public void testSampleBy() { - DataFrame df = context.range(0, 100).select(col("id").mod(3).as("key")); + DataFrame df = context.range(0, 100, 1, 2).select(col("id").mod(3).as("key")); DataFrame sampled = df.stat().sampleBy("key", ImmutableMap.of(0, 0.1, 1, 0.2), 0L); Row[] actual = sampled.groupBy("key").count().orderBy("key").collect(); Row[] expected = new Row[] {RowFactory.create(0, 5), RowFactory.create(1, 8)}; diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDFSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDFSuite.java index 79d92734ff375..bb02b58cca9be 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDFSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDFSuite.java @@ -23,12 +23,12 @@ import org.junit.Before; import org.junit.Test; +import org.apache.spark.SparkContext; import org.apache.spark.sql.Row; import org.apache.spark.sql.SQLContext; import org.apache.spark.sql.api.java.UDF1; import org.apache.spark.sql.api.java.UDF2; import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.sql.test.TestSQLContext$; import org.apache.spark.sql.types.DataTypes; // The test suite itself is Serializable so that anonymous Function implementations can be @@ -40,12 +40,16 @@ public class JavaUDFSuite implements Serializable { @Before public void setUp() { - sqlContext = TestSQLContext$.MODULE$; - sc = new JavaSparkContext(sqlContext.sparkContext()); + SparkContext _sc = new SparkContext("local[*]", "testing"); + sqlContext = new SQLContext(_sc); + sc = new JavaSparkContext(_sc); } @After public void tearDown() { + sqlContext.sparkContext().stop(); + sqlContext = null; + sc = null; } @SuppressWarnings("unchecked") diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaSaveLoadSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaSaveLoadSuite.java index 2706e01bd28af..6f9e7f68dc39c 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaSaveLoadSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaSaveLoadSuite.java @@ -21,13 +21,14 @@ import java.io.IOException; import java.util.*; +import org.junit.After; import org.junit.Assert; import org.junit.Before; import org.junit.Test; +import org.apache.spark.SparkContext; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.sql.test.TestSQLContext$; import org.apache.spark.sql.*; import org.apache.spark.sql.types.DataTypes; import org.apache.spark.sql.types.StructField; @@ -52,8 +53,9 @@ private void checkAnswer(DataFrame actual, List expected) { @Before public void setUp() throws IOException { - sqlContext = TestSQLContext$.MODULE$; - sc = new JavaSparkContext(sqlContext.sparkContext()); + SparkContext _sc = new SparkContext("local[*]", "testing"); + sqlContext = new SQLContext(_sc); + sc = new JavaSparkContext(_sc); originalDefaultSource = sqlContext.conf().defaultDataSourceName(); path = @@ -71,6 +73,13 @@ public void setUp() throws IOException { df.registerTempTable("jsonTable"); } + @After + public void tearDown() { + sqlContext.sparkContext().stop(); + sqlContext = null; + sc = null; + } + @Test public void saveAndLoad() { Map options = new HashMap(); diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala index a88df91b1001c..af7590c3d3c17 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala @@ -18,24 +18,20 @@ package org.apache.spark.sql import scala.concurrent.duration._ -import scala.language.{implicitConversions, postfixOps} +import scala.language.postfixOps import org.scalatest.concurrent.Eventually._ import org.apache.spark.Accumulators -import org.apache.spark.sql.TestData._ import org.apache.spark.sql.columnar._ import org.apache.spark.sql.functions._ +import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.storage.{StorageLevel, RDDBlockId} -case class BigData(s: String) +private case class BigData(s: String) -class CachedTableSuite extends QueryTest { - TestData // Load test tables. - - private lazy val ctx = org.apache.spark.sql.test.TestSQLContext - import ctx.implicits._ - import ctx.sql +class CachedTableSuite extends QueryTest with SharedSQLContext { + import testImplicits._ def rddIdOf(tableName: String): Int = { val executedPlan = ctx.table(tableName).queryExecution.executedPlan diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala index 6a09a3b72c081..ee74e3e83da5a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala @@ -21,16 +21,20 @@ import org.scalatest.Matchers._ import org.apache.spark.sql.execution.{Project, TungstenProject} import org.apache.spark.sql.functions._ +import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ -import org.apache.spark.sql.test.SQLTestUtils -class ColumnExpressionSuite extends QueryTest with SQLTestUtils { - import org.apache.spark.sql.TestData._ +class ColumnExpressionSuite extends QueryTest with SharedSQLContext { + import testImplicits._ - private lazy val ctx = org.apache.spark.sql.test.TestSQLContext - import ctx.implicits._ - - override def sqlContext(): SQLContext = ctx + private lazy val booleanData = { + ctx.createDataFrame(ctx.sparkContext.parallelize( + Row(false, false) :: + Row(false, true) :: + Row(true, false) :: + Row(true, true) :: Nil), + StructType(Seq(StructField("a", BooleanType), StructField("b", BooleanType)))) + } test("column names with space") { val df = Seq((1, "a")).toDF("name with space", "name.with.dot") @@ -258,7 +262,7 @@ class ColumnExpressionSuite extends QueryTest with SQLTestUtils { nullStrings.collect().toSeq.filter(r => r.getString(1) eq null)) checkAnswer( - ctx.sql("select isnull(null), isnull(1)"), + sql("select isnull(null), isnull(1)"), Row(true, false)) } @@ -268,7 +272,7 @@ class ColumnExpressionSuite extends QueryTest with SQLTestUtils { nullStrings.collect().toSeq.filter(r => r.getString(1) ne null)) checkAnswer( - ctx.sql("select isnotnull(null), isnotnull('a')"), + sql("select isnotnull(null), isnotnull('a')"), Row(false, true)) } @@ -289,7 +293,7 @@ class ColumnExpressionSuite extends QueryTest with SQLTestUtils { Row(true, true) :: Row(true, true) :: Row(false, false) :: Row(false, false) :: Nil) checkAnswer( - ctx.sql("select isnan(15), isnan('invalid')"), + sql("select isnan(15), isnan('invalid')"), Row(false, false)) } @@ -309,7 +313,7 @@ class ColumnExpressionSuite extends QueryTest with SQLTestUtils { ) testData.registerTempTable("t") checkAnswer( - ctx.sql( + sql( "select nanvl(a, 5), nanvl(b, 10), nanvl(10, b), nanvl(c, null), nanvl(d, 10), " + " nanvl(b, e), nanvl(e, f) from t"), Row(null, 3.0, 10.0, null, Double.PositiveInfinity, 3.0, 1.0) @@ -433,13 +437,6 @@ class ColumnExpressionSuite extends QueryTest with SQLTestUtils { } } - val booleanData = ctx.createDataFrame(ctx.sparkContext.parallelize( - Row(false, false) :: - Row(false, true) :: - Row(true, false) :: - Row(true, true) :: Nil), - StructType(Seq(StructField("a", BooleanType), StructField("b", BooleanType)))) - test("&&") { checkAnswer( booleanData.filter($"a" && true), @@ -523,7 +520,7 @@ class ColumnExpressionSuite extends QueryTest with SQLTestUtils { ) checkAnswer( - ctx.sql("SELECT upper('aB'), ucase('cDe')"), + sql("SELECT upper('aB'), ucase('cDe')"), Row("AB", "CDE")) } @@ -544,7 +541,7 @@ class ColumnExpressionSuite extends QueryTest with SQLTestUtils { ) checkAnswer( - ctx.sql("SELECT lower('aB'), lcase('cDe')"), + sql("SELECT lower('aB'), lcase('cDe')"), Row("ab", "cde")) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index f9cff7440a76e..72cf7aab0b977 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -17,15 +17,13 @@ package org.apache.spark.sql -import org.apache.spark.sql.TestData._ import org.apache.spark.sql.functions._ -import org.apache.spark.sql.types.{BinaryType, DecimalType} +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types.DecimalType -class DataFrameAggregateSuite extends QueryTest { - - private lazy val ctx = org.apache.spark.sql.test.TestSQLContext - import ctx.implicits._ +class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { + import testImplicits._ test("groupBy") { checkAnswer( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 03116a374f3be..9d965258e389d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -17,17 +17,15 @@ package org.apache.spark.sql -import org.apache.spark.sql.TestData._ import org.apache.spark.sql.functions._ +import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ /** * Test suite for functions in [[org.apache.spark.sql.functions]]. */ -class DataFrameFunctionsSuite extends QueryTest { - - private lazy val ctx = org.apache.spark.sql.test.TestSQLContext - import ctx.implicits._ +class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { + import testImplicits._ test("array with column name") { val df = Seq((0, 1)).toDF("a", "b") @@ -119,11 +117,11 @@ class DataFrameFunctionsSuite extends QueryTest { test("constant functions") { checkAnswer( - ctx.sql("SELECT E()"), + sql("SELECT E()"), Row(scala.math.E) ) checkAnswer( - ctx.sql("SELECT PI()"), + sql("SELECT PI()"), Row(scala.math.Pi) ) } @@ -153,7 +151,7 @@ class DataFrameFunctionsSuite extends QueryTest { test("nvl function") { checkAnswer( - ctx.sql("SELECT nvl(null, 'x'), nvl('y', 'x'), nvl(null, null)"), + sql("SELECT nvl(null, 'x'), nvl('y', 'x'), nvl(null, null)"), Row("x", "y", null)) } @@ -222,7 +220,7 @@ class DataFrameFunctionsSuite extends QueryTest { Row(-1) ) checkAnswer( - ctx.sql("SELECT least(a, 2) as l from testData2 order by l"), + sql("SELECT least(a, 2) as l from testData2 order by l"), Seq(Row(1), Row(1), Row(2), Row(2), Row(2), Row(2)) ) } @@ -233,7 +231,7 @@ class DataFrameFunctionsSuite extends QueryTest { Row(3) ) checkAnswer( - ctx.sql("SELECT greatest(a, 2) as g from testData2 order by g"), + sql("SELECT greatest(a, 2) as g from testData2 order by g"), Seq(Row(2), Row(2), Row(2), Row(2), Row(3), Row(3)) ) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameImplicitsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameImplicitsSuite.scala index fbb30706a4943..e5d7d63441a6b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameImplicitsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameImplicitsSuite.scala @@ -17,10 +17,10 @@ package org.apache.spark.sql -class DataFrameImplicitsSuite extends QueryTest { +import org.apache.spark.sql.test.SharedSQLContext - private lazy val ctx = org.apache.spark.sql.test.TestSQLContext - import ctx.implicits._ +class DataFrameImplicitsSuite extends QueryTest with SharedSQLContext { + import testImplicits._ test("RDD of tuples") { checkAnswer( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala index e1c6c706242d2..e2716d7841d85 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala @@ -17,14 +17,12 @@ package org.apache.spark.sql -import org.apache.spark.sql.TestData._ import org.apache.spark.sql.execution.joins.BroadcastHashJoin import org.apache.spark.sql.functions._ +import org.apache.spark.sql.test.SharedSQLContext -class DataFrameJoinSuite extends QueryTest { - - private lazy val ctx = org.apache.spark.sql.test.TestSQLContext - import ctx.implicits._ +class DataFrameJoinSuite extends QueryTest with SharedSQLContext { + import testImplicits._ test("join - join using") { val df = Seq(1, 2, 3).map(i => (i, i.toString)).toDF("int", "str") @@ -59,7 +57,7 @@ class DataFrameJoinSuite extends QueryTest { checkAnswer( df1.join(df2, $"df1.key" === $"df2.key"), - ctx.sql("SELECT a.key, b.key FROM testData a JOIN testData b ON a.key = b.key") + sql("SELECT a.key, b.key FROM testData a JOIN testData b ON a.key = b.key") .collect().toSeq) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala index dbe3b44ee2c79..cdaa14ac80785 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala @@ -19,11 +19,11 @@ package org.apache.spark.sql import scala.collection.JavaConversions._ +import org.apache.spark.sql.test.SharedSQLContext -class DataFrameNaFunctionsSuite extends QueryTest { - private lazy val ctx = org.apache.spark.sql.test.TestSQLContext - import ctx.implicits._ +class DataFrameNaFunctionsSuite extends QueryTest with SharedSQLContext { + import testImplicits._ def createDF(): DataFrame = { Seq[(String, java.lang.Integer, java.lang.Double)]( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala index 8f5984e4a8ce2..28bdd6f83b687 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala @@ -19,20 +19,17 @@ package org.apache.spark.sql import java.util.Random -import org.scalatest.Matchers._ - import org.apache.spark.sql.functions.col +import org.apache.spark.sql.test.SharedSQLContext -class DataFrameStatSuite extends QueryTest { - - private val sqlCtx = org.apache.spark.sql.test.TestSQLContext - import sqlCtx.implicits._ +class DataFrameStatSuite extends QueryTest with SharedSQLContext { + import testImplicits._ private def toLetter(i: Int): String = (i + 97).toChar.toString test("sample with replacement") { val n = 100 - val data = sqlCtx.sparkContext.parallelize(1 to n, 2).toDF("id") + val data = ctx.sparkContext.parallelize(1 to n, 2).toDF("id") checkAnswer( data.sample(withReplacement = true, 0.05, seed = 13), Seq(5, 10, 52, 73).map(Row(_)) @@ -41,7 +38,7 @@ class DataFrameStatSuite extends QueryTest { test("sample without replacement") { val n = 100 - val data = sqlCtx.sparkContext.parallelize(1 to n, 2).toDF("id") + val data = ctx.sparkContext.parallelize(1 to n, 2).toDF("id") checkAnswer( data.sample(withReplacement = false, 0.05, seed = 13), Seq(16, 23, 88, 100).map(Row(_)) @@ -50,7 +47,7 @@ class DataFrameStatSuite extends QueryTest { test("randomSplit") { val n = 600 - val data = sqlCtx.sparkContext.parallelize(1 to n, 2).toDF("id") + val data = ctx.sparkContext.parallelize(1 to n, 2).toDF("id") for (seed <- 1 to 5) { val splits = data.randomSplit(Array[Double](1, 2, 3), seed) assert(splits.length == 3, "wrong number of splits") @@ -167,7 +164,7 @@ class DataFrameStatSuite extends QueryTest { } test("Frequent Items 2") { - val rows = sqlCtx.sparkContext.parallelize(Seq.empty[Int], 4) + val rows = ctx.sparkContext.parallelize(Seq.empty[Int], 4) // this is a regression test, where when merging partitions, we omitted values with higher // counts than those that existed in the map when the map was full. This test should also fail // if anything like SPARK-9614 is observed once again @@ -185,7 +182,7 @@ class DataFrameStatSuite extends QueryTest { } test("sampleBy") { - val df = sqlCtx.range(0, 100).select((col("id") % 3).as("key")) + val df = ctx.range(0, 100).select((col("id") % 3).as("key")) val sampled = df.stat.sampleBy("key", Map(0 -> 0.1, 1 -> 0.2), 0L) checkAnswer( sampled.groupBy("key").count().orderBy("key"), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 2feec29955bc8..10bfa9b64f00d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -23,18 +23,12 @@ import scala.language.postfixOps import scala.util.Random import org.apache.spark.sql.catalyst.plans.logical.OneRowRelation -import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.functions._ -import org.apache.spark.sql.execution.datasources.json.JSONRelation -import org.apache.spark.sql.execution.datasources.parquet.ParquetRelation import org.apache.spark.sql.types._ -import org.apache.spark.sql.test.{ExamplePointUDT, ExamplePoint, SQLTestUtils} +import org.apache.spark.sql.test.{ExamplePointUDT, ExamplePoint, SharedSQLContext} -class DataFrameSuite extends QueryTest with SQLTestUtils { - import org.apache.spark.sql.TestData._ - - lazy val sqlContext = org.apache.spark.sql.test.TestSQLContext - import sqlContext.implicits._ +class DataFrameSuite extends QueryTest with SharedSQLContext { + import testImplicits._ test("analysis error should be eagerly reported") { // Eager analysis. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTungstenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTungstenSuite.scala index bf8ef9a97bc60..77907e91363ec 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTungstenSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTungstenSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql import org.apache.spark.sql.functions._ -import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ /** @@ -27,10 +27,8 @@ import org.apache.spark.sql.types._ * This is here for now so I can make sure Tungsten project is tested without refactoring existing * end-to-end test infra. In the long run this should just go away. */ -class DataFrameTungstenSuite extends QueryTest with SQLTestUtils { - - override lazy val sqlContext: SQLContext = org.apache.spark.sql.test.TestSQLContext - import sqlContext.implicits._ +class DataFrameTungstenSuite extends QueryTest with SharedSQLContext { + import testImplicits._ test("test simple types") { withSQLConf(SQLConf.UNSAFE_ENABLED.key -> "true") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala index 17897caf952a3..9080c53c491ac 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala @@ -22,19 +22,18 @@ import java.text.SimpleDateFormat import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.functions._ +import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.unsafe.types.CalendarInterval -class DateFunctionsSuite extends QueryTest { - private lazy val ctx = org.apache.spark.sql.test.TestSQLContext - - import ctx.implicits._ +class DateFunctionsSuite extends QueryTest with SharedSQLContext { + import testImplicits._ test("function current_date") { val df1 = Seq((1, 2), (3, 1)).toDF("a", "b") val d0 = DateTimeUtils.millisToDays(System.currentTimeMillis()) val d1 = DateTimeUtils.fromJavaDate(df1.select(current_date()).collect().head.getDate(0)) val d2 = DateTimeUtils.fromJavaDate( - ctx.sql("""SELECT CURRENT_DATE()""").collect().head.getDate(0)) + sql("""SELECT CURRENT_DATE()""").collect().head.getDate(0)) val d3 = DateTimeUtils.millisToDays(System.currentTimeMillis()) assert(d0 <= d1 && d1 <= d2 && d2 <= d3 && d3 - d0 <= 1) } @@ -44,9 +43,9 @@ class DateFunctionsSuite extends QueryTest { val df1 = Seq((1, 2), (3, 1)).toDF("a", "b") checkAnswer(df1.select(countDistinct(current_timestamp())), Row(1)) // Execution in one query should return the same value - checkAnswer(ctx.sql("""SELECT CURRENT_TIMESTAMP() = CURRENT_TIMESTAMP()"""), + checkAnswer(sql("""SELECT CURRENT_TIMESTAMP() = CURRENT_TIMESTAMP()"""), Row(true)) - assert(math.abs(ctx.sql("""SELECT CURRENT_TIMESTAMP()""").collect().head.getTimestamp( + assert(math.abs(sql("""SELECT CURRENT_TIMESTAMP()""").collect().head.getTimestamp( 0).getTime - System.currentTimeMillis()) < 5000) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index ae07eaf91c872..f5c5046a8ed88 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -17,22 +17,15 @@ package org.apache.spark.sql -import org.scalatest.BeforeAndAfterEach - -import org.apache.spark.sql.TestData._ import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation import org.apache.spark.sql.execution.joins._ -import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.sql.test.SharedSQLContext -class JoinSuite extends QueryTest with SQLTestUtils with BeforeAndAfterEach { - // Ensures tables are loaded. - TestData +class JoinSuite extends QueryTest with SharedSQLContext { + import testImplicits._ - override def sqlContext: SQLContext = org.apache.spark.sql.test.TestSQLContext - lazy val ctx = org.apache.spark.sql.test.TestSQLContext - import ctx.implicits._ - import ctx.logicalPlanToSparkQuery + setupTestData() test("equi-join is hash-join") { val x = testData2.as("x") @@ -43,7 +36,7 @@ class JoinSuite extends QueryTest with SQLTestUtils with BeforeAndAfterEach { } def assertJoin(sqlString: String, c: Class[_]): Any = { - val df = ctx.sql(sqlString) + val df = sql(sqlString) val physical = df.queryExecution.sparkPlan val operators = physical.collect { case j: ShuffledHashJoin => j @@ -126,7 +119,7 @@ class JoinSuite extends QueryTest with SQLTestUtils with BeforeAndAfterEach { test("broadcasted hash join operator selection") { ctx.cacheManager.clearCache() - ctx.sql("CACHE TABLE testData") + sql("CACHE TABLE testData") for (sortMergeJoinEnabled <- Seq(true, false)) { withClue(s"sortMergeJoinEnabled=$sortMergeJoinEnabled") { withSQLConf(SQLConf.SORTMERGE_JOIN.key -> s"$sortMergeJoinEnabled") { @@ -141,12 +134,12 @@ class JoinSuite extends QueryTest with SQLTestUtils with BeforeAndAfterEach { } } } - ctx.sql("UNCACHE TABLE testData") + sql("UNCACHE TABLE testData") } test("broadcasted hash outer join operator selection") { ctx.cacheManager.clearCache() - ctx.sql("CACHE TABLE testData") + sql("CACHE TABLE testData") withSQLConf(SQLConf.SORTMERGE_JOIN.key -> "true") { Seq( ("SELECT * FROM testData LEFT JOIN testData2 ON key = a", @@ -167,7 +160,7 @@ class JoinSuite extends QueryTest with SQLTestUtils with BeforeAndAfterEach { classOf[BroadcastHashOuterJoin]) ).foreach { case (query, joinClass) => assertJoin(query, joinClass) } } - ctx.sql("UNCACHE TABLE testData") + sql("UNCACHE TABLE testData") } test("multiple-key equi-join is hash-join") { @@ -279,7 +272,7 @@ class JoinSuite extends QueryTest with SQLTestUtils with BeforeAndAfterEach { // Make sure we are choosing left.outputPartitioning as the // outputPartitioning for the outer join operator. checkAnswer( - ctx.sql( + sql( """ |SELECT l.N, count(*) |FROM upperCaseData l LEFT OUTER JOIN allNulls r ON (l.N = r.a) @@ -293,7 +286,7 @@ class JoinSuite extends QueryTest with SQLTestUtils with BeforeAndAfterEach { Row(6, 1) :: Nil) checkAnswer( - ctx.sql( + sql( """ |SELECT r.a, count(*) |FROM upperCaseData l LEFT OUTER JOIN allNulls r ON (l.N = r.a) @@ -339,7 +332,7 @@ class JoinSuite extends QueryTest with SQLTestUtils with BeforeAndAfterEach { // Make sure we are choosing right.outputPartitioning as the // outputPartitioning for the outer join operator. checkAnswer( - ctx.sql( + sql( """ |SELECT l.a, count(*) |FROM allNulls l RIGHT OUTER JOIN upperCaseData r ON (l.a = r.N) @@ -348,7 +341,7 @@ class JoinSuite extends QueryTest with SQLTestUtils with BeforeAndAfterEach { Row(null, 6)) checkAnswer( - ctx.sql( + sql( """ |SELECT r.N, count(*) |FROM allNulls l RIGHT OUTER JOIN upperCaseData r ON (l.a = r.N) @@ -400,7 +393,7 @@ class JoinSuite extends QueryTest with SQLTestUtils with BeforeAndAfterEach { // Make sure we are UnknownPartitioning as the outputPartitioning for the outer join operator. checkAnswer( - ctx.sql( + sql( """ |SELECT l.a, count(*) |FROM allNulls l FULL OUTER JOIN upperCaseData r ON (l.a = r.N) @@ -409,7 +402,7 @@ class JoinSuite extends QueryTest with SQLTestUtils with BeforeAndAfterEach { Row(null, 10)) checkAnswer( - ctx.sql( + sql( """ |SELECT r.N, count(*) |FROM allNulls l FULL OUTER JOIN upperCaseData r ON (l.a = r.N) @@ -424,7 +417,7 @@ class JoinSuite extends QueryTest with SQLTestUtils with BeforeAndAfterEach { Row(null, 4) :: Nil) checkAnswer( - ctx.sql( + sql( """ |SELECT l.N, count(*) |FROM upperCaseData l FULL OUTER JOIN allNulls r ON (l.N = r.a) @@ -439,7 +432,7 @@ class JoinSuite extends QueryTest with SQLTestUtils with BeforeAndAfterEach { Row(null, 4) :: Nil) checkAnswer( - ctx.sql( + sql( """ |SELECT r.a, count(*) |FROM upperCaseData l FULL OUTER JOIN allNulls r ON (l.N = r.a) @@ -450,7 +443,7 @@ class JoinSuite extends QueryTest with SQLTestUtils with BeforeAndAfterEach { test("broadcasted left semi join operator selection") { ctx.cacheManager.clearCache() - ctx.sql("CACHE TABLE testData") + sql("CACHE TABLE testData") withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "1000000000") { Seq( @@ -469,11 +462,11 @@ class JoinSuite extends QueryTest with SQLTestUtils with BeforeAndAfterEach { } } - ctx.sql("UNCACHE TABLE testData") + sql("UNCACHE TABLE testData") } test("left semi join") { - val df = ctx.sql("SELECT * FROM testData2 LEFT SEMI JOIN testData ON key = a") + val df = sql("SELECT * FROM testData2 LEFT SEMI JOIN testData ON key = a") checkAnswer(df, Row(1, 1) :: Row(1, 2) :: diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala index 71c26a6f8d367..045fea82e4c89 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala @@ -17,10 +17,10 @@ package org.apache.spark.sql -class JsonFunctionsSuite extends QueryTest { +import org.apache.spark.sql.test.SharedSQLContext - private lazy val ctx = org.apache.spark.sql.test.TestSQLContext - import ctx.implicits._ +class JsonFunctionsSuite extends QueryTest with SharedSQLContext { + import testImplicits._ test("function get_json_object") { val df: DataFrame = Seq(("""{"name": "alice", "age": 5}""", "")).toDF("a", "b") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala index 2089660c52bf7..babf8835d2545 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala @@ -19,12 +19,11 @@ package org.apache.spark.sql import org.scalatest.BeforeAndAfter +import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.{BooleanType, StringType, StructField, StructType} -class ListTablesSuite extends QueryTest with BeforeAndAfter { - - private lazy val ctx = org.apache.spark.sql.test.TestSQLContext - import ctx.implicits._ +class ListTablesSuite extends QueryTest with BeforeAndAfter with SharedSQLContext { + import testImplicits._ private lazy val df = (1 to 10).map(i => (i, s"str$i")).toDF("key", "value") @@ -42,7 +41,7 @@ class ListTablesSuite extends QueryTest with BeforeAndAfter { Row("ListTablesSuiteTable", true)) checkAnswer( - ctx.sql("SHOW tables").filter("tableName = 'ListTablesSuiteTable'"), + sql("SHOW tables").filter("tableName = 'ListTablesSuiteTable'"), Row("ListTablesSuiteTable", true)) ctx.catalog.unregisterTable(Seq("ListTablesSuiteTable")) @@ -55,7 +54,7 @@ class ListTablesSuite extends QueryTest with BeforeAndAfter { Row("ListTablesSuiteTable", true)) checkAnswer( - ctx.sql("show TABLES in DB").filter("tableName = 'ListTablesSuiteTable'"), + sql("show TABLES in DB").filter("tableName = 'ListTablesSuiteTable'"), Row("ListTablesSuiteTable", true)) ctx.catalog.unregisterTable(Seq("ListTablesSuiteTable")) @@ -67,13 +66,13 @@ class ListTablesSuite extends QueryTest with BeforeAndAfter { StructField("tableName", StringType, false) :: StructField("isTemporary", BooleanType, false) :: Nil) - Seq(ctx.tables(), ctx.sql("SHOW TABLes")).foreach { + Seq(ctx.tables(), sql("SHOW TABLes")).foreach { case tableDF => assert(expectedSchema === tableDF.schema) tableDF.registerTempTable("tables") checkAnswer( - ctx.sql( + sql( "SELECT isTemporary, tableName from tables WHERE tableName = 'ListTablesSuiteTable'"), Row(true, "ListTablesSuiteTable") ) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala index 8cf2ef5957d8d..30289c3c1d097 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala @@ -19,18 +19,16 @@ package org.apache.spark.sql import org.apache.spark.sql.functions._ import org.apache.spark.sql.functions.{log => logarithm} +import org.apache.spark.sql.test.SharedSQLContext private object MathExpressionsTestData { case class DoubleData(a: java.lang.Double, b: java.lang.Double) case class NullDoubles(a: java.lang.Double) } -class MathExpressionsSuite extends QueryTest { - +class MathExpressionsSuite extends QueryTest with SharedSQLContext { import MathExpressionsTestData._ - - private lazy val ctx = org.apache.spark.sql.test.TestSQLContext - import ctx.implicits._ + import testImplicits._ private lazy val doubleData = (1 to 10).map(i => DoubleData(i * 0.2 - 1, i * -0.2 + 1)).toDF() @@ -149,7 +147,7 @@ class MathExpressionsSuite extends QueryTest { test("toDegrees") { testOneToOneMathFunction(toDegrees, math.toDegrees) checkAnswer( - ctx.sql("SELECT degrees(0), degrees(1), degrees(1.5)"), + sql("SELECT degrees(0), degrees(1), degrees(1.5)"), Seq((1, 2)).toDF().select(toDegrees(lit(0)), toDegrees(lit(1)), toDegrees(lit(1.5))) ) } @@ -157,7 +155,7 @@ class MathExpressionsSuite extends QueryTest { test("toRadians") { testOneToOneMathFunction(toRadians, math.toRadians) checkAnswer( - ctx.sql("SELECT radians(0), radians(1), radians(1.5)"), + sql("SELECT radians(0), radians(1), radians(1.5)"), Seq((1, 2)).toDF().select(toRadians(lit(0)), toRadians(lit(1)), toRadians(lit(1.5))) ) } @@ -169,7 +167,7 @@ class MathExpressionsSuite extends QueryTest { test("ceil and ceiling") { testOneToOneMathFunction(ceil, math.ceil) checkAnswer( - ctx.sql("SELECT ceiling(0), ceiling(1), ceiling(1.5)"), + sql("SELECT ceiling(0), ceiling(1), ceiling(1.5)"), Row(0.0, 1.0, 2.0)) } @@ -214,7 +212,7 @@ class MathExpressionsSuite extends QueryTest { val pi = 3.1415 checkAnswer( - ctx.sql(s"SELECT round($pi, -3), round($pi, -2), round($pi, -1), " + + sql(s"SELECT round($pi, -3), round($pi, -2), round($pi, -1), " + s"round($pi, 0), round($pi, 1), round($pi, 2), round($pi, 3)"), Seq(Row(BigDecimal("0E3"), BigDecimal("0E2"), BigDecimal("0E1"), BigDecimal(3), BigDecimal("3.1"), BigDecimal("3.14"), BigDecimal("3.142"))) @@ -233,7 +231,7 @@ class MathExpressionsSuite extends QueryTest { testOneToOneMathFunction[Double](signum, math.signum) checkAnswer( - ctx.sql("SELECT sign(10), signum(-11)"), + sql("SELECT sign(10), signum(-11)"), Row(1, -1)) } @@ -241,7 +239,7 @@ class MathExpressionsSuite extends QueryTest { testTwoToOneMathFunction(pow, pow, math.pow) checkAnswer( - ctx.sql("SELECT pow(1, 2), power(2, 1)"), + sql("SELECT pow(1, 2), power(2, 1)"), Seq((1, 2)).toDF().select(pow(lit(1), lit(2)), pow(lit(2), lit(1))) ) } @@ -280,7 +278,7 @@ class MathExpressionsSuite extends QueryTest { test("log / ln") { testOneToOneNonNegativeMathFunction(org.apache.spark.sql.functions.log, math.log) checkAnswer( - ctx.sql("SELECT ln(0), ln(1), ln(1.5)"), + sql("SELECT ln(0), ln(1), ln(1.5)"), Seq((1, 2)).toDF().select(logarithm(lit(0)), logarithm(lit(1)), logarithm(lit(1.5))) ) } @@ -375,7 +373,7 @@ class MathExpressionsSuite extends QueryTest { df.select(log2("b") + log2("a")), Row(1)) - checkAnswer(ctx.sql("SELECT LOG2(8), LOG2(null)"), Row(3, null)) + checkAnswer(sql("SELECT LOG2(8), LOG2(null)"), Row(3, null)) } test("sqrt") { @@ -384,13 +382,13 @@ class MathExpressionsSuite extends QueryTest { df.select(sqrt("a"), sqrt("b")), Row(1.0, 2.0)) - checkAnswer(ctx.sql("SELECT SQRT(4.0), SQRT(null)"), Row(2.0, null)) + checkAnswer(sql("SELECT SQRT(4.0), SQRT(null)"), Row(2.0, null)) checkAnswer(df.selectExpr("sqrt(a)", "sqrt(b)", "sqrt(null)"), Row(1.0, 2.0, null)) } test("negative") { checkAnswer( - ctx.sql("SELECT negative(1), negative(0), negative(-1)"), + sql("SELECT negative(1), negative(0), negative(-1)"), Row(-1, 0, 1)) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala index 98ba3c99283a1..4adcefb7dc4b1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -71,12 +71,6 @@ class QueryTest extends PlanTest { checkAnswer(df, expectedAnswer.collect()) } - def sqlTest(sqlString: String, expectedAnswer: Seq[Row])(implicit sqlContext: SQLContext) { - test(sqlString) { - checkAnswer(sqlContext.sql(sqlString), expectedAnswer) - } - } - /** * Asserts that a given [[DataFrame]] will be executed using the given number of cached results. */ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala index 8a679c7865d6a..795d4e983f27e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala @@ -20,13 +20,12 @@ package org.apache.spark.sql import org.apache.spark.SparkFunSuite import org.apache.spark.sql.execution.SparkSqlSerializer import org.apache.spark.sql.catalyst.expressions.{GenericMutableRow, SpecificMutableRow} +import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String -class RowSuite extends SparkFunSuite { - - private lazy val ctx = org.apache.spark.sql.test.TestSQLContext - import ctx.implicits._ +class RowSuite extends SparkFunSuite with SharedSQLContext { + import testImplicits._ test("create row") { val expected = new GenericMutableRow(4) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala index 75791e9d53c20..7699adadd9cc8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala @@ -17,11 +17,10 @@ package org.apache.spark.sql +import org.apache.spark.sql.test.SharedSQLContext -class SQLConfSuite extends QueryTest { - - private lazy val ctx = org.apache.spark.sql.test.TestSQLContext +class SQLConfSuite extends QueryTest with SharedSQLContext { private val testKey = "test.key.0" private val testVal = "test.val.0" @@ -52,21 +51,21 @@ class SQLConfSuite extends QueryTest { test("parse SQL set commands") { ctx.conf.clear() - ctx.sql(s"set $testKey=$testVal") + sql(s"set $testKey=$testVal") assert(ctx.getConf(testKey, testVal + "_") === testVal) assert(ctx.getConf(testKey, testVal + "_") === testVal) - ctx.sql("set some.property=20") + sql("set some.property=20") assert(ctx.getConf("some.property", "0") === "20") - ctx.sql("set some.property = 40") + sql("set some.property = 40") assert(ctx.getConf("some.property", "0") === "40") val key = "spark.sql.key" val vs = "val0,val_1,val2.3,my_table" - ctx.sql(s"set $key=$vs") + sql(s"set $key=$vs") assert(ctx.getConf(key, "0") === vs) - ctx.sql(s"set $key=") + sql(s"set $key=") assert(ctx.getConf(key, "0") === "") ctx.conf.clear() @@ -74,14 +73,14 @@ class SQLConfSuite extends QueryTest { test("deprecated property") { ctx.conf.clear() - ctx.sql(s"set ${SQLConf.Deprecated.MAPRED_REDUCE_TASKS}=10") + sql(s"set ${SQLConf.Deprecated.MAPRED_REDUCE_TASKS}=10") assert(ctx.conf.numShufflePartitions === 10) } test("invalid conf value") { ctx.conf.clear() val e = intercept[IllegalArgumentException] { - ctx.sql(s"set ${SQLConf.CASE_SENSITIVE.key}=10") + sql(s"set ${SQLConf.CASE_SENSITIVE.key}=10") } assert(e.getMessage === s"${SQLConf.CASE_SENSITIVE.key} should be boolean, but was 10") } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala index c8d8796568a41..007be12950774 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala @@ -17,16 +17,17 @@ package org.apache.spark.sql -import org.scalatest.BeforeAndAfterAll - import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.test.SharedSQLContext -class SQLContextSuite extends SparkFunSuite with BeforeAndAfterAll { - - private lazy val ctx = org.apache.spark.sql.test.TestSQLContext +class SQLContextSuite extends SparkFunSuite with SharedSQLContext { override def afterAll(): Unit = { - SQLContext.setLastInstantiatedContext(ctx) + try { + SQLContext.setLastInstantiatedContext(ctx) + } finally { + super.afterAll() + } } test("getOrCreate instantiates SQLContext") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index b14ef9bab90cb..8c2c328f8191c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -19,28 +19,23 @@ package org.apache.spark.sql import java.sql.Timestamp -import org.scalatest.BeforeAndAfterAll - import org.apache.spark.AccumulatorSuite import org.apache.spark.sql.catalyst.analysis.FunctionRegistry import org.apache.spark.sql.catalyst.DefaultParserDialect import org.apache.spark.sql.catalyst.errors.DialectException import org.apache.spark.sql.execution.aggregate import org.apache.spark.sql.functions._ -import org.apache.spark.sql.TestData._ -import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.test.SQLTestData._ import org.apache.spark.sql.types._ /** A SQL Dialect for testing purpose, and it can not be nested type */ class MyDialect extends DefaultParserDialect -class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { - // Make sure the tables are loaded. - TestData +class SQLQuerySuite extends QueryTest with SharedSQLContext { + import testImplicits._ - val sqlContext = org.apache.spark.sql.test.TestSQLContext - import sqlContext.implicits._ - import sqlContext.sql + setupTestData() test("having clause") { Seq(("one", 1), ("two", 2), ("three", 3), ("one", 5)).toDF("k", "v").registerTempTable("hav") @@ -60,7 +55,8 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { } test("show functions") { - checkAnswer(sql("SHOW functions"), FunctionRegistry.builtin.listFunction().sorted.map(Row(_))) + checkAnswer(sql("SHOW functions"), + FunctionRegistry.builtin.listFunction().sorted.map(Row(_))) } test("describe functions") { @@ -178,7 +174,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { val df = Seq(Tuple1(1), Tuple1(2), Tuple1(3)).toDF("index") // we except the id is materialized once - val idUDF = udf(() => UUID.randomUUID().toString) + val idUDF = org.apache.spark.sql.functions.udf(() => UUID.randomUUID().toString) val dfWithId = df.withColumn("id", idUDF()) // Make a new DataFrame (actually the same reference to the old one) @@ -712,9 +708,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { checkAnswer( sql( - """ - |SELECT COUNT(a), COUNT(b), COUNT(1), COUNT(DISTINCT a), COUNT(DISTINCT b) FROM testData3 - """.stripMargin), + "SELECT COUNT(a), COUNT(b), COUNT(1), COUNT(DISTINCT a), COUNT(DISTINCT b) FROM testData3"), Row(2, 1, 2, 2, 1)) } @@ -1161,7 +1155,8 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { validateMetadata(sql("SELECT * FROM personWithMeta")) validateMetadata(sql("SELECT id, name FROM personWithMeta")) validateMetadata(sql("SELECT * FROM personWithMeta JOIN salary ON id = personId")) - validateMetadata(sql("SELECT name, salary FROM personWithMeta JOIN salary ON id = personId")) + validateMetadata(sql( + "SELECT name, salary FROM personWithMeta JOIN salary ON id = personId")) } test("SPARK-3371 Renaming a function expression with group by gives error") { @@ -1627,7 +1622,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { .toDF("num", "str") df.registerTempTable("1one") - checkAnswer(sqlContext.sql("select count(num) from 1one"), Row(10)) + checkAnswer(sql("select count(num) from 1one"), Row(10)) sqlContext.dropTempTable("1one") } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala index ab6d3dd96d271..295f02f9a7b5d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql import java.sql.{Date, Timestamp} import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.test.SharedSQLContext case class ReflectData( stringField: String, @@ -71,17 +72,15 @@ case class ComplexReflectData( mapFieldContainsNull: Map[Int, Option[Long]], dataField: Data) -class ScalaReflectionRelationSuite extends SparkFunSuite { - - private lazy val ctx = org.apache.spark.sql.test.TestSQLContext - import ctx.implicits._ +class ScalaReflectionRelationSuite extends SparkFunSuite with SharedSQLContext { + import testImplicits._ test("query case class RDD") { val data = ReflectData("a", 1, 1L, 1.toFloat, 1.toDouble, 1.toShort, 1.toByte, true, new java.math.BigDecimal(1), Date.valueOf("1970-01-01"), new Timestamp(12345), Seq(1, 2, 3)) Seq(data).toDF().registerTempTable("reflectData") - assert(ctx.sql("SELECT * FROM reflectData").collect().head === + assert(sql("SELECT * FROM reflectData").collect().head === Row("a", 1, 1L, 1.toFloat, 1.toDouble, 1.toShort, 1.toByte, true, new java.math.BigDecimal(1), Date.valueOf("1970-01-01"), new Timestamp(12345), Seq(1, 2, 3))) @@ -91,7 +90,7 @@ class ScalaReflectionRelationSuite extends SparkFunSuite { val data = NullReflectData(null, null, null, null, null, null, null) Seq(data).toDF().registerTempTable("reflectNullData") - assert(ctx.sql("SELECT * FROM reflectNullData").collect().head === + assert(sql("SELECT * FROM reflectNullData").collect().head === Row.fromSeq(Seq.fill(7)(null))) } @@ -99,7 +98,7 @@ class ScalaReflectionRelationSuite extends SparkFunSuite { val data = OptionalReflectData(None, None, None, None, None, None, None) Seq(data).toDF().registerTempTable("reflectOptionalData") - assert(ctx.sql("SELECT * FROM reflectOptionalData").collect().head === + assert(sql("SELECT * FROM reflectOptionalData").collect().head === Row.fromSeq(Seq.fill(7)(null))) } @@ -107,7 +106,7 @@ class ScalaReflectionRelationSuite extends SparkFunSuite { test("query binary data") { Seq(ReflectBinary(Array[Byte](1))).toDF().registerTempTable("reflectBinary") - val result = ctx.sql("SELECT data FROM reflectBinary") + val result = sql("SELECT data FROM reflectBinary") .collect().head(0).asInstanceOf[Array[Byte]] assert(result.toSeq === Seq[Byte](1)) } @@ -126,7 +125,7 @@ class ScalaReflectionRelationSuite extends SparkFunSuite { Nested(None, "abc"))) Seq(data).toDF().registerTempTable("reflectComplexData") - assert(ctx.sql("SELECT * FROM reflectComplexData").collect().head === + assert(sql("SELECT * FROM reflectComplexData").collect().head === Row( Seq(1, 2, 3), Seq(1, 2, null), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SerializationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SerializationSuite.scala index e55c9e460b791..45d0ee4a8e749 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SerializationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SerializationSuite.scala @@ -19,13 +19,12 @@ package org.apache.spark.sql import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.serializer.JavaSerializer +import org.apache.spark.sql.test.SharedSQLContext -class SerializationSuite extends SparkFunSuite { - - private lazy val ctx = org.apache.spark.sql.test.TestSQLContext +class SerializationSuite extends SparkFunSuite with SharedSQLContext { test("[SPARK-5235] SQLContext should be serializable") { - val sqlContext = new SQLContext(ctx.sparkContext) - new JavaSerializer(new SparkConf()).newInstance().serialize(sqlContext) + val _sqlContext = new SQLContext(sqlContext.sparkContext) + new JavaSerializer(new SparkConf()).newInstance().serialize(_sqlContext) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala index ca298b2434410..cc95eede005d7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala @@ -18,13 +18,12 @@ package org.apache.spark.sql import org.apache.spark.sql.functions._ +import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.Decimal -class StringFunctionsSuite extends QueryTest { - - private lazy val ctx = org.apache.spark.sql.test.TestSQLContext - import ctx.implicits._ +class StringFunctionsSuite extends QueryTest with SharedSQLContext { + import testImplicits._ test("string concat") { val df = Seq[(String, String, String)](("a", "b", null)).toDF("a", "b", "c") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala deleted file mode 100644 index bd9729c431f30..0000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala +++ /dev/null @@ -1,197 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql - -import org.apache.spark.sql.test.TestSQLContext.implicits._ -import org.apache.spark.sql.test._ - - -case class TestData(key: Int, value: String) - -object TestData { - val testData = TestSQLContext.sparkContext.parallelize( - (1 to 100).map(i => TestData(i, i.toString))).toDF() - testData.registerTempTable("testData") - - val negativeData = TestSQLContext.sparkContext.parallelize( - (1 to 100).map(i => TestData(-i, (-i).toString))).toDF() - negativeData.registerTempTable("negativeData") - - case class LargeAndSmallInts(a: Int, b: Int) - val largeAndSmallInts = - TestSQLContext.sparkContext.parallelize( - LargeAndSmallInts(2147483644, 1) :: - LargeAndSmallInts(1, 2) :: - LargeAndSmallInts(2147483645, 1) :: - LargeAndSmallInts(2, 2) :: - LargeAndSmallInts(2147483646, 1) :: - LargeAndSmallInts(3, 2) :: Nil).toDF() - largeAndSmallInts.registerTempTable("largeAndSmallInts") - - case class TestData2(a: Int, b: Int) - val testData2 = - TestSQLContext.sparkContext.parallelize( - TestData2(1, 1) :: - TestData2(1, 2) :: - TestData2(2, 1) :: - TestData2(2, 2) :: - TestData2(3, 1) :: - TestData2(3, 2) :: Nil, 2).toDF() - testData2.registerTempTable("testData2") - - case class DecimalData(a: BigDecimal, b: BigDecimal) - - val decimalData = - TestSQLContext.sparkContext.parallelize( - DecimalData(1, 1) :: - DecimalData(1, 2) :: - DecimalData(2, 1) :: - DecimalData(2, 2) :: - DecimalData(3, 1) :: - DecimalData(3, 2) :: Nil).toDF() - decimalData.registerTempTable("decimalData") - - case class BinaryData(a: Array[Byte], b: Int) - val binaryData = - TestSQLContext.sparkContext.parallelize( - BinaryData("12".getBytes(), 1) :: - BinaryData("22".getBytes(), 5) :: - BinaryData("122".getBytes(), 3) :: - BinaryData("121".getBytes(), 2) :: - BinaryData("123".getBytes(), 4) :: Nil).toDF() - binaryData.registerTempTable("binaryData") - - case class TestData3(a: Int, b: Option[Int]) - val testData3 = - TestSQLContext.sparkContext.parallelize( - TestData3(1, None) :: - TestData3(2, Some(2)) :: Nil).toDF() - testData3.registerTempTable("testData3") - - case class UpperCaseData(N: Int, L: String) - val upperCaseData = - TestSQLContext.sparkContext.parallelize( - UpperCaseData(1, "A") :: - UpperCaseData(2, "B") :: - UpperCaseData(3, "C") :: - UpperCaseData(4, "D") :: - UpperCaseData(5, "E") :: - UpperCaseData(6, "F") :: Nil).toDF() - upperCaseData.registerTempTable("upperCaseData") - - case class LowerCaseData(n: Int, l: String) - val lowerCaseData = - TestSQLContext.sparkContext.parallelize( - LowerCaseData(1, "a") :: - LowerCaseData(2, "b") :: - LowerCaseData(3, "c") :: - LowerCaseData(4, "d") :: Nil).toDF() - lowerCaseData.registerTempTable("lowerCaseData") - - case class ArrayData(data: Seq[Int], nestedData: Seq[Seq[Int]]) - val arrayData = - TestSQLContext.sparkContext.parallelize( - ArrayData(Seq(1, 2, 3), Seq(Seq(1, 2, 3))) :: - ArrayData(Seq(2, 3, 4), Seq(Seq(2, 3, 4))) :: Nil) - arrayData.toDF().registerTempTable("arrayData") - - case class MapData(data: scala.collection.Map[Int, String]) - val mapData = - TestSQLContext.sparkContext.parallelize( - MapData(Map(1 -> "a1", 2 -> "b1", 3 -> "c1", 4 -> "d1", 5 -> "e1")) :: - MapData(Map(1 -> "a2", 2 -> "b2", 3 -> "c2", 4 -> "d2")) :: - MapData(Map(1 -> "a3", 2 -> "b3", 3 -> "c3")) :: - MapData(Map(1 -> "a4", 2 -> "b4")) :: - MapData(Map(1 -> "a5")) :: Nil) - mapData.toDF().registerTempTable("mapData") - - case class StringData(s: String) - val repeatedData = - TestSQLContext.sparkContext.parallelize(List.fill(2)(StringData("test"))) - repeatedData.toDF().registerTempTable("repeatedData") - - val nullableRepeatedData = - TestSQLContext.sparkContext.parallelize( - List.fill(2)(StringData(null)) ++ - List.fill(2)(StringData("test"))) - nullableRepeatedData.toDF().registerTempTable("nullableRepeatedData") - - case class NullInts(a: Integer) - val nullInts = - TestSQLContext.sparkContext.parallelize( - NullInts(1) :: - NullInts(2) :: - NullInts(3) :: - NullInts(null) :: Nil - ).toDF() - nullInts.registerTempTable("nullInts") - - val allNulls = - TestSQLContext.sparkContext.parallelize( - NullInts(null) :: - NullInts(null) :: - NullInts(null) :: - NullInts(null) :: Nil).toDF() - allNulls.registerTempTable("allNulls") - - case class NullStrings(n: Int, s: String) - val nullStrings = - TestSQLContext.sparkContext.parallelize( - NullStrings(1, "abc") :: - NullStrings(2, "ABC") :: - NullStrings(3, null) :: Nil).toDF() - nullStrings.registerTempTable("nullStrings") - - case class TableName(tableName: String) - TestSQLContext - .sparkContext - .parallelize(TableName("test") :: Nil) - .toDF() - .registerTempTable("tableName") - - val unparsedStrings = - TestSQLContext.sparkContext.parallelize( - "1, A1, true, null" :: - "2, B2, false, null" :: - "3, C3, true, null" :: - "4, D4, true, 2147483644" :: Nil) - - case class IntField(i: Int) - // An RDD with 4 elements and 8 partitions - val withEmptyParts = TestSQLContext.sparkContext.parallelize((1 to 4).map(IntField), 8) - withEmptyParts.toDF().registerTempTable("withEmptyParts") - - case class Person(id: Int, name: String, age: Int) - case class Salary(personId: Int, salary: Double) - val person = TestSQLContext.sparkContext.parallelize( - Person(0, "mike", 30) :: - Person(1, "jim", 20) :: Nil).toDF() - person.registerTempTable("person") - val salary = TestSQLContext.sparkContext.parallelize( - Salary(0, 2000.0) :: - Salary(1, 1000.0) :: Nil).toDF() - salary.registerTempTable("salary") - - case class ComplexData(m: Map[String, Int], s: TestData, a: Seq[Int], b: Boolean) - val complexData = - TestSQLContext.sparkContext.parallelize( - ComplexData(Map("1" -> 1), TestData(1, "1"), Seq(1, 1, 1), true) - :: ComplexData(Map("2" -> 2), TestData(2, "2"), Seq(2, 2, 2), false) - :: Nil).toDF() - complexData.registerTempTable("complexData") -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala index 183dc3407b3ab..eb275af101e2f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala @@ -17,16 +17,13 @@ package org.apache.spark.sql -import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.test.SQLTestData._ -case class FunctionResult(f1: String, f2: String) +private case class FunctionResult(f1: String, f2: String) -class UDFSuite extends QueryTest with SQLTestUtils { - - private lazy val ctx = org.apache.spark.sql.test.TestSQLContext - import ctx.implicits._ - - override def sqlContext(): SQLContext = ctx +class UDFSuite extends QueryTest with SharedSQLContext { + import testImplicits._ test("built-in fixed arity expressions") { val df = ctx.emptyDataFrame @@ -57,7 +54,7 @@ class UDFSuite extends QueryTest with SQLTestUtils { test("SPARK-8003 spark_partition_id") { val df = Seq((1, "Tearing down the walls that divide us")).toDF("id", "saying") df.registerTempTable("tmp_table") - checkAnswer(ctx.sql("select spark_partition_id() from tmp_table").toDF(), Row(0)) + checkAnswer(sql("select spark_partition_id() from tmp_table").toDF(), Row(0)) ctx.dropTempTable("tmp_table") } @@ -66,9 +63,9 @@ class UDFSuite extends QueryTest with SQLTestUtils { val data = ctx.sparkContext.parallelize(0 to 10, 2).toDF("id") data.write.parquet(dir.getCanonicalPath) ctx.read.parquet(dir.getCanonicalPath).registerTempTable("test_table") - val answer = ctx.sql("select input_file_name() from test_table").head().getString(0) + val answer = sql("select input_file_name() from test_table").head().getString(0) assert(answer.contains(dir.getCanonicalPath)) - assert(ctx.sql("select input_file_name() from test_table").distinct().collect().length >= 2) + assert(sql("select input_file_name() from test_table").distinct().collect().length >= 2) ctx.dropTempTable("test_table") } } @@ -91,17 +88,17 @@ class UDFSuite extends QueryTest with SQLTestUtils { test("Simple UDF") { ctx.udf.register("strLenScala", (_: String).length) - assert(ctx.sql("SELECT strLenScala('test')").head().getInt(0) === 4) + assert(sql("SELECT strLenScala('test')").head().getInt(0) === 4) } test("ZeroArgument UDF") { ctx.udf.register("random0", () => { Math.random()}) - assert(ctx.sql("SELECT random0()").head().getDouble(0) >= 0.0) + assert(sql("SELECT random0()").head().getDouble(0) >= 0.0) } test("TwoArgument UDF") { ctx.udf.register("strLenScala", (_: String).length + (_: Int)) - assert(ctx.sql("SELECT strLenScala('test', 1)").head().getInt(0) === 5) + assert(sql("SELECT strLenScala('test', 1)").head().getInt(0) === 5) } test("UDF in a WHERE") { @@ -112,7 +109,7 @@ class UDFSuite extends QueryTest with SQLTestUtils { df.registerTempTable("integerData") val result = - ctx.sql("SELECT * FROM integerData WHERE oneArgFilter(key)") + sql("SELECT * FROM integerData WHERE oneArgFilter(key)") assert(result.count() === 20) } @@ -124,7 +121,7 @@ class UDFSuite extends QueryTest with SQLTestUtils { df.registerTempTable("groupData") val result = - ctx.sql( + sql( """ | SELECT g, SUM(v) as s | FROM groupData @@ -143,7 +140,7 @@ class UDFSuite extends QueryTest with SQLTestUtils { df.registerTempTable("groupData") val result = - ctx.sql( + sql( """ | SELECT SUM(v) | FROM groupData @@ -163,7 +160,7 @@ class UDFSuite extends QueryTest with SQLTestUtils { df.registerTempTable("groupData") val result = - ctx.sql( + sql( """ | SELECT timesHundred(SUM(v)) as v100 | FROM groupData @@ -178,7 +175,7 @@ class UDFSuite extends QueryTest with SQLTestUtils { ctx.udf.register("returnStruct", (f1: String, f2: String) => FunctionResult(f1, f2)) val result = - ctx.sql("SELECT returnStruct('test', 'test2') as ret") + sql("SELECT returnStruct('test', 'test2') as ret") .select($"ret.f1").head().getString(0) assert(result === "test") } @@ -186,12 +183,12 @@ class UDFSuite extends QueryTest with SQLTestUtils { test("udf that is transformed") { ctx.udf.register("makeStruct", (x: Int, y: Int) => (x, y)) // 1 + 1 is constant folded causing a transformation. - assert(ctx.sql("SELECT makeStruct(1 + 1, 2)").first().getAs[Row](0) === Row(2, 2)) + assert(sql("SELECT makeStruct(1 + 1, 2)").first().getAs[Row](0) === Row(2, 2)) } test("type coercion for udf inputs") { ctx.udf.register("intExpected", (x: Int) => x) // pass a decimal to intExpected. - assert(ctx.sql("SELECT intExpected(1.0)").head().getInt(0) === 1) + assert(sql("SELECT intExpected(1.0)").head().getInt(0) === 1) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala index 9181222f6922b..b6d279ae47268 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala @@ -24,6 +24,7 @@ import com.clearspring.analytics.stream.cardinality.HyperLogLog import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.expressions.{OpenHashSetUDT, HyperLogLogUDT} import org.apache.spark.sql.functions._ +import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ import org.apache.spark.util.Utils import org.apache.spark.util.collection.OpenHashSet @@ -66,10 +67,8 @@ private[sql] class MyDenseVectorUDT extends UserDefinedType[MyDenseVector] { private[spark] override def asNullable: MyDenseVectorUDT = this } -class UserDefinedTypeSuite extends QueryTest { - - private lazy val ctx = org.apache.spark.sql.test.TestSQLContext - import ctx.implicits._ +class UserDefinedTypeSuite extends QueryTest with SharedSQLContext { + import testImplicits._ private lazy val pointsRDD = Seq( MyLabeledPoint(1.0, new MyDenseVector(Array(0.1, 1.0))), @@ -94,7 +93,7 @@ class UserDefinedTypeSuite extends QueryTest { ctx.udf.register("testType", (d: MyDenseVector) => d.isInstanceOf[MyDenseVector]) pointsRDD.registerTempTable("points") checkAnswer( - ctx.sql("SELECT testType(features) from points"), + sql("SELECT testType(features) from points"), Seq(Row(true), Row(true))) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala index 9bca4e7e660d6..952637c5f9cb8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala @@ -19,18 +19,16 @@ package org.apache.spark.sql.columnar import java.sql.{Date, Timestamp} -import org.apache.spark.sql.TestData._ +import org.apache.spark.sql.{QueryTest, Row} +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.test.SQLTestData._ import org.apache.spark.sql.types._ -import org.apache.spark.sql.{QueryTest, Row, TestData} import org.apache.spark.storage.StorageLevel.MEMORY_ONLY -class InMemoryColumnarQuerySuite extends QueryTest { - // Make sure the tables are loaded. - TestData +class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { + import testImplicits._ - private lazy val ctx = org.apache.spark.sql.test.TestSQLContext - import ctx.implicits._ - import ctx.{logicalPlanToSparkQuery, sql} + setupTestData() test("simple columnar query") { val plan = ctx.executePlan(testData.logicalPlan).executedPlan diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala index 2c0879927a129..ab2644eb4581d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala @@ -17,20 +17,19 @@ package org.apache.spark.sql.columnar -import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll} - import org.apache.spark.SparkFunSuite import org.apache.spark.sql._ +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.test.SQLTestData._ -class PartitionBatchPruningSuite extends SparkFunSuite with BeforeAndAfterAll with BeforeAndAfter { - - private lazy val ctx = org.apache.spark.sql.test.TestSQLContext - import ctx.implicits._ +class PartitionBatchPruningSuite extends SparkFunSuite with SharedSQLContext { + import testImplicits._ private lazy val originalColumnBatchSize = ctx.conf.columnBatchSize private lazy val originalInMemoryPartitionPruning = ctx.conf.inMemoryPartitionPruning override protected def beforeAll(): Unit = { + super.beforeAll() // Make a table with 5 partitions, 2 batches per partition, 10 elements per batch ctx.setConf(SQLConf.COLUMN_BATCH_SIZE, 10) @@ -44,19 +43,17 @@ class PartitionBatchPruningSuite extends SparkFunSuite with BeforeAndAfterAll wi ctx.setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, true) // Enable in-memory table scan accumulators ctx.setConf("spark.sql.inMemoryTableScanStatistics.enable", "true") - } - - override protected def afterAll(): Unit = { - ctx.setConf(SQLConf.COLUMN_BATCH_SIZE, originalColumnBatchSize) - ctx.setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, originalInMemoryPartitionPruning) - } - - before { ctx.cacheTable("pruningData") } - after { - ctx.uncacheTable("pruningData") + override protected def afterAll(): Unit = { + try { + ctx.setConf(SQLConf.COLUMN_BATCH_SIZE, originalColumnBatchSize) + ctx.setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, originalInMemoryPartitionPruning) + ctx.uncacheTable("pruningData") + } finally { + super.afterAll() + } } // Comparisons @@ -110,7 +107,7 @@ class PartitionBatchPruningSuite extends SparkFunSuite with BeforeAndAfterAll wi expectedQueryResult: => Seq[Int]): Unit = { test(query) { - val df = ctx.sql(query) + val df = sql(query) val queryExecution = df.queryExecution assertResult(expectedQueryResult.toArray, s"Wrong query result: $queryExecution") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala index 79e903c2bbd40..8998f5111124c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala @@ -19,8 +19,9 @@ package org.apache.spark.sql.execution import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.plans.physical.SinglePartition +import org.apache.spark.sql.test.SharedSQLContext -class ExchangeSuite extends SparkPlanTest { +class ExchangeSuite extends SparkPlanTest with SharedSQLContext { test("shuffling UnsafeRows in exchange") { val input = (1 to 1000).map(Tuple1.apply) checkAnswer( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index 5582caa0d366e..937a108543531 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution import org.apache.spark.SparkFunSuite import org.apache.spark.rdd.RDD -import org.apache.spark.sql.TestData._ +import org.apache.spark.sql.{execution, Row, SQLConf} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, Literal, SortOrder} import org.apache.spark.sql.catalyst.plans._ @@ -27,19 +27,18 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.joins.{BroadcastHashJoin, ShuffledHashJoin} import org.apache.spark.sql.functions._ -import org.apache.spark.sql.test.{SQLTestUtils, TestSQLContext} -import org.apache.spark.sql.test.TestSQLContext._ -import org.apache.spark.sql.test.TestSQLContext.implicits._ -import org.apache.spark.sql.test.TestSQLContext.planner._ +import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ -import org.apache.spark.sql.{SQLContext, Row, SQLConf, execution} -class PlannerSuite extends SparkFunSuite with SQLTestUtils { +class PlannerSuite extends SparkFunSuite with SharedSQLContext { + import testImplicits._ - override def sqlContext: SQLContext = TestSQLContext + setupTestData() private def testPartialAggregationPlan(query: LogicalPlan): Unit = { + val _ctx = ctx + import _ctx.planner._ val plannedOption = HashAggregation(query).headOption.orElse(Aggregation(query).headOption) val planned = plannedOption.getOrElse( @@ -54,6 +53,8 @@ class PlannerSuite extends SparkFunSuite with SQLTestUtils { } test("unions are collapsed") { + val _ctx = ctx + import _ctx.planner._ val query = testData.unionAll(testData).unionAll(testData).logicalPlan val planned = BasicOperators(query).head val logicalUnions = query collect { case u: logical.Union => u } @@ -81,14 +82,14 @@ class PlannerSuite extends SparkFunSuite with SQLTestUtils { test("sizeInBytes estimation of limit operator for broadcast hash join optimization") { def checkPlan(fieldTypes: Seq[DataType], newThreshold: Int): Unit = { - setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, newThreshold) + ctx.setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, newThreshold) val fields = fieldTypes.zipWithIndex.map { case (dataType, index) => StructField(s"c${index}", dataType, true) } :+ StructField("key", IntegerType, true) val schema = StructType(fields) val row = Row.fromSeq(Seq.fill(fields.size)(null)) - val rowRDD = org.apache.spark.sql.test.TestSQLContext.sparkContext.parallelize(row :: Nil) - createDataFrame(rowRDD, schema).registerTempTable("testLimit") + val rowRDD = ctx.sparkContext.parallelize(row :: Nil) + ctx.createDataFrame(rowRDD, schema).registerTempTable("testLimit") val planned = sql( """ @@ -102,10 +103,10 @@ class PlannerSuite extends SparkFunSuite with SQLTestUtils { assert(broadcastHashJoins.size === 1, "Should use broadcast hash join") assert(shuffledHashJoins.isEmpty, "Should not use shuffled hash join") - dropTempTable("testLimit") + ctx.dropTempTable("testLimit") } - val origThreshold = conf.autoBroadcastJoinThreshold + val origThreshold = ctx.conf.autoBroadcastJoinThreshold val simpleTypes = NullType :: @@ -137,18 +138,18 @@ class PlannerSuite extends SparkFunSuite with SQLTestUtils { checkPlan(complexTypes, newThreshold = 901617) - setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, origThreshold) + ctx.setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, origThreshold) } test("InMemoryRelation statistics propagation") { - val origThreshold = conf.autoBroadcastJoinThreshold - setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, 81920) + val origThreshold = ctx.conf.autoBroadcastJoinThreshold + ctx.setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, 81920) testData.limit(3).registerTempTable("tiny") sql("CACHE TABLE tiny") val a = testData.as("a") - val b = table("tiny").as("b") + val b = ctx.table("tiny").as("b") val planned = a.join(b, $"a.key" === $"b.key").queryExecution.executedPlan val broadcastHashJoins = planned.collect { case join: BroadcastHashJoin => join } @@ -157,12 +158,12 @@ class PlannerSuite extends SparkFunSuite with SQLTestUtils { assert(broadcastHashJoins.size === 1, "Should use broadcast hash join") assert(shuffledHashJoins.isEmpty, "Should not use shuffled hash join") - setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, origThreshold) + ctx.setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, origThreshold) } test("efficient limit -> project -> sort") { val query = testData.sort('key).select('value).limit(2).logicalPlan - val planned = planner.TakeOrderedAndProject(query) + val planned = ctx.planner.TakeOrderedAndProject(query) assert(planned.head.isInstanceOf[execution.TakeOrderedAndProject]) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala index dd08e9025a927..ef6ad59b71fb3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala @@ -21,11 +21,11 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Attribute, Literal, IsNull} -import org.apache.spark.sql.test.TestSQLContext -import org.apache.spark.sql.types.{GenericArrayData, ArrayType, StructType, StringType} +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types.{GenericArrayData, ArrayType, StringType} import org.apache.spark.unsafe.types.UTF8String -class RowFormatConvertersSuite extends SparkPlanTest { +class RowFormatConvertersSuite extends SparkPlanTest with SharedSQLContext { private def getConverters(plan: SparkPlan): Seq[SparkPlan] = plan.collect { case c: ConvertToUnsafe => c @@ -39,20 +39,20 @@ class RowFormatConvertersSuite extends SparkPlanTest { test("planner should insert unsafe->safe conversions when required") { val plan = Limit(10, outputsUnsafe) - val preparedPlan = TestSQLContext.prepareForExecution.execute(plan) + val preparedPlan = ctx.prepareForExecution.execute(plan) assert(preparedPlan.children.head.isInstanceOf[ConvertToSafe]) } test("filter can process unsafe rows") { val plan = Filter(IsNull(IsNull(Literal(1))), outputsUnsafe) - val preparedPlan = TestSQLContext.prepareForExecution.execute(plan) + val preparedPlan = ctx.prepareForExecution.execute(plan) assert(getConverters(preparedPlan).size === 1) assert(preparedPlan.outputsUnsafeRows) } test("filter can process safe rows") { val plan = Filter(IsNull(IsNull(Literal(1))), outputsSafe) - val preparedPlan = TestSQLContext.prepareForExecution.execute(plan) + val preparedPlan = ctx.prepareForExecution.execute(plan) assert(getConverters(preparedPlan).isEmpty) assert(!preparedPlan.outputsUnsafeRows) } @@ -67,33 +67,33 @@ class RowFormatConvertersSuite extends SparkPlanTest { test("union requires all of its input rows' formats to agree") { val plan = Union(Seq(outputsSafe, outputsUnsafe)) assert(plan.canProcessSafeRows && plan.canProcessUnsafeRows) - val preparedPlan = TestSQLContext.prepareForExecution.execute(plan) + val preparedPlan = ctx.prepareForExecution.execute(plan) assert(preparedPlan.outputsUnsafeRows) } test("union can process safe rows") { val plan = Union(Seq(outputsSafe, outputsSafe)) - val preparedPlan = TestSQLContext.prepareForExecution.execute(plan) + val preparedPlan = ctx.prepareForExecution.execute(plan) assert(!preparedPlan.outputsUnsafeRows) } test("union can process unsafe rows") { val plan = Union(Seq(outputsUnsafe, outputsUnsafe)) - val preparedPlan = TestSQLContext.prepareForExecution.execute(plan) + val preparedPlan = ctx.prepareForExecution.execute(plan) assert(preparedPlan.outputsUnsafeRows) } test("round trip with ConvertToUnsafe and ConvertToSafe") { val input = Seq(("hello", 1), ("world", 2)) checkAnswer( - TestSQLContext.createDataFrame(input), + ctx.createDataFrame(input), plan => ConvertToSafe(ConvertToUnsafe(plan)), input.map(Row.fromTuple) ) } test("SPARK-9683: copy UTF8String when convert unsafe array/map to safe") { - SparkPlan.currentContext.set(TestSQLContext) + SparkPlan.currentContext.set(ctx) val schema = ArrayType(StringType) val rows = (1 to 100).map { i => InternalRow(new GenericArrayData(Array[Any](UTF8String.fromString(i.toString)))) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala index a2c10fdaf6cdb..8fa77b0fcb7b7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala @@ -19,8 +19,9 @@ package org.apache.spark.sql.execution import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.test.SharedSQLContext -class SortSuite extends SparkPlanTest { +class SortSuite extends SparkPlanTest with SharedSQLContext { // This test was originally added as an example of how to use [[SparkPlanTest]]; // it's not designed to be a comprehensive test of ExternalSort. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala index f46855edfe0de..3a87f374d94b0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala @@ -17,29 +17,27 @@ package org.apache.spark.sql.execution -import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute -import org.apache.spark.sql.catalyst.util._ -import org.apache.spark.sql.test.TestSQLContext -import org.apache.spark.sql.{SQLContext, DataFrame, DataFrameHolder, Row} - import scala.language.implicitConversions import scala.reflect.runtime.universe.TypeTag import scala.util.control.NonFatal +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.{DataFrame, DataFrameHolder, Row, SQLContext} +import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute +import org.apache.spark.sql.catalyst.util._ + /** * Base class for writing tests for individual physical operators. For an example of how this * class's test helper methods can be used, see [[SortSuite]]. */ -class SparkPlanTest extends SparkFunSuite { - - protected def sqlContext: SQLContext = TestSQLContext +private[sql] abstract class SparkPlanTest extends SparkFunSuite { + protected def _sqlContext: SQLContext /** * Creates a DataFrame from a local Seq of Product. */ implicit def localSeqToDataFrameHolder[A <: Product : TypeTag](data: Seq[A]): DataFrameHolder = { - sqlContext.implicits.localSeqToDataFrameHolder(data) + _sqlContext.implicits.localSeqToDataFrameHolder(data) } /** @@ -100,7 +98,7 @@ class SparkPlanTest extends SparkFunSuite { planFunction: Seq[SparkPlan] => SparkPlan, expectedAnswer: Seq[Row], sortAnswers: Boolean = true): Unit = { - SparkPlanTest.checkAnswer(input, planFunction, expectedAnswer, sortAnswers, sqlContext) match { + SparkPlanTest.checkAnswer(input, planFunction, expectedAnswer, sortAnswers, _sqlContext) match { case Some(errorMessage) => fail(errorMessage) case None => } @@ -124,7 +122,7 @@ class SparkPlanTest extends SparkFunSuite { expectedPlanFunction: SparkPlan => SparkPlan, sortAnswers: Boolean = true): Unit = { SparkPlanTest.checkAnswer( - input, planFunction, expectedPlanFunction, sortAnswers, sqlContext) match { + input, planFunction, expectedPlanFunction, sortAnswers, _sqlContext) match { case Some(errorMessage) => fail(errorMessage) case None => } @@ -151,13 +149,13 @@ object SparkPlanTest { planFunction: SparkPlan => SparkPlan, expectedPlanFunction: SparkPlan => SparkPlan, sortAnswers: Boolean, - sqlContext: SQLContext): Option[String] = { + _sqlContext: SQLContext): Option[String] = { val outputPlan = planFunction(input.queryExecution.sparkPlan) val expectedOutputPlan = expectedPlanFunction(input.queryExecution.sparkPlan) val expectedAnswer: Seq[Row] = try { - executePlan(expectedOutputPlan, sqlContext) + executePlan(expectedOutputPlan, _sqlContext) } catch { case NonFatal(e) => val errorMessage = @@ -172,7 +170,7 @@ object SparkPlanTest { } val actualAnswer: Seq[Row] = try { - executePlan(outputPlan, sqlContext) + executePlan(outputPlan, _sqlContext) } catch { case NonFatal(e) => val errorMessage = @@ -212,12 +210,12 @@ object SparkPlanTest { planFunction: Seq[SparkPlan] => SparkPlan, expectedAnswer: Seq[Row], sortAnswers: Boolean, - sqlContext: SQLContext): Option[String] = { + _sqlContext: SQLContext): Option[String] = { val outputPlan = planFunction(input.map(_.queryExecution.sparkPlan)) val sparkAnswer: Seq[Row] = try { - executePlan(outputPlan, sqlContext) + executePlan(outputPlan, _sqlContext) } catch { case NonFatal(e) => val errorMessage = @@ -280,10 +278,10 @@ object SparkPlanTest { } } - private def executePlan(outputPlan: SparkPlan, sqlContext: SQLContext): Seq[Row] = { + private def executePlan(outputPlan: SparkPlan, _sqlContext: SQLContext): Seq[Row] = { // A very simple resolver to make writing tests easier. In contrast to the real resolver // this is always case sensitive and does not try to handle scoping or complex type resolution. - val resolvedPlan = sqlContext.prepareForExecution.execute( + val resolvedPlan = _sqlContext.prepareForExecution.execute( outputPlan transform { case plan: SparkPlan => val inputMap = plan.children.flatMap(_.output).map(a => (a.name, a)).toMap diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/TungstenSortSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/TungstenSortSuite.scala index 88bce0e319f9e..3158458edb831 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/TungstenSortSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/TungstenSortSuite.scala @@ -19,25 +19,28 @@ package org.apache.spark.sql.execution import scala.util.Random -import org.scalatest.BeforeAndAfterAll - import org.apache.spark.AccumulatorSuite import org.apache.spark.sql.{RandomDataGenerator, Row, SQLConf} import org.apache.spark.sql.catalyst.dsl.expressions._ -import org.apache.spark.sql.test.TestSQLContext +import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ /** * A test suite that generates randomized data to test the [[TungstenSort]] operator. */ -class TungstenSortSuite extends SparkPlanTest with BeforeAndAfterAll { +class TungstenSortSuite extends SparkPlanTest with SharedSQLContext { override def beforeAll(): Unit = { - TestSQLContext.conf.setConf(SQLConf.CODEGEN_ENABLED, true) + super.beforeAll() + ctx.conf.setConf(SQLConf.CODEGEN_ENABLED, true) } override def afterAll(): Unit = { - TestSQLContext.conf.setConf(SQLConf.CODEGEN_ENABLED, SQLConf.CODEGEN_ENABLED.defaultValue.get) + try { + ctx.conf.setConf(SQLConf.CODEGEN_ENABLED, SQLConf.CODEGEN_ENABLED.defaultValue.get) + } finally { + super.afterAll() + } } test("sort followed by limit") { @@ -61,7 +64,7 @@ class TungstenSortSuite extends SparkPlanTest with BeforeAndAfterAll { } test("sorting updates peak execution memory") { - val sc = TestSQLContext.sparkContext + val sc = ctx.sparkContext AccumulatorSuite.verifyPeakExecutionMemorySet(sc, "unsafe external sort") { checkThatPlansAgree( (1 to 100).map(v => Tuple1(v)).toDF("a"), @@ -80,8 +83,8 @@ class TungstenSortSuite extends SparkPlanTest with BeforeAndAfterAll { ) { test(s"sorting on $dataType with nullable=$nullable, sortOrder=$sortOrder") { val inputData = Seq.fill(1000)(randomDataGenerator()) - val inputDf = TestSQLContext.createDataFrame( - TestSQLContext.sparkContext.parallelize(Random.shuffle(inputData).map(v => Row(v))), + val inputDf = ctx.createDataFrame( + ctx.sparkContext.parallelize(Random.shuffle(inputData).map(v => Row(v))), StructType(StructField("a", dataType, nullable = true) :: Nil) ) assert(TungstenSort.supportsSchema(inputDf.schema)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala index e03473041c3e9..d1f0b2b1fc52f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala @@ -26,7 +26,7 @@ import org.scalatest.Matchers import org.apache.spark.sql.catalyst.expressions.{UnsafeRow, UnsafeProjection} import org.apache.spark.{TaskContextImpl, TaskContext, SparkFunSuite} import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.test.TestSQLContext +import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ import org.apache.spark.unsafe.memory.{ExecutorMemoryManager, MemoryAllocator, TaskMemoryManager} import org.apache.spark.unsafe.types.UTF8String @@ -36,7 +36,10 @@ import org.apache.spark.unsafe.types.UTF8String * * Use [[testWithMemoryLeakDetection]] rather than [[test]] to construct test cases. */ -class UnsafeFixedWidthAggregationMapSuite extends SparkFunSuite with Matchers { +class UnsafeFixedWidthAggregationMapSuite + extends SparkFunSuite + with Matchers + with SharedSQLContext { import UnsafeFixedWidthAggregationMap._ @@ -171,9 +174,6 @@ class UnsafeFixedWidthAggregationMapSuite extends SparkFunSuite with Matchers { } testWithMemoryLeakDetection("test external sorting") { - // Calling this make sure we have block manager and everything else setup. - TestSQLContext - // Memory consumption in the beginning of the task. val initialMemoryConsumption = shuffleMemoryManager.getMemoryConsumptionForThisTask() @@ -233,8 +233,6 @@ class UnsafeFixedWidthAggregationMapSuite extends SparkFunSuite with Matchers { } testWithMemoryLeakDetection("test external sorting with an empty map") { - // Calling this make sure we have block manager and everything else setup. - TestSQLContext val map = new UnsafeFixedWidthAggregationMap( emptyAggregationBuffer, @@ -282,8 +280,6 @@ class UnsafeFixedWidthAggregationMapSuite extends SparkFunSuite with Matchers { } testWithMemoryLeakDetection("test external sorting with empty records") { - // Calling this make sure we have block manager and everything else setup. - TestSQLContext // Memory consumption in the beginning of the task. val initialMemoryConsumption = shuffleMemoryManager.getMemoryConsumptionForThisTask() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala index a9515a03acf2c..d3be568a8758c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala @@ -23,15 +23,14 @@ import org.apache.spark._ import org.apache.spark.sql.{RandomDataGenerator, Row} import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.expressions.{InterpretedOrdering, UnsafeRow, UnsafeProjection} -import org.apache.spark.sql.test.TestSQLContext +import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ import org.apache.spark.unsafe.memory.{ExecutorMemoryManager, MemoryAllocator, TaskMemoryManager} /** * Test suite for [[UnsafeKVExternalSorter]], with randomly generated test data. */ -class UnsafeKVExternalSorterSuite extends SparkFunSuite { - +class UnsafeKVExternalSorterSuite extends SparkFunSuite with SharedSQLContext { private val keyTypes = Seq(IntegerType, FloatType, DoubleType, StringType) private val valueTypes = Seq(IntegerType, FloatType, DoubleType, StringType) @@ -109,8 +108,6 @@ class UnsafeKVExternalSorterSuite extends SparkFunSuite { inputData: Seq[(InternalRow, InternalRow)], pageSize: Long, spill: Boolean): Unit = { - // Calling this make sure we have block manager and everything else setup. - TestSQLContext val taskMemMgr = new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP)) val shuffleMemMgr = new TestShuffleMemoryManager diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIteratorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIteratorSuite.scala index ac22c2f3c0a58..5fdb82b067728 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIteratorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIteratorSuite.scala @@ -21,15 +21,12 @@ import org.apache.spark._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.InterpretedMutableProjection import org.apache.spark.sql.execution.metric.SQLMetrics -import org.apache.spark.sql.test.TestSQLContext +import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.unsafe.memory.TaskMemoryManager -class TungstenAggregationIteratorSuite extends SparkFunSuite { +class TungstenAggregationIteratorSuite extends SparkFunSuite with SharedSQLContext { test("memory acquired on construction") { - // set up environment - val ctx = TestSQLContext - val taskMemoryManager = new TaskMemoryManager(SparkEnv.get.executorMemoryManager) val taskContext = new TaskContextImpl(0, 0, 0, 0, taskMemoryManager, null, Seq.empty) TaskContext.setTaskContext(taskContext) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala index 73d5621897819..1174b27732f22 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala @@ -24,22 +24,16 @@ import com.fasterxml.jackson.core.JsonFactory import org.apache.spark.rdd.RDD import org.scalactic.Tolerance._ -import org.apache.spark.sql.{SQLContext, QueryTest, Row, SQLConf} -import org.apache.spark.sql.TestData._ +import org.apache.spark.sql.{QueryTest, Row, SQLConf} import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution.datasources.{ResolvedDataSource, LogicalRelation} import org.apache.spark.sql.execution.datasources.json.InferSchema.compatibleType +import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ -import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.util.Utils -class JsonSuite extends QueryTest with SQLTestUtils with TestJsonData { - - protected lazy val ctx = org.apache.spark.sql.test.TestSQLContext - override def sqlContext: SQLContext = ctx // used by SQLTestUtils - - import ctx.sql - import ctx.implicits._ +class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { + import testImplicits._ test("Type promotion") { def checkTypePromotion(expected: Any, actual: Any) { @@ -596,7 +590,8 @@ class JsonSuite extends QueryTest with SQLTestUtils with TestJsonData { val schema = StructType(StructField("a", LongType, true) :: Nil) val logicalRelation = - ctx.read.schema(schema).json(path).queryExecution.analyzed.asInstanceOf[LogicalRelation] + ctx.read.schema(schema).json(path) + .queryExecution.analyzed.asInstanceOf[LogicalRelation] val relationWithSchema = logicalRelation.relation.asInstanceOf[JSONRelation] assert(relationWithSchema.paths === Array(path)) assert(relationWithSchema.schema === schema) @@ -1040,31 +1035,29 @@ class JsonSuite extends QueryTest with SQLTestUtils with TestJsonData { } test("JSONRelation equality test") { - val context = org.apache.spark.sql.test.TestSQLContext - val relation0 = new JSONRelation( Some(empty), 1.0, Some(StructType(StructField("a", IntegerType, true) :: Nil)), - None, None)(context) + None, None)(ctx) val logicalRelation0 = LogicalRelation(relation0) val relation1 = new JSONRelation( Some(singleRow), 1.0, Some(StructType(StructField("a", IntegerType, true) :: Nil)), - None, None)(context) + None, None)(ctx) val logicalRelation1 = LogicalRelation(relation1) val relation2 = new JSONRelation( Some(singleRow), 0.5, Some(StructType(StructField("a", IntegerType, true) :: Nil)), - None, None)(context) + None, None)(ctx) val logicalRelation2 = LogicalRelation(relation2) val relation3 = new JSONRelation( Some(singleRow), 1.0, Some(StructType(StructField("b", IntegerType, true) :: Nil)), - None, None)(context) + None, None)(ctx) val logicalRelation3 = LogicalRelation(relation3) assert(relation0 !== relation1) @@ -1089,14 +1082,14 @@ class JsonSuite extends QueryTest with SQLTestUtils with TestJsonData { .map(i => s"""{"a": 1, "b": "str$i"}""").saveAsTextFile(path) val d1 = ResolvedDataSource( - context, + ctx, userSpecifiedSchema = None, partitionColumns = Array.empty[String], provider = classOf[DefaultSource].getCanonicalName, options = Map("path" -> path)) val d2 = ResolvedDataSource( - context, + ctx, userSpecifiedSchema = None, partitionColumns = Array.empty[String], provider = classOf[DefaultSource].getCanonicalName, @@ -1162,11 +1155,12 @@ class JsonSuite extends QueryTest with SQLTestUtils with TestJsonData { "abd") ctx.read.json(root.getAbsolutePath).registerTempTable("test_myjson_with_part") - checkAnswer( - sql("SELECT count(a) FROM test_myjson_with_part where d1 = 1 and col1='abc'"), Row(4)) - checkAnswer( - sql("SELECT count(a) FROM test_myjson_with_part where d1 = 1 and col1='abd'"), Row(5)) - checkAnswer(sql("SELECT count(a) FROM test_myjson_with_part where d1 = 1"), Row(9)) + checkAnswer(sql( + "SELECT count(a) FROM test_myjson_with_part where d1 = 1 and col1='abc'"), Row(4)) + checkAnswer(sql( + "SELECT count(a) FROM test_myjson_with_part where d1 = 1 and col1='abd'"), Row(5)) + checkAnswer(sql( + "SELECT count(a) FROM test_myjson_with_part where d1 = 1"), Row(9)) }) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/TestJsonData.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/TestJsonData.scala index 6b62c9a003df6..2864181cf91d5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/TestJsonData.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/TestJsonData.scala @@ -20,12 +20,11 @@ package org.apache.spark.sql.execution.datasources.json import org.apache.spark.rdd.RDD import org.apache.spark.sql.SQLContext -trait TestJsonData { - - protected def ctx: SQLContext +private[json] trait TestJsonData { + protected def _sqlContext: SQLContext def primitiveFieldAndType: RDD[String] = - ctx.sparkContext.parallelize( + _sqlContext.sparkContext.parallelize( """{"string":"this is a simple string.", "integer":10, "long":21474836470, @@ -36,7 +35,7 @@ trait TestJsonData { }""" :: Nil) def primitiveFieldValueTypeConflict: RDD[String] = - ctx.sparkContext.parallelize( + _sqlContext.sparkContext.parallelize( """{"num_num_1":11, "num_num_2":null, "num_num_3": 1.1, "num_bool":true, "num_str":13.1, "str_bool":"str1"}""" :: """{"num_num_1":null, "num_num_2":21474836470.9, "num_num_3": null, @@ -47,14 +46,14 @@ trait TestJsonData { "num_bool":null, "num_str":92233720368547758070, "str_bool":null}""" :: Nil) def jsonNullStruct: RDD[String] = - ctx.sparkContext.parallelize( + _sqlContext.sparkContext.parallelize( """{"nullstr":"","ip":"27.31.100.29","headers":{"Host":"1.abc.com","Charset":"UTF-8"}}""" :: """{"nullstr":"","ip":"27.31.100.29","headers":{}}""" :: """{"nullstr":"","ip":"27.31.100.29","headers":""}""" :: """{"nullstr":null,"ip":"27.31.100.29","headers":null}""" :: Nil) def complexFieldValueTypeConflict: RDD[String] = - ctx.sparkContext.parallelize( + _sqlContext.sparkContext.parallelize( """{"num_struct":11, "str_array":[1, 2, 3], "array":[], "struct_array":[], "struct": {}}""" :: """{"num_struct":{"field":false}, "str_array":null, @@ -65,14 +64,14 @@ trait TestJsonData { "array":[7], "struct_array":{"field": true}, "struct": {"field": "str"}}""" :: Nil) def arrayElementTypeConflict: RDD[String] = - ctx.sparkContext.parallelize( + _sqlContext.sparkContext.parallelize( """{"array1": [1, 1.1, true, null, [], {}, [2,3,4], {"field":"str"}], "array2": [{"field":214748364700}, {"field":1}]}""" :: """{"array3": [{"field":"str"}, {"field":1}]}""" :: """{"array3": [1, 2, 3]}""" :: Nil) def missingFields: RDD[String] = - ctx.sparkContext.parallelize( + _sqlContext.sparkContext.parallelize( """{"a":true}""" :: """{"b":21474836470}""" :: """{"c":[33, 44]}""" :: @@ -80,7 +79,7 @@ trait TestJsonData { """{"e":"str"}""" :: Nil) def complexFieldAndType1: RDD[String] = - ctx.sparkContext.parallelize( + _sqlContext.sparkContext.parallelize( """{"struct":{"field1": true, "field2": 92233720368547758070}, "structWithArrayFields":{"field1":[4, 5, 6], "field2":["str1", "str2"]}, "arrayOfString":["str1", "str2"], @@ -96,7 +95,7 @@ trait TestJsonData { }""" :: Nil) def complexFieldAndType2: RDD[String] = - ctx.sparkContext.parallelize( + _sqlContext.sparkContext.parallelize( """{"arrayOfStruct":[{"field1": true, "field2": "str1"}, {"field1": false}, {"field3": null}], "complexArrayOfStruct": [ { @@ -150,7 +149,7 @@ trait TestJsonData { }""" :: Nil) def mapType1: RDD[String] = - ctx.sparkContext.parallelize( + _sqlContext.sparkContext.parallelize( """{"map": {"a": 1}}""" :: """{"map": {"b": 2}}""" :: """{"map": {"c": 3}}""" :: @@ -158,7 +157,7 @@ trait TestJsonData { """{"map": {"e": null}}""" :: Nil) def mapType2: RDD[String] = - ctx.sparkContext.parallelize( + _sqlContext.sparkContext.parallelize( """{"map": {"a": {"field1": [1, 2, 3, null]}}}""" :: """{"map": {"b": {"field2": 2}}}""" :: """{"map": {"c": {"field1": [], "field2": 4}}}""" :: @@ -167,21 +166,21 @@ trait TestJsonData { """{"map": {"f": {"field1": null}}}""" :: Nil) def nullsInArrays: RDD[String] = - ctx.sparkContext.parallelize( + _sqlContext.sparkContext.parallelize( """{"field1":[[null], [[["Test"]]]]}""" :: """{"field2":[null, [{"Test":1}]]}""" :: """{"field3":[[null], [{"Test":"2"}]]}""" :: """{"field4":[[null, [1,2,3]]]}""" :: Nil) def jsonArray: RDD[String] = - ctx.sparkContext.parallelize( + _sqlContext.sparkContext.parallelize( """[{"a":"str_a_1"}]""" :: """[{"a":"str_a_2"}, {"b":"str_b_3"}]""" :: """{"b":"str_b_4", "a":"str_a_4", "c":"str_c_4"}""" :: """[]""" :: Nil) def corruptRecords: RDD[String] = - ctx.sparkContext.parallelize( + _sqlContext.sparkContext.parallelize( """{""" :: """""" :: """{"a":1, b:2}""" :: @@ -190,7 +189,7 @@ trait TestJsonData { """]""" :: Nil) def emptyRecords: RDD[String] = - ctx.sparkContext.parallelize( + _sqlContext.sparkContext.parallelize( """{""" :: """""" :: """{"a": {}}""" :: @@ -198,9 +197,8 @@ trait TestJsonData { """{"b": [{"c": {}}]}""" :: """]""" :: Nil) - lazy val singleRow: RDD[String] = - ctx.sparkContext.parallelize( - """{"a":123}""" :: Nil) - def empty: RDD[String] = ctx.sparkContext.parallelize(Seq[String]()) + lazy val singleRow: RDD[String] = _sqlContext.sparkContext.parallelize("""{"a":123}""" :: Nil) + + def empty: RDD[String] = _sqlContext.sparkContext.parallelize(Seq[String]()) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetAvroCompatibilitySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetAvroCompatibilitySuite.scala index 866a975ad5404..82d40e2b61a10 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetAvroCompatibilitySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetAvroCompatibilitySuite.scala @@ -27,18 +27,16 @@ import org.apache.avro.generic.IndexedRecord import org.apache.hadoop.fs.Path import org.apache.parquet.avro.AvroParquetWriter -import org.apache.spark.sql.execution.datasources.parquet.test.avro.{Nested, ParquetAvroCompat, ParquetEnum, Suit} -import org.apache.spark.sql.test.TestSQLContext -import org.apache.spark.sql.{Row, SQLContext} +import org.apache.spark.sql.Row +import org.apache.spark.sql.execution.datasources.parquet.test.avro._ +import org.apache.spark.sql.test.SharedSQLContext -class ParquetAvroCompatibilitySuite extends ParquetCompatibilityTest { +class ParquetAvroCompatibilitySuite extends ParquetCompatibilityTest with SharedSQLContext { import ParquetCompatibilityTest._ - override val sqlContext: SQLContext = TestSQLContext - private def withWriter[T <: IndexedRecord] (path: String, schema: Schema) - (f: AvroParquetWriter[T] => Unit) = { + (f: AvroParquetWriter[T] => Unit): Unit = { val writer = new AvroParquetWriter[T](new Path(path), schema) try f(writer) finally writer.close() } @@ -129,7 +127,7 @@ class ParquetAvroCompatibilitySuite extends ParquetCompatibilityTest { } test("SPARK-9407 Don't push down predicates involving Parquet ENUM columns") { - import sqlContext.implicits._ + import testImplicits._ withTempPath { dir => val path = dir.getCanonicalPath diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCompatibilityTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCompatibilityTest.scala index 0ea64aa2a509b..b3406729fcc5e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCompatibilityTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCompatibilityTest.scala @@ -22,16 +22,18 @@ import scala.collection.JavaConversions._ import org.apache.hadoop.fs.{Path, PathFilter} import org.apache.parquet.hadoop.ParquetFileReader import org.apache.parquet.schema.MessageType -import org.scalatest.BeforeAndAfterAll import org.apache.spark.sql.QueryTest -abstract class ParquetCompatibilityTest extends QueryTest with ParquetTest with BeforeAndAfterAll { - def readParquetSchema(path: String): MessageType = { +/** + * Helper class for testing Parquet compatibility. + */ +private[sql] abstract class ParquetCompatibilityTest extends QueryTest with ParquetTest { + protected def readParquetSchema(path: String): MessageType = { readParquetSchema(path, { path => !path.getName.startsWith("_") }) } - def readParquetSchema(path: String, pathFilter: Path => Boolean): MessageType = { + protected def readParquetSchema(path: String, pathFilter: Path => Boolean): MessageType = { val fsPath = new Path(path) val fs = fsPath.getFileSystem(configuration) val parquetFiles = fs.listStatus(fsPath, new PathFilter { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala index 7dd9680d8cd65..5b4e568bb9838 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala @@ -20,12 +20,13 @@ package org.apache.spark.sql.execution.datasources.parquet import org.apache.parquet.filter2.predicate.Operators._ import org.apache.parquet.filter2.predicate.{FilterPredicate, Operators} +import org.apache.spark.sql.{Column, DataFrame, QueryTest, Row, SQLConf} import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.execution.datasources.LogicalRelation +import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ -import org.apache.spark.sql.{Column, DataFrame, QueryTest, Row, SQLConf} /** * A test suite that tests Parquet filter2 API based filter pushdown optimization. @@ -39,8 +40,7 @@ import org.apache.spark.sql.{Column, DataFrame, QueryTest, Row, SQLConf} * 2. `Tuple1(Option(x))` is used together with `AnyVal` types like `Int` to ensure the inferred * data type is nullable. */ -class ParquetFilterSuite extends QueryTest with ParquetTest { - lazy val sqlContext = org.apache.spark.sql.test.TestSQLContext +class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContext { private def checkFilterPredicate( df: DataFrame, @@ -301,7 +301,7 @@ class ParquetFilterSuite extends QueryTest with ParquetTest { } test("SPARK-6554: don't push down predicates which reference partition columns") { - import sqlContext.implicits._ + import testImplicits._ withSQLConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED.key -> "true") { withTempPath { dir => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala index cb166349fdb26..d819f3ab5e6ab 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala @@ -37,6 +37,7 @@ import org.apache.spark.SparkException import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ // Write support class for nested groups: ParquetWriter initializes GroupWriteSupport @@ -62,9 +63,8 @@ private[parquet] class TestGroupWriteSupport(schema: MessageType) extends WriteS /** * A test suite that tests basic Parquet I/O. */ -class ParquetIOSuite extends QueryTest with ParquetTest { - lazy val sqlContext = org.apache.spark.sql.test.TestSQLContext - import sqlContext.implicits._ +class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { + import testImplicits._ /** * Writes `data` to a Parquet file, reads it back and check file contents. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala index 73152de244759..ed8bafb10c60b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala @@ -26,13 +26,13 @@ import scala.collection.mutable.ArrayBuffer import com.google.common.io.Files import org.apache.hadoop.fs.Path +import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Literal import org.apache.spark.sql.execution.datasources.{LogicalRelation, PartitionSpec, Partition, PartitioningUtils} +import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ -import org.apache.spark.sql._ import org.apache.spark.unsafe.types.UTF8String -import PartitioningUtils._ // The data where the partitioning key exists only in the directory structure. case class ParquetData(intField: Int, stringField: String) @@ -40,11 +40,9 @@ case class ParquetData(intField: Int, stringField: String) // The data that also includes the partitioning key case class ParquetDataWithKey(intField: Int, pi: Int, stringField: String, ps: String) -class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest { - - override lazy val sqlContext: SQLContext = org.apache.spark.sql.test.TestSQLContext - import sqlContext.implicits._ - import sqlContext.sql +class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with SharedSQLContext { + import PartitioningUtils._ + import testImplicits._ val defaultPartitionName = "__HIVE_DEFAULT_PARTITION__" diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetProtobufCompatibilitySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetProtobufCompatibilitySuite.scala index 981334cf771cf..b290429c2a021 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetProtobufCompatibilitySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetProtobufCompatibilitySuite.scala @@ -17,11 +17,10 @@ package org.apache.spark.sql.execution.datasources.parquet -import org.apache.spark.sql.test.TestSQLContext -import org.apache.spark.sql.{DataFrame, Row, SQLContext} +import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.test.SharedSQLContext -class ParquetProtobufCompatibilitySuite extends ParquetCompatibilityTest { - override def sqlContext: SQLContext = TestSQLContext +class ParquetProtobufCompatibilitySuite extends ParquetCompatibilityTest with SharedSQLContext { private def readParquetProtobufFile(name: String): DataFrame = { val url = Thread.currentThread().getContextClassLoader.getResource(name) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala index 5e6d9c1cd44a8..e2f2a8c744783 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala @@ -21,16 +21,15 @@ import java.io.File import org.apache.hadoop.fs.Path -import org.apache.spark.sql.types._ import org.apache.spark.sql.{QueryTest, Row, SQLConf} +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types._ import org.apache.spark.util.Utils /** * A test suite that tests various Parquet queries. */ -class ParquetQuerySuite extends QueryTest with ParquetTest { - lazy val sqlContext = org.apache.spark.sql.test.TestSQLContext - import sqlContext.sql +class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext { test("simple select queries") { withParquetTable((0 until 10).map(i => (i, i.toString)), "t") { @@ -41,22 +40,22 @@ class ParquetQuerySuite extends QueryTest with ParquetTest { test("appending") { val data = (0 until 10).map(i => (i, i.toString)) - sqlContext.createDataFrame(data).toDF("c1", "c2").registerTempTable("tmp") + ctx.createDataFrame(data).toDF("c1", "c2").registerTempTable("tmp") withParquetTable(data, "t") { sql("INSERT INTO TABLE t SELECT * FROM tmp") - checkAnswer(sqlContext.table("t"), (data ++ data).map(Row.fromTuple)) + checkAnswer(ctx.table("t"), (data ++ data).map(Row.fromTuple)) } - sqlContext.catalog.unregisterTable(Seq("tmp")) + ctx.catalog.unregisterTable(Seq("tmp")) } test("overwriting") { val data = (0 until 10).map(i => (i, i.toString)) - sqlContext.createDataFrame(data).toDF("c1", "c2").registerTempTable("tmp") + ctx.createDataFrame(data).toDF("c1", "c2").registerTempTable("tmp") withParquetTable(data, "t") { sql("INSERT OVERWRITE TABLE t SELECT * FROM tmp") - checkAnswer(sqlContext.table("t"), data.map(Row.fromTuple)) + checkAnswer(ctx.table("t"), data.map(Row.fromTuple)) } - sqlContext.catalog.unregisterTable(Seq("tmp")) + ctx.catalog.unregisterTable(Seq("tmp")) } test("self-join") { @@ -119,9 +118,9 @@ class ParquetQuerySuite extends QueryTest with ParquetTest { val schema = StructType(List(StructField("d", DecimalType(18, 0), false), StructField("time", TimestampType, false)).toArray) withTempPath { file => - val df = sqlContext.createDataFrame(sqlContext.sparkContext.parallelize(data), schema) + val df = ctx.createDataFrame(ctx.sparkContext.parallelize(data), schema) df.write.parquet(file.getCanonicalPath) - val df2 = sqlContext.read.parquet(file.getCanonicalPath) + val df2 = ctx.read.parquet(file.getCanonicalPath) checkAnswer(df2, df.collect().toSeq) } } @@ -130,12 +129,12 @@ class ParquetQuerySuite extends QueryTest with ParquetTest { def testSchemaMerging(expectedColumnNumber: Int): Unit = { withTempDir { dir => val basePath = dir.getCanonicalPath - sqlContext.range(0, 10).toDF("a").write.parquet(new Path(basePath, "foo=1").toString) - sqlContext.range(0, 10).toDF("b").write.parquet(new Path(basePath, "foo=2").toString) + ctx.range(0, 10).toDF("a").write.parquet(new Path(basePath, "foo=1").toString) + ctx.range(0, 10).toDF("b").write.parquet(new Path(basePath, "foo=2").toString) // delete summary files, so if we don't merge part-files, one column will not be included. Utils.deleteRecursively(new File(basePath + "/foo=1/_metadata")) Utils.deleteRecursively(new File(basePath + "/foo=1/_common_metadata")) - assert(sqlContext.read.parquet(basePath).columns.length === expectedColumnNumber) + assert(ctx.read.parquet(basePath).columns.length === expectedColumnNumber) } } @@ -154,9 +153,9 @@ class ParquetQuerySuite extends QueryTest with ParquetTest { def testSchemaMerging(expectedColumnNumber: Int): Unit = { withTempDir { dir => val basePath = dir.getCanonicalPath - sqlContext.range(0, 10).toDF("a").write.parquet(new Path(basePath, "foo=1").toString) - sqlContext.range(0, 10).toDF("b").write.parquet(new Path(basePath, "foo=2").toString) - assert(sqlContext.read.parquet(basePath).columns.length === expectedColumnNumber) + ctx.range(0, 10).toDF("a").write.parquet(new Path(basePath, "foo=1").toString) + ctx.range(0, 10).toDF("b").write.parquet(new Path(basePath, "foo=2").toString) + assert(ctx.read.parquet(basePath).columns.length === expectedColumnNumber) } } @@ -172,19 +171,19 @@ class ParquetQuerySuite extends QueryTest with ParquetTest { test("SPARK-8990 DataFrameReader.parquet() should respect user specified options") { withTempPath { dir => val basePath = dir.getCanonicalPath - sqlContext.range(0, 10).toDF("a").write.parquet(new Path(basePath, "foo=1").toString) - sqlContext.range(0, 10).toDF("b").write.parquet(new Path(basePath, "foo=a").toString) + ctx.range(0, 10).toDF("a").write.parquet(new Path(basePath, "foo=1").toString) + ctx.range(0, 10).toDF("b").write.parquet(new Path(basePath, "foo=a").toString) // Disables the global SQL option for schema merging withSQLConf(SQLConf.PARQUET_SCHEMA_MERGING_ENABLED.key -> "false") { assertResult(2) { // Disables schema merging via data source option - sqlContext.read.option("mergeSchema", "false").parquet(basePath).columns.length + ctx.read.option("mergeSchema", "false").parquet(basePath).columns.length } assertResult(3) { // Enables schema merging via data source option - sqlContext.read.option("mergeSchema", "true").parquet(basePath).columns.length + ctx.read.option("mergeSchema", "true").parquet(basePath).columns.length } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala index 971f71e27bfc6..9dcbc1a047bea 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala @@ -22,13 +22,11 @@ import scala.reflect.runtime.universe.TypeTag import org.apache.parquet.schema.MessageTypeParser -import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.ScalaReflection -import org.apache.spark.sql.test.TestSQLContext +import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ -abstract class ParquetSchemaTest extends SparkFunSuite with ParquetTest { - val sqlContext = TestSQLContext +abstract class ParquetSchemaTest extends ParquetTest with SharedSQLContext { /** * Checks whether the reflected Parquet message type for product type `T` conforms `messageType`. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTest.scala index 3c6e54db4bca7..5dbc7d1630f27 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTest.scala @@ -22,9 +22,8 @@ import java.io.File import scala.reflect.ClassTag import scala.reflect.runtime.universe.TypeTag -import org.apache.spark.SparkFunSuite import org.apache.spark.sql.test.SQLTestUtils -import org.apache.spark.sql.{DataFrame, SaveMode} +import org.apache.spark.sql.{DataFrame, SaveMode, SQLContext} /** * A helper trait that provides convenient facilities for Parquet testing. @@ -33,7 +32,9 @@ import org.apache.spark.sql.{DataFrame, SaveMode} * convenient to use tuples rather than special case classes when writing test cases/suites. * Especially, `Tuple1.apply` can be used to easily wrap a single type/value. */ -private[sql] trait ParquetTest extends SQLTestUtils { this: SparkFunSuite => +private[sql] trait ParquetTest extends SQLTestUtils { + protected def _sqlContext: SQLContext + /** * Writes `data` to a Parquet file, which is then passed to `f` and will be deleted after `f` * returns. @@ -42,7 +43,7 @@ private[sql] trait ParquetTest extends SQLTestUtils { this: SparkFunSuite => (data: Seq[T]) (f: String => Unit): Unit = { withTempPath { file => - sqlContext.createDataFrame(data).write.parquet(file.getCanonicalPath) + _sqlContext.createDataFrame(data).write.parquet(file.getCanonicalPath) f(file.getCanonicalPath) } } @@ -54,7 +55,7 @@ private[sql] trait ParquetTest extends SQLTestUtils { this: SparkFunSuite => protected def withParquetDataFrame[T <: Product: ClassTag: TypeTag] (data: Seq[T]) (f: DataFrame => Unit): Unit = { - withParquetFile(data)(path => f(sqlContext.read.parquet(path))) + withParquetFile(data)(path => f(_sqlContext.read.parquet(path))) } /** @@ -66,14 +67,14 @@ private[sql] trait ParquetTest extends SQLTestUtils { this: SparkFunSuite => (data: Seq[T], tableName: String) (f: => Unit): Unit = { withParquetDataFrame(data) { df => - sqlContext.registerDataFrameAsTable(df, tableName) + _sqlContext.registerDataFrameAsTable(df, tableName) withTempTable(tableName)(f) } } protected def makeParquetFile[T <: Product: ClassTag: TypeTag]( data: Seq[T], path: File): Unit = { - sqlContext.createDataFrame(data).write.mode(SaveMode.Overwrite).parquet(path.getCanonicalPath) + _sqlContext.createDataFrame(data).write.mode(SaveMode.Overwrite).parquet(path.getCanonicalPath) } protected def makeParquetFile[T <: Product: ClassTag: TypeTag]( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetThriftCompatibilitySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetThriftCompatibilitySuite.scala index 92b1d822172d5..b789c5a106e56 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetThriftCompatibilitySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetThriftCompatibilitySuite.scala @@ -17,14 +17,12 @@ package org.apache.spark.sql.execution.datasources.parquet -import org.apache.spark.sql.test.TestSQLContext -import org.apache.spark.sql.{Row, SQLContext} +import org.apache.spark.sql.Row +import org.apache.spark.sql.test.SharedSQLContext -class ParquetThriftCompatibilitySuite extends ParquetCompatibilityTest { +class ParquetThriftCompatibilitySuite extends ParquetCompatibilityTest with SharedSQLContext { import ParquetCompatibilityTest._ - override val sqlContext: SQLContext = TestSQLContext - private val parquetFilePath = Thread.currentThread().getContextClassLoader.getResource("parquet-thrift-compat.snappy.parquet") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala index 239deb7973845..22189477d277d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala @@ -18,10 +18,10 @@ package org.apache.spark.sql.execution.debug import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.TestData._ -import org.apache.spark.sql.test.TestSQLContext._ +import org.apache.spark.sql.test.SharedSQLContext + +class DebuggingSuite extends SparkFunSuite with SharedSQLContext { -class DebuggingSuite extends SparkFunSuite { test("DataFrame.debug()") { testData.debug() } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala index d33a967093ca5..4c9187a9a7106 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala @@ -23,12 +23,12 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.execution.metric.SQLMetrics -import org.apache.spark.sql.test.TestSQLContext +import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.{IntegerType, StructField, StructType} import org.apache.spark.util.collection.CompactBuffer -class HashedRelationSuite extends SparkFunSuite { +class HashedRelationSuite extends SparkFunSuite with SharedSQLContext { // Key is simply the record itself private val keyProjection = new Projection { @@ -37,7 +37,7 @@ class HashedRelationSuite extends SparkFunSuite { test("GeneralHashedRelation") { val data = Array(InternalRow(0), InternalRow(1), InternalRow(2), InternalRow(2)) - val numDataRows = SQLMetrics.createLongMetric(TestSQLContext.sparkContext, "data") + val numDataRows = SQLMetrics.createLongMetric(ctx.sparkContext, "data") val hashed = HashedRelation(data.iterator, numDataRows, keyProjection) assert(hashed.isInstanceOf[GeneralHashedRelation]) @@ -53,7 +53,7 @@ class HashedRelationSuite extends SparkFunSuite { test("UniqueKeyHashedRelation") { val data = Array(InternalRow(0), InternalRow(1), InternalRow(2)) - val numDataRows = SQLMetrics.createLongMetric(TestSQLContext.sparkContext, "data") + val numDataRows = SQLMetrics.createLongMetric(ctx.sparkContext, "data") val hashed = HashedRelation(data.iterator, numDataRows, keyProjection) assert(hashed.isInstanceOf[UniqueKeyHashedRelation]) @@ -73,7 +73,7 @@ class HashedRelationSuite extends SparkFunSuite { test("UnsafeHashedRelation") { val schema = StructType(StructField("a", IntegerType, true) :: Nil) val data = Array(InternalRow(0), InternalRow(1), InternalRow(2), InternalRow(2)) - val numDataRows = SQLMetrics.createLongMetric(TestSQLContext.sparkContext, "data") + val numDataRows = SQLMetrics.createLongMetric(ctx.sparkContext, "data") val toUnsafe = UnsafeProjection.create(schema) val unsafeData = data.map(toUnsafe(_).copy()).toArray diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala index ddff7cebcc17d..cc649b9bd4c45 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala @@ -17,97 +17,19 @@ package org.apache.spark.sql.execution.joins +import org.apache.spark.sql.{DataFrame, execution, Row, SQLConf} +import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys import org.apache.spark.sql.catalyst.plans.Inner import org.apache.spark.sql.catalyst.plans.logical.Join -import org.apache.spark.sql.test.SQLTestUtils -import org.apache.spark.sql.types.{IntegerType, StringType, StructType} -import org.apache.spark.sql.{SQLConf, execution, Row, DataFrame} -import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.execution._ +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types.{IntegerType, StringType, StructType} -class InnerJoinSuite extends SparkPlanTest with SQLTestUtils { - - private def testInnerJoin( - testName: String, - leftRows: DataFrame, - rightRows: DataFrame, - condition: Expression, - expectedAnswer: Seq[Product]): Unit = { - val join = Join(leftRows.logicalPlan, rightRows.logicalPlan, Inner, Some(condition)) - ExtractEquiJoinKeys.unapply(join).foreach { - case (joinType, leftKeys, rightKeys, boundCondition, leftChild, rightChild) => - - def makeBroadcastHashJoin(left: SparkPlan, right: SparkPlan, side: BuildSide) = { - val broadcastHashJoin = - execution.joins.BroadcastHashJoin(leftKeys, rightKeys, side, left, right) - boundCondition.map(Filter(_, broadcastHashJoin)).getOrElse(broadcastHashJoin) - } - - def makeShuffledHashJoin(left: SparkPlan, right: SparkPlan, side: BuildSide) = { - val shuffledHashJoin = - execution.joins.ShuffledHashJoin(leftKeys, rightKeys, side, left, right) - val filteredJoin = - boundCondition.map(Filter(_, shuffledHashJoin)).getOrElse(shuffledHashJoin) - EnsureRequirements(sqlContext).apply(filteredJoin) - } - - def makeSortMergeJoin(left: SparkPlan, right: SparkPlan) = { - val sortMergeJoin = - execution.joins.SortMergeJoin(leftKeys, rightKeys, left, right) - val filteredJoin = boundCondition.map(Filter(_, sortMergeJoin)).getOrElse(sortMergeJoin) - EnsureRequirements(sqlContext).apply(filteredJoin) - } - - test(s"$testName using BroadcastHashJoin (build=left)") { - withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { - checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => - makeBroadcastHashJoin(left, right, joins.BuildLeft), - expectedAnswer.map(Row.fromTuple), - sortAnswers = true) - } - } - - test(s"$testName using BroadcastHashJoin (build=right)") { - withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { - checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => - makeBroadcastHashJoin(left, right, joins.BuildRight), - expectedAnswer.map(Row.fromTuple), - sortAnswers = true) - } - } - - test(s"$testName using ShuffledHashJoin (build=left)") { - withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { - checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => - makeShuffledHashJoin(left, right, joins.BuildLeft), - expectedAnswer.map(Row.fromTuple), - sortAnswers = true) - } - } - - test(s"$testName using ShuffledHashJoin (build=right)") { - withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { - checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => - makeShuffledHashJoin(left, right, joins.BuildRight), - expectedAnswer.map(Row.fromTuple), - sortAnswers = true) - } - } +class InnerJoinSuite extends SparkPlanTest with SharedSQLContext { - test(s"$testName using SortMergeJoin") { - withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { - checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => - makeSortMergeJoin(left, right), - expectedAnswer.map(Row.fromTuple), - sortAnswers = true) - } - } - } - } - - { - val upperCaseData = sqlContext.createDataFrame(sqlContext.sparkContext.parallelize(Seq( + private lazy val myUpperCaseData = ctx.createDataFrame( + ctx.sparkContext.parallelize(Seq( Row(1, "A"), Row(2, "B"), Row(3, "C"), @@ -117,7 +39,8 @@ class InnerJoinSuite extends SparkPlanTest with SQLTestUtils { Row(null, "G") )), new StructType().add("N", IntegerType).add("L", StringType)) - val lowerCaseData = sqlContext.createDataFrame(sqlContext.sparkContext.parallelize(Seq( + private lazy val myLowerCaseData = ctx.createDataFrame( + ctx.sparkContext.parallelize(Seq( Row(1, "a"), Row(2, "b"), Row(3, "c"), @@ -125,21 +48,7 @@ class InnerJoinSuite extends SparkPlanTest with SQLTestUtils { Row(null, "e") )), new StructType().add("n", IntegerType).add("l", StringType)) - testInnerJoin( - "inner join, one match per row", - upperCaseData, - lowerCaseData, - (upperCaseData.col("N") === lowerCaseData.col("n")).expr, - Seq( - (1, "A", 1, "a"), - (2, "B", 2, "b"), - (3, "C", 3, "c"), - (4, "D", 4, "d") - ) - ) - } - - private val testData2 = Seq( + private lazy val myTestData = Seq( (1, 1), (1, 2), (2, 1), @@ -148,14 +57,139 @@ class InnerJoinSuite extends SparkPlanTest with SQLTestUtils { (3, 2) ).toDF("a", "b") + // Note: the input dataframes and expression must be evaluated lazily because + // the SQLContext should be used only within a test to keep SQL tests stable + private def testInnerJoin( + testName: String, + leftRows: => DataFrame, + rightRows: => DataFrame, + condition: () => Expression, + expectedAnswer: Seq[Product]): Unit = { + + def extractJoinParts(): Option[ExtractEquiJoinKeys.ReturnType] = { + val join = Join(leftRows.logicalPlan, rightRows.logicalPlan, Inner, Some(condition())) + ExtractEquiJoinKeys.unapply(join) + } + + def makeBroadcastHashJoin( + leftKeys: Seq[Expression], + rightKeys: Seq[Expression], + boundCondition: Option[Expression], + leftPlan: SparkPlan, + rightPlan: SparkPlan, + side: BuildSide) = { + val broadcastHashJoin = + execution.joins.BroadcastHashJoin(leftKeys, rightKeys, side, leftPlan, rightPlan) + boundCondition.map(Filter(_, broadcastHashJoin)).getOrElse(broadcastHashJoin) + } + + def makeShuffledHashJoin( + leftKeys: Seq[Expression], + rightKeys: Seq[Expression], + boundCondition: Option[Expression], + leftPlan: SparkPlan, + rightPlan: SparkPlan, + side: BuildSide) = { + val shuffledHashJoin = + execution.joins.ShuffledHashJoin(leftKeys, rightKeys, side, leftPlan, rightPlan) + val filteredJoin = + boundCondition.map(Filter(_, shuffledHashJoin)).getOrElse(shuffledHashJoin) + EnsureRequirements(sqlContext).apply(filteredJoin) + } + + def makeSortMergeJoin( + leftKeys: Seq[Expression], + rightKeys: Seq[Expression], + boundCondition: Option[Expression], + leftPlan: SparkPlan, + rightPlan: SparkPlan) = { + val sortMergeJoin = + execution.joins.SortMergeJoin(leftKeys, rightKeys, leftPlan, rightPlan) + val filteredJoin = boundCondition.map(Filter(_, sortMergeJoin)).getOrElse(sortMergeJoin) + EnsureRequirements(sqlContext).apply(filteredJoin) + } + + test(s"$testName using BroadcastHashJoin (build=left)") { + extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) => + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + checkAnswer2(leftRows, rightRows, (leftPlan: SparkPlan, rightPlan: SparkPlan) => + makeBroadcastHashJoin( + leftKeys, rightKeys, boundCondition, leftPlan, rightPlan, joins.BuildLeft), + expectedAnswer.map(Row.fromTuple), + sortAnswers = true) + } + } + } + + test(s"$testName using BroadcastHashJoin (build=right)") { + extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) => + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + checkAnswer2(leftRows, rightRows, (leftPlan: SparkPlan, rightPlan: SparkPlan) => + makeBroadcastHashJoin( + leftKeys, rightKeys, boundCondition, leftPlan, rightPlan, joins.BuildRight), + expectedAnswer.map(Row.fromTuple), + sortAnswers = true) + } + } + } + + test(s"$testName using ShuffledHashJoin (build=left)") { + extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) => + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + checkAnswer2(leftRows, rightRows, (leftPlan: SparkPlan, rightPlan: SparkPlan) => + makeShuffledHashJoin( + leftKeys, rightKeys, boundCondition, leftPlan, rightPlan, joins.BuildLeft), + expectedAnswer.map(Row.fromTuple), + sortAnswers = true) + } + } + } + + test(s"$testName using ShuffledHashJoin (build=right)") { + extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) => + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + checkAnswer2(leftRows, rightRows, (leftPlan: SparkPlan, rightPlan: SparkPlan) => + makeShuffledHashJoin( + leftKeys, rightKeys, boundCondition, leftPlan, rightPlan, joins.BuildRight), + expectedAnswer.map(Row.fromTuple), + sortAnswers = true) + } + } + } + + test(s"$testName using SortMergeJoin") { + extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) => + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + checkAnswer2(leftRows, rightRows, (leftPlan: SparkPlan, rightPlan: SparkPlan) => + makeSortMergeJoin(leftKeys, rightKeys, boundCondition, leftPlan, rightPlan), + expectedAnswer.map(Row.fromTuple), + sortAnswers = true) + } + } + } + } + + testInnerJoin( + "inner join, one match per row", + myUpperCaseData, + myLowerCaseData, + () => (myUpperCaseData.col("N") === myLowerCaseData.col("n")).expr, + Seq( + (1, "A", 1, "a"), + (2, "B", 2, "b"), + (3, "C", 3, "c"), + (4, "D", 4, "d") + ) + ) + { - val left = testData2.where("a = 1") - val right = testData2.where("a = 1") + lazy val left = myTestData.where("a = 1") + lazy val right = myTestData.where("a = 1") testInnerJoin( "inner join, multiple matches", left, right, - (left.col("a") === right.col("a")).expr, + () => (left.col("a") === right.col("a")).expr, Seq( (1, 1, 1, 1), (1, 1, 1, 2), @@ -166,13 +200,13 @@ class InnerJoinSuite extends SparkPlanTest with SQLTestUtils { } { - val left = testData2.where("a = 1") - val right = testData2.where("a = 2") + lazy val left = myTestData.where("a = 1") + lazy val right = myTestData.where("a = 2") testInnerJoin( "inner join, no matches", left, right, - (left.col("a") === right.col("a")).expr, + () => (left.col("a") === right.col("a")).expr, Seq.empty ) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala index e16f5e39aa2f4..a1a617d7b7398 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala @@ -17,28 +17,65 @@ package org.apache.spark.sql.execution.joins +import org.apache.spark.sql.{DataFrame, Row, SQLConf} import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys +import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical.Join -import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.sql.catalyst.expressions.{And, Expression, LessThan} +import org.apache.spark.sql.execution.{EnsureRequirements, SparkPlan, SparkPlanTest} +import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.{IntegerType, DoubleType, StructType} -import org.apache.spark.sql.{SQLConf, DataFrame, Row} -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans._ -import org.apache.spark.sql.execution.{EnsureRequirements, joins, SparkPlan, SparkPlanTest} -class OuterJoinSuite extends SparkPlanTest with SQLTestUtils { +class OuterJoinSuite extends SparkPlanTest with SharedSQLContext { + + private lazy val left = ctx.createDataFrame( + ctx.sparkContext.parallelize(Seq( + Row(1, 2.0), + Row(2, 100.0), + Row(2, 1.0), // This row is duplicated to ensure that we will have multiple buffered matches + Row(2, 1.0), + Row(3, 3.0), + Row(5, 1.0), + Row(6, 6.0), + Row(null, null) + )), new StructType().add("a", IntegerType).add("b", DoubleType)) + + private lazy val right = ctx.createDataFrame( + ctx.sparkContext.parallelize(Seq( + Row(0, 0.0), + Row(2, 3.0), // This row is duplicated to ensure that we will have multiple buffered matches + Row(2, -1.0), + Row(2, -1.0), + Row(2, 3.0), + Row(3, 2.0), + Row(4, 1.0), + Row(5, 3.0), + Row(7, 7.0), + Row(null, null) + )), new StructType().add("c", IntegerType).add("d", DoubleType)) + + private lazy val condition = { + And((left.col("a") === right.col("c")).expr, + LessThan(left.col("b").expr, right.col("d").expr)) + } + // Note: the input dataframes and expression must be evaluated lazily because + // the SQLContext should be used only within a test to keep SQL tests stable private def testOuterJoin( testName: String, - leftRows: DataFrame, - rightRows: DataFrame, + leftRows: => DataFrame, + rightRows: => DataFrame, joinType: JoinType, - condition: Expression, + condition: => Expression, expectedAnswer: Seq[Product]): Unit = { - val join = Join(leftRows.logicalPlan, rightRows.logicalPlan, Inner, Some(condition)) - ExtractEquiJoinKeys.unapply(join).foreach { - case (_, leftKeys, rightKeys, boundCondition, leftChild, rightChild) => - test(s"$testName using ShuffledHashOuterJoin") { + + def extractJoinParts(): Option[ExtractEquiJoinKeys.ReturnType] = { + val join = Join(leftRows.logicalPlan, rightRows.logicalPlan, Inner, Some(condition)) + ExtractEquiJoinKeys.unapply(join) + } + + test(s"$testName using ShuffledHashOuterJoin") { + extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) => withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => EnsureRequirements(sqlContext).apply( @@ -46,19 +83,23 @@ class OuterJoinSuite extends SparkPlanTest with SQLTestUtils { expectedAnswer.map(Row.fromTuple), sortAnswers = true) } - } + } + } - if (joinType != FullOuter) { - test(s"$testName using BroadcastHashOuterJoin") { + if (joinType != FullOuter) { + test(s"$testName using BroadcastHashOuterJoin") { + extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) => withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => BroadcastHashOuterJoin(leftKeys, rightKeys, joinType, boundCondition, left, right), expectedAnswer.map(Row.fromTuple), sortAnswers = true) } - } + } + } - test(s"$testName using SortMergeOuterJoin") { + test(s"$testName using SortMergeOuterJoin") { + extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) => withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => EnsureRequirements(sqlContext).apply( @@ -66,57 +107,9 @@ class OuterJoinSuite extends SparkPlanTest with SQLTestUtils { expectedAnswer.map(Row.fromTuple), sortAnswers = false) } - } } - } - - test(s"$testName using BroadcastNestedLoopJoin (build=left)") { - withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { - checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => - joins.BroadcastNestedLoopJoin(left, right, joins.BuildLeft, joinType, Some(condition)), - expectedAnswer.map(Row.fromTuple), - sortAnswers = true) } } - - test(s"$testName using BroadcastNestedLoopJoin (build=right)") { - withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { - checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => - joins.BroadcastNestedLoopJoin(left, right, joins.BuildRight, joinType, Some(condition)), - expectedAnswer.map(Row.fromTuple), - sortAnswers = true) - } - } - } - - val left = sqlContext.createDataFrame(sqlContext.sparkContext.parallelize(Seq( - Row(1, 2.0), - Row(2, 100.0), - Row(2, 1.0), // This row is duplicated to ensure that we will have multiple buffered matches - Row(2, 1.0), - Row(3, 3.0), - Row(5, 1.0), - Row(6, 6.0), - Row(null, null) - )), new StructType().add("a", IntegerType).add("b", DoubleType)) - - val right = sqlContext.createDataFrame(sqlContext.sparkContext.parallelize(Seq( - Row(0, 0.0), - Row(2, 3.0), // This row is duplicated to ensure that we will have multiple buffered matches - Row(2, -1.0), - Row(2, -1.0), - Row(2, 3.0), - Row(3, 2.0), - Row(4, 1.0), - Row(5, 3.0), - Row(7, 7.0), - Row(null, null) - )), new StructType().add("c", IntegerType).add("d", DoubleType)) - - val condition = { - And( - (left.col("a") === right.col("c")).expr, - LessThan(left.col("b").expr, right.col("d").expr)) } // --- Basic outer joins ------------------------------------------------------------------------ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala index 4503ed251fcb1..baa86e320d986 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala @@ -17,44 +17,80 @@ package org.apache.spark.sql.execution.joins +import org.apache.spark.sql.{SQLConf, DataFrame, Row} import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys import org.apache.spark.sql.catalyst.plans.Inner import org.apache.spark.sql.catalyst.plans.logical.Join -import org.apache.spark.sql.test.SQLTestUtils -import org.apache.spark.sql.types.{DoubleType, IntegerType, StructType} -import org.apache.spark.sql.{SQLConf, DataFrame, Row} import org.apache.spark.sql.catalyst.expressions.{And, LessThan, Expression} import org.apache.spark.sql.execution.{EnsureRequirements, SparkPlan, SparkPlanTest} +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types.{DoubleType, IntegerType, StructType} + +class SemiJoinSuite extends SparkPlanTest with SharedSQLContext { -class SemiJoinSuite extends SparkPlanTest with SQLTestUtils { + private lazy val left = ctx.createDataFrame( + ctx.sparkContext.parallelize(Seq( + Row(1, 2.0), + Row(1, 2.0), + Row(2, 1.0), + Row(2, 1.0), + Row(3, 3.0), + Row(null, null), + Row(null, 5.0), + Row(6, null) + )), new StructType().add("a", IntegerType).add("b", DoubleType)) + private lazy val right = ctx.createDataFrame( + ctx.sparkContext.parallelize(Seq( + Row(2, 3.0), + Row(2, 3.0), + Row(3, 2.0), + Row(4, 1.0), + Row(null, null), + Row(null, 5.0), + Row(6, null) + )), new StructType().add("c", IntegerType).add("d", DoubleType)) + + private lazy val condition = { + And((left.col("a") === right.col("c")).expr, + LessThan(left.col("b").expr, right.col("d").expr)) + } + + // Note: the input dataframes and expression must be evaluated lazily because + // the SQLContext should be used only within a test to keep SQL tests stable private def testLeftSemiJoin( testName: String, - leftRows: DataFrame, - rightRows: DataFrame, - condition: Expression, + leftRows: => DataFrame, + rightRows: => DataFrame, + condition: => Expression, expectedAnswer: Seq[Product]): Unit = { - val join = Join(leftRows.logicalPlan, rightRows.logicalPlan, Inner, Some(condition)) - ExtractEquiJoinKeys.unapply(join).foreach { - case (joinType, leftKeys, rightKeys, boundCondition, leftChild, rightChild) => - test(s"$testName using LeftSemiJoinHash") { - withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { - checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => - EnsureRequirements(left.sqlContext).apply( - LeftSemiJoinHash(leftKeys, rightKeys, left, right, boundCondition)), - expectedAnswer.map(Row.fromTuple), - sortAnswers = true) - } + + def extractJoinParts(): Option[ExtractEquiJoinKeys.ReturnType] = { + val join = Join(leftRows.logicalPlan, rightRows.logicalPlan, Inner, Some(condition)) + ExtractEquiJoinKeys.unapply(join) + } + + test(s"$testName using LeftSemiJoinHash") { + extractJoinParts().foreach { case (joinType, leftKeys, rightKeys, boundCondition, _, _) => + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => + EnsureRequirements(left.sqlContext).apply( + LeftSemiJoinHash(leftKeys, rightKeys, left, right, boundCondition)), + expectedAnswer.map(Row.fromTuple), + sortAnswers = true) } + } + } - test(s"$testName using BroadcastLeftSemiJoinHash") { - withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { - checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => - BroadcastLeftSemiJoinHash(leftKeys, rightKeys, left, right, boundCondition), - expectedAnswer.map(Row.fromTuple), - sortAnswers = true) - } + test(s"$testName using BroadcastLeftSemiJoinHash") { + extractJoinParts().foreach { case (joinType, leftKeys, rightKeys, boundCondition, _, _) => + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => + BroadcastLeftSemiJoinHash(leftKeys, rightKeys, left, right, boundCondition), + expectedAnswer.map(Row.fromTuple), + sortAnswers = true) } + } } test(s"$testName using LeftSemiJoinBNL") { @@ -67,33 +103,6 @@ class SemiJoinSuite extends SparkPlanTest with SQLTestUtils { } } - val left = sqlContext.createDataFrame(sqlContext.sparkContext.parallelize(Seq( - Row(1, 2.0), - Row(1, 2.0), - Row(2, 1.0), - Row(2, 1.0), - Row(3, 3.0), - Row(null, null), - Row(null, 5.0), - Row(6, null) - )), new StructType().add("a", IntegerType).add("b", DoubleType)) - - val right = sqlContext.createDataFrame(sqlContext.sparkContext.parallelize(Seq( - Row(2, 3.0), - Row(2, 3.0), - Row(3, 2.0), - Row(4, 1.0), - Row(null, null), - Row(null, 5.0), - Row(6, null) - )), new StructType().add("c", IntegerType).add("d", DoubleType)) - - val condition = { - And( - (left.col("a") === right.col("c")).expr, - LessThan(left.col("b").expr, right.col("d").expr)) - } - testLeftSemiJoin( "basic test", left, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala index 7383d3f8fe024..80006bf077fe8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala @@ -28,17 +28,15 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.sql._ import org.apache.spark.sql.execution.ui.SparkPlanGraph import org.apache.spark.sql.functions._ -import org.apache.spark.sql.test.{SQLTestUtils, TestSQLContext} +import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.util.Utils -class SQLMetricsSuite extends SparkFunSuite with SQLTestUtils { - override val sqlContext = TestSQLContext - - import sqlContext.implicits._ +class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext { + import testImplicits._ test("LongSQLMetric should not box Long") { - val l = SQLMetrics.createLongMetric(TestSQLContext.sparkContext, "long") + val l = SQLMetrics.createLongMetric(ctx.sparkContext, "long") val f = () => { l += 1L l.add(1L) @@ -52,7 +50,7 @@ class SQLMetricsSuite extends SparkFunSuite with SQLTestUtils { test("Normal accumulator should do boxing") { // We need this test to make sure BoxingFinder works. - val l = TestSQLContext.sparkContext.accumulator(0L) + val l = ctx.sparkContext.accumulator(0L) val f = () => { l += 1L } BoxingFinder.getClassReader(f.getClass).foreach { cl => val boxingFinder = new BoxingFinder() @@ -73,19 +71,19 @@ class SQLMetricsSuite extends SparkFunSuite with SQLTestUtils { df: DataFrame, expectedNumOfJobs: Int, expectedMetrics: Map[Long, (String, Map[String, Any])]): Unit = { - val previousExecutionIds = TestSQLContext.listener.executionIdToData.keySet + val previousExecutionIds = ctx.listener.executionIdToData.keySet df.collect() - TestSQLContext.sparkContext.listenerBus.waitUntilEmpty(10000) - val executionIds = TestSQLContext.listener.executionIdToData.keySet.diff(previousExecutionIds) + ctx.sparkContext.listenerBus.waitUntilEmpty(10000) + val executionIds = ctx.listener.executionIdToData.keySet.diff(previousExecutionIds) assert(executionIds.size === 1) val executionId = executionIds.head - val jobs = TestSQLContext.listener.getExecution(executionId).get.jobs + val jobs = ctx.listener.getExecution(executionId).get.jobs // Use "<=" because there is a race condition that we may miss some jobs // TODO Change it to "=" once we fix the race condition that missing the JobStarted event. assert(jobs.size <= expectedNumOfJobs) if (jobs.size == expectedNumOfJobs) { // If we can track all jobs, check the metric values - val metricValues = TestSQLContext.listener.getExecutionMetrics(executionId) + val metricValues = ctx.listener.getExecutionMetrics(executionId) val actualMetrics = SparkPlanGraph(df.queryExecution.executedPlan).nodes.filter { node => expectedMetrics.contains(node.id) }.map { node => @@ -111,7 +109,7 @@ class SQLMetricsSuite extends SparkFunSuite with SQLTestUtils { SQLConf.TUNGSTEN_ENABLED.key -> "false") { // Assume the execution plan is // PhysicalRDD(nodeId = 1) -> Project(nodeId = 0) - val df = TestData.person.select('name) + val df = person.select('name) testSparkPlanMetrics(df, 1, Map( 0L ->("Project", Map( "number of rows" -> 2L))) @@ -126,7 +124,7 @@ class SQLMetricsSuite extends SparkFunSuite with SQLTestUtils { SQLConf.TUNGSTEN_ENABLED.key -> "true") { // Assume the execution plan is // PhysicalRDD(nodeId = 1) -> TungstenProject(nodeId = 0) - val df = TestData.person.select('name) + val df = person.select('name) testSparkPlanMetrics(df, 1, Map( 0L ->("TungstenProject", Map( "number of rows" -> 2L))) @@ -137,7 +135,7 @@ class SQLMetricsSuite extends SparkFunSuite with SQLTestUtils { test("Filter metrics") { // Assume the execution plan is // PhysicalRDD(nodeId = 1) -> Filter(nodeId = 0) - val df = TestData.person.filter('age < 25) + val df = person.filter('age < 25) testSparkPlanMetrics(df, 1, Map( 0L -> ("Filter", Map( "number of input rows" -> 2L, @@ -152,7 +150,7 @@ class SQLMetricsSuite extends SparkFunSuite with SQLTestUtils { SQLConf.TUNGSTEN_ENABLED.key -> "false") { // Assume the execution plan is // ... -> Aggregate(nodeId = 2) -> TungstenExchange(nodeId = 1) -> Aggregate(nodeId = 0) - val df = TestData.testData2.groupBy().count() // 2 partitions + val df = testData2.groupBy().count() // 2 partitions testSparkPlanMetrics(df, 1, Map( 2L -> ("Aggregate", Map( "number of input rows" -> 6L, @@ -163,7 +161,7 @@ class SQLMetricsSuite extends SparkFunSuite with SQLTestUtils { ) // 2 partitions and each partition contains 2 keys - val df2 = TestData.testData2.groupBy('a).count() + val df2 = testData2.groupBy('a).count() testSparkPlanMetrics(df2, 1, Map( 2L -> ("Aggregate", Map( "number of input rows" -> 6L, @@ -185,7 +183,7 @@ class SQLMetricsSuite extends SparkFunSuite with SQLTestUtils { // Assume the execution plan is // ... -> SortBasedAggregate(nodeId = 2) -> TungstenExchange(nodeId = 1) -> // SortBasedAggregate(nodeId = 0) - val df = TestData.testData2.groupBy().count() // 2 partitions + val df = testData2.groupBy().count() // 2 partitions testSparkPlanMetrics(df, 1, Map( 2L -> ("SortBasedAggregate", Map( "number of input rows" -> 6L, @@ -199,7 +197,7 @@ class SQLMetricsSuite extends SparkFunSuite with SQLTestUtils { // ... -> SortBasedAggregate(nodeId = 3) -> TungstenExchange(nodeId = 2) // -> ExternalSort(nodeId = 1)-> SortBasedAggregate(nodeId = 0) // 2 partitions and each partition contains 2 keys - val df2 = TestData.testData2.groupBy('a).count() + val df2 = testData2.groupBy('a).count() testSparkPlanMetrics(df2, 1, Map( 3L -> ("SortBasedAggregate", Map( "number of input rows" -> 6L, @@ -219,7 +217,7 @@ class SQLMetricsSuite extends SparkFunSuite with SQLTestUtils { // Assume the execution plan is // ... -> TungstenAggregate(nodeId = 2) -> Exchange(nodeId = 1) // -> TungstenAggregate(nodeId = 0) - val df = TestData.testData2.groupBy().count() // 2 partitions + val df = testData2.groupBy().count() // 2 partitions testSparkPlanMetrics(df, 1, Map( 2L -> ("TungstenAggregate", Map( "number of input rows" -> 6L, @@ -230,7 +228,7 @@ class SQLMetricsSuite extends SparkFunSuite with SQLTestUtils { ) // 2 partitions and each partition contains 2 keys - val df2 = TestData.testData2.groupBy('a).count() + val df2 = testData2.groupBy('a).count() testSparkPlanMetrics(df2, 1, Map( 2L -> ("TungstenAggregate", Map( "number of input rows" -> 6L, @@ -246,7 +244,7 @@ class SQLMetricsSuite extends SparkFunSuite with SQLTestUtils { // Because SortMergeJoin may skip different rows if the number of partitions is different, this // test should use the deterministic number of partitions. withSQLConf(SQLConf.SORTMERGE_JOIN.key -> "true") { - val testDataForJoin = TestData.testData2.filter('a < 2) // TestData2(1, 1) :: TestData2(1, 2) + val testDataForJoin = testData2.filter('a < 2) // TestData2(1, 1) :: TestData2(1, 2) testDataForJoin.registerTempTable("testDataForJoin") withTempTable("testDataForJoin") { // Assume the execution plan is @@ -268,7 +266,7 @@ class SQLMetricsSuite extends SparkFunSuite with SQLTestUtils { // Because SortMergeOuterJoin may skip different rows if the number of partitions is different, // this test should use the deterministic number of partitions. withSQLConf(SQLConf.SORTMERGE_JOIN.key -> "true") { - val testDataForJoin = TestData.testData2.filter('a < 2) // TestData2(1, 1) :: TestData2(1, 2) + val testDataForJoin = testData2.filter('a < 2) // TestData2(1, 1) :: TestData2(1, 2) testDataForJoin.registerTempTable("testDataForJoin") withTempTable("testDataForJoin") { // Assume the execution plan is @@ -314,7 +312,7 @@ class SQLMetricsSuite extends SparkFunSuite with SQLTestUtils { test("ShuffledHashJoin metrics") { withSQLConf(SQLConf.SORTMERGE_JOIN.key -> "false") { - val testDataForJoin = TestData.testData2.filter('a < 2) // TestData2(1, 1) :: TestData2(1, 2) + val testDataForJoin = testData2.filter('a < 2) // TestData2(1, 1) :: TestData2(1, 2) testDataForJoin.registerTempTable("testDataForJoin") withTempTable("testDataForJoin") { // Assume the execution plan is @@ -390,7 +388,7 @@ class SQLMetricsSuite extends SparkFunSuite with SQLTestUtils { test("BroadcastNestedLoopJoin metrics") { withSQLConf(SQLConf.SORTMERGE_JOIN.key -> "true") { - val testDataForJoin = TestData.testData2.filter('a < 2) // TestData2(1, 1) :: TestData2(1, 2) + val testDataForJoin = testData2.filter('a < 2) // TestData2(1, 1) :: TestData2(1, 2) testDataForJoin.registerTempTable("testDataForJoin") withTempTable("testDataForJoin") { // Assume the execution plan is @@ -458,7 +456,7 @@ class SQLMetricsSuite extends SparkFunSuite with SQLTestUtils { } test("CartesianProduct metrics") { - val testDataForJoin = TestData.testData2.filter('a < 2) // TestData2(1, 1) :: TestData2(1, 2) + val testDataForJoin = testData2.filter('a < 2) // TestData2(1, 1) :: TestData2(1, 2) testDataForJoin.registerTempTable("testDataForJoin") withTempTable("testDataForJoin") { // Assume the execution plan is @@ -476,19 +474,19 @@ class SQLMetricsSuite extends SparkFunSuite with SQLTestUtils { test("save metrics") { withTempPath { file => - val previousExecutionIds = TestSQLContext.listener.executionIdToData.keySet + val previousExecutionIds = ctx.listener.executionIdToData.keySet // Assume the execution plan is // PhysicalRDD(nodeId = 0) - TestData.person.select('name).write.format("json").save(file.getAbsolutePath) - TestSQLContext.sparkContext.listenerBus.waitUntilEmpty(10000) - val executionIds = TestSQLContext.listener.executionIdToData.keySet.diff(previousExecutionIds) + person.select('name).write.format("json").save(file.getAbsolutePath) + ctx.sparkContext.listenerBus.waitUntilEmpty(10000) + val executionIds = ctx.listener.executionIdToData.keySet.diff(previousExecutionIds) assert(executionIds.size === 1) val executionId = executionIds.head - val jobs = TestSQLContext.listener.getExecution(executionId).get.jobs + val jobs = ctx.listener.getExecution(executionId).get.jobs // Use "<=" because there is a race condition that we may miss some jobs // TODO Change "<=" to "=" once we fix the race condition that missing the JobStarted event. assert(jobs.size <= 1) - val metricValues = TestSQLContext.listener.getExecutionMetrics(executionId) + val metricValues = ctx.listener.getExecutionMetrics(executionId) // Because "save" will create a new DataFrame internally, we cannot get the real metric id. // However, we still can check the value. assert(metricValues.values.toSeq === Seq(2L)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala index 41dd1896c15df..80d1e88956949 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala @@ -25,12 +25,12 @@ import org.apache.spark.sql.execution.metric.LongSQLMetricValue import org.apache.spark.scheduler._ import org.apache.spark.sql.{DataFrame, SQLContext} import org.apache.spark.sql.execution.SQLExecution -import org.apache.spark.sql.test.TestSQLContext +import org.apache.spark.sql.test.SharedSQLContext -class SQLListenerSuite extends SparkFunSuite { +class SQLListenerSuite extends SparkFunSuite with SharedSQLContext { + import testImplicits._ private def createTestDataFrame: DataFrame = { - import TestSQLContext.implicits._ Seq( (1, 1), (2, 2) @@ -74,7 +74,7 @@ class SQLListenerSuite extends SparkFunSuite { } test("basic") { - val listener = new SQLListener(TestSQLContext) + val listener = new SQLListener(ctx) val executionId = 0 val df = createTestDataFrame val accumulatorIds = @@ -212,7 +212,7 @@ class SQLListenerSuite extends SparkFunSuite { } test("onExecutionEnd happens before onJobEnd(JobSucceeded)") { - val listener = new SQLListener(TestSQLContext) + val listener = new SQLListener(ctx) val executionId = 0 val df = createTestDataFrame listener.onExecutionStart( @@ -241,7 +241,7 @@ class SQLListenerSuite extends SparkFunSuite { } test("onExecutionEnd happens before multiple onJobEnd(JobSucceeded)s") { - val listener = new SQLListener(TestSQLContext) + val listener = new SQLListener(ctx) val executionId = 0 val df = createTestDataFrame listener.onExecutionStart( @@ -281,7 +281,7 @@ class SQLListenerSuite extends SparkFunSuite { } test("onExecutionEnd happens before onJobEnd(JobFailed)") { - val listener = new SQLListener(TestSQLContext) + val listener = new SQLListener(ctx) val executionId = 0 val df = createTestDataFrame listener.onExecutionStart( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index e4dcf4c75d208..0edac0848c3bb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -25,10 +25,13 @@ import org.h2.jdbc.JdbcSQLException import org.scalatest.BeforeAndAfter import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ import org.apache.spark.util.Utils -class JDBCSuite extends SparkFunSuite with BeforeAndAfter { +class JDBCSuite extends SparkFunSuite with BeforeAndAfter with SharedSQLContext { + import testImplicits._ + val url = "jdbc:h2:mem:testdb0" val urlWithUserAndPass = "jdbc:h2:mem:testdb0;user=testUser;password=testPass" var conn: java.sql.Connection = null @@ -42,10 +45,6 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter { Some(StringType) } - private lazy val ctx = org.apache.spark.sql.test.TestSQLContext - import ctx.implicits._ - import ctx.sql - before { Utils.classForName("org.h2.Driver") // Extra properties that will be specified for our database. We need these to test diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala index 84b52ca2c733c..5dc3a2c07b8c7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala @@ -23,11 +23,13 @@ import java.util.Properties import org.scalatest.BeforeAndAfter import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.{SaveMode, Row} +import org.apache.spark.sql.{Row, SaveMode} +import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ import org.apache.spark.util.Utils -class JDBCWriteSuite extends SparkFunSuite with BeforeAndAfter { +class JDBCWriteSuite extends SparkFunSuite with BeforeAndAfter with SharedSQLContext { + val url = "jdbc:h2:mem:testdb2" var conn: java.sql.Connection = null val url1 = "jdbc:h2:mem:testdb3" @@ -37,10 +39,6 @@ class JDBCWriteSuite extends SparkFunSuite with BeforeAndAfter { properties.setProperty("password", "testPass") properties.setProperty("rowId", "false") - private lazy val ctx = org.apache.spark.sql.test.TestSQLContext - import ctx.implicits._ - import ctx.sql - before { Utils.classForName("org.h2.Driver") conn = DriverManager.getConnection(url) @@ -58,14 +56,14 @@ class JDBCWriteSuite extends SparkFunSuite with BeforeAndAfter { "create table test.people1 (name TEXT(32) NOT NULL, theid INTEGER NOT NULL)").executeUpdate() conn1.commit() - ctx.sql( + sql( s""" |CREATE TEMPORARY TABLE PEOPLE |USING org.apache.spark.sql.jdbc |OPTIONS (url '$url1', dbtable 'TEST.PEOPLE', user 'testUser', password 'testPass') """.stripMargin.replaceAll("\n", " ")) - ctx.sql( + sql( s""" |CREATE TEMPORARY TABLE PEOPLE1 |USING org.apache.spark.sql.jdbc @@ -144,14 +142,14 @@ class JDBCWriteSuite extends SparkFunSuite with BeforeAndAfter { } test("INSERT to JDBC Datasource") { - ctx.sql("INSERT INTO TABLE PEOPLE1 SELECT * FROM PEOPLE") + sql("INSERT INTO TABLE PEOPLE1 SELECT * FROM PEOPLE") assert(2 === ctx.read.jdbc(url1, "TEST.PEOPLE1", properties).count) assert(2 === ctx.read.jdbc(url1, "TEST.PEOPLE1", properties).collect()(0).length) } test("INSERT to JDBC Datasource with overwrite") { - ctx.sql("INSERT INTO TABLE PEOPLE1 SELECT * FROM PEOPLE") - ctx.sql("INSERT OVERWRITE TABLE PEOPLE1 SELECT * FROM PEOPLE") + sql("INSERT INTO TABLE PEOPLE1 SELECT * FROM PEOPLE") + sql("INSERT OVERWRITE TABLE PEOPLE1 SELECT * FROM PEOPLE") assert(2 === ctx.read.jdbc(url1, "TEST.PEOPLE1", properties).count) assert(2 === ctx.read.jdbc(url1, "TEST.PEOPLE1", properties).collect()(0).length) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala index 562c279067048..9bc3f6bcf6fce 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala @@ -19,28 +19,32 @@ package org.apache.spark.sql.sources import java.io.{File, IOException} -import org.scalatest.BeforeAndAfterAll +import org.scalatest.BeforeAndAfter import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.execution.datasources.DDLException +import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.util.Utils -class CreateTableAsSelectSuite extends DataSourceTest with BeforeAndAfterAll { - - import caseInsensitiveContext.sql +class CreateTableAsSelectSuite extends DataSourceTest with SharedSQLContext with BeforeAndAfter { + protected override lazy val sql = caseInsensitiveContext.sql _ private lazy val sparkContext = caseInsensitiveContext.sparkContext - - var path: File = null + private var path: File = null override def beforeAll(): Unit = { + super.beforeAll() path = Utils.createTempDir() val rdd = sparkContext.parallelize((1 to 10).map(i => s"""{"a":$i, "b":"str${i}"}""")) caseInsensitiveContext.read.json(rdd).registerTempTable("jt") } override def afterAll(): Unit = { - caseInsensitiveContext.dropTempTable("jt") + try { + caseInsensitiveContext.dropTempTable("jt") + } finally { + super.afterAll() + } } after { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLSourceLoadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLSourceLoadSuite.scala index 392da0b0826b5..853707c036c9a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLSourceLoadSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLSourceLoadSuite.scala @@ -18,11 +18,12 @@ package org.apache.spark.sql.sources import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.{StringType, StructField, StructType} // please note that the META-INF/services had to be modified for the test directory for this to work -class DDLSourceLoadSuite extends DataSourceTest { +class DDLSourceLoadSuite extends DataSourceTest with SharedSQLContext { test("data sources with the same name") { intercept[RuntimeException] { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLTestSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLTestSuite.scala index 84855ce45e918..5f8514e1a2411 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLTestSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLTestSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.sources import org.apache.spark.rdd.RDD import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -68,10 +69,12 @@ case class SimpleDDLScan(from: Int, to: Int, table: String)(@transient val sqlCo } } -class DDLTestSuite extends DataSourceTest { +class DDLTestSuite extends DataSourceTest with SharedSQLContext { + protected override lazy val sql = caseInsensitiveContext.sql _ - before { - caseInsensitiveContext.sql( + override def beforeAll(): Unit = { + super.beforeAll() + sql( """ |CREATE TEMPORARY TABLE ddlPeople |USING org.apache.spark.sql.sources.DDLScanSource @@ -105,7 +108,7 @@ class DDLTestSuite extends DataSourceTest { )) test("SPARK-7686 DescribeCommand should have correct physical plan output attributes") { - val attributes = caseInsensitiveContext.sql("describe ddlPeople") + val attributes = sql("describe ddlPeople") .queryExecution.executedPlan.output assert(attributes.map(_.name) === Seq("col_name", "data_type", "comment")) assert(attributes.map(_.dataType).toSet === Set(StringType)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala index 00cc7d5ea580f..d74d29fb0beb0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala @@ -17,18 +17,23 @@ package org.apache.spark.sql.sources -import org.scalatest.BeforeAndAfter - import org.apache.spark.sql._ -import org.apache.spark.sql.test.TestSQLContext -abstract class DataSourceTest extends QueryTest with BeforeAndAfter { +private[sql] abstract class DataSourceTest extends QueryTest { + protected def _sqlContext: SQLContext + // We want to test some edge cases. - protected implicit lazy val caseInsensitiveContext = { - val ctx = new SQLContext(TestSQLContext.sparkContext) + protected lazy val caseInsensitiveContext: SQLContext = { + val ctx = new SQLContext(_sqlContext.sparkContext) ctx.setConf(SQLConf.CASE_SENSITIVE, false) ctx } + protected def sqlTest(sqlString: String, expectedAnswer: Seq[Row]) { + test(sqlString) { + checkAnswer(caseInsensitiveContext.sql(sqlString), expectedAnswer) + } + } + } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala index 5ef365797eace..c81c3d3982805 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala @@ -21,6 +21,7 @@ import scala.language.existentials import org.apache.spark.rdd.RDD import org.apache.spark.sql._ +import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ @@ -96,11 +97,11 @@ object FiltersPushed { var list: Seq[Filter] = Nil } -class FilteredScanSuite extends DataSourceTest { +class FilteredScanSuite extends DataSourceTest with SharedSQLContext { + protected override lazy val sql = caseInsensitiveContext.sql _ - import caseInsensitiveContext.sql - - before { + override def beforeAll(): Unit = { + super.beforeAll() sql( """ |CREATE TEMPORARY TABLE oneToTenFiltered diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala index cdbfaf6455fe4..78bd3e5582964 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala @@ -19,20 +19,17 @@ package org.apache.spark.sql.sources import java.io.File -import org.scalatest.BeforeAndAfterAll - import org.apache.spark.sql.{SaveMode, AnalysisException, Row} +import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.util.Utils -class InsertSuite extends DataSourceTest with BeforeAndAfterAll { - - import caseInsensitiveContext.sql - +class InsertSuite extends DataSourceTest with SharedSQLContext { + protected override lazy val sql = caseInsensitiveContext.sql _ private lazy val sparkContext = caseInsensitiveContext.sparkContext - - var path: File = null + private var path: File = null override def beforeAll(): Unit = { + super.beforeAll() path = Utils.createTempDir() val rdd = sparkContext.parallelize((1 to 10).map(i => s"""{"a":$i, "b":"str$i"}""")) caseInsensitiveContext.read.json(rdd).registerTempTable("jt") @@ -47,9 +44,13 @@ class InsertSuite extends DataSourceTest with BeforeAndAfterAll { } override def afterAll(): Unit = { - caseInsensitiveContext.dropTempTable("jsonTable") - caseInsensitiveContext.dropTempTable("jt") - Utils.deleteRecursively(path) + try { + caseInsensitiveContext.dropTempTable("jsonTable") + caseInsensitiveContext.dropTempTable("jt") + Utils.deleteRecursively(path) + } finally { + super.afterAll() + } } test("Simple INSERT OVERWRITE a JSONRelation") { @@ -221,9 +222,10 @@ class InsertSuite extends DataSourceTest with BeforeAndAfterAll { sql("SELECT a * 2 FROM jsonTable"), (1 to 10).map(i => Row(i * 2)).toSeq) - assertCached(sql("SELECT x.a, y.a FROM jsonTable x JOIN jsonTable y ON x.a = y.a + 1"), 2) - checkAnswer( - sql("SELECT x.a, y.a FROM jsonTable x JOIN jsonTable y ON x.a = y.a + 1"), + assertCached(sql( + "SELECT x.a, y.a FROM jsonTable x JOIN jsonTable y ON x.a = y.a + 1"), 2) + checkAnswer(sql( + "SELECT x.a, y.a FROM jsonTable x JOIN jsonTable y ON x.a = y.a + 1"), (2 to 10).map(i => Row(i, i - 1)).toSeq) // Insert overwrite and keep the same schema. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/PartitionedWriteSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/PartitionedWriteSuite.scala index c86ddd7c83e53..79b6e9b45c009 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/PartitionedWriteSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/PartitionedWriteSuite.scala @@ -19,21 +19,21 @@ package org.apache.spark.sql.sources import org.apache.spark.sql.{Row, QueryTest} import org.apache.spark.sql.functions._ -import org.apache.spark.sql.test.TestSQLContext +import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.util.Utils -class PartitionedWriteSuite extends QueryTest { - import TestSQLContext.implicits._ +class PartitionedWriteSuite extends QueryTest with SharedSQLContext { + import testImplicits._ test("write many partitions") { val path = Utils.createTempDir() path.delete() - val df = TestSQLContext.range(100).select($"id", lit(1).as("data")) + val df = ctx.range(100).select($"id", lit(1).as("data")) df.write.partitionBy("id").save(path.getCanonicalPath) checkAnswer( - TestSQLContext.read.load(path.getCanonicalPath), + ctx.read.load(path.getCanonicalPath), (0 to 99).map(Row(1, _)).toSeq) Utils.deleteRecursively(path) @@ -43,12 +43,12 @@ class PartitionedWriteSuite extends QueryTest { val path = Utils.createTempDir() path.delete() - val base = TestSQLContext.range(100) + val base = ctx.range(100) val df = base.unionAll(base).select($"id", lit(1).as("data")) df.write.partitionBy("id").save(path.getCanonicalPath) checkAnswer( - TestSQLContext.read.load(path.getCanonicalPath), + ctx.read.load(path.getCanonicalPath), (0 to 99).map(Row(1, _)).toSeq ++ (0 to 99).map(Row(1, _)).toSeq) Utils.deleteRecursively(path) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala index 0d5183444af78..a89c5f8007e78 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala @@ -21,6 +21,7 @@ import scala.language.existentials import org.apache.spark.rdd.RDD import org.apache.spark.sql._ +import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ class PrunedScanSource extends RelationProvider { @@ -51,10 +52,12 @@ case class SimplePrunedScan(from: Int, to: Int)(@transient val sqlContext: SQLCo } } -class PrunedScanSuite extends DataSourceTest { +class PrunedScanSuite extends DataSourceTest with SharedSQLContext { + protected override lazy val sql = caseInsensitiveContext.sql _ - before { - caseInsensitiveContext.sql( + override def beforeAll(): Unit = { + super.beforeAll() + sql( """ |CREATE TEMPORARY TABLE oneToTenPruned |USING org.apache.spark.sql.sources.PrunedScanSource @@ -114,7 +117,7 @@ class PrunedScanSuite extends DataSourceTest { def testPruning(sqlString: String, expectedColumns: String*): Unit = { test(s"Columns output ${expectedColumns.mkString(",")}: $sqlString") { - val queryExecution = caseInsensitiveContext.sql(sqlString).queryExecution + val queryExecution = sql(sqlString).queryExecution val rawPlan = queryExecution.executedPlan.collect { case p: execution.PhysicalRDD => p } match { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/SaveLoadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/SaveLoadSuite.scala index 31730a3d3f8d3..f18546b4c2d9b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/SaveLoadSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/SaveLoadSuite.scala @@ -19,25 +19,22 @@ package org.apache.spark.sql.sources import java.io.File -import org.scalatest.BeforeAndAfterAll +import org.scalatest.BeforeAndAfter import org.apache.spark.sql.{AnalysisException, SaveMode, SQLConf, DataFrame} +import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ import org.apache.spark.util.Utils -class SaveLoadSuite extends DataSourceTest with BeforeAndAfterAll { - - import caseInsensitiveContext.sql - +class SaveLoadSuite extends DataSourceTest with SharedSQLContext with BeforeAndAfter { + protected override lazy val sql = caseInsensitiveContext.sql _ private lazy val sparkContext = caseInsensitiveContext.sparkContext - - var originalDefaultSource: String = null - - var path: File = null - - var df: DataFrame = null + private var originalDefaultSource: String = null + private var path: File = null + private var df: DataFrame = null override def beforeAll(): Unit = { + super.beforeAll() originalDefaultSource = caseInsensitiveContext.conf.defaultDataSourceName path = Utils.createTempDir() @@ -49,11 +46,14 @@ class SaveLoadSuite extends DataSourceTest with BeforeAndAfterAll { } override def afterAll(): Unit = { - caseInsensitiveContext.conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, originalDefaultSource) + try { + caseInsensitiveContext.conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, originalDefaultSource) + } finally { + super.afterAll() + } } after { - caseInsensitiveContext.conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, originalDefaultSource) Utils.deleteRecursively(path) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala index e34e0956d1fdd..12af8068c398f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala @@ -22,6 +22,7 @@ import java.sql.{Date, Timestamp} import org.apache.spark.rdd.RDD import org.apache.spark.sql._ +import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ class DefaultSource extends SimpleScanSource @@ -95,8 +96,8 @@ case class AllDataTypesScan( } } -class TableScanSuite extends DataSourceTest { - import caseInsensitiveContext.sql +class TableScanSuite extends DataSourceTest with SharedSQLContext { + protected override lazy val sql = caseInsensitiveContext.sql _ private lazy val tableWithSchemaExpected = (1 to 10).map { i => Row( @@ -122,7 +123,8 @@ class TableScanSuite extends DataSourceTest { Row(Seq(s"str_$i", s"str_${i + 1}"), Row(Seq(Date.valueOf(s"1970-01-${i + 1}"))))) }.toSeq - before { + override def beforeAll(): Unit = { + super.beforeAll() sql( """ |CREATE TEMPORARY TABLE oneToTen @@ -303,9 +305,10 @@ class TableScanSuite extends DataSourceTest { sql("SELECT i * 2 FROM oneToTen"), (1 to 10).map(i => Row(i * 2)).toSeq) - assertCached(sql("SELECT a.i, b.i FROM oneToTen a JOIN oneToTen b ON a.i = b.i + 1"), 2) - checkAnswer( - sql("SELECT a.i, b.i FROM oneToTen a JOIN oneToTen b ON a.i = b.i + 1"), + assertCached(sql( + "SELECT a.i, b.i FROM oneToTen a JOIN oneToTen b ON a.i = b.i + 1"), 2) + checkAnswer(sql( + "SELECT a.i, b.i FROM oneToTen a JOIN oneToTen b ON a.i = b.i + 1"), (2 to 10).map(i => Row(i, i - 1)).toSeq) // Verify uncaching diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala new file mode 100644 index 0000000000000..1374a97476ca1 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala @@ -0,0 +1,290 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.test + +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.{DataFrame, SQLContext, SQLImplicits} + +/** + * A collection of sample data used in SQL tests. + */ +private[sql] trait SQLTestData { self => + protected def _sqlContext: SQLContext + + // Helper object to import SQL implicits without a concrete SQLContext + private object internalImplicits extends SQLImplicits { + protected override def _sqlContext: SQLContext = self._sqlContext + } + + import internalImplicits._ + import SQLTestData._ + + // Note: all test data should be lazy because the SQLContext is not set up yet. + + protected lazy val testData: DataFrame = { + val df = _sqlContext.sparkContext.parallelize( + (1 to 100).map(i => TestData(i, i.toString))).toDF() + df.registerTempTable("testData") + df + } + + protected lazy val testData2: DataFrame = { + val df = _sqlContext.sparkContext.parallelize( + TestData2(1, 1) :: + TestData2(1, 2) :: + TestData2(2, 1) :: + TestData2(2, 2) :: + TestData2(3, 1) :: + TestData2(3, 2) :: Nil, 2).toDF() + df.registerTempTable("testData2") + df + } + + protected lazy val testData3: DataFrame = { + val df = _sqlContext.sparkContext.parallelize( + TestData3(1, None) :: + TestData3(2, Some(2)) :: Nil).toDF() + df.registerTempTable("testData3") + df + } + + protected lazy val negativeData: DataFrame = { + val df = _sqlContext.sparkContext.parallelize( + (1 to 100).map(i => TestData(-i, (-i).toString))).toDF() + df.registerTempTable("negativeData") + df + } + + protected lazy val largeAndSmallInts: DataFrame = { + val df = _sqlContext.sparkContext.parallelize( + LargeAndSmallInts(2147483644, 1) :: + LargeAndSmallInts(1, 2) :: + LargeAndSmallInts(2147483645, 1) :: + LargeAndSmallInts(2, 2) :: + LargeAndSmallInts(2147483646, 1) :: + LargeAndSmallInts(3, 2) :: Nil).toDF() + df.registerTempTable("largeAndSmallInts") + df + } + + protected lazy val decimalData: DataFrame = { + val df = _sqlContext.sparkContext.parallelize( + DecimalData(1, 1) :: + DecimalData(1, 2) :: + DecimalData(2, 1) :: + DecimalData(2, 2) :: + DecimalData(3, 1) :: + DecimalData(3, 2) :: Nil).toDF() + df.registerTempTable("decimalData") + df + } + + protected lazy val binaryData: DataFrame = { + val df = _sqlContext.sparkContext.parallelize( + BinaryData("12".getBytes, 1) :: + BinaryData("22".getBytes, 5) :: + BinaryData("122".getBytes, 3) :: + BinaryData("121".getBytes, 2) :: + BinaryData("123".getBytes, 4) :: Nil).toDF() + df.registerTempTable("binaryData") + df + } + + protected lazy val upperCaseData: DataFrame = { + val df = _sqlContext.sparkContext.parallelize( + UpperCaseData(1, "A") :: + UpperCaseData(2, "B") :: + UpperCaseData(3, "C") :: + UpperCaseData(4, "D") :: + UpperCaseData(5, "E") :: + UpperCaseData(6, "F") :: Nil).toDF() + df.registerTempTable("upperCaseData") + df + } + + protected lazy val lowerCaseData: DataFrame = { + val df = _sqlContext.sparkContext.parallelize( + LowerCaseData(1, "a") :: + LowerCaseData(2, "b") :: + LowerCaseData(3, "c") :: + LowerCaseData(4, "d") :: Nil).toDF() + df.registerTempTable("lowerCaseData") + df + } + + protected lazy val arrayData: RDD[ArrayData] = { + val rdd = _sqlContext.sparkContext.parallelize( + ArrayData(Seq(1, 2, 3), Seq(Seq(1, 2, 3))) :: + ArrayData(Seq(2, 3, 4), Seq(Seq(2, 3, 4))) :: Nil) + rdd.toDF().registerTempTable("arrayData") + rdd + } + + protected lazy val mapData: RDD[MapData] = { + val rdd = _sqlContext.sparkContext.parallelize( + MapData(Map(1 -> "a1", 2 -> "b1", 3 -> "c1", 4 -> "d1", 5 -> "e1")) :: + MapData(Map(1 -> "a2", 2 -> "b2", 3 -> "c2", 4 -> "d2")) :: + MapData(Map(1 -> "a3", 2 -> "b3", 3 -> "c3")) :: + MapData(Map(1 -> "a4", 2 -> "b4")) :: + MapData(Map(1 -> "a5")) :: Nil) + rdd.toDF().registerTempTable("mapData") + rdd + } + + protected lazy val repeatedData: RDD[StringData] = { + val rdd = _sqlContext.sparkContext.parallelize(List.fill(2)(StringData("test"))) + rdd.toDF().registerTempTable("repeatedData") + rdd + } + + protected lazy val nullableRepeatedData: RDD[StringData] = { + val rdd = _sqlContext.sparkContext.parallelize( + List.fill(2)(StringData(null)) ++ + List.fill(2)(StringData("test"))) + rdd.toDF().registerTempTable("nullableRepeatedData") + rdd + } + + protected lazy val nullInts: DataFrame = { + val df = _sqlContext.sparkContext.parallelize( + NullInts(1) :: + NullInts(2) :: + NullInts(3) :: + NullInts(null) :: Nil).toDF() + df.registerTempTable("nullInts") + df + } + + protected lazy val allNulls: DataFrame = { + val df = _sqlContext.sparkContext.parallelize( + NullInts(null) :: + NullInts(null) :: + NullInts(null) :: + NullInts(null) :: Nil).toDF() + df.registerTempTable("allNulls") + df + } + + protected lazy val nullStrings: DataFrame = { + val df = _sqlContext.sparkContext.parallelize( + NullStrings(1, "abc") :: + NullStrings(2, "ABC") :: + NullStrings(3, null) :: Nil).toDF() + df.registerTempTable("nullStrings") + df + } + + protected lazy val tableName: DataFrame = { + val df = _sqlContext.sparkContext.parallelize(TableName("test") :: Nil).toDF() + df.registerTempTable("tableName") + df + } + + protected lazy val unparsedStrings: RDD[String] = { + _sqlContext.sparkContext.parallelize( + "1, A1, true, null" :: + "2, B2, false, null" :: + "3, C3, true, null" :: + "4, D4, true, 2147483644" :: Nil) + } + + // An RDD with 4 elements and 8 partitions + protected lazy val withEmptyParts: RDD[IntField] = { + val rdd = _sqlContext.sparkContext.parallelize((1 to 4).map(IntField), 8) + rdd.toDF().registerTempTable("withEmptyParts") + rdd + } + + protected lazy val person: DataFrame = { + val df = _sqlContext.sparkContext.parallelize( + Person(0, "mike", 30) :: + Person(1, "jim", 20) :: Nil).toDF() + df.registerTempTable("person") + df + } + + protected lazy val salary: DataFrame = { + val df = _sqlContext.sparkContext.parallelize( + Salary(0, 2000.0) :: + Salary(1, 1000.0) :: Nil).toDF() + df.registerTempTable("salary") + df + } + + protected lazy val complexData: DataFrame = { + val df = _sqlContext.sparkContext.parallelize( + ComplexData(Map("1" -> 1), TestData(1, "1"), Seq(1, 1, 1), true) :: + ComplexData(Map("2" -> 2), TestData(2, "2"), Seq(2, 2, 2), false) :: + Nil).toDF() + df.registerTempTable("complexData") + df + } + + /** + * Initialize all test data such that all temp tables are properly registered. + */ + def loadTestData(): Unit = { + assert(_sqlContext != null, "attempted to initialize test data before SQLContext.") + testData + testData2 + testData3 + negativeData + largeAndSmallInts + decimalData + binaryData + upperCaseData + lowerCaseData + arrayData + mapData + repeatedData + nullableRepeatedData + nullInts + allNulls + nullStrings + tableName + unparsedStrings + withEmptyParts + person + salary + complexData + } +} + +/** + * Case classes used in test data. + */ +private[sql] object SQLTestData { + case class TestData(key: Int, value: String) + case class TestData2(a: Int, b: Int) + case class TestData3(a: Int, b: Option[Int]) + case class LargeAndSmallInts(a: Int, b: Int) + case class DecimalData(a: BigDecimal, b: BigDecimal) + case class BinaryData(a: Array[Byte], b: Int) + case class UpperCaseData(N: Int, L: String) + case class LowerCaseData(n: Int, l: String) + case class ArrayData(data: Seq[Int], nestedData: Seq[Seq[Int]]) + case class MapData(data: scala.collection.Map[Int, String]) + case class StringData(s: String) + case class IntField(i: Int) + case class NullInts(a: Integer) + case class NullStrings(n: Int, s: String) + case class TableName(tableName: String) + case class Person(id: Int, name: String, age: Int) + case class Salary(personId: Int, salary: Double) + case class ComplexData(m: Map[String, Int], s: TestData, a: Seq[Int], b: Boolean) +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala index 1066695589778..cdd691e035897 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala @@ -21,15 +21,71 @@ import java.io.File import java.util.UUID import scala.util.Try +import scala.language.implicitConversions + +import org.apache.hadoop.conf.Configuration +import org.scalatest.BeforeAndAfterAll import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.{DataFrame, SQLContext, SQLImplicits} +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.util.Utils -trait SQLTestUtils { this: SparkFunSuite => - protected def sqlContext: SQLContext +/** + * Helper trait that should be extended by all SQL test suites. + * + * This allows subclasses to plugin a custom [[SQLContext]]. It comes with test data + * prepared in advance as well as all implicit conversions used extensively by dataframes. + * To use implicit methods, import `testImplicits._` instead of through the [[SQLContext]]. + * + * Subclasses should *not* create [[SQLContext]]s in the test suite constructor, which is + * prone to leaving multiple overlapping [[org.apache.spark.SparkContext]]s in the same JVM. + */ +private[sql] trait SQLTestUtils + extends SparkFunSuite + with BeforeAndAfterAll + with SQLTestData { self => + + protected def _sqlContext: SQLContext + + // Whether to materialize all test data before the first test is run + private var loadTestDataBeforeTests = false + + // Shorthand for running a query using our SQLContext + protected lazy val sql = _sqlContext.sql _ + + /** + * A helper object for importing SQL implicits. + * + * Note that the alternative of importing `sqlContext.implicits._` is not possible here. + * This is because we create the [[SQLContext]] immediately before the first test is run, + * but the implicits import is needed in the constructor. + */ + protected object testImplicits extends SQLImplicits { + protected override def _sqlContext: SQLContext = self._sqlContext + } + + /** + * Materialize the test data immediately after the [[SQLContext]] is set up. + * This is necessary if the data is accessed by name but not through direct reference. + */ + protected def setupTestData(): Unit = { + loadTestDataBeforeTests = true + } - protected def configuration = sqlContext.sparkContext.hadoopConfiguration + protected override def beforeAll(): Unit = { + super.beforeAll() + if (loadTestDataBeforeTests) { + loadTestData() + } + } + + /** + * The Hadoop configuration used by the active [[SQLContext]]. + */ + protected def configuration: Configuration = { + _sqlContext.sparkContext.hadoopConfiguration + } /** * Sets all SQL configurations specified in `pairs`, calls `f`, and then restore all SQL @@ -39,12 +95,12 @@ trait SQLTestUtils { this: SparkFunSuite => */ protected def withSQLConf(pairs: (String, String)*)(f: => Unit): Unit = { val (keys, values) = pairs.unzip - val currentValues = keys.map(key => Try(sqlContext.conf.getConfString(key)).toOption) - (keys, values).zipped.foreach(sqlContext.conf.setConfString) + val currentValues = keys.map(key => Try(_sqlContext.conf.getConfString(key)).toOption) + (keys, values).zipped.foreach(_sqlContext.conf.setConfString) try f finally { keys.zip(currentValues).foreach { - case (key, Some(value)) => sqlContext.conf.setConfString(key, value) - case (key, None) => sqlContext.conf.unsetConf(key) + case (key, Some(value)) => _sqlContext.conf.setConfString(key, value) + case (key, None) => _sqlContext.conf.unsetConf(key) } } } @@ -76,7 +132,7 @@ trait SQLTestUtils { this: SparkFunSuite => * Drops temporary table `tableName` after calling `f`. */ protected def withTempTable(tableNames: String*)(f: => Unit): Unit = { - try f finally tableNames.foreach(sqlContext.dropTempTable) + try f finally tableNames.foreach(_sqlContext.dropTempTable) } /** @@ -85,7 +141,7 @@ trait SQLTestUtils { this: SparkFunSuite => protected def withTable(tableNames: String*)(f: => Unit): Unit = { try f finally { tableNames.foreach { name => - sqlContext.sql(s"DROP TABLE IF EXISTS $name") + _sqlContext.sql(s"DROP TABLE IF EXISTS $name") } } } @@ -98,12 +154,12 @@ trait SQLTestUtils { this: SparkFunSuite => val dbName = s"db_${UUID.randomUUID().toString.replace('-', '_')}" try { - sqlContext.sql(s"CREATE DATABASE $dbName") + _sqlContext.sql(s"CREATE DATABASE $dbName") } catch { case cause: Throwable => fail("Failed to create temporary database", cause) } - try f(dbName) finally sqlContext.sql(s"DROP DATABASE $dbName CASCADE") + try f(dbName) finally _sqlContext.sql(s"DROP DATABASE $dbName CASCADE") } /** @@ -111,7 +167,15 @@ trait SQLTestUtils { this: SparkFunSuite => * `f` returns. */ protected def activateDatabase(db: String)(f: => Unit): Unit = { - sqlContext.sql(s"USE $db") - try f finally sqlContext.sql(s"USE default") + _sqlContext.sql(s"USE $db") + try f finally _sqlContext.sql(s"USE default") + } + + /** + * Turn a logical plan into a [[DataFrame]]. This should be removed once we have an easier + * way to construct [[DataFrame]] directly out of local data without relying on implicits. + */ + protected implicit def logicalPlanToSparkQuery(plan: LogicalPlan): DataFrame = { + DataFrame(_sqlContext, plan) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala new file mode 100644 index 0000000000000..3cfd822e2a747 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.test + +import org.apache.spark.sql.SQLContext + + +/** + * Helper trait for SQL test suites where all tests share a single [[TestSQLContext]]. + */ +private[sql] trait SharedSQLContext extends SQLTestUtils { + + /** + * The [[TestSQLContext]] to use for all tests in this suite. + * + * By default, the underlying [[org.apache.spark.SparkContext]] will be run in local + * mode with the default test configurations. + */ + private var _ctx: TestSQLContext = null + + /** + * The [[TestSQLContext]] to use for all tests in this suite. + */ + protected def ctx: TestSQLContext = _ctx + protected def sqlContext: TestSQLContext = _ctx + protected override def _sqlContext: SQLContext = _ctx + + /** + * Initialize the [[TestSQLContext]]. + */ + protected override def beforeAll(): Unit = { + if (_ctx == null) { + _ctx = new TestSQLContext + } + // Ensure we have initialized the context before calling parent code + super.beforeAll() + } + + /** + * Stop the underlying [[org.apache.spark.SparkContext]], if any. + */ + protected override def afterAll(): Unit = { + try { + if (_ctx != null) { + _ctx.sparkContext.stop() + _ctx = null + } + } finally { + super.afterAll() + } + } + +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/test/TestSQLContext.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala similarity index 54% rename from sql/core/src/main/scala/org/apache/spark/sql/test/TestSQLContext.scala rename to sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala index b3a4231da91c2..92ef2f7d74ba1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/test/TestSQLContext.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala @@ -17,40 +17,36 @@ package org.apache.spark.sql.test -import scala.language.implicitConversions - import org.apache.spark.{SparkConf, SparkContext} -import org.apache.spark.sql.{DataFrame, SQLConf, SQLContext} -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan - -/** A SQLContext that can be used for local testing. */ -class LocalSQLContext - extends SQLContext( - new SparkContext("local[2]", "TestSQLContext", new SparkConf() - .set("spark.sql.testkey", "true") - // SPARK-8910 - .set("spark.ui.enabled", "false"))) { - - override protected[sql] def createSession(): SQLSession = { - new this.SQLSession() +import org.apache.spark.sql.{SQLConf, SQLContext} + + +/** + * A special [[SQLContext]] prepared for testing. + */ +private[sql] class TestSQLContext(sc: SparkContext) extends SQLContext(sc) { self => + + def this() { + this(new SparkContext("local[2]", "test-sql-context", + new SparkConf().set("spark.sql.testkey", "true"))) } + // Use fewer partitions to speed up testing + protected[sql] override def createSession(): SQLSession = new this.SQLSession() + + /** A special [[SQLSession]] that uses fewer shuffle partitions than normal. */ protected[sql] class SQLSession extends super.SQLSession { protected[sql] override lazy val conf: SQLConf = new SQLConf { - /** Fewer partitions to speed up testing. */ override def numShufflePartitions: Int = this.getConf(SQLConf.SHUFFLE_PARTITIONS, 5) } } - /** - * Turn a logical plan into a [[DataFrame]]. This should be removed once we have an easier way to - * construct [[DataFrame]] directly out of local data without relying on implicits. - */ - protected[sql] implicit def logicalPlanToSparkQuery(plan: LogicalPlan): DataFrame = { - DataFrame(this, plan) + // Needed for Java tests + def loadTestData(): Unit = { + testData.loadTestData() } + private object testData extends SQLTestData { + protected override def _sqlContext: SQLContext = self + } } - -object TestSQLContext extends LocalSQLContext - diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/UISeleniumSuite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/UISeleniumSuite.scala index 806240e6de458..bf431cd6b0260 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/UISeleniumSuite.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/UISeleniumSuite.scala @@ -27,7 +27,6 @@ import org.scalatest.concurrent.Eventually._ import org.scalatest.selenium.WebBrowser import org.scalatest.time.SpanSugar._ -import org.apache.spark.sql.hive.HiveContext import org.apache.spark.ui.SparkUICssErrorHandler class UISeleniumSuite @@ -36,7 +35,6 @@ class UISeleniumSuite implicit var webDriver: WebDriver = _ var server: HiveThriftServer2 = _ - var hc: HiveContext = _ val uiPort = 20000 + Random.nextInt(10000) override def mode: ServerMode.Value = ServerMode.binary diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala index 59e65ff97b8e0..574624d501f22 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.hive.test.TestHive.implicits._ import org.apache.spark.sql.sources.DataSourceTest import org.apache.spark.sql.test.{ExamplePointUDT, SQLTestUtils} import org.apache.spark.sql.types.{DecimalType, StringType, StructType} -import org.apache.spark.sql.{Row, SaveMode} +import org.apache.spark.sql.{Row, SaveMode, SQLContext} import org.apache.spark.{Logging, SparkFunSuite} @@ -53,7 +53,8 @@ class HiveMetastoreCatalogSuite extends SparkFunSuite with Logging { } class DataSourceWithHiveMetastoreCatalogSuite extends DataSourceTest with SQLTestUtils { - override val sqlContext = TestHive + override def _sqlContext: SQLContext = TestHive + import testImplicits._ private val testDF = range(1, 3).select( ('id + 0.1) cast DecimalType(10, 3) as 'd1, diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveParquetSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveParquetSuite.scala index 1fa005d5f9a15..fe0db5228de16 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveParquetSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveParquetSuite.scala @@ -19,14 +19,13 @@ package org.apache.spark.sql.hive import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.execution.datasources.parquet.ParquetTest -import org.apache.spark.sql.{QueryTest, Row} +import org.apache.spark.sql.{QueryTest, Row, SQLContext} case class Cases(lower: String, UPPER: String) class HiveParquetSuite extends QueryTest with ParquetTest { - val sqlContext = TestHive - - import sqlContext._ + private val ctx = TestHive + override def _sqlContext: SQLContext = ctx test("Case insensitive attribute names") { withParquetTable((1 to 4).map(i => Cases(i.toString, i.toString)), "cases") { @@ -54,7 +53,7 @@ class HiveParquetSuite extends QueryTest with ParquetTest { test("Converting Hive to Parquet Table via saveAsParquetFile") { withTempPath { dir => sql("SELECT * FROM src").write.parquet(dir.getCanonicalPath) - read.parquet(dir.getCanonicalPath).registerTempTable("p") + ctx.read.parquet(dir.getCanonicalPath).registerTempTable("p") withTempTable("p") { checkAnswer( sql("SELECT * FROM src ORDER BY key"), @@ -67,7 +66,7 @@ class HiveParquetSuite extends QueryTest with ParquetTest { withParquetTable((1 to 10).map(i => (i, s"val_$i")), "t") { withTempPath { file => sql("SELECT * FROM t LIMIT 1").write.parquet(file.getCanonicalPath) - read.parquet(file.getCanonicalPath).registerTempTable("p") + ctx.read.parquet(file.getCanonicalPath).registerTempTable("p") withTempTable("p") { // let's do three overwrites for good measure sql("INSERT OVERWRITE TABLE p SELECT * FROM t") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala index 7f36a483a3965..20a50586d5201 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala @@ -22,7 +22,6 @@ import java.io.{IOException, File} import scala.collection.mutable.ArrayBuffer import org.apache.hadoop.fs.Path -import org.apache.hadoop.mapred.InvalidInputException import org.scalatest.BeforeAndAfterAll import org.apache.spark.Logging @@ -42,7 +41,8 @@ import org.apache.spark.util.Utils */ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with BeforeAndAfterAll with Logging { - override val sqlContext = TestHive + override def _sqlContext: SQLContext = TestHive + private val sqlContext = _sqlContext var jsonFilePath: String = _ diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MultiDatabaseSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MultiDatabaseSuite.scala index 73852f13ad20d..417e8b07917cc 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MultiDatabaseSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MultiDatabaseSuite.scala @@ -22,9 +22,8 @@ import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql.{QueryTest, SQLContext, SaveMode} class MultiDatabaseSuite extends QueryTest with SQLTestUtils { - override val sqlContext: SQLContext = TestHive - - import sqlContext.sql + override val _sqlContext: SQLContext = TestHive + private val sqlContext = _sqlContext private val df = sqlContext.range(10).coalesce(1) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ParquetHiveCompatibilitySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ParquetHiveCompatibilitySuite.scala index 251e0324bfa5f..13452e71a1b3b 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ParquetHiveCompatibilitySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ParquetHiveCompatibilitySuite.scala @@ -26,7 +26,8 @@ import org.apache.spark.sql.{Row, SQLConf, SQLContext} class ParquetHiveCompatibilitySuite extends ParquetCompatibilityTest { import ParquetCompatibilityTest.makeNullable - override val sqlContext: SQLContext = TestHive + override def _sqlContext: SQLContext = TestHive + private val sqlContext = _sqlContext /** * Set the staging directory (and hence path to ignore Parquet files under) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/UDFSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/UDFSuite.scala index 9b3ede43ee2d1..7ee1c8d13aa3f 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/UDFSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/UDFSuite.scala @@ -17,14 +17,12 @@ package org.apache.spark.sql.hive -import org.apache.spark.sql.{Row, QueryTest} +import org.apache.spark.sql.QueryTest case class FunctionResult(f1: String, f2: String) class UDFSuite extends QueryTest { - private lazy val ctx = org.apache.spark.sql.hive.test.TestHive - import ctx.implicits._ test("UDF case insensitive") { ctx.udf.register("random0", () => { Math.random() }) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala index 7b5aa4763fd9e..a312f84958248 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala @@ -17,17 +17,18 @@ package org.apache.spark.sql.hive.execution +import org.scalatest.BeforeAndAfterAll + +import org.apache.spark.sql._ import org.apache.spark.sql.execution.aggregate import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType} -import org.apache.spark.sql._ -import org.scalatest.BeforeAndAfterAll import _root_.test.org.apache.spark.sql.hive.aggregate.{MyDoubleAvg, MyDoubleSum} abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with BeforeAndAfterAll { - - override val sqlContext = TestHive + override def _sqlContext: SQLContext = TestHive + protected val sqlContext = _sqlContext import sqlContext.implicits._ var originalUseAggregate2: Boolean = _ diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala index 44c5b80392fa5..11d7a872dff09 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala @@ -26,8 +26,8 @@ import org.apache.spark.sql.test.SQLTestUtils * A set of tests that validates support for Hive Explain command. */ class HiveExplainSuite extends QueryTest with SQLTestUtils { - - def sqlContext: SQLContext = TestHive + override def _sqlContext: SQLContext = TestHive + private val sqlContext = _sqlContext test("explain extended command") { checkExistence(sql(" explain select * from src where key=123 "), true, diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 79a136ae6f619..8b8f520776e70 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -66,7 +66,8 @@ class MyDialect extends DefaultParserDialect * valid, but Hive currently cannot execute it. */ class SQLQuerySuite extends QueryTest with SQLTestUtils { - override def sqlContext: SQLContext = TestHive + override def _sqlContext: SQLContext = TestHive + private val sqlContext = _sqlContext test("UDTF") { sql(s"ADD JAR ${TestHive.getHiveFile("TestUDTF.jar").getCanonicalPath()}") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ScriptTransformationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ScriptTransformationSuite.scala index 0875232aede3e..9aca40f15ac15 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ScriptTransformationSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ScriptTransformationSuite.scala @@ -31,7 +31,8 @@ import org.apache.spark.sql.types.StringType class ScriptTransformationSuite extends SparkPlanTest { - override def sqlContext: SQLContext = TestHive + override def _sqlContext: SQLContext = TestHive + private val sqlContext = _sqlContext private val noSerdeIOSchema = HiveScriptIOSchema( inputRowFormat = Seq.empty, diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcTest.scala index 145965388da01..f7ba20ff41d8d 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcTest.scala @@ -27,8 +27,8 @@ import org.apache.spark.sql._ import org.apache.spark.sql.test.SQLTestUtils private[sql] trait OrcTest extends SQLTestUtils { this: SparkFunSuite => - lazy val sqlContext = org.apache.spark.sql.hive.test.TestHive - + protected override def _sqlContext: SQLContext = org.apache.spark.sql.hive.test.TestHive + protected val sqlContext = _sqlContext import sqlContext.implicits._ import sqlContext.sparkContext diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala index 50f02432dacce..34d3434569f58 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala @@ -685,7 +685,8 @@ class ParquetSourceSuite extends ParquetPartitioningTest { * A collection of tests for parquet data with various forms of partitioning. */ abstract class ParquetPartitioningTest extends QueryTest with SQLTestUtils with BeforeAndAfterAll { - override def sqlContext: SQLContext = TestHive + override def _sqlContext: SQLContext = TestHive + protected val sqlContext = _sqlContext var partitionedTableDir: File = null var normalTableDir: File = null diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/CommitFailureTestRelationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/CommitFailureTestRelationSuite.scala index e976125b3706d..b4640b1616281 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/CommitFailureTestRelationSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/CommitFailureTestRelationSuite.scala @@ -18,14 +18,16 @@ package org.apache.spark.sql.sources import org.apache.hadoop.fs.Path -import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.{SparkException, SparkFunSuite} +import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.sql.SQLContext import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.test.SQLTestUtils class CommitFailureTestRelationSuite extends SparkFunSuite with SQLTestUtils { - override val sqlContext = TestHive + override def _sqlContext: SQLContext = TestHive + private val sqlContext = _sqlContext // When committing a task, `CommitFailureTestSource` throws an exception for testing purpose. val dataSourceName: String = classOf[CommitFailureTestSource].getCanonicalName diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala index 2a69d331b6e52..af445626fbe4d 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala @@ -34,9 +34,8 @@ import org.apache.spark.sql.types._ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils { - override lazy val sqlContext: SQLContext = TestHive - - import sqlContext.sql + override def _sqlContext: SQLContext = TestHive + protected val sqlContext = _sqlContext import sqlContext.implicits._ val dataSourceName: String From bd35385d53a6b039e0241e3e73092b8b0a8e455a Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Thu, 13 Aug 2015 21:12:59 -0700 Subject: [PATCH 101/253] [SPARK-9945] [SQL] pageSize should be calculated from executor.memory Currently, pageSize of TungstenSort is calculated from driver.memory, it should use executor.memory instead. Also, in the worst case, the safeFactor could be 4 (because of rounding), increase it to 16. cc rxin Author: Davies Liu Closes #8175 from davies/page_size. --- .../org/apache/spark/shuffle/ShuffleMemoryManager.scala | 4 +++- .../main/scala/org/apache/spark/sql/execution/sort.scala | 6 +++--- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleMemoryManager.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleMemoryManager.scala index 8c3a72644c38a..a0d8abc2eecb3 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/ShuffleMemoryManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleMemoryManager.scala @@ -175,7 +175,9 @@ private[spark] object ShuffleMemoryManager { val minPageSize = 1L * 1024 * 1024 // 1MB val maxPageSize = 64L * minPageSize // 64MB val cores = if (numCores > 0) numCores else Runtime.getRuntime.availableProcessors() - val safetyFactor = 8 + // Because of rounding to next power of 2, we may have safetyFactor as 8 in worst case + val safetyFactor = 16 + // TODO(davies): don't round to next power of 2 val size = ByteArrayMethods.nextPowerOf2(maxMemory / cores / safetyFactor) val default = math.min(maxPageSize, math.max(minPageSize, size)) conf.getSizeAsBytes("spark.buffer.pageSize", default) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala index e316930470127..40ef7c3b53530 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala @@ -17,15 +17,15 @@ package org.apache.spark.sql.execution -import org.apache.spark.{SparkEnv, InternalAccumulator, TaskContext} import org.apache.spark.rdd.{MapPartitionsWithPreparationRDD, RDD} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.errors._ import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.physical.{UnspecifiedDistribution, OrderedDistribution, Distribution} +import org.apache.spark.sql.catalyst.plans.physical.{Distribution, OrderedDistribution, UnspecifiedDistribution} import org.apache.spark.sql.types.StructType import org.apache.spark.util.CompletionIterator import org.apache.spark.util.collection.ExternalSorter +import org.apache.spark.{SparkEnv, InternalAccumulator, TaskContext} //////////////////////////////////////////////////////////////////////////////////////////////////// // This file defines various sort operators. @@ -122,7 +122,6 @@ case class TungstenSort( protected override def doExecute(): RDD[InternalRow] = { val schema = child.schema val childOutput = child.output - val pageSize = SparkEnv.get.shuffleMemoryManager.pageSizeBytes /** * Set up the sorter in each partition before computing the parent partition. @@ -143,6 +142,7 @@ case class TungstenSort( } } + val pageSize = SparkEnv.get.shuffleMemoryManager.pageSizeBytes val sorter = new UnsafeExternalRowSorter( schema, ordering, prefixComparator, prefixComputer, pageSize) if (testSpillFrequency > 0) { From 7c7c7529a16c0e79778e522a3df82a0f1c3762a3 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 13 Aug 2015 22:06:09 -0700 Subject: [PATCH 102/253] [MINOR] [SQL] Remove canEqual in Row As `InternalRow` does not extend `Row` now, I think we can remove it. Author: Liang-Chi Hsieh Closes #8170 from viirya/remove_canequal. --- .../main/scala/org/apache/spark/sql/Row.scala | 21 ------------------- 1 file changed, 21 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala index 40159aaf14d34..ec895af9c3037 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala @@ -364,31 +364,10 @@ trait Row extends Serializable { false } - /** - * Returns true if we can check equality for these 2 rows. - * Equality check between external row and internal row is not allowed. - * Here we do this check to prevent call `equals` on external row with internal row. - */ - protected def canEqual(other: Row) = { - // Note that `Row` is not only the interface of external row but also the parent - // of `InternalRow`, so we have to ensure `other` is not a internal row here to prevent - // call `equals` on external row with internal row. - // `InternalRow` overrides canEqual, and these two canEquals together makes sure that - // equality check between external Row and InternalRow will always fail. - // In the future, InternalRow should not extend Row. In that case, we can remove these - // canEqual methods. - !other.isInstanceOf[InternalRow] - } - override def equals(o: Any): Boolean = { if (!o.isInstanceOf[Row]) return false val other = o.asInstanceOf[Row] - if (!canEqual(other)) { - throw new UnsupportedOperationException( - "cannot check equality between external and internal rows") - } - if (other eq null) return false if (length != other.length) { From c8677d73666850b37ff937520e538650632ce304 Mon Sep 17 00:00:00 2001 From: zsxwing Date: Fri, 14 Aug 2015 14:41:53 +0800 Subject: [PATCH 103/253] [SPARK-9958] [SQL] Make HiveThriftServer2Listener thread-safe and update the tab name to "JDBC/ODBC Server" This PR fixed the thread-safe issue of HiveThriftServer2Listener, and also changed the tab name to "JDBC/ODBC Server" since it's conflict with the new SQL tab. thriftserver Author: zsxwing Closes #8185 from zsxwing/SPARK-9958. --- .../hive/thriftserver/HiveThriftServer2.scala | 64 +++++++++++-------- .../thriftserver/ui/ThriftServerPage.scala | 32 +++++----- .../ui/ThriftServerSessionPage.scala | 38 +++++------ .../thriftserver/ui/ThriftServerTab.scala | 4 +- 4 files changed, 78 insertions(+), 60 deletions(-) diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala index 2c9fa595b2dad..dd9fef9206d0b 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala @@ -152,16 +152,26 @@ object HiveThriftServer2 extends Logging { override def onApplicationEnd(applicationEnd: SparkListenerApplicationEnd): Unit = { server.stop() } - var onlineSessionNum: Int = 0 - val sessionList = new mutable.LinkedHashMap[String, SessionInfo] - val executionList = new mutable.LinkedHashMap[String, ExecutionInfo] - val retainedStatements = - conf.getConf(SQLConf.THRIFTSERVER_UI_STATEMENT_LIMIT) - val retainedSessions = - conf.getConf(SQLConf.THRIFTSERVER_UI_SESSION_LIMIT) - var totalRunning = 0 - - override def onJobStart(jobStart: SparkListenerJobStart): Unit = { + private var onlineSessionNum: Int = 0 + private val sessionList = new mutable.LinkedHashMap[String, SessionInfo] + private val executionList = new mutable.LinkedHashMap[String, ExecutionInfo] + private val retainedStatements = conf.getConf(SQLConf.THRIFTSERVER_UI_STATEMENT_LIMIT) + private val retainedSessions = conf.getConf(SQLConf.THRIFTSERVER_UI_SESSION_LIMIT) + private var totalRunning = 0 + + def getOnlineSessionNum: Int = synchronized { onlineSessionNum } + + def getTotalRunning: Int = synchronized { totalRunning } + + def getSessionList: Seq[SessionInfo] = synchronized { sessionList.values.toSeq } + + def getSession(sessionId: String): Option[SessionInfo] = synchronized { + sessionList.get(sessionId) + } + + def getExecutionList: Seq[ExecutionInfo] = synchronized { executionList.values.toSeq } + + override def onJobStart(jobStart: SparkListenerJobStart): Unit = synchronized { for { props <- Option(jobStart.properties) groupId <- Option(props.getProperty(SparkContext.SPARK_JOB_GROUP_ID)) @@ -173,13 +183,15 @@ object HiveThriftServer2 extends Logging { } def onSessionCreated(ip: String, sessionId: String, userName: String = "UNKNOWN"): Unit = { - val info = new SessionInfo(sessionId, System.currentTimeMillis, ip, userName) - sessionList.put(sessionId, info) - onlineSessionNum += 1 - trimSessionIfNecessary() + synchronized { + val info = new SessionInfo(sessionId, System.currentTimeMillis, ip, userName) + sessionList.put(sessionId, info) + onlineSessionNum += 1 + trimSessionIfNecessary() + } } - def onSessionClosed(sessionId: String): Unit = { + def onSessionClosed(sessionId: String): Unit = synchronized { sessionList(sessionId).finishTimestamp = System.currentTimeMillis onlineSessionNum -= 1 trimSessionIfNecessary() @@ -190,7 +202,7 @@ object HiveThriftServer2 extends Logging { sessionId: String, statement: String, groupId: String, - userName: String = "UNKNOWN"): Unit = { + userName: String = "UNKNOWN"): Unit = synchronized { val info = new ExecutionInfo(statement, sessionId, System.currentTimeMillis, userName) info.state = ExecutionState.STARTED executionList.put(id, info) @@ -200,27 +212,29 @@ object HiveThriftServer2 extends Logging { totalRunning += 1 } - def onStatementParsed(id: String, executionPlan: String): Unit = { + def onStatementParsed(id: String, executionPlan: String): Unit = synchronized { executionList(id).executePlan = executionPlan executionList(id).state = ExecutionState.COMPILED } def onStatementError(id: String, errorMessage: String, errorTrace: String): Unit = { - executionList(id).finishTimestamp = System.currentTimeMillis - executionList(id).detail = errorMessage - executionList(id).state = ExecutionState.FAILED - totalRunning -= 1 - trimExecutionIfNecessary() + synchronized { + executionList(id).finishTimestamp = System.currentTimeMillis + executionList(id).detail = errorMessage + executionList(id).state = ExecutionState.FAILED + totalRunning -= 1 + trimExecutionIfNecessary() + } } - def onStatementFinish(id: String): Unit = { + def onStatementFinish(id: String): Unit = synchronized { executionList(id).finishTimestamp = System.currentTimeMillis executionList(id).state = ExecutionState.FINISHED totalRunning -= 1 trimExecutionIfNecessary() } - private def trimExecutionIfNecessary() = synchronized { + private def trimExecutionIfNecessary() = { if (executionList.size > retainedStatements) { val toRemove = math.max(retainedStatements / 10, 1) executionList.filter(_._2.finishTimestamp != 0).take(toRemove).foreach { s => @@ -229,7 +243,7 @@ object HiveThriftServer2 extends Logging { } } - private def trimSessionIfNecessary() = synchronized { + private def trimSessionIfNecessary() = { if (sessionList.size > retainedSessions) { val toRemove = math.max(retainedSessions / 10, 1) sessionList.filter(_._2.finishTimestamp != 0).take(toRemove).foreach { s => diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerPage.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerPage.scala index 10c83d8b27a2a..e990bd06011ff 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerPage.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerPage.scala @@ -39,14 +39,16 @@ private[ui] class ThriftServerPage(parent: ThriftServerTab) extends WebUIPage("" /** Render the page */ def render(request: HttpServletRequest): Seq[Node] = { val content = - generateBasicStats() ++ -
++ -

- {listener.onlineSessionNum} session(s) are online, - running {listener.totalRunning} SQL statement(s) -

++ - generateSessionStatsTable() ++ - generateSQLStatsTable() + listener.synchronized { // make sure all parts in this page are consistent + generateBasicStats() ++ +
++ +

+ {listener.getOnlineSessionNum} session(s) are online, + running {listener.getTotalRunning} SQL statement(s) +

++ + generateSessionStatsTable() ++ + generateSQLStatsTable() + } UIUtils.headerSparkPage("JDBC/ODBC Server", content, parent, Some(5000)) } @@ -65,11 +67,11 @@ private[ui] class ThriftServerPage(parent: ThriftServerTab) extends WebUIPage("" /** Generate stats of batch statements of the thrift server program */ private def generateSQLStatsTable(): Seq[Node] = { - val numStatement = listener.executionList.size + val numStatement = listener.getExecutionList.size val table = if (numStatement > 0) { val headerRow = Seq("User", "JobID", "GroupID", "Start Time", "Finish Time", "Duration", "Statement", "State", "Detail") - val dataRows = listener.executionList.values + val dataRows = listener.getExecutionList def generateDataRow(info: ExecutionInfo): Seq[Node] = { val jobLink = info.jobId.map { id: String => @@ -136,15 +138,15 @@ private[ui] class ThriftServerPage(parent: ThriftServerTab) extends WebUIPage("" /** Generate stats of batch sessions of the thrift server program */ private def generateSessionStatsTable(): Seq[Node] = { - val numBatches = listener.sessionList.size + val sessionList = listener.getSessionList + val numBatches = sessionList.size val table = if (numBatches > 0) { - val dataRows = - listener.sessionList.values + val dataRows = sessionList val headerRow = Seq("User", "IP", "Session ID", "Start Time", "Finish Time", "Duration", "Total Execute") def generateDataRow(session: SessionInfo): Seq[Node] = { - val sessionLink = "%s/sql/session?id=%s" - .format(UIUtils.prependBaseUri(parent.basePath), session.sessionId) + val sessionLink = "%s/%s/session?id=%s" + .format(UIUtils.prependBaseUri(parent.basePath), parent.prefix, session.sessionId) {session.userName} {session.ip} diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerSessionPage.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerSessionPage.scala index 3b01afa603cea..af16cb31df187 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerSessionPage.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerSessionPage.scala @@ -40,21 +40,22 @@ private[ui] class ThriftServerSessionPage(parent: ThriftServerTab) def render(request: HttpServletRequest): Seq[Node] = { val parameterId = request.getParameter("id") require(parameterId != null && parameterId.nonEmpty, "Missing id parameter") - val sessionStat = listener.sessionList.find(stat => { - stat._1 == parameterId - }).getOrElse(null) - require(sessionStat != null, "Invalid sessionID[" + parameterId + "]") val content = - generateBasicStats() ++ -
++ -

- User {sessionStat._2.userName}, - IP {sessionStat._2.ip}, - Session created at {formatDate(sessionStat._2.startTimestamp)}, - Total run {sessionStat._2.totalExecution} SQL -

++ - generateSQLStatsTable(sessionStat._2.sessionId) + listener.synchronized { // make sure all parts in this page are consistent + val sessionStat = listener.getSession(parameterId).getOrElse(null) + require(sessionStat != null, "Invalid sessionID[" + parameterId + "]") + + generateBasicStats() ++ +
++ +

+ User {sessionStat.userName}, + IP {sessionStat.ip}, + Session created at {formatDate(sessionStat.startTimestamp)}, + Total run {sessionStat.totalExecution} SQL +

++ + generateSQLStatsTable(sessionStat.sessionId) + } UIUtils.headerSparkPage("JDBC/ODBC Session", content, parent, Some(5000)) } @@ -73,13 +74,13 @@ private[ui] class ThriftServerSessionPage(parent: ThriftServerTab) /** Generate stats of batch statements of the thrift server program */ private def generateSQLStatsTable(sessionID: String): Seq[Node] = { - val executionList = listener.executionList - .filter(_._2.sessionId == sessionID) + val executionList = listener.getExecutionList + .filter(_.sessionId == sessionID) val numStatement = executionList.size val table = if (numStatement > 0) { val headerRow = Seq("User", "JobID", "GroupID", "Start Time", "Finish Time", "Duration", "Statement", "State", "Detail") - val dataRows = executionList.values.toSeq.sortBy(_.startTimestamp).reverse + val dataRows = executionList.sortBy(_.startTimestamp).reverse def generateDataRow(info: ExecutionInfo): Seq[Node] = { val jobLink = info.jobId.map { id: String => @@ -146,10 +147,11 @@ private[ui] class ThriftServerSessionPage(parent: ThriftServerTab) /** Generate stats of batch sessions of the thrift server program */ private def generateSessionStatsTable(): Seq[Node] = { - val numBatches = listener.sessionList.size + val sessionList = listener.getSessionList + val numBatches = sessionList.size val table = if (numBatches > 0) { val dataRows = - listener.sessionList.values.toSeq.sortBy(_.startTimestamp).reverse.map ( session => + sessionList.sortBy(_.startTimestamp).reverse.map ( session => Seq( session.userName, session.ip, diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerTab.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerTab.scala index 94fd8a6bb60b9..4eabeaa6735e6 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerTab.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerTab.scala @@ -27,9 +27,9 @@ import org.apache.spark.{SparkContext, Logging, SparkException} * This assumes the given SparkContext has enabled its SparkUI. */ private[thriftserver] class ThriftServerTab(sparkContext: SparkContext) - extends SparkUITab(getSparkUI(sparkContext), "sql") with Logging { + extends SparkUITab(getSparkUI(sparkContext), "sqlserver") with Logging { - override val name = "SQL" + override val name = "JDBC/ODBC Server" val parent = getSparkUI(sparkContext) val listener = HiveThriftServer2.listener From a0e1abbd010b9e73d472ce12ff1d987678005d32 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Fri, 14 Aug 2015 10:25:11 -0700 Subject: [PATCH 104/253] [SPARK-9661] [MLLIB] minor clean-up of SPARK-9661 Some minor clean-ups after SPARK-9661. See my inline comments. MechCoder jkbradley Author: Xiangrui Meng Closes #8190 from mengxr/SPARK-9661-fix. --- .../spark/mllib/clustering/LDAModel.scala | 5 +-- .../apache/spark/mllib/stat/Statistics.scala | 6 +-- .../spark/mllib/clustering/JavaLDASuite.java | 40 +++++++++++-------- .../spark/mllib/clustering/LDASuite.scala | 2 +- 4 files changed, 28 insertions(+), 25 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala index f31949f13a4cf..82f05e4a18cee 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala @@ -674,10 +674,9 @@ class DistributedLDAModel private[clustering] ( } /** Java-friendly version of [[topTopicsPerDocument]] */ - def javaTopTopicsPerDocument( - k: Int): JavaRDD[(java.lang.Long, Array[Int], Array[java.lang.Double])] = { + def javaTopTopicsPerDocument(k: Int): JavaRDD[(java.lang.Long, Array[Int], Array[Double])] = { val topics = topTopicsPerDocument(k) - topics.asInstanceOf[RDD[(java.lang.Long, Array[Int], Array[java.lang.Double])]].toJavaRDD() + topics.asInstanceOf[RDD[(java.lang.Long, Array[Int], Array[Double])]].toJavaRDD() } // TODO: diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala index 24fe48cb8f71f..ef8d78607048f 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala @@ -221,9 +221,7 @@ object Statistics { def kolmogorovSmirnovTest( data: JavaDoubleRDD, distName: String, - params: java.lang.Double*): KolmogorovSmirnovTestResult = { - val javaParams = params.asInstanceOf[Seq[Double]] - KolmogorovSmirnovTest.testOneSample(data.rdd.asInstanceOf[RDD[Double]], - distName, javaParams: _*) + params: Double*): KolmogorovSmirnovTestResult = { + kolmogorovSmirnovTest(data.rdd.asInstanceOf[RDD[Double]], distName, params: _*) } } diff --git a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java index 427be9430d820..6e91cde2eabb5 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java @@ -22,12 +22,14 @@ import java.util.Arrays; import scala.Tuple2; +import scala.Tuple3; import org.junit.After; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertArrayEquals; import org.junit.Before; import org.junit.Test; +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; import org.apache.spark.api.java.function.Function; import org.apache.spark.api.java.JavaPairRDD; @@ -44,9 +46,9 @@ public class JavaLDASuite implements Serializable { public void setUp() { sc = new JavaSparkContext("local", "JavaLDA"); ArrayList> tinyCorpus = new ArrayList>(); - for (int i = 0; i < LDASuite$.MODULE$.tinyCorpus().length; i++) { - tinyCorpus.add(new Tuple2((Long)LDASuite$.MODULE$.tinyCorpus()[i]._1(), - LDASuite$.MODULE$.tinyCorpus()[i]._2())); + for (int i = 0; i < LDASuite.tinyCorpus().length; i++) { + tinyCorpus.add(new Tuple2((Long)LDASuite.tinyCorpus()[i]._1(), + LDASuite.tinyCorpus()[i]._2())); } JavaRDD> tmpCorpus = sc.parallelize(tinyCorpus, 2); corpus = JavaPairRDD.fromJavaRDD(tmpCorpus); @@ -60,7 +62,7 @@ public void tearDown() { @Test public void localLDAModel() { - Matrix topics = LDASuite$.MODULE$.tinyTopics(); + Matrix topics = LDASuite.tinyTopics(); double[] topicConcentration = new double[topics.numRows()]; Arrays.fill(topicConcentration, 1.0D / topics.numRows()); LocalLDAModel model = new LocalLDAModel(topics, Vectors.dense(topicConcentration), 1D, 100D); @@ -110,8 +112,8 @@ public void distributedLDAModel() { assertEquals(roundedLocalTopicSummary.length, k); // Check: log probabilities - assert(model.logLikelihood() < 0.0); - assert(model.logPrior() < 0.0); + assertTrue(model.logLikelihood() < 0.0); + assertTrue(model.logPrior() < 0.0); // Check: topic distributions JavaPairRDD topicDistributions = model.javaTopicDistributions(); @@ -126,8 +128,12 @@ public Boolean call(Tuple2 tuple2) { assertEquals(topicDistributions.count(), nonEmptyCorpus.count()); // Check: javaTopTopicsPerDocuments - JavaRDD> topTopics = - model.javaTopTopicsPerDocument(3); + Tuple3 topTopics = model.javaTopTopicsPerDocument(3).first(); + Long docId = topTopics._1(); // confirm doc ID type + int[] topicIndices = topTopics._2(); + double[] topicWeights = topTopics._3(); + assertEquals(3, topicIndices.length); + assertEquals(3, topicWeights.length); } @Test @@ -177,18 +183,18 @@ public void localLdaMethods() { // check: logLikelihood. ArrayList> docsSingleWord = new ArrayList>(); - docsSingleWord.add(new Tuple2(Long.valueOf(0), Vectors.dense(1.0, 0.0, 0.0))); + docsSingleWord.add(new Tuple2(0L, Vectors.dense(1.0, 0.0, 0.0))); JavaPairRDD single = JavaPairRDD.fromJavaRDD(sc.parallelize(docsSingleWord)); double logLikelihood = toyModel.logLikelihood(single); } - private static int tinyK = LDASuite$.MODULE$.tinyK(); - private static int tinyVocabSize = LDASuite$.MODULE$.tinyVocabSize(); - private static Matrix tinyTopics = LDASuite$.MODULE$.tinyTopics(); + private static int tinyK = LDASuite.tinyK(); + private static int tinyVocabSize = LDASuite.tinyVocabSize(); + private static Matrix tinyTopics = LDASuite.tinyTopics(); private static Tuple2[] tinyTopicDescription = - LDASuite$.MODULE$.tinyTopicDescription(); + LDASuite.tinyTopicDescription(); private JavaPairRDD corpus; - private LocalLDAModel toyModel = LDASuite$.MODULE$.toyModel(); - private ArrayList> toyData = LDASuite$.MODULE$.javaToyData(); + private LocalLDAModel toyModel = LDASuite.toyModel(); + private ArrayList> toyData = LDASuite.javaToyData(); } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala index 926185e90bcf9..99e28499fd316 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala @@ -581,7 +581,7 @@ private[clustering] object LDASuite { def javaToyData: JArrayList[(java.lang.Long, Vector)] = { val javaData = new JArrayList[(java.lang.Long, Vector)] var i = 0 - while (i < toyData.size) { + while (i < toyData.length) { javaData.add((toyData(i)._1, toyData(i)._2)) i += 1 } From 7ecf0c46990c39df8aeddbd64ca33d01824bcc0a Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Fri, 14 Aug 2015 10:48:02 -0700 Subject: [PATCH 105/253] [SPARK-9956] [ML] Make trees work with one-category features This modifies DecisionTreeMetadata construction to treat 1-category features as continuous, so that trees do not fail with such features. It is important for the pipelines API, where VectorIndexer can automatically categorize certain features as categorical. As stated in the JIRA, this is a temp fix which we can improve upon later by automatically filtering out those features. That will take longer, though, since it will require careful indexing. Targeted for 1.5 and master CC: manishamde mengxr yanboliang Author: Joseph K. Bradley Closes #8187 from jkbradley/tree-1cat. --- .../tree/impl/DecisionTreeMetadata.scala | 27 ++++++++++++------- .../DecisionTreeClassifierSuite.scala | 13 +++++++++ 2 files changed, 30 insertions(+), 10 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala index 9fe264656ede7..21ee49c45788c 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala @@ -144,21 +144,28 @@ private[spark] object DecisionTreeMetadata extends Logging { val maxCategoriesForUnorderedFeature = ((math.log(maxPossibleBins / 2 + 1) / math.log(2.0)) + 1).floor.toInt strategy.categoricalFeaturesInfo.foreach { case (featureIndex, numCategories) => - // Decide if some categorical features should be treated as unordered features, - // which require 2 * ((1 << numCategories - 1) - 1) bins. - // We do this check with log values to prevent overflows in case numCategories is large. - // The next check is equivalent to: 2 * ((1 << numCategories - 1) - 1) <= maxBins - if (numCategories <= maxCategoriesForUnorderedFeature) { - unorderedFeatures.add(featureIndex) - numBins(featureIndex) = numUnorderedBins(numCategories) - } else { - numBins(featureIndex) = numCategories + // Hack: If a categorical feature has only 1 category, we treat it as continuous. + // TODO(SPARK-9957): Handle this properly by filtering out those features. + if (numCategories > 1) { + // Decide if some categorical features should be treated as unordered features, + // which require 2 * ((1 << numCategories - 1) - 1) bins. + // We do this check with log values to prevent overflows in case numCategories is large. + // The next check is equivalent to: 2 * ((1 << numCategories - 1) - 1) <= maxBins + if (numCategories <= maxCategoriesForUnorderedFeature) { + unorderedFeatures.add(featureIndex) + numBins(featureIndex) = numUnorderedBins(numCategories) + } else { + numBins(featureIndex) = numCategories + } } } } else { // Binary classification or regression strategy.categoricalFeaturesInfo.foreach { case (featureIndex, numCategories) => - numBins(featureIndex) = numCategories + // If a categorical feature has only 1 category, we treat it as continuous: SPARK-9957 + if (numCategories > 1) { + numBins(featureIndex) = numCategories + } } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala index 4b7c5d3f23d2c..f680d8d3c4cc2 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala @@ -261,6 +261,19 @@ class DecisionTreeClassifierSuite extends SparkFunSuite with MLlibTestSparkConte } } + test("training with 1-category categorical feature") { + val data = sc.parallelize(Seq( + LabeledPoint(0, Vectors.dense(0, 2, 3)), + LabeledPoint(1, Vectors.dense(0, 3, 1)), + LabeledPoint(0, Vectors.dense(0, 2, 2)), + LabeledPoint(1, Vectors.dense(0, 3, 9)), + LabeledPoint(0, Vectors.dense(0, 2, 6)) + )) + val df = TreeTests.setMetadata(data, Map(0 -> 1), 2) + val dt = new DecisionTreeClassifier().setMaxDepth(3) + val model = dt.fit(df) + } + ///////////////////////////////////////////////////////////////////////////// // Tests of model save/load ///////////////////////////////////////////////////////////////////////////// From a7317ccdc20d001e5b7f5277b0535923468bfbc6 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Fri, 14 Aug 2015 11:22:10 -0700 Subject: [PATCH 106/253] [SPARK-8744] [ML] Add a public constructor to StringIndexer It would be helpful to allow users to pass a pre-computed index to create an indexer, rather than always going through StringIndexer to create the model. Author: Holden Karau Closes #7267 from holdenk/SPARK-8744-StringIndexerModel-should-have-public-constructor. --- .../scala/org/apache/spark/ml/feature/StringIndexer.scala | 4 +++- .../org/apache/spark/ml/feature/StringIndexerSuite.scala | 2 ++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala index 9f6e7b6b6b274..63475780a6ff9 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala @@ -102,10 +102,12 @@ class StringIndexer(override val uid: String) extends Estimator[StringIndexerMod * This is a temporary fix for the case when target labels do not exist during prediction. */ @Experimental -class StringIndexerModel private[ml] ( +class StringIndexerModel ( override val uid: String, labels: Array[String]) extends Model[StringIndexerModel] with StringIndexerBase { + def this(labels: Array[String]) = this(Identifiable.randomUID("strIdx"), labels) + private val labelToIndex: OpenHashMap[String, Double] = { val n = labels.length val map = new OpenHashMap[String, Double](n) diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala index fa918ce64877c..0b4c8ba71ee61 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala @@ -30,7 +30,9 @@ class StringIndexerSuite extends SparkFunSuite with MLlibTestSparkContext { test("params") { ParamsSuite.checkParams(new StringIndexer) val model = new StringIndexerModel("indexer", Array("a", "b")) + val modelWithoutUid = new StringIndexerModel(Array("a", "b")) ParamsSuite.checkParams(model) + ParamsSuite.checkParams(modelWithoutUid) } test("StringIndexer") { From 34d610be854d2a975d9c1e232d87433b85add6fd Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Fri, 14 Aug 2015 12:00:01 -0700 Subject: [PATCH 107/253] [SPARK-9929] [SQL] support metadata in withColumn in MLlib sometimes we need to set metadata for the new column, thus we will alias the new column with metadata before call `withColumn` and in `withColumn` we alias this clolumn again. Here I overloaded `withColumn` to allow user set metadata, just like what we did for `Column.as`. Author: Wenchen Fan Closes #8159 from cloud-fan/withColumn. --- .../spark/ml/classification/OneVsRest.scala | 6 +++--- .../apache/spark/ml/feature/Bucketizer.scala | 2 +- .../apache/spark/ml/feature/VectorIndexer.scala | 2 +- .../apache/spark/ml/feature/VectorSlicer.scala | 3 +-- .../scala/org/apache/spark/sql/DataFrame.scala | 17 +++++++++++++++++ 5 files changed, 23 insertions(+), 7 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala index 1132d8046df67..c62e132f5d533 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala @@ -131,7 +131,7 @@ final class OneVsRestModel private[ml] ( // output label and label metadata as prediction aggregatedDataset - .withColumn($(predictionCol), labelUDF(col(accColName)).as($(predictionCol), labelMetadata)) + .withColumn($(predictionCol), labelUDF(col(accColName)), labelMetadata) .drop(accColName) } @@ -203,8 +203,8 @@ final class OneVsRest(override val uid: String) // TODO: use when ... otherwise after SPARK-7321 is merged val newLabelMeta = BinaryAttribute.defaultAttr.withName("label").toMetadata() val labelColName = "mc2b$" + index - val labelUDFWithNewMeta = labelUDF(col($(labelCol))).as(labelColName, newLabelMeta) - val trainingDataset = multiclassLabeled.withColumn(labelColName, labelUDFWithNewMeta) + val trainingDataset = + multiclassLabeled.withColumn(labelColName, labelUDF(col($(labelCol))), newLabelMeta) val classifier = getClassifier val paramMap = new ParamMap() paramMap.put(classifier.labelCol -> labelColName) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala index cfca494dcf468..6fdf25b015b0b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala @@ -75,7 +75,7 @@ final class Bucketizer(override val uid: String) } val newCol = bucketizer(dataset($(inputCol))) val newField = prepOutputField(dataset.schema) - dataset.withColumn($(outputCol), newCol.as($(outputCol), newField.metadata)) + dataset.withColumn($(outputCol), newCol, newField.metadata) } private def prepOutputField(schema: StructType): StructField = { diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala index 6875aefe065bb..61b925c0fdc07 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala @@ -341,7 +341,7 @@ class VectorIndexerModel private[ml] ( val newField = prepOutputField(dataset.schema) val transformUDF = udf { (vector: Vector) => transformFunc(vector) } val newCol = transformUDF(dataset($(inputCol))) - dataset.withColumn($(outputCol), newCol.as($(outputCol), newField.metadata)) + dataset.withColumn($(outputCol), newCol, newField.metadata) } override def transformSchema(schema: StructType): StructType = { diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala index 772bebeff214b..c5c2272270792 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala @@ -119,8 +119,7 @@ final class VectorSlicer(override val uid: String) case features: SparseVector => features.slice(inds) } } - dataset.withColumn($(outputCol), - slicer(dataset($(inputCol))).as($(outputCol), outputAttr.toMetadata())) + dataset.withColumn($(outputCol), slicer(dataset($(inputCol))), outputAttr.toMetadata()) } /** Get the feature indices in order: indices, names */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index c466d9e6cb349..cf75e64e884b4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -1149,6 +1149,23 @@ class DataFrame private[sql]( } } + /** + * Returns a new [[DataFrame]] by adding a column with metadata. + */ + private[spark] def withColumn(colName: String, col: Column, metadata: Metadata): DataFrame = { + val resolver = sqlContext.analyzer.resolver + val replaced = schema.exists(f => resolver(f.name, colName)) + if (replaced) { + val colNames = schema.map { field => + val name = field.name + if (resolver(name, colName)) col.as(colName, metadata) else Column(name) + } + select(colNames : _*) + } else { + select(Column("*"), col.as(colName, metadata)) + } + } + /** * Returns a new [[DataFrame]] with a column renamed. * This is a no-op if schema doesn't contain existingName. From 57c2d08800a506614b461b5a505a1dd1a28a8908 Mon Sep 17 00:00:00 2001 From: Neelesh Srinivas Salian Date: Fri, 14 Aug 2015 20:03:50 +0100 Subject: [PATCH 108/253] [SPARK-9923] [CORE] ShuffleMapStage.numAvailableOutputs should be an Int instead of Long Modified type of ShuffleMapStage.numAvailableOutputs from Long to Int Author: Neelesh Srinivas Salian Closes #8183 from nssalian/SPARK-9923. --- .../main/scala/org/apache/spark/scheduler/ShuffleMapStage.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapStage.scala b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapStage.scala index 66c75f325fcde..48d8d8e9c4b78 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapStage.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapStage.scala @@ -37,7 +37,7 @@ private[spark] class ShuffleMapStage( override def toString: String = "ShuffleMapStage " + id - var numAvailableOutputs: Long = 0 + var numAvailableOutputs: Int = 0 def isAvailable: Boolean = numAvailableOutputs == numPartitions From 3bc55287220b1248e935bf817d880ff176ad4d3b Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Fri, 14 Aug 2015 12:32:35 -0700 Subject: [PATCH 109/253] [SPARK-9946] [SPARK-9589] [SQL] fix NPE and thread-safety in TaskMemoryManager Currently, we access the `page.pageNumer` after it's freed, that could be modified by other thread, cause NPE. The same TaskMemoryManager could be used by multiple threads (for example, Python UDF and TransportScript), so it should be thread safe to allocate/free memory/page. The underlying Bitset and HashSet are not thread safe, we should put them inside a synchronized block. cc JoshRosen Author: Davies Liu Closes #8177 from davies/memory_manager. --- .../unsafe/memory/TaskMemoryManager.java | 44 ++++++++++++------- 1 file changed, 28 insertions(+), 16 deletions(-) diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/memory/TaskMemoryManager.java b/unsafe/src/main/java/org/apache/spark/unsafe/memory/TaskMemoryManager.java index 358bb37250158..ca70d7f4a4311 100644 --- a/unsafe/src/main/java/org/apache/spark/unsafe/memory/TaskMemoryManager.java +++ b/unsafe/src/main/java/org/apache/spark/unsafe/memory/TaskMemoryManager.java @@ -144,14 +144,16 @@ public MemoryBlock allocatePage(long size) { public void freePage(MemoryBlock page) { assert (page.pageNumber != -1) : "Called freePage() on memory that wasn't allocated with allocatePage()"; - executorMemoryManager.free(page); + assert(allocatedPages.get(page.pageNumber)); + pageTable[page.pageNumber] = null; synchronized (this) { allocatedPages.clear(page.pageNumber); } - pageTable[page.pageNumber] = null; if (logger.isTraceEnabled()) { logger.trace("Freed page number {} ({} bytes)", page.pageNumber, page.size()); } + // Cannot access a page once it's freed. + executorMemoryManager.free(page); } /** @@ -166,7 +168,9 @@ public void freePage(MemoryBlock page) { public MemoryBlock allocate(long size) throws OutOfMemoryError { assert(size > 0) : "Size must be positive, but got " + size; final MemoryBlock memory = executorMemoryManager.allocate(size); - allocatedNonPageMemory.add(memory); + synchronized(allocatedNonPageMemory) { + allocatedNonPageMemory.add(memory); + } return memory; } @@ -176,8 +180,10 @@ public MemoryBlock allocate(long size) throws OutOfMemoryError { public void free(MemoryBlock memory) { assert (memory.pageNumber == -1) : "Should call freePage() for pages, not free()"; executorMemoryManager.free(memory); - final boolean wasAlreadyRemoved = !allocatedNonPageMemory.remove(memory); - assert (!wasAlreadyRemoved) : "Called free() on memory that was already freed!"; + synchronized(allocatedNonPageMemory) { + final boolean wasAlreadyRemoved = !allocatedNonPageMemory.remove(memory); + assert (!wasAlreadyRemoved) : "Called free() on memory that was already freed!"; + } } /** @@ -223,9 +229,10 @@ public Object getPage(long pagePlusOffsetAddress) { if (inHeap) { final int pageNumber = decodePageNumber(pagePlusOffsetAddress); assert (pageNumber >= 0 && pageNumber < PAGE_TABLE_SIZE); - final Object page = pageTable[pageNumber].getBaseObject(); + final MemoryBlock page = pageTable[pageNumber]; assert (page != null); - return page; + assert (page.getBaseObject() != null); + return page.getBaseObject(); } else { return null; } @@ -244,7 +251,9 @@ public long getOffsetInPage(long pagePlusOffsetAddress) { // converted the absolute address into a relative address. Here, we invert that operation: final int pageNumber = decodePageNumber(pagePlusOffsetAddress); assert (pageNumber >= 0 && pageNumber < PAGE_TABLE_SIZE); - return pageTable[pageNumber].getBaseOffset() + offsetInPage; + final MemoryBlock page = pageTable[pageNumber]; + assert (page != null); + return page.getBaseOffset() + offsetInPage; } } @@ -260,14 +269,17 @@ public long cleanUpAllAllocatedMemory() { freePage(page); } } - final Iterator iter = allocatedNonPageMemory.iterator(); - while (iter.hasNext()) { - final MemoryBlock memory = iter.next(); - freedBytes += memory.size(); - // We don't call free() here because that calls Set.remove, which would lead to a - // ConcurrentModificationException here. - executorMemoryManager.free(memory); - iter.remove(); + + synchronized (allocatedNonPageMemory) { + final Iterator iter = allocatedNonPageMemory.iterator(); + while (iter.hasNext()) { + final MemoryBlock memory = iter.next(); + freedBytes += memory.size(); + // We don't call free() here because that calls Set.remove, which would lead to a + // ConcurrentModificationException here. + executorMemoryManager.free(memory); + iter.remove(); + } } return freedBytes; } From ece00566e4d5f38585f2810bef38e526cae7d25e Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Fri, 14 Aug 2015 12:37:21 -0700 Subject: [PATCH 110/253] [SPARK-9561] Re-enable BroadcastJoinSuite We can do this now that SPARK-9580 is resolved. Author: Andrew Or Closes #8208 from andrewor14/reenable-sql-tests. --- .../execution/joins/BroadcastJoinSuite.scala | 145 +++++++++--------- 1 file changed, 69 insertions(+), 76 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala index 0554e11d252ba..53a0e53fd7719 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala @@ -15,80 +15,73 @@ * limitations under the License. */ -// TODO: uncomment the test here! It is currently failing due to -// bad interaction with org.apache.spark.sql.test.TestSQLContext. +package org.apache.spark.sql.execution.joins -// scalastyle:off -//package org.apache.spark.sql.execution.joins -// -//import scala.reflect.ClassTag -// -//import org.scalatest.BeforeAndAfterAll -// -//import org.apache.spark.{AccumulatorSuite, SparkConf, SparkContext} -//import org.apache.spark.sql.functions._ -//import org.apache.spark.sql.{SQLConf, SQLContext, QueryTest} -// -///** -// * Test various broadcast join operators with unsafe enabled. -// * -// * This needs to be its own suite because [[org.apache.spark.sql.test.TestSQLContext]] runs -// * in local mode, but for tests in this suite we need to run Spark in local-cluster mode. -// * In particular, the use of [[org.apache.spark.unsafe.map.BytesToBytesMap]] in -// * [[org.apache.spark.sql.execution.joins.UnsafeHashedRelation]] is not triggered without -// * serializing the hashed relation, which does not happen in local mode. -// */ -//class BroadcastJoinSuite extends QueryTest with BeforeAndAfterAll { -// private var sc: SparkContext = null -// private var sqlContext: SQLContext = null -// -// /** -// * Create a new [[SQLContext]] running in local-cluster mode with unsafe and codegen enabled. -// */ -// override def beforeAll(): Unit = { -// super.beforeAll() -// val conf = new SparkConf() -// .setMaster("local-cluster[2,1,1024]") -// .setAppName("testing") -// sc = new SparkContext(conf) -// sqlContext = new SQLContext(sc) -// sqlContext.setConf(SQLConf.UNSAFE_ENABLED, true) -// sqlContext.setConf(SQLConf.CODEGEN_ENABLED, true) -// } -// -// override def afterAll(): Unit = { -// sc.stop() -// sc = null -// sqlContext = null -// } -// -// /** -// * Test whether the specified broadcast join updates the peak execution memory accumulator. -// */ -// private def testBroadcastJoin[T: ClassTag](name: String, joinType: String): Unit = { -// AccumulatorSuite.verifyPeakExecutionMemorySet(sc, name) { -// val df1 = sqlContext.createDataFrame(Seq((1, "4"), (2, "2"))).toDF("key", "value") -// val df2 = sqlContext.createDataFrame(Seq((1, "1"), (2, "2"))).toDF("key", "value") -// // Comparison at the end is for broadcast left semi join -// val joinExpression = df1("key") === df2("key") && df1("value") > df2("value") -// val df3 = df1.join(broadcast(df2), joinExpression, joinType) -// val plan = df3.queryExecution.executedPlan -// assert(plan.collect { case p: T => p }.size === 1) -// plan.executeCollect() -// } -// } -// -// test("unsafe broadcast hash join updates peak execution memory") { -// testBroadcastJoin[BroadcastHashJoin]("unsafe broadcast hash join", "inner") -// } -// -// test("unsafe broadcast hash outer join updates peak execution memory") { -// testBroadcastJoin[BroadcastHashOuterJoin]("unsafe broadcast hash outer join", "left_outer") -// } -// -// test("unsafe broadcast left semi join updates peak execution memory") { -// testBroadcastJoin[BroadcastLeftSemiJoinHash]("unsafe broadcast left semi join", "leftsemi") -// } -// -//} -// scalastyle:on +import scala.reflect.ClassTag + +import org.scalatest.BeforeAndAfterAll + +import org.apache.spark.{AccumulatorSuite, SparkConf, SparkContext} +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.{SQLConf, SQLContext, QueryTest} + +/** + * Test various broadcast join operators with unsafe enabled. + * + * Tests in this suite we need to run Spark in local-cluster mode. In particular, the use of + * unsafe map in [[org.apache.spark.sql.execution.joins.UnsafeHashedRelation]] is not triggered + * without serializing the hashed relation, which does not happen in local mode. + */ +class BroadcastJoinSuite extends QueryTest with BeforeAndAfterAll { + private var sc: SparkContext = null + private var sqlContext: SQLContext = null + + /** + * Create a new [[SQLContext]] running in local-cluster mode with unsafe and codegen enabled. + */ + override def beforeAll(): Unit = { + super.beforeAll() + val conf = new SparkConf() + .setMaster("local-cluster[2,1,1024]") + .setAppName("testing") + sc = new SparkContext(conf) + sqlContext = new SQLContext(sc) + sqlContext.setConf(SQLConf.UNSAFE_ENABLED, true) + sqlContext.setConf(SQLConf.CODEGEN_ENABLED, true) + } + + override def afterAll(): Unit = { + sc.stop() + sc = null + sqlContext = null + } + + /** + * Test whether the specified broadcast join updates the peak execution memory accumulator. + */ + private def testBroadcastJoin[T: ClassTag](name: String, joinType: String): Unit = { + AccumulatorSuite.verifyPeakExecutionMemorySet(sc, name) { + val df1 = sqlContext.createDataFrame(Seq((1, "4"), (2, "2"))).toDF("key", "value") + val df2 = sqlContext.createDataFrame(Seq((1, "1"), (2, "2"))).toDF("key", "value") + // Comparison at the end is for broadcast left semi join + val joinExpression = df1("key") === df2("key") && df1("value") > df2("value") + val df3 = df1.join(broadcast(df2), joinExpression, joinType) + val plan = df3.queryExecution.executedPlan + assert(plan.collect { case p: T => p }.size === 1) + plan.executeCollect() + } + } + + test("unsafe broadcast hash join updates peak execution memory") { + testBroadcastJoin[BroadcastHashJoin]("unsafe broadcast hash join", "inner") + } + + test("unsafe broadcast hash outer join updates peak execution memory") { + testBroadcastJoin[BroadcastHashOuterJoin]("unsafe broadcast hash outer join", "left_outer") + } + + test("unsafe broadcast left semi join updates peak execution memory") { + testBroadcastJoin[BroadcastLeftSemiJoinHash]("unsafe broadcast left semi join", "leftsemi") + } + +} From ffa05c84fe75663fc33f3d954d1cb1e084ab3280 Mon Sep 17 00:00:00 2001 From: MechCoder Date: Fri, 14 Aug 2015 12:46:05 -0700 Subject: [PATCH 111/253] [SPARK-9828] [PYSPARK] Mutable values should not be default arguments Author: MechCoder Closes #8110 from MechCoder/spark-9828. --- python/pyspark/ml/evaluation.py | 4 +++- python/pyspark/ml/param/__init__.py | 26 +++++++++++++++++--------- python/pyspark/ml/pipeline.py | 4 ++-- python/pyspark/ml/tuning.py | 8 ++++++-- python/pyspark/rdd.py | 5 ++++- python/pyspark/sql/readwriter.py | 8 ++++++-- python/pyspark/statcounter.py | 4 +++- python/pyspark/streaming/kafka.py | 12 +++++++++--- 8 files changed, 50 insertions(+), 21 deletions(-) diff --git a/python/pyspark/ml/evaluation.py b/python/pyspark/ml/evaluation.py index 2734092575ad9..e23ce053baeb7 100644 --- a/python/pyspark/ml/evaluation.py +++ b/python/pyspark/ml/evaluation.py @@ -46,7 +46,7 @@ def _evaluate(self, dataset): """ raise NotImplementedError() - def evaluate(self, dataset, params={}): + def evaluate(self, dataset, params=None): """ Evaluates the output with optional parameters. @@ -56,6 +56,8 @@ def evaluate(self, dataset, params={}): params :return: metric """ + if params is None: + params = dict() if isinstance(params, dict): if params: return self.copy(params)._evaluate(dataset) diff --git a/python/pyspark/ml/param/__init__.py b/python/pyspark/ml/param/__init__.py index 7845536161e07..eeeac49b21980 100644 --- a/python/pyspark/ml/param/__init__.py +++ b/python/pyspark/ml/param/__init__.py @@ -60,14 +60,16 @@ class Params(Identifiable): __metaclass__ = ABCMeta - #: internal param map for user-supplied values param map - _paramMap = {} + def __init__(self): + super(Params, self).__init__() + #: internal param map for user-supplied values param map + self._paramMap = {} - #: internal param map for default values - _defaultParamMap = {} + #: internal param map for default values + self._defaultParamMap = {} - #: value returned by :py:func:`params` - _params = None + #: value returned by :py:func:`params` + self._params = None @property def params(self): @@ -155,7 +157,7 @@ def getOrDefault(self, param): else: return self._defaultParamMap[param] - def extractParamMap(self, extra={}): + def extractParamMap(self, extra=None): """ Extracts the embedded default param values and user-supplied values, and then merges them with extra values from input into @@ -165,12 +167,14 @@ def extractParamMap(self, extra={}): :param extra: extra param values :return: merged param map """ + if extra is None: + extra = dict() paramMap = self._defaultParamMap.copy() paramMap.update(self._paramMap) paramMap.update(extra) return paramMap - def copy(self, extra={}): + def copy(self, extra=None): """ Creates a copy of this instance with the same uid and some extra params. The default implementation creates a @@ -181,6 +185,8 @@ def copy(self, extra={}): :param extra: Extra parameters to copy to the new instance :return: Copy of this instance """ + if extra is None: + extra = dict() that = copy.copy(self) that._paramMap = self.extractParamMap(extra) return that @@ -233,7 +239,7 @@ def _setDefault(self, **kwargs): self._defaultParamMap[getattr(self, param)] = value return self - def _copyValues(self, to, extra={}): + def _copyValues(self, to, extra=None): """ Copies param values from this instance to another instance for params shared by them. @@ -241,6 +247,8 @@ def _copyValues(self, to, extra={}): :param extra: extra params to be copied :return: the target instance with param values copied """ + if extra is None: + extra = dict() paramMap = self.extractParamMap(extra) for p in self.params: if p in paramMap and to.hasParam(p.name): diff --git a/python/pyspark/ml/pipeline.py b/python/pyspark/ml/pipeline.py index 9889f56cac9e4..13cf2b0f7bbd9 100644 --- a/python/pyspark/ml/pipeline.py +++ b/python/pyspark/ml/pipeline.py @@ -141,7 +141,7 @@ class Pipeline(Estimator): @keyword_only def __init__(self, stages=None): """ - __init__(self, stages=[]) + __init__(self, stages=None) """ if stages is None: stages = [] @@ -170,7 +170,7 @@ def getStages(self): @keyword_only def setParams(self, stages=None): """ - setParams(self, stages=[]) + setParams(self, stages=None) Sets params for Pipeline. """ if stages is None: diff --git a/python/pyspark/ml/tuning.py b/python/pyspark/ml/tuning.py index 0bf988fd72f14..dcfee6a3170ab 100644 --- a/python/pyspark/ml/tuning.py +++ b/python/pyspark/ml/tuning.py @@ -227,7 +227,9 @@ def _fit(self, dataset): bestModel = est.fit(dataset, epm[bestIndex]) return CrossValidatorModel(bestModel) - def copy(self, extra={}): + def copy(self, extra=None): + if extra is None: + extra = dict() newCV = Params.copy(self, extra) if self.isSet(self.estimator): newCV.setEstimator(self.getEstimator().copy(extra)) @@ -250,7 +252,7 @@ def __init__(self, bestModel): def _transform(self, dataset): return self.bestModel.transform(dataset) - def copy(self, extra={}): + def copy(self, extra=None): """ Creates a copy of this instance with a randomly generated uid and some extra params. This copies the underlying bestModel, @@ -259,6 +261,8 @@ def copy(self, extra={}): :param extra: Extra parameters to copy to the new instance :return: Copy of this instance """ + if extra is None: + extra = dict() return CrossValidatorModel(self.bestModel.copy(extra)) diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index fa8e0a0574a62..9ef60a7e2c84b 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -700,7 +700,7 @@ def groupBy(self, f, numPartitions=None): return self.map(lambda x: (f(x), x)).groupByKey(numPartitions) @ignore_unicode_prefix - def pipe(self, command, env={}, checkCode=False): + def pipe(self, command, env=None, checkCode=False): """ Return an RDD created by piping elements to a forked external process. @@ -709,6 +709,9 @@ def pipe(self, command, env={}, checkCode=False): :param checkCode: whether or not to check the return value of the shell command. """ + if env is None: + env = dict() + def func(iterator): pipe = Popen( shlex.split(command), env=env, stdin=PIPE, stdout=PIPE) diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index bf6ac084bbbf8..78247c8fa7372 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -182,7 +182,7 @@ def orc(self, path): @since(1.4) def jdbc(self, url, table, column=None, lowerBound=None, upperBound=None, numPartitions=None, - predicates=None, properties={}): + predicates=None, properties=None): """ Construct a :class:`DataFrame` representing the database table accessible via JDBC URL `url` named `table` and connection `properties`. @@ -208,6 +208,8 @@ def jdbc(self, url, table, column=None, lowerBound=None, upperBound=None, numPar should be included. :return: a DataFrame """ + if properties is None: + properties = dict() jprop = JavaClass("java.util.Properties", self._sqlContext._sc._gateway._gateway_client)() for k in properties: jprop.setProperty(k, properties[k]) @@ -427,7 +429,7 @@ def orc(self, path, mode=None, partitionBy=None): self._jwrite.orc(path) @since(1.4) - def jdbc(self, url, table, mode=None, properties={}): + def jdbc(self, url, table, mode=None, properties=None): """Saves the content of the :class:`DataFrame` to a external database table via JDBC. .. note:: Don't create too many partitions in parallel on a large cluster;\ @@ -445,6 +447,8 @@ def jdbc(self, url, table, mode=None, properties={}): arbitrary string tag/value. Normally at least a "user" and "password" property should be included. """ + if properties is None: + properties = dict() jprop = JavaClass("java.util.Properties", self._sqlContext._sc._gateway._gateway_client)() for k in properties: jprop.setProperty(k, properties[k]) diff --git a/python/pyspark/statcounter.py b/python/pyspark/statcounter.py index 944fa414b0c0e..0fee3b2096826 100644 --- a/python/pyspark/statcounter.py +++ b/python/pyspark/statcounter.py @@ -30,7 +30,9 @@ class StatCounter(object): - def __init__(self, values=[]): + def __init__(self, values=None): + if values is None: + values = list() self.n = 0 # Running count of our values self.mu = 0.0 # Running mean of our values self.m2 = 0.0 # Running variance numerator (sum of (x - mean)^2) diff --git a/python/pyspark/streaming/kafka.py b/python/pyspark/streaming/kafka.py index 33dd596335b47..dc5b7fd878aef 100644 --- a/python/pyspark/streaming/kafka.py +++ b/python/pyspark/streaming/kafka.py @@ -35,7 +35,7 @@ def utf8_decoder(s): class KafkaUtils(object): @staticmethod - def createStream(ssc, zkQuorum, groupId, topics, kafkaParams={}, + def createStream(ssc, zkQuorum, groupId, topics, kafkaParams=None, storageLevel=StorageLevel.MEMORY_AND_DISK_SER_2, keyDecoder=utf8_decoder, valueDecoder=utf8_decoder): """ @@ -52,6 +52,8 @@ def createStream(ssc, zkQuorum, groupId, topics, kafkaParams={}, :param valueDecoder: A function used to decode value (default is utf8_decoder) :return: A DStream object """ + if kafkaParams is None: + kafkaParams = dict() kafkaParams.update({ "zookeeper.connect": zkQuorum, "group.id": groupId, @@ -77,7 +79,7 @@ def createStream(ssc, zkQuorum, groupId, topics, kafkaParams={}, return stream.map(lambda k_v: (keyDecoder(k_v[0]), valueDecoder(k_v[1]))) @staticmethod - def createDirectStream(ssc, topics, kafkaParams, fromOffsets={}, + def createDirectStream(ssc, topics, kafkaParams, fromOffsets=None, keyDecoder=utf8_decoder, valueDecoder=utf8_decoder): """ .. note:: Experimental @@ -105,6 +107,8 @@ def createDirectStream(ssc, topics, kafkaParams, fromOffsets={}, :param valueDecoder: A function used to decode value (default is utf8_decoder). :return: A DStream object """ + if fromOffsets is None: + fromOffsets = dict() if not isinstance(topics, list): raise TypeError("topics should be list") if not isinstance(kafkaParams, dict): @@ -129,7 +133,7 @@ def createDirectStream(ssc, topics, kafkaParams, fromOffsets={}, return KafkaDStream(stream._jdstream, ssc, stream._jrdd_deserializer) @staticmethod - def createRDD(sc, kafkaParams, offsetRanges, leaders={}, + def createRDD(sc, kafkaParams, offsetRanges, leaders=None, keyDecoder=utf8_decoder, valueDecoder=utf8_decoder): """ .. note:: Experimental @@ -145,6 +149,8 @@ def createRDD(sc, kafkaParams, offsetRanges, leaders={}, :param valueDecoder: A function used to decode value (default is utf8_decoder) :return: A RDD object """ + if leaders is None: + leaders = dict() if not isinstance(kafkaParams, dict): raise TypeError("kafkaParams should be dict") if not isinstance(offsetRanges, list): From 33bae585d4cb25aed2ac32e0d1248f78cc65318b Mon Sep 17 00:00:00 2001 From: Carson Wang Date: Fri, 14 Aug 2015 13:38:25 -0700 Subject: [PATCH 112/253] [SPARK-9809] Task crashes because the internal accumulators are not properly initialized When a stage failed and another stage was resubmitted with only part of partitions to compute, all the tasks failed with error message: java.util.NoSuchElementException: key not found: peakExecutionMemory. This is because the internal accumulators are not properly initialized for this stage while other codes assume the internal accumulators always exist. Author: Carson Wang Closes #8090 from carsonwang/SPARK-9809. --- .../main/scala/org/apache/spark/scheduler/DAGScheduler.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 7ab5ccf50adb7..f1c63d08766c2 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -790,9 +790,10 @@ class DAGScheduler( } } + // Create internal accumulators if the stage has no accumulators initialized. // Reset internal accumulators only if this stage is not partially submitted // Otherwise, we may override existing accumulator values from some tasks - if (allPartitions == partitionsToCompute) { + if (stage.internalAccumulators.isEmpty || allPartitions == partitionsToCompute) { stage.resetInternalAccumulators() } From 6518ef63037aa56b541927f99ad26744f91098ce Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Fri, 14 Aug 2015 13:42:53 -0700 Subject: [PATCH 113/253] [SPARK-9948] Fix flaky AccumulatorSuite - internal accumulators In these tests, we use a custom listener and we assert on fields in the stage / task completion events. However, these events are posted in a separate thread so they're not guaranteed to be posted in time. This commit fixes this flakiness through a job end registration callback. Author: Andrew Or Closes #8176 from andrewor14/fix-accumulator-suite. --- .../org/apache/spark/AccumulatorSuite.scala | 153 +++++++++++------- 1 file changed, 92 insertions(+), 61 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala b/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala index 0eb2293a9d063..5b84acf40be4e 100644 --- a/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala +++ b/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala @@ -182,26 +182,30 @@ class AccumulatorSuite extends SparkFunSuite with Matchers with LocalSparkContex sc = new SparkContext("local", "test") sc.addSparkListener(listener) // Have each task add 1 to the internal accumulator - sc.parallelize(1 to 100, numPartitions).mapPartitions { iter => + val rdd = sc.parallelize(1 to 100, numPartitions).mapPartitions { iter => TaskContext.get().internalMetricsToAccumulators(TEST_ACCUMULATOR) += 1 iter - }.count() - val stageInfos = listener.getCompletedStageInfos - val taskInfos = listener.getCompletedTaskInfos - assert(stageInfos.size === 1) - assert(taskInfos.size === numPartitions) - // The accumulator values should be merged in the stage - val stageAccum = findAccumulableInfo(stageInfos.head.accumulables.values, TEST_ACCUMULATOR) - assert(stageAccum.value.toLong === numPartitions) - // The accumulator should be updated locally on each task - val taskAccumValues = taskInfos.map { taskInfo => - val taskAccum = findAccumulableInfo(taskInfo.accumulables, TEST_ACCUMULATOR) - assert(taskAccum.update.isDefined) - assert(taskAccum.update.get.toLong === 1) - taskAccum.value.toLong } - // Each task should keep track of the partial value on the way, i.e. 1, 2, ... numPartitions - assert(taskAccumValues.sorted === (1L to numPartitions).toSeq) + // Register asserts in job completion callback to avoid flakiness + listener.registerJobCompletionCallback { _ => + val stageInfos = listener.getCompletedStageInfos + val taskInfos = listener.getCompletedTaskInfos + assert(stageInfos.size === 1) + assert(taskInfos.size === numPartitions) + // The accumulator values should be merged in the stage + val stageAccum = findAccumulableInfo(stageInfos.head.accumulables.values, TEST_ACCUMULATOR) + assert(stageAccum.value.toLong === numPartitions) + // The accumulator should be updated locally on each task + val taskAccumValues = taskInfos.map { taskInfo => + val taskAccum = findAccumulableInfo(taskInfo.accumulables, TEST_ACCUMULATOR) + assert(taskAccum.update.isDefined) + assert(taskAccum.update.get.toLong === 1) + taskAccum.value.toLong + } + // Each task should keep track of the partial value on the way, i.e. 1, 2, ... numPartitions + assert(taskAccumValues.sorted === (1L to numPartitions).toSeq) + } + rdd.count() } test("internal accumulators in multiple stages") { @@ -211,7 +215,7 @@ class AccumulatorSuite extends SparkFunSuite with Matchers with LocalSparkContex sc.addSparkListener(listener) // Each stage creates its own set of internal accumulators so the // values for the same metric should not be mixed up across stages - sc.parallelize(1 to 100, numPartitions) + val rdd = sc.parallelize(1 to 100, numPartitions) .map { i => (i, i) } .mapPartitions { iter => TaskContext.get().internalMetricsToAccumulators(TEST_ACCUMULATOR) += 1 @@ -227,16 +231,20 @@ class AccumulatorSuite extends SparkFunSuite with Matchers with LocalSparkContex TaskContext.get().internalMetricsToAccumulators(TEST_ACCUMULATOR) += 100 iter } - .count() - // We ran 3 stages, and the accumulator values should be distinct - val stageInfos = listener.getCompletedStageInfos - assert(stageInfos.size === 3) - val firstStageAccum = findAccumulableInfo(stageInfos(0).accumulables.values, TEST_ACCUMULATOR) - val secondStageAccum = findAccumulableInfo(stageInfos(1).accumulables.values, TEST_ACCUMULATOR) - val thirdStageAccum = findAccumulableInfo(stageInfos(2).accumulables.values, TEST_ACCUMULATOR) - assert(firstStageAccum.value.toLong === numPartitions) - assert(secondStageAccum.value.toLong === numPartitions * 10) - assert(thirdStageAccum.value.toLong === numPartitions * 2 * 100) + // Register asserts in job completion callback to avoid flakiness + listener.registerJobCompletionCallback { _ => + // We ran 3 stages, and the accumulator values should be distinct + val stageInfos = listener.getCompletedStageInfos + assert(stageInfos.size === 3) + val (firstStageAccum, secondStageAccum, thirdStageAccum) = + (findAccumulableInfo(stageInfos(0).accumulables.values, TEST_ACCUMULATOR), + findAccumulableInfo(stageInfos(1).accumulables.values, TEST_ACCUMULATOR), + findAccumulableInfo(stageInfos(2).accumulables.values, TEST_ACCUMULATOR)) + assert(firstStageAccum.value.toLong === numPartitions) + assert(secondStageAccum.value.toLong === numPartitions * 10) + assert(thirdStageAccum.value.toLong === numPartitions * 2 * 100) + } + rdd.count() } test("internal accumulators in fully resubmitted stages") { @@ -268,7 +276,7 @@ class AccumulatorSuite extends SparkFunSuite with Matchers with LocalSparkContex // This says use 1 core and retry tasks up to 2 times sc = new SparkContext("local[1, 2]", "test") sc.addSparkListener(listener) - sc.parallelize(1 to 100, numPartitions).mapPartitionsWithIndex { case (i, iter) => + val rdd = sc.parallelize(1 to 100, numPartitions).mapPartitionsWithIndex { case (i, iter) => val taskContext = TaskContext.get() taskContext.internalMetricsToAccumulators(TEST_ACCUMULATOR) += 1 // Fail the first attempts of a subset of the tasks @@ -276,28 +284,32 @@ class AccumulatorSuite extends SparkFunSuite with Matchers with LocalSparkContex throw new Exception("Failing a task intentionally.") } iter - }.count() - val stageInfos = listener.getCompletedStageInfos - val taskInfos = listener.getCompletedTaskInfos - assert(stageInfos.size === 1) - assert(taskInfos.size === numPartitions + numFailedPartitions) - val stageAccum = findAccumulableInfo(stageInfos.head.accumulables.values, TEST_ACCUMULATOR) - // We should not double count values in the merged accumulator - assert(stageAccum.value.toLong === numPartitions) - val taskAccumValues = taskInfos.flatMap { taskInfo => - if (!taskInfo.failed) { - // If a task succeeded, its update value should always be 1 - val taskAccum = findAccumulableInfo(taskInfo.accumulables, TEST_ACCUMULATOR) - assert(taskAccum.update.isDefined) - assert(taskAccum.update.get.toLong === 1) - Some(taskAccum.value.toLong) - } else { - // If a task failed, we should not get its accumulator values - assert(taskInfo.accumulables.isEmpty) - None + } + // Register asserts in job completion callback to avoid flakiness + listener.registerJobCompletionCallback { _ => + val stageInfos = listener.getCompletedStageInfos + val taskInfos = listener.getCompletedTaskInfos + assert(stageInfos.size === 1) + assert(taskInfos.size === numPartitions + numFailedPartitions) + val stageAccum = findAccumulableInfo(stageInfos.head.accumulables.values, TEST_ACCUMULATOR) + // We should not double count values in the merged accumulator + assert(stageAccum.value.toLong === numPartitions) + val taskAccumValues = taskInfos.flatMap { taskInfo => + if (!taskInfo.failed) { + // If a task succeeded, its update value should always be 1 + val taskAccum = findAccumulableInfo(taskInfo.accumulables, TEST_ACCUMULATOR) + assert(taskAccum.update.isDefined) + assert(taskAccum.update.get.toLong === 1) + Some(taskAccum.value.toLong) + } else { + // If a task failed, we should not get its accumulator values + assert(taskInfo.accumulables.isEmpty) + None + } } + assert(taskAccumValues.sorted === (1L to numPartitions).toSeq) } - assert(taskAccumValues.sorted === (1L to numPartitions).toSeq) + rdd.count() } } @@ -313,20 +325,27 @@ private[spark] object AccumulatorSuite { testName: String)(testBody: => Unit): Unit = { val listener = new SaveInfoListener sc.addSparkListener(listener) - // Verify that the accumulator does not already exist + // Register asserts in job completion callback to avoid flakiness + listener.registerJobCompletionCallback { jobId => + if (jobId == 0) { + // The first job is a dummy one to verify that the accumulator does not already exist + val accums = listener.getCompletedStageInfos.flatMap(_.accumulables.values) + assert(!accums.exists(_.name == InternalAccumulator.PEAK_EXECUTION_MEMORY)) + } else { + // In the subsequent jobs, verify that peak execution memory is updated + val accum = listener.getCompletedStageInfos + .flatMap(_.accumulables.values) + .find(_.name == InternalAccumulator.PEAK_EXECUTION_MEMORY) + .getOrElse { + throw new TestFailedException( + s"peak execution memory accumulator not set in '$testName'", 0) + } + assert(accum.value.toLong > 0) + } + } + // Run the jobs sc.parallelize(1 to 10).count() - val accums = listener.getCompletedStageInfos.flatMap(_.accumulables.values) - assert(!accums.exists(_.name == InternalAccumulator.PEAK_EXECUTION_MEMORY)) testBody - // Verify that peak execution memory is updated - val accum = listener.getCompletedStageInfos - .flatMap(_.accumulables.values) - .find(_.name == InternalAccumulator.PEAK_EXECUTION_MEMORY) - .getOrElse { - throw new TestFailedException( - s"peak execution memory accumulator not set in '$testName'", 0) - } - assert(accum.value.toLong > 0) } } @@ -336,10 +355,22 @@ private[spark] object AccumulatorSuite { private class SaveInfoListener extends SparkListener { private val completedStageInfos: ArrayBuffer[StageInfo] = new ArrayBuffer[StageInfo] private val completedTaskInfos: ArrayBuffer[TaskInfo] = new ArrayBuffer[TaskInfo] + private var jobCompletionCallback: (Int => Unit) = null // parameter is job ID def getCompletedStageInfos: Seq[StageInfo] = completedStageInfos.toArray.toSeq def getCompletedTaskInfos: Seq[TaskInfo] = completedTaskInfos.toArray.toSeq + /** Register a callback to be called on job end. */ + def registerJobCompletionCallback(callback: (Int => Unit)): Unit = { + jobCompletionCallback = callback + } + + override def onJobEnd(jobEnd: SparkListenerJobEnd): Unit = { + if (jobCompletionCallback != null) { + jobCompletionCallback(jobEnd.jobId) + } + } + override def onStageCompleted(stageCompleted: SparkListenerStageCompleted): Unit = { completedStageInfos += stageCompleted.stageInfo } From 9407baa2a7c26f527f2d043715d313d75bd765bb Mon Sep 17 00:00:00 2001 From: jerryshao Date: Fri, 14 Aug 2015 13:44:38 -0700 Subject: [PATCH 114/253] [SPARK-9877] [CORE] Fix StandaloneRestServer NPE when submitting application Detailed exception log can be seen in [SPARK-9877](https://issues.apache.org/jira/browse/SPARK-9877), the problem is when creating `StandaloneRestServer`, `self` (`masterEndpoint`) is null. So this fix is creating `StandaloneRestServer` when `self` is available. Author: jerryshao Closes #8127 from jerryshao/SPARK-9877. --- .../org/apache/spark/deploy/master/Master.scala | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala index 9217202b69a66..26904d39a9bec 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala @@ -127,14 +127,8 @@ private[deploy] class Master( // Alternative application submission gateway that is stable across Spark versions private val restServerEnabled = conf.getBoolean("spark.master.rest.enabled", true) - private val restServer = - if (restServerEnabled) { - val port = conf.getInt("spark.master.rest.port", 6066) - Some(new StandaloneRestServer(address.host, port, conf, self, masterUrl)) - } else { - None - } - private val restServerBoundPort = restServer.map(_.start()) + private var restServer: Option[StandaloneRestServer] = None + private var restServerBoundPort: Option[Int] = None override def onStart(): Unit = { logInfo("Starting Spark master at " + masterUrl) @@ -148,6 +142,12 @@ private[deploy] class Master( } }, 0, WORKER_TIMEOUT_MS, TimeUnit.MILLISECONDS) + if (restServerEnabled) { + val port = conf.getInt("spark.master.rest.port", 6066) + restServer = Some(new StandaloneRestServer(address.host, port, conf, self, masterUrl)) + } + restServerBoundPort = restServer.map(_.start()) + masterMetricsSystem.registerSource(masterSource) masterMetricsSystem.start() applicationMetricsSystem.start() From 11ed2b180ec86523a94679a8b8132fadb911ccd5 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Fri, 14 Aug 2015 13:55:29 -0700 Subject: [PATCH 115/253] [SPARK-9978] [PYSPARK] [SQL] fix Window.orderBy and doc of ntile() Author: Davies Liu Closes #8213 from davies/fix_window. --- python/pyspark/sql/functions.py | 7 ++++--- python/pyspark/sql/tests.py | 23 +++++++++++++++++++++++ python/pyspark/sql/window.py | 2 +- 3 files changed, 28 insertions(+), 4 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index e98979533f901..41dfee9f54f7a 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -530,9 +530,10 @@ def lead(col, count=1, default=None): @since(1.4) def ntile(n): """ - Window function: returns a group id from 1 to `n` (inclusive) in a round-robin fashion in - a window partition. Fow example, if `n` is 3, the first row will get 1, the second row will - get 2, the third row will get 3, and the fourth row will get 1... + Window function: returns the ntile group id (from 1 to `n` inclusive) + in an ordered window partition. Fow example, if `n` is 4, the first + quarter of the rows will get value 1, the second quarter will get 2, + the third quarter will get 3, and the last quarter will get 4. This is equivalent to the NTILE function in SQL. diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 38c83c427a747..9b748101b5e53 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -1124,5 +1124,28 @@ def test_window_functions(self): for r, ex in zip(rs, expected): self.assertEqual(tuple(r), ex[:len(r)]) + def test_window_functions_without_partitionBy(self): + df = self.sqlCtx.createDataFrame([(1, "1"), (2, "2"), (1, "2"), (1, "2")], ["key", "value"]) + w = Window.orderBy("key", df.value) + from pyspark.sql import functions as F + sel = df.select(df.value, df.key, + F.max("key").over(w.rowsBetween(0, 1)), + F.min("key").over(w.rowsBetween(0, 1)), + F.count("key").over(w.rowsBetween(float('-inf'), float('inf'))), + F.rowNumber().over(w), + F.rank().over(w), + F.denseRank().over(w), + F.ntile(2).over(w)) + rs = sorted(sel.collect()) + expected = [ + ("1", 1, 1, 1, 4, 1, 1, 1, 1), + ("2", 1, 1, 1, 4, 2, 2, 2, 1), + ("2", 1, 2, 1, 4, 3, 2, 2, 2), + ("2", 2, 2, 2, 4, 4, 4, 3, 2) + ] + for r, ex in zip(rs, expected): + self.assertEqual(tuple(r), ex[:len(r)]) + + if __name__ == "__main__": unittest.main() diff --git a/python/pyspark/sql/window.py b/python/pyspark/sql/window.py index c74745c726a0c..eaf4d7e98620a 100644 --- a/python/pyspark/sql/window.py +++ b/python/pyspark/sql/window.py @@ -64,7 +64,7 @@ def orderBy(*cols): Creates a :class:`WindowSpec` with the partitioning defined. """ sc = SparkContext._active_spark_context - jspec = sc._jvm.org.apache.spark.sql.expressions.Window.partitionBy(_to_java_cols(cols)) + jspec = sc._jvm.org.apache.spark.sql.expressions.Window.orderBy(_to_java_cols(cols)) return WindowSpec(jspec) From 2a6590e510aba3bfc6603d280023128b3f5ac702 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Fri, 14 Aug 2015 14:05:03 -0700 Subject: [PATCH 116/253] [SPARK-9981] [ML] Made labels public for StringIndexerModel Also added unit test for integration between StringIndexerModel and IndexToString CC: holdenk We realized we should have left in your unit test (to catch the issue with removing the inverse() method), so this adds it back. mengxr Author: Joseph K. Bradley Closes #8211 from jkbradley/stridx-labels. --- .../spark/ml/feature/StringIndexer.scala | 5 ++++- .../spark/ml/feature/StringIndexerSuite.scala | 18 ++++++++++++++++++ 2 files changed, 22 insertions(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala index 63475780a6ff9..24250e4c4cf92 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala @@ -97,14 +97,17 @@ class StringIndexer(override val uid: String) extends Estimator[StringIndexerMod /** * :: Experimental :: * Model fitted by [[StringIndexer]]. + * * NOTE: During transformation, if the input column does not exist, * [[StringIndexerModel.transform]] would return the input dataset unmodified. * This is a temporary fix for the case when target labels do not exist during prediction. + * + * @param labels Ordered list of labels, corresponding to indices to be assigned */ @Experimental class StringIndexerModel ( override val uid: String, - labels: Array[String]) extends Model[StringIndexerModel] with StringIndexerBase { + val labels: Array[String]) extends Model[StringIndexerModel] with StringIndexerBase { def this(labels: Array[String]) = this(Identifiable.randomUID("strIdx"), labels) diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala index 0b4c8ba71ee61..05e05bdc64bb1 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala @@ -147,4 +147,22 @@ class StringIndexerSuite extends SparkFunSuite with MLlibTestSparkContext { assert(actual === expected) } } + + test("StringIndexer, IndexToString are inverses") { + val data = sc.parallelize(Seq((0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c")), 2) + val df = sqlContext.createDataFrame(data).toDF("id", "label") + val indexer = new StringIndexer() + .setInputCol("label") + .setOutputCol("labelIndex") + .fit(df) + val transformed = indexer.transform(df) + val idx2str = new IndexToString() + .setInputCol("labelIndex") + .setOutputCol("sameLabel") + .setLabels(indexer.labels) + idx2str.transform(transformed).select("label", "sameLabel").collect().foreach { + case Row(a: String, b: String) => + assert(a === b) + } + } } From 1150a19b188a075166899fdb1e107b2ba1e505d8 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Fri, 14 Aug 2015 14:09:46 -0700 Subject: [PATCH 117/253] [SPARK-8670] [SQL] Nested columns can't be referenced in pyspark This bug is caused by a wrong column-exist-check in `__getitem__` of pyspark dataframe. `DataFrame.apply` accepts not only top level column names, but also nested column name like `a.b`, so we should remove that check from `__getitem__`. Author: Wenchen Fan Closes #8202 from cloud-fan/nested. --- python/pyspark/sql/dataframe.py | 2 -- python/pyspark/sql/tests.py | 4 +++- sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala | 2 ++ 3 files changed, 5 insertions(+), 3 deletions(-) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 09647ff6d0749..da742d7ce7d13 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -722,8 +722,6 @@ def __getitem__(self, item): [Row(age=5, name=u'Bob')] """ if isinstance(item, basestring): - if item not in self.columns: - raise IndexError("no such column: %s" % item) jc = self._jdf.apply(item) return Column(jc) elif isinstance(item, Column): diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 9b748101b5e53..13cf647b66da8 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -770,7 +770,7 @@ def test_access_column(self): self.assertTrue(isinstance(df['key'], Column)) self.assertTrue(isinstance(df[0], Column)) self.assertRaises(IndexError, lambda: df[2]) - self.assertRaises(IndexError, lambda: df["bad_key"]) + self.assertRaises(AnalysisException, lambda: df["bad_key"]) self.assertRaises(TypeError, lambda: df[{}]) def test_column_name_with_non_ascii(self): @@ -794,7 +794,9 @@ def test_field_accessor(self): df = self.sc.parallelize([Row(l=[1], r=Row(a=1, b="b"), d={"k": "v"})]).toDF() self.assertEqual(1, df.select(df.l[0]).first()[0]) self.assertEqual(1, df.select(df.r["a"]).first()[0]) + self.assertEqual(1, df.select(df["r.a"]).first()[0]) self.assertEqual("b", df.select(df.r["b"]).first()[0]) + self.assertEqual("b", df.select(df["r.b"]).first()[0]) self.assertEqual("v", df.select(df.d["k"]).first()[0]) def test_infer_long_type(self): diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index cf75e64e884b4..fd0ead4401193 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -634,6 +634,7 @@ class DataFrame private[sql]( /** * Selects column based on the column name and return it as a [[Column]]. + * Note that the column name can also reference to a nested column like `a.b`. * @group dfops * @since 1.3.0 */ @@ -641,6 +642,7 @@ class DataFrame private[sql]( /** * Selects column based on the column name and return it as a [[Column]]. + * Note that the column name can also reference to a nested column like `a.b`. * @group dfops * @since 1.3.0 */ From f3bfb711c1742d0915e43bda8230b4d1d22b4190 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Fri, 14 Aug 2015 15:10:01 -0700 Subject: [PATCH 118/253] [SPARK-9966] [STREAMING] Handle couple of corner cases in PIDRateEstimator 1. The rate estimator should not estimate any rate when there are no records in the batch, as there is no data to estimate the rate. In the current state, it estimates and set the rate to zero. That is incorrect. 2. The rate estimator should not never set the rate to zero under any circumstances. Otherwise the system will stop receiving data, and stop generating useful estimates (see reason 1). So the fix is to define a parameters that sets a lower bound on the estimated rate, so that the system always receives some data. Author: Tathagata Das Closes #8199 from tdas/SPARK-9966 and squashes the following commits: 829f793 [Tathagata Das] Fixed unit test and added comments 3a994db [Tathagata Das] Added min rate and updated tests in PIDRateEstimator --- .../scheduler/rate/PIDRateEstimator.scala | 46 ++++++++--- .../scheduler/rate/RateEstimator.scala | 4 +- .../rate/PIDRateEstimatorSuite.scala | 79 ++++++++++++------- 3 files changed, 87 insertions(+), 42 deletions(-) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/rate/PIDRateEstimator.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/rate/PIDRateEstimator.scala index 6ae56a68ad88c..84a3ca9d74e58 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/rate/PIDRateEstimator.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/rate/PIDRateEstimator.scala @@ -17,6 +17,8 @@ package org.apache.spark.streaming.scheduler.rate +import org.apache.spark.Logging + /** * Implements a proportional-integral-derivative (PID) controller which acts on * the speed of ingestion of elements into Spark Streaming. A PID controller works @@ -26,7 +28,7 @@ package org.apache.spark.streaming.scheduler.rate * * @see https://en.wikipedia.org/wiki/PID_controller * - * @param batchDurationMillis the batch duration, in milliseconds + * @param batchIntervalMillis the batch duration, in milliseconds * @param proportional how much the correction should depend on the current * error. This term usually provides the bulk of correction and should be positive or zero. * A value too large would make the controller overshoot the setpoint, while a small value @@ -39,13 +41,17 @@ package org.apache.spark.streaming.scheduler.rate * of future errors, based on current rate of change. This value should be positive or 0. * This term is not used very often, as it impacts stability of the system. The default * value is 0. + * @param minRate what is the minimum rate that can be estimated. + * This must be greater than zero, so that the system always receives some data for rate + * estimation to work. */ private[streaming] class PIDRateEstimator( batchIntervalMillis: Long, - proportional: Double = 1D, - integral: Double = .2D, - derivative: Double = 0D) - extends RateEstimator { + proportional: Double, + integral: Double, + derivative: Double, + minRate: Double + ) extends RateEstimator with Logging { private var firstRun: Boolean = true private var latestTime: Long = -1L @@ -64,16 +70,23 @@ private[streaming] class PIDRateEstimator( require( derivative >= 0, s"Derivative term $derivative in PIDRateEstimator should be >= 0.") + require( + minRate > 0, + s"Minimum rate in PIDRateEstimator should be > 0") + logInfo(s"Created PIDRateEstimator with proportional = $proportional, integral = $integral, " + + s"derivative = $derivative, min rate = $minRate") - def compute(time: Long, // in milliseconds + def compute( + time: Long, // in milliseconds numElements: Long, processingDelay: Long, // in milliseconds schedulingDelay: Long // in milliseconds ): Option[Double] = { - + logTrace(s"\ntime = $time, # records = $numElements, " + + s"processing time = $processingDelay, scheduling delay = $schedulingDelay") this.synchronized { - if (time > latestTime && processingDelay > 0 && batchIntervalMillis > 0) { + if (time > latestTime && numElements > 0 && processingDelay > 0) { // in seconds, should be close to batchDuration val delaySinceUpdate = (time - latestTime).toDouble / 1000 @@ -104,21 +117,30 @@ private[streaming] class PIDRateEstimator( val newRate = (latestRate - proportional * error - integral * historicalError - - derivative * dError).max(0.0) + derivative * dError).max(minRate) + logTrace(s""" + | latestRate = $latestRate, error = $error + | latestError = $latestError, historicalError = $historicalError + | delaySinceUpdate = $delaySinceUpdate, dError = $dError + """.stripMargin) + latestTime = time if (firstRun) { latestRate = processingRate latestError = 0D firstRun = false - + logTrace("First run, rate estimation skipped") None } else { latestRate = newRate latestError = error - + logTrace(s"New rate = $newRate") Some(newRate) } - } else None + } else { + logTrace("Rate estimation skipped") + None + } } } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/rate/RateEstimator.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/rate/RateEstimator.scala index 17ccebc1ed41b..d7210f64fcc36 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/rate/RateEstimator.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/rate/RateEstimator.scala @@ -18,7 +18,6 @@ package org.apache.spark.streaming.scheduler.rate import org.apache.spark.SparkConf -import org.apache.spark.SparkException import org.apache.spark.streaming.Duration /** @@ -61,7 +60,8 @@ object RateEstimator { val proportional = conf.getDouble("spark.streaming.backpressure.pid.proportional", 1.0) val integral = conf.getDouble("spark.streaming.backpressure.pid.integral", 0.2) val derived = conf.getDouble("spark.streaming.backpressure.pid.derived", 0.0) - new PIDRateEstimator(batchInterval.milliseconds, proportional, integral, derived) + val minRate = conf.getDouble("spark.streaming.backpressure.pid.minRate", 100) + new PIDRateEstimator(batchInterval.milliseconds, proportional, integral, derived, minRate) case estimator => throw new IllegalArgumentException(s"Unkown rate estimator: $estimator") diff --git a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/rate/PIDRateEstimatorSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/rate/PIDRateEstimatorSuite.scala index 97c32d8f2d59e..a1af95be81c8e 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/rate/PIDRateEstimatorSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/rate/PIDRateEstimatorSuite.scala @@ -36,72 +36,89 @@ class PIDRateEstimatorSuite extends SparkFunSuite with Matchers { test("estimator checks ranges") { intercept[IllegalArgumentException] { - new PIDRateEstimator(0, 1, 2, 3) + new PIDRateEstimator(batchIntervalMillis = 0, 1, 2, 3, 10) } intercept[IllegalArgumentException] { - new PIDRateEstimator(100, -1, 2, 3) + new PIDRateEstimator(100, proportional = -1, 2, 3, 10) } intercept[IllegalArgumentException] { - new PIDRateEstimator(100, 0, -1, 3) + new PIDRateEstimator(100, 0, integral = -1, 3, 10) } intercept[IllegalArgumentException] { - new PIDRateEstimator(100, 0, 0, -1) + new PIDRateEstimator(100, 0, 0, derivative = -1, 10) + } + intercept[IllegalArgumentException] { + new PIDRateEstimator(100, 0, 0, 0, minRate = 0) + } + intercept[IllegalArgumentException] { + new PIDRateEstimator(100, 0, 0, 0, minRate = -10) } } - private def createDefaultEstimator: PIDRateEstimator = { - new PIDRateEstimator(20, 1D, 0D, 0D) - } - - test("first bound is None") { - val p = createDefaultEstimator + test("first estimate is None") { + val p = createDefaultEstimator() p.compute(0, 10, 10, 0) should equal(None) } - test("second bound is rate") { - val p = createDefaultEstimator + test("second estimate is not None") { + val p = createDefaultEstimator() p.compute(0, 10, 10, 0) // 1000 elements / s p.compute(10, 10, 10, 0) should equal(Some(1000)) } - test("works even with no time between updates") { - val p = createDefaultEstimator + test("no estimate when no time difference between successive calls") { + val p = createDefaultEstimator() + p.compute(0, 10, 10, 0) + p.compute(time = 10, 10, 10, 0) shouldNot equal(None) + p.compute(time = 10, 10, 10, 0) should equal(None) + } + + test("no estimate when no records in previous batch") { + val p = createDefaultEstimator() p.compute(0, 10, 10, 0) - p.compute(10, 10, 10, 0) - p.compute(10, 10, 10, 0) should equal(None) + p.compute(10, numElements = 0, 10, 0) should equal(None) + p.compute(20, numElements = -10, 10, 0) should equal(None) } - test("bound is never negative") { - val p = new PIDRateEstimator(20, 1D, 1D, 0D) + test("no estimate when there is no processing delay") { + val p = createDefaultEstimator() + p.compute(0, 10, 10, 0) + p.compute(10, 10, processingDelay = 0, 0) should equal(None) + p.compute(20, 10, processingDelay = -10, 0) should equal(None) + } + + test("estimate is never less than min rate") { + val minRate = 5D + val p = new PIDRateEstimator(20, 1D, 1D, 0D, minRate) // prepare a series of batch updates, one every 20ms, 0 processed elements, 2ms of processing // this might point the estimator to try and decrease the bound, but we test it never - // goes below zero, which would be nonsensical. + // goes below the min rate, which would be nonsensical. val times = List.tabulate(50)(x => x * 20) // every 20ms - val elements = List.fill(50)(0) // no processing + val elements = List.fill(50)(1) // no processing val proc = List.fill(50)(20) // 20ms of processing val sched = List.fill(50)(100) // strictly positive accumulation val res = for (i <- List.range(0, 50)) yield p.compute(times(i), elements(i), proc(i), sched(i)) res.head should equal(None) - res.tail should equal(List.fill(49)(Some(0D))) + res.tail should equal(List.fill(49)(Some(minRate))) } test("with no accumulated or positive error, |I| > 0, follow the processing speed") { - val p = new PIDRateEstimator(20, 1D, 1D, 0D) + val p = new PIDRateEstimator(20, 1D, 1D, 0D, 10) // prepare a series of batch updates, one every 20ms with an increasing number of processed // elements in each batch, but constant processing time, and no accumulated error. Even though // the integral part is non-zero, the estimated rate should follow only the proportional term val times = List.tabulate(50)(x => x * 20) // every 20ms - val elements = List.tabulate(50)(x => x * 20) // increasing + val elements = List.tabulate(50)(x => (x + 1) * 20) // increasing val proc = List.fill(50)(20) // 20ms of processing val sched = List.fill(50)(0) val res = for (i <- List.range(0, 50)) yield p.compute(times(i), elements(i), proc(i), sched(i)) res.head should equal(None) - res.tail should equal(List.tabulate(50)(x => Some(x * 1000D)).tail) + res.tail should equal(List.tabulate(50)(x => Some((x + 1) * 1000D)).tail) } test("with no accumulated but some positive error, |I| > 0, follow the processing speed") { - val p = new PIDRateEstimator(20, 1D, 1D, 0D) + val p = new PIDRateEstimator(20, 1D, 1D, 0D, 10) // prepare a series of batch updates, one every 20ms with an decreasing number of processed // elements in each batch, but constant processing time, and no accumulated error. Even though // the integral part is non-zero, the estimated rate should follow only the proportional term, @@ -116,13 +133,14 @@ class PIDRateEstimatorSuite extends SparkFunSuite with Matchers { } test("with some accumulated and some positive error, |I| > 0, stay below the processing speed") { - val p = new PIDRateEstimator(20, 1D, .01D, 0D) + val minRate = 10D + val p = new PIDRateEstimator(20, 1D, .01D, 0D, minRate) val times = List.tabulate(50)(x => x * 20) // every 20ms val rng = new Random() - val elements = List.tabulate(50)(x => rng.nextInt(1000)) + val elements = List.tabulate(50)(x => rng.nextInt(1000) + 1000) val procDelayMs = 20 val proc = List.fill(50)(procDelayMs) // 20ms of processing - val sched = List.tabulate(50)(x => rng.nextInt(19)) // random wait + val sched = List.tabulate(50)(x => rng.nextInt(19) + 1) // random wait val speeds = elements map ((x) => x.toDouble / procDelayMs * 1000) val res = for (i <- List.range(0, 50)) yield p.compute(times(i), elements(i), proc(i), sched(i)) @@ -131,7 +149,12 @@ class PIDRateEstimatorSuite extends SparkFunSuite with Matchers { res(n) should not be None if (res(n).get > 0 && sched(n) > 0) { res(n).get should be < speeds(n) + res(n).get should be >= minRate } } } + + private def createDefaultEstimator(): PIDRateEstimator = { + new PIDRateEstimator(20, 1D, 0D, 0D, 10) + } } From 18a761ef7a01a4dfa1dd91abe78cd68f2f8fdb67 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Fri, 14 Aug 2015 15:54:14 -0700 Subject: [PATCH 119/253] [SPARK-9968] [STREAMING] Reduced time spent within synchronized block to prevent lock starvation When the rate limiter is actually limiting the rate at which data is inserted into the buffer, the synchronized block of BlockGenerator.addData stays blocked for long time. This causes the thread switching the buffer and generating blocks (synchronized with addData) to starve and not generate blocks for seconds. The correct solution is to not block on the rate limiter within the synchronized block for adding data to the buffer. Author: Tathagata Das Closes #8204 from tdas/SPARK-9968 and squashes the following commits: 8cbcc1b [Tathagata Das] Removed unused val a73b645 [Tathagata Das] Reduced time spent within synchronized block --- .../streaming/receiver/BlockGenerator.scala | 40 +++++++++++++++---- 1 file changed, 32 insertions(+), 8 deletions(-) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/BlockGenerator.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/BlockGenerator.scala index 794dece370b2c..300e820d01777 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/BlockGenerator.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/BlockGenerator.scala @@ -155,10 +155,17 @@ private[streaming] class BlockGenerator( /** * Push a single data item into the buffer. */ - def addData(data: Any): Unit = synchronized { + def addData(data: Any): Unit = { if (state == Active) { waitToPush() - currentBuffer += data + synchronized { + if (state == Active) { + currentBuffer += data + } else { + throw new SparkException( + "Cannot add data as BlockGenerator has not been started or has been stopped") + } + } } else { throw new SparkException( "Cannot add data as BlockGenerator has not been started or has been stopped") @@ -169,11 +176,18 @@ private[streaming] class BlockGenerator( * Push a single data item into the buffer. After buffering the data, the * `BlockGeneratorListener.onAddData` callback will be called. */ - def addDataWithCallback(data: Any, metadata: Any): Unit = synchronized { + def addDataWithCallback(data: Any, metadata: Any): Unit = { if (state == Active) { waitToPush() - currentBuffer += data - listener.onAddData(data, metadata) + synchronized { + if (state == Active) { + currentBuffer += data + listener.onAddData(data, metadata) + } else { + throw new SparkException( + "Cannot add data as BlockGenerator has not been started or has been stopped") + } + } } else { throw new SparkException( "Cannot add data as BlockGenerator has not been started or has been stopped") @@ -185,13 +199,23 @@ private[streaming] class BlockGenerator( * `BlockGeneratorListener.onAddData` callback will be called. Note that all the data items * are atomically added to the buffer, and are hence guaranteed to be present in a single block. */ - def addMultipleDataWithCallback(dataIterator: Iterator[Any], metadata: Any): Unit = synchronized { + def addMultipleDataWithCallback(dataIterator: Iterator[Any], metadata: Any): Unit = { if (state == Active) { + // Unroll iterator into a temp buffer, and wait for pushing in the process + val tempBuffer = new ArrayBuffer[Any] dataIterator.foreach { data => waitToPush() - currentBuffer += data + tempBuffer += data + } + synchronized { + if (state == Active) { + currentBuffer ++= tempBuffer + listener.onAddData(tempBuffer, metadata) + } else { + throw new SparkException( + "Cannot add data as BlockGenerator has not been started or has been stopped") + } } - listener.onAddData(dataIterator, metadata) } else { throw new SparkException( "Cannot add data as BlockGenerator has not been started or has been stopped") From 932b24fd144232fb08184f0bd0a46369ecba164e Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Fri, 14 Aug 2015 17:35:17 -0700 Subject: [PATCH 120/253] [SPARK-9949] [SQL] Fix TakeOrderedAndProject's output. https://issues.apache.org/jira/browse/SPARK-9949 Author: Yin Huai Closes #8179 from yhuai/SPARK-9949. --- .../spark/sql/execution/basicOperators.scala | 12 ++++++++++- .../spark/sql/execution/PlannerSuite.scala | 20 ++++++++++++++++--- 2 files changed, 28 insertions(+), 4 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index 247c900baae9e..77b98064a9e16 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -237,7 +237,10 @@ case class TakeOrderedAndProject( projectList: Option[Seq[NamedExpression]], child: SparkPlan) extends UnaryNode { - override def output: Seq[Attribute] = child.output + override def output: Seq[Attribute] = { + val projectOutput = projectList.map(_.map(_.toAttribute)) + projectOutput.getOrElse(child.output) + } override def outputPartitioning: Partitioning = SinglePartition @@ -263,6 +266,13 @@ case class TakeOrderedAndProject( protected override def doExecute(): RDD[InternalRow] = sparkContext.makeRDD(collectData(), 1) override def outputOrdering: Seq[SortOrder] = sortOrder + + override def simpleString: String = { + val orderByString = sortOrder.mkString("[", ",", "]") + val outputString = output.mkString("[", ",", "]") + + s"TakeOrderedAndProject(limit=$limit, orderBy=$orderByString, output=$outputString)" + } } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index 937a108543531..fad93b014c237 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -162,9 +162,23 @@ class PlannerSuite extends SparkFunSuite with SharedSQLContext { } test("efficient limit -> project -> sort") { - val query = testData.sort('key).select('value).limit(2).logicalPlan - val planned = ctx.planner.TakeOrderedAndProject(query) - assert(planned.head.isInstanceOf[execution.TakeOrderedAndProject]) + { + val query = + testData.select('key, 'value).sort('key).limit(2).logicalPlan + val planned = ctx.planner.TakeOrderedAndProject(query) + assert(planned.head.isInstanceOf[execution.TakeOrderedAndProject]) + assert(planned.head.output === testData.select('key, 'value).logicalPlan.output) + } + + { + // We need to make sure TakeOrderedAndProject's output is correct when we push a project + // into it. + val query = + testData.select('key, 'value).sort('key).select('value, 'key).limit(2).logicalPlan + val planned = ctx.planner.TakeOrderedAndProject(query) + assert(planned.head.isInstanceOf[execution.TakeOrderedAndProject]) + assert(planned.head.output === testData.select('value, 'key).logicalPlan.output) + } } test("PartitioningCollection") { From e5fd60415fbfea2c5c02602f7ddbc999dd058064 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Fri, 14 Aug 2015 20:55:32 -0700 Subject: [PATCH 121/253] [SPARK-9934] Deprecate NIO ConnectionManager. Deprecate NIO ConnectionManager in Spark 1.5.0, before removing it in Spark 1.6.0. Author: Reynold Xin Closes #8162 from rxin/SPARK-9934. --- core/src/main/scala/org/apache/spark/SparkEnv.scala | 2 ++ docs/configuration.md | 3 ++- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index a796e72850191..0f1e2e069568d 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -331,6 +331,8 @@ object SparkEnv extends Logging { case "netty" => new NettyBlockTransferService(conf, securityManager, numUsableCores) case "nio" => + logWarning("NIO-based block transfer service is deprecated, " + + "and will be removed in Spark 1.6.0.") new NioBlockTransferService(conf, securityManager) } diff --git a/docs/configuration.md b/docs/configuration.md index c60dd16839c02..32147098ae64a 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -389,7 +389,8 @@ Apart from these, the following properties are also available, and may be useful Implementation to use for transferring shuffle and cached blocks between executors. There are two implementations available: netty and nio. Netty-based block transfer is intended to be simpler but equally efficient and is the default option - starting in 1.2. + starting in 1.2, and nio block transfer is deprecated in Spark 1.5.0 and will + be removed in Spark 1.6.0. From 37586e5449ff8f892d41f0b6b8fa1de83dd3849e Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Fri, 14 Aug 2015 20:56:55 -0700 Subject: [PATCH 122/253] [HOTFIX] fix duplicated braces Author: Davies Liu Closes #8219 from davies/fix_typo. --- .../main/scala/org/apache/spark/storage/BlockManager.scala | 2 +- .../scala/org/apache/spark/storage/BlockManagerMaster.scala | 6 +++--- .../main/scala/org/apache/spark/util/ClosureCleaner.scala | 2 +- core/src/main/scala/org/apache/spark/util/Utils.scala | 2 +- .../scala/org/apache/spark/examples/ml/MovieLensALS.scala | 2 +- .../apache/spark/examples/mllib/DecisionTreeRunner.scala | 2 +- .../org/apache/spark/examples/mllib/MovieLensALS.scala | 2 +- .../src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala | 2 +- .../spark/sql/catalyst/analysis/HiveTypeCoercion.scala | 2 +- .../spark/sql/execution/datasources/jdbc/JdbcUtils.scala | 2 +- .../org/apache/spark/sql/execution/ui/SQLListener.scala | 2 +- .../spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala | 2 +- .../apache/spark/streaming/scheduler/InputInfoTracker.scala | 2 +- 13 files changed, 15 insertions(+), 15 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index 86493673d958d..eedb27942e841 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -222,7 +222,7 @@ private[spark] class BlockManager( return } catch { case e: Exception if i < MAX_ATTEMPTS => - logError(s"Failed to connect to external shuffle server, will retry ${MAX_ATTEMPTS - i}}" + logError(s"Failed to connect to external shuffle server, will retry ${MAX_ATTEMPTS - i}" + s" more times after waiting $SLEEP_TIME_SECS seconds...", e) Thread.sleep(SLEEP_TIME_SECS * 1000) } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala index f70f701494dbf..2a11f371b9d6e 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala @@ -103,7 +103,7 @@ class BlockManagerMaster( val future = driverEndpoint.askWithRetry[Future[Seq[Int]]](RemoveRdd(rddId)) future.onFailure { case e: Exception => - logWarning(s"Failed to remove RDD $rddId - ${e.getMessage}}", e) + logWarning(s"Failed to remove RDD $rddId - ${e.getMessage}", e) }(ThreadUtils.sameThread) if (blocking) { timeout.awaitResult(future) @@ -115,7 +115,7 @@ class BlockManagerMaster( val future = driverEndpoint.askWithRetry[Future[Seq[Boolean]]](RemoveShuffle(shuffleId)) future.onFailure { case e: Exception => - logWarning(s"Failed to remove shuffle $shuffleId - ${e.getMessage}}", e) + logWarning(s"Failed to remove shuffle $shuffleId - ${e.getMessage}", e) }(ThreadUtils.sameThread) if (blocking) { timeout.awaitResult(future) @@ -129,7 +129,7 @@ class BlockManagerMaster( future.onFailure { case e: Exception => logWarning(s"Failed to remove broadcast $broadcastId" + - s" with removeFromMaster = $removeFromMaster - ${e.getMessage}}", e) + s" with removeFromMaster = $removeFromMaster - ${e.getMessage}", e) }(ThreadUtils.sameThread) if (blocking) { timeout.awaitResult(future) diff --git a/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala b/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala index ebead830c6466..150d82b3930ef 100644 --- a/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala +++ b/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala @@ -181,7 +181,7 @@ private[spark] object ClosureCleaner extends Logging { return } - logDebug(s"+++ Cleaning closure $func (${func.getClass.getName}}) +++") + logDebug(s"+++ Cleaning closure $func (${func.getClass.getName}) +++") // A list of classes that represents closures enclosed in the given one val innerClasses = getInnerClosureClasses(func) diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index f2abf227dc129..fddc24dbfc237 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -1366,7 +1366,7 @@ private[spark] object Utils extends Logging { file.getAbsolutePath, effectiveStartIndex, effectiveEndIndex)) } sum += fileToLength(file) - logDebug(s"After processing file $file, string built is ${stringBuffer.toString}}") + logDebug(s"After processing file $file, string built is ${stringBuffer.toString}") } stringBuffer.toString } diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/MovieLensALS.scala b/examples/src/main/scala/org/apache/spark/examples/ml/MovieLensALS.scala index cd411397a4b9d..3ae53e57dbdb8 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/MovieLensALS.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/MovieLensALS.scala @@ -76,7 +76,7 @@ object MovieLensALS { .text("path to a MovieLens dataset of movies") .action((x, c) => c.copy(movies = x)) opt[Int]("rank") - .text(s"rank, default: ${defaultParams.rank}}") + .text(s"rank, default: ${defaultParams.rank}") .action((x, c) => c.copy(rank = x)) opt[Int]("maxIter") .text(s"max number of iterations, default: ${defaultParams.maxIter}") diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala index 57ffe3dd2524f..cc6bce3cb7c9c 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala @@ -100,7 +100,7 @@ object DecisionTreeRunner { .action((x, c) => c.copy(numTrees = x)) opt[String]("featureSubsetStrategy") .text(s"feature subset sampling strategy" + - s" (${RandomForest.supportedFeatureSubsetStrategies.mkString(", ")}}), " + + s" (${RandomForest.supportedFeatureSubsetStrategies.mkString(", ")}), " + s"default: ${defaultParams.featureSubsetStrategy}") .action((x, c) => c.copy(featureSubsetStrategy = x)) opt[Double]("fracTest") diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/MovieLensALS.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/MovieLensALS.scala index e43a6f2864c73..69691ae297f64 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/MovieLensALS.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/MovieLensALS.scala @@ -55,7 +55,7 @@ object MovieLensALS { val parser = new OptionParser[Params]("MovieLensALS") { head("MovieLensALS: an example app for ALS on MovieLens data.") opt[Int]("rank") - .text(s"rank, default: ${defaultParams.rank}}") + .text(s"rank, default: ${defaultParams.rank}") .action((x, c) => c.copy(rank = x)) opt[Int]("numIterations") .text(s"number of iterations, default: ${defaultParams.numIterations}") diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala index 9029093e0fa08..bbbcc8436b7c2 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala @@ -469,7 +469,7 @@ private[spark] object BLAS extends Serializable with Logging { require(A.numCols == x.size, s"The columns of A don't match the number of elements of x. A: ${A.numCols}, x: ${x.size}") require(A.numRows == y.size, - s"The rows of A don't match the number of elements of y. A: ${A.numRows}, y:${y.size}}") + s"The rows of A don't match the number of elements of y. A: ${A.numRows}, y:${y.size}") if (alpha == 0.0) { logDebug("gemv: alpha is equal to 0. Returning y.") } else { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index 970f3c8282c81..8581d6b496c15 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -164,7 +164,7 @@ object HiveTypeCoercion { // Leave the same if the dataTypes match. case Some(newType) if a.dataType == newType.dataType => a case Some(newType) => - logDebug(s"Promoting $a to $newType in ${q.simpleString}}") + logDebug(s"Promoting $a to $newType in ${q.simpleString}") newType } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala index 039c13bf163ca..8ee3b8bda8fc7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala @@ -170,7 +170,7 @@ object JdbcUtils extends Logging { case BinaryType => "BLOB" case TimestampType => "TIMESTAMP" case DateType => "DATE" - case t: DecimalType => s"DECIMAL(${t.precision}},${t.scale}})" + case t: DecimalType => s"DECIMAL(${t.precision},${t.scale})" case _ => throw new IllegalArgumentException(s"Don't know how to save $field to JDBC") }) val nullable = if (field.nullable) "" else "NOT NULL" diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala index 0b9bad987c488..5779c71f64e9e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala @@ -162,7 +162,7 @@ private[sql] class SQLListener(sqlContext: SQLContext) extends SparkListener wit // A task of an old stage attempt. Because a new stage is submitted, we can ignore it. } else if (stageAttemptID > stageMetrics.stageAttemptId) { logWarning(s"A task should not have a higher stageAttemptID ($stageAttemptID) then " + - s"what we have seen (${stageMetrics.stageAttemptId}})") + s"what we have seen (${stageMetrics.stageAttemptId})") } else { // TODO We don't know the attemptId. Currently, what we can do is overriding the // accumulator updates. However, if there are two same task are running, such as diff --git a/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala b/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala index 31ce8e1ec14d7..620b8a36a2baf 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala @@ -84,7 +84,7 @@ class WriteAheadLogBackedBlockRDD[T: ClassTag]( require( blockIds.length == walRecordHandles.length, s"Number of block Ids (${blockIds.length}) must be " + - s" same as number of WAL record handles (${walRecordHandles.length}})") + s" same as number of WAL record handles (${walRecordHandles.length})") require( isBlockIdValid.isEmpty || isBlockIdValid.length == blockIds.length, diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/InputInfoTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/InputInfoTracker.scala index 363c03d431f04..deb15d075975c 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/InputInfoTracker.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/InputInfoTracker.scala @@ -66,7 +66,7 @@ private[streaming] class InputInfoTracker(ssc: StreamingContext) extends Logging new mutable.HashMap[Int, StreamInputInfo]()) if (inputInfos.contains(inputInfo.inputStreamId)) { - throw new IllegalStateException(s"Input stream ${inputInfo.inputStreamId}} for batch" + + throw new IllegalStateException(s"Input stream ${inputInfo.inputStreamId} for batch" + s"$batchTime is already added into InputInfoTracker, this is a illegal state") } inputInfos += ((inputInfo.inputStreamId, inputInfo)) From ec29f2034a3306cc0afdc4c160b42c2eefa0897c Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Fri, 14 Aug 2015 20:59:54 -0700 Subject: [PATCH 123/253] [SPARK-9634] [SPARK-9323] [SQL] cleanup unnecessary Aliases in LogicalPlan at the end of analysis Also alias the ExtractValue instead of wrapping it with UnresolvedAlias when resolve attribute in LogicalPlan, as this alias will be trimmed if it's unnecessary. Based on #7957 without the changes to mllib, but instead maintaining earlier behavior when using `withColumn` on expressions that already have metadata. Author: Wenchen Fan Author: Michael Armbrust Closes #8215 from marmbrus/pr/7957. --- .../sql/catalyst/analysis/Analyzer.scala | 75 ++++++++++++++++--- .../expressions/complexTypeCreator.scala | 2 - .../sql/catalyst/optimizer/Optimizer.scala | 9 ++- .../catalyst/plans/logical/LogicalPlan.scala | 8 +- .../plans/logical/basicOperators.scala | 2 +- .../sql/catalyst/analysis/AnalysisSuite.scala | 17 +++++ .../scala/org/apache/spark/sql/Column.scala | 16 +++- .../spark/sql/ColumnExpressionSuite.scala | 9 +++ .../org/apache/spark/sql/DataFrameSuite.scala | 6 ++ 9 files changed, 120 insertions(+), 24 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index a684dbc3afa42..4bc1c1af40bf4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -82,7 +82,9 @@ class Analyzer( HiveTypeCoercion.typeCoercionRules ++ extendedResolutionRules : _*), Batch("Nondeterministic", Once, - PullOutNondeterministic) + PullOutNondeterministic), + Batch("Cleanup", fixedPoint, + CleanupAliases) ) /** @@ -146,8 +148,6 @@ class Analyzer( child match { case _: UnresolvedAttribute => u case ne: NamedExpression => ne - case g: GetStructField => Alias(g, g.field.name)() - case g: GetArrayStructFields => Alias(g, g.field.name)() case g: Generator if g.resolved && g.elementTypes.size > 1 => MultiAlias(g, Nil) case e if !e.resolved => u case other => Alias(other, s"_c$i")() @@ -384,9 +384,7 @@ class Analyzer( case u @ UnresolvedAttribute(nameParts) => // Leave unchanged if resolution fails. Hopefully will be resolved next round. val result = - withPosition(u) { - q.resolveChildren(nameParts, resolver).map(trimUnresolvedAlias).getOrElse(u) - } + withPosition(u) { q.resolveChildren(nameParts, resolver).getOrElse(u) } logDebug(s"Resolving $u to $result") result case UnresolvedExtractValue(child, fieldExpr) if child.resolved => @@ -412,11 +410,6 @@ class Analyzer( exprs.exists(_.collect { case _: Star => true }.nonEmpty) } - private def trimUnresolvedAlias(ne: NamedExpression) = ne match { - case UnresolvedAlias(child) => child - case other => other - } - private def resolveSortOrders(ordering: Seq[SortOrder], plan: LogicalPlan, throws: Boolean) = { ordering.map { order => // Resolve SortOrder in one round. @@ -426,7 +419,7 @@ class Analyzer( try { val newOrder = order transformUp { case u @ UnresolvedAttribute(nameParts) => - plan.resolve(nameParts, resolver).map(trimUnresolvedAlias).getOrElse(u) + plan.resolve(nameParts, resolver).getOrElse(u) case UnresolvedExtractValue(child, fieldName) if child.resolved => ExtractValue(child, fieldName, resolver) } @@ -968,3 +961,61 @@ object EliminateSubQueries extends Rule[LogicalPlan] { case Subquery(_, child) => child } } + +/** + * Cleans up unnecessary Aliases inside the plan. Basically we only need Alias as a top level + * expression in Project(project list) or Aggregate(aggregate expressions) or + * Window(window expressions). + */ +object CleanupAliases extends Rule[LogicalPlan] { + private def trimAliases(e: Expression): Expression = { + var stop = false + e.transformDown { + // CreateStruct is a special case, we need to retain its top level Aliases as they decide the + // name of StructField. We also need to stop transform down this expression, or the Aliases + // under CreateStruct will be mistakenly trimmed. + case c: CreateStruct if !stop => + stop = true + c.copy(children = c.children.map(trimNonTopLevelAliases)) + case c: CreateStructUnsafe if !stop => + stop = true + c.copy(children = c.children.map(trimNonTopLevelAliases)) + case Alias(child, _) if !stop => child + } + } + + def trimNonTopLevelAliases(e: Expression): Expression = e match { + case a: Alias => + Alias(trimAliases(a.child), a.name)(a.exprId, a.qualifiers, a.explicitMetadata) + case other => trimAliases(other) + } + + override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + case Project(projectList, child) => + val cleanedProjectList = + projectList.map(trimNonTopLevelAliases(_).asInstanceOf[NamedExpression]) + Project(cleanedProjectList, child) + + case Aggregate(grouping, aggs, child) => + val cleanedAggs = aggs.map(trimNonTopLevelAliases(_).asInstanceOf[NamedExpression]) + Aggregate(grouping.map(trimAliases), cleanedAggs, child) + + case w @ Window(projectList, windowExprs, partitionSpec, orderSpec, child) => + val cleanedWindowExprs = + windowExprs.map(e => trimNonTopLevelAliases(e).asInstanceOf[NamedExpression]) + Window(projectList, cleanedWindowExprs, partitionSpec.map(trimAliases), + orderSpec.map(trimAliases(_).asInstanceOf[SortOrder]), child) + + case other => + var stop = false + other transformExpressionsDown { + case c: CreateStruct if !stop => + stop = true + c.copy(children = c.children.map(trimNonTopLevelAliases)) + case c: CreateStructUnsafe if !stop => + stop = true + c.copy(children = c.children.map(trimNonTopLevelAliases)) + case Alias(child, _) if !stop => child + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index 4a071e663e0d1..298aee3499275 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -75,8 +75,6 @@ case class CreateStruct(children: Seq[Expression]) extends Expression { override def foldable: Boolean = children.forall(_.foldable) - override lazy val resolved: Boolean = childrenResolved - override lazy val dataType: StructType = { val fields = children.zipWithIndex.map { case (child, idx) => child match { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 4ab5ac2c61e3c..47b06cae15436 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.optimizer import scala.collection.immutable.HashSet -import org.apache.spark.sql.catalyst.analysis.EliminateSubQueries +import org.apache.spark.sql.catalyst.analysis.{CleanupAliases, EliminateSubQueries} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.Inner import org.apache.spark.sql.catalyst.plans.FullOuter @@ -260,8 +260,11 @@ object ProjectCollapsing extends Rule[LogicalPlan] { val substitutedProjection = projectList1.map(_.transform { case a: Attribute => aliasMap.getOrElse(a, a) }).asInstanceOf[Seq[NamedExpression]] - - Project(substitutedProjection, child) + // collapse 2 projects may introduce unnecessary Aliases, trim them here. + val cleanedProjection = substitutedProjection.map(p => + CleanupAliases.trimNonTopLevelAliases(p).asInstanceOf[NamedExpression] + ) + Project(cleanedProjection, child) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index c290e6acb361c..9bb466ac2d29c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -259,13 +259,13 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging { // One match, but we also need to extract the requested nested field. case Seq((a, nestedFields)) => // The foldLeft adds ExtractValues for every remaining parts of the identifier, - // and wrap it with UnresolvedAlias which will be removed later. + // and aliased it with the last part of the name. // For example, consider "a.b.c", where "a" is resolved to an existing attribute. - // Then this will add ExtractValue("c", ExtractValue("b", a)), and wrap it as - // UnresolvedAlias(ExtractValue("c", ExtractValue("b", a))). + // Then this will add ExtractValue("c", ExtractValue("b", a)), and alias the final + // expression as "c". val fieldExprs = nestedFields.foldLeft(a: Expression)((expr, fieldName) => ExtractValue(expr, Literal(fieldName), resolver)) - Some(UnresolvedAlias(fieldExprs)) + Some(Alias(fieldExprs, nestedFields.last)()) // No matches. case Seq() => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index 7c404722d811c..73b8261260acb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -228,7 +228,7 @@ case class Window( child: LogicalPlan) extends UnaryNode { override def output: Seq[Attribute] = - (projectList ++ windowExpressions).map(_.toAttribute) + projectList ++ windowExpressions.map(_.toAttribute) } /** diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index c944bc69e25b0..1e0cc81dae974 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -119,4 +119,21 @@ class AnalysisSuite extends AnalysisTest { Project(testRelation.output :+ projected, testRelation))) checkAnalysis(plan, expected) } + + test("SPARK-9634: cleanup unnecessary Aliases in LogicalPlan") { + val a = testRelation.output.head + var plan = testRelation.select(((a + 1).as("a+1") + 2).as("col")) + var expected = testRelation.select((a + 1 + 2).as("col")) + checkAnalysis(plan, expected) + + plan = testRelation.groupBy(a.as("a1").as("a2"))((min(a).as("min_a") + 1).as("col")) + expected = testRelation.groupBy(a)((min(a) + 1).as("col")) + checkAnalysis(plan, expected) + + // CreateStruct is a special case that we should not trim Alias for it. + plan = testRelation.select(CreateStruct(Seq(a, (a + 1).as("a+1"))).as("col")) + checkAnalysis(plan, plan) + plan = testRelation.select(CreateStructUnsafe(Seq(a, (a + 1).as("a+1"))).as("col")) + checkAnalysis(plan, plan) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index 27bd084847346..807bc8c30c12d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -753,10 +753,16 @@ class Column(protected[sql] val expr: Expression) extends Logging { * df.select($"colA".as("colB")) * }}} * + * If the current column has metadata associated with it, this metadata will be propagated + * to the new column. If this not desired, use `as` with explicitly empty metadata. + * * @group expr_ops * @since 1.3.0 */ - def as(alias: String): Column = Alias(expr, alias)() + def as(alias: String): Column = expr match { + case ne: NamedExpression => Alias(expr, alias)(explicitMetadata = Some(ne.metadata)) + case other => Alias(other, alias)() + } /** * (Scala-specific) Assigns the given aliases to the results of a table generating function. @@ -789,10 +795,16 @@ class Column(protected[sql] val expr: Expression) extends Logging { * df.select($"colA".as('colB)) * }}} * + * If the current column has metadata associated with it, this metadata will be propagated + * to the new column. If this not desired, use `as` with explicitly empty metadata. + * * @group expr_ops * @since 1.3.0 */ - def as(alias: Symbol): Column = Alias(expr, alias.name)() + def as(alias: Symbol): Column = expr match { + case ne: NamedExpression => Alias(expr, alias.name)(explicitMetadata = Some(ne.metadata)) + case other => Alias(other, alias.name)() + } /** * Gives the column an alias with metadata. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala index ee74e3e83da5a..37738ec5b3c1d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql +import org.apache.spark.sql.catalyst.expressions.NamedExpression import org.scalatest.Matchers._ import org.apache.spark.sql.execution.{Project, TungstenProject} @@ -110,6 +111,14 @@ class ColumnExpressionSuite extends QueryTest with SharedSQLContext { assert(df.select(df("a").alias("b")).columns.head === "b") } + test("as propagates metadata") { + val metadata = new MetadataBuilder + metadata.putString("key", "value") + val origCol = $"a".as("b", metadata.build()) + val newCol = origCol.as("c") + assert(newCol.expr.asInstanceOf[NamedExpression].metadata.getString("key") === "value") + } + test("single explode") { val df = Seq((1, Seq(1, 2, 3))).toDF("a", "intList") checkAnswer( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 10bfa9b64f00d..cf22797752b97 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -867,4 +867,10 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { val actual = df.sort(rand(seed)).collect().map(_.getInt(0)) assert(expected === actual) } + + test("SPARK-9323: DataFrame.orderBy should support nested column name") { + val df = sqlContext.read.json(sqlContext.sparkContext.makeRDD( + """{"a": {"b": 1}}""" :: Nil)) + checkAnswer(df.orderBy("a.b"), Row(Row(1))) + } } From 6c4fdbec33af287d24cd0995ecbd7191545d05c9 Mon Sep 17 00:00:00 2001 From: Yijie Shen Date: Fri, 14 Aug 2015 21:03:14 -0700 Subject: [PATCH 124/253] [SPARK-8887] [SQL] Explicit define which data types can be used as dynamic partition columns This PR enforce dynamic partition column data type requirements by adding analysis rules. JIRA: https://issues.apache.org/jira/browse/SPARK-8887 Author: Yijie Shen Closes #8201 from yjshen/dynamic_partition_columns. --- .../datasources/PartitioningUtils.scala | 13 +++++++++++++ .../datasources/ResolvedDataSource.scala | 5 ++++- .../execution/datasources/WriterContainer.scala | 2 +- .../spark/sql/execution/datasources/rules.scala | 8 ++++++-- .../sql/sources/hadoopFsRelationSuites.scala | 17 +++++++++++++++++ 5 files changed, 41 insertions(+), 4 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala index 66dfcc308ceca..0a2007e15843c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala @@ -26,6 +26,7 @@ import scala.util.Try import org.apache.hadoop.fs.Path import org.apache.hadoop.util.Shell +import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Cast, Literal} import org.apache.spark.sql.types._ @@ -270,6 +271,18 @@ private[sql] object PartitioningUtils { private val upCastingOrder: Seq[DataType] = Seq(NullType, IntegerType, LongType, FloatType, DoubleType, StringType) + def validatePartitionColumnDataTypes( + schema: StructType, + partitionColumns: Array[String]): Unit = { + + ResolvedDataSource.partitionColumnsSchema(schema, partitionColumns).foreach { field => + field.dataType match { + case _: AtomicType => // OK + case _ => throw new AnalysisException(s"Cannot use ${field.dataType} for partition column") + } + } + } + /** * Given a collection of [[Literal]]s, resolves possible type conflicts by up-casting "lower" * types. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ResolvedDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ResolvedDataSource.scala index 7770bbd712f04..8fbaf3a3059db 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ResolvedDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ResolvedDataSource.scala @@ -143,7 +143,7 @@ object ResolvedDataSource extends Logging { new ResolvedDataSource(clazz, relation) } - private def partitionColumnsSchema( + def partitionColumnsSchema( schema: StructType, partitionColumns: Array[String]): StructType = { StructType(partitionColumns.map { col => @@ -179,6 +179,9 @@ object ResolvedDataSource extends Logging { val fs = path.getFileSystem(sqlContext.sparkContext.hadoopConfiguration) path.makeQualified(fs.getUri, fs.getWorkingDirectory) } + + PartitioningUtils.validatePartitionColumnDataTypes(data.schema, partitionColumns) + val dataSchema = StructType(data.schema.filterNot(f => partitionColumns.contains(f.name))) val r = dataSource.createRelation( sqlContext, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala index 2f11f40422402..d36197e50d448 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala @@ -287,7 +287,7 @@ private[sql] class DynamicPartitionWriterContainer( PartitioningUtils.escapePathName _, StringType, Seq(Cast(c, StringType)), Seq(StringType)) val str = If(IsNull(c), Literal(defaultPartitionName), escaped) val partitionName = Literal(c.name + "=") :: str :: Nil - if (i == 0) partitionName else Literal(Path.SEPARATOR_CHAR.toString) :: partitionName + if (i == 0) partitionName else Literal(Path.SEPARATOR) :: partitionName } // Returns the partition path given a partition key. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala index 40ca8bf4095d8..9d3d35692ffcc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala @@ -116,6 +116,8 @@ private[sql] case class PreWriteCheck(catalog: Catalog) extends (LogicalPlan => // OK } + PartitioningUtils.validatePartitionColumnDataTypes(r.schema, part.keySet.toArray) + // Get all input data source relations of the query. val srcRelations = query.collect { case LogicalRelation(src: BaseRelation) => src @@ -138,10 +140,10 @@ private[sql] case class PreWriteCheck(catalog: Catalog) extends (LogicalPlan => // OK } - case CreateTableUsingAsSelect(tableName, _, _, _, SaveMode.Overwrite, _, query) => + case CreateTableUsingAsSelect(tableName, _, _, partitionColumns, mode, _, query) => // When the SaveMode is Overwrite, we need to check if the table is an input table of // the query. If so, we will throw an AnalysisException to let users know it is not allowed. - if (catalog.tableExists(Seq(tableName))) { + if (mode == SaveMode.Overwrite && catalog.tableExists(Seq(tableName))) { // Need to remove SubQuery operator. EliminateSubQueries(catalog.lookupRelation(Seq(tableName))) match { // Only do the check if the table is a data source table @@ -164,6 +166,8 @@ private[sql] case class PreWriteCheck(catalog: Catalog) extends (LogicalPlan => // OK } + PartitioningUtils.validatePartitionColumnDataTypes(query.schema, partitionColumns) + case _ => // OK } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala index af445626fbe4d..8d0d9218ddd6a 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.sources +import java.sql.Date + import scala.collection.JavaConversions._ import org.apache.hadoop.conf.Configuration @@ -553,6 +555,21 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils { clonedConf.foreach(entry => configuration.set(entry.getKey, entry.getValue)) } } + + test("SPARK-8887: Explicitly define which data types can be used as dynamic partition columns") { + val df = Seq( + (1, "v1", Array(1, 2, 3), Map("k1" -> "v1"), Tuple2(1, "4")), + (2, "v2", Array(4, 5, 6), Map("k2" -> "v2"), Tuple2(2, "5")), + (3, "v3", Array(7, 8, 9), Map("k3" -> "v3"), Tuple2(3, "6"))).toDF("a", "b", "c", "d", "e") + withTempDir { file => + intercept[AnalysisException] { + df.write.format(dataSourceName).partitionBy("c", "d", "e").save(file.getCanonicalPath) + } + } + intercept[AnalysisException] { + df.write.format(dataSourceName).partitionBy("c", "d", "e").saveAsTable("t") + } + } } // This class is used to test SPARK-8578. We should not use any custom output committer when From 609ce3c07d4962a9242e488ad0ed48c183896802 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Fri, 14 Aug 2015 21:12:11 -0700 Subject: [PATCH 125/253] [SPARK-9984] [SQL] Create local physical operator interface. This pull request creates a new operator interface that is more similar to traditional database query iterators (with open/close/next/get). These local operators are not currently used anywhere, but will become the basis for SPARK-9983 (local physical operators for query execution). cc zsxwing Author: Reynold Xin Closes #8212 from rxin/SPARK-9984. --- .../sql/execution/local/FilterNode.scala | 47 ++++++++++ .../spark/sql/execution/local/LocalNode.scala | 86 +++++++++++++++++++ .../sql/execution/local/ProjectNode.scala | 42 +++++++++ .../sql/execution/local/SeqScanNode.scala | 49 +++++++++++ 4 files changed, 224 insertions(+) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/local/FilterNode.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/local/ProjectNode.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/local/SeqScanNode.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/FilterNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/FilterNode.scala new file mode 100644 index 0000000000000..a485a1a1d7ae4 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/FilterNode.scala @@ -0,0 +1,47 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.execution.local + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression} +import org.apache.spark.sql.catalyst.expressions.codegen.GeneratePredicate + + +case class FilterNode(condition: Expression, child: LocalNode) extends UnaryLocalNode { + + private[this] var predicate: (InternalRow) => Boolean = _ + + override def output: Seq[Attribute] = child.output + + override def open(): Unit = { + child.open() + predicate = GeneratePredicate.generate(condition, child.output) + } + + override def next(): Boolean = { + var found = false + while (child.next() && !found) { + found = predicate.apply(child.get()) + } + found + } + + override def get(): InternalRow = child.get() + + override def close(): Unit = child.close() +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala new file mode 100644 index 0000000000000..341c81438e6d6 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala @@ -0,0 +1,86 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.execution.local + +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.trees.TreeNode +import org.apache.spark.sql.types.StructType + +/** + * A local physical operator, in the form of an iterator. + * + * Before consuming the iterator, open function must be called. + * After consuming the iterator, close function must be called. + */ +abstract class LocalNode extends TreeNode[LocalNode] { + + def output: Seq[Attribute] + + /** + * Initializes the iterator state. Must be called before calling `next()`. + * + * Implementations of this must also call the `open()` function of its children. + */ + def open(): Unit + + /** + * Advances the iterator to the next tuple. Returns true if there is at least one more tuple. + */ + def next(): Boolean + + /** + * Returns the current tuple. + */ + def get(): InternalRow + + /** + * Closes the iterator and releases all resources. + * + * Implementations of this must also call the `close()` function of its children. + */ + def close(): Unit + + /** + * Returns the content of the iterator from the beginning to the end in the form of a Scala Seq. + */ + def collect(): Seq[Row] = { + val converter = CatalystTypeConverters.createToScalaConverter(StructType.fromAttributes(output)) + val result = new scala.collection.mutable.ArrayBuffer[Row] + open() + while (next()) { + result += converter.apply(get()).asInstanceOf[Row] + } + close() + result + } +} + + +abstract class LeafLocalNode extends LocalNode { + override def children: Seq[LocalNode] = Seq.empty +} + + +abstract class UnaryLocalNode extends LocalNode { + + def child: LocalNode + + override def children: Seq[LocalNode] = Seq(child) +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/ProjectNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/ProjectNode.scala new file mode 100644 index 0000000000000..e574d1473cdcb --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/ProjectNode.scala @@ -0,0 +1,42 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.execution.local + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{UnsafeProjection, Attribute, NamedExpression} + + +case class ProjectNode(projectList: Seq[NamedExpression], child: LocalNode) extends UnaryLocalNode { + + private[this] var project: UnsafeProjection = _ + + override def output: Seq[Attribute] = projectList.map(_.toAttribute) + + override def open(): Unit = { + project = UnsafeProjection.create(projectList, child.output) + child.open() + } + + override def next(): Boolean = child.next() + + override def get(): InternalRow = { + project.apply(child.get()) + } + + override def close(): Unit = child.close() +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/SeqScanNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/SeqScanNode.scala new file mode 100644 index 0000000000000..994de8afa9a02 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/SeqScanNode.scala @@ -0,0 +1,49 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.execution.local + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.Attribute + +/** + * An operator that scans some local data collection in the form of Scala Seq. + */ +case class SeqScanNode(output: Seq[Attribute], data: Seq[InternalRow]) extends LeafLocalNode { + + private[this] var iterator: Iterator[InternalRow] = _ + private[this] var currentRow: InternalRow = _ + + override def open(): Unit = { + iterator = data.iterator + } + + override def next(): Boolean = { + if (iterator.hasNext) { + currentRow = iterator.next() + true + } else { + false + } + } + + override def get(): InternalRow = currentRow + + override def close(): Unit = { + // Do nothing + } +} From 71a3af8a94f900a26ac7094f22ec1216cab62e15 Mon Sep 17 00:00:00 2001 From: zc he Date: Fri, 14 Aug 2015 21:28:50 -0700 Subject: [PATCH 126/253] [SPARK-9960] [GRAPHX] sendMessage type fix in LabelPropagation.scala Author: zc he Closes #8188 from farseer90718/farseer-patch-1. --- .../scala/org/apache/spark/graphx/lib/LabelPropagation.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/graphx/src/main/scala/org/apache/spark/graphx/lib/LabelPropagation.scala b/graphx/src/main/scala/org/apache/spark/graphx/lib/LabelPropagation.scala index 2bcf8684b8b8e..a3ad6bed1c998 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/lib/LabelPropagation.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/lib/LabelPropagation.scala @@ -43,7 +43,7 @@ object LabelPropagation { */ def run[VD, ED: ClassTag](graph: Graph[VD, ED], maxSteps: Int): Graph[VertexId, ED] = { val lpaGraph = graph.mapVertices { case (vid, _) => vid } - def sendMessage(e: EdgeTriplet[VertexId, ED]): Iterator[(VertexId, Map[VertexId, VertexId])] = { + def sendMessage(e: EdgeTriplet[VertexId, ED]): Iterator[(VertexId, Map[VertexId, Long])] = { Iterator((e.srcId, Map(e.dstAttr -> 1L)), (e.dstId, Map(e.srcAttr -> 1L))) } def mergeMessage(count1: Map[VertexId, Long], count2: Map[VertexId, Long]) From 7c1e56825b716a7d703dff38254b4739755ac0c4 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Fri, 14 Aug 2015 22:30:35 -0700 Subject: [PATCH 127/253] [SPARK-9725] [SQL] fix serialization of UTF8String across different JVM The BYTE_ARRAY_OFFSET could be different in JVM with different configurations (for example, different heap size, 24 if heap > 32G, otherwise 16), so offset of UTF8String is not portable, we should handler that during serialization. Author: Davies Liu Closes #8210 from davies/serialize_utf8string. --- .../apache/spark/unsafe/types/UTF8String.java | 31 +++++++++++++++---- 1 file changed, 25 insertions(+), 6 deletions(-) diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java index 667c00900f2c5..cbcab958c05a9 100644 --- a/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java +++ b/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java @@ -18,8 +18,7 @@ package org.apache.spark.unsafe.types; import javax.annotation.Nonnull; -import java.io.Serializable; -import java.io.UnsupportedEncodingException; +import java.io.*; import java.nio.ByteOrder; import java.util.Arrays; import java.util.Map; @@ -38,12 +37,13 @@ *

* Note: This is not designed for general use cases, should not be used outside SQL. */ -public final class UTF8String implements Comparable, Serializable { +public final class UTF8String implements Comparable, Externalizable { + // These are only updated by readExternal() @Nonnull - private final Object base; - private final long offset; - private final int numBytes; + private Object base; + private long offset; + private int numBytes; public Object getBaseObject() { return base; } public long getBaseOffset() { return offset; } @@ -127,6 +127,11 @@ protected UTF8String(Object base, long offset, int numBytes) { this.numBytes = numBytes; } + // for serialization + public UTF8String() { + this(null, 0, 0); + } + /** * Writes the content of this string into a memory address, identified by an object and an offset. * The target memory address must already been allocated, and have enough space to hold all the @@ -978,4 +983,18 @@ public UTF8String soundex() { } return UTF8String.fromBytes(sx); } + + public void writeExternal(ObjectOutput out) throws IOException { + byte[] bytes = getBytes(); + out.writeInt(bytes.length); + out.write(bytes); + } + + public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException { + offset = BYTE_ARRAY_OFFSET; + numBytes = in.readInt(); + base = new byte[numBytes]; + in.readFully((byte[]) base); + } + } From a85fb6c07fdda5c74d53d6373910dcf5db3ff111 Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Sat, 15 Aug 2015 10:46:04 +0100 Subject: [PATCH 128/253] [SPARK-9980] [BUILD] Fix SBT publishLocal error due to invalid characters in doc Tiny modification to a few comments ```sbt publishLocal``` work again. Author: Herman van Hovell Closes #8209 from hvanhovell/SPARK-9980. --- .../java/org/apache/spark/unsafe/map/BytesToBytesMap.java | 6 +++--- .../apache/spark/examples/ml/JavaDeveloperApiExample.java | 4 ++-- .../examples/streaming/JavaStatefulNetworkWordCount.java | 2 +- .../src/main/java/org/apache/spark/launcher/Main.java | 4 ++-- .../apache/spark/launcher/SparkClassCommandBuilder.java | 2 +- .../java/org/apache/spark/launcher/SparkLauncher.java | 6 +++--- .../apache/spark/launcher/SparkSubmitCommandBuilder.java | 4 ++-- .../apache/spark/launcher/SparkSubmitOptionParser.java | 8 ++++---- .../org/apache/spark/unsafe/memory/TaskMemoryManager.java | 2 +- 9 files changed, 19 insertions(+), 19 deletions(-) diff --git a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java index 5f3a4fcf4d585..b24eed3952fd6 100644 --- a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java +++ b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java @@ -92,9 +92,9 @@ public final class BytesToBytesMap { /** * The maximum number of keys that BytesToBytesMap supports. The hash table has to be - * power-of-2-sized and its backing Java array can contain at most (1 << 30) elements, since - * that's the largest power-of-2 that's less than Integer.MAX_VALUE. We need two long array - * entries per key, giving us a maximum capacity of (1 << 29). + * power-of-2-sized and its backing Java array can contain at most (1 << 30) elements, + * since that's the largest power-of-2 that's less than Integer.MAX_VALUE. We need two long array + * entries per key, giving us a maximum capacity of (1 << 29). */ @VisibleForTesting static final int MAX_CAPACITY = (1 << 29); diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java index 3f1fe900b0008..a377694507d29 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java @@ -124,7 +124,7 @@ public String uid() { /** * Param for max number of iterations - *

+ *

* NOTE: The usual way to add a parameter to a model or algorithm is to include: * - val myParamName: ParamType * - def getMyParamName @@ -222,7 +222,7 @@ public Vector predictRaw(Vector features) { /** * Create a copy of the model. * The copy is shallow, except for the embedded paramMap, which gets a deep copy. - *

+ *

* This is used for the defaul implementation of [[transform()]]. * * In Java, we have to make this method public since Java does not understand Scala's protected diff --git a/examples/src/main/java/org/apache/spark/examples/streaming/JavaStatefulNetworkWordCount.java b/examples/src/main/java/org/apache/spark/examples/streaming/JavaStatefulNetworkWordCount.java index 02f58f48b07ab..99b63a2590ae2 100644 --- a/examples/src/main/java/org/apache/spark/examples/streaming/JavaStatefulNetworkWordCount.java +++ b/examples/src/main/java/org/apache/spark/examples/streaming/JavaStatefulNetworkWordCount.java @@ -45,7 +45,7 @@ * Usage: JavaStatefulNetworkWordCount * and describe the TCP server that Spark Streaming would connect to receive * data. - *

+ *

* To run this on your local machine, you need to first run a Netcat server * `$ nc -lk 9999` * and then run the example diff --git a/launcher/src/main/java/org/apache/spark/launcher/Main.java b/launcher/src/main/java/org/apache/spark/launcher/Main.java index 62492f9baf3bb..a4e3acc674f36 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/Main.java +++ b/launcher/src/main/java/org/apache/spark/launcher/Main.java @@ -32,7 +32,7 @@ class Main { /** * Usage: Main [class] [class args] - *

+ *

* This CLI works in two different modes: *