diff --git a/LICENSE b/LICENSE index 820f14dbdeed0..b771bd552b762 100644 --- a/LICENSE +++ b/LICENSE @@ -201,102 +201,61 @@ limitations under the License. -======================================================================= -Apache Spark Subcomponents: - -The Apache Spark project contains subcomponents with separate copyright -notices and license terms. Your use of the source code for the these -subcomponents is subject to the terms and conditions of the following -licenses. - - -======================================================================== -For heapq (pyspark/heapq3.py): -======================================================================== - -See license/LICENSE-heapq.txt - -======================================================================== -For SnapTree: -======================================================================== - -See license/LICENSE-SnapTree.txt - -======================================================================== -For jbcrypt: -======================================================================== - -See license/LICENSE-jbcrypt.txt - -======================================================================== -BSD-style licenses -======================================================================== - -The following components are provided under a BSD-style license. See project link for details. -The text of each license is also included at licenses/LICENSE-[project].txt. - - (BSD 3 Clause) netlib core (com.github.fommil.netlib:core:1.1.2 - https://github.com/fommil/netlib-java/core) - (BSD 3 Clause) JPMML-Model (org.jpmml:pmml-model:1.2.7 - https://github.com/jpmml/jpmml-model) - (BSD License) AntLR Parser Generator (antlr:antlr:2.7.7 - http://www.antlr.org/) - (BSD License) ANTLR 4.5.2-1 (org.antlr:antlr4:4.5.2-1 - http://wwww.antlr.org/) - (BSD licence) ANTLR ST4 4.0.4 (org.antlr:ST4:4.0.4 - http://www.stringtemplate.org) - (BSD licence) ANTLR StringTemplate (org.antlr:stringtemplate:3.2.1 - http://www.stringtemplate.org) - (BSD License) Javolution (javolution:javolution:5.5.1 - http://javolution.org) - (BSD) JLine (jline:jline:0.9.94 - http://jline.sourceforge.net) - (BSD) ParaNamer Core (com.thoughtworks.paranamer:paranamer:2.3 - http://paranamer.codehaus.org/paranamer) - (BSD) ParaNamer Core (com.thoughtworks.paranamer:paranamer:2.6 - http://paranamer.codehaus.org/paranamer) - (BSD 3 Clause) Scala (http://www.scala-lang.org/download/#License) - (Interpreter classes (all .scala files in repl/src/main/scala - except for Main.Scala, SparkHelper.scala and ExecutorClassLoader.scala), - and for SerializableMapWrapper in JavaUtils.scala) - (BSD-like) Scala Actors library (org.scala-lang:scala-actors:2.11.8 - http://www.scala-lang.org/) - (BSD-like) Scala Compiler (org.scala-lang:scala-compiler:2.11.8 - http://www.scala-lang.org/) - (BSD-like) Scala Compiler (org.scala-lang:scala-reflect:2.11.8 - http://www.scala-lang.org/) - (BSD-like) Scala Library (org.scala-lang:scala-library:2.11.8 - http://www.scala-lang.org/) - (BSD-like) Scalap (org.scala-lang:scalap:2.11.8 - http://www.scala-lang.org/) - (BSD-style) scalacheck (org.scalacheck:scalacheck_2.11:1.10.0 - http://www.scalacheck.org) - (BSD-style) spire (org.spire-math:spire_2.11:0.7.1 - http://spire-math.org) - (BSD-style) spire-macros (org.spire-math:spire-macros_2.11:0.7.1 - http://spire-math.org) - (New BSD License) Kryo (com.esotericsoftware:kryo:3.0.3 - https://github.com/EsotericSoftware/kryo) - (New BSD License) MinLog (com.esotericsoftware:minlog:1.3.0 - https://github.com/EsotericSoftware/minlog) - (New BSD license) Protocol Buffer Java API (com.google.protobuf:protobuf-java:2.5.0 - http://code.google.com/p/protobuf) - (New BSD license) Protocol Buffer Java API (org.spark-project.protobuf:protobuf-java:2.4.1-shaded - http://code.google.com/p/protobuf) - (The BSD License) Fortran to Java ARPACK (net.sourceforge.f2j:arpack_combined_all:0.1 - http://f2j.sourceforge.net) - (The BSD License) xmlenc Library (xmlenc:xmlenc:0.52 - http://xmlenc.sourceforge.net) - (The New BSD License) Py4J (net.sf.py4j:py4j:0.10.7 - http://py4j.sourceforge.net/) - (Two-clause BSD-style license) JUnit-Interface (com.novocode:junit-interface:0.10 - http://github.com/szeiger/junit-interface/) - (BSD licence) sbt and sbt-launch-lib.bash - (BSD 3 Clause) d3.min.js (https://github.com/mbostock/d3/blob/master/LICENSE) - (BSD 3 Clause) DPark (https://github.com/douban/dpark/blob/master/LICENSE) - (BSD 3 Clause) CloudPickle (https://github.com/cloudpipe/cloudpickle/blob/master/LICENSE) - (BSD 2 Clause) Zstd-jni (https://github.com/luben/zstd-jni/blob/master/LICENSE) - (BSD license) Zstd (https://github.com/facebook/zstd/blob/v1.3.1/LICENSE) - -======================================================================== -MIT licenses -======================================================================== - -The following components are provided under the MIT License. See project link for details. -The text of each license is also included at licenses/LICENSE-[project].txt. - - (MIT License) JCL 1.1.1 implemented over SLF4J (org.slf4j:jcl-over-slf4j:1.7.5 - http://www.slf4j.org) - (MIT License) JUL to SLF4J bridge (org.slf4j:jul-to-slf4j:1.7.5 - http://www.slf4j.org) - (MIT License) SLF4J API Module (org.slf4j:slf4j-api:1.7.5 - http://www.slf4j.org) - (MIT License) SLF4J LOG4J-12 Binding (org.slf4j:slf4j-log4j12:1.7.5 - http://www.slf4j.org) - (MIT License) pyrolite (org.spark-project:pyrolite:2.0.1 - http://pythonhosted.org/Pyro4/) - (MIT License) scopt (com.github.scopt:scopt_2.11:3.2.0 - https://github.com/scopt/scopt) - (The MIT License) Mockito (org.mockito:mockito-core:1.9.5 - http://www.mockito.org) - (MIT License) jquery (https://jquery.org/license/) - (MIT License) AnchorJS (https://github.com/bryanbraun/anchorjs) - (MIT License) graphlib-dot (https://github.com/cpettitt/graphlib-dot) - (MIT License) dagre-d3 (https://github.com/cpettitt/dagre-d3) - (MIT License) sorttable (https://github.com/stuartlangridge/sorttable) - (MIT License) boto (https://github.com/boto/boto/blob/develop/LICENSE) - (MIT License) datatables (http://datatables.net/license) - (MIT License) mustache (https://github.com/mustache/mustache/blob/master/LICENSE) - (MIT License) cookies (http://code.google.com/p/cookies/wiki/License) - (MIT License) blockUI (http://jquery.malsup.com/block/) - (MIT License) RowsGroup (http://datatables.net/license/mit) - (MIT License) jsonFormatter (http://www.jqueryscript.net/other/jQuery-Plugin-For-Pretty-JSON-Formatting-jsonFormatter.html) - (MIT License) modernizr (https://github.com/Modernizr/Modernizr/blob/master/LICENSE) - (MIT License) machinist (https://github.com/typelevel/machinist) +------------------------------------------------------------------------------------ +This product bundles various third-party components under other open source licenses. +This section summarizes those components and their licenses. See licenses/ +for text of these licenses. + + +Apache Software Foundation License 2.0 +-------------------------------------- + +common/network-common/src/main/java/org/apache/spark/network/util/LimitedInputStream.java +core/src/main/java/org/apache/spark/util/collection/TimSort.java +core/src/main/resources/org/apache/spark/ui/static/bootstrap* +core/src/main/resources/org/apache/spark/ui/static/jsonFormatter* +core/src/main/resources/org/apache/spark/ui/static/vis* +docs/js/vendor/bootstrap.js + + +Python Software Foundation License +---------------------------------- + +pyspark/heapq3.py + + +BSD 3-Clause +------------ + +python/lib/py4j-*-src.zip +python/pyspark/cloudpickle.py +python/pyspark/join.py +core/src/main/resources/org/apache/spark/ui/static/d3.min.js + +The CSS style for the navigation sidebar of the documentation was originally +submitted by Óscar Nájera for the scikit-learn project. The scikit-learn project +is distributed under the 3-Clause BSD license. + + +MIT License +----------- + +core/src/main/resources/org/apache/spark/ui/static/dagre-d3.min.js +core/src/main/resources/org/apache/spark/ui/static/*dataTables* +core/src/main/resources/org/apache/spark/ui/static/graphlib-dot.min.js +ore/src/main/resources/org/apache/spark/ui/static/jquery* +core/src/main/resources/org/apache/spark/ui/static/sorttable.js +docs/js/vendor/anchor.min.js +docs/js/vendor/jquery* +docs/js/vendor/modernizer* + + +Creative Commons CC0 1.0 Universal Public Domain Dedication +----------------------------------------------------------- +(see LICENSE-CC0.txt) + +data/mllib/images/kittens/29.5.a_b_EGDP022204.jpg +data/mllib/images/kittens/54893.jpg +data/mllib/images/kittens/DP153539.jpg +data/mllib/images/kittens/DP802813.jpg +data/mllib/images/multi-channel/chr30.4.184.jpg \ No newline at end of file diff --git a/LICENSE-binary b/LICENSE-binary new file mode 100644 index 0000000000000..c033dd8ad2e6a --- /dev/null +++ b/LICENSE-binary @@ -0,0 +1,520 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + +------------------------------------------------------------------------------------ +This project bundles some components that are also licensed under the Apache +License Version 2.0: + +commons-beanutils:commons-beanutils +org.apache.zookeeper:zookeeper +oro:oro +commons-configuration:commons-configuration +commons-digester:commons-digester +com.chuusai:shapeless_2.11 +com.googlecode.javaewah:JavaEWAH +com.twitter:chill-java +com.twitter:chill_2.11 +com.univocity:univocity-parsers +javax.jdo:jdo-api +joda-time:joda-time +net.sf.opencsv:opencsv +org.apache.derby:derby +org.objenesis:objenesis +org.roaringbitmap:RoaringBitmap +org.scalanlp:breeze-macros_2.11 +org.scalanlp:breeze_2.11 +org.typelevel:macro-compat_2.11 +org.yaml:snakeyaml +org.apache.xbean:xbean-asm5-shaded +com.squareup.okhttp3:logging-interceptor +com.squareup.okhttp3:okhttp +com.squareup.okio:okio +net.java.dev.jets3t:jets3t +org.apache.spark:spark-catalyst_2.11 +org.apache.spark:spark-kvstore_2.11 +org.apache.spark:spark-launcher_2.11 +org.apache.spark:spark-mllib-local_2.11 +org.apache.spark:spark-network-common_2.11 +org.apache.spark:spark-network-shuffle_2.11 +org.apache.spark:spark-sketch_2.11 +org.apache.spark:spark-tags_2.11 +org.apache.spark:spark-unsafe_2.11 +commons-httpclient:commons-httpclient +com.vlkan:flatbuffers +com.ning:compress-lzf +io.airlift:aircompressor +io.dropwizard.metrics:metrics-core +io.dropwizard.metrics:metrics-ganglia +io.dropwizard.metrics:metrics-graphite +io.dropwizard.metrics:metrics-json +io.dropwizard.metrics:metrics-jvm +org.iq80.snappy:snappy +com.clearspring.analytics:stream +com.jamesmurty.utils:java-xmlbuilder +commons-codec:commons-codec +commons-collections:commons-collections +io.fabric8:kubernetes-client +io.fabric8:kubernetes-model +io.netty:netty +io.netty:netty-all +net.hydromatic:eigenbase-properties +net.sf.supercsv:super-csv +org.apache.arrow:arrow-format +org.apache.arrow:arrow-memory +org.apache.arrow:arrow-vector +org.apache.calcite:calcite-avatica +org.apache.calcite:calcite-core +org.apache.calcite:calcite-linq4j +org.apache.commons:commons-crypto +org.apache.commons:commons-lang3 +org.apache.hadoop:hadoop-annotations +org.apache.hadoop:hadoop-auth +org.apache.hadoop:hadoop-client +org.apache.hadoop:hadoop-common +org.apache.hadoop:hadoop-hdfs +org.apache.hadoop:hadoop-mapreduce-client-app +org.apache.hadoop:hadoop-mapreduce-client-common +org.apache.hadoop:hadoop-mapreduce-client-core +org.apache.hadoop:hadoop-mapreduce-client-jobclient +org.apache.hadoop:hadoop-mapreduce-client-shuffle +org.apache.hadoop:hadoop-yarn-api +org.apache.hadoop:hadoop-yarn-client +org.apache.hadoop:hadoop-yarn-common +org.apache.hadoop:hadoop-yarn-server-common +org.apache.hadoop:hadoop-yarn-server-web-proxy +org.apache.httpcomponents:httpclient +org.apache.httpcomponents:httpcore +org.apache.orc:orc-core +org.apache.orc:orc-mapreduce +org.mortbay.jetty:jetty +org.mortbay.jetty:jetty-util +com.jolbox:bonecp +org.json4s:json4s-ast_2.11 +org.json4s:json4s-core_2.11 +org.json4s:json4s-jackson_2.11 +org.json4s:json4s-scalap_2.11 +com.carrotsearch:hppc +com.fasterxml.jackson.core:jackson-annotations +com.fasterxml.jackson.core:jackson-core +com.fasterxml.jackson.core:jackson-databind +com.fasterxml.jackson.dataformat:jackson-dataformat-yaml +com.fasterxml.jackson.module:jackson-module-jaxb-annotations +com.fasterxml.jackson.module:jackson-module-paranamer +com.fasterxml.jackson.module:jackson-module-scala_2.11 +com.github.mifmif:generex +com.google.code.findbugs:jsr305 +com.google.code.gson:gson +com.google.inject:guice +com.google.inject.extensions:guice-servlet +com.twitter:parquet-hadoop-bundle +commons-beanutils:commons-beanutils-core +commons-cli:commons-cli +commons-dbcp:commons-dbcp +commons-io:commons-io +commons-lang:commons-lang +commons-logging:commons-logging +commons-net:commons-net +commons-pool:commons-pool +io.fabric8:zjsonpatch +javax.inject:javax.inject +javax.validation:validation-api +log4j:apache-log4j-extras +log4j:log4j +net.sf.jpam:jpam +org.apache.avro:avro +org.apache.avro:avro-ipc +org.apache.avro:avro-mapred +org.apache.commons:commons-compress +org.apache.commons:commons-math3 +org.apache.curator:curator-client +org.apache.curator:curator-framework +org.apache.curator:curator-recipes +org.apache.directory.api:api-asn1-api +org.apache.directory.api:api-util +org.apache.directory.server:apacheds-i18n +org.apache.directory.server:apacheds-kerberos-codec +org.apache.htrace:htrace-core +org.apache.ivy:ivy +org.apache.mesos:mesos +org.apache.parquet:parquet-column +org.apache.parquet:parquet-common +org.apache.parquet:parquet-encoding +org.apache.parquet:parquet-format +org.apache.parquet:parquet-hadoop +org.apache.parquet:parquet-jackson +org.apache.thrift:libfb303 +org.apache.thrift:libthrift +org.codehaus.jackson:jackson-core-asl +org.codehaus.jackson:jackson-mapper-asl +org.datanucleus:datanucleus-api-jdo +org.datanucleus:datanucleus-core +org.datanucleus:datanucleus-rdbms +org.lz4:lz4-java +org.spark-project.hive:hive-beeline +org.spark-project.hive:hive-cli +org.spark-project.hive:hive-exec +org.spark-project.hive:hive-jdbc +org.spark-project.hive:hive-metastore +org.xerial.snappy:snappy-java +stax:stax-api +xerces:xercesImpl +org.codehaus.jackson:jackson-jaxrs +org.codehaus.jackson:jackson-xc +org.eclipse.jetty:jetty-client +org.eclipse.jetty:jetty-continuation +org.eclipse.jetty:jetty-http +org.eclipse.jetty:jetty-io +org.eclipse.jetty:jetty-jndi +org.eclipse.jetty:jetty-plus +org.eclipse.jetty:jetty-proxy +org.eclipse.jetty:jetty-security +org.eclipse.jetty:jetty-server +org.eclipse.jetty:jetty-servlet +org.eclipse.jetty:jetty-servlets +org.eclipse.jetty:jetty-util +org.eclipse.jetty:jetty-webapp +org.eclipse.jetty:jetty-xml + +core/src/main/java/org/apache/spark/util/collection/TimSort.java +core/src/main/resources/org/apache/spark/ui/static/bootstrap* +core/src/main/resources/org/apache/spark/ui/static/jsonFormatter* +core/src/main/resources/org/apache/spark/ui/static/vis* +docs/js/vendor/bootstrap.js + + +------------------------------------------------------------------------------------ +This product bundles various third-party components under other open source licenses. +This section summarizes those components and their licenses. See licenses-binary/ +for text of these licenses. + + +BSD 2-Clause +------------ + +com.github.luben:zstd-jni +javolution:javolution +com.esotericsoftware:kryo-shaded +com.esotericsoftware:minlog +com.esotericsoftware:reflectasm +com.google.protobuf:protobuf-java +org.codehaus.janino:commons-compiler +org.codehaus.janino:janino +jline:jline +org.jodd:jodd-core + + +BSD 3-Clause +------------ + +dk.brics.automaton:automaton +org.antlr:antlr-runtime +org.antlr:ST4 +org.antlr:stringtemplate +org.antlr:antlr4-runtime +antlr:antlr +com.github.fommil.netlib:core +com.thoughtworks.paranamer:paranamer +org.scala-lang:scala-compiler +org.scala-lang:scala-library +org.scala-lang:scala-reflect +org.scala-lang.modules:scala-parser-combinators_2.11 +org.scala-lang.modules:scala-xml_2.11 +org.fusesource.leveldbjni:leveldbjni-all +net.sourceforge.f2j:arpack_combined_all +xmlenc:xmlenc +net.sf.py4j:py4j +org.jpmml:pmml-model +org.jpmml:pmml-schema + +python/lib/py4j-*-src.zip +python/pyspark/cloudpickle.py +python/pyspark/join.py +core/src/main/resources/org/apache/spark/ui/static/d3.min.js + +The CSS style for the navigation sidebar of the documentation was originally +submitted by Óscar Nájera for the scikit-learn project. The scikit-learn project +is distributed under the 3-Clause BSD license. + + +MIT License +----------- + +org.spire-math:spire-macros_2.11 +org.spire-math:spire_2.11 +org.typelevel:machinist_2.11 +net.razorvine:pyrolite +org.slf4j:jcl-over-slf4j +org.slf4j:jul-to-slf4j +org.slf4j:slf4j-api +org.slf4j:slf4j-log4j12 +com.github.scopt:scopt_2.11 +org.bouncycastle:bcprov-jdk15on + +core/src/main/resources/org/apache/spark/ui/static/dagre-d3.min.js +core/src/main/resources/org/apache/spark/ui/static/*dataTables* +core/src/main/resources/org/apache/spark/ui/static/graphlib-dot.min.js +ore/src/main/resources/org/apache/spark/ui/static/jquery* +core/src/main/resources/org/apache/spark/ui/static/sorttable.js +docs/js/vendor/anchor.min.js +docs/js/vendor/jquery* +docs/js/vendor/modernizer* + + +Common Development and Distribution License (CDDL) 1.0 +------------------------------------------------------ + +javax.activation:activation http://www.oracle.com/technetwork/java/javase/tech/index-jsp-138795.html +javax.xml.stream:stax-api https://jcp.org/en/jsr/detail?id=173 + + +Common Development and Distribution License (CDDL) 1.1 +------------------------------------------------------ + +javax.annotation:javax.annotation-api https://jcp.org/en/jsr/detail?id=250 +javax.servlet:javax.servlet-api https://javaee.github.io/servlet-spec/ +javax.transaction:jta http://www.oracle.com/technetwork/java/index.html +javax.ws.rs:javax.ws.rs-api https://github.com/jax-rs +javax.xml.bind:jaxb-api https://github.com/javaee/jaxb-v2 +org.glassfish.hk2:hk2-api https://github.com/javaee/glassfish +org.glassfish.hk2:hk2-locator (same) +org.glassfish.hk2:hk2-utils +org.glassfish.hk2:osgi-resource-locator +org.glassfish.hk2.external:aopalliance-repackaged +org.glassfish.hk2.external:javax.inject +org.glassfish.jersey.bundles.repackaged:jersey-guava +org.glassfish.jersey.containers:jersey-container-servlet +org.glassfish.jersey.containers:jersey-container-servlet-core +org.glassfish.jersey.core:jersey-client +org.glassfish.jersey.core:jersey-common +org.glassfish.jersey.core:jersey-server +org.glassfish.jersey.media:jersey-media-jaxb + + +Mozilla Public License (MPL) 1.1 +-------------------------------- + +com.github.rwl:jtransforms https://sourceforge.net/projects/jtransforms/ + + +Python Software Foundation License +---------------------------------- + +pyspark/heapq3.py + + +Public Domain +------------- + +aopalliance:aopalliance +net.iharder:base64 +org.tukaani:xz + + +Creative Commons CC0 1.0 Universal Public Domain Dedication +----------------------------------------------------------- +(see LICENSE-CC0.txt) + +data/mllib/images/kittens/29.5.a_b_EGDP022204.jpg +data/mllib/images/kittens/54893.jpg +data/mllib/images/kittens/DP153539.jpg +data/mllib/images/kittens/DP802813.jpg +data/mllib/images/multi-channel/chr30.4.184.jpg diff --git a/NOTICE b/NOTICE index 6ec240efbf12e..9246cc54caa3a 100644 --- a/NOTICE +++ b/NOTICE @@ -4,664 +4,3 @@ Copyright 2014 and onwards The Apache Software Foundation. This product includes software developed at The Apache Software Foundation (http://www.apache.org/). - -======================================================================== -Common Development and Distribution License 1.0 -======================================================================== - -The following components are provided under the Common Development and Distribution License 1.0. See project link for details. - - (CDDL 1.0) Glassfish Jasper (org.mortbay.jetty:jsp-2.1:6.1.14 - http://jetty.mortbay.org/project/modules/jsp-2.1) - (CDDL 1.0) JAX-RS (https://jax-rs-spec.java.net/) - (CDDL 1.0) Servlet Specification 2.5 API (org.mortbay.jetty:servlet-api-2.5:6.1.14 - http://jetty.mortbay.org/project/modules/servlet-api-2.5) - (CDDL 1.0) (GPL2 w/ CPE) javax.annotation API (https://glassfish.java.net/nonav/public/CDDL+GPL.html) - (COMMON DEVELOPMENT AND DISTRIBUTION LICENSE (CDDL) Version 1.0) (GNU General Public Library) Streaming API for XML (javax.xml.stream:stax-api:1.0-2 - no url defined) - (Common Development and Distribution License (CDDL) v1.0) JavaBeans Activation Framework (JAF) (javax.activation:activation:1.1 - http://java.sun.com/products/javabeans/jaf/index.jsp) - -======================================================================== -Common Development and Distribution License 1.1 -======================================================================== - -The following components are provided under the Common Development and Distribution License 1.1. See project link for details. - - (CDDL 1.1) (GPL2 w/ CPE) org.glassfish.hk2 (https://hk2.java.net) - (CDDL 1.1) (GPL2 w/ CPE) JAXB API bundle for GlassFish V3 (javax.xml.bind:jaxb-api:2.2.2 - https://jaxb.dev.java.net/) - (CDDL 1.1) (GPL2 w/ CPE) JAXB RI (com.sun.xml.bind:jaxb-impl:2.2.3-1 - http://jaxb.java.net/) - (CDDL 1.1) (GPL2 w/ CPE) Jersey 2 (https://jersey.java.net) - -======================================================================== -Common Public License 1.0 -======================================================================== - -The following components are provided under the Common Public 1.0 License. See project link for details. - - (Common Public License Version 1.0) JUnit (junit:junit-dep:4.10 - http://junit.org) - (Common Public License Version 1.0) JUnit (junit:junit:3.8.1 - http://junit.org) - (Common Public License Version 1.0) JUnit (junit:junit:4.8.2 - http://junit.org) - -======================================================================== -Eclipse Public License 1.0 -======================================================================== - -The following components are provided under the Eclipse Public License 1.0. See project link for details. - - (Eclipse Public License v1.0) Eclipse JDT Core (org.eclipse.jdt:core:3.1.1 - http://www.eclipse.org/jdt/) - -======================================================================== -Mozilla Public License 1.0 -======================================================================== - -The following components are provided under the Mozilla Public License 1.0. See project link for details. - - (GPL) (LGPL) (MPL) JTransforms (com.github.rwl:jtransforms:2.4.0 - http://sourceforge.net/projects/jtransforms/) - (Mozilla Public License Version 1.1) jamon-runtime (org.jamon:jamon-runtime:2.3.1 - http://www.jamon.org/jamon-runtime/) - - - -======================================================================== -NOTICE files -======================================================================== - -The following NOTICEs are pertain to software distributed with this project. - - -// ------------------------------------------------------------------ -// NOTICE file corresponding to the section 4d of The Apache License, -// Version 2.0, in this case for -// ------------------------------------------------------------------ - -Apache Avro -Copyright 2009-2013 The Apache Software Foundation - -This product includes software developed at -The Apache Software Foundation (http://www.apache.org/). - -Apache Commons Codec -Copyright 2002-2009 The Apache Software Foundation - -This product includes software developed by -The Apache Software Foundation (http://www.apache.org/). - --------------------------------------------------------------------------------- -src/test/org/apache/commons/codec/language/DoubleMetaphoneTest.java contains -test data from http://aspell.sourceforge.net/test/batch0.tab. - -Copyright (C) 2002 Kevin Atkinson (kevina@gnu.org). Verbatim copying -and distribution of this entire article is permitted in any medium, -provided this notice is preserved. --------------------------------------------------------------------------------- - -Apache HttpComponents HttpClient -Copyright 1999-2011 The Apache Software Foundation - -This project contains annotations derived from JCIP-ANNOTATIONS -Copyright (c) 2005 Brian Goetz and Tim Peierls. See http://www.jcip.net - -Apache HttpComponents HttpCore -Copyright 2005-2011 The Apache Software Foundation - -Curator Recipes -Copyright 2011-2014 The Apache Software Foundation - -Curator Framework -Copyright 2011-2014 The Apache Software Foundation - -Curator Client -Copyright 2011-2014 The Apache Software Foundation - -Apache Geronimo -Copyright 2003-2008 The Apache Software Foundation - -Activation 1.1 -Copyright 2003-2007 The Apache Software Foundation - -Apache Commons Lang -Copyright 2001-2014 The Apache Software Foundation - -This product includes software from the Spring Framework, -under the Apache License 2.0 (see: StringUtils.containsWhitespace()) - -Apache log4j -Copyright 2007 The Apache Software Foundation - -# Compress LZF - -This library contains efficient implementation of LZF compression format, -as well as additional helper classes that build on JDK-provided gzip (deflat) -codec. - -## Licensing - -Library is licensed under Apache License 2.0, as per accompanying LICENSE file. - -## Credit - -Library has been written by Tatu Saloranta (tatu.saloranta@iki.fi). -It was started at Ning, inc., as an official Open Source process used by -platform backend, but after initial versions has been developed outside of -Ning by supporting community. - -Other contributors include: - -* Jon Hartlaub (first versions of streaming reader/writer; unit tests) -* Cedrik Lime: parallel LZF implementation - -Various community members have contributed bug reports, and suggested minor -fixes; these can be found from file "VERSION.txt" in SCM. - -Objenesis -Copyright 2006-2009 Joe Walnes, Henri Tremblay, Leonardo Mesquita - -Apache Commons Net -Copyright 2001-2010 The Apache Software Foundation - - The Netty Project - ================= - -Please visit the Netty web site for more information: - - * http://netty.io/ - -Copyright 2011 The Netty Project - -The Netty Project licenses this file to you under the Apache License, -version 2.0 (the "License"); you may not use this file except in compliance -with the License. You may obtain a copy of the License at: - -http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, WITHOUT -WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the -License for the specific language governing permissions and limitations -under the License. - -Also, please refer to each LICENSE..txt file, which is located in -the 'license' directory of the distribution file, for the license terms of the -components that this product depends on. - -------------------------------------------------------------------------------- -This product contains the extensions to Java Collections Framework which has -been derived from the works by JSR-166 EG, Doug Lea, and Jason T. Greene: - - * LICENSE: - * license/LICENSE.jsr166y.txt (Public Domain) - * HOMEPAGE: - * http://gee.cs.oswego.edu/cgi-bin/viewcvs.cgi/jsr166/ - * http://viewvc.jboss.org/cgi-bin/viewvc.cgi/jbosscache/experimental/jsr166/ - -This product contains a modified version of Robert Harder's Public Domain -Base64 Encoder and Decoder, which can be obtained at: - - * LICENSE: - * license/LICENSE.base64.txt (Public Domain) - * HOMEPAGE: - * http://iharder.sourceforge.net/current/java/base64/ - -This product contains a modified version of 'JZlib', a re-implementation of -zlib in pure Java, which can be obtained at: - - * LICENSE: - * license/LICENSE.jzlib.txt (BSD Style License) - * HOMEPAGE: - * http://www.jcraft.com/jzlib/ - -This product optionally depends on 'Protocol Buffers', Google's data -interchange format, which can be obtained at: - - * LICENSE: - * license/LICENSE.protobuf.txt (New BSD License) - * HOMEPAGE: - * http://code.google.com/p/protobuf/ - -This product optionally depends on 'SLF4J', a simple logging facade for Java, -which can be obtained at: - - * LICENSE: - * license/LICENSE.slf4j.txt (MIT License) - * HOMEPAGE: - * http://www.slf4j.org/ - -This product optionally depends on 'Apache Commons Logging', a logging -framework, which can be obtained at: - - * LICENSE: - * license/LICENSE.commons-logging.txt (Apache License 2.0) - * HOMEPAGE: - * http://commons.apache.org/logging/ - -This product optionally depends on 'Apache Log4J', a logging framework, -which can be obtained at: - - * LICENSE: - * license/LICENSE.log4j.txt (Apache License 2.0) - * HOMEPAGE: - * http://logging.apache.org/log4j/ - -This product optionally depends on 'JBoss Logging', a logging framework, -which can be obtained at: - - * LICENSE: - * license/LICENSE.jboss-logging.txt (GNU LGPL 2.1) - * HOMEPAGE: - * http://anonsvn.jboss.org/repos/common/common-logging-spi/ - -This product optionally depends on 'Apache Felix', an open source OSGi -framework implementation, which can be obtained at: - - * LICENSE: - * license/LICENSE.felix.txt (Apache License 2.0) - * HOMEPAGE: - * http://felix.apache.org/ - -This product optionally depends on 'Webbit', a Java event based -WebSocket and HTTP server: - - * LICENSE: - * license/LICENSE.webbit.txt (BSD License) - * HOMEPAGE: - * https://github.com/joewalnes/webbit - -# Jackson JSON processor - -Jackson is a high-performance, Free/Open Source JSON processing library. -It was originally written by Tatu Saloranta (tatu.saloranta@iki.fi), and has -been in development since 2007. -It is currently developed by a community of developers, as well as supported -commercially by FasterXML.com. - -Jackson core and extension components may be licensed under different licenses. -To find the details that apply to this artifact see the accompanying LICENSE file. -For more information, including possible other licensing options, contact -FasterXML.com (http://fasterxml.com). - -## Credits - -A list of contributors may be found from CREDITS file, which is included -in some artifacts (usually source distributions); but is always available -from the source code management (SCM) system project uses. - -Jackson core and extension components may licensed under different licenses. -To find the details that apply to this artifact see the accompanying LICENSE file. -For more information, including possible other licensing options, contact -FasterXML.com (http://fasterxml.com). - -mesos -Copyright 2014 The Apache Software Foundation - -Apache Thrift -Copyright 2006-2010 The Apache Software Foundation. - - Apache Ant - Copyright 1999-2013 The Apache Software Foundation - - The task is based on code Copyright (c) 2002, Landmark - Graphics Corp that has been kindly donated to the Apache Software - Foundation. - -Apache Commons IO -Copyright 2002-2012 The Apache Software Foundation - -Apache Commons Math -Copyright 2001-2013 The Apache Software Foundation - -=============================================================================== - -The inverse error function implementation in the Erf class is based on CUDA -code developed by Mike Giles, Oxford-Man Institute of Quantitative Finance, -and published in GPU Computing Gems, volume 2, 2010. -=============================================================================== - -The BracketFinder (package org.apache.commons.math3.optimization.univariate) -and PowellOptimizer (package org.apache.commons.math3.optimization.general) -classes are based on the Python code in module "optimize.py" (version 0.5) -developed by Travis E. Oliphant for the SciPy library (http://www.scipy.org/) -Copyright © 2003-2009 SciPy Developers. -=============================================================================== - -The LinearConstraint, LinearObjectiveFunction, LinearOptimizer, -RelationShip, SimplexSolver and SimplexTableau classes in package -org.apache.commons.math3.optimization.linear include software developed by -Benjamin McCann (http://www.benmccann.com) and distributed with -the following copyright: Copyright 2009 Google Inc. -=============================================================================== - -This product includes software developed by the -University of Chicago, as Operator of Argonne National -Laboratory. -The LevenbergMarquardtOptimizer class in package -org.apache.commons.math3.optimization.general includes software -translated from the lmder, lmpar and qrsolv Fortran routines -from the Minpack package -Minpack Copyright Notice (1999) University of Chicago. All rights reserved -=============================================================================== - -The GraggBulirschStoerIntegrator class in package -org.apache.commons.math3.ode.nonstiff includes software translated -from the odex Fortran routine developed by E. Hairer and G. Wanner. -Original source copyright: -Copyright (c) 2004, Ernst Hairer -=============================================================================== - -The EigenDecompositionImpl class in package -org.apache.commons.math3.linear includes software translated -from some LAPACK Fortran routines. Original source copyright: -Copyright (c) 1992-2008 The University of Tennessee. All rights reserved. -=============================================================================== - -The MersenneTwister class in package org.apache.commons.math3.random -includes software translated from the 2002-01-26 version of -the Mersenne-Twister generator written in C by Makoto Matsumoto and Takuji -Nishimura. Original source copyright: -Copyright (C) 1997 - 2002, Makoto Matsumoto and Takuji Nishimura, -All rights reserved -=============================================================================== - -The LocalizedFormatsTest class in the unit tests is an adapted version of -the OrekitMessagesTest class from the orekit library distributed under the -terms of the Apache 2 licence. Original source copyright: -Copyright 2010 CS Systèmes d'Information -=============================================================================== - -The HermiteInterpolator class and its corresponding test have been imported from -the orekit library distributed under the terms of the Apache 2 licence. Original -source copyright: -Copyright 2010-2012 CS Systèmes d'Information -=============================================================================== - -The creation of the package "o.a.c.m.analysis.integration.gauss" was inspired -by an original code donated by Sébastien Brisard. -=============================================================================== - -The complete text of licenses and disclaimers associated with the the original -sources enumerated above at the time of code translation are in the LICENSE.txt -file. - -This product currently only contains code developed by authors -of specific components, as identified by the source code files; -if such notes are missing files have been created by -Tatu Saloranta. - -For additional credits (generally to people who reported problems) -see CREDITS file. - -Apache Commons Lang -Copyright 2001-2011 The Apache Software Foundation - -Apache Commons Compress -Copyright 2002-2012 The Apache Software Foundation - -Apache Commons CLI -Copyright 2001-2009 The Apache Software Foundation - -Google Guice - Extensions - Servlet -Copyright 2006-2011 Google, Inc. - -Google Guice - Core Library -Copyright 2006-2011 Google, Inc. - -Apache Jakarta HttpClient -Copyright 1999-2007 The Apache Software Foundation - -Apache Hive -Copyright 2008-2013 The Apache Software Foundation - -This product includes software developed by The Apache Software -Foundation (http://www.apache.org/). - -This product includes software developed by The JDBM Project -(http://jdbm.sourceforge.net/). - -This product includes/uses ANTLR (http://www.antlr.org/), -Copyright (c) 2003-2011, Terrence Parr. - -This product includes/uses StringTemplate (http://www.stringtemplate.org/), -Copyright (c) 2011, Terrence Parr. - -This product includes/uses ASM (http://asm.ow2.org/), -Copyright (c) 2000-2007 INRIA, France Telecom. - -This product includes/uses JLine (http://jline.sourceforge.net/), -Copyright (c) 2002-2006, Marc Prud'hommeaux . - -This product includes/uses SQLLine (http://sqlline.sourceforge.net), -Copyright (c) 2002, 2003, 2004, 2005 Marc Prud'hommeaux . - -This product includes/uses SLF4J (http://www.slf4j.org/), -Copyright (c) 2004-2010 QOS.ch - -This product includes/uses Bootstrap (http://twitter.github.com/bootstrap/), -Copyright (c) 2012 Twitter, Inc. - -This product includes/uses Glyphicons (http://glyphicons.com/), -Copyright (c) 2010 - 2012 Jan Kovarík - -This product includes DataNucleus (http://www.datanucleus.org/) -Copyright 2008-2008 DataNucleus - -This product includes Guava (http://code.google.com/p/guava-libraries/) -Copyright (C) 2006 Google Inc. - -This product includes JavaEWAH (http://code.google.com/p/javaewah/) -Copyright (C) 2011 Google Inc. - -Apache Commons Pool -Copyright 1999-2009 The Apache Software Foundation - -This product includes/uses Kubernetes & OpenShift 3 Java Client (https://github.com/fabric8io/kubernetes-client) -Copyright (C) 2015 Red Hat, Inc. - -This product includes/uses OkHttp (https://github.com/square/okhttp) -Copyright (C) 2012 The Android Open Source Project - -========================================================================= -== NOTICE file corresponding to section 4(d) of the Apache License, == -== Version 2.0, in this case for the DataNucleus distribution. == -========================================================================= - -=================================================================== -This product includes software developed by many individuals, -including the following: -=================================================================== -Erik Bengtson -Andy Jefferson - -=================================================================== -This product has included contributions from some individuals, -including the following: -=================================================================== - -=================================================================== -This product has included contributions from some individuals, -including the following: -=================================================================== -Joerg von Frantzius -Thomas Marti -Barry Haddow -Marco Schulze -Ralph Ullrich -David Ezzio -Brendan de Beer -David Eaves -Martin Taal -Tony Lai -Roland Szabo -Marcus Mennemeier -Xuan Baldauf -Eric Sultan - -=================================================================== -This product also includes software developed by the TJDO project -(http://tjdo.sourceforge.net/). -=================================================================== - -=================================================================== -This product includes software developed by many individuals, -including the following: -=================================================================== -Andy Jefferson -Erik Bengtson -Joerg von Frantzius -Marco Schulze - -=================================================================== -This product has included contributions from some individuals, -including the following: -=================================================================== -Barry Haddow -Ralph Ullrich -David Ezzio -Brendan de Beer -David Eaves -Martin Taal -Tony Lai -Roland Szabo -Anton Troshin (Timesten) - -=================================================================== -This product also includes software developed by the Apache Commons project -(http://commons.apache.org/). -=================================================================== - -Apache Java Data Objects (JDO) -Copyright 2005-2006 The Apache Software Foundation - -========================================================================= -== NOTICE file corresponding to section 4(d) of the Apache License, == -== Version 2.0, in this case for the Apache Derby distribution. == -========================================================================= - -Apache Derby -Copyright 2004-2008 The Apache Software Foundation - -Portions of Derby were originally developed by -International Business Machines Corporation and are -licensed to the Apache Software Foundation under the -"Software Grant and Corporate Contribution License Agreement", -informally known as the "Derby CLA". -The following copyright notice(s) were affixed to portions of the code -with which this file is now or was at one time distributed -and are placed here unaltered. - -(C) Copyright 1997,2004 International Business Machines Corporation. All rights reserved. - -(C) Copyright IBM Corp. 2003. - -The portion of the functionTests under 'nist' was originally -developed by the National Institute of Standards and Technology (NIST), -an agency of the United States Department of Commerce, and adapted by -International Business Machines Corporation in accordance with the NIST -Software Acknowledgment and Redistribution document at -http://www.itl.nist.gov/div897/ctg/sql_form.htm - -Apache Commons Collections -Copyright 2001-2008 The Apache Software Foundation - -Apache Commons Configuration -Copyright 2001-2008 The Apache Software Foundation - -Apache Jakarta Commons Digester -Copyright 2001-2006 The Apache Software Foundation - -Apache Commons BeanUtils -Copyright 2000-2008 The Apache Software Foundation - -Apache Avro Mapred API -Copyright 2009-2013 The Apache Software Foundation - -Apache Avro IPC -Copyright 2009-2013 The Apache Software Foundation - - -Vis.js -Copyright 2010-2015 Almende B.V. - -Vis.js is dual licensed under both - - * The Apache 2.0 License - http://www.apache.org/licenses/LICENSE-2.0 - - and - - * The MIT License - http://opensource.org/licenses/MIT - -Vis.js may be distributed under either license. - - -Vis.js uses and redistributes the following third-party libraries: - -- component-emitter - https://github.com/component/emitter - The MIT License - -- hammer.js - http://hammerjs.github.io/ - The MIT License - -- moment.js - http://momentjs.com/ - The MIT License - -- keycharm - https://github.com/AlexDM0/keycharm - The MIT License - -=============================================================================== - -The CSS style for the navigation sidebar of the documentation was originally -submitted by Óscar Nájera for the scikit-learn project. The scikit-learn project -is distributed under the 3-Clause BSD license. -=============================================================================== - -For CSV functionality: - -/* - * Copyright 2014 Databricks - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/* - * Copyright 2015 Ayasdi Inc - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - - -=============================================================================== -For dev/sparktestsupport/toposort.py: - -Copyright 2014 True Blade Systems, Inc. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - -http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. diff --git a/NOTICE-binary b/NOTICE-binary new file mode 100644 index 0000000000000..d56f99bdb55a6 --- /dev/null +++ b/NOTICE-binary @@ -0,0 +1,1170 @@ +Apache Spark +Copyright 2014 and onwards The Apache Software Foundation. + +This product includes software developed at +The Apache Software Foundation (http://www.apache.org/). + + +// ------------------------------------------------------------------ +// NOTICE file corresponding to the section 4d of The Apache License, +// Version 2.0, in this case for +// ------------------------------------------------------------------ + +Hive Beeline +Copyright 2016 The Apache Software Foundation + +This product includes software developed at +The Apache Software Foundation (http://www.apache.org/). + +Apache Avro +Copyright 2009-2014 The Apache Software Foundation + +This product currently only contains code developed by authors +of specific components, as identified by the source code files; +if such notes are missing files have been created by +Tatu Saloranta. + +For additional credits (generally to people who reported problems) +see CREDITS file. + +Apache Commons Compress +Copyright 2002-2012 The Apache Software Foundation + +This product includes software developed by +The Apache Software Foundation (http://www.apache.org/). + +Apache Avro Mapred API +Copyright 2009-2014 The Apache Software Foundation + +Apache Avro IPC +Copyright 2009-2014 The Apache Software Foundation + +Objenesis +Copyright 2006-2013 Joe Walnes, Henri Tremblay, Leonardo Mesquita + +Apache XBean :: ASM 5 shaded (repackaged) +Copyright 2005-2015 The Apache Software Foundation + +-------------------------------------- + +This product includes software developed at +OW2 Consortium (http://asm.ow2.org/) + +This product includes software developed by The Apache Software +Foundation (http://www.apache.org/). + +The binary distribution of this product bundles binaries of +org.iq80.leveldb:leveldb-api (https://github.com/dain/leveldb), which has the +following notices: +* Copyright 2011 Dain Sundstrom +* Copyright 2011 FuseSource Corp. http://fusesource.com + +The binary distribution of this product bundles binaries of +org.fusesource.hawtjni:hawtjni-runtime (https://github.com/fusesource/hawtjni), +which has the following notices: +* This product includes software developed by FuseSource Corp. + http://fusesource.com +* This product includes software developed at + Progress Software Corporation and/or its subsidiaries or affiliates. +* This product includes software developed by IBM Corporation and others. + +The binary distribution of this product bundles binaries of +Gson 2.2.4, +which has the following notices: + + The Netty Project + ================= + +Please visit the Netty web site for more information: + + * http://netty.io/ + +Copyright 2014 The Netty Project + +The Netty Project licenses this file to you under the Apache License, +version 2.0 (the "License"); you may not use this file except in compliance +with the License. You may obtain a copy of the License at: + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +License for the specific language governing permissions and limitations +under the License. + +Also, please refer to each LICENSE..txt file, which is located in +the 'license' directory of the distribution file, for the license terms of the +components that this product depends on. + +------------------------------------------------------------------------------- +This product contains the extensions to Java Collections Framework which has +been derived from the works by JSR-166 EG, Doug Lea, and Jason T. Greene: + + * LICENSE: + * license/LICENSE.jsr166y.txt (Public Domain) + * HOMEPAGE: + * http://gee.cs.oswego.edu/cgi-bin/viewcvs.cgi/jsr166/ + * http://viewvc.jboss.org/cgi-bin/viewvc.cgi/jbosscache/experimental/jsr166/ + +This product contains a modified version of Robert Harder's Public Domain +Base64 Encoder and Decoder, which can be obtained at: + + * LICENSE: + * license/LICENSE.base64.txt (Public Domain) + * HOMEPAGE: + * http://iharder.sourceforge.net/current/java/base64/ + +This product contains a modified portion of 'Webbit', an event based +WebSocket and HTTP server, which can be obtained at: + + * LICENSE: + * license/LICENSE.webbit.txt (BSD License) + * HOMEPAGE: + * https://github.com/joewalnes/webbit + +This product contains a modified portion of 'SLF4J', a simple logging +facade for Java, which can be obtained at: + + * LICENSE: + * license/LICENSE.slf4j.txt (MIT License) + * HOMEPAGE: + * http://www.slf4j.org/ + +This product contains a modified portion of 'ArrayDeque', written by Josh +Bloch of Google, Inc: + + * LICENSE: + * license/LICENSE.deque.txt (Public Domain) + +This product contains a modified portion of 'Apache Harmony', an open source +Java SE, which can be obtained at: + + * LICENSE: + * license/LICENSE.harmony.txt (Apache License 2.0) + * HOMEPAGE: + * http://archive.apache.org/dist/harmony/ + +This product contains a modified version of Roland Kuhn's ASL2 +AbstractNodeQueue, which is based on Dmitriy Vyukov's non-intrusive MPSC queue. +It can be obtained at: + + * LICENSE: + * license/LICENSE.abstractnodequeue.txt (Public Domain) + * HOMEPAGE: + * https://github.com/akka/akka/blob/wip-2.2.3-for-scala-2.11/akka-actor/src/main/java/akka/dispatch/AbstractNodeQueue.java + +This product contains a modified portion of 'jbzip2', a Java bzip2 compression +and decompression library written by Matthew J. Francis. It can be obtained at: + + * LICENSE: + * license/LICENSE.jbzip2.txt (MIT License) + * HOMEPAGE: + * https://code.google.com/p/jbzip2/ + +This product contains a modified portion of 'libdivsufsort', a C API library to construct +the suffix array and the Burrows-Wheeler transformed string for any input string of +a constant-size alphabet written by Yuta Mori. It can be obtained at: + + * LICENSE: + * license/LICENSE.libdivsufsort.txt (MIT License) + * HOMEPAGE: + * https://code.google.com/p/libdivsufsort/ + +This product contains a modified portion of Nitsan Wakart's 'JCTools', Java Concurrency Tools for the JVM, + which can be obtained at: + + * LICENSE: + * license/LICENSE.jctools.txt (ASL2 License) + * HOMEPAGE: + * https://github.com/JCTools/JCTools + +This product optionally depends on 'JZlib', a re-implementation of zlib in +pure Java, which can be obtained at: + + * LICENSE: + * license/LICENSE.jzlib.txt (BSD style License) + * HOMEPAGE: + * http://www.jcraft.com/jzlib/ + +This product optionally depends on 'Compress-LZF', a Java library for encoding and +decoding data in LZF format, written by Tatu Saloranta. It can be obtained at: + + * LICENSE: + * license/LICENSE.compress-lzf.txt (Apache License 2.0) + * HOMEPAGE: + * https://github.com/ning/compress + +This product optionally depends on 'lz4', a LZ4 Java compression +and decompression library written by Adrien Grand. It can be obtained at: + + * LICENSE: + * license/LICENSE.lz4.txt (Apache License 2.0) + * HOMEPAGE: + * https://github.com/jpountz/lz4-java + +This product optionally depends on 'lzma-java', a LZMA Java compression +and decompression library, which can be obtained at: + + * LICENSE: + * license/LICENSE.lzma-java.txt (Apache License 2.0) + * HOMEPAGE: + * https://github.com/jponge/lzma-java + +This product contains a modified portion of 'jfastlz', a Java port of FastLZ compression +and decompression library written by William Kinney. It can be obtained at: + + * LICENSE: + * license/LICENSE.jfastlz.txt (MIT License) + * HOMEPAGE: + * https://code.google.com/p/jfastlz/ + +This product contains a modified portion of and optionally depends on 'Protocol Buffers', Google's data +interchange format, which can be obtained at: + + * LICENSE: + * license/LICENSE.protobuf.txt (New BSD License) + * HOMEPAGE: + * http://code.google.com/p/protobuf/ + +This product optionally depends on 'Bouncy Castle Crypto APIs' to generate +a temporary self-signed X.509 certificate when the JVM does not provide the +equivalent functionality. It can be obtained at: + + * LICENSE: + * license/LICENSE.bouncycastle.txt (MIT License) + * HOMEPAGE: + * http://www.bouncycastle.org/ + +This product optionally depends on 'Snappy', a compression library produced +by Google Inc, which can be obtained at: + + * LICENSE: + * license/LICENSE.snappy.txt (New BSD License) + * HOMEPAGE: + * http://code.google.com/p/snappy/ + +This product optionally depends on 'JBoss Marshalling', an alternative Java +serialization API, which can be obtained at: + + * LICENSE: + * license/LICENSE.jboss-marshalling.txt (GNU LGPL 2.1) + * HOMEPAGE: + * http://www.jboss.org/jbossmarshalling + +This product optionally depends on 'Caliper', Google's micro- +benchmarking framework, which can be obtained at: + + * LICENSE: + * license/LICENSE.caliper.txt (Apache License 2.0) + * HOMEPAGE: + * http://code.google.com/p/caliper/ + +This product optionally depends on 'Apache Commons Logging', a logging +framework, which can be obtained at: + + * LICENSE: + * license/LICENSE.commons-logging.txt (Apache License 2.0) + * HOMEPAGE: + * http://commons.apache.org/logging/ + +This product optionally depends on 'Apache Log4J', a logging framework, which +can be obtained at: + + * LICENSE: + * license/LICENSE.log4j.txt (Apache License 2.0) + * HOMEPAGE: + * http://logging.apache.org/log4j/ + +This product optionally depends on 'Aalto XML', an ultra-high performance +non-blocking XML processor, which can be obtained at: + + * LICENSE: + * license/LICENSE.aalto-xml.txt (Apache License 2.0) + * HOMEPAGE: + * http://wiki.fasterxml.com/AaltoHome + +This product contains a modified version of 'HPACK', a Java implementation of +the HTTP/2 HPACK algorithm written by Twitter. It can be obtained at: + + * LICENSE: + * license/LICENSE.hpack.txt (Apache License 2.0) + * HOMEPAGE: + * https://github.com/twitter/hpack + +This product contains a modified portion of 'Apache Commons Lang', a Java library +provides utilities for the java.lang API, which can be obtained at: + + * LICENSE: + * license/LICENSE.commons-lang.txt (Apache License 2.0) + * HOMEPAGE: + * https://commons.apache.org/proper/commons-lang/ + +The binary distribution of this product bundles binaries of +Commons Codec 1.4, +which has the following notices: + * src/test/org/apache/commons/codec/language/DoubleMetaphoneTest.javacontains test data from http://aspell.net/test/orig/batch0.tab.Copyright (C) 2002 Kevin Atkinson (kevina@gnu.org) + =============================================================================== + The content of package org.apache.commons.codec.language.bm has been translated + from the original php source code available at http://stevemorse.org/phoneticinfo.htm + with permission from the original authors. + Original source copyright:Copyright (c) 2008 Alexander Beider & Stephen P. Morse. + +The binary distribution of this product bundles binaries of +Commons Lang 2.6, +which has the following notices: + * This product includes software from the Spring Framework,under the Apache License 2.0 (see: StringUtils.containsWhitespace()) + +The binary distribution of this product bundles binaries of +Apache Log4j 1.2.17, +which has the following notices: + * ResolverUtil.java + Copyright 2005-2006 Tim Fennell + Dumbster SMTP test server + Copyright 2004 Jason Paul Kitchen + TypeUtil.java + Copyright 2002-2012 Ramnivas Laddad, Juergen Hoeller, Chris Beams + +The binary distribution of this product bundles binaries of +Jetty 6.1.26, +which has the following notices: + * ============================================================== + Jetty Web Container + Copyright 1995-2016 Mort Bay Consulting Pty Ltd. + ============================================================== + + The Jetty Web Container is Copyright Mort Bay Consulting Pty Ltd + unless otherwise noted. + + Jetty is dual licensed under both + + * The Apache 2.0 License + http://www.apache.org/licenses/LICENSE-2.0.html + + and + + * The Eclipse Public 1.0 License + http://www.eclipse.org/legal/epl-v10.html + + Jetty may be distributed under either license. + + ------ + Eclipse + + The following artifacts are EPL. + * org.eclipse.jetty.orbit:org.eclipse.jdt.core + + The following artifacts are EPL and ASL2. + * org.eclipse.jetty.orbit:javax.security.auth.message + + The following artifacts are EPL and CDDL 1.0. + * org.eclipse.jetty.orbit:javax.mail.glassfish + + ------ + Oracle + + The following artifacts are CDDL + GPLv2 with classpath exception. + https://glassfish.dev.java.net/nonav/public/CDDL+GPL.html + + * javax.servlet:javax.servlet-api + * javax.annotation:javax.annotation-api + * javax.transaction:javax.transaction-api + * javax.websocket:javax.websocket-api + + ------ + Oracle OpenJDK + + If ALPN is used to negotiate HTTP/2 connections, then the following + artifacts may be included in the distribution or downloaded when ALPN + module is selected. + + * java.sun.security.ssl + + These artifacts replace/modify OpenJDK classes. The modififications + are hosted at github and both modified and original are under GPL v2 with + classpath exceptions. + http://openjdk.java.net/legal/gplv2+ce.html + + ------ + OW2 + + The following artifacts are licensed by the OW2 Foundation according to the + terms of http://asm.ow2.org/license.html + + org.ow2.asm:asm-commons + org.ow2.asm:asm + + ------ + Apache + + The following artifacts are ASL2 licensed. + + org.apache.taglibs:taglibs-standard-spec + org.apache.taglibs:taglibs-standard-impl + + ------ + MortBay + + The following artifacts are ASL2 licensed. Based on selected classes from + following Apache Tomcat jars, all ASL2 licensed. + + org.mortbay.jasper:apache-jsp + org.apache.tomcat:tomcat-jasper + org.apache.tomcat:tomcat-juli + org.apache.tomcat:tomcat-jsp-api + org.apache.tomcat:tomcat-el-api + org.apache.tomcat:tomcat-jasper-el + org.apache.tomcat:tomcat-api + org.apache.tomcat:tomcat-util-scan + org.apache.tomcat:tomcat-util + + org.mortbay.jasper:apache-el + org.apache.tomcat:tomcat-jasper-el + org.apache.tomcat:tomcat-el-api + + ------ + Mortbay + + The following artifacts are CDDL + GPLv2 with classpath exception. + + https://glassfish.dev.java.net/nonav/public/CDDL+GPL.html + + org.eclipse.jetty.toolchain:jetty-schemas + + ------ + Assorted + + The UnixCrypt.java code implements the one way cryptography used by + Unix systems for simple password protection. Copyright 1996 Aki Yoshida, + modified April 2001 by Iris Van den Broeke, Daniel Deville. + Permission to use, copy, modify and distribute UnixCrypt + for non-commercial or commercial purposes and without fee is + granted provided that the copyright notice appears in all copies./ + +The binary distribution of this product bundles binaries of +Snappy for Java 1.0.4.1, +which has the following notices: + * This product includes software developed by Google + Snappy: http://code.google.com/p/snappy/ (New BSD License) + + This product includes software developed by Apache + PureJavaCrc32C from apache-hadoop-common http://hadoop.apache.org/ + (Apache 2.0 license) + + This library containd statically linked libstdc++. This inclusion is allowed by + "GCC RUntime Library Exception" + http://gcc.gnu.org/onlinedocs/libstdc++/manual/license.html + + == Contributors == + * Tatu Saloranta + * Providing benchmark suite + * Alec Wysoker + * Performance and memory usage improvement + +The binary distribution of this product bundles binaries of +Xerces2 Java Parser 2.9.1, +which has the following notices: + * ========================================================================= + == NOTICE file corresponding to section 4(d) of the Apache License, == + == Version 2.0, in this case for the Apache Xerces Java distribution. == + ========================================================================= + + Apache Xerces Java + Copyright 1999-2007 The Apache Software Foundation + + This product includes software developed at + The Apache Software Foundation (http://www.apache.org/). + + Portions of this software were originally based on the following: + - software copyright (c) 1999, IBM Corporation., http://www.ibm.com. + - software copyright (c) 1999, Sun Microsystems., http://www.sun.com. + - voluntary contributions made by Paul Eng on behalf of the + Apache Software Foundation that were originally developed at iClick, Inc., + software copyright (c) 1999. + +Apache Commons Collections +Copyright 2001-2015 The Apache Software Foundation + +Apache Commons Configuration +Copyright 2001-2008 The Apache Software Foundation + +Apache Jakarta Commons Digester +Copyright 2001-2006 The Apache Software Foundation + +Apache Commons BeanUtils +Copyright 2000-2008 The Apache Software Foundation + +ApacheDS Protocol Kerberos Codec +Copyright 2003-2013 The Apache Software Foundation + +ApacheDS I18n +Copyright 2003-2013 The Apache Software Foundation + +Apache Directory API ASN.1 API +Copyright 2003-2013 The Apache Software Foundation + +Apache Directory LDAP API Utilities +Copyright 2003-2013 The Apache Software Foundation + +Curator Client +Copyright 2011-2015 The Apache Software Foundation + +htrace-core +Copyright 2015 The Apache Software Foundation + + ========================================================================= + == NOTICE file corresponding to section 4(d) of the Apache License, == + == Version 2.0, in this case for the Apache Xerces Java distribution. == + ========================================================================= + + Portions of this software were originally based on the following: + - software copyright (c) 1999, IBM Corporation., http://www.ibm.com. + - software copyright (c) 1999, Sun Microsystems., http://www.sun.com. + - voluntary contributions made by Paul Eng on behalf of the + Apache Software Foundation that were originally developed at iClick, Inc., + software copyright (c) 1999. + +# Jackson JSON processor + +Jackson is a high-performance, Free/Open Source JSON processing library. +It was originally written by Tatu Saloranta (tatu.saloranta@iki.fi), and has +been in development since 2007. +It is currently developed by a community of developers, as well as supported +commercially by FasterXML.com. + +## Licensing + +Jackson core and extension components may licensed under different licenses. +To find the details that apply to this artifact see the accompanying LICENSE file. +For more information, including possible other licensing options, contact +FasterXML.com (http://fasterxml.com). + +## Credits + +A list of contributors may be found from CREDITS file, which is included +in some artifacts (usually source distributions); but is always available +from the source code management (SCM) system project uses. + +Apache HttpCore +Copyright 2005-2017 The Apache Software Foundation + +Curator Recipes +Copyright 2011-2015 The Apache Software Foundation + +Curator Framework +Copyright 2011-2015 The Apache Software Foundation + +Apache Commons Lang +Copyright 2001-2016 The Apache Software Foundation + +This product includes software from the Spring Framework, +under the Apache License 2.0 (see: StringUtils.containsWhitespace()) + +Apache Commons Math +Copyright 2001-2015 The Apache Software Foundation + +This product includes software developed for Orekit by +CS Systèmes d'Information (http://www.c-s.fr/) +Copyright 2010-2012 CS Systèmes d'Information + +Apache log4j +Copyright 2007 The Apache Software Foundation + +# Compress LZF + +This library contains efficient implementation of LZF compression format, +as well as additional helper classes that build on JDK-provided gzip (deflat) +codec. + +Library is licensed under Apache License 2.0, as per accompanying LICENSE file. + +## Credit + +Library has been written by Tatu Saloranta (tatu.saloranta@iki.fi). +It was started at Ning, inc., as an official Open Source process used by +platform backend, but after initial versions has been developed outside of +Ning by supporting community. + +Other contributors include: + +* Jon Hartlaub (first versions of streaming reader/writer; unit tests) +* Cedrik Lime: parallel LZF implementation + +Various community members have contributed bug reports, and suggested minor +fixes; these can be found from file "VERSION.txt" in SCM. + +Apache Commons Net +Copyright 2001-2012 The Apache Software Foundation + +Copyright 2011 The Netty Project + +http://www.apache.org/licenses/LICENSE-2.0 + +This product contains a modified version of 'JZlib', a re-implementation of +zlib in pure Java, which can be obtained at: + + * LICENSE: + * license/LICENSE.jzlib.txt (BSD Style License) + * HOMEPAGE: + * http://www.jcraft.com/jzlib/ + +This product contains a modified version of 'Webbit', a Java event based +WebSocket and HTTP server: + +This product optionally depends on 'Protocol Buffers', Google's data +interchange format, which can be obtained at: + +This product optionally depends on 'SLF4J', a simple logging facade for Java, +which can be obtained at: + +This product optionally depends on 'Apache Log4J', a logging framework, +which can be obtained at: + +This product optionally depends on 'JBoss Logging', a logging framework, +which can be obtained at: + + * LICENSE: + * license/LICENSE.jboss-logging.txt (GNU LGPL 2.1) + * HOMEPAGE: + * http://anonsvn.jboss.org/repos/common/common-logging-spi/ + +This product optionally depends on 'Apache Felix', an open source OSGi +framework implementation, which can be obtained at: + + * LICENSE: + * license/LICENSE.felix.txt (Apache License 2.0) + * HOMEPAGE: + * http://felix.apache.org/ + +Jackson core and extension components may be licensed under different licenses. +To find the details that apply to this artifact see the accompanying LICENSE file. +For more information, including possible other licensing options, contact +FasterXML.com (http://fasterxml.com). + +Apache Ivy (TM) +Copyright 2007-2014 The Apache Software Foundation + +Portions of Ivy were originally developed at +Jayasoft SARL (http://www.jayasoft.fr/) +and are licensed to the Apache Software Foundation under the +"Software Grant License Agreement" + +SSH and SFTP support is provided by the JCraft JSch package, +which is open source software, available under +the terms of a BSD style license. +The original software and related information is available +at http://www.jcraft.com/jsch/. + + +ORC Core +Copyright 2013-2018 The Apache Software Foundation + +Apache Commons Lang +Copyright 2001-2011 The Apache Software Foundation + +ORC MapReduce +Copyright 2013-2018 The Apache Software Foundation + +Apache Parquet Format +Copyright 2017 The Apache Software Foundation + +Arrow Vectors +Copyright 2017 The Apache Software Foundation + +Arrow Format +Copyright 2017 The Apache Software Foundation + +Arrow Memory +Copyright 2017 The Apache Software Foundation + +Apache Commons CLI +Copyright 2001-2009 The Apache Software Foundation + +Google Guice - Extensions - Servlet +Copyright 2006-2011 Google, Inc. + +Apache Commons IO +Copyright 2002-2012 The Apache Software Foundation + +Google Guice - Core Library +Copyright 2006-2011 Google, Inc. + +mesos +Copyright 2017 The Apache Software Foundation + +Apache Parquet Hadoop Bundle (Incubating) +Copyright 2015 The Apache Software Foundation + +Hive Query Language +Copyright 2016 The Apache Software Foundation + +Apache Extras Companion for log4j 1.2. +Copyright 2007 The Apache Software Foundation + +Hive Metastore +Copyright 2016 The Apache Software Foundation + +Apache Commons Logging +Copyright 2003-2013 The Apache Software Foundation + +========================================================================= +== NOTICE file corresponding to section 4(d) of the Apache License, == +== Version 2.0, in this case for the DataNucleus distribution. == +========================================================================= + +=================================================================== +This product includes software developed by many individuals, +including the following: +=================================================================== +Erik Bengtson +Andy Jefferson + +=================================================================== +This product has included contributions from some individuals, +including the following: +=================================================================== + +=================================================================== +This product includes software developed by many individuals, +including the following: +=================================================================== +Andy Jefferson +Erik Bengtson +Joerg von Frantzius +Marco Schulze + +=================================================================== +This product has included contributions from some individuals, +including the following: +=================================================================== +Barry Haddow +Ralph Ullrich +David Ezzio +Brendan de Beer +David Eaves +Martin Taal +Tony Lai +Roland Szabo +Anton Troshin (Timesten) + +=================================================================== +This product also includes software developed by the TJDO project +(http://tjdo.sourceforge.net/). +=================================================================== + +=================================================================== +This product also includes software developed by the Apache Commons project +(http://commons.apache.org/). +=================================================================== + +Apache Commons Pool +Copyright 1999-2009 The Apache Software Foundation + +Apache Commons DBCP +Copyright 2001-2010 The Apache Software Foundation + +Apache Java Data Objects (JDO) +Copyright 2005-2006 The Apache Software Foundation + +Apache Jakarta HttpClient +Copyright 1999-2007 The Apache Software Foundation + +Calcite Avatica +Copyright 2012-2015 The Apache Software Foundation + +Calcite Core +Copyright 2012-2015 The Apache Software Foundation + +Calcite Linq4j +Copyright 2012-2015 The Apache Software Foundation + +Apache HttpClient +Copyright 1999-2017 The Apache Software Foundation + +Apache Commons Codec +Copyright 2002-2014 The Apache Software Foundation + +src/test/org/apache/commons/codec/language/DoubleMetaphoneTest.java +contains test data from http://aspell.net/test/orig/batch0.tab. +Copyright (C) 2002 Kevin Atkinson (kevina@gnu.org) + +=============================================================================== + +The content of package org.apache.commons.codec.language.bm has been translated +from the original php source code available at http://stevemorse.org/phoneticinfo.htm +with permission from the original authors. +Original source copyright: +Copyright (c) 2008 Alexander Beider & Stephen P. Morse. + +============================================================================= += NOTICE file corresponding to section 4d of the Apache License Version 2.0 = +============================================================================= +This product includes software developed by +Joda.org (http://www.joda.org/). + +=================================================================== +This product has included contributions from some individuals, +including the following: +=================================================================== +Joerg von Frantzius +Thomas Marti +Barry Haddow +Marco Schulze +Ralph Ullrich +David Ezzio +Brendan de Beer +David Eaves +Martin Taal +Tony Lai +Roland Szabo +Marcus Mennemeier +Xuan Baldauf +Eric Sultan + +Apache Thrift +Copyright 2006-2010 The Apache Software Foundation. + +========================================================================= +== NOTICE file corresponding to section 4(d) of the Apache License, +== Version 2.0, in this case for the Apache Derby distribution. +== +== DO NOT EDIT THIS FILE DIRECTLY. IT IS GENERATED +== BY THE buildnotice TARGET IN THE TOP LEVEL build.xml FILE. +== +========================================================================= + +Apache Derby +Copyright 2004-2015 The Apache Software Foundation + +========================================================================= + +Portions of Derby were originally developed by +International Business Machines Corporation and are +licensed to the Apache Software Foundation under the +"Software Grant and Corporate Contribution License Agreement", +informally known as the "Derby CLA". +The following copyright notice(s) were affixed to portions of the code +with which this file is now or was at one time distributed +and are placed here unaltered. + +(C) Copyright 1997,2004 International Business Machines Corporation. All rights reserved. + +(C) Copyright IBM Corp. 2003. + +The portion of the functionTests under 'nist' was originally +developed by the National Institute of Standards and Technology (NIST), +an agency of the United States Department of Commerce, and adapted by +International Business Machines Corporation in accordance with the NIST +Software Acknowledgment and Redistribution document at +http://www.itl.nist.gov/div897/ctg/sql_form.htm + +The JDBC apis for small devices and JDBC3 (under java/stubs/jsr169 and +java/stubs/jdbc3) were produced by trimming sources supplied by the +Apache Harmony project. In addition, the Harmony SerialBlob and +SerialClob implementations are used. The following notice covers the Harmony sources: + +Portions of Harmony were originally developed by +Intel Corporation and are licensed to the Apache Software +Foundation under the "Software Grant and Corporate Contribution +License Agreement", informally known as the "Intel Harmony CLA". + +The Derby build relies on source files supplied by the Apache Felix +project. The following notice covers the Felix files: + + Apache Felix Main + Copyright 2008 The Apache Software Foundation + + I. Included Software + + This product includes software developed at + The Apache Software Foundation (http://www.apache.org/). + Licensed under the Apache License 2.0. + + This product includes software developed at + The OSGi Alliance (http://www.osgi.org/). + Copyright (c) OSGi Alliance (2000, 2007). + Licensed under the Apache License 2.0. + + This product includes software from http://kxml.sourceforge.net. + Copyright (c) 2002,2003, Stefan Haustein, Oberhausen, Rhld., Germany. + Licensed under BSD License. + + II. Used Software + + This product uses software developed at + The OSGi Alliance (http://www.osgi.org/). + Copyright (c) OSGi Alliance (2000, 2007). + Licensed under the Apache License 2.0. + + III. License Summary + - Apache License 2.0 + - BSD License + +The Derby build relies on jar files supplied by the Apache Lucene +project. The following notice covers the Lucene files: + +Apache Lucene +Copyright 2013 The Apache Software Foundation + +Includes software from other Apache Software Foundation projects, +including, but not limited to: + - Apache Ant + - Apache Jakarta Regexp + - Apache Commons + - Apache Xerces + +ICU4J, (under analysis/icu) is licensed under an MIT styles license +and Copyright (c) 1995-2008 International Business Machines Corporation and others + +Some data files (under analysis/icu/src/data) are derived from Unicode data such +as the Unicode Character Database. See http://unicode.org/copyright.html for more +details. + +Brics Automaton (under core/src/java/org/apache/lucene/util/automaton) is +BSD-licensed, created by Anders Møller. See http://www.brics.dk/automaton/ + +The levenshtein automata tables (under core/src/java/org/apache/lucene/util/automaton) were +automatically generated with the moman/finenight FSA library, created by +Jean-Philippe Barrette-LaPierre. This library is available under an MIT license, +see http://sites.google.com/site/rrettesite/moman and +http://bitbucket.org/jpbarrette/moman/overview/ + +The class org.apache.lucene.util.WeakIdentityMap was derived from +the Apache CXF project and is Apache License 2.0. + +The Google Code Prettify is Apache License 2.0. +See http://code.google.com/p/google-code-prettify/ + +JUnit (junit-4.10) is licensed under the Common Public License v. 1.0 +See http://junit.sourceforge.net/cpl-v10.html + +This product includes code (JaspellTernarySearchTrie) from Java Spelling Checkin +g Package (jaspell): http://jaspell.sourceforge.net/ +License: The BSD License (http://www.opensource.org/licenses/bsd-license.php) + +The snowball stemmers in + analysis/common/src/java/net/sf/snowball +were developed by Martin Porter and Richard Boulton. +The snowball stopword lists in + analysis/common/src/resources/org/apache/lucene/analysis/snowball +were developed by Martin Porter and Richard Boulton. +The full snowball package is available from + http://snowball.tartarus.org/ + +The KStem stemmer in + analysis/common/src/org/apache/lucene/analysis/en +was developed by Bob Krovetz and Sergio Guzman-Lara (CIIR-UMass Amherst) +under the BSD-license. + +The Arabic,Persian,Romanian,Bulgarian, and Hindi analyzers (common) come with a default +stopword list that is BSD-licensed created by Jacques Savoy. These files reside in: +analysis/common/src/resources/org/apache/lucene/analysis/ar/stopwords.txt, +analysis/common/src/resources/org/apache/lucene/analysis/fa/stopwords.txt, +analysis/common/src/resources/org/apache/lucene/analysis/ro/stopwords.txt, +analysis/common/src/resources/org/apache/lucene/analysis/bg/stopwords.txt, +analysis/common/src/resources/org/apache/lucene/analysis/hi/stopwords.txt +See http://members.unine.ch/jacques.savoy/clef/index.html. + +The German,Spanish,Finnish,French,Hungarian,Italian,Portuguese,Russian and Swedish light stemmers +(common) are based on BSD-licensed reference implementations created by Jacques Savoy and +Ljiljana Dolamic. These files reside in: +analysis/common/src/java/org/apache/lucene/analysis/de/GermanLightStemmer.java +analysis/common/src/java/org/apache/lucene/analysis/de/GermanMinimalStemmer.java +analysis/common/src/java/org/apache/lucene/analysis/es/SpanishLightStemmer.java +analysis/common/src/java/org/apache/lucene/analysis/fi/FinnishLightStemmer.java +analysis/common/src/java/org/apache/lucene/analysis/fr/FrenchLightStemmer.java +analysis/common/src/java/org/apache/lucene/analysis/fr/FrenchMinimalStemmer.java +analysis/common/src/java/org/apache/lucene/analysis/hu/HungarianLightStemmer.java +analysis/common/src/java/org/apache/lucene/analysis/it/ItalianLightStemmer.java +analysis/common/src/java/org/apache/lucene/analysis/pt/PortugueseLightStemmer.java +analysis/common/src/java/org/apache/lucene/analysis/ru/RussianLightStemmer.java +analysis/common/src/java/org/apache/lucene/analysis/sv/SwedishLightStemmer.java + +The Stempel analyzer (stempel) includes BSD-licensed software developed +by the Egothor project http://egothor.sf.net/, created by Leo Galambos, Martin Kvapil, +and Edmond Nolan. + +The Polish analyzer (stempel) comes with a default +stopword list that is BSD-licensed created by the Carrot2 project. The file resides +in stempel/src/resources/org/apache/lucene/analysis/pl/stopwords.txt. +See http://project.carrot2.org/license.html. + +The SmartChineseAnalyzer source code (smartcn) was +provided by Xiaoping Gao and copyright 2009 by www.imdict.net. + +WordBreakTestUnicode_*.java (under modules/analysis/common/src/test/) +is derived from Unicode data such as the Unicode Character Database. +See http://unicode.org/copyright.html for more details. + +The Morfologik analyzer (morfologik) includes BSD-licensed software +developed by Dawid Weiss and Marcin Miłkowski (http://morfologik.blogspot.com/). + +Morfologik uses data from Polish ispell/myspell dictionary +(http://www.sjp.pl/slownik/en/) licenced on the terms of (inter alia) +LGPL and Creative Commons ShareAlike. + +Morfologic includes data from BSD-licensed dictionary of Polish (SGJP) +(http://sgjp.pl/morfeusz/) + +Servlet-api.jar and javax.servlet-*.jar are under the CDDL license, the original +source code for this can be found at http://www.eclipse.org/jetty/downloads.php + +=========================================================================== +Kuromoji Japanese Morphological Analyzer - Apache Lucene Integration +=========================================================================== + +This software includes a binary and/or source version of data from + + mecab-ipadic-2.7.0-20070801 + +which can be obtained from + + http://atilika.com/releases/mecab-ipadic/mecab-ipadic-2.7.0-20070801.tar.gz + +or + + http://jaist.dl.sourceforge.net/project/mecab/mecab-ipadic/2.7.0-20070801/mecab-ipadic-2.7.0-20070801.tar.gz + +=========================================================================== +mecab-ipadic-2.7.0-20070801 Notice +=========================================================================== + +Nara Institute of Science and Technology (NAIST), +the copyright holders, disclaims all warranties with regard to this +software, including all implied warranties of merchantability and +fitness, in no event shall NAIST be liable for +any special, indirect or consequential damages or any damages +whatsoever resulting from loss of use, data or profits, whether in an +action of contract, negligence or other tortuous action, arising out +of or in connection with the use or performance of this software. + +A large portion of the dictionary entries +originate from ICOT Free Software. The following conditions for ICOT +Free Software applies to the current dictionary as well. + +Each User may also freely distribute the Program, whether in its +original form or modified, to any third party or parties, PROVIDED +that the provisions of Section 3 ("NO WARRANTY") will ALWAYS appear +on, or be attached to, the Program, which is distributed substantially +in the same form as set out herein and that such intended +distribution, if actually made, will neither violate or otherwise +contravene any of the laws and regulations of the countries having +jurisdiction over the User or the intended distribution itself. + +NO WARRANTY + +The program was produced on an experimental basis in the course of the +research and development conducted during the project and is provided +to users as so produced on an experimental basis. Accordingly, the +program is provided without any warranty whatsoever, whether express, +implied, statutory or otherwise. The term "warranty" used herein +includes, but is not limited to, any warranty of the quality, +performance, merchantability and fitness for a particular purpose of +the program and the nonexistence of any infringement or violation of +any right of any third party. + +Each user of the program will agree and understand, and be deemed to +have agreed and understood, that there is no warranty whatsoever for +the program and, accordingly, the entire risk arising from or +otherwise connected with the program is assumed by the user. + +Therefore, neither ICOT, the copyright holder, or any other +organization that participated in or was otherwise related to the +development of the program and their respective officials, directors, +officers and other employees shall be held liable for any and all +damages, including, without limitation, general, special, incidental +and consequential damages, arising out of or otherwise in connection +with the use or inability to use the program or any product, material +or result produced or otherwise obtained by using the program, +regardless of whether they have been advised of, or otherwise had +knowledge of, the possibility of such damages at any time during the +project or thereafter. Each user will be deemed to have agreed to the +foregoing by his or her commencement of use of the program. The term +"use" as used herein includes, but is not limited to, the use, +modification, copying and distribution of the program and the +production of secondary products from the program. + +In the case where the program, whether in its original form or +modified, was distributed or delivered to or received by a user from +any person, organization or entity other than ICOT, unless it makes or +grants independently of ICOT any specific warranty to the user in +writing, such person, organization or entity, will also be exempted +from and not be held liable to the user for any such damages as noted +above as far as the program is concerned. + +The Derby build relies on a jar file supplied by the JSON Simple +project, hosted at https://code.google.com/p/json-simple/. +The JSON simple jar file is licensed under the Apache 2.0 License. + +Hive CLI +Copyright 2016 The Apache Software Foundation + +Hive JDBC +Copyright 2016 The Apache Software Foundation + + +Chill is a set of Scala extensions for Kryo. +Copyright 2012 Twitter, Inc. + +Third Party Dependencies: + +Kryo 2.17 +BSD 3-Clause License +http://code.google.com/p/kryo + +Commons-Codec 1.7 +Apache Public License 2.0 +http://hadoop.apache.org + + + +Breeze is distributed under an Apache License V2.0 (See LICENSE) + +=============================================================================== + +Proximal algorithms outlined in Proximal.scala (package breeze.optimize.proximal) +are based on https://github.com/cvxgrp/proximal (see LICENSE for details) and distributed with +Copyright (c) 2014 by Debasish Das (Verizon), all rights reserved. + +=============================================================================== + +QuadraticMinimizer class in package breeze.optimize.proximal is distributed with Copyright (c) +2014, Debasish Das (Verizon), all rights reserved. + +=============================================================================== + +NonlinearMinimizer class in package breeze.optimize.proximal is distributed with Copyright (c) +2015, Debasish Das (Verizon), all rights reserved. + + + ========================================================================= + == NOTICE file corresponding to section 4(d) of the Apache License, == + == Version 2.0, in this case for the distribution of jets3t. == + ========================================================================= + + This product includes software developed by: + + The Apache Software Foundation (http://www.apache.org/). + + The ExoLab Project (http://www.exolab.org/) + + Sun Microsystems (http://www.sun.com/) + + Codehaus (http://castor.codehaus.org) + + Tatu Saloranta (http://wiki.fasterxml.com/TatuSaloranta) + + + +stream-lib +Copyright 2016 AddThis + +This product includes software developed by AddThis. + +This product also includes code adapted from: + +Apache Solr (http://lucene.apache.org/solr/) +Copyright 2014 The Apache Software Foundation + +Apache Mahout (http://mahout.apache.org/) +Copyright 2014 The Apache Software Foundation \ No newline at end of file diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index 9696f6987ad78..adfd3871f3426 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -201,13 +201,16 @@ exportMethods("%<=>%", "approxCountDistinct", "approxQuantile", "array_contains", + "array_distinct", "array_join", "array_max", "array_min", "array_position", + "array_remove", "array_repeat", "array_sort", "arrays_overlap", + "arrays_zip", "asc", "ascii", "asin", @@ -306,6 +309,7 @@ exportMethods("%<=>%", "lpad", "ltrim", "map_entries", + "map_from_arrays", "map_keys", "map_values", "max", diff --git a/R/pkg/R/client.R b/R/pkg/R/client.R index 4c87f64e7f0e1..660f0864403e0 100644 --- a/R/pkg/R/client.R +++ b/R/pkg/R/client.R @@ -71,15 +71,20 @@ checkJavaVersion <- function() { # If java is missing from PATH, we get an error in Unix and a warning in Windows javaVersionOut <- tryCatch( - launchScript(javaBin, "-version", wait = TRUE, stdout = TRUE, stderr = TRUE), - error = function(e) { - stop("Java version check failed. Please make sure Java is installed", - " and set JAVA_HOME to point to the installation directory.", e) - }, - warning = function(w) { - stop("Java version check failed. Please make sure Java is installed", - " and set JAVA_HOME to point to the installation directory.", w) - }) + if (is_windows()) { + # See SPARK-24535 + system2(javaBin, "-version", wait = TRUE, stdout = TRUE, stderr = TRUE) + } else { + launchScript(javaBin, "-version", wait = TRUE, stdout = TRUE, stderr = TRUE) + }, + error = function(e) { + stop("Java version check failed. Please make sure Java is installed", + " and set JAVA_HOME to point to the installation directory.", e) + }, + warning = function(w) { + stop("Java version check failed. Please make sure Java is installed", + " and set JAVA_HOME to point to the installation directory.", w) + }) javaVersionFilter <- Filter( function(x) { grepl(" version", x) @@ -93,6 +98,7 @@ checkJavaVersion <- function() { stop(paste("Java version", sparkJavaVersion, "is required for this package; found version:", javaVersionStr)) } + return(javaVersionNum) } launchBackend <- function(args, sparkHome, jars, sparkSubmitOpts, packages) { diff --git a/R/pkg/R/context.R b/R/pkg/R/context.R index 8ec727dd042bc..3e996a5ba26fc 100644 --- a/R/pkg/R/context.R +++ b/R/pkg/R/context.R @@ -305,6 +305,8 @@ setCheckpointDirSC <- function(sc, dirName) { #' Currently directories are only supported for Hadoop-supported filesystems. #' Refer Hadoop-supported filesystems at \url{https://wiki.apache.org/hadoop/HCFS}. #' +#' Note: A path can be added only once. Subsequent additions of the same path are ignored. +#' #' @rdname spark.addFile #' @param path The path of the file to be added #' @param recursive Whether to add files recursively from the path. Default is FALSE. diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index 3bff633fbc1ff..2929a00330c62 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -194,10 +194,12 @@ NULL #' \itemize{ #' \item \code{array_contains}: a value to be checked if contained in the column. #' \item \code{array_position}: a value to locate in the given array. +#' \item \code{array_remove}: a value to remove in the given array. #' } #' @param ... additional argument(s). In \code{to_json} and \code{from_json}, this contains #' additional named properties to control how it is converted, accepts the same -#' options as the JSON data source. +#' options as the JSON data source. In \code{arrays_zip}, this contains additional +#' Columns of arrays to be merged. #' @name column_collection_functions #' @rdname column_collection_functions #' @family collection functions @@ -207,9 +209,9 @@ NULL #' df <- createDataFrame(cbind(model = rownames(mtcars), mtcars)) #' tmp <- mutate(df, v1 = create_array(df$mpg, df$cyl, df$hp)) #' head(select(tmp, array_contains(tmp$v1, 21), size(tmp$v1))) -#' head(select(tmp, array_max(tmp$v1), array_min(tmp$v1))) +#' head(select(tmp, array_max(tmp$v1), array_min(tmp$v1), array_distinct(tmp$v1))) #' head(select(tmp, array_position(tmp$v1, 21), array_repeat(df$mpg, 3), array_sort(tmp$v1))) -#' head(select(tmp, flatten(tmp$v1), reverse(tmp$v1))) +#' head(select(tmp, flatten(tmp$v1), reverse(tmp$v1), array_remove(tmp$v1, 21))) #' tmp2 <- mutate(tmp, v2 = explode(tmp$v1)) #' head(tmp2) #' head(select(tmp, posexplode(tmp$v1))) @@ -221,6 +223,7 @@ NULL #' head(select(tmp3, element_at(tmp3$v3, "Valiant"))) #' tmp4 <- mutate(df, v4 = create_array(df$mpg, df$cyl), v5 = create_array(df$cyl, df$hp)) #' head(select(tmp4, concat(tmp4$v4, tmp4$v5), arrays_overlap(tmp4$v4, tmp4$v5))) +#' head(select(tmp4, arrays_zip(tmp4$v4, tmp4$v5), map_from_arrays(tmp4$v4, tmp4$v5))) #' head(select(tmp, concat(df$mpg, df$cyl, df$hp))) #' tmp5 <- mutate(df, v6 = create_array(df$model, df$model)) #' head(select(tmp5, array_join(tmp5$v6, "#"), array_join(tmp5$v6, "#", "NULL")))} @@ -1978,7 +1981,7 @@ setMethod("levenshtein", signature(y = "Column"), }) #' @details -#' \code{months_between}: Returns number of months between dates \code{y} and \code{x}. +#' \code{months_between}: Returns number of months between dates \code{y} and \code{x}. #' If \code{y} is later than \code{x}, then the result is positive. If \code{y} and \code{x} #' are on the same day of month, or both are the last day of month, time of day will be ignored. #' Otherwise, the difference is calculated based on 31 days per month, and rounded to 8 digits. @@ -3008,6 +3011,19 @@ setMethod("array_contains", column(jc) }) +#' @details +#' \code{array_distinct}: Removes duplicate values from the array. +#' +#' @rdname column_collection_functions +#' @aliases array_distinct array_distinct,Column-method +#' @note array_distinct since 2.4.0 +setMethod("array_distinct", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "array_distinct", x@jc) + column(jc) + }) + #' @details #' \code{array_join}: Concatenates the elements of column using the delimiter. #' Null values are replaced with nullReplacement if set, otherwise they are ignored. @@ -3071,6 +3087,19 @@ setMethod("array_position", column(jc) }) +#' @details +#' \code{array_remove}: Removes all elements that equal to element from the given array. +#' +#' @rdname column_collection_functions +#' @aliases array_remove array_remove,Column-method +#' @note array_remove since 2.4.0 +setMethod("array_remove", + signature(x = "Column", value = "ANY"), + function(x, value) { + jc <- callJStatic("org.apache.spark.sql.functions", "array_remove", x@jc, value) + column(jc) + }) + #' @details #' \code{array_repeat}: Creates an array containing \code{x} repeated the number of times #' given by \code{count}. @@ -3120,6 +3149,24 @@ setMethod("arrays_overlap", column(jc) }) +#' @details +#' \code{arrays_zip}: Returns a merged array of structs in which the N-th struct contains all N-th +#' values of input arrays. +#' +#' @rdname column_collection_functions +#' @aliases arrays_zip arrays_zip,Column-method +#' @note arrays_zip since 2.4.0 +setMethod("arrays_zip", + signature(x = "Column"), + function(x, ...) { + jcols <- lapply(list(x, ...), function(arg) { + stopifnot(class(arg) == "Column") + arg@jc + }) + jc <- callJStatic("org.apache.spark.sql.functions", "arrays_zip", jcols) + column(jc) + }) + #' @details #' \code{flatten}: Creates a single array from an array of arrays. #' If a structure of nested arrays is deeper than two levels, only one level of nesting is removed. @@ -3147,6 +3194,21 @@ setMethod("map_entries", column(jc) }) +#' @details +#' \code{map_from_arrays}: Creates a new map column. The array in the first column is used for +#' keys. The array in the second column is used for values. All elements in the array for key +#' should not be null. +#' +#' @rdname column_collection_functions +#' @aliases map_from_arrays map_from_arrays,Column-method +#' @note map_from_arrays since 2.4.0 +setMethod("map_from_arrays", + signature(x = "Column", y = "Column"), + function(x, y) { + jc <- callJStatic("org.apache.spark.sql.functions", "map_from_arrays", x@jc, y@jc) + column(jc) + }) + #' @details #' \code{map_keys}: Returns an unordered array containing the keys of the map. #' diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 9321bbaf96ff8..4a7210bf1b902 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -757,6 +757,10 @@ setGeneric("approxCountDistinct", function(x, ...) { standardGeneric("approxCoun #' @name NULL setGeneric("array_contains", function(x, value) { standardGeneric("array_contains") }) +#' @rdname column_collection_functions +#' @name NULL +setGeneric("array_distinct", function(x) { standardGeneric("array_distinct") }) + #' @rdname column_collection_functions #' @name NULL setGeneric("array_join", function(x, delimiter, ...) { standardGeneric("array_join") }) @@ -773,6 +777,10 @@ setGeneric("array_min", function(x) { standardGeneric("array_min") }) #' @name NULL setGeneric("array_position", function(x, value) { standardGeneric("array_position") }) +#' @rdname column_collection_functions +#' @name NULL +setGeneric("array_remove", function(x, value) { standardGeneric("array_remove") }) + #' @rdname column_collection_functions #' @name NULL setGeneric("array_repeat", function(x, count) { standardGeneric("array_repeat") }) @@ -785,6 +793,10 @@ setGeneric("array_sort", function(x) { standardGeneric("array_sort") }) #' @name NULL setGeneric("arrays_overlap", function(x, y) { standardGeneric("arrays_overlap") }) +#' @rdname column_collection_functions +#' @name NULL +setGeneric("arrays_zip", function(x, ...) { standardGeneric("arrays_zip") }) + #' @rdname column_string_functions #' @name NULL setGeneric("ascii", function(x) { standardGeneric("ascii") }) @@ -1050,6 +1062,10 @@ setGeneric("ltrim", function(x, trimString) { standardGeneric("ltrim") }) #' @name NULL setGeneric("map_entries", function(x) { standardGeneric("map_entries") }) +#' @rdname column_collection_functions +#' @name NULL +setGeneric("map_from_arrays", function(x, y) { standardGeneric("map_from_arrays") }) + #' @rdname column_collection_functions #' @name NULL setGeneric("map_keys", function(x) { standardGeneric("map_keys") }) diff --git a/R/pkg/R/sparkR.R b/R/pkg/R/sparkR.R index f7c1663d32c96..d3a9cbae7d808 100644 --- a/R/pkg/R/sparkR.R +++ b/R/pkg/R/sparkR.R @@ -167,7 +167,7 @@ sparkR.sparkContext <- function( submitOps <- getClientModeSparkSubmitOpts( Sys.getenv("SPARKR_SUBMIT_ARGS", "sparkr-shell"), sparkEnvirMap) - checkJavaVersion() + invisible(checkJavaVersion()) launchBackend( args = path, sparkHome = sparkHome, diff --git a/R/pkg/inst/tests/testthat/test_basic.R b/R/pkg/inst/tests/testthat/test_basic.R index 823d26f12feee..243f5f0298284 100644 --- a/R/pkg/inst/tests/testthat/test_basic.R +++ b/R/pkg/inst/tests/testthat/test_basic.R @@ -18,6 +18,10 @@ context("basic tests for CRAN") test_that("create DataFrame from list or data.frame", { + tryCatch( checkJavaVersion(), + error = function(e) { skip("error on Java check") }, + warning = function(e) { skip("warning on Java check") } ) + sparkR.session(master = sparkRTestMaster, enableHiveSupport = FALSE, sparkConfig = sparkRTestConfig) @@ -50,6 +54,10 @@ test_that("create DataFrame from list or data.frame", { }) test_that("spark.glm and predict", { + tryCatch( checkJavaVersion(), + error = function(e) { skip("error on Java check") }, + warning = function(e) { skip("warning on Java check") } ) + sparkR.session(master = sparkRTestMaster, enableHiveSupport = FALSE, sparkConfig = sparkRTestConfig) diff --git a/R/pkg/tests/fulltests/test_sparkSQL.R b/R/pkg/tests/fulltests/test_sparkSQL.R index 36e0f78bb0599..adcbbff823a2d 100644 --- a/R/pkg/tests/fulltests/test_sparkSQL.R +++ b/R/pkg/tests/fulltests/test_sparkSQL.R @@ -1503,6 +1503,27 @@ test_that("column functions", { result <- collect(select(df2, reverse(df2[[1]])))[[1]] expect_equal(result, "cba") + # Test array_distinct() and array_remove() + df <- createDataFrame(list(list(list(1L, 2L, 3L, 1L, 2L)), list(list(6L, 5L, 5L, 4L, 6L)))) + result <- collect(select(df, array_distinct(df[[1]])))[[1]] + expect_equal(result, list(list(1L, 2L, 3L), list(6L, 5L, 4L))) + + result <- collect(select(df, array_remove(df[[1]], 2L)))[[1]] + expect_equal(result, list(list(1L, 3L, 1L), list(6L, 5L, 5L, 4L, 6L))) + + # Test arrays_zip() + df <- createDataFrame(list(list(list(1L, 2L), list(3L, 4L))), schema = c("c1", "c2")) + result <- collect(select(df, arrays_zip(df[[1]], df[[2]])))[[1]] + expected_entries <- list(listToStruct(list(c1 = 1L, c2 = 3L)), + listToStruct(list(c1 = 2L, c2 = 4L))) + expect_equal(result, list(expected_entries)) + + # Test map_from_arrays() + df <- createDataFrame(list(list(list("x", "y"), list(1, 2))), schema = c("k", "v")) + result <- collect(select(df, map_from_arrays(df$k, df$v)))[[1]] + expected_entries <- list(as.environment(list(x = 1, y = 2))) + expect_equal(result, expected_entries) + # Test array_repeat() df <- createDataFrame(list(list("a", 3L), list("b", 2L))) result <- collect(select(df, array_repeat(df[[1]], df[[2]])))[[1]] diff --git a/R/pkg/vignettes/sparkr-vignettes.Rmd b/R/pkg/vignettes/sparkr-vignettes.Rmd index d4713de7806a1..68a18ab57b28d 100644 --- a/R/pkg/vignettes/sparkr-vignettes.Rmd +++ b/R/pkg/vignettes/sparkr-vignettes.Rmd @@ -590,6 +590,7 @@ summary(model) Predict values on training data ```{r} prediction <- predict(model, training) +head(select(prediction, "Class", "Sex", "Age", "Freq", "Survived", "prediction")) ``` #### Logistic Regression @@ -613,6 +614,7 @@ summary(model) Predict values on training data ```{r} fitted <- predict(model, training) +head(select(fitted, "Class", "Sex", "Age", "Freq", "Survived", "prediction")) ``` Multinomial logistic regression against three classes @@ -807,6 +809,7 @@ df <- createDataFrame(t) dtModel <- spark.decisionTree(df, Survived ~ ., type = "classification", maxDepth = 2) summary(dtModel) predictions <- predict(dtModel, df) +head(select(predictions, "Class", "Sex", "Age", "Freq", "Survived", "prediction")) ``` #### Gradient-Boosted Trees @@ -822,6 +825,7 @@ df <- createDataFrame(t) gbtModel <- spark.gbt(df, Survived ~ ., type = "classification", maxDepth = 2, maxIter = 2) summary(gbtModel) predictions <- predict(gbtModel, df) +head(select(predictions, "Class", "Sex", "Age", "Freq", "Survived", "prediction")) ``` #### Random Forest @@ -837,6 +841,7 @@ df <- createDataFrame(t) rfModel <- spark.randomForest(df, Survived ~ ., type = "classification", maxDepth = 2, numTrees = 2) summary(rfModel) predictions <- predict(rfModel, df) +head(select(predictions, "Class", "Sex", "Age", "Freq", "Survived", "prediction")) ``` #### Bisecting k-Means diff --git a/bin/docker-image-tool.sh b/bin/docker-image-tool.sh index a871ab5d448c3..a3f1bcffaea57 100755 --- a/bin/docker-image-tool.sh +++ b/bin/docker-image-tool.sh @@ -70,17 +70,18 @@ function build { local BASEDOCKERFILE=${BASEDOCKERFILE:-"$IMG_PATH/spark/Dockerfile"} local PYDOCKERFILE=${PYDOCKERFILE:-"$IMG_PATH/spark/bindings/python/Dockerfile"} - docker build "${BUILD_ARGS[@]}" \ + docker build $NOCACHEARG "${BUILD_ARGS[@]}" \ -t $(image_ref spark) \ -f "$BASEDOCKERFILE" . - docker build "${BINDING_BUILD_ARGS[@]}" \ + docker build $NOCACHEARG "${BINDING_BUILD_ARGS[@]}" \ -t $(image_ref spark-py) \ -f "$PYDOCKERFILE" . } function push { docker push "$(image_ref spark)" + docker push "$(image_ref spark-py)" } function usage { @@ -99,6 +100,7 @@ Options: -r repo Repository address. -t tag Tag to apply to the built image, or to identify the image to be pushed. -m Use minikube's Docker daemon. + -n Build docker image with --no-cache Using minikube when building images will do so directly into minikube's Docker daemon. There is no need to push the images into minikube in that case, they'll be automatically @@ -127,7 +129,8 @@ REPO= TAG= BASEDOCKERFILE= PYDOCKERFILE= -while getopts f:mr:t: option +NOCACHEARG= +while getopts f:mr:t:n option do case "${option}" in @@ -135,6 +138,7 @@ do p) PYDOCKERFILE=${OPTARG};; r) REPO=${OPTARG};; t) TAG=${OPTARG};; + n) NOCACHEARG="--no-cache";; m) if ! which minikube 1>/dev/null; then error "Cannot find minikube." diff --git a/build/mvn b/build/mvn index efa4f9364ea52..ae4276dbc7e32 100755 --- a/build/mvn +++ b/build/mvn @@ -93,7 +93,7 @@ install_mvn() { install_zinc() { local zinc_path="zinc-0.3.15/bin/zinc" [ ! -f "${_DIR}/${zinc_path}" ] && ZINC_INSTALL_FLAG=1 - local TYPESAFE_MIRROR=${TYPESAFE_MIRROR:-https://downloads.typesafe.com} + local TYPESAFE_MIRROR=${TYPESAFE_MIRROR:-https://downloads.lightbend.com} install_app \ "${TYPESAFE_MIRROR}/zinc/0.3.15" \ @@ -109,7 +109,7 @@ install_scala() { # determine the Scala version used in Spark local scala_version=`grep "scala.version" "${_DIR}/../pom.xml" | head -n1 | awk -F '[<>]' '{print $3}'` local scala_bin="${_DIR}/scala-${scala_version}/bin/scala" - local TYPESAFE_MIRROR=${TYPESAFE_MIRROR:-https://downloads.typesafe.com} + local TYPESAFE_MIRROR=${TYPESAFE_MIRROR:-https://downloads.lightbend.com} install_app \ "${TYPESAFE_MIRROR}/scala/${scala_version}" \ @@ -154,4 +154,4 @@ export MAVEN_OPTS=${MAVEN_OPTS:-"$_COMPILE_JVM_OPTS"} echo "Using \`mvn\` from path: $MVN_BIN" 1>&2 # Last, call the `mvn` command as usual -${MVN_BIN} -DzincPort=${ZINC_PORT} "$@" +"${MVN_BIN}" -DzincPort=${ZINC_PORT} "$@" diff --git a/common/network-common/src/main/java/org/apache/spark/network/client/StreamCallbackWithID.java b/common/network-common/src/main/java/org/apache/spark/network/client/StreamCallbackWithID.java new file mode 100644 index 0000000000000..bd173b653e33e --- /dev/null +++ b/common/network-common/src/main/java/org/apache/spark/network/client/StreamCallbackWithID.java @@ -0,0 +1,22 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.client; + +public interface StreamCallbackWithID extends StreamCallback { + String getID(); +} diff --git a/common/network-common/src/main/java/org/apache/spark/network/client/StreamInterceptor.java b/common/network-common/src/main/java/org/apache/spark/network/client/StreamInterceptor.java index b0e85bae7c309..f3eb744ff7345 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/client/StreamInterceptor.java +++ b/common/network-common/src/main/java/org/apache/spark/network/client/StreamInterceptor.java @@ -22,22 +22,24 @@ import io.netty.buffer.ByteBuf; +import org.apache.spark.network.protocol.Message; +import org.apache.spark.network.server.MessageHandler; import org.apache.spark.network.util.TransportFrameDecoder; /** * An interceptor that is registered with the frame decoder to feed stream data to a * callback. */ -class StreamInterceptor implements TransportFrameDecoder.Interceptor { +public class StreamInterceptor implements TransportFrameDecoder.Interceptor { - private final TransportResponseHandler handler; + private final MessageHandler handler; private final String streamId; private final long byteCount; private final StreamCallback callback; private long bytesRead; - StreamInterceptor( - TransportResponseHandler handler, + public StreamInterceptor( + MessageHandler handler, String streamId, long byteCount, StreamCallback callback) { @@ -50,16 +52,24 @@ class StreamInterceptor implements TransportFrameDecoder.Interceptor { @Override public void exceptionCaught(Throwable cause) throws Exception { - handler.deactivateStream(); + deactivateStream(); callback.onFailure(streamId, cause); } @Override public void channelInactive() throws Exception { - handler.deactivateStream(); + deactivateStream(); callback.onFailure(streamId, new ClosedChannelException()); } + private void deactivateStream() { + if (handler instanceof TransportResponseHandler) { + // we only have to do this for TransportResponseHandler as it exposes numOutstandingFetches + // (there is no extra cleanup that needs to happen) + ((TransportResponseHandler) handler).deactivateStream(); + } + } + @Override public boolean handle(ByteBuf buf) throws Exception { int toRead = (int) Math.min(buf.readableBytes(), byteCount - bytesRead); @@ -72,10 +82,10 @@ public boolean handle(ByteBuf buf) throws Exception { RuntimeException re = new IllegalStateException(String.format( "Read too many bytes? Expected %d, but read %d.", byteCount, bytesRead)); callback.onFailure(streamId, re); - handler.deactivateStream(); + deactivateStream(); throw re; } else if (bytesRead == byteCount) { - handler.deactivateStream(); + deactivateStream(); callback.onComplete(streamId); } diff --git a/common/network-common/src/main/java/org/apache/spark/network/client/TransportClient.java b/common/network-common/src/main/java/org/apache/spark/network/client/TransportClient.java index 8f354ad78bbaa..325225dc0ea2c 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/client/TransportClient.java +++ b/common/network-common/src/main/java/org/apache/spark/network/client/TransportClient.java @@ -32,15 +32,15 @@ import com.google.common.base.Throwables; import com.google.common.util.concurrent.SettableFuture; import io.netty.channel.Channel; +import io.netty.util.concurrent.Future; +import io.netty.util.concurrent.GenericFutureListener; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.apache.spark.network.buffer.ManagedBuffer; import org.apache.spark.network.buffer.NioManagedBuffer; -import org.apache.spark.network.protocol.ChunkFetchRequest; -import org.apache.spark.network.protocol.OneWayMessage; -import org.apache.spark.network.protocol.RpcRequest; -import org.apache.spark.network.protocol.StreamChunkId; -import org.apache.spark.network.protocol.StreamRequest; +import org.apache.spark.network.protocol.*; + import static org.apache.spark.network.util.NettyUtils.getRemoteAddress; /** @@ -133,34 +133,21 @@ public void fetchChunk( long streamId, int chunkIndex, ChunkReceivedCallback callback) { - long startTime = System.currentTimeMillis(); if (logger.isDebugEnabled()) { logger.debug("Sending fetch chunk request {} to {}", chunkIndex, getRemoteAddress(channel)); } StreamChunkId streamChunkId = new StreamChunkId(streamId, chunkIndex); - handler.addFetchRequest(streamChunkId, callback); - - channel.writeAndFlush(new ChunkFetchRequest(streamChunkId)).addListener(future -> { - if (future.isSuccess()) { - long timeTaken = System.currentTimeMillis() - startTime; - if (logger.isTraceEnabled()) { - logger.trace("Sending request {} to {} took {} ms", streamChunkId, - getRemoteAddress(channel), timeTaken); - } - } else { - String errorMsg = String.format("Failed to send request %s to %s: %s", streamChunkId, - getRemoteAddress(channel), future.cause()); - logger.error(errorMsg, future.cause()); + StdChannelListener listener = new StdChannelListener(streamChunkId) { + @Override + void handleFailure(String errorMsg, Throwable cause) { handler.removeFetchRequest(streamChunkId); - channel.close(); - try { - callback.onFailure(chunkIndex, new IOException(errorMsg, future.cause())); - } catch (Exception e) { - logger.error("Uncaught exception in RPC response callback handler!", e); - } + callback.onFailure(chunkIndex, new IOException(errorMsg, cause)); } - }); + }; + handler.addFetchRequest(streamChunkId, callback); + + channel.writeAndFlush(new ChunkFetchRequest(streamChunkId)).addListener(listener); } /** @@ -170,7 +157,12 @@ public void fetchChunk( * @param callback Object to call with the stream data. */ public void stream(String streamId, StreamCallback callback) { - long startTime = System.currentTimeMillis(); + StdChannelListener listener = new StdChannelListener(streamId) { + @Override + void handleFailure(String errorMsg, Throwable cause) throws Exception { + callback.onFailure(streamId, new IOException(errorMsg, cause)); + } + }; if (logger.isDebugEnabled()) { logger.debug("Sending stream request for {} to {}", streamId, getRemoteAddress(channel)); } @@ -180,25 +172,7 @@ public void stream(String streamId, StreamCallback callback) { // when responses arrive. synchronized (this) { handler.addStreamCallback(streamId, callback); - channel.writeAndFlush(new StreamRequest(streamId)).addListener(future -> { - if (future.isSuccess()) { - long timeTaken = System.currentTimeMillis() - startTime; - if (logger.isTraceEnabled()) { - logger.trace("Sending request for {} to {} took {} ms", streamId, - getRemoteAddress(channel), timeTaken); - } - } else { - String errorMsg = String.format("Failed to send request for %s to %s: %s", streamId, - getRemoteAddress(channel), future.cause()); - logger.error(errorMsg, future.cause()); - channel.close(); - try { - callback.onFailure(streamId, new IOException(errorMsg, future.cause())); - } catch (Exception e) { - logger.error("Uncaught exception in RPC response callback handler!", e); - } - } - }); + channel.writeAndFlush(new StreamRequest(streamId)).addListener(listener); } } @@ -211,35 +185,44 @@ public void stream(String streamId, StreamCallback callback) { * @return The RPC's id. */ public long sendRpc(ByteBuffer message, RpcResponseCallback callback) { - long startTime = System.currentTimeMillis(); if (logger.isTraceEnabled()) { logger.trace("Sending RPC to {}", getRemoteAddress(channel)); } - long requestId = Math.abs(UUID.randomUUID().getLeastSignificantBits()); + long requestId = requestId(); handler.addRpcRequest(requestId, callback); + RpcChannelListener listener = new RpcChannelListener(requestId, callback); channel.writeAndFlush(new RpcRequest(requestId, new NioManagedBuffer(message))) - .addListener(future -> { - if (future.isSuccess()) { - long timeTaken = System.currentTimeMillis() - startTime; - if (logger.isTraceEnabled()) { - logger.trace("Sending request {} to {} took {} ms", requestId, - getRemoteAddress(channel), timeTaken); - } - } else { - String errorMsg = String.format("Failed to send RPC %s to %s: %s", requestId, - getRemoteAddress(channel), future.cause()); - logger.error(errorMsg, future.cause()); - handler.removeRpcRequest(requestId); - channel.close(); - try { - callback.onFailure(new IOException(errorMsg, future.cause())); - } catch (Exception e) { - logger.error("Uncaught exception in RPC response callback handler!", e); - } - } - }); + .addListener(listener); + + return requestId; + } + + /** + * Send data to the remote end as a stream. This differs from stream() in that this is a request + * to *send* data to the remote end, not to receive it from the remote. + * + * @param meta meta data associated with the stream, which will be read completely on the + * receiving end before the stream itself. + * @param data this will be streamed to the remote end to allow for transferring large amounts + * of data without reading into memory. + * @param callback handles the reply -- onSuccess will only be called when both message and data + * are received successfully. + */ + public long uploadStream( + ManagedBuffer meta, + ManagedBuffer data, + RpcResponseCallback callback) { + if (logger.isTraceEnabled()) { + logger.trace("Sending RPC to {}", getRemoteAddress(channel)); + } + + long requestId = requestId(); + handler.addRpcRequest(requestId, callback); + + RpcChannelListener listener = new RpcChannelListener(requestId, callback); + channel.writeAndFlush(new UploadStream(requestId, meta, data)).addListener(listener); return requestId; } @@ -319,4 +302,60 @@ public String toString() { .add("isActive", isActive()) .toString(); } + + private static long requestId() { + return Math.abs(UUID.randomUUID().getLeastSignificantBits()); + } + + private class StdChannelListener + implements GenericFutureListener> { + final long startTime; + final Object requestId; + + StdChannelListener(Object requestId) { + this.startTime = System.currentTimeMillis(); + this.requestId = requestId; + } + + @Override + public void operationComplete(Future future) throws Exception { + if (future.isSuccess()) { + if (logger.isTraceEnabled()) { + long timeTaken = System.currentTimeMillis() - startTime; + logger.trace("Sending request {} to {} took {} ms", requestId, + getRemoteAddress(channel), timeTaken); + } + } else { + String errorMsg = String.format("Failed to send RPC %s to %s: %s", requestId, + getRemoteAddress(channel), future.cause()); + logger.error(errorMsg, future.cause()); + channel.close(); + try { + handleFailure(errorMsg, future.cause()); + } catch (Exception e) { + logger.error("Uncaught exception in RPC response callback handler!", e); + } + } + } + + void handleFailure(String errorMsg, Throwable cause) throws Exception {} + } + + private class RpcChannelListener extends StdChannelListener { + final long rpcRequestId; + final RpcResponseCallback callback; + + RpcChannelListener(long rpcRequestId, RpcResponseCallback callback) { + super("RPC " + rpcRequestId); + this.rpcRequestId = rpcRequestId; + this.callback = callback; + } + + @Override + void handleFailure(String errorMsg, Throwable cause) { + handler.removeRpcRequest(rpcRequestId); + callback.onFailure(new IOException(errorMsg, cause)); + } + } + } diff --git a/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthRpcHandler.java b/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthRpcHandler.java index 8a6e3858081bf..fb44dbbb0953b 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthRpcHandler.java +++ b/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthRpcHandler.java @@ -29,6 +29,7 @@ import org.slf4j.LoggerFactory; import org.apache.spark.network.client.RpcResponseCallback; +import org.apache.spark.network.client.StreamCallbackWithID; import org.apache.spark.network.client.TransportClient; import org.apache.spark.network.sasl.SecretKeyHolder; import org.apache.spark.network.sasl.SaslRpcHandler; @@ -149,6 +150,14 @@ public void receive(TransportClient client, ByteBuffer message) { delegate.receive(client, message); } + @Override + public StreamCallbackWithID receiveStream( + TransportClient client, + ByteBuffer message, + RpcResponseCallback callback) { + return delegate.receiveStream(client, message, callback); + } + @Override public StreamManager getStreamManager() { return delegate.getStreamManager(); diff --git a/common/network-common/src/main/java/org/apache/spark/network/protocol/Message.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/Message.java index 434935a8ef2ad..0ccd70c03aba8 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/protocol/Message.java +++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/Message.java @@ -37,7 +37,7 @@ enum Type implements Encodable { ChunkFetchRequest(0), ChunkFetchSuccess(1), ChunkFetchFailure(2), RpcRequest(3), RpcResponse(4), RpcFailure(5), StreamRequest(6), StreamResponse(7), StreamFailure(8), - OneWayMessage(9), User(-1); + OneWayMessage(9), UploadStream(10), User(-1); private final byte id; @@ -65,6 +65,7 @@ public static Type decode(ByteBuf buf) { case 7: return StreamResponse; case 8: return StreamFailure; case 9: return OneWayMessage; + case 10: return UploadStream; case -1: throw new IllegalArgumentException("User type messages cannot be decoded."); default: throw new IllegalArgumentException("Unknown message type: " + id); } diff --git a/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageDecoder.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageDecoder.java index 39a7495828a8a..bf80aed0afe10 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageDecoder.java +++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageDecoder.java @@ -80,6 +80,9 @@ private Message decode(Message.Type msgType, ByteBuf in) { case StreamFailure: return StreamFailure.decode(in); + case UploadStream: + return UploadStream.decode(in); + default: throw new IllegalArgumentException("Unexpected message type: " + msgType); } diff --git a/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageWithHeader.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageWithHeader.java index a5337656cbd84..e7b66a6f33a82 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageWithHeader.java +++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageWithHeader.java @@ -137,30 +137,15 @@ protected void deallocate() { } private int copyByteBuf(ByteBuf buf, WritableByteChannel target) throws IOException { - ByteBuffer buffer = buf.nioBuffer(); - int written = (buffer.remaining() <= NIO_BUFFER_LIMIT) ? - target.write(buffer) : writeNioBuffer(target, buffer); + // SPARK-24578: cap the sub-region's size of returned nio buffer to improve the performance + // for the case that the passed-in buffer has too many components. + int length = Math.min(buf.readableBytes(), NIO_BUFFER_LIMIT); + ByteBuffer buffer = buf.nioBuffer(buf.readerIndex(), length); + int written = target.write(buffer); buf.skipBytes(written); return written; } - private int writeNioBuffer( - WritableByteChannel writeCh, - ByteBuffer buf) throws IOException { - int originalLimit = buf.limit(); - int ret = 0; - - try { - int ioSize = Math.min(buf.remaining(), NIO_BUFFER_LIMIT); - buf.limit(buf.position() + ioSize); - ret = writeCh.write(buf); - } finally { - buf.limit(originalLimit); - } - - return ret; - } - @Override public MessageWithHeader touch(Object o) { super.touch(o); diff --git a/common/network-common/src/main/java/org/apache/spark/network/protocol/StreamResponse.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/StreamResponse.java index 87e212f3e157b..50b811604b84b 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/protocol/StreamResponse.java +++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/StreamResponse.java @@ -67,7 +67,7 @@ public static StreamResponse decode(ByteBuf buf) { @Override public int hashCode() { - return Objects.hashCode(byteCount, streamId, body()); + return Objects.hashCode(byteCount, streamId); } @Override diff --git a/common/network-common/src/main/java/org/apache/spark/network/protocol/UploadStream.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/UploadStream.java new file mode 100644 index 0000000000000..fa1d26e76b852 --- /dev/null +++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/UploadStream.java @@ -0,0 +1,107 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.protocol; + +import java.io.IOException; +import java.nio.ByteBuffer; + +import com.google.common.base.Objects; +import io.netty.buffer.ByteBuf; + +import org.apache.spark.network.buffer.ManagedBuffer; +import org.apache.spark.network.buffer.NettyManagedBuffer; + +/** + * An RPC with data that is sent outside of the frame, so it can be read as a stream. + */ +public final class UploadStream extends AbstractMessage implements RequestMessage { + /** Used to link an RPC request with its response. */ + public final long requestId; + public final ManagedBuffer meta; + public final long bodyByteCount; + + public UploadStream(long requestId, ManagedBuffer meta, ManagedBuffer body) { + super(body, false); // body is *not* included in the frame + this.requestId = requestId; + this.meta = meta; + bodyByteCount = body.size(); + } + + // this version is called when decoding the bytes on the receiving end. The body is handled + // separately. + private UploadStream(long requestId, ManagedBuffer meta, long bodyByteCount) { + super(null, false); + this.requestId = requestId; + this.meta = meta; + this.bodyByteCount = bodyByteCount; + } + + @Override + public Type type() { return Type.UploadStream; } + + @Override + public int encodedLength() { + // the requestId, meta size, meta and bodyByteCount (body is not included) + return 8 + 4 + ((int) meta.size()) + 8; + } + + @Override + public void encode(ByteBuf buf) { + buf.writeLong(requestId); + try { + ByteBuffer metaBuf = meta.nioByteBuffer(); + buf.writeInt(metaBuf.remaining()); + buf.writeBytes(metaBuf); + } catch (IOException io) { + throw new RuntimeException(io); + } + buf.writeLong(bodyByteCount); + } + + public static UploadStream decode(ByteBuf buf) { + long requestId = buf.readLong(); + int metaSize = buf.readInt(); + ManagedBuffer meta = new NettyManagedBuffer(buf.readRetainedSlice(metaSize)); + long bodyByteCount = buf.readLong(); + // This is called by the frame decoder, so the data is still null. We need a StreamInterceptor + // to read the data. + return new UploadStream(requestId, meta, bodyByteCount); + } + + @Override + public int hashCode() { + return Long.hashCode(requestId); + } + + @Override + public boolean equals(Object other) { + if (other instanceof UploadStream) { + UploadStream o = (UploadStream) other; + return requestId == o.requestId && super.equals(o); + } + return false; + } + + @Override + public String toString() { + return Objects.toStringHelper(this) + .add("requestId", requestId) + .add("body", body()) + .toString(); + } +} diff --git a/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java index 0231428318add..355a3def8cc22 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java +++ b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java @@ -28,6 +28,7 @@ import org.slf4j.LoggerFactory; import org.apache.spark.network.client.RpcResponseCallback; +import org.apache.spark.network.client.StreamCallbackWithID; import org.apache.spark.network.client.TransportClient; import org.apache.spark.network.server.RpcHandler; import org.apache.spark.network.server.StreamManager; @@ -132,6 +133,14 @@ public void receive(TransportClient client, ByteBuffer message) { delegate.receive(client, message); } + @Override + public StreamCallbackWithID receiveStream( + TransportClient client, + ByteBuffer message, + RpcResponseCallback callback) { + return delegate.receiveStream(client, message, callback); + } + @Override public StreamManager getStreamManager() { return delegate.getStreamManager(); diff --git a/common/network-common/src/main/java/org/apache/spark/network/server/RpcHandler.java b/common/network-common/src/main/java/org/apache/spark/network/server/RpcHandler.java index 8f7554e2e07d5..38569baf82bce 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/server/RpcHandler.java +++ b/common/network-common/src/main/java/org/apache/spark/network/server/RpcHandler.java @@ -23,6 +23,7 @@ import org.slf4j.LoggerFactory; import org.apache.spark.network.client.RpcResponseCallback; +import org.apache.spark.network.client.StreamCallbackWithID; import org.apache.spark.network.client.TransportClient; /** @@ -36,7 +37,8 @@ public abstract class RpcHandler { * Receive a single RPC message. Any exception thrown while in this method will be sent back to * the client in string form as a standard RPC failure. * - * This method will not be called in parallel for a single TransportClient (i.e., channel). + * Neither this method nor #receiveStream will be called in parallel for a single + * TransportClient (i.e., channel). * * @param client A channel client which enables the handler to make requests back to the sender * of this RPC. This will always be the exact same object for a particular channel. @@ -49,6 +51,36 @@ public abstract void receive( ByteBuffer message, RpcResponseCallback callback); + /** + * Receive a single RPC message which includes data that is to be received as a stream. Any + * exception thrown while in this method will be sent back to the client in string form as a + * standard RPC failure. + * + * Neither this method nor #receive will be called in parallel for a single TransportClient + * (i.e., channel). + * + * An error while reading data from the stream + * ({@link org.apache.spark.network.client.StreamCallback#onData(String, ByteBuffer)}) + * will fail the entire channel. A failure in "post-processing" the stream in + * {@link org.apache.spark.network.client.StreamCallback#onComplete(String)} will result in an + * rpcFailure, but the channel will remain active. + * + * @param client A channel client which enables the handler to make requests back to the sender + * of this RPC. This will always be the exact same object for a particular channel. + * @param messageHeader The serialized bytes of the header portion of the RPC. This is in meant + * to be relatively small, and will be buffered entirely in memory, to + * facilitate how the streaming portion should be received. + * @param callback Callback which should be invoked exactly once upon success or failure of the + * RPC. + * @return a StreamCallback for handling the accompanying streaming data + */ + public StreamCallbackWithID receiveStream( + TransportClient client, + ByteBuffer messageHeader, + RpcResponseCallback callback) { + throw new UnsupportedOperationException(); + } + /** * Returns the StreamManager which contains the state about which streams are currently being * fetched by a TransportClient. diff --git a/common/network-common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java b/common/network-common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java index e94453578e6b0..e1d7b2dbff60f 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java +++ b/common/network-common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java @@ -17,6 +17,7 @@ package org.apache.spark.network.server; +import java.io.IOException; import java.net.SocketAddress; import java.nio.ByteBuffer; @@ -28,20 +29,10 @@ import org.apache.spark.network.buffer.ManagedBuffer; import org.apache.spark.network.buffer.NioManagedBuffer; -import org.apache.spark.network.client.RpcResponseCallback; -import org.apache.spark.network.client.TransportClient; -import org.apache.spark.network.protocol.ChunkFetchRequest; -import org.apache.spark.network.protocol.ChunkFetchFailure; -import org.apache.spark.network.protocol.ChunkFetchSuccess; -import org.apache.spark.network.protocol.Encodable; -import org.apache.spark.network.protocol.OneWayMessage; -import org.apache.spark.network.protocol.RequestMessage; -import org.apache.spark.network.protocol.RpcFailure; -import org.apache.spark.network.protocol.RpcRequest; -import org.apache.spark.network.protocol.RpcResponse; -import org.apache.spark.network.protocol.StreamFailure; -import org.apache.spark.network.protocol.StreamRequest; -import org.apache.spark.network.protocol.StreamResponse; +import org.apache.spark.network.client.*; +import org.apache.spark.network.protocol.*; +import org.apache.spark.network.util.TransportFrameDecoder; + import static org.apache.spark.network.util.NettyUtils.getRemoteAddress; /** @@ -52,6 +43,7 @@ * The messages should have been processed by the pipeline setup by {@link TransportServer}. */ public class TransportRequestHandler extends MessageHandler { + private static final Logger logger = LoggerFactory.getLogger(TransportRequestHandler.class); /** The Netty channel that this handler is associated with. */ @@ -113,6 +105,8 @@ public void handle(RequestMessage request) { processOneWayMessage((OneWayMessage) request); } else if (request instanceof StreamRequest) { processStreamRequest((StreamRequest) request); + } else if (request instanceof UploadStream) { + processStreamUpload((UploadStream) request); } else { throw new IllegalArgumentException("Unknown request type: " + request); } @@ -203,6 +197,79 @@ public void onFailure(Throwable e) { } } + /** + * Handle a request from the client to upload a stream of data. + */ + private void processStreamUpload(final UploadStream req) { + assert (req.body() == null); + try { + RpcResponseCallback callback = new RpcResponseCallback() { + @Override + public void onSuccess(ByteBuffer response) { + respond(new RpcResponse(req.requestId, new NioManagedBuffer(response))); + } + + @Override + public void onFailure(Throwable e) { + respond(new RpcFailure(req.requestId, Throwables.getStackTraceAsString(e))); + } + }; + TransportFrameDecoder frameDecoder = (TransportFrameDecoder) + channel.pipeline().get(TransportFrameDecoder.HANDLER_NAME); + ByteBuffer meta = req.meta.nioByteBuffer(); + StreamCallbackWithID streamHandler = rpcHandler.receiveStream(reverseClient, meta, callback); + if (streamHandler == null) { + throw new NullPointerException("rpcHandler returned a null streamHandler"); + } + StreamCallbackWithID wrappedCallback = new StreamCallbackWithID() { + @Override + public void onData(String streamId, ByteBuffer buf) throws IOException { + streamHandler.onData(streamId, buf); + } + + @Override + public void onComplete(String streamId) throws IOException { + try { + streamHandler.onComplete(streamId); + callback.onSuccess(ByteBuffer.allocate(0)); + } catch (Exception ex) { + IOException ioExc = new IOException("Failure post-processing complete stream;" + + " failing this rpc and leaving channel active"); + callback.onFailure(ioExc); + streamHandler.onFailure(streamId, ioExc); + } + } + + @Override + public void onFailure(String streamId, Throwable cause) throws IOException { + callback.onFailure(new IOException("Destination failed while reading stream", cause)); + streamHandler.onFailure(streamId, cause); + } + + @Override + public String getID() { + return streamHandler.getID(); + } + }; + if (req.bodyByteCount > 0) { + StreamInterceptor interceptor = new StreamInterceptor(this, wrappedCallback.getID(), + req.bodyByteCount, wrappedCallback); + frameDecoder.setInterceptor(interceptor); + } else { + wrappedCallback.onComplete(wrappedCallback.getID()); + } + } catch (Exception e) { + logger.error("Error while invoking RpcHandler#receive() on RPC id " + req.requestId, e); + respond(new RpcFailure(req.requestId, Throwables.getStackTraceAsString(e))); + // We choose to totally fail the channel, rather than trying to recover as we do in other + // cases. We don't know how many bytes of the stream the client has already sent for the + // stream, it's not worth trying to recover. + channel.pipeline().fireExceptionCaught(e); + } finally { + req.meta.release(); + } + } + private void processOneWayMessage(OneWayMessage req) { try { rpcHandler.receive(reverseClient, req.body().nioByteBuffer()); diff --git a/common/network-common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java b/common/network-common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java index 8ff737b129641..1f4d75c7e2ec5 100644 --- a/common/network-common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java +++ b/common/network-common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java @@ -17,43 +17,46 @@ package org.apache.spark.network; +import java.io.*; import java.nio.ByteBuffer; -import java.util.ArrayList; -import java.util.Collections; -import java.util.HashSet; -import java.util.Iterator; -import java.util.List; -import java.util.Set; +import java.util.*; +import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.Semaphore; import java.util.concurrent.TimeUnit; import com.google.common.collect.Sets; +import com.google.common.io.Files; +import org.apache.commons.lang3.tuple.ImmutablePair; +import org.apache.commons.lang3.tuple.Pair; import org.junit.AfterClass; import org.junit.BeforeClass; import org.junit.Test; import static org.junit.Assert.*; -import org.apache.spark.network.client.RpcResponseCallback; -import org.apache.spark.network.client.TransportClient; -import org.apache.spark.network.client.TransportClientFactory; -import org.apache.spark.network.server.OneForOneStreamManager; -import org.apache.spark.network.server.RpcHandler; -import org.apache.spark.network.server.StreamManager; -import org.apache.spark.network.server.TransportServer; +import org.apache.spark.network.buffer.ManagedBuffer; +import org.apache.spark.network.buffer.NioManagedBuffer; +import org.apache.spark.network.client.*; +import org.apache.spark.network.server.*; import org.apache.spark.network.util.JavaUtils; import org.apache.spark.network.util.MapConfigProvider; import org.apache.spark.network.util.TransportConf; public class RpcIntegrationSuite { + static TransportConf conf; static TransportServer server; static TransportClientFactory clientFactory; static RpcHandler rpcHandler; static List oneWayMsgs; + static StreamTestHelper testData; + + static ConcurrentHashMap streamCallbacks = + new ConcurrentHashMap<>(); @BeforeClass public static void setUp() throws Exception { - TransportConf conf = new TransportConf("shuffle", MapConfigProvider.EMPTY); + conf = new TransportConf("shuffle", MapConfigProvider.EMPTY); + testData = new StreamTestHelper(); rpcHandler = new RpcHandler() { @Override public void receive( @@ -71,6 +74,14 @@ public void receive( } } + @Override + public StreamCallbackWithID receiveStream( + TransportClient client, + ByteBuffer messageHeader, + RpcResponseCallback callback) { + return receiveStreamHelper(JavaUtils.bytesToString(messageHeader)); + } + @Override public void receive(TransportClient client, ByteBuffer message) { oneWayMsgs.add(JavaUtils.bytesToString(message)); @@ -85,10 +96,71 @@ public void receive(TransportClient client, ByteBuffer message) { oneWayMsgs = new ArrayList<>(); } + private static StreamCallbackWithID receiveStreamHelper(String msg) { + try { + if (msg.startsWith("fail/")) { + String[] parts = msg.split("/"); + switch (parts[1]) { + case "exception-ondata": + return new StreamCallbackWithID() { + @Override + public void onData(String streamId, ByteBuffer buf) throws IOException { + throw new IOException("failed to read stream data!"); + } + + @Override + public void onComplete(String streamId) throws IOException { + } + + @Override + public void onFailure(String streamId, Throwable cause) throws IOException { + } + + @Override + public String getID() { + return msg; + } + }; + case "exception-oncomplete": + return new StreamCallbackWithID() { + @Override + public void onData(String streamId, ByteBuffer buf) throws IOException { + } + + @Override + public void onComplete(String streamId) throws IOException { + throw new IOException("exception in onComplete"); + } + + @Override + public void onFailure(String streamId, Throwable cause) throws IOException { + } + + @Override + public String getID() { + return msg; + } + }; + case "null": + return null; + default: + throw new IllegalArgumentException("unexpected msg: " + msg); + } + } else { + VerifyingStreamCallback streamCallback = new VerifyingStreamCallback(msg); + streamCallbacks.put(msg, streamCallback); + return streamCallback; + } + } catch (IOException e) { + throw new RuntimeException(e); + } + } + @AfterClass public static void tearDown() { server.close(); clientFactory.close(); + testData.cleanup(); } static class RpcResult { @@ -130,6 +202,59 @@ public void onFailure(Throwable e) { return res; } + private RpcResult sendRpcWithStream(String... streams) throws Exception { + TransportClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort()); + final Semaphore sem = new Semaphore(0); + RpcResult res = new RpcResult(); + res.successMessages = Collections.synchronizedSet(new HashSet()); + res.errorMessages = Collections.synchronizedSet(new HashSet()); + + for (String stream : streams) { + int idx = stream.lastIndexOf('/'); + ManagedBuffer meta = new NioManagedBuffer(JavaUtils.stringToBytes(stream)); + String streamName = (idx == -1) ? stream : stream.substring(idx + 1); + ManagedBuffer data = testData.openStream(conf, streamName); + client.uploadStream(meta, data, new RpcStreamCallback(stream, res, sem)); + } + + if (!sem.tryAcquire(streams.length, 5, TimeUnit.SECONDS)) { + fail("Timeout getting response from the server"); + } + streamCallbacks.values().forEach(streamCallback -> { + try { + streamCallback.verify(); + } catch (IOException e) { + throw new RuntimeException(e); + } + }); + client.close(); + return res; + } + + private static class RpcStreamCallback implements RpcResponseCallback { + final String streamId; + final RpcResult res; + final Semaphore sem; + + RpcStreamCallback(String streamId, RpcResult res, Semaphore sem) { + this.streamId = streamId; + this.res = res; + this.sem = sem; + } + + @Override + public void onSuccess(ByteBuffer message) { + res.successMessages.add(streamId); + sem.release(); + } + + @Override + public void onFailure(Throwable e) { + res.errorMessages.add(e.getMessage()); + sem.release(); + } + } + @Test public void singleRPC() throws Exception { RpcResult res = sendRPC("hello/Aaron"); @@ -193,10 +318,83 @@ public void sendOneWayMessage() throws Exception { } } + @Test + public void sendRpcWithStreamOneAtATime() throws Exception { + for (String stream : StreamTestHelper.STREAMS) { + RpcResult res = sendRpcWithStream(stream); + assertTrue("there were error messages!" + res.errorMessages, res.errorMessages.isEmpty()); + assertEquals(Sets.newHashSet(stream), res.successMessages); + } + } + + @Test + public void sendRpcWithStreamConcurrently() throws Exception { + String[] streams = new String[10]; + for (int i = 0; i < 10; i++) { + streams[i] = StreamTestHelper.STREAMS[i % StreamTestHelper.STREAMS.length]; + } + RpcResult res = sendRpcWithStream(streams); + assertEquals(Sets.newHashSet(StreamTestHelper.STREAMS), res.successMessages); + assertTrue(res.errorMessages.isEmpty()); + } + + @Test + public void sendRpcWithStreamFailures() throws Exception { + // when there is a failure reading stream data, we don't try to keep the channel usable, + // just send back a decent error msg. + RpcResult exceptionInCallbackResult = + sendRpcWithStream("fail/exception-ondata/smallBuffer", "smallBuffer"); + assertErrorAndClosed(exceptionInCallbackResult, "Destination failed while reading stream"); + + RpcResult nullStreamHandler = + sendRpcWithStream("fail/null/smallBuffer", "smallBuffer"); + assertErrorAndClosed(exceptionInCallbackResult, "Destination failed while reading stream"); + + // OTOH, if there is a failure during onComplete, the channel should still be fine + RpcResult exceptionInOnComplete = + sendRpcWithStream("fail/exception-oncomplete/smallBuffer", "smallBuffer"); + assertErrorsContain(exceptionInOnComplete.errorMessages, + Sets.newHashSet("Failure post-processing")); + assertEquals(Sets.newHashSet("smallBuffer"), exceptionInOnComplete.successMessages); + } + private void assertErrorsContain(Set errors, Set contains) { - assertEquals(contains.size(), errors.size()); + assertEquals("Expected " + contains.size() + " errors, got " + errors.size() + "errors: " + + errors, contains.size(), errors.size()); + + Pair, Set> r = checkErrorsContain(errors, contains); + assertTrue("Could not find error containing " + r.getRight() + "; errors: " + errors, + r.getRight().isEmpty()); + + assertTrue(r.getLeft().isEmpty()); + } + + private void assertErrorAndClosed(RpcResult result, String expectedError) { + assertTrue("unexpected success: " + result.successMessages, result.successMessages.isEmpty()); + // we expect 1 additional error, which contains *either* "closed" or "Connection reset" + Set errors = result.errorMessages; + assertEquals("Expected 2 errors, got " + errors.size() + "errors: " + + errors, 2, errors.size()); + + Set containsAndClosed = Sets.newHashSet(expectedError); + containsAndClosed.add("closed"); + containsAndClosed.add("Connection reset"); + + Pair, Set> r = checkErrorsContain(errors, containsAndClosed); + Set errorsNotFound = r.getRight(); + assertEquals(1, errorsNotFound.size()); + String err = errorsNotFound.iterator().next(); + assertTrue(err.equals("closed") || err.equals("Connection reset")); + + assertTrue(r.getLeft().isEmpty()); + } + + private Pair, Set> checkErrorsContain( + Set errors, + Set contains) { Set remainingErrors = Sets.newHashSet(errors); + Set notFound = Sets.newHashSet(); for (String contain : contains) { Iterator it = remainingErrors.iterator(); boolean foundMatch = false; @@ -207,9 +405,66 @@ private void assertErrorsContain(Set errors, Set contains) { break; } } - assertTrue("Could not find error containing " + contain + "; errors: " + errors, foundMatch); + if (!foundMatch) { + notFound.add(contain); + } + } + return new ImmutablePair<>(remainingErrors, notFound); + } + + private static class VerifyingStreamCallback implements StreamCallbackWithID { + final String streamId; + final StreamSuite.TestCallback helper; + final OutputStream out; + final File outFile; + + VerifyingStreamCallback(String streamId) throws IOException { + if (streamId.equals("file")) { + outFile = File.createTempFile("data", ".tmp", testData.tempDir); + out = new FileOutputStream(outFile); + } else { + out = new ByteArrayOutputStream(); + outFile = null; + } + this.streamId = streamId; + helper = new StreamSuite.TestCallback(out); + } + + void verify() throws IOException { + if (streamId.equals("file")) { + assertTrue("File stream did not match.", Files.equal(testData.testFile, outFile)); + } else { + byte[] result = ((ByteArrayOutputStream)out).toByteArray(); + ByteBuffer srcBuffer = testData.srcBuffer(streamId); + ByteBuffer base; + synchronized (srcBuffer) { + base = srcBuffer.duplicate(); + } + byte[] expected = new byte[base.remaining()]; + base.get(expected); + assertEquals(expected.length, result.length); + assertTrue("buffers don't match", Arrays.equals(expected, result)); + } + } + + @Override + public void onData(String streamId, ByteBuffer buf) throws IOException { + helper.onData(streamId, buf); } - assertTrue(remainingErrors.isEmpty()); + @Override + public void onComplete(String streamId) throws IOException { + helper.onComplete(streamId); + } + + @Override + public void onFailure(String streamId, Throwable cause) throws IOException { + helper.onFailure(streamId, cause); + } + + @Override + public String getID() { + return streamId; + } } } diff --git a/common/network-common/src/test/java/org/apache/spark/network/StreamSuite.java b/common/network-common/src/test/java/org/apache/spark/network/StreamSuite.java index f253a07e64be1..f3050cb79cdfd 100644 --- a/common/network-common/src/test/java/org/apache/spark/network/StreamSuite.java +++ b/common/network-common/src/test/java/org/apache/spark/network/StreamSuite.java @@ -26,7 +26,6 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.List; -import java.util.Random; import java.util.concurrent.Executors; import java.util.concurrent.ExecutorService; import java.util.concurrent.TimeUnit; @@ -37,9 +36,7 @@ import org.junit.Test; import static org.junit.Assert.*; -import org.apache.spark.network.buffer.FileSegmentManagedBuffer; import org.apache.spark.network.buffer.ManagedBuffer; -import org.apache.spark.network.buffer.NioManagedBuffer; import org.apache.spark.network.client.RpcResponseCallback; import org.apache.spark.network.client.StreamCallback; import org.apache.spark.network.client.TransportClient; @@ -51,16 +48,11 @@ import org.apache.spark.network.util.TransportConf; public class StreamSuite { - private static final String[] STREAMS = { "largeBuffer", "smallBuffer", "emptyBuffer", "file" }; + private static final String[] STREAMS = StreamTestHelper.STREAMS; + private static StreamTestHelper testData; private static TransportServer server; private static TransportClientFactory clientFactory; - private static File testFile; - private static File tempDir; - - private static ByteBuffer emptyBuffer; - private static ByteBuffer smallBuffer; - private static ByteBuffer largeBuffer; private static ByteBuffer createBuffer(int bufSize) { ByteBuffer buf = ByteBuffer.allocate(bufSize); @@ -73,23 +65,7 @@ private static ByteBuffer createBuffer(int bufSize) { @BeforeClass public static void setUp() throws Exception { - tempDir = Files.createTempDir(); - emptyBuffer = createBuffer(0); - smallBuffer = createBuffer(100); - largeBuffer = createBuffer(100000); - - testFile = File.createTempFile("stream-test-file", "txt", tempDir); - FileOutputStream fp = new FileOutputStream(testFile); - try { - Random rnd = new Random(); - for (int i = 0; i < 512; i++) { - byte[] fileContent = new byte[1024]; - rnd.nextBytes(fileContent); - fp.write(fileContent); - } - } finally { - fp.close(); - } + testData = new StreamTestHelper(); final TransportConf conf = new TransportConf("shuffle", MapConfigProvider.EMPTY); final StreamManager streamManager = new StreamManager() { @@ -100,18 +76,7 @@ public ManagedBuffer getChunk(long streamId, int chunkIndex) { @Override public ManagedBuffer openStream(String streamId) { - switch (streamId) { - case "largeBuffer": - return new NioManagedBuffer(largeBuffer); - case "smallBuffer": - return new NioManagedBuffer(smallBuffer); - case "emptyBuffer": - return new NioManagedBuffer(emptyBuffer); - case "file": - return new FileSegmentManagedBuffer(conf, testFile, 0, testFile.length()); - default: - throw new IllegalArgumentException("Invalid stream: " + streamId); - } + return testData.openStream(conf, streamId); } }; RpcHandler handler = new RpcHandler() { @@ -137,12 +102,7 @@ public StreamManager getStreamManager() { public static void tearDown() { server.close(); clientFactory.close(); - if (tempDir != null) { - for (File f : tempDir.listFiles()) { - f.delete(); - } - tempDir.delete(); - } + testData.cleanup(); } @Test @@ -234,21 +194,21 @@ public void run() { case "largeBuffer": baos = new ByteArrayOutputStream(); out = baos; - srcBuffer = largeBuffer; + srcBuffer = testData.largeBuffer; break; case "smallBuffer": baos = new ByteArrayOutputStream(); out = baos; - srcBuffer = smallBuffer; + srcBuffer = testData.smallBuffer; break; case "file": - outFile = File.createTempFile("data", ".tmp", tempDir); + outFile = File.createTempFile("data", ".tmp", testData.tempDir); out = new FileOutputStream(outFile); break; case "emptyBuffer": baos = new ByteArrayOutputStream(); out = baos; - srcBuffer = emptyBuffer; + srcBuffer = testData.emptyBuffer; break; default: throw new IllegalArgumentException(streamId); @@ -256,10 +216,10 @@ public void run() { TestCallback callback = new TestCallback(out); client.stream(streamId, callback); - waitForCompletion(callback); + callback.waitForCompletion(timeoutMs); if (srcBuffer == null) { - assertTrue("File stream did not match.", Files.equal(testFile, outFile)); + assertTrue("File stream did not match.", Files.equal(testData.testFile, outFile)); } else { ByteBuffer base; synchronized (srcBuffer) { @@ -292,23 +252,9 @@ public void check() throws Throwable { throw error; } } - - private void waitForCompletion(TestCallback callback) throws Exception { - long now = System.currentTimeMillis(); - long deadline = now + timeoutMs; - synchronized (callback) { - while (!callback.completed && now < deadline) { - callback.wait(deadline - now); - now = System.currentTimeMillis(); - } - } - assertTrue("Timed out waiting for stream.", callback.completed); - assertNull(callback.error); - } - } - private static class TestCallback implements StreamCallback { + static class TestCallback implements StreamCallback { private final OutputStream out; public volatile boolean completed; @@ -344,6 +290,22 @@ public void onFailure(String streamId, Throwable cause) { } } + void waitForCompletion(long timeoutMs) { + long now = System.currentTimeMillis(); + long deadline = now + timeoutMs; + synchronized (this) { + while (!completed && now < deadline) { + try { + wait(deadline - now); + } catch (InterruptedException ie) { + throw new RuntimeException(ie); + } + now = System.currentTimeMillis(); + } + } + assertTrue("Timed out waiting for stream.", completed); + assertNull(error); + } } } diff --git a/common/network-common/src/test/java/org/apache/spark/network/StreamTestHelper.java b/common/network-common/src/test/java/org/apache/spark/network/StreamTestHelper.java new file mode 100644 index 0000000000000..0f5c82c9e9b1f --- /dev/null +++ b/common/network-common/src/test/java/org/apache/spark/network/StreamTestHelper.java @@ -0,0 +1,104 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network; + +import java.io.File; +import java.io.FileOutputStream; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.Random; + +import com.google.common.io.Files; + +import org.apache.spark.network.buffer.FileSegmentManagedBuffer; +import org.apache.spark.network.buffer.ManagedBuffer; +import org.apache.spark.network.buffer.NioManagedBuffer; +import org.apache.spark.network.util.JavaUtils; +import org.apache.spark.network.util.TransportConf; + +class StreamTestHelper { + static final String[] STREAMS = { "largeBuffer", "smallBuffer", "emptyBuffer", "file" }; + + final File testFile; + final File tempDir; + + final ByteBuffer emptyBuffer; + final ByteBuffer smallBuffer; + final ByteBuffer largeBuffer; + + private static ByteBuffer createBuffer(int bufSize) { + ByteBuffer buf = ByteBuffer.allocate(bufSize); + for (int i = 0; i < bufSize; i ++) { + buf.put((byte) i); + } + buf.flip(); + return buf; + } + + StreamTestHelper() throws Exception { + tempDir = Files.createTempDir(); + emptyBuffer = createBuffer(0); + smallBuffer = createBuffer(100); + largeBuffer = createBuffer(100000); + + testFile = File.createTempFile("stream-test-file", "txt", tempDir); + FileOutputStream fp = new FileOutputStream(testFile); + try { + Random rnd = new Random(); + for (int i = 0; i < 512; i++) { + byte[] fileContent = new byte[1024]; + rnd.nextBytes(fileContent); + fp.write(fileContent); + } + } finally { + fp.close(); + } + } + + public ByteBuffer srcBuffer(String name) { + switch (name) { + case "largeBuffer": + return largeBuffer; + case "smallBuffer": + return smallBuffer; + case "emptyBuffer": + return emptyBuffer; + default: + throw new IllegalArgumentException("Invalid stream: " + name); + } + } + + public ManagedBuffer openStream(TransportConf conf, String streamId) { + switch (streamId) { + case "file": + return new FileSegmentManagedBuffer(conf, testFile, 0, testFile.length()); + default: + return new NioManagedBuffer(srcBuffer(streamId)); + } + } + + void cleanup() { + if (tempDir != null) { + try { + JavaUtils.deleteRecursively(tempDir); + } catch (IOException io) { + throw new RuntimeException(io); + } + } + } +} diff --git a/core/pom.xml b/core/pom.xml index 220522d3a8296..d0b869e6ef92c 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -56,7 +56,7 @@ org.apache.xbean - xbean-asm5-shaded + xbean-asm6-shaded org.apache.hadoop diff --git a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java index 5f0045507aaab..9a767dd739b91 100644 --- a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java +++ b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java @@ -703,7 +703,7 @@ public boolean append(Object kbase, long koff, int klen, Object vbase, long voff // must be stored in the same memory page. // (8 byte key length) (key) (value) (8 byte pointer to next value) int uaoSize = UnsafeAlignedOffset.getUaoSize(); - final long recordLength = (2 * uaoSize) + klen + vlen + 8; + final long recordLength = (2L * uaoSize) + klen + vlen + 8; if (currentPage == null || currentPage.size() - pageCursor < recordLength) { if (!acquireNewPage(recordLength + uaoSize)) { return false; diff --git a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala index aa363eeffffb8..17b88631bcb4c 100644 --- a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala +++ b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala @@ -488,9 +488,15 @@ private[spark] class ExecutorAllocationManager( newExecutorTotal = numExistingExecutors if (testing || executorsRemoved.nonEmpty) { executorsRemoved.foreach { removedExecutorId => + // If it is a cached block, it uses cachedExecutorIdleTimeoutS for timeout + val idleTimeout = if (blockManagerMaster.hasCachedBlocks(removedExecutorId)) { + cachedExecutorIdleTimeoutS + } else { + executorIdleTimeoutS + } newExecutorTotal -= 1 logInfo(s"Removing executor $removedExecutorId because it has been idle for " + - s"$executorIdleTimeoutS seconds (new desired total will be $newExecutorTotal)") + s"$idleTimeout seconds (new desired total will be $newExecutorTotal)") executorsPendingToRemove.add(removedExecutorId) } executorsRemoved diff --git a/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala b/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala index ff960b396dbf1..bcbc8df0d5865 100644 --- a/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala +++ b/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala @@ -74,10 +74,9 @@ private[spark] class HeartbeatReceiver(sc: SparkContext, clock: Clock) // "spark.network.timeout" uses "seconds", while `spark.storage.blockManagerSlaveTimeoutMs` uses // "milliseconds" - private val slaveTimeoutMs = - sc.conf.getTimeAsMs("spark.storage.blockManagerSlaveTimeoutMs", "120s") private val executorTimeoutMs = - sc.conf.getTimeAsSeconds("spark.network.timeout", s"${slaveTimeoutMs}ms") * 1000 + sc.conf.getTimeAsMs("spark.storage.blockManagerSlaveTimeoutMs", + s"${sc.conf.getTimeAsSeconds("spark.network.timeout", "120s")}s") // "spark.network.timeoutInterval" uses "seconds", while // "spark.storage.blockManagerTimeoutIntervalMs" uses "milliseconds" diff --git a/core/src/main/scala/org/apache/spark/SSLOptions.scala b/core/src/main/scala/org/apache/spark/SSLOptions.scala index 04c38f12acc78..1632e0c69eef5 100644 --- a/core/src/main/scala/org/apache/spark/SSLOptions.scala +++ b/core/src/main/scala/org/apache/spark/SSLOptions.scala @@ -21,6 +21,7 @@ import java.io.File import java.security.NoSuchAlgorithmException import javax.net.ssl.SSLContext +import org.apache.hadoop.conf.Configuration import org.eclipse.jetty.util.ssl.SslContextFactory import org.apache.spark.internal.Logging @@ -163,11 +164,16 @@ private[spark] object SSLOptions extends Logging { * missing in SparkConf, the corresponding setting is used from the default configuration. * * @param conf Spark configuration object where the settings are collected from + * @param hadoopConf Hadoop configuration to get settings * @param ns the namespace name * @param defaults the default configuration * @return [[org.apache.spark.SSLOptions]] object */ - def parse(conf: SparkConf, ns: String, defaults: Option[SSLOptions] = None): SSLOptions = { + def parse( + conf: SparkConf, + hadoopConf: Configuration, + ns: String, + defaults: Option[SSLOptions] = None): SSLOptions = { val enabled = conf.getBoolean(s"$ns.enabled", defaultValue = defaults.exists(_.enabled)) val port = conf.getWithSubstitution(s"$ns.port").map(_.toInt) @@ -179,9 +185,11 @@ private[spark] object SSLOptions extends Logging { .orElse(defaults.flatMap(_.keyStore)) val keyStorePassword = conf.getWithSubstitution(s"$ns.keyStorePassword") + .orElse(Option(hadoopConf.getPassword(s"$ns.keyStorePassword")).map(new String(_))) .orElse(defaults.flatMap(_.keyStorePassword)) val keyPassword = conf.getWithSubstitution(s"$ns.keyPassword") + .orElse(Option(hadoopConf.getPassword(s"$ns.keyPassword")).map(new String(_))) .orElse(defaults.flatMap(_.keyPassword)) val keyStoreType = conf.getWithSubstitution(s"$ns.keyStoreType") @@ -194,6 +202,7 @@ private[spark] object SSLOptions extends Logging { .orElse(defaults.flatMap(_.trustStore)) val trustStorePassword = conf.getWithSubstitution(s"$ns.trustStorePassword") + .orElse(Option(hadoopConf.getPassword(s"$ns.trustStorePassword")).map(new String(_))) .orElse(defaults.flatMap(_.trustStorePassword)) val trustStoreType = conf.getWithSubstitution(s"$ns.trustStoreType") diff --git a/core/src/main/scala/org/apache/spark/SecurityManager.scala b/core/src/main/scala/org/apache/spark/SecurityManager.scala index b87476322573d..3cfafeb951105 100644 --- a/core/src/main/scala/org/apache/spark/SecurityManager.scala +++ b/core/src/main/scala/org/apache/spark/SecurityManager.scala @@ -19,11 +19,11 @@ package org.apache.spark import java.net.{Authenticator, PasswordAuthentication} import java.nio.charset.StandardCharsets.UTF_8 -import javax.net.ssl._ import org.apache.hadoop.io.Text import org.apache.hadoop.security.{Credentials, UserGroupInformation} +import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.internal.Logging import org.apache.spark.internal.config._ import org.apache.spark.launcher.SparkLauncher @@ -111,11 +111,14 @@ private[spark] class SecurityManager( ) } + private val hadoopConf = SparkHadoopUtil.get.newConfiguration(sparkConf) // the default SSL configuration - it will be used by all communication layers unless overwritten - private val defaultSSLOptions = SSLOptions.parse(sparkConf, "spark.ssl", defaults = None) + private val defaultSSLOptions = + SSLOptions.parse(sparkConf, hadoopConf, "spark.ssl", defaults = None) def getSSLOptions(module: String): SSLOptions = { - val opts = SSLOptions.parse(sparkConf, s"spark.ssl.$module", Some(defaultSSLOptions)) + val opts = + SSLOptions.parse(sparkConf, hadoopConf, s"spark.ssl.$module", Some(defaultSSLOptions)) logDebug(s"Created SSL options for $module: $opts") opts } diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 5e8595603cc90..531384ab57305 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -1306,11 +1306,12 @@ class SparkContext(config: SparkConf) extends Logging { /** Build the union of a list of RDDs. */ def union[T: ClassTag](rdds: Seq[RDD[T]]): RDD[T] = withScope { - val partitioners = rdds.flatMap(_.partitioner).toSet - if (rdds.forall(_.partitioner.isDefined) && partitioners.size == 1) { - new PartitionerAwareUnionRDD(this, rdds) + val nonEmptyRdds = rdds.filter(!_.partitions.isEmpty) + val partitioners = nonEmptyRdds.flatMap(_.partitioner).toSet + if (nonEmptyRdds.forall(_.partitioner.isDefined) && partitioners.size == 1) { + new PartitionerAwareUnionRDD(this, nonEmptyRdds) } else { - new UnionRDD(this, rdds) + new UnionRDD(this, nonEmptyRdds) } } @@ -1495,6 +1496,8 @@ class SparkContext(config: SparkConf) extends Logging { * @param path can be either a local file, a file in HDFS (or other Hadoop-supported * filesystems), or an HTTP, HTTPS or FTP URI. To access the file in Spark jobs, * use `SparkFiles.get(fileName)` to find its download location. + * + * @note A path can be added only once. Subsequent additions of the same path are ignored. */ def addFile(path: String): Unit = { addFile(path, false) @@ -1515,6 +1518,8 @@ class SparkContext(config: SparkConf) extends Logging { * use `SparkFiles.get(fileName)` to find its download location. * @param recursive if true, a directory can be given in `path`. Currently directories are * only supported for Hadoop-supported filesystems. + * + * @note A path can be added only once. Subsequent additions of the same path are ignored. */ def addFile(path: String, recursive: Boolean): Unit = { val uri = new Path(path).toUri @@ -1554,6 +1559,9 @@ class SparkContext(config: SparkConf) extends Logging { Utils.fetchFile(uri.toString, new File(SparkFiles.getRootDirectory()), conf, env.securityManager, hadoopConfiguration, timestamp, useCache = false) postEnvironmentUpdate() + } else { + logWarning(s"The path $path has been added already. Overwriting of added paths " + + "is not supported in the current version.") } } @@ -1802,6 +1810,8 @@ class SparkContext(config: SparkConf) extends Logging { * * @param path can be either a local file, a file in HDFS (or other Hadoop-supported filesystems), * an HTTP, HTTPS or FTP URI, or local:/path for a file on every worker node. + * + * @note A path can be added only once. Subsequent additions of the same path are ignored. */ def addJar(path: String) { def addJarFile(file: File): String = { @@ -1848,6 +1858,9 @@ class SparkContext(config: SparkConf) extends Logging { if (addedJars.putIfAbsent(key, timestamp).isEmpty) { logInfo(s"Added JAR $path at $key with timestamp $timestamp") postEnvironmentUpdate() + } else { + logWarning(s"The jar $path has been added already. Overwriting of added jars " + + "is not supported in the current version.") } } } diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala index f1936bf587282..09c83849e26b2 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala @@ -668,6 +668,8 @@ class JavaSparkContext(val sc: SparkContext) * The `path` passed can be either a local file, a file in HDFS (or other Hadoop-supported * filesystems), or an HTTP, HTTPS or FTP URI. To access the file in Spark jobs, * use `SparkFiles.get(fileName)` to find its download location. + * + * @note A path can be added only once. Subsequent additions of the same path are ignored. */ def addFile(path: String) { sc.addFile(path) @@ -681,6 +683,8 @@ class JavaSparkContext(val sc: SparkContext) * * A directory can be given if the recursive option is set to true. Currently directories are only * supported for Hadoop-supported filesystems. + * + * @note A path can be added only once. Subsequent additions of the same path are ignored. */ def addFile(path: String, recursive: Boolean): Unit = { sc.addFile(path, recursive) @@ -690,6 +694,8 @@ class JavaSparkContext(val sc: SparkContext) * Adds a JAR dependency for all tasks to be executed on this SparkContext in the future. * The `path` passed can be either a local file, a file in HDFS (or other Hadoop-supported * filesystems), or an HTTP, HTTPS or FTP URI. + * + * @note A path can be added only once. Subsequent additions of the same path are ignored. */ def addJar(path: String) { sc.addJar(path) diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index e83d82f847c61..e7310ee886103 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -385,7 +385,7 @@ private[spark] class SparkSubmit extends Logging { val forceDownloadSchemes = sparkConf.get(FORCE_DOWNLOAD_SCHEMES) def shouldDownload(scheme: String): Boolean = { - forceDownloadSchemes.contains(scheme) || + forceDownloadSchemes.contains("*") || forceDownloadSchemes.contains(scheme) || Try { FileSystem.getFileSystemClass(scheme, hadoopConf) }.isFailure } @@ -578,7 +578,8 @@ private[spark] class SparkSubmit extends Logging { } // Add the main application jar and any added jars to classpath in case YARN client // requires these jars. - // This assumes both primaryResource and user jars are local jars, otherwise it will not be + // This assumes both primaryResource and user jars are local jars, or already downloaded + // to local by configuring "spark.yarn.dist.forceDownloadSchemes", otherwise it will not be // added to the classpath of YARN client. if (isYarnCluster) { if (isUserJar(args.primaryResource)) { @@ -702,6 +703,8 @@ private[spark] class SparkSubmit extends Logging { childArgs ++= Array("--primary-java-resource", args.primaryResource) childArgs ++= Array("--main-class", args.mainClass) } + } else { + childArgs ++= Array("--main-class", args.mainClass) } if (args.childArgs != null) { args.childArgs.foreach { arg => diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala index 066275e8f8425..56f3f59504a7d 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala @@ -124,7 +124,7 @@ class HistoryServer( attachHandler(ApiRootResource.getServletHandler(this)) - attachHandler(createStaticHandler(SparkUI.STATIC_RESOURCE_DIR, "/static")) + addStaticHandler(SparkUI.STATIC_RESOURCE_DIR) val contextHandler = new ServletContextHandler contextHandler.setContextPath(HistoryServer.UI_PATH_PREFIX) diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala index 35b7ddd46e4db..e87b2240564bd 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala @@ -43,7 +43,7 @@ class MasterWebUI( val masterPage = new MasterPage(this) attachPage(new ApplicationPage(this)) attachPage(masterPage) - attachHandler(createStaticHandler(MasterWebUI.STATIC_RESOURCE_DIR, "/static")) + addStaticHandler(MasterWebUI.STATIC_RESOURCE_DIR) attachHandler(createRedirectHandler( "/app/kill", "/", masterPage.handleAppKillRequest, httpMethods = Set("POST"))) attachHandler(createRedirectHandler( diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionClient.scala b/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionClient.scala index 742a95841a138..31a8e3e60c067 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionClient.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionClient.scala @@ -233,30 +233,44 @@ private[spark] class RestSubmissionClient(master: String) extends Logging { private[rest] def readResponse(connection: HttpURLConnection): SubmitRestProtocolResponse = { import scala.concurrent.ExecutionContext.Implicits.global val responseFuture = Future { - val dataStream = - if (connection.getResponseCode == HttpServletResponse.SC_OK) { - connection.getInputStream - } else { - connection.getErrorStream + val responseCode = connection.getResponseCode + + if (responseCode != HttpServletResponse.SC_OK) { + val errString = Some(Source.fromInputStream(connection.getErrorStream()) + .getLines().mkString("\n")) + if (responseCode == HttpServletResponse.SC_INTERNAL_SERVER_ERROR && + !connection.getContentType().contains("application/json")) { + throw new SubmitRestProtocolException(s"Server responded with exception:\n${errString}") + } + logError(s"Server responded with error:\n${errString}") + val error = new ErrorResponse + if (responseCode == RestSubmissionServer.SC_UNKNOWN_PROTOCOL_VERSION) { + error.highestProtocolVersion = RestSubmissionServer.PROTOCOL_VERSION + } + error.message = errString.get + error + } else { + val dataStream = connection.getInputStream + + // If the server threw an exception while writing a response, it will not have a body + if (dataStream == null) { + throw new SubmitRestProtocolException("Server returned empty body") + } + val responseJson = Source.fromInputStream(dataStream).mkString + logDebug(s"Response from the server:\n$responseJson") + val response = SubmitRestProtocolMessage.fromJson(responseJson) + response.validate() + response match { + // If the response is an error, log the message + case error: ErrorResponse => + logError(s"Server responded with error:\n${error.message}") + error + // Otherwise, simply return the response + case response: SubmitRestProtocolResponse => response + case unexpected => + throw new SubmitRestProtocolException( + s"Message received from server was not a response:\n${unexpected.toJson}") } - // If the server threw an exception while writing a response, it will not have a body - if (dataStream == null) { - throw new SubmitRestProtocolException("Server returned empty body") - } - val responseJson = Source.fromInputStream(dataStream).mkString - logDebug(s"Response from the server:\n$responseJson") - val response = SubmitRestProtocolMessage.fromJson(responseJson) - response.validate() - response match { - // If the response is an error, log the message - case error: ErrorResponse => - logError(s"Server responded with error:\n${error.message}") - error - // Otherwise, simply return the response - case response: SubmitRestProtocolResponse => response - case unexpected => - throw new SubmitRestProtocolException( - s"Message received from server was not a response:\n${unexpected.toJson}") } } diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala b/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala index 58a181128eb4d..a6d13d12fc28d 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala @@ -225,7 +225,7 @@ private[deploy] class DriverRunner( // check if attempting another run keepTrying = supervise && exitCode != 0 && !killed if (keepTrying) { - if (clock.getTimeMillis() - processStart > successfulRunDuration * 1000) { + if (clock.getTimeMillis() - processStart > successfulRunDuration * 1000L) { waitSeconds = 1 } logInfo(s"Command exited with status $exitCode, re-launching after $waitSeconds s.") diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala index db696b04384bd..ea67b7434a769 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala @@ -47,7 +47,7 @@ class WorkerWebUI( val logPage = new LogPage(this) attachPage(logPage) attachPage(new WorkerPage(this)) - attachHandler(createStaticHandler(WorkerWebUI.STATIC_RESOURCE_BASE, "/static")) + addStaticHandler(WorkerWebUI.STATIC_RESOURCE_BASE) attachHandler(createServletHandler("/log", (request: HttpServletRequest) => logPage.renderLog(request), worker.securityMgr, diff --git a/core/src/main/scala/org/apache/spark/input/WholeTextFileInputFormat.scala b/core/src/main/scala/org/apache/spark/input/WholeTextFileInputFormat.scala index f47cd38d712c3..04c5c4b90e8a1 100644 --- a/core/src/main/scala/org/apache/spark/input/WholeTextFileInputFormat.scala +++ b/core/src/main/scala/org/apache/spark/input/WholeTextFileInputFormat.scala @@ -53,6 +53,19 @@ private[spark] class WholeTextFileInputFormat val totalLen = files.map(file => if (file.isDirectory) 0L else file.getLen).sum val maxSplitSize = Math.ceil(totalLen * 1.0 / (if (minPartitions == 0) 1 else minPartitions)).toLong + + // For small files we need to ensure the min split size per node & rack <= maxSplitSize + val config = context.getConfiguration + val minSplitSizePerNode = config.getLong(CombineFileInputFormat.SPLIT_MINSIZE_PERNODE, 0L) + val minSplitSizePerRack = config.getLong(CombineFileInputFormat.SPLIT_MINSIZE_PERRACK, 0L) + + if (maxSplitSize < minSplitSizePerNode) { + super.setMinSplitSizeNode(maxSplitSize) + } + + if (maxSplitSize < minSplitSizePerRack) { + super.setMinSplitSizeRack(maxSplitSize) + } super.setMaxSplitSize(maxSplitSize) } } diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index a54b091a64d50..ba892bf7f60d6 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -72,6 +72,9 @@ package object config { private[spark] val EVENT_LOG_OVERWRITE = ConfigBuilder("spark.eventLog.overwrite").booleanConf.createWithDefault(false) + private[spark] val EVENT_LOG_CALLSITE_FORM = + ConfigBuilder("spark.eventLog.callsite").stringConf.createWithDefault("short") + private[spark] val EXECUTOR_CLASS_PATH = ConfigBuilder(SparkLauncher.EXECUTOR_EXTRA_CLASSPATH).stringConf.createOptional @@ -483,10 +486,11 @@ package object config { private[spark] val FORCE_DOWNLOAD_SCHEMES = ConfigBuilder("spark.yarn.dist.forceDownloadSchemes") - .doc("Comma-separated list of schemes for which files will be downloaded to the " + + .doc("Comma-separated list of schemes for which resources will be downloaded to the " + "local disk prior to being added to YARN's distributed cache. For use in cases " + "where the YARN service does not support schemes that are supported by Spark, like http, " + - "https and ftp.") + "https and ftp, or jars required to be in the local YARN client's classpath. Wildcard " + + "'*' is denoted to download resources for all the schemes.") .stringConf .toSequence .createWithDefault(Nil) @@ -552,4 +556,11 @@ package object config { .timeConf(TimeUnit.SECONDS) .createWithDefaultString("1h") + private[spark] val SHUFFLE_MIN_NUM_PARTS_TO_HIGHLY_COMPRESS = + ConfigBuilder("spark.shuffle.minNumPartitionsToHighlyCompress") + .internal() + .doc("Number of partitions to determine if MapStatus should use HighlyCompressedMapStatus") + .intConf + .checkValue(v => v > 0, "The value should be a positive integer.") + .createWithDefault(2000) } diff --git a/core/src/main/scala/org/apache/spark/internal/io/SparkHadoopWriter.scala b/core/src/main/scala/org/apache/spark/internal/io/SparkHadoopWriter.scala index abf39213fa0d2..9ebd0aa301592 100644 --- a/core/src/main/scala/org/apache/spark/internal/io/SparkHadoopWriter.scala +++ b/core/src/main/scala/org/apache/spark/internal/io/SparkHadoopWriter.scala @@ -76,13 +76,17 @@ object SparkHadoopWriter extends Logging { // Try to write all RDD partitions as a Hadoop OutputFormat. try { val ret = sparkContext.runJob(rdd, (context: TaskContext, iter: Iterator[(K, V)]) => { + // SPARK-24552: Generate a unique "attempt ID" based on the stage and task attempt numbers. + // Assumes that there won't be more than Short.MaxValue attempts, at least not concurrently. + val attemptId = (context.stageAttemptNumber << 16) | context.attemptNumber + executeTask( context = context, config = config, jobTrackerId = jobTrackerId, commitJobId = commitJobId, sparkPartitionId = context.partitionId, - sparkAttemptNumber = context.attemptNumber, + sparkAttemptNumber = attemptId, committer = committer, iterator = iter) }) diff --git a/core/src/main/scala/org/apache/spark/mapred/SparkHadoopMapRedUtil.scala b/core/src/main/scala/org/apache/spark/mapred/SparkHadoopMapRedUtil.scala index 764735dc4eae7..db8aff94ea1e1 100644 --- a/core/src/main/scala/org/apache/spark/mapred/SparkHadoopMapRedUtil.scala +++ b/core/src/main/scala/org/apache/spark/mapred/SparkHadoopMapRedUtil.scala @@ -69,9 +69,9 @@ object SparkHadoopMapRedUtil extends Logging { if (shouldCoordinateWithDriver) { val outputCommitCoordinator = SparkEnv.get.outputCommitCoordinator - val taskAttemptNumber = TaskContext.get().attemptNumber() - val stageId = TaskContext.get().stageId() - val canCommit = outputCommitCoordinator.canCommit(stageId, splitId, taskAttemptNumber) + val ctx = TaskContext.get() + val canCommit = outputCommitCoordinator.canCommit(ctx.stageId(), ctx.stageAttemptNumber(), + splitId, ctx.attemptNumber()) if (canCommit) { performCommit() @@ -81,7 +81,7 @@ object SparkHadoopMapRedUtil extends Logging { logInfo(message) // We need to abort the task so that the driver can reschedule new attempts, if necessary committer.abortTask(mrTaskContext) - throw new CommitDeniedException(message, stageId, splitId, taskAttemptNumber) + throw new CommitDeniedException(message, ctx.stageId(), splitId, ctx.attemptNumber()) } } else { // Speculation is disabled or a user has chosen to manually bypass the commit coordination diff --git a/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala b/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala index 13db4985b0b80..ba9dae4ad48ec 100644 --- a/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala @@ -95,7 +95,7 @@ class AsyncRDDActions[T: ClassTag](self: RDD[T]) extends Serializable with Loggi // the left side of max is >=1 whenever partsScanned >= 2 numPartsToTry = Math.max(1, (1.5 * num * partsScanned / results.size).toInt - partsScanned) - numPartsToTry = Math.min(numPartsToTry, partsScanned * 4) + numPartsToTry = Math.min(numPartsToTry, partsScanned * 4L) } } diff --git a/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala b/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala index 4e036c2ed49b5..23cf19d55b4ae 100644 --- a/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala @@ -30,7 +30,7 @@ private[spark] class BlockRDD[T: ClassTag](sc: SparkContext, @transient val blockIds: Array[BlockId]) extends RDD[T](sc, Nil) { - @transient lazy val _locations = BlockManager.blockIdsToHosts(blockIds, SparkEnv.get) + @transient lazy val _locations = BlockManager.blockIdsToLocations(blockIds, SparkEnv.get) @volatile private var _isValid = true override def getPartitions: Array[Partition] = { diff --git a/core/src/main/scala/org/apache/spark/scheduler/BlacklistTracker.scala b/core/src/main/scala/org/apache/spark/scheduler/BlacklistTracker.scala index 30cf75d43ee09..980fbbe516b91 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/BlacklistTracker.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/BlacklistTracker.scala @@ -371,7 +371,7 @@ private[scheduler] class BlacklistTracker ( } -private[scheduler] object BlacklistTracker extends Logging { +private[spark] object BlacklistTracker extends Logging { private val DEFAULT_TIMEOUT = "1h" diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 041eade82d3ca..f74425d73b392 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -1171,6 +1171,7 @@ class DAGScheduler( outputCommitCoordinator.taskCompleted( stageId, + task.stageAttemptId, task.partitionId, event.taskInfo.attemptNumber, // this is a task attempt number event.reason) @@ -1330,23 +1331,24 @@ class DAGScheduler( s" ${task.stageAttemptId} and there is a more recent attempt for that stage " + s"(attempt ${failedStage.latestInfo.attemptNumber}) running") } else { + failedStage.fetchFailedAttemptIds.add(task.stageAttemptId) + val shouldAbortStage = + failedStage.fetchFailedAttemptIds.size >= maxConsecutiveStageAttempts || + disallowStageRetryForTest + // It is likely that we receive multiple FetchFailed for a single stage (because we have // multiple tasks running concurrently on different executors). In that case, it is // possible the fetch failure has already been handled by the scheduler. if (runningStages.contains(failedStage)) { logInfo(s"Marking $failedStage (${failedStage.name}) as failed " + s"due to a fetch failure from $mapStage (${mapStage.name})") - markStageAsFinished(failedStage, Some(failureMessage)) + markStageAsFinished(failedStage, errorMessage = Some(failureMessage), + willRetry = !shouldAbortStage) } else { logDebug(s"Received fetch failure from $task, but its from $failedStage which is no " + s"longer running") } - failedStage.fetchFailedAttemptIds.add(task.stageAttemptId) - val shouldAbortStage = - failedStage.fetchFailedAttemptIds.size >= maxConsecutiveStageAttempts || - disallowStageRetryForTest - if (shouldAbortStage) { val abortMessage = if (disallowStageRetryForTest) { "Fetch failure will not retry stage due to testing config" @@ -1545,7 +1547,10 @@ class DAGScheduler( /** * Marks a stage as finished and removes it from the list of running stages. */ - private def markStageAsFinished(stage: Stage, errorMessage: Option[String] = None): Unit = { + private def markStageAsFinished( + stage: Stage, + errorMessage: Option[String] = None, + willRetry: Boolean = false): Unit = { val serviceTime = stage.latestInfo.submissionTime match { case Some(t) => "%.03f".format((clock.getTimeMillis() - t) / 1000.0) case _ => "Unknown" @@ -1564,7 +1569,9 @@ class DAGScheduler( logInfo(s"$stage (${stage.name}) failed in $serviceTime s due to ${errorMessage.get}") } - outputCommitCoordinator.stageEnd(stage.id) + if (!willRetry) { + outputCommitCoordinator.stageEnd(stage.id) + } listenerBus.post(SparkListenerStageCompleted(stage.latestInfo)) runningStages -= stage } diff --git a/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala b/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala index 2ec2f2031aa45..659694dd189ad 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala @@ -50,7 +50,9 @@ private[spark] sealed trait MapStatus { private[spark] object MapStatus { def apply(loc: BlockManagerId, uncompressedSizes: Array[Long]): MapStatus = { - if (uncompressedSizes.length > 2000) { + if (uncompressedSizes.length > Option(SparkEnv.get) + .map(_.conf.get(config.SHUFFLE_MIN_NUM_PARTS_TO_HIGHLY_COMPRESS)) + .getOrElse(config.SHUFFLE_MIN_NUM_PARTS_TO_HIGHLY_COMPRESS.defaultValue.get)) { HighlyCompressedMapStatus(loc, uncompressedSizes) } else { new CompressedMapStatus(loc, uncompressedSizes) diff --git a/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala b/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala index 83d87b548a430..b382d623806e2 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala @@ -27,7 +27,11 @@ import org.apache.spark.util.{RpcUtils, ThreadUtils} private sealed trait OutputCommitCoordinationMessage extends Serializable private case object StopCoordinator extends OutputCommitCoordinationMessage -private case class AskPermissionToCommitOutput(stage: Int, partition: Int, attemptNumber: Int) +private case class AskPermissionToCommitOutput( + stage: Int, + stageAttempt: Int, + partition: Int, + attemptNumber: Int) /** * Authority that decides whether tasks can commit output to HDFS. Uses a "first committer wins" @@ -45,13 +49,15 @@ private[spark] class OutputCommitCoordinator(conf: SparkConf, isDriver: Boolean) // Initialized by SparkEnv var coordinatorRef: Option[RpcEndpointRef] = None - private type StageId = Int - private type PartitionId = Int - private type TaskAttemptNumber = Int - private val NO_AUTHORIZED_COMMITTER: TaskAttemptNumber = -1 + // Class used to identify a committer. The task ID for a committer is implicitly defined by + // the partition being processed, but the coordinator needs to keep track of both the stage + // attempt and the task attempt, because in some situations the same task may be running + // concurrently in two different attempts of the same stage. + private case class TaskIdentifier(stageAttempt: Int, taskAttempt: Int) + private case class StageState(numPartitions: Int) { - val authorizedCommitters = Array.fill[TaskAttemptNumber](numPartitions)(NO_AUTHORIZED_COMMITTER) - val failures = mutable.Map[PartitionId, mutable.Set[TaskAttemptNumber]]() + val authorizedCommitters = Array.fill[TaskIdentifier](numPartitions)(null) + val failures = mutable.Map[Int, mutable.Set[TaskIdentifier]]() } /** @@ -64,7 +70,7 @@ private[spark] class OutputCommitCoordinator(conf: SparkConf, isDriver: Boolean) * * Access to this map should be guarded by synchronizing on the OutputCommitCoordinator instance. */ - private val stageStates = mutable.Map[StageId, StageState]() + private val stageStates = mutable.Map[Int, StageState]() /** * Returns whether the OutputCommitCoordinator's internal data structures are all empty. @@ -87,10 +93,11 @@ private[spark] class OutputCommitCoordinator(conf: SparkConf, isDriver: Boolean) * @return true if this task is authorized to commit, false otherwise */ def canCommit( - stage: StageId, - partition: PartitionId, - attemptNumber: TaskAttemptNumber): Boolean = { - val msg = AskPermissionToCommitOutput(stage, partition, attemptNumber) + stage: Int, + stageAttempt: Int, + partition: Int, + attemptNumber: Int): Boolean = { + val msg = AskPermissionToCommitOutput(stage, stageAttempt, partition, attemptNumber) coordinatorRef match { case Some(endpointRef) => ThreadUtils.awaitResult(endpointRef.ask[Boolean](msg), @@ -103,26 +110,35 @@ private[spark] class OutputCommitCoordinator(conf: SparkConf, isDriver: Boolean) } /** - * Called by the DAGScheduler when a stage starts. + * Called by the DAGScheduler when a stage starts. Initializes the stage's state if it hasn't + * yet been initialized. * * @param stage the stage id. * @param maxPartitionId the maximum partition id that could appear in this stage's tasks (i.e. * the maximum possible value of `context.partitionId`). */ - private[scheduler] def stageStart(stage: StageId, maxPartitionId: Int): Unit = synchronized { - stageStates(stage) = new StageState(maxPartitionId + 1) + private[scheduler] def stageStart(stage: Int, maxPartitionId: Int): Unit = synchronized { + stageStates.get(stage) match { + case Some(state) => + require(state.authorizedCommitters.length == maxPartitionId + 1) + logInfo(s"Reusing state from previous attempt of stage $stage.") + + case _ => + stageStates(stage) = new StageState(maxPartitionId + 1) + } } // Called by DAGScheduler - private[scheduler] def stageEnd(stage: StageId): Unit = synchronized { + private[scheduler] def stageEnd(stage: Int): Unit = synchronized { stageStates.remove(stage) } // Called by DAGScheduler private[scheduler] def taskCompleted( - stage: StageId, - partition: PartitionId, - attemptNumber: TaskAttemptNumber, + stage: Int, + stageAttempt: Int, + partition: Int, + attemptNumber: Int, reason: TaskEndReason): Unit = synchronized { val stageState = stageStates.getOrElse(stage, { logDebug(s"Ignoring task completion for completed stage") @@ -131,16 +147,17 @@ private[spark] class OutputCommitCoordinator(conf: SparkConf, isDriver: Boolean) reason match { case Success => // The task output has been committed successfully - case denied: TaskCommitDenied => - logInfo(s"Task was denied committing, stage: $stage, partition: $partition, " + - s"attempt: $attemptNumber") - case otherReason => + case _: TaskCommitDenied => + logInfo(s"Task was denied committing, stage: $stage.$stageAttempt, " + + s"partition: $partition, attempt: $attemptNumber") + case _ => // Mark the attempt as failed to blacklist from future commit protocol - stageState.failures.getOrElseUpdate(partition, mutable.Set()) += attemptNumber - if (stageState.authorizedCommitters(partition) == attemptNumber) { + val taskId = TaskIdentifier(stageAttempt, attemptNumber) + stageState.failures.getOrElseUpdate(partition, mutable.Set()) += taskId + if (stageState.authorizedCommitters(partition) == taskId) { logDebug(s"Authorized committer (attemptNumber=$attemptNumber, stage=$stage, " + s"partition=$partition) failed; clearing lock") - stageState.authorizedCommitters(partition) = NO_AUTHORIZED_COMMITTER + stageState.authorizedCommitters(partition) = null } } } @@ -155,47 +172,41 @@ private[spark] class OutputCommitCoordinator(conf: SparkConf, isDriver: Boolean) // Marked private[scheduler] instead of private so this can be mocked in tests private[scheduler] def handleAskPermissionToCommit( - stage: StageId, - partition: PartitionId, - attemptNumber: TaskAttemptNumber): Boolean = synchronized { + stage: Int, + stageAttempt: Int, + partition: Int, + attemptNumber: Int): Boolean = synchronized { stageStates.get(stage) match { - case Some(state) if attemptFailed(state, partition, attemptNumber) => - logInfo(s"Denying attemptNumber=$attemptNumber to commit for stage=$stage," + - s" partition=$partition as task attempt $attemptNumber has already failed.") + case Some(state) if attemptFailed(state, stageAttempt, partition, attemptNumber) => + logInfo(s"Commit denied for stage=$stage.$stageAttempt, partition=$partition: " + + s"task attempt $attemptNumber already marked as failed.") false case Some(state) => - state.authorizedCommitters(partition) match { - case NO_AUTHORIZED_COMMITTER => - logDebug(s"Authorizing attemptNumber=$attemptNumber to commit for stage=$stage, " + - s"partition=$partition") - state.authorizedCommitters(partition) = attemptNumber - true - case existingCommitter => - // Coordinator should be idempotent when receiving AskPermissionToCommit. - if (existingCommitter == attemptNumber) { - logWarning(s"Authorizing duplicate request to commit for " + - s"attemptNumber=$attemptNumber to commit for stage=$stage," + - s" partition=$partition; existingCommitter = $existingCommitter." + - s" This can indicate dropped network traffic.") - true - } else { - logDebug(s"Denying attemptNumber=$attemptNumber to commit for stage=$stage, " + - s"partition=$partition; existingCommitter = $existingCommitter") - false - } + val existing = state.authorizedCommitters(partition) + if (existing == null) { + logDebug(s"Commit allowed for stage=$stage.$stageAttempt, partition=$partition, " + + s"task attempt $attemptNumber") + state.authorizedCommitters(partition) = TaskIdentifier(stageAttempt, attemptNumber) + true + } else { + logDebug(s"Commit denied for stage=$stage.$stageAttempt, partition=$partition: " + + s"already committed by $existing") + false } case None => - logDebug(s"Stage $stage has completed, so not allowing" + - s" attempt number $attemptNumber of partition $partition to commit") + logDebug(s"Commit denied for stage=$stage.$stageAttempt, partition=$partition: " + + "stage already marked as completed.") false } } private def attemptFailed( stageState: StageState, - partition: PartitionId, - attempt: TaskAttemptNumber): Boolean = synchronized { - stageState.failures.get(partition).exists(_.contains(attempt)) + stageAttempt: Int, + partition: Int, + attempt: Int): Boolean = synchronized { + val failInfo = TaskIdentifier(stageAttempt, attempt) + stageState.failures.get(partition).exists(_.contains(failInfo)) } } @@ -215,9 +226,10 @@ private[spark] object OutputCommitCoordinator { } override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { - case AskPermissionToCommitOutput(stage, partition, attemptNumber) => + case AskPermissionToCommitOutput(stage, stageAttempt, partition, attemptNumber) => context.reply( - outputCommitCoordinator.handleAskPermissionToCommit(stage, partition, attemptNumber)) + outputCommitCoordinator.handleAskPermissionToCommit(stage, stageAttempt, partition, + attemptNumber)) } } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index d8794e8e551aa..9b90e309d2e04 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -170,8 +170,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp if (executorDataMap.contains(executorId)) { executorRef.send(RegisterExecutorFailed("Duplicate executor ID: " + executorId)) context.reply(true) - } else if (scheduler.nodeBlacklist != null && - scheduler.nodeBlacklist.contains(hostname)) { + } else if (scheduler.nodeBlacklist.contains(hostname)) { // If the cluster manager gives us an executor on a blacklisted node (because it // already started allocating those resources before we informed it of our blacklist, // or if it ignored our blacklist), then we reject that executor immediately. diff --git a/core/src/main/scala/org/apache/spark/status/AppStatusStore.scala b/core/src/main/scala/org/apache/spark/status/AppStatusStore.scala index 688f25a9fdea1..e237281c552b1 100644 --- a/core/src/main/scala/org/apache/spark/status/AppStatusStore.scala +++ b/core/src/main/scala/org/apache/spark/status/AppStatusStore.scala @@ -471,7 +471,7 @@ private[spark] class AppStatusStore( def operationGraphForJob(jobId: Int): Seq[RDDOperationGraph] = { val job = store.read(classOf[JobDataWrapper], jobId) - val stages = job.info.stageIds + val stages = job.info.stageIds.sorted stages.map { id => val g = store.read(classOf[RDDOperationGraphWrapper], id).toRDDOperationGraph() diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/ApiRootResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/ApiRootResource.scala index d121068718b8a..84c2ad48f1f27 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/ApiRootResource.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/ApiRootResource.scala @@ -28,7 +28,7 @@ import org.glassfish.jersey.server.ServerProperties import org.glassfish.jersey.servlet.ServletContainer import org.apache.spark.SecurityManager -import org.apache.spark.ui.SparkUI +import org.apache.spark.ui.{SparkUI, UIUtils} /** * Main entry point for serving spark application metrics as json, using JAX-RS. @@ -148,38 +148,18 @@ private[v1] trait BaseAppResource extends ApiRequestContext { } private[v1] class ForbiddenException(msg: String) extends WebApplicationException( - Response.status(Response.Status.FORBIDDEN).entity(msg).build()) + UIUtils.buildErrorResponse(Response.Status.FORBIDDEN, msg)) private[v1] class NotFoundException(msg: String) extends WebApplicationException( - new NoSuchElementException(msg), - Response - .status(Response.Status.NOT_FOUND) - .entity(ErrorWrapper(msg)) - .build() -) + UIUtils.buildErrorResponse(Response.Status.NOT_FOUND, msg)) private[v1] class ServiceUnavailable(msg: String) extends WebApplicationException( - new ServiceUnavailableException(msg), - Response - .status(Response.Status.SERVICE_UNAVAILABLE) - .entity(ErrorWrapper(msg)) - .build() -) + UIUtils.buildErrorResponse(Response.Status.SERVICE_UNAVAILABLE, msg)) private[v1] class BadParameterException(msg: String) extends WebApplicationException( - new IllegalArgumentException(msg), - Response - .status(Response.Status.BAD_REQUEST) - .entity(ErrorWrapper(msg)) - .build() -) { + UIUtils.buildErrorResponse(Response.Status.BAD_REQUEST, msg)) { def this(param: String, exp: String, actual: String) = { this(raw"""Bad value for parameter "$param". Expected a $exp, got "$actual"""") } } -/** - * Signal to JacksonMessageWriter to not convert the message into json (which would result in an - * extra set of quotes). - */ -private[v1] case class ErrorWrapper(s: String) diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/JacksonMessageWriter.scala b/core/src/main/scala/org/apache/spark/status/api/v1/JacksonMessageWriter.scala index 76af33c1a18db..4560d300cb0c8 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/JacksonMessageWriter.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/JacksonMessageWriter.scala @@ -68,10 +68,7 @@ private[v1] class JacksonMessageWriter extends MessageBodyWriter[Object]{ mediaType: MediaType, multivaluedMap: MultivaluedMap[String, AnyRef], outputStream: OutputStream): Unit = { - t match { - case ErrorWrapper(err) => outputStream.write(err.getBytes(StandardCharsets.UTF_8)) - case _ => mapper.writeValue(outputStream, t) - } + mapper.writeValue(outputStream, t) } override def getSize( diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/OneApplicationResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/OneApplicationResource.scala index 974697890dd03..32100c5704538 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/OneApplicationResource.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/OneApplicationResource.scala @@ -140,11 +140,8 @@ private[v1] class AbstractApplicationResource extends BaseAppResource { .header("Content-Type", MediaType.APPLICATION_OCTET_STREAM) .build() } catch { - case NonFatal(e) => - Response.serverError() - .entity(s"Event logs are not available for app: $appId.") - .status(Response.Status.SERVICE_UNAVAILABLE) - .build() + case NonFatal(_) => + throw new ServiceUnavailable(s"Event logs are not available for app: $appId.") } } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index e0276a4dc4224..0e1c7d5fd3fa2 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -45,6 +45,7 @@ import org.apache.spark.network.netty.SparkTransportConf import org.apache.spark.network.shuffle.{ExternalShuffleClient, TempFileManager} import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo import org.apache.spark.rpc.RpcEnv +import org.apache.spark.scheduler.ExecutorCacheTaskLocation import org.apache.spark.serializer.{SerializerInstance, SerializerManager} import org.apache.spark.shuffle.ShuffleManager import org.apache.spark.storage.memory._ @@ -291,7 +292,7 @@ private[spark] class BlockManager( case e: Exception if i < MAX_ATTEMPTS => logError(s"Failed to connect to external shuffle server, will retry ${MAX_ATTEMPTS - i}" + s" more times after waiting $SLEEP_TIME_SECS seconds...", e) - Thread.sleep(SLEEP_TIME_SECS * 1000) + Thread.sleep(SLEEP_TIME_SECS * 1000L) case NonFatal(e) => throw new SparkException("Unable to register with external shuffle server due to : " + e.getMessage, e) @@ -1554,7 +1555,7 @@ private[spark] class BlockManager( private[spark] object BlockManager { private val ID_GENERATOR = new IdGenerator - def blockIdsToHosts( + def blockIdsToLocations( blockIds: Array[BlockId], env: SparkEnv, blockManagerMaster: BlockManagerMaster = null): Map[BlockId, Seq[String]] = { @@ -1569,7 +1570,9 @@ private[spark] object BlockManager { val blockManagers = new HashMap[BlockId, Seq[String]] for (i <- 0 until blockIds.length) { - blockManagers(blockIds(i)) = blockLocations(i).map(_.host) + blockManagers(blockIds(i)) = blockLocations(i).map { loc => + ExecutorCacheTaskLocation(loc.host, loc.executorId).toString + } } blockManagers.toMap } diff --git a/core/src/main/scala/org/apache/spark/storage/RDDInfo.scala b/core/src/main/scala/org/apache/spark/storage/RDDInfo.scala index e5abbf745cc41..9ccc8f9cc585b 100644 --- a/core/src/main/scala/org/apache/spark/storage/RDDInfo.scala +++ b/core/src/main/scala/org/apache/spark/storage/RDDInfo.scala @@ -17,7 +17,9 @@ package org.apache.spark.storage +import org.apache.spark.SparkEnv import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.internal.config._ import org.apache.spark.rdd.{RDD, RDDOperationScope} import org.apache.spark.util.Utils @@ -53,10 +55,16 @@ class RDDInfo( } private[spark] object RDDInfo { + private val callsiteForm = SparkEnv.get.conf.get(EVENT_LOG_CALLSITE_FORM) + def fromRdd(rdd: RDD[_]): RDDInfo = { val rddName = Option(rdd.name).getOrElse(Utils.getFormattedClassName(rdd)) val parentIds = rdd.dependencies.map(_.rdd.id) + val callSite = callsiteForm match { + case "short" => rdd.creationSite.shortForm + case "long" => rdd.creationSite.longForm + } new RDDInfo(rdd.id, rddName, rdd.partitions.length, - rdd.getStorageLevel, parentIds, rdd.creationSite.shortForm, rdd.scope) + rdd.getStorageLevel, parentIds, callSite, rdd.scope) } } diff --git a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala index b44ac0ea1febc..d315ef66e0dc0 100644 --- a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala +++ b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala @@ -65,7 +65,7 @@ private[spark] class SparkUI private ( attachTab(new StorageTab(this, store)) attachTab(new EnvironmentTab(this, store)) attachTab(new ExecutorsTab(this)) - attachHandler(createStaticHandler(SparkUI.STATIC_RESOURCE_DIR, "/static")) + addStaticHandler(SparkUI.STATIC_RESOURCE_DIR) attachHandler(createRedirectHandler("/", "/jobs/", basePath = basePath)) attachHandler(ApiRootResource.getServletHandler(this)) diff --git a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala index 5d015b0531ef6..732b7528f499e 100644 --- a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala @@ -21,6 +21,7 @@ import java.net.URLDecoder import java.text.SimpleDateFormat import java.util.{Date, Locale, TimeZone} import javax.servlet.http.HttpServletRequest +import javax.ws.rs.core.{MediaType, Response} import scala.util.control.NonFatal import scala.xml._ @@ -566,4 +567,8 @@ private[spark] object UIUtils extends Logging { NEWLINE_AND_SINGLE_QUOTE_REGEX.replaceAllIn(requestParameter, "")) } } + + def buildErrorResponse(status: Response.Status, msg: String): Response = { + Response.status(status).entity(msg).`type`(MediaType.TEXT_PLAIN).build() + } } diff --git a/core/src/main/scala/org/apache/spark/ui/WebUI.scala b/core/src/main/scala/org/apache/spark/ui/WebUI.scala index 8b75f5d8fe1a8..2e43f17e6a8e3 100644 --- a/core/src/main/scala/org/apache/spark/ui/WebUI.scala +++ b/core/src/main/scala/org/apache/spark/ui/WebUI.scala @@ -60,23 +60,25 @@ private[spark] abstract class WebUI( def getHandlers: Seq[ServletContextHandler] = handlers def getSecurityManager: SecurityManager = securityManager - /** Attach a tab to this UI, along with all of its attached pages. */ - def attachTab(tab: WebUITab) { + /** Attaches a tab to this UI, along with all of its attached pages. */ + def attachTab(tab: WebUITab): Unit = { tab.pages.foreach(attachPage) tabs += tab } - def detachTab(tab: WebUITab) { + /** Detaches a tab from this UI, along with all of its attached pages. */ + def detachTab(tab: WebUITab): Unit = { tab.pages.foreach(detachPage) tabs -= tab } - def detachPage(page: WebUIPage) { + /** Detaches a page from this UI, along with all of its attached handlers. */ + def detachPage(page: WebUIPage): Unit = { pageToHandlers.remove(page).foreach(_.foreach(detachHandler)) } - /** Attach a page to this UI. */ - def attachPage(page: WebUIPage) { + /** Attaches a page to this UI. */ + def attachPage(page: WebUIPage): Unit = { val pagePath = "/" + page.prefix val renderHandler = createServletHandler(pagePath, (request: HttpServletRequest) => page.render(request), securityManager, conf, basePath) @@ -88,41 +90,41 @@ private[spark] abstract class WebUI( handlers += renderHandler } - /** Attach a handler to this UI. */ - def attachHandler(handler: ServletContextHandler) { + /** Attaches a handler to this UI. */ + def attachHandler(handler: ServletContextHandler): Unit = { handlers += handler serverInfo.foreach(_.addHandler(handler)) } - /** Detach a handler from this UI. */ - def detachHandler(handler: ServletContextHandler) { + /** Detaches a handler from this UI. */ + def detachHandler(handler: ServletContextHandler): Unit = { handlers -= handler serverInfo.foreach(_.removeHandler(handler)) } /** - * Add a handler for static content. + * Detaches the content handler at `path` URI. * - * @param resourceBase Root of where to find resources to serve. - * @param path Path in UI where to mount the resources. + * @param path Path in UI to unmount. */ - def addStaticHandler(resourceBase: String, path: String): Unit = { - attachHandler(JettyUtils.createStaticHandler(resourceBase, path)) + def detachHandler(path: String): Unit = { + handlers.find(_.getContextPath() == path).foreach(detachHandler) } /** - * Remove a static content handler. + * Adds a handler for static content. * - * @param path Path in UI to unmount. + * @param resourceBase Root of where to find resources to serve. + * @param path Path in UI where to mount the resources. */ - def removeStaticHandler(path: String): Unit = { - handlers.find(_.getContextPath() == path).foreach(detachHandler) + def addStaticHandler(resourceBase: String, path: String = "/static"): Unit = { + attachHandler(JettyUtils.createStaticHandler(resourceBase, path)) } - /** Initialize all components of the server. */ + /** A hook to initialize components of the UI */ def initialize(): Unit - /** Bind to the HTTP server behind this web interface. */ + /** Binds to the HTTP server behind this web interface. */ def bind(): Unit = { assert(serverInfo.isEmpty, s"Attempted to bind $className more than once!") try { @@ -136,17 +138,17 @@ private[spark] abstract class WebUI( } } - /** Return the url of web interface. Only valid after bind(). */ + /** @return The url of web interface. Only valid after [[bind]]. */ def webUrl: String = s"http://$publicHostName:$boundPort" - /** Return the actual port to which this server is bound. Only valid after bind(). */ + /** @return The actual port to which this server is bound. Only valid after [[bind]]. */ def boundPort: Int = serverInfo.map(_.boundPort).getOrElse(-1) - /** Stop the server behind this web interface. Only valid after bind(). */ + /** Stops the server behind this web interface. Only valid after [[bind]]. */ def stop(): Unit = { assert(serverInfo.isDefined, s"Attempted to stop $className before binding to a server!") - serverInfo.get.stop() + serverInfo.foreach(_.stop()) } } diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala index 178d2c8d1a10a..90e9a7a3630cf 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala @@ -464,7 +464,7 @@ private[ui] class JobDataSource( val jobDescription = UIUtils.makeDescription(lastStageDescription, basePath, plainText = false) - val detailUrl = "%s/jobs/job?id=%s".format(basePath, jobData.jobId) + val detailUrl = "%s/jobs/job/?id=%s".format(basePath, jobData.jobId) new JobTableRowData( jobData, diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala index d4e6a7bc3effa..55eb989962668 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala @@ -282,7 +282,7 @@ private[ui] class StagePage(parent: StagesTab, store: AppStatusStore) extends We val _taskTable = new TaskPagedTable( stageData, UIUtils.prependBaseUri(request, parent.basePath) + - s"/stages/stage?id=${stageId}&attempt=${stageAttemptId}", + s"/stages/stage/?id=${stageId}&attempt=${stageAttemptId}", currentTime, pageSize = taskPageSize, sortColumn = taskSortColumn, diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala index 56e4d6838a99a..d01acdae59c9f 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala @@ -370,7 +370,7 @@ private[ui] class StagePagedTable( Seq.empty } - val nameLinkUri = s"$basePathUri/stages/stage?id=${s.stageId}&attempt=${s.attemptId}" + val nameLinkUri = s"$basePathUri/stages/stage/?id=${s.stageId}&attempt=${s.attemptId}" val nameLink = {s.name} val cachedRddInfos = store.rddList().filter { rdd => s.rddIds.contains(rdd.id) } diff --git a/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala b/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala index ad0c0639521f6..073d71c63b0c7 100644 --- a/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala +++ b/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala @@ -22,8 +22,8 @@ import java.io.{ByteArrayInputStream, ByteArrayOutputStream} import scala.collection.mutable.{Map, Set, Stack} import scala.language.existentials -import org.apache.xbean.asm5.{ClassReader, ClassVisitor, MethodVisitor, Type} -import org.apache.xbean.asm5.Opcodes._ +import org.apache.xbean.asm6.{ClassReader, ClassVisitor, MethodVisitor, Type} +import org.apache.xbean.asm6.Opcodes._ import org.apache.spark.{SparkEnv, SparkException} import org.apache.spark.internal.Logging diff --git a/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala b/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala index 165a15c73e7ca..0f08a2b0ad895 100644 --- a/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala +++ b/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala @@ -19,13 +19,12 @@ package org.apache.spark.util import java.util.concurrent._ +import com.google.common.util.concurrent.{MoreExecutors, ThreadFactoryBuilder} import scala.concurrent.{Awaitable, ExecutionContext, ExecutionContextExecutor} -import scala.concurrent.duration.Duration +import scala.concurrent.duration.{Duration, FiniteDuration} import scala.concurrent.forkjoin.{ForkJoinPool => SForkJoinPool, ForkJoinWorkerThread => SForkJoinWorkerThread} import scala.util.control.NonFatal -import com.google.common.util.concurrent.{MoreExecutors, ThreadFactoryBuilder} - import org.apache.spark.SparkException private[spark] object ThreadUtils { @@ -103,6 +102,22 @@ private[spark] object ThreadUtils { executor } + /** + * Wrapper over ScheduledThreadPoolExecutor. + */ + def newDaemonThreadPoolScheduledExecutor(threadNamePrefix: String, numThreads: Int) + : ScheduledExecutorService = { + val threadFactory = new ThreadFactoryBuilder() + .setDaemon(true) + .setNameFormat(s"$threadNamePrefix-%d") + .build() + val executor = new ScheduledThreadPoolExecutor(numThreads, threadFactory) + // By default, a cancelled task is not automatically removed from the work queue until its delay + // elapses. We have to enable it manually. + executor.setRemoveOnCancelPolicy(true) + executor + } + /** * Run a piece of code in a new thread and return the result. Exception in the new thread is * thrown in the caller thread with an adjusted stack trace that removes references to this @@ -229,4 +244,14 @@ private[spark] object ThreadUtils { } } // scalastyle:on awaitready + + def shutdown( + executor: ExecutorService, + gracePeriod: Duration = FiniteDuration(30, TimeUnit.SECONDS)): Unit = { + executor.shutdown() + executor.awaitTermination(gracePeriod.toMillis, TimeUnit.MILLISECONDS) + if (!executor.isShutdown) { + executor.shutdownNow() + } + } } diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index c139db46b63a3..a6fd3637663e8 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -100,7 +100,7 @@ private[spark] object Utils extends Logging { */ val DEFAULT_MAX_TO_STRING_FIELDS = 25 - private def maxNumToStringFields = { + private[spark] def maxNumToStringFields = { if (SparkEnv.get != null) { SparkEnv.get.conf.getInt("spark.debug.maxToStringFields", DEFAULT_MAX_TO_STRING_FIELDS) } else { diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala index 176f84fa2a0d2..b159200d79222 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala @@ -368,8 +368,8 @@ private[spark] class ExternalSorter[K, V, C]( val bufferedIters = iterators.filter(_.hasNext).map(_.buffered) type Iter = BufferedIterator[Product2[K, C]] val heap = new mutable.PriorityQueue[Iter]()(new Ordering[Iter] { - // Use the reverse of comparator.compare because PriorityQueue dequeues the max - override def compare(x: Iter, y: Iter): Int = -comparator.compare(x.head._1, y.head._1) + // Use the reverse order because PriorityQueue dequeues the max + override def compare(x: Iter, y: Iter): Int = comparator.compare(y.head._1, x.head._1) }) heap.enqueue(bufferedIters: _*) // Will contain only the iterators with hasNext = true new Iterator[Product2[K, C]] { diff --git a/core/src/test/java/org/apache/spark/memory/TestMemoryConsumer.java b/core/src/test/java/org/apache/spark/memory/TestMemoryConsumer.java index db91329c94cb6..0bbaea6b834b8 100644 --- a/core/src/test/java/org/apache/spark/memory/TestMemoryConsumer.java +++ b/core/src/test/java/org/apache/spark/memory/TestMemoryConsumer.java @@ -17,6 +17,10 @@ package org.apache.spark.memory; +import com.google.common.annotations.VisibleForTesting; + +import org.apache.spark.unsafe.memory.MemoryBlock; + import java.io.IOException; public class TestMemoryConsumer extends MemoryConsumer { @@ -43,6 +47,12 @@ void free(long size) { used -= size; taskMemoryManager.releaseExecutionMemory(size, this); } + + @VisibleForTesting + public void freePage(MemoryBlock page) { + used -= page.size(); + taskMemoryManager.freePage(page, this); + } } diff --git a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java index c145532328514..85ffdca436e14 100644 --- a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java +++ b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java @@ -129,7 +129,6 @@ public int compare( final UnsafeSorterIterator iter = sorter.getSortedIterator(); int iterLength = 0; long prevPrefix = -1; - Arrays.sort(dataToSort); while (iter.hasNext()) { iter.loadNext(); final String str = diff --git a/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala b/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala index 88916488c0def..b705556e54b14 100644 --- a/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala +++ b/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala @@ -19,7 +19,6 @@ package org.apache.spark import java.util.concurrent.{ExecutorService, TimeUnit} -import scala.collection.Map import scala.collection.mutable import scala.concurrent.Future import scala.concurrent.duration._ @@ -73,6 +72,7 @@ class HeartbeatReceiverSuite sc = spy(new SparkContext(conf)) scheduler = mock(classOf[TaskSchedulerImpl]) when(sc.taskScheduler).thenReturn(scheduler) + when(scheduler.nodeBlacklist).thenReturn(Predef.Set[String]()) when(scheduler.sc).thenReturn(sc) heartbeatReceiverClock = new ManualClock heartbeatReceiver = new HeartbeatReceiver(sc, heartbeatReceiverClock) @@ -241,7 +241,7 @@ class HeartbeatReceiverSuite } === Some(true)) } - private def getTrackedExecutors: Map[String, Long] = { + private def getTrackedExecutors: collection.Map[String, Long] = { // We may receive undesired SparkListenerExecutorAdded from LocalSchedulerBackend, // so exclude it from the map. See SPARK-10800. heartbeatReceiver.invokePrivate(_executorLastSeen()). @@ -272,7 +272,7 @@ private class FakeSchedulerBackend( protected override def doRequestTotalExecutors(requestedTotal: Int): Future[Boolean] = { clusterManagerEndpoint.ask[Boolean]( - RequestExecutors(requestedTotal, localityAwareTasks, hostToLocalTaskCount, Set.empty[String])) + RequestExecutors(requestedTotal, localityAwareTasks, hostToLocalTaskCount, Set.empty)) } protected override def doKillExecutors(executorIds: Seq[String]): Future[Boolean] = { diff --git a/core/src/test/scala/org/apache/spark/SSLOptionsSuite.scala b/core/src/test/scala/org/apache/spark/SSLOptionsSuite.scala index 8eabc2b3cb958..5dbfc5c10a6f8 100644 --- a/core/src/test/scala/org/apache/spark/SSLOptionsSuite.scala +++ b/core/src/test/scala/org/apache/spark/SSLOptionsSuite.scala @@ -18,8 +18,11 @@ package org.apache.spark import java.io.File +import java.util.UUID import javax.net.ssl.SSLContext +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.security.alias.{CredentialProvider, CredentialProviderFactory} import org.scalatest.BeforeAndAfterAll import org.apache.spark.util.SparkConfWithEnv @@ -40,6 +43,7 @@ class SSLOptionsSuite extends SparkFunSuite with BeforeAndAfterAll { .toSet val conf = new SparkConf + val hadoopConf = new Configuration() conf.set("spark.ssl.enabled", "true") conf.set("spark.ssl.keyStore", keyStorePath) conf.set("spark.ssl.keyStorePassword", "password") @@ -49,7 +53,7 @@ class SSLOptionsSuite extends SparkFunSuite with BeforeAndAfterAll { conf.set("spark.ssl.enabledAlgorithms", algorithms.mkString(",")) conf.set("spark.ssl.protocol", "TLSv1.2") - val opts = SSLOptions.parse(conf, "spark.ssl") + val opts = SSLOptions.parse(conf, hadoopConf, "spark.ssl") assert(opts.enabled === true) assert(opts.trustStore.isDefined === true) @@ -70,6 +74,7 @@ class SSLOptionsSuite extends SparkFunSuite with BeforeAndAfterAll { val trustStorePath = new File(this.getClass.getResource("/truststore").toURI).getAbsolutePath val conf = new SparkConf + val hadoopConf = new Configuration() conf.set("spark.ssl.enabled", "true") conf.set("spark.ssl.keyStore", keyStorePath) conf.set("spark.ssl.keyStorePassword", "password") @@ -80,8 +85,8 @@ class SSLOptionsSuite extends SparkFunSuite with BeforeAndAfterAll { "TLS_RSA_WITH_AES_128_CBC_SHA, TLS_RSA_WITH_AES_256_CBC_SHA") conf.set("spark.ssl.protocol", "SSLv3") - val defaultOpts = SSLOptions.parse(conf, "spark.ssl", defaults = None) - val opts = SSLOptions.parse(conf, "spark.ssl.ui", defaults = Some(defaultOpts)) + val defaultOpts = SSLOptions.parse(conf, hadoopConf, "spark.ssl", defaults = None) + val opts = SSLOptions.parse(conf, hadoopConf, "spark.ssl.ui", defaults = Some(defaultOpts)) assert(opts.enabled === true) assert(opts.trustStore.isDefined === true) @@ -103,6 +108,7 @@ class SSLOptionsSuite extends SparkFunSuite with BeforeAndAfterAll { val trustStorePath = new File(this.getClass.getResource("/truststore").toURI).getAbsolutePath val conf = new SparkConf + val hadoopConf = new Configuration() conf.set("spark.ssl.enabled", "true") conf.set("spark.ssl.ui.enabled", "false") conf.set("spark.ssl.ui.port", "4242") @@ -117,8 +123,8 @@ class SSLOptionsSuite extends SparkFunSuite with BeforeAndAfterAll { conf.set("spark.ssl.ui.enabledAlgorithms", "ABC, DEF") conf.set("spark.ssl.protocol", "SSLv3") - val defaultOpts = SSLOptions.parse(conf, "spark.ssl", defaults = None) - val opts = SSLOptions.parse(conf, "spark.ssl.ui", defaults = Some(defaultOpts)) + val defaultOpts = SSLOptions.parse(conf, hadoopConf, "spark.ssl", defaults = None) + val opts = SSLOptions.parse(conf, hadoopConf, "spark.ssl.ui", defaults = Some(defaultOpts)) assert(opts.enabled === false) assert(opts.port === Some(4242)) @@ -139,14 +145,71 @@ class SSLOptionsSuite extends SparkFunSuite with BeforeAndAfterAll { val conf = new SparkConfWithEnv(Map( "ENV1" -> "val1", "ENV2" -> "val2")) + val hadoopConf = new Configuration() conf.set("spark.ssl.enabled", "true") conf.set("spark.ssl.keyStore", "${env:ENV1}") conf.set("spark.ssl.trustStore", "${env:ENV2}") - val opts = SSLOptions.parse(conf, "spark.ssl", defaults = None) + val opts = SSLOptions.parse(conf, hadoopConf, "spark.ssl", defaults = None) assert(opts.keyStore === Some(new File("val1"))) assert(opts.trustStore === Some(new File("val2"))) } + test("get password from Hadoop credential provider") { + val keyStorePath = new File(this.getClass.getResource("/keystore").toURI).getAbsolutePath + val trustStorePath = new File(this.getClass.getResource("/truststore").toURI).getAbsolutePath + + val conf = new SparkConf + val hadoopConf = new Configuration() + val tmpPath = s"localjceks://file${sys.props("java.io.tmpdir")}/test-" + + s"${UUID.randomUUID().toString}.jceks" + val provider = createCredentialProvider(tmpPath, hadoopConf) + + conf.set("spark.ssl.enabled", "true") + conf.set("spark.ssl.keyStore", keyStorePath) + storePassword(provider, "spark.ssl.keyStorePassword", "password") + storePassword(provider, "spark.ssl.keyPassword", "password") + conf.set("spark.ssl.trustStore", trustStorePath) + storePassword(provider, "spark.ssl.trustStorePassword", "password") + conf.set("spark.ssl.enabledAlgorithms", + "TLS_RSA_WITH_AES_128_CBC_SHA, TLS_RSA_WITH_AES_256_CBC_SHA") + conf.set("spark.ssl.protocol", "SSLv3") + + val defaultOpts = SSLOptions.parse(conf, hadoopConf, "spark.ssl", defaults = None) + val opts = SSLOptions.parse(conf, hadoopConf, "spark.ssl.ui", defaults = Some(defaultOpts)) + + assert(opts.enabled === true) + assert(opts.trustStore.isDefined === true) + assert(opts.trustStore.get.getName === "truststore") + assert(opts.trustStore.get.getAbsolutePath === trustStorePath) + assert(opts.keyStore.isDefined === true) + assert(opts.keyStore.get.getName === "keystore") + assert(opts.keyStore.get.getAbsolutePath === keyStorePath) + assert(opts.trustStorePassword === Some("password")) + assert(opts.keyStorePassword === Some("password")) + assert(opts.keyPassword === Some("password")) + assert(opts.protocol === Some("SSLv3")) + assert(opts.enabledAlgorithms === + Set("TLS_RSA_WITH_AES_128_CBC_SHA", "TLS_RSA_WITH_AES_256_CBC_SHA")) + } + + private def createCredentialProvider(tmpPath: String, conf: Configuration): CredentialProvider = { + conf.set(CredentialProviderFactory.CREDENTIAL_PROVIDER_PATH, tmpPath) + + val provider = CredentialProviderFactory.getProviders(conf).get(0) + if (provider == null) { + throw new IllegalStateException(s"Fail to get credential provider with path $tmpPath") + } + + provider + } + + private def storePassword( + provider: CredentialProvider, + passwordKey: String, + password: String): Unit = { + provider.createCredentialEntry(passwordKey, password.toCharArray) + provider.flush() + } } diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala index 545c8d0423dc3..f829fecc30840 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala @@ -995,20 +995,24 @@ class SparkSubmitSuite } test("download remote resource if it is not supported by yarn service") { - testRemoteResources(enableHttpFs = false, blacklistHttpFs = false) + testRemoteResources(enableHttpFs = false) } test("avoid downloading remote resource if it is supported by yarn service") { - testRemoteResources(enableHttpFs = true, blacklistHttpFs = false) + testRemoteResources(enableHttpFs = true) } test("force download from blacklisted schemes") { - testRemoteResources(enableHttpFs = true, blacklistHttpFs = true) + testRemoteResources(enableHttpFs = true, blacklistSchemes = Seq("http")) + } + + test("force download for all the schemes") { + testRemoteResources(enableHttpFs = true, blacklistSchemes = Seq("*")) } private def testRemoteResources( enableHttpFs: Boolean, - blacklistHttpFs: Boolean): Unit = { + blacklistSchemes: Seq[String] = Nil): Unit = { val hadoopConf = new Configuration() updateConfWithFakeS3Fs(hadoopConf) if (enableHttpFs) { @@ -1025,8 +1029,8 @@ class SparkSubmitSuite val tmpHttpJar = TestUtils.createJarWithFiles(Map("test.resource" -> "USER"), tmpDir) val tmpHttpJarPath = s"http://${new File(tmpHttpJar.toURI).getAbsolutePath}" - val forceDownloadArgs = if (blacklistHttpFs) { - Seq("--conf", "spark.yarn.dist.forceDownloadSchemes=http") + val forceDownloadArgs = if (blacklistSchemes.nonEmpty) { + Seq("--conf", s"spark.yarn.dist.forceDownloadSchemes=${blacklistSchemes.mkString(",")}") } else { Nil } @@ -1044,14 +1048,19 @@ class SparkSubmitSuite val jars = conf.get("spark.yarn.dist.jars").split(",").toSet - // The URI of remote S3 resource should still be remote. - assert(jars.contains(tmpS3JarPath)) + def isSchemeBlacklisted(scheme: String) = { + blacklistSchemes.contains("*") || blacklistSchemes.contains(scheme) + } + + if (!isSchemeBlacklisted("s3")) { + assert(jars.contains(tmpS3JarPath)) + } - if (enableHttpFs && !blacklistHttpFs) { + if (enableHttpFs && blacklistSchemes.isEmpty) { // If Http FS is supported by yarn service, the URI of remote http resource should // still be remote. assert(jars.contains(tmpHttpJarPath)) - } else { + } else if (!enableHttpFs || isSchemeBlacklisted("http")) { // If Http FS is not supported by yarn service, or http scheme is configured to be force // downloading, the URI of remote http resource should be changed to a local one. val jarName = new File(tmpHttpJar.toURI).getName diff --git a/core/src/test/scala/org/apache/spark/input/WholeTextFileInputFormatSuite.scala b/core/src/test/scala/org/apache/spark/input/WholeTextFileInputFormatSuite.scala new file mode 100644 index 0000000000000..817dc082b7d38 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/input/WholeTextFileInputFormatSuite.scala @@ -0,0 +1,96 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.input + +import java.io.{DataOutputStream, File, FileOutputStream} + +import scala.collection.immutable.IndexedSeq + +import org.scalatest.BeforeAndAfterAll + +import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} +import org.apache.spark.internal.Logging +import org.apache.spark.util.Utils + +/** + * Tests the correctness of + * [[org.apache.spark.input.WholeTextFileInputFormat WholeTextFileInputFormat]]. A temporary + * directory containing files is created as fake input which is deleted in the end. + */ +class WholeTextFileInputFormatSuite extends SparkFunSuite with BeforeAndAfterAll with Logging { + private var sc: SparkContext = _ + + override def beforeAll() { + super.beforeAll() + val conf = new SparkConf() + sc = new SparkContext("local", "test", conf) + } + + override def afterAll() { + try { + sc.stop() + } finally { + super.afterAll() + } + } + + private def createNativeFile(inputDir: File, fileName: String, contents: Array[Byte], + compress: Boolean) = { + val path = s"${inputDir.toString}/$fileName" + val out = new DataOutputStream(new FileOutputStream(path)) + out.write(contents, 0, contents.length) + out.close() + } + + test("for small files minimum split size per node and per rack should be less than or equal to " + + "maximum split size.") { + var dir : File = null; + try { + dir = Utils.createTempDir() + logInfo(s"Local disk address is ${dir.toString}.") + + // Set the minsize per node and rack to be larger than the size of the input file. + sc.hadoopConfiguration.setLong( + "mapreduce.input.fileinputformat.split.minsize.per.node", 123456) + sc.hadoopConfiguration.setLong( + "mapreduce.input.fileinputformat.split.minsize.per.rack", 123456) + + WholeTextFileInputFormatSuite.files.foreach { case (filename, contents) => + createNativeFile(dir, filename, contents, false) + } + // ensure spark job runs successfully without exceptions from the CombineFileInputFormat + assert(sc.wholeTextFiles(dir.toString).count == 3) + } finally { + Utils.deleteRecursively(dir) + } + } +} + +/** + * Files to be tested are defined here. + */ +object WholeTextFileInputFormatSuite { + private val testWords: IndexedSeq[Byte] = "Spark is easy to use.\n".map(_.toByte) + + private val fileNames = Array("part-00000", "part-00001", "part-00002") + private val fileLengths = Array(10, 100, 1000) + + private val files = fileLengths.zip(fileNames).map { case (upperBound, filename) => + filename -> Stream.continually(testWords.toList.toStream).flatten.take(upperBound).toArray + }.toMap +} diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala index 191c61250ce21..5148ce05bd918 100644 --- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala @@ -154,6 +154,16 @@ class RDDSuite extends SparkFunSuite with SharedSparkContext { } } + test("SPARK-23778: empty RDD in union should not produce a UnionRDD") { + val rddWithPartitioner = sc.parallelize(Seq(1 -> true)).partitionBy(new HashPartitioner(1)) + val emptyRDD = sc.emptyRDD[(Int, Boolean)] + val unionRDD = sc.union(emptyRDD, rddWithPartitioner) + assert(unionRDD.isInstanceOf[PartitionerAwareUnionRDD[_]]) + val unionAllEmptyRDD = sc.union(emptyRDD, emptyRDD) + assert(unionAllEmptyRDD.isInstanceOf[UnionRDD[_]]) + assert(unionAllEmptyRDD.collect().isEmpty) + } + test("partitioner aware union") { def makeRDDWithPartitioner(seq: Seq[Int]): RDD[Int] = { sc.makeRDD(seq, 1) @@ -1047,7 +1057,9 @@ class RDDSuite extends SparkFunSuite with SharedSparkContext { private class CyclicalDependencyRDD[T: ClassTag] extends RDD[T](sc, Nil) { private val mutableDependencies: ArrayBuffer[Dependency[_]] = ArrayBuffer.empty override def compute(p: Partition, c: TaskContext): Iterator[T] = Iterator.empty - override def getPartitions: Array[Partition] = Array.empty + override def getPartitions: Array[Partition] = Array(new Partition { + override def index: Int = 0 + }) override def getDependencies: Seq[Dependency[_]] = mutableDependencies def addDependency(dep: Dependency[_]) { mutableDependencies += dep diff --git a/core/src/test/scala/org/apache/spark/scheduler/MapStatusSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/MapStatusSuite.scala index 2155a0f2b6c21..354e6386fa60e 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/MapStatusSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/MapStatusSuite.scala @@ -188,4 +188,32 @@ class MapStatusSuite extends SparkFunSuite { assert(count === 3000) } } + + test("SPARK-24519: HighlyCompressedMapStatus has configurable threshold") { + val conf = new SparkConf() + val env = mock(classOf[SparkEnv]) + doReturn(conf).when(env).conf + SparkEnv.set(env) + val sizes = Array.fill[Long](500)(150L) + // Test default value + val status = MapStatus(null, sizes) + assert(status.isInstanceOf[CompressedMapStatus]) + // Test Non-positive values + for (s <- -1 to 0) { + assertThrows[IllegalArgumentException] { + conf.set(config.SHUFFLE_MIN_NUM_PARTS_TO_HIGHLY_COMPRESS, s) + val status = MapStatus(null, sizes) + } + } + // Test positive values + Seq(1, 100, 499, 500, 501).foreach { s => + conf.set(config.SHUFFLE_MIN_NUM_PARTS_TO_HIGHLY_COMPRESS, s) + val status = MapStatus(null, sizes) + if(sizes.length > s) { + assert(status.isInstanceOf[HighlyCompressedMapStatus]) + } else { + assert(status.isInstanceOf[CompressedMapStatus]) + } + } + } } diff --git a/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala index 03b1903902491..158c9eb75f2b6 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala @@ -35,6 +35,7 @@ import org.scalatest.BeforeAndAfter import org.apache.spark._ import org.apache.spark.internal.io.{FileCommitProtocol, HadoopMapRedCommitProtocol, SparkHadoopWriterUtils} import org.apache.spark.rdd.{FakeOutputCommitter, RDD} +import org.apache.spark.shuffle.FetchFailedException import org.apache.spark.util.{ThreadUtils, Utils} /** @@ -153,7 +154,7 @@ class OutputCommitCoordinatorSuite extends SparkFunSuite with BeforeAndAfter { test("Job should not complete if all commits are denied") { // Create a mock OutputCommitCoordinator that denies all attempts to commit doReturn(false).when(outputCommitCoordinator).handleAskPermissionToCommit( - Matchers.any(), Matchers.any(), Matchers.any()) + Matchers.any(), Matchers.any(), Matchers.any(), Matchers.any()) val rdd: RDD[Int] = sc.parallelize(Seq(1), 1) def resultHandler(x: Int, y: Unit): Unit = {} val futureAction: SimpleFutureAction[Unit] = sc.submitJob[Int, Unit, Unit](rdd, @@ -169,45 +170,106 @@ class OutputCommitCoordinatorSuite extends SparkFunSuite with BeforeAndAfter { test("Only authorized committer failures can clear the authorized committer lock (SPARK-6614)") { val stage: Int = 1 + val stageAttempt: Int = 1 val partition: Int = 2 val authorizedCommitter: Int = 3 val nonAuthorizedCommitter: Int = 100 outputCommitCoordinator.stageStart(stage, maxPartitionId = 2) - assert(outputCommitCoordinator.canCommit(stage, partition, authorizedCommitter)) - assert(!outputCommitCoordinator.canCommit(stage, partition, nonAuthorizedCommitter)) + assert(outputCommitCoordinator.canCommit(stage, stageAttempt, partition, authorizedCommitter)) + assert(!outputCommitCoordinator.canCommit(stage, stageAttempt, partition, + nonAuthorizedCommitter)) // The non-authorized committer fails - outputCommitCoordinator.taskCompleted( - stage, partition, attemptNumber = nonAuthorizedCommitter, reason = TaskKilled("test")) + outputCommitCoordinator.taskCompleted(stage, stageAttempt, partition, + attemptNumber = nonAuthorizedCommitter, reason = TaskKilled("test")) // New tasks should still not be able to commit because the authorized committer has not failed - assert( - !outputCommitCoordinator.canCommit(stage, partition, nonAuthorizedCommitter + 1)) + assert(!outputCommitCoordinator.canCommit(stage, stageAttempt, partition, + nonAuthorizedCommitter + 1)) // The authorized committer now fails, clearing the lock - outputCommitCoordinator.taskCompleted( - stage, partition, attemptNumber = authorizedCommitter, reason = TaskKilled("test")) + outputCommitCoordinator.taskCompleted(stage, stageAttempt, partition, + attemptNumber = authorizedCommitter, reason = TaskKilled("test")) // A new task should now be allowed to become the authorized committer - assert( - outputCommitCoordinator.canCommit(stage, partition, nonAuthorizedCommitter + 2)) + assert(outputCommitCoordinator.canCommit(stage, stageAttempt, partition, + nonAuthorizedCommitter + 2)) // There can only be one authorized committer - assert( - !outputCommitCoordinator.canCommit(stage, partition, nonAuthorizedCommitter + 3)) - } - - test("Duplicate calls to canCommit from the authorized committer gets idempotent responses.") { - val rdd = sc.parallelize(Seq(1), 1) - sc.runJob(rdd, OutputCommitFunctions(tempDir.getAbsolutePath).callCanCommitMultipleTimes _, - 0 until rdd.partitions.size) + assert(!outputCommitCoordinator.canCommit(stage, stageAttempt, partition, + nonAuthorizedCommitter + 3)) } test("SPARK-19631: Do not allow failed attempts to be authorized for committing") { val stage: Int = 1 + val stageAttempt: Int = 1 val partition: Int = 1 val failedAttempt: Int = 0 outputCommitCoordinator.stageStart(stage, maxPartitionId = 1) - outputCommitCoordinator.taskCompleted(stage, partition, attemptNumber = failedAttempt, + outputCommitCoordinator.taskCompleted(stage, stageAttempt, partition, + attemptNumber = failedAttempt, reason = ExecutorLostFailure("0", exitCausedByApp = true, None)) - assert(!outputCommitCoordinator.canCommit(stage, partition, failedAttempt)) - assert(outputCommitCoordinator.canCommit(stage, partition, failedAttempt + 1)) + assert(!outputCommitCoordinator.canCommit(stage, stageAttempt, partition, failedAttempt)) + assert(outputCommitCoordinator.canCommit(stage, stageAttempt, partition, failedAttempt + 1)) + } + + test("SPARK-24589: Differentiate tasks from different stage attempts") { + var stage = 1 + val taskAttempt = 1 + val partition = 1 + + outputCommitCoordinator.stageStart(stage, maxPartitionId = 1) + assert(outputCommitCoordinator.canCommit(stage, 1, partition, taskAttempt)) + assert(!outputCommitCoordinator.canCommit(stage, 2, partition, taskAttempt)) + + // Fail the task in the first attempt, the task in the second attempt should succeed. + stage += 1 + outputCommitCoordinator.stageStart(stage, maxPartitionId = 1) + outputCommitCoordinator.taskCompleted(stage, 1, partition, taskAttempt, + ExecutorLostFailure("0", exitCausedByApp = true, None)) + assert(!outputCommitCoordinator.canCommit(stage, 1, partition, taskAttempt)) + assert(outputCommitCoordinator.canCommit(stage, 2, partition, taskAttempt)) + + // Commit the 1st attempt, fail the 2nd attempt, make sure 3rd attempt cannot commit, + // then fail the 1st attempt and make sure the 4th one can commit again. + stage += 1 + outputCommitCoordinator.stageStart(stage, maxPartitionId = 1) + assert(outputCommitCoordinator.canCommit(stage, 1, partition, taskAttempt)) + outputCommitCoordinator.taskCompleted(stage, 2, partition, taskAttempt, + ExecutorLostFailure("0", exitCausedByApp = true, None)) + assert(!outputCommitCoordinator.canCommit(stage, 3, partition, taskAttempt)) + outputCommitCoordinator.taskCompleted(stage, 1, partition, taskAttempt, + ExecutorLostFailure("0", exitCausedByApp = true, None)) + assert(outputCommitCoordinator.canCommit(stage, 4, partition, taskAttempt)) + } + + test("SPARK-24589: Make sure stage state is cleaned up") { + // Normal application without stage failures. + sc.parallelize(1 to 100, 100) + .map { i => (i % 10, i) } + .reduceByKey(_ + _) + .collect() + + assert(sc.dagScheduler.outputCommitCoordinator.isEmpty) + + // Force failures in a few tasks so that a stage is retried. Collect the ID of the failing + // stage so that we can check the state of the output committer. + val retriedStage = sc.parallelize(1 to 100, 10) + .map { i => (i % 10, i) } + .reduceByKey { case (_, _) => + val ctx = TaskContext.get() + if (ctx.stageAttemptNumber() == 0) { + throw new FetchFailedException(SparkEnv.get.blockManager.blockManagerId, 1, 1, 1, + new Exception("Failure for test.")) + } else { + ctx.stageId() + } + } + .collect() + .map { case (k, v) => v } + .toSet + + assert(retriedStage.size === 1) + assert(sc.dagScheduler.outputCommitCoordinator.isEmpty) + verify(sc.env.outputCommitCoordinator, times(2)) + .stageStart(Matchers.eq(retriedStage.head), Matchers.any()) + verify(sc.env.outputCommitCoordinator).stageEnd(Matchers.eq(retriedStage.head)) } } @@ -243,16 +305,6 @@ private case class OutputCommitFunctions(tempDirPath: String) { if (ctx.attemptNumber == 0) failingOutputCommitter else successfulOutputCommitter) } - // Receiver should be idempotent for AskPermissionToCommitOutput - def callCanCommitMultipleTimes(iter: Iterator[Int]): Unit = { - val ctx = TaskContext.get() - val canCommit1 = SparkEnv.get.outputCommitCoordinator - .canCommit(ctx.stageId(), ctx.partitionId(), ctx.attemptNumber()) - val canCommit2 = SparkEnv.get.outputCommitCoordinator - .canCommit(ctx.stageId(), ctx.partitionId(), ctx.attemptNumber()) - assert(canCommit1 && canCommit2) - } - private def runCommitWithProvidedCommitter( ctx: TaskContext, iter: Iterator[Int], diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index b19d8ebf72c61..08172f0b07b75 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -1422,6 +1422,19 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE assert(mockBlockTransferService.tempFileManager === store.remoteBlockTempFileManager) } + test("query locations of blockIds") { + val mockBlockManagerMaster = mock(classOf[BlockManagerMaster]) + val blockLocations = Seq(BlockManagerId("1", "host1", 100), BlockManagerId("2", "host2", 200)) + when(mockBlockManagerMaster.getLocations(mc.any[Array[BlockId]])) + .thenReturn(Array(blockLocations)) + val env = mock(classOf[SparkEnv]) + + val blockIds: Array[BlockId] = Array(StreamBlockId(1, 2)) + val locs = BlockManager.blockIdsToLocations(blockIds, env, mockBlockManagerMaster) + val expectedLocs = Seq("executor_host1_1", "executor_host2_2") + assert(locs(blockIds(0)) == expectedLocs) + } + class MockBlockTransferService(val maxFailures: Int) extends BlockTransferService { var numCalls = 0 var tempFileManager: TempFileManager = null diff --git a/dev/.rat-excludes b/dev/.rat-excludes index 9552d001a079c..466135e72233a 100644 --- a/dev/.rat-excludes +++ b/dev/.rat-excludes @@ -11,6 +11,10 @@ cache .rat-excludes .*md derby.log +licenses/* +licenses-binary/* +LICENSE +NOTICE TAGS RELEASE control @@ -106,3 +110,4 @@ spark-warehouse structured-streaming/* kafka-source-initial-offset-version-2.1.0.bin kafka-source-initial-offset-future-version.bin +vote.tmpl diff --git a/dev/create-release/do-release-docker.sh b/dev/create-release/do-release-docker.sh new file mode 100755 index 0000000000000..fa7b73cdb40ec --- /dev/null +++ b/dev/create-release/do-release-docker.sh @@ -0,0 +1,143 @@ +#!/usr/bin/env bash + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# +# Creates a Spark release candidate. The script will update versions, tag the branch, +# build Spark binary packages and documentation, and upload maven artifacts to a staging +# repository. There is also a dry run mode where only local builds are performed, and +# nothing is uploaded to the ASF repos. +# +# Run with "-h" for options. +# + +set -e +SELF=$(cd $(dirname $0) && pwd) +. "$SELF/release-util.sh" + +function usage { + local NAME=$(basename $0) + cat < "$GPG_KEY_FILE" + +run_silent "Building spark-rm image with tag $IMGTAG..." "docker-build.log" \ + docker build -t "spark-rm:$IMGTAG" --build-arg UID=$UID "$SELF/spark-rm" + +# Write the release information to a file with environment variables to be used when running the +# image. +ENVFILE="$WORKDIR/env.list" +fcreate_secure "$ENVFILE" + +function cleanup { + rm -f "$ENVFILE" + rm -f "$GPG_KEY_FILE" +} + +trap cleanup EXIT + +cat > $ENVFILE <> $ENVFILE + JAVA_VOL="--volume $JAVA:/opt/spark-java" +fi + +echo "Building $RELEASE_TAG; output will be at $WORKDIR/output" +docker run -ti \ + --env-file "$ENVFILE" \ + --volume "$WORKDIR:/opt/spark-rm" \ + $JAVA_VOL \ + "spark-rm:$IMGTAG" diff --git a/dev/create-release/do-release.sh b/dev/create-release/do-release.sh new file mode 100755 index 0000000000000..f1d4f3ab5ddec --- /dev/null +++ b/dev/create-release/do-release.sh @@ -0,0 +1,81 @@ +#!/usr/bin/env bash + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +SELF=$(cd $(dirname $0) && pwd) +. "$SELF/release-util.sh" + +while getopts "bn" opt; do + case $opt in + b) GIT_BRANCH=$OPTARG ;; + n) DRY_RUN=1 ;; + ?) error "Invalid option: $OPTARG" ;; + esac +done + +if [ "$RUNNING_IN_DOCKER" = "1" ]; then + # Inside docker, need to import the GPG key stored in the current directory. + echo $GPG_PASSPHRASE | $GPG --passphrase-fd 0 --import "$SELF/gpg.key" + + # We may need to adjust the path since JAVA_HOME may be overridden by the driver script. + if [ -n "$JAVA_HOME" ]; then + export PATH="$JAVA_HOME/bin:$PATH" + else + # JAVA_HOME for the openjdk package. + export JAVA_HOME=/usr + fi +else + # Outside docker, need to ask for information about the release. + get_release_info +fi + +function should_build { + local WHAT=$1 + [ -z "$RELEASE_STEP" ] || [ "$WHAT" = "$RELEASE_STEP" ] +} + +if should_build "tag" && [ $SKIP_TAG = 0 ]; then + run_silent "Creating release tag $RELEASE_TAG..." "tag.log" \ + "$SELF/release-tag.sh" + echo "It may take some time for the tag to be synchronized to github." + echo "Press enter when you've verified that the new tag ($RELEASE_TAG) is available." + read +else + echo "Skipping tag creation for $RELEASE_TAG." +fi + +if should_build "build"; then + run_silent "Building Spark..." "build.log" \ + "$SELF/release-build.sh" package +else + echo "Skipping build step." +fi + +if should_build "docs"; then + run_silent "Building documentation..." "docs.log" \ + "$SELF/release-build.sh" docs +else + echo "Skipping docs step." +fi + +if should_build "publish"; then + run_silent "Publishing release" "publish.log" \ + "$SELF/release-build.sh" publish-release +else + echo "Skipping publish step." +fi diff --git a/dev/create-release/release-build.sh b/dev/create-release/release-build.sh index 5faa3d3260a56..24a62a8f4c7d3 100755 --- a/dev/create-release/release-build.sh +++ b/dev/create-release/release-build.sh @@ -17,6 +17,9 @@ # limitations under the License. # +SELF=$(cd $(dirname $0) && pwd) +. "$SELF/release-util.sh" + function exit_with_usage { cat << EOF usage: release-build.sh @@ -87,49 +90,56 @@ NEXUS_ROOT=https://repository.apache.org/service/local/staging NEXUS_PROFILE=d63f592e7eac0 # Profile for Spark staging uploads BASE_DIR=$(pwd) -MVN="build/mvn --force" - -# Hive-specific profiles for some builds -HIVE_PROFILES="-Phive -Phive-thriftserver" -# Profiles for publishing snapshots and release to Maven Central -PUBLISH_PROFILES="-Pmesos -Pyarn -Pkubernetes -Pflume $HIVE_PROFILES -Pspark-ganglia-lgpl -Pkinesis-asl" -# Profiles for building binary releases -BASE_RELEASE_PROFILES="-Pmesos -Pyarn -Pkubernetes -Pflume -Psparkr" -# Scala 2.11 only profiles for some builds -SCALA_2_11_PROFILES="-Pkafka-0-8" -# Scala 2.12 only profiles for some builds -SCALA_2_12_PROFILES="-Pscala-2.12" +init_java +init_maven_sbt rm -rf spark -git clone https://git-wip-us.apache.org/repos/asf/spark.git +git clone "$ASF_REPO" cd spark git checkout $GIT_REF git_hash=`git rev-parse --short HEAD` echo "Checked out Spark git hash $git_hash" if [ -z "$SPARK_VERSION" ]; then - SPARK_VERSION=$($MVN help:evaluate -Dexpression=project.version \ - | grep -v INFO | grep -v WARNING | grep -v Download) + # Run $MVN in a separate command so that 'set -e' does the right thing. + TMP=$(mktemp) + $MVN help:evaluate -Dexpression=project.version > $TMP + SPARK_VERSION=$(cat $TMP | grep -v INFO | grep -v WARNING | grep -v Download) + rm $TMP fi -# Verify we have the right java version set -if [ -z "$JAVA_HOME" ]; then - echo "Please set JAVA_HOME." - exit 1 +# Depending on the version being built, certain extra profiles need to be activated, and +# different versions of Scala are supported. +BASE_PROFILES="-Pmesos -Pyarn" +PUBLISH_SCALA_2_10=0 +SCALA_2_10_PROFILES="-Pscala-2.10" +SCALA_2_11_PROFILES= +SCALA_2_12_PROFILES="-Pscala-2.12" + +if [[ $SPARK_VERSION > "2.3" ]]; then + BASE_PROFILES="$BASE_PROFILES -Pkubernetes -Pflume" + SCALA_2_11_PROFILES="-Pkafka-0-8" +else + PUBLISH_SCALA_2_10=1 fi -java_version=$("${JAVA_HOME}"/bin/javac -version 2>&1 | cut -d " " -f 2) +# Hive-specific profiles for some builds +HIVE_PROFILES="-Phive -Phive-thriftserver" +# Profiles for publishing snapshots and release to Maven Central +PUBLISH_PROFILES="$BASE_PROFILES $HIVE_PROFILES -Pspark-ganglia-lgpl -Pkinesis-asl" +# Profiles for building binary releases +BASE_RELEASE_PROFILES="$BASE_PROFILES -Psparkr" if [[ ! $SPARK_VERSION < "2.2." ]]; then - if [[ $java_version < "1.8." ]]; then - echo "Java version $java_version is less than required 1.8 for 2.2+" + if [[ $JAVA_VERSION < "1.8." ]]; then + echo "Java version $JAVA_VERSION is less than required 1.8 for 2.2+" echo "Please set JAVA_HOME correctly." exit 1 fi else - if [[ $java_version > "1.7." ]]; then + if ! [[ $JAVA_VERSION =~ 1\.7\..* ]]; then if [ -z "$JAVA_7_HOME" ]; then - echo "Java version $java_version is higher than required 1.7 for pre-2.2" + echo "Java version $JAVA_VERSION is higher than required 1.7 for pre-2.2" echo "Please set JAVA_HOME correctly." exit 1 else @@ -174,8 +184,9 @@ if [[ "$1" == "package" ]]; then FLAGS=$2 ZINC_PORT=$3 BUILD_PACKAGE=$4 - cp -r spark spark-$SPARK_VERSION-bin-$NAME + echo "Building binary dist $NAME" + cp -r spark spark-$SPARK_VERSION-bin-$NAME cd spark-$SPARK_VERSION-bin-$NAME # TODO There should probably be a flag to make-distribution to allow 2.12 support @@ -244,31 +255,39 @@ if [[ "$1" == "package" ]]; then spark-$SPARK_VERSION-bin-$NAME.tgz.sha512 } - # TODO: Check exit codes of children here: - # http://stackoverflow.com/questions/1570262/shell-get-exit-code-of-background-process - # We increment the Zinc port each time to avoid OOM's and other craziness if multiple builds # share the same Zinc server. - make_binary_release "hadoop2.6" "-Phadoop-2.6 $HIVE_PROFILES $SCALA_2_11_PROFILES $BASE_RELEASE_PROFILES" "3035" "withr" & - make_binary_release "hadoop2.7" "-Phadoop-2.7 $HIVE_PROFILES $SCALA_2_11_PROFILES $BASE_RELEASE_PROFILES" "3036" "withpip" & - make_binary_release "without-hadoop" "-Phadoop-provided $SCALA_2_11_PROFILES $BASE_RELEASE_PROFILES" "3038" & - wait - rm -rf spark-$SPARK_VERSION-bin-*/ + if ! make_binary_release "hadoop2.6" "$MVN_EXTRA_OPTS -B -Phadoop-2.6 $HIVE_PROFILES $SCALA_2_11_PROFILES $BASE_RELEASE_PROFILES" "3035" "withr"; then + error "Failed to build hadoop2.6 package. Check logs for details." + fi + if ! is_dry_run; then + if ! make_binary_release "hadoop2.7" "$MVN_EXTRA_OPTS -B -Phadoop-2.7 $HIVE_PROFILES $SCALA_2_11_PROFILES $BASE_RELEASE_PROFILES" "3036" "withpip"; then + error "Failed to build hadoop2.7 package. Check logs for details." + fi + if ! make_binary_release "without-hadoop" "$MVN_EXTRA_OPTS -B -Phadoop-provided $SCALA_2_11_PROFILES $BASE_RELEASE_PROFILES" "3037"; then + error "Failed to build without-hadoop package. Check logs for details." + fi + fi - svn co --depth=empty $RELEASE_STAGING_LOCATION svn-spark - rm -rf "svn-spark/${DEST_DIR_NAME}-bin" - mkdir -p "svn-spark/${DEST_DIR_NAME}-bin" + rm -rf spark-$SPARK_VERSION-bin-*/ - echo "Copying release tarballs" - cp spark-* "svn-spark/${DEST_DIR_NAME}-bin/" - cp pyspark-* "svn-spark/${DEST_DIR_NAME}-bin/" - cp SparkR_* "svn-spark/${DEST_DIR_NAME}-bin/" - svn add "svn-spark/${DEST_DIR_NAME}-bin" + if ! is_dry_run; then + svn co --depth=empty $RELEASE_STAGING_LOCATION svn-spark + rm -rf "svn-spark/${DEST_DIR_NAME}-bin" + mkdir -p "svn-spark/${DEST_DIR_NAME}-bin" + + echo "Copying release tarballs" + cp spark-* "svn-spark/${DEST_DIR_NAME}-bin/" + cp pyspark-* "svn-spark/${DEST_DIR_NAME}-bin/" + cp SparkR_* "svn-spark/${DEST_DIR_NAME}-bin/" + svn add "svn-spark/${DEST_DIR_NAME}-bin" + + cd svn-spark + svn ci --username $ASF_USERNAME --password "$ASF_PASSWORD" -m"Apache Spark $SPARK_PACKAGE_VERSION" + cd .. + rm -rf svn-spark + fi - cd svn-spark - svn ci --username $ASF_USERNAME --password "$ASF_PASSWORD" -m"Apache Spark $SPARK_PACKAGE_VERSION" - cd .. - rm -rf svn-spark exit 0 fi @@ -282,18 +301,22 @@ if [[ "$1" == "docs" ]]; then cd .. cd .. - svn co --depth=empty $RELEASE_STAGING_LOCATION svn-spark - rm -rf "svn-spark/${DEST_DIR_NAME}-docs" - mkdir -p "svn-spark/${DEST_DIR_NAME}-docs" + if ! is_dry_run; then + svn co --depth=empty $RELEASE_STAGING_LOCATION svn-spark + rm -rf "svn-spark/${DEST_DIR_NAME}-docs" + mkdir -p "svn-spark/${DEST_DIR_NAME}-docs" - echo "Copying release documentation" - cp -R "spark/docs/_site" "svn-spark/${DEST_DIR_NAME}-docs/" - svn add "svn-spark/${DEST_DIR_NAME}-docs" + echo "Copying release documentation" + cp -R "spark/docs/_site" "svn-spark/${DEST_DIR_NAME}-docs/" + svn add "svn-spark/${DEST_DIR_NAME}-docs" - cd svn-spark - svn ci --username $ASF_USERNAME --password "$ASF_PASSWORD" -m"Apache Spark $SPARK_PACKAGE_VERSION docs" - cd .. - rm -rf svn-spark + cd svn-spark + svn ci --username $ASF_USERNAME --password "$ASF_PASSWORD" -m"Apache Spark $SPARK_PACKAGE_VERSION docs" + cd .. + rm -rf svn-spark + fi + + mv "spark/docs/_site" docs/ exit 0 fi @@ -341,13 +364,15 @@ if [[ "$1" == "publish-release" ]]; then # Using Nexus API documented here: # https://support.sonatype.com/entries/39720203-Uploading-to-a-Staging-Repository-via-REST-API - echo "Creating Nexus staging repository" - repo_request="Apache Spark $SPARK_VERSION (commit $git_hash)" - out=$(curl -X POST -d "$repo_request" -u $ASF_USERNAME:$ASF_PASSWORD \ - -H "Content-Type:application/xml" -v \ - $NEXUS_ROOT/profiles/$NEXUS_PROFILE/start) - staged_repo_id=$(echo $out | sed -e "s/.*\(orgapachespark-[0-9]\{4\}\).*/\1/") - echo "Created Nexus staging repository: $staged_repo_id" + if ! is_dry_run; then + echo "Creating Nexus staging repository" + repo_request="Apache Spark $SPARK_VERSION (commit $git_hash)" + out=$(curl -X POST -d "$repo_request" -u $ASF_USERNAME:$ASF_PASSWORD \ + -H "Content-Type:application/xml" -v \ + $NEXUS_ROOT/profiles/$NEXUS_PROFILE/start) + staged_repo_id=$(echo $out | sed -e "s/.*\(orgapachespark-[0-9]\{4\}\).*/\1/") + echo "Created Nexus staging repository: $staged_repo_id" + fi tmp_repo=$(mktemp -d spark-repo-XXXXX) @@ -356,6 +381,12 @@ if [[ "$1" == "publish-release" ]]; then $MVN -DzincPort=$ZINC_PORT -Dmaven.repo.local=$tmp_repo -DskipTests $SCALA_2_11_PROFILES $PUBLISH_PROFILES clean install + if ! is_dry_run && [[ $PUBLISH_SCALA_2_10 = 1 ]]; then + ./dev/change-scala-version.sh 2.10 + $MVN -DzincPort=$((ZINC_PORT + 1)) -Dmaven.repo.local=$tmp_repo -Dscala-2.10 \ + -DskipTests $PUBLISH_PROFILES $SCALA_2_10_PROFILES clean install + fi + #./dev/change-scala-version.sh 2.12 #$MVN -DzincPort=$ZINC_PORT -Dmaven.repo.local=$tmp_repo \ # -DskipTests $SCALA_2_12_PROFILES §$PUBLISH_PROFILES clean install @@ -386,23 +417,26 @@ if [[ "$1" == "publish-release" ]]; then sha1sum $file | cut -f1 -d' ' > $file.sha1 done - nexus_upload=$NEXUS_ROOT/deployByRepositoryId/$staged_repo_id - echo "Uplading files to $nexus_upload" - for file in $(find . -type f) - do - # strip leading ./ - file_short=$(echo $file | sed -e "s/\.\///") - dest_url="$nexus_upload/org/apache/spark/$file_short" - echo " Uploading $file_short" - curl -u $ASF_USERNAME:$ASF_PASSWORD --upload-file $file_short $dest_url - done + if ! is_dry_run; then + nexus_upload=$NEXUS_ROOT/deployByRepositoryId/$staged_repo_id + echo "Uplading files to $nexus_upload" + for file in $(find . -type f) + do + # strip leading ./ + file_short=$(echo $file | sed -e "s/\.\///") + dest_url="$nexus_upload/org/apache/spark/$file_short" + echo " Uploading $file_short" + curl -u $ASF_USERNAME:$ASF_PASSWORD --upload-file $file_short $dest_url + done + + echo "Closing nexus staging repository" + repo_request="$staged_repo_idApache Spark $SPARK_VERSION (commit $git_hash)" + out=$(curl -X POST -d "$repo_request" -u $ASF_USERNAME:$ASF_PASSWORD \ + -H "Content-Type:application/xml" -v \ + $NEXUS_ROOT/profiles/$NEXUS_PROFILE/finish) + echo "Closed Nexus staging repository: $staged_repo_id" + fi - echo "Closing nexus staging repository" - repo_request="$staged_repo_idApache Spark $SPARK_VERSION (commit $git_hash)" - out=$(curl -X POST -d "$repo_request" -u $ASF_USERNAME:$ASF_PASSWORD \ - -H "Content-Type:application/xml" -v \ - $NEXUS_ROOT/profiles/$NEXUS_PROFILE/finish) - echo "Closed Nexus staging repository: $staged_repo_id" popd rm -rf $tmp_repo cd .. diff --git a/dev/create-release/release-tag.sh b/dev/create-release/release-tag.sh index a05716a5f66bb..628bc0504c9c8 100755 --- a/dev/create-release/release-tag.sh +++ b/dev/create-release/release-tag.sh @@ -17,6 +17,9 @@ # limitations under the License. # +SELF=$(cd $(dirname $0) && pwd) +. "$SELF/release-util.sh" + function exit_with_usage { cat << EOF usage: tag-release.sh @@ -36,6 +39,7 @@ EOF } set -e +set -o pipefail if [[ $@ == *"help"* ]]; then exit_with_usage @@ -54,8 +58,10 @@ for env in ASF_USERNAME ASF_PASSWORD RELEASE_VERSION RELEASE_TAG NEXT_VERSION GI fi done +init_java +init_maven_sbt + ASF_SPARK_REPO="git-wip-us.apache.org/repos/asf/spark.git" -MVN="build/mvn --force" rm -rf spark git clone "https://$ASF_USERNAME:$ASF_PASSWORD@$ASF_SPARK_REPO" -b $GIT_BRANCH @@ -94,9 +100,15 @@ sed -i".tmp7" 's/SPARK_VERSION_SHORT:.*$/SPARK_VERSION_SHORT: '"$R_NEXT_VERSION" git commit -a -m "Preparing development version $NEXT_VERSION" -# Push changes -git push origin $RELEASE_TAG -git push origin HEAD:$GIT_BRANCH - -cd .. -rm -rf spark +if ! is_dry_run; then + # Push changes + git push origin $RELEASE_TAG + git push origin HEAD:$GIT_BRANCH + + cd .. + rm -rf spark +else + cd .. + mv spark spark.tag + echo "Clone with version changes and tag available as spark.tag in the output directory." +fi diff --git a/dev/create-release/release-util.sh b/dev/create-release/release-util.sh new file mode 100644 index 0000000000000..7426b0d6ca08d --- /dev/null +++ b/dev/create-release/release-util.sh @@ -0,0 +1,228 @@ +#!/usr/bin/env bash + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +DRY_RUN=${DRY_RUN:-0} +GPG="gpg --no-tty --batch" +ASF_REPO="https://git-wip-us.apache.org/repos/asf/spark.git" +ASF_REPO_WEBUI="https://git-wip-us.apache.org/repos/asf?p=spark.git" + +function error { + echo "$*" + exit 1 +} + +function read_config { + local PROMPT="$1" + local DEFAULT="$2" + local REPLY= + + read -p "$PROMPT [$DEFAULT]: " REPLY + local RETVAL="${REPLY:-$DEFAULT}" + if [ -z "$RETVAL" ]; then + error "$PROMPT is must be provided." + fi + echo "$RETVAL" +} + +function parse_version { + grep -e '.*' | \ + head -n 2 | tail -n 1 | cut -d'>' -f2 | cut -d '<' -f1 +} + +function run_silent { + local BANNER="$1" + local LOG_FILE="$2" + shift 2 + + echo "========================" + echo "= $BANNER" + echo "Command: $@" + echo "Log file: $LOG_FILE" + + "$@" 1>"$LOG_FILE" 2>&1 + + local EC=$? + if [ $EC != 0 ]; then + echo "Command FAILED. Check full logs for details." + tail "$LOG_FILE" + exit $EC + fi +} + +function fcreate_secure { + local FPATH="$1" + rm -f "$FPATH" + touch "$FPATH" + chmod 600 "$FPATH" +} + +function check_for_tag { + curl -s --head --fail "$ASF_REPO_WEBUI;a=commit;h=$1" >/dev/null +} + +function get_release_info { + if [ -z "$GIT_BRANCH" ]; then + # If no branch is specified, found out the latest branch from the repo. + GIT_BRANCH=$(git ls-remote --heads "$ASF_REPO" | + grep -v refs/heads/master | + awk '{print $2}' | + sort -r | + head -n 1 | + cut -d/ -f3) + fi + + export GIT_BRANCH=$(read_config "Branch" "$GIT_BRANCH") + + # Find the current version for the branch. + local VERSION=$(curl -s "$ASF_REPO_WEBUI;a=blob_plain;f=pom.xml;hb=refs/heads/$GIT_BRANCH" | + parse_version) + echo "Current branch version is $VERSION." + + if [[ ! $VERSION =~ .*-SNAPSHOT ]]; then + error "Not a SNAPSHOT version: $VERSION" + fi + + NEXT_VERSION="$VERSION" + RELEASE_VERSION="${VERSION/-SNAPSHOT/}" + SHORT_VERSION=$(echo "$VERSION" | cut -d . -f 1-2) + local REV=$(echo "$VERSION" | cut -d . -f 3) + + # Find out what rc is being prepared. + # - If the current version is "x.y.0", then this is rc1 of the "x.y.0" release. + # - If not, need to check whether the previous version has been already released or not. + # - If it has, then we're building rc1 of the current version. + # - If it has not, we're building the next RC of the previous version. + local RC_COUNT + if [ $REV != 0 ]; then + local PREV_REL_REV=$((REV - 1)) + local PREV_REL_TAG="v${SHORT_VERSION}.${PREV_REL_REV}" + if check_for_tag "$PREV_REL_TAG"; then + RC_COUNT=1 + REV=$((REV + 1)) + NEXT_VERSION="${SHORT_VERSION}.${REV}-SNAPSHOT" + else + RELEASE_VERSION="${SHORT_VERSION}.${PREV_REL_REV}" + RC_COUNT=$(git ls-remote --tags "$ASF_REPO" "v${RELEASE_VERSION}-rc*" | wc -l) + RC_COUNT=$((RC_COUNT + 1)) + fi + else + REV=$((REV + 1)) + NEXT_VERSION="${SHORT_VERSION}.${REV}-SNAPSHOT" + RC_COUNT=1 + fi + + export NEXT_VERSION + export RELEASE_VERSION=$(read_config "Release" "$RELEASE_VERSION") + + RC_COUNT=$(read_config "RC #" "$RC_COUNT") + + # Check if the RC already exists, and if re-creating the RC, skip tag creation. + RELEASE_TAG="v${RELEASE_VERSION}-rc${RC_COUNT}" + SKIP_TAG=0 + if check_for_tag "$RELEASE_TAG"; then + read -p "$RELEASE_TAG already exists. Continue anyway [y/n]? " ANSWER + if [ "$ANSWER" != "y" ]; then + error "Exiting." + fi + SKIP_TAG=1 + fi + + + export RELEASE_TAG + + GIT_REF="$RELEASE_TAG" + if is_dry_run; then + echo "This is a dry run. Please confirm the ref that will be built for testing." + GIT_REF=$(read_config "Ref" "$GIT_REF") + fi + export GIT_REF + export SPARK_PACKAGE_VERSION="$RELEASE_TAG" + + # Gather some user information. + export ASF_USERNAME=$(read_config "ASF user" "$LOGNAME") + + GIT_NAME=$(git config user.name || echo "") + export GIT_NAME=$(read_config "Full name" "$GIT_NAME") + + export GIT_EMAIL="$ASF_USERNAME@apache.org" + export GPG_KEY=$(read_config "GPG key" "$GIT_EMAIL") + + cat <&1 | cut -d " " -f 2) + export JAVA_VERSION +} + +# Initializes MVN_EXTRA_OPTS and SBT_OPTS depending on the JAVA_VERSION in use. Requires init_java. +function init_maven_sbt { + MVN="build/mvn -B" + MVN_EXTRA_OPTS= + SBT_OPTS= + if [[ $JAVA_VERSION < "1.8." ]]; then + # Needed for maven central when using Java 7. + SBT_OPTS="-Dhttps.protocols=TLSv1.1,TLSv1.2" + MVN_EXTRA_OPTS="-Dhttps.protocols=TLSv1.1,TLSv1.2" + MVN="$MVN $MVN_EXTRA_OPTS" + fi + export MVN MVN_EXTRA_OPTS SBT_OPTS +} diff --git a/dev/create-release/releaseutils.py b/dev/create-release/releaseutils.py index 32f6cbb29f0be..ab812e1bb7c04 100755 --- a/dev/create-release/releaseutils.py +++ b/dev/create-release/releaseutils.py @@ -49,13 +49,16 @@ print("Install using 'sudo pip install unidecode'") sys.exit(-1) +if sys.version < '3': + input = raw_input + # Contributors list file name contributors_file_name = "contributors.txt" # Prompt the user to answer yes or no until they do so def yesOrNoPrompt(msg): - response = raw_input("%s [y/n]: " % msg) + response = input("%s [y/n]: " % msg) while response != "y" and response != "n": return yesOrNoPrompt(msg) return response == "y" diff --git a/dev/create-release/spark-rm/Dockerfile b/dev/create-release/spark-rm/Dockerfile new file mode 100644 index 0000000000000..07ce320177f5a --- /dev/null +++ b/dev/create-release/spark-rm/Dockerfile @@ -0,0 +1,87 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Image for building Spark releases. Based on Ubuntu 16.04. +# +# Includes: +# * Java 8 +# * Ivy +# * Python/PyPandoc (2.7.12/3.5.2) +# * R-base/R-base-dev (3.3.2+) +# * Ruby 2.3 build utilities + +FROM ubuntu:16.04 + +# These arguments are just for reuse and not really meant to be customized. +ARG APT_INSTALL="apt-get install --no-install-recommends -y" + +ARG BASE_PIP_PKGS="setuptools wheel virtualenv" +ARG PIP_PKGS="pyopenssl pypandoc numpy pygments sphinx" + +# Install extra needed repos and refresh. +# - CRAN repo +# - Ruby repo (for doc generation) +# +# This is all in a single "RUN" command so that if anything changes, "apt update" is run to fetch +# the most current package versions (instead of potentially using old versions cached by docker). +RUN echo 'deb http://cran.cnr.Berkeley.edu/bin/linux/ubuntu xenial/' >> /etc/apt/sources.list && \ + gpg --keyserver keyserver.ubuntu.com --recv-key E084DAB9 && \ + gpg -a --export E084DAB9 | apt-key add - && \ + apt-get clean && \ + rm -rf /var/lib/apt/lists/* && \ + apt-get clean && \ + apt-get update && \ + $APT_INSTALL software-properties-common && \ + apt-add-repository -y ppa:brightbox/ruby-ng && \ + apt-get update && \ + # Install openjdk 8. + $APT_INSTALL openjdk-8-jdk && \ + update-alternatives --set java /usr/lib/jvm/java-8-openjdk-amd64/jre/bin/java && \ + # Install build / source control tools + $APT_INSTALL curl wget git maven ivy subversion make gcc lsof libffi-dev \ + pandoc pandoc-citeproc libssl-dev libcurl4-openssl-dev libxml2-dev && \ + ln -s -T /usr/share/java/ivy.jar /usr/share/ant/lib/ivy.jar && \ + curl -sL https://deb.nodesource.com/setup_4.x | bash && \ + $APT_INSTALL nodejs && \ + # Install needed python packages. Use pip for installing packages (for consistency). + $APT_INSTALL libpython2.7-dev libpython3-dev python-pip python3-pip && \ + pip install $BASE_PIP_PKGS && \ + pip install $PIP_PKGS && \ + cd && \ + virtualenv -p python3 p35 && \ + . p35/bin/activate && \ + pip install $BASE_PIP_PKGS && \ + pip install $PIP_PKGS && \ + # Install R packages and dependencies used when building. + # R depends on pandoc*, libssl (which are installed above). + $APT_INSTALL r-base r-base-dev && \ + $APT_INSTALL texlive-latex-base texlive texlive-fonts-extra texinfo qpdf && \ + Rscript -e "install.packages(c('curl', 'xml2', 'httr', 'devtools', 'testthat', 'knitr', 'rmarkdown', 'roxygen2', 'e1071', 'survival'), repos='http://cran.us.r-project.org/')" && \ + Rscript -e "devtools::install_github('jimhester/lintr')" && \ + # Install tools needed to build the documentation. + $APT_INSTALL ruby2.3 ruby2.3-dev && \ + gem install jekyll --no-rdoc --no-ri && \ + gem install jekyll-redirect-from && \ + gem install pygments.rb + +WORKDIR /opt/spark-rm/output + +ARG UID +RUN useradd -m -s /bin/bash -p spark-rm -u $UID spark-rm +USER spark-rm:spark-rm + +ENTRYPOINT [ "/opt/spark-rm/do-release.sh" ] diff --git a/dev/create-release/vote.tmpl b/dev/create-release/vote.tmpl new file mode 100644 index 0000000000000..2ce953c2f7ec4 --- /dev/null +++ b/dev/create-release/vote.tmpl @@ -0,0 +1,65 @@ +Please vote on releasing the following candidate as Apache Spark version {version}. + +The vote is open until {deadline} and passes if a majority +1 PMC votes are cast, with +a minimum of 3 +1 votes. + +[ ] +1 Release this package as Apache Spark {version} +[ ] -1 Do not release this package because ... + +To learn more about Apache Spark, please see http://spark.apache.org/ + +The tag to be voted on is {tag} (commit {tag_commit}): +https://github.com/apache/spark/tree/{tag} + +The release files, including signatures, digests, etc. can be found at: +https://dist.apache.org/repos/dist/dev/spark/{tag}-bin/ + +Signatures used for Spark RCs can be found in this file: +https://dist.apache.org/repos/dist/dev/spark/KEYS + +The staging repository for this release can be found at: +https://repository.apache.org/content/repositories/orgapachespark-{repo_id}/ + +The documentation corresponding to this release can be found at: +https://dist.apache.org/repos/dist/dev/spark/{tag}-docs/ + +The list of bug fixes going into {version} can be found at the following URL: +https://issues.apache.org/jira/projects/SPARK/versions/{jira_version_id} + +FAQ + +========================= +How can I help test this release? +========================= + +If you are a Spark user, you can help us test this release by taking +an existing Spark workload and running on this release candidate, then +reporting any regressions. + +If you're working in PySpark you can set up a virtual env and install +the current RC and see if anything important breaks, in the Java/Scala +you can add the staging repository to your projects resolvers and test +with the RC (make sure to clean up the artifact cache before/after so +you don't end up building with a out of date RC going forward). + +=========================================== +What should happen to JIRA tickets still targeting {version}? +=========================================== + +The current list of open tickets targeted at {version} can be found at: +{open_issues_link} + +Committers should look at those and triage. Extremely important bug +fixes, documentation, and API tweaks that impact compatibility should +be worked on immediately. Everything else please retarget to an +appropriate release. + +================== +But my bug isn't fixed? +================== + +In order to make timely releases, we will typically not hold the +release unless the bug in question is a regression from the previous +release. That being said, if there is something which is a regression +that has not been correctly targeted please ping me or a committer to +help target the issue. \ No newline at end of file diff --git a/dev/deps/spark-deps-hadoop-2.6 b/dev/deps/spark-deps-hadoop-2.6 index 723180a14febb..f50a0aac0aefc 100644 --- a/dev/deps/spark-deps-hadoop-2.6 +++ b/dev/deps/spark-deps-hadoop-2.6 @@ -122,7 +122,7 @@ jersey-server-2.22.2.jar jets3t-0.9.4.jar jetty-6.1.26.jar jetty-util-6.1.26.jar -jline-2.12.1.jar +jline-2.14.3.jar joda-time-2.9.3.jar jodd-core-3.5.2.jar jpam-1.1.jar @@ -172,10 +172,10 @@ parquet-jackson-1.10.0.jar protobuf-java-2.5.0.jar py4j-0.10.7.jar pyrolite-4.13.jar -scala-compiler-2.11.8.jar -scala-library-2.11.8.jar -scala-parser-combinators_2.11-1.0.4.jar -scala-reflect-2.11.8.jar +scala-compiler-2.11.12.jar +scala-library-2.11.12.jar +scala-parser-combinators_2.11-1.1.0.jar +scala-reflect-2.11.12.jar scala-xml_2.11-1.0.5.jar shapeless_2.11-2.3.2.jar slf4j-api-1.7.16.jar @@ -192,7 +192,7 @@ stringtemplate-3.2.1.jar super-csv-2.2.0.jar univocity-parsers-2.6.3.jar validation-api-1.1.0.Final.jar -xbean-asm5-shaded-4.4.jar +xbean-asm6-shaded-4.8.jar xercesImpl-2.9.1.jar xmlenc-0.52.jar xz-1.0.jar diff --git a/dev/deps/spark-deps-hadoop-2.7 b/dev/deps/spark-deps-hadoop-2.7 index ea08a001a1c9b..774f9dc39ce4d 100644 --- a/dev/deps/spark-deps-hadoop-2.7 +++ b/dev/deps/spark-deps-hadoop-2.7 @@ -122,7 +122,7 @@ jersey-server-2.22.2.jar jets3t-0.9.4.jar jetty-6.1.26.jar jetty-util-6.1.26.jar -jline-2.12.1.jar +jline-2.14.3.jar joda-time-2.9.3.jar jodd-core-3.5.2.jar jpam-1.1.jar @@ -173,10 +173,10 @@ parquet-jackson-1.10.0.jar protobuf-java-2.5.0.jar py4j-0.10.7.jar pyrolite-4.13.jar -scala-compiler-2.11.8.jar -scala-library-2.11.8.jar -scala-parser-combinators_2.11-1.0.4.jar -scala-reflect-2.11.8.jar +scala-compiler-2.11.12.jar +scala-library-2.11.12.jar +scala-parser-combinators_2.11-1.1.0.jar +scala-reflect-2.11.12.jar scala-xml_2.11-1.0.5.jar shapeless_2.11-2.3.2.jar slf4j-api-1.7.16.jar @@ -193,7 +193,7 @@ stringtemplate-3.2.1.jar super-csv-2.2.0.jar univocity-parsers-2.6.3.jar validation-api-1.1.0.Final.jar -xbean-asm5-shaded-4.4.jar +xbean-asm6-shaded-4.8.jar xercesImpl-2.9.1.jar xmlenc-0.52.jar xz-1.0.jar diff --git a/dev/deps/spark-deps-hadoop-3.1 b/dev/deps/spark-deps-hadoop-3.1 index da874026d7d10..19c05ad1e991f 100644 --- a/dev/deps/spark-deps-hadoop-3.1 +++ b/dev/deps/spark-deps-hadoop-3.1 @@ -122,7 +122,7 @@ jersey-server-2.22.2.jar jets3t-0.9.4.jar jetty-webapp-9.3.20.v20170531.jar jetty-xml-9.3.20.v20170531.jar -jline-2.12.1.jar +jline-2.14.3.jar joda-time-2.9.3.jar jodd-core-3.5.2.jar jpam-1.1.jar @@ -192,10 +192,10 @@ protobuf-java-2.5.0.jar py4j-0.10.7.jar pyrolite-4.13.jar re2j-1.1.jar -scala-compiler-2.11.8.jar -scala-library-2.11.8.jar -scala-parser-combinators_2.11-1.0.4.jar -scala-reflect-2.11.8.jar +scala-compiler-2.11.12.jar +scala-library-2.11.12.jar +scala-parser-combinators_2.11-1.1.0.jar +scala-reflect-2.11.12.jar scala-xml_2.11-1.0.5.jar shapeless_2.11-2.3.2.jar slf4j-api-1.7.16.jar @@ -214,7 +214,7 @@ token-provider-1.0.1.jar univocity-parsers-2.6.3.jar validation-api-1.1.0.Final.jar woodstox-core-5.0.3.jar -xbean-asm5-shaded-4.4.jar +xbean-asm6-shaded-4.8.jar xz-1.0.jar zjsonpatch-0.3.0.jar zookeeper-3.4.9.jar diff --git a/dev/make-distribution.sh b/dev/make-distribution.sh index 84233c64caa9c..ad99ce55806af 100755 --- a/dev/make-distribution.sh +++ b/dev/make-distribution.sh @@ -211,9 +211,10 @@ mkdir -p "$DISTDIR/examples/src/main" cp -r "$SPARK_HOME/examples/src/main" "$DISTDIR/examples/src/" # Copy license and ASF files -cp "$SPARK_HOME/LICENSE" "$DISTDIR" -cp -r "$SPARK_HOME/licenses" "$DISTDIR" -cp "$SPARK_HOME/NOTICE" "$DISTDIR" +cp "$SPARK_HOME/LICENSE-binary" "$DISTDIR/LICENSE" +mkdir -p "$DISTDIR/licenses" +cp -r "$SPARK_HOME/licenses-binary" "$DISTDIR/licenses" +cp "$SPARK_HOME/NOTICE-binary" "$DISTDIR/NOTICE" if [ -e "$SPARK_HOME/CHANGES.txt" ]; then cp "$SPARK_HOME/CHANGES.txt" "$DISTDIR" diff --git a/dev/merge_spark_pr.py b/dev/merge_spark_pr.py index 7f46a1c8f6a7c..79c7c021fe74a 100755 --- a/dev/merge_spark_pr.py +++ b/dev/merge_spark_pr.py @@ -39,6 +39,9 @@ except ImportError: JIRA_IMPORTED = False +if sys.version < '3': + input = raw_input + # Location of your Spark git development area SPARK_HOME = os.environ.get("SPARK_HOME", os.getcwd()) # Remote name which points to the Gihub site @@ -95,7 +98,7 @@ def run_cmd(cmd): def continue_maybe(prompt): - result = raw_input("\n%s (y/n): " % prompt) + result = input("\n%s (y/n): " % prompt) if result.lower() != "y": fail("Okay, exiting") @@ -134,7 +137,7 @@ def merge_pr(pr_num, target_ref, title, body, pr_repo_desc): '--pretty=format:%an <%ae>']).split("\n") distinct_authors = sorted(set(commit_authors), key=lambda x: commit_authors.count(x), reverse=True) - primary_author = raw_input( + primary_author = input( "Enter primary author in the format of \"name \" [%s]: " % distinct_authors[0]) if primary_author == "": @@ -184,7 +187,7 @@ def merge_pr(pr_num, target_ref, title, body, pr_repo_desc): def cherry_pick(pr_num, merge_hash, default_branch): - pick_ref = raw_input("Enter a branch name [%s]: " % default_branch) + pick_ref = input("Enter a branch name [%s]: " % default_branch) if pick_ref == "": pick_ref = default_branch @@ -231,7 +234,7 @@ def resolve_jira_issue(merge_branches, comment, default_jira_id=""): asf_jira = jira.client.JIRA({'server': JIRA_API_BASE}, basic_auth=(JIRA_USERNAME, JIRA_PASSWORD)) - jira_id = raw_input("Enter a JIRA id [%s]: " % default_jira_id) + jira_id = input("Enter a JIRA id [%s]: " % default_jira_id) if jira_id == "": jira_id = default_jira_id @@ -276,7 +279,7 @@ def resolve_jira_issue(merge_branches, comment, default_jira_id=""): default_fix_versions = filter(lambda x: x != v, default_fix_versions) default_fix_versions = ",".join(default_fix_versions) - fix_versions = raw_input("Enter comma-separated fix version(s) [%s]: " % default_fix_versions) + fix_versions = input("Enter comma-separated fix version(s) [%s]: " % default_fix_versions) if fix_versions == "": fix_versions = default_fix_versions fix_versions = fix_versions.replace(" ", "").split(",") @@ -315,7 +318,7 @@ def choose_jira_assignee(issue, asf_jira): if author in commentors: annotations.append("Commentor") print("[%d] %s (%s)" % (idx, author.displayName, ",".join(annotations))) - raw_assignee = raw_input( + raw_assignee = input( "Enter number of user, or userid, to assign to (blank to leave unassigned):") if raw_assignee == "": return None @@ -428,7 +431,7 @@ def main(): # Assumes branch names can be sorted lexicographically latest_branch = sorted(branch_names, reverse=True)[0] - pr_num = raw_input("Which pull request would you like to merge? (e.g. 34): ") + pr_num = input("Which pull request would you like to merge? (e.g. 34): ") pr = get_json("%s/pulls/%s" % (GITHUB_API_BASE, pr_num)) pr_events = get_json("%s/issues/%s/events" % (GITHUB_API_BASE, pr_num)) @@ -440,7 +443,7 @@ def main(): print("I've re-written the title as follows to match the standard format:") print("Original: %s" % pr["title"]) print("Modified: %s" % modified_title) - result = raw_input("Would you like to use the modified title? (y/n): ") + result = input("Would you like to use the modified title? (y/n): ") if result.lower() == "y": title = modified_title print("Using modified title:") @@ -491,7 +494,7 @@ def main(): merge_hash = merge_pr(pr_num, target_ref, title, body, pr_repo_desc) pick_prompt = "Would you like to pick %s into another branch?" % merge_hash - while raw_input("\n%s (y/n): " % pick_prompt).lower() == "y": + while input("\n%s (y/n): " % pick_prompt).lower() == "y": merged_refs = merged_refs + [cherry_pick(pr_num, merge_hash, latest_branch)] if JIRA_IMPORTED: diff --git a/dev/requirements.txt b/dev/requirements.txt index 79782279f8fbd..fa833ab96b8e7 100644 --- a/dev/requirements.txt +++ b/dev/requirements.txt @@ -2,3 +2,4 @@ jira==1.0.3 PyGithub==1.26.0 Unidecode==0.04.19 pypandoc==1.3.3 +sphinx diff --git a/dev/run-tests.py b/dev/run-tests.py index 5e8c8590b5c34..d9d3789ac1255 100755 --- a/dev/run-tests.py +++ b/dev/run-tests.py @@ -110,7 +110,7 @@ def determine_modules_to_test(changed_modules): ['graphx', 'examples'] >>> x = [x.name for x in determine_modules_to_test([modules.sql])] >>> x # doctest: +NORMALIZE_WHITESPACE - ['sql', 'hive', 'mllib', 'sql-kafka-0-10', 'examples', 'hive-thriftserver', + ['sql', 'avro', 'hive', 'mllib', 'sql-kafka-0-10', 'examples', 'hive-thriftserver', 'pyspark-sql', 'repl', 'sparkr', 'pyspark-mllib', 'pyspark-ml'] """ modules_to_test = set() @@ -357,7 +357,7 @@ def build_spark_unidoc_sbt(hadoop_version): exec_sbt(profiles_and_goals) -def build_spark_assembly_sbt(hadoop_version): +def build_spark_assembly_sbt(hadoop_version, checkstyle=False): # Enable all of the profiles for the build: build_profiles = get_hadoop_profiles(hadoop_version) + modules.root.build_profile_flags sbt_goals = ["assembly/package"] @@ -366,6 +366,9 @@ def build_spark_assembly_sbt(hadoop_version): " ".join(profiles_and_goals)) exec_sbt(profiles_and_goals) + if checkstyle: + run_java_style_checks() + # Note that we skip Unidoc build only if Hadoop 2.6 is explicitly set in this SBT build. # Due to a different dependency resolution in SBT & Unidoc by an unknown reason, the # documentation build fails on a specific machine & environment in Jenkins but it was unable @@ -570,11 +573,13 @@ def main(): or f.endswith("scalastyle-config.xml") for f in changed_files): run_scala_style_checks() + should_run_java_style_checks = False if not changed_files or any(f.endswith(".java") or f.endswith("checkstyle.xml") or f.endswith("checkstyle-suppressions.xml") for f in changed_files): - run_java_style_checks() + # Run SBT Checkstyle after the build to prevent a side-effect to the build. + should_run_java_style_checks = True if not changed_files or any(f.endswith("lint-python") or f.endswith("tox.ini") or f.endswith(".py") @@ -603,7 +608,7 @@ def main(): detect_binary_inop_with_mima(hadoop_version) # Since we did not build assembly/package before running dev/mima, we need to # do it here because the tests still rely on it; see SPARK-13294 for details. - build_spark_assembly_sbt(hadoop_version) + build_spark_assembly_sbt(hadoop_version, should_run_java_style_checks) # run the test suites run_scala_tests(build_tool, hadoop_version, test_modules, excluded_tags) diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index dfea762db98c6..2aa355504bf29 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -170,6 +170,16 @@ def __hash__(self): ] ) +avro = Module( + name="avro", + dependencies=[sql], + source_file_regexes=[ + "external/avro", + ], + sbt_test_goals=[ + "avro/test", + ] +) sql_kafka = Module( name="sql-kafka-0-10", diff --git a/docs/building-spark.md b/docs/building-spark.md index 0236bb05849ad..c3bcd90ccc78f 100644 --- a/docs/building-spark.md +++ b/docs/building-spark.md @@ -215,19 +215,23 @@ If you are building Spark for use in a Python environment and you wish to pip in Alternatively, you can also run make-distribution with the --pip option. -## PySpark Tests with Maven +## PySpark Tests with Maven or SBT If you are building PySpark and wish to run the PySpark tests you will need to build Spark with Hive support. ./build/mvn -DskipTests clean package -Phive ./python/run-tests +If you are building PySpark with SBT and wish to run the PySpark tests, you will need to build Spark with Hive support and also build the test components: + + ./build/sbt -Phive clean package + ./build/sbt test:compile + ./python/run-tests + The run-tests script also can be limited to a specific Python version or a specific module ./python/run-tests --python-executables=python --modules=pyspark-sql -**Note:** You can also run Python tests with an sbt build, provided you build Spark with Hive support. - ## Running R Tests To run the SparkR tests you will need to install the [knitr](https://cran.r-project.org/package=knitr), [rmarkdown](https://cran.r-project.org/package=rmarkdown), [testthat](https://cran.r-project.org/package=testthat), [e1071](https://cran.r-project.org/package=e1071) and [survival](https://cran.r-project.org/package=survival) packages first: diff --git a/docs/cloud-integration.md b/docs/cloud-integration.md index ac1c336988930..18e8fe77bbdbe 100644 --- a/docs/cloud-integration.md +++ b/docs/cloud-integration.md @@ -70,7 +70,7 @@ be safely used as the direct destination of work with the normal rename-based co ### Installation With the relevant libraries on the classpath and Spark configured with valid credentials, -objects can be can be read or written by using their URLs as the path to data. +objects can be read or written by using their URLs as the path to data. For example `sparkContext.textFile("s3a://landsat-pds/scene_list.gz")` will create an RDD of the file `scene_list.gz` stored in S3, using the s3a connector. @@ -184,7 +184,8 @@ is no need for a workflow of write-then-rename to ensure that files aren't picke while they are still being written. Applications can write straight to the monitored directory. 1. Streams should only be checkpointed to a store implementing a fast and -atomic `rename()` operation Otherwise the checkpointing may be slow and potentially unreliable. +atomic `rename()` operation. +Otherwise the checkpointing may be slow and potentially unreliable. ## Further Reading diff --git a/docs/configuration.md b/docs/configuration.md index 6aa7878fe614d..0c7c4472be643 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -456,33 +456,6 @@ Apart from these, the following properties are also available, and may be useful from JVM to Python worker for every task. - - spark.sql.repl.eagerEval.enabled - false - - Enable eager evaluation or not. If true and the REPL you are using supports eager evaluation, - Dataset will be ran automatically. The HTML table which generated by _repl_html_ - called by notebooks like Jupyter will feedback the queries user have defined. For plain Python - REPL, the output will be shown like dataframe.show() - (see SPARK-24215 for more details). - - - - spark.sql.repl.eagerEval.maxNumRows - 20 - - Default number of rows in eager evaluation output HTML table generated by _repr_html_ or plain text, - this only take effect when spark.sql.repl.eagerEval.enabled is set to true. - - - - spark.sql.repl.eagerEval.truncate - 20 - - Default number of truncate in eager evaluation output HTML table generated by _repr_html_ or - plain text, this only take effect when spark.sql.repl.eagerEval.enabled set to true. - - spark.files diff --git a/docs/job-scheduling.md b/docs/job-scheduling.md index da90342406c84..2316f175676ee 100644 --- a/docs/job-scheduling.md +++ b/docs/job-scheduling.md @@ -264,3 +264,11 @@ within it for the various settings. For example: A full example is also available in `conf/fairscheduler.xml.template`. Note that any pools not configured in the XML file will simply get default values for all settings (scheduling mode FIFO, weight 1, and minShare 0). + +## Scheduling using JDBC Connections +To set a [Fair Scheduler](job-scheduling.html#fair-scheduler-pools) pool for a JDBC client session, +users can set the `spark.sql.thriftserver.scheduler.pool` variable: + +{% highlight SQL %} +SET spark.sql.thriftserver.scheduler.pool=accounting; +{% endhighlight %} diff --git a/docs/ml-features.md b/docs/ml-features.md index 7aed2341584fc..ad6e718b37f1b 100644 --- a/docs/ml-features.md +++ b/docs/ml-features.md @@ -1429,7 +1429,7 @@ for more details on the API. ## Imputer -The `Imputer` transformer completes missing values in a dataset, either using the mean or the +The `Imputer` estimator completes missing values in a dataset, either using the mean or the median of the columns in which the missing values are located. The input columns should be of `DoubleType` or `FloatType`. Currently `Imputer` does not support categorical features and possibly creates incorrect values for columns containing categorical features. Imputer can impute custom values diff --git a/docs/ml-statistics.md b/docs/ml-statistics.md index abfb3cab1e566..6c82b3bb94b24 100644 --- a/docs/ml-statistics.md +++ b/docs/ml-statistics.md @@ -89,4 +89,32 @@ Refer to the [`ChiSquareTest` Python docs](api/python/index.html#pyspark.ml.stat {% include_example python/ml/chi_square_test_example.py %} + + +## Summarizer + +We provide vector column summary statistics for `Dataframe` through `Summarizer`. +Available metrics are the column-wise max, min, mean, variance, and number of nonzeros, as well as the total count. + +
+
+The following example demonstrates using [`Summarizer`](api/scala/index.html#org.apache.spark.ml.stat.Summarizer$) +to compute the mean and variance for a vector column of the input dataframe, with and without a weight column. + +{% include_example scala/org/apache/spark/examples/ml/SummarizerExample.scala %} +
+ +
+The following example demonstrates using [`Summarizer`](api/java/org/apache/spark/ml/stat/Summarizer.html) +to compute the mean and variance for a vector column of the input dataframe, with and without a weight column. + +{% include_example java/org/apache/spark/examples/ml/JavaSummarizerExample.java %} +
+ +
+Refer to the [`Summarizer` Python docs](api/python/index.html#pyspark.ml.stat.Summarizer$) for details on the API. + +{% include_example python/ml/summarizer_example.py %} +
+
\ No newline at end of file diff --git a/docs/running-on-kubernetes.md b/docs/running-on-kubernetes.md index 408e446ea4822..7149616e534aa 100644 --- a/docs/running-on-kubernetes.md +++ b/docs/running-on-kubernetes.md @@ -629,6 +629,54 @@ specific to Spark on Kubernetes. Add as an environment variable to the executor container with name EnvName (case sensitive), the value referenced by key key in the data of the referenced Kubernetes Secret. For example, spark.kubernetes.executor.secrets.ENV_VAR=spark-secret:key. + + + spark.kubernetes.driver.volumes.[VolumeType].[VolumeName].mount.path + (none) + + Add the Kubernetes Volume named VolumeName of the VolumeType type to the driver pod on the path specified in the value. For example, + spark.kubernetes.driver.volumes.persistentVolumeClaim.checkpointpvc.mount.path=/checkpoint. + + + + spark.kubernetes.driver.volumes.[VolumeType].[VolumeName].mount.readOnly + (none) + + Specify if the mounted volume is read only or not. For example, + spark.kubernetes.driver.volumes.persistentVolumeClaim.checkpointpvc.mount.readOnly=false. + + + + spark.kubernetes.driver.volumes.[VolumeType].[VolumeName].options.[OptionName] + (none) + + Configure Kubernetes Volume options passed to the Kubernetes with OptionName as key having specified value, must conform with Kubernetes option format. For example, + spark.kubernetes.driver.volumes.persistentVolumeClaim.checkpointpvc.options.claimName=spark-pvc-claim. + + + + spark.kubernetes.executor.volumes.[VolumeType].[VolumeName].mount.path + (none) + + Add the Kubernetes Volume named VolumeName of the VolumeType type to the executor pod on the path specified in the value. For example, + spark.kubernetes.executor.volumes.persistentVolumeClaim.checkpointpvc.mount.path=/checkpoint. + + + + spark.kubernetes.executor.volumes.[VolumeType].[VolumeName].mount.readOnly + false + + Specify if the mounted volume is read only or not. For example, + spark.kubernetes.executor.volumes.persistentVolumeClaim.checkpointpvc.mount.readOnly=false. + + + + spark.kubernetes.executor.volumes.[VolumeType].[VolumeName].options.[OptionName] + (none) + + Configure Kubernetes Volume options passed to the Kubernetes with OptionName as key having specified value. For example, + spark.kubernetes.executor.volumes.persistentVolumeClaim.checkpointpvc.options.claimName=spark-pvc-claim. + spark.kubernetes.memoryOverheadFactor diff --git a/docs/running-on-yarn.md b/docs/running-on-yarn.md index 4dbcbeafbbd9d..0b265b0cb1b31 100644 --- a/docs/running-on-yarn.md +++ b/docs/running-on-yarn.md @@ -218,9 +218,10 @@ To use a custom metrics.properties for the application master and executors, upd spark.yarn.dist.forceDownloadSchemes (none) - Comma-separated list of schemes for which files will be downloaded to the local disk prior to + Comma-separated list of schemes for which resources will be downloaded to the local disk prior to being added to YARN's distributed cache. For use in cases where the YARN service does not - support schemes that are supported by Spark, like http, https and ftp. + support schemes that are supported by Spark, like http, https and ftp, or jars required to be in the + local YARN client's classpath. Wildcard '*' is denoted to download resources for all the schemes. @@ -411,6 +412,16 @@ To use a custom metrics.properties for the application master and executors, upd name matches both the include and the exclude pattern, this file will be excluded eventually. + + spark.yarn.blacklist.executor.launch.blacklisting.enabled + false + + Flag to enable blacklisting of nodes having YARN resource allocation problems. + The error limit for blacklisting can be configured by + spark.blacklist.application.maxFailedExecutorsPerNode. + + + # Important notes diff --git a/docs/security.md b/docs/security.md index 8c0c66fb5a285..6ef3a808e0471 100644 --- a/docs/security.md +++ b/docs/security.md @@ -177,7 +177,7 @@ ACLs can be configured for either users or groups. Configuration entries accept lists as input, meaning multiple users or groups can be given the desired privileges. This can be used if you run on a shared cluster and have a set of administrators or developers who need to monitor applications they may not have started themselves. A wildcard (`*`) added to specific ACL -means that all users will have the respective pivilege. By default, only the user submitting the +means that all users will have the respective privilege. By default, only the user submitting the application is added to the ACLs. Group membership is established by using a configurable group mapping provider. The mapper is @@ -446,6 +446,27 @@ replaced with one of the above namespaces. +Spark also supports retrieving `${ns}.keyPassword`, `${ns}.keyStorePassword` and `${ns}.trustStorePassword` from +[Hadoop Credential Providers](https://hadoop.apache.org/docs/current/hadoop-project-dist/hadoop-common/CredentialProviderAPI.html). +User could store password into credential file and make it accessible by different components, like: + +``` +hadoop credential create spark.ssl.keyPassword -value password \ + -provider jceks://hdfs@nn1.example.com:9001/user/backup/ssl.jceks +``` + +To configure the location of the credential provider, set the `hadoop.security.credential.provider.path` +config option in the Hadoop configuration used by Spark, like: + +``` + + hadoop.security.credential.provider.path + jceks://hdfs@nn1.example.com:9001/user/backup/ssl.jceks + +``` + +Or via SparkConf "spark.hadoop.hadoop.security.credential.provider.path=jceks://hdfs@nn1.example.com:9001/user/backup/ssl.jceks". + ## Preparing the key stores Key stores can be generated by `keytool` program. The reference documentation for this tool for diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 4d8a738507bd1..ad23dae7c6b7c 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1302,9 +1302,33 @@ the following case-insensitive options: dbtable - The JDBC table that should be read. Note that anything that is valid in a FROM clause of - a SQL query can be used. For example, instead of a full table you could also use a - subquery in parentheses. + The JDBC table that should be read from or written into. Note that when using it in the read + path anything that is valid in a FROM clause of a SQL query can be used. + For example, instead of a full table you could also use a subquery in parentheses. It is not + allowed to specify `dbtable` and `query` options at the same time. + + + + query + + A query that will be used to read data into Spark. The specified query will be parenthesized and used + as a subquery in the FROM clause. Spark will also assign an alias to the subquery clause. + As an example, spark will issue a query of the following form to the JDBC Source.

+ SELECT <columns> FROM (<user_specified_query>) spark_gen_alias

+ Below are couple of restrictions while using this option.
+
    +
  1. It is not allowed to specify `dbtable` and `query` options at the same time.
  2. +
  3. It is not allowed to spcify `query` and `partitionColumn` options at the same time. When specifying + `partitionColumn` option is required, the subquery can be specified using `dbtable` option instead and + partition columns can be qualified using the subquery alias provided as part of `dbtable`.
    + Example:
    + + spark.read.format("jdbc")
    +    .option("dbtable", "(select c1, c2 from t1) as subq")
    +    .option("partitionColumn", "subq.c1"
    +    .load() +
  4. +
@@ -1752,14 +1776,10 @@ To use `groupBy().apply()`, the user needs to define the following: * A Python function that defines the computation for each group. * A `StructType` object or a string that defines the schema of the output `DataFrame`. -The output schema will be applied to the columns of the returned `pandas.DataFrame` in order by position, -not by name. This means that the columns in the `pandas.DataFrame` must be indexed so that their -position matches the corresponding field in the schema. - -Note that when creating a new `pandas.DataFrame` using a dictionary, the actual position of the column -can differ from the order that it was placed in the dictionary. It is recommended in this case to -explicitly define the column order using the `columns` keyword, e.g. -`pandas.DataFrame({'id': ids, 'a': data}, columns=['id', 'a'])`, or alternatively use an `OrderedDict`. +The column labels of the returned `pandas.DataFrame` must either match the field names in the +defined output schema if specified as strings, or match the field data types by position if not +strings, e.g. integer indices. See [pandas.DataFrame](https://pandas.pydata.org/pandas-docs/stable/generated/pandas.DataFrame.html#pandas.DataFrame) +on how to label columns when constructing a `pandas.DataFrame`. Note that all data for a group will be loaded into memory before the function is applied. This can lead to out of memory exceptons, especially if the group sizes are skewed. The configuration for @@ -1830,7 +1850,9 @@ working with timestamps in `pandas_udf`s to get the best performance, see - Since Spark 2.4, writing a dataframe with an empty or nested empty schema using any file formats (parquet, orc, json, text, csv etc.) is not allowed. An exception is thrown when attempting to write dataframes with empty schema. - Since Spark 2.4, Spark compares a DATE type with a TIMESTAMP type after promotes both sides to TIMESTAMP. To set `false` to `spark.sql.hive.compareDateTimestampInTimestamp` restores the previous behavior. This option will be removed in Spark 3.0. - Since Spark 2.4, creating a managed table with nonempty location is not allowed. An exception is thrown when attempting to create a managed table with nonempty location. To set `true` to `spark.sql.allowCreatingManagedTableUsingNonemptyLocation` restores the previous behavior. This option will be removed in Spark 3.0. + - Since Spark 2.4, renaming a managed table to existing location is not allowed. An exception is thrown when attempting to rename a managed table to existing location. - Since Spark 2.4, the type coercion rules can automatically promote the argument types of the variadic SQL functions (e.g., IN/COALESCE) to the widest common type, no matter how the input arguments order. In prior Spark versions, the promotion could fail in some specific orders (e.g., TimestampType, IntegerType and StringType) and throw an exception. + - Since Spark 2.4, Spark has enabled non-cascading SQL cache invalidation in addition to the traditional cache invalidation mechanism. The non-cascading cache invalidation mechanism allows users to remove a cache without impacting its dependent caches. This new cache invalidation mechanism is used in scenarios where the data of the cache to be removed is still valid, e.g., calling unpersist() on a Dataset, or dropping a temporary view. This allows users to free up memory and keep the desired caches valid at the same time. - In version 2.3 and earlier, `to_utc_timestamp` and `from_utc_timestamp` respect the timezone in the input timestamp string, which breaks the assumption that the input timestamp is in a specific timezone. Therefore, these 2 functions can return unexpected results. In version 2.4 and later, this problem has been fixed. `to_utc_timestamp` and `from_utc_timestamp` will return null if the input timestamp string contains timezone. As an example, `from_utc_timestamp('2000-10-10 00:00:00', 'GMT+1')` will return `2000-10-10 01:00:00` in both Spark 2.3 and 2.4. However, `from_utc_timestamp('2000-10-10 00:00:00+00:00', 'GMT+1')`, assuming a local timezone of GMT+8, will return `2000-10-10 09:00:00` in Spark 2.3 but `null` in 2.4. For people who don't care about this problem and want to retain the previous behaivor to keep their query unchanged, you can set `spark.sql.function.rejectTimezoneInString` to false. This option will be removed in Spark 3.0 and should only be used as a temporary workaround. - In version 2.3 and earlier, Spark converts Parquet Hive tables by default but ignores table properties like `TBLPROPERTIES (parquet.compression 'NONE')`. This happens for ORC Hive table properties like `TBLPROPERTIES (orc.compress 'NONE')` in case of `spark.sql.hive.convertMetastoreOrc=true`, too. Since Spark 2.4, Spark respects Parquet/ORC specific table properties while converting Parquet/ORC Hive tables. As an example, `CREATE TABLE t(id int) STORED AS PARQUET TBLPROPERTIES (parquet.compression 'NONE')` would generate Snappy parquet files during insertion in Spark 2.3, and in Spark 2.4, the result would be uncompressed parquet files. - Since Spark 2.0, Spark converts Parquet Hive tables by default for better performance. Since Spark 2.4, Spark converts ORC Hive tables by default, too. It means Spark uses its own ORC support by default instead of Hive SerDe. As an example, `CREATE TABLE t(id int) STORED AS ORC` would be handled with Hive SerDe in Spark 2.3, and in Spark 2.4, it would be converted into Spark's ORC data source table and ORC vectorization would be applied. To set `false` to `spark.sql.hive.convertMetastoreOrc` restores the previous behavior. @@ -1996,6 +2018,7 @@ working with timestamps in `pandas_udf`s to get the best performance, see - Literal values used in SQL operations are converted to DECIMAL with the exact precision and scale needed by them. - The configuration `spark.sql.decimalOperations.allowPrecisionLoss` has been introduced. It defaults to `true`, which means the new behavior described here; if set to `false`, Spark uses previous rules, ie. it doesn't adjust the needed scale to represent the values and it returns NULL if an exact representation of the value is not possible. - In PySpark, `df.replace` does not allow to omit `value` when `to_replace` is not a dictionary. Previously, `value` could be omitted in the other cases and had `None` by default, which is counterintuitive and error-prone. + - Un-aliased subquery's semantic has not been well defined with confusing behaviors. Since Spark 2.3, we invalidate such confusing cases, for example: `SELECT v.i from (SELECT i FROM v)`, Spark will throw an analysis exception in this case because users should not be able to use the qualifier inside a subquery. See [SPARK-20690](https://issues.apache.org/jira/browse/SPARK-20690) and [SPARK-21335](https://issues.apache.org/jira/browse/SPARK-21335) for more details. ## Upgrading From Spark SQL 2.1 to 2.2 diff --git a/docs/streaming-programming-guide.md b/docs/streaming-programming-guide.md index c30959263cdfa..118b05355c74d 100644 --- a/docs/streaming-programming-guide.md +++ b/docs/streaming-programming-guide.md @@ -2176,6 +2176,8 @@ the input data stream (using `inputStream.repartition()`). This distributes the received batches of data across the specified number of machines in the cluster before further processing. +For direct stream, please refer to [Spark Streaming + Kafka Integration Guide](streaming-kafka-integration.html) + ### Level of Parallelism in Data Processing {:.no_toc} Cluster resources can be under-utilized if the number of parallel tasks used in any stage of the diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaSummarizerExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaSummarizerExample.java new file mode 100644 index 0000000000000..e9b84365d86ed --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaSummarizerExample.java @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import org.apache.spark.sql.*; + +// $example on$ +import java.util.Arrays; +import java.util.List; + +import org.apache.spark.ml.linalg.Vector; +import org.apache.spark.ml.linalg.Vectors; +import org.apache.spark.ml.linalg.VectorUDT; +import org.apache.spark.ml.stat.Summarizer; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +// $example off$ + +public class JavaSummarizerExample { + public static void main(String[] args) { + SparkSession spark = SparkSession + .builder() + .appName("JavaSummarizerExample") + .getOrCreate(); + + // $example on$ + List data = Arrays.asList( + RowFactory.create(Vectors.dense(2.0, 3.0, 5.0), 1.0), + RowFactory.create(Vectors.dense(4.0, 6.0, 7.0), 2.0) + ); + + StructType schema = new StructType(new StructField[]{ + new StructField("features", new VectorUDT(), false, Metadata.empty()), + new StructField("weight", DataTypes.DoubleType, false, Metadata.empty()) + }); + + Dataset df = spark.createDataFrame(data, schema); + + Row result1 = df.select(Summarizer.metrics("mean", "variance") + .summary(new Column("features"), new Column("weight")).as("summary")) + .select("summary.mean", "summary.variance").first(); + System.out.println("with weight: mean = " + result1.getAs(0).toString() + + ", variance = " + result1.getAs(1).toString()); + + Row result2 = df.select( + Summarizer.mean(new Column("features")), + Summarizer.variance(new Column("features")) + ).first(); + System.out.println("without weight: mean = " + result2.getAs(0).toString() + + ", variance = " + result2.getAs(1).toString()); + // $example off$ + spark.stop(); + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/streaming/JavaDirectKafkaWordCount.java b/examples/src/main/java/org/apache/spark/examples/streaming/JavaDirectKafkaWordCount.java index b6b163fa8b2cd..748bf58f30350 100644 --- a/examples/src/main/java/org/apache/spark/examples/streaming/JavaDirectKafkaWordCount.java +++ b/examples/src/main/java/org/apache/spark/examples/streaming/JavaDirectKafkaWordCount.java @@ -26,7 +26,9 @@ import scala.Tuple2; +import org.apache.kafka.clients.consumer.ConsumerConfig; import org.apache.kafka.clients.consumer.ConsumerRecord; +import org.apache.kafka.common.serialization.StringDeserializer; import org.apache.spark.SparkConf; import org.apache.spark.streaming.api.java.*; @@ -37,30 +39,33 @@ /** * Consumes messages from one or more topics in Kafka and does wordcount. - * Usage: JavaDirectKafkaWordCount + * Usage: JavaDirectKafkaWordCount * is a list of one or more Kafka brokers + * is a consumer group name to consume from topics * is a list of one or more kafka topics to consume from * * Example: * $ bin/run-example streaming.JavaDirectKafkaWordCount broker1-host:port,broker2-host:port \ - * topic1,topic2 + * consumer-group topic1,topic2 */ public final class JavaDirectKafkaWordCount { private static final Pattern SPACE = Pattern.compile(" "); public static void main(String[] args) throws Exception { - if (args.length < 2) { - System.err.println("Usage: JavaDirectKafkaWordCount \n" + - " is a list of one or more Kafka brokers\n" + - " is a list of one or more kafka topics to consume from\n\n"); + if (args.length < 3) { + System.err.println("Usage: JavaDirectKafkaWordCount \n" + + " is a list of one or more Kafka brokers\n" + + " is a consumer group name to consume from topics\n" + + " is a list of one or more kafka topics to consume from\n\n"); System.exit(1); } StreamingExamples.setStreamingLogLevels(); String brokers = args[0]; - String topics = args[1]; + String groupId = args[1]; + String topics = args[2]; // Create context with a 2 seconds batch interval SparkConf sparkConf = new SparkConf().setAppName("JavaDirectKafkaWordCount"); @@ -68,7 +73,10 @@ public static void main(String[] args) throws Exception { Set topicsSet = new HashSet<>(Arrays.asList(topics.split(","))); Map kafkaParams = new HashMap<>(); - kafkaParams.put("metadata.broker.list", brokers); + kafkaParams.put(ConsumerConfig.BOOTSTRAP_SERVERS_CONFIG, brokers); + kafkaParams.put(ConsumerConfig.GROUP_ID_CONFIG, groupId); + kafkaParams.put(ConsumerConfig.KEY_DESERIALIZER_CLASS_CONFIG, StringDeserializer.class); + kafkaParams.put(ConsumerConfig.VALUE_DESERIALIZER_CLASS_CONFIG, StringDeserializer.class); // Create direct kafka stream with brokers and topics JavaInputDStream> messages = KafkaUtils.createDirectStream( diff --git a/examples/src/main/python/ml/summarizer_example.py b/examples/src/main/python/ml/summarizer_example.py new file mode 100644 index 0000000000000..8835f189a1ad4 --- /dev/null +++ b/examples/src/main/python/ml/summarizer_example.py @@ -0,0 +1,59 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +An example for summarizer. +Run with: + bin/spark-submit examples/src/main/python/ml/summarizer_example.py +""" +from __future__ import print_function + +from pyspark.sql import SparkSession +# $example on$ +from pyspark.ml.stat import Summarizer +from pyspark.sql import Row +from pyspark.ml.linalg import Vectors +# $example off$ + +if __name__ == "__main__": + spark = SparkSession \ + .builder \ + .appName("SummarizerExample") \ + .getOrCreate() + sc = spark.sparkContext + + # $example on$ + df = sc.parallelize([Row(weight=1.0, features=Vectors.dense(1.0, 1.0, 1.0)), + Row(weight=0.0, features=Vectors.dense(1.0, 2.0, 3.0))]).toDF() + + # create summarizer for multiple metrics "mean" and "count" + summarizer = Summarizer.metrics("mean", "count") + + # compute statistics for multiple metrics with weight + df.select(summarizer.summary(df.features, df.weight)).show(truncate=False) + + # compute statistics for multiple metrics without weight + df.select(summarizer.summary(df.features)).show(truncate=False) + + # compute statistics for single metric "mean" with weight + df.select(Summarizer.mean(df.features, df.weight)).show(truncate=False) + + # compute statistics for single metric "mean" without weight + df.select(Summarizer.mean(df.features)).show(truncate=False) + # $example off$ + + spark.stop() diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/SummarizerExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/SummarizerExample.scala new file mode 100644 index 0000000000000..2f54d1d81bc48 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/SummarizerExample.scala @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +// $example on$ +import org.apache.spark.ml.linalg.{Vector, Vectors} +import org.apache.spark.ml.stat.Summarizer +// $example off$ +import org.apache.spark.sql.SparkSession + +object SummarizerExample { + def main(args: Array[String]): Unit = { + val spark = SparkSession + .builder + .appName("SummarizerExample") + .getOrCreate() + + import spark.implicits._ + import Summarizer._ + + // $example on$ + val data = Seq( + (Vectors.dense(2.0, 3.0, 5.0), 1.0), + (Vectors.dense(4.0, 6.0, 7.0), 2.0) + ) + + val df = data.toDF("features", "weight") + + val (meanVal, varianceVal) = df.select(metrics("mean", "variance") + .summary($"features", $"weight").as("summary")) + .select("summary.mean", "summary.variance") + .as[(Vector, Vector)].first() + + println(s"with weight: mean = ${meanVal}, variance = ${varianceVal}") + + val (meanVal2, varianceVal2) = df.select(mean($"features"), variance($"features")) + .as[(Vector, Vector)].first() + + println(s"without weight: mean = ${meanVal2}, sum = ${varianceVal2}") + // $example off$ + + spark.stop() + } +} +// scalastyle:on println diff --git a/external/avro/pom.xml b/external/avro/pom.xml new file mode 100644 index 0000000000000..42e865bc38824 --- /dev/null +++ b/external/avro/pom.xml @@ -0,0 +1,73 @@ + + + + + 4.0.0 + + org.apache.spark + spark-parent_2.11 + 2.4.0-SNAPSHOT + ../../pom.xml + + + spark-sql-avro_2.11 + + avro + + jar + Spark Avro + http://spark.apache.org/ + + + + org.apache.spark + spark-sql_${scala.binary.version} + ${project.version} + provided + + + org.apache.spark + spark-core_${scala.binary.version} + ${project.version} + test-jar + test + + + org.apache.spark + spark-catalyst_${scala.binary.version} + ${project.version} + test-jar + test + + + org.apache.spark + spark-sql_${scala.binary.version} + ${project.version} + test-jar + test + + + org.apache.spark + spark-tags_${scala.binary.version} + + + + target/scala-${scala.binary.version}/classes + target/scala-${scala.binary.version}/test-classes + + diff --git a/external/avro/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister b/external/avro/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister new file mode 100644 index 0000000000000..95835f0d4ca49 --- /dev/null +++ b/external/avro/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister @@ -0,0 +1 @@ +org.apache.spark.sql.avro.AvroFileFormat diff --git a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroDeserializer.scala b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroDeserializer.scala new file mode 100644 index 0000000000000..b31149a2c74c2 --- /dev/null +++ b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroDeserializer.scala @@ -0,0 +1,348 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.avro + +import java.nio.ByteBuffer + +import scala.collection.JavaConverters._ +import scala.collection.mutable.ArrayBuffer + +import org.apache.avro.{Schema, SchemaBuilder} +import org.apache.avro.Schema.Type._ +import org.apache.avro.generic._ +import org.apache.avro.util.Utf8 + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{SpecificInternalRow, UnsafeArrayData} +import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, DateTimeUtils, GenericArrayData} +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String + +/** + * A deserializer to deserialize data in avro format to data in catalyst format. + */ +class AvroDeserializer(rootAvroType: Schema, rootCatalystType: DataType) { + private val converter: Any => Any = rootCatalystType match { + // A shortcut for empty schema. + case st: StructType if st.isEmpty => + (data: Any) => InternalRow.empty + + case st: StructType => + val resultRow = new SpecificInternalRow(st.map(_.dataType)) + val fieldUpdater = new RowUpdater(resultRow) + val writer = getRecordWriter(rootAvroType, st, Nil) + (data: Any) => { + val record = data.asInstanceOf[GenericRecord] + writer(fieldUpdater, record) + resultRow + } + + case _ => + val tmpRow = new SpecificInternalRow(Seq(rootCatalystType)) + val fieldUpdater = new RowUpdater(tmpRow) + val writer = newWriter(rootAvroType, rootCatalystType, Nil) + (data: Any) => { + writer(fieldUpdater, 0, data) + tmpRow.get(0, rootCatalystType) + } + } + + def deserialize(data: Any): Any = converter(data) + + /** + * Creates a writer to write avro values to Catalyst values at the given ordinal with the given + * updater. + */ + private def newWriter( + avroType: Schema, + catalystType: DataType, + path: List[String]): (CatalystDataUpdater, Int, Any) => Unit = + (avroType.getType, catalystType) match { + case (NULL, NullType) => (updater, ordinal, _) => + updater.setNullAt(ordinal) + + // TODO: we can avoid boxing if future version of avro provide primitive accessors. + case (BOOLEAN, BooleanType) => (updater, ordinal, value) => + updater.setBoolean(ordinal, value.asInstanceOf[Boolean]) + + case (INT, IntegerType) => (updater, ordinal, value) => + updater.setInt(ordinal, value.asInstanceOf[Int]) + + case (LONG, LongType) => (updater, ordinal, value) => + updater.setLong(ordinal, value.asInstanceOf[Long]) + + case (LONG, TimestampType) => (updater, ordinal, value) => + updater.setLong(ordinal, value.asInstanceOf[Long] * 1000) + + case (LONG, DateType) => (updater, ordinal, value) => + updater.setInt(ordinal, (value.asInstanceOf[Long] / DateTimeUtils.MILLIS_PER_DAY).toInt) + + case (FLOAT, FloatType) => (updater, ordinal, value) => + updater.setFloat(ordinal, value.asInstanceOf[Float]) + + case (DOUBLE, DoubleType) => (updater, ordinal, value) => + updater.setDouble(ordinal, value.asInstanceOf[Double]) + + case (STRING, StringType) => (updater, ordinal, value) => + val str = value match { + case s: String => UTF8String.fromString(s) + case s: Utf8 => + val bytes = new Array[Byte](s.getByteLength) + System.arraycopy(s.getBytes, 0, bytes, 0, s.getByteLength) + UTF8String.fromBytes(bytes) + } + updater.set(ordinal, str) + + case (ENUM, StringType) => (updater, ordinal, value) => + updater.set(ordinal, UTF8String.fromString(value.toString)) + + case (FIXED, BinaryType) => (updater, ordinal, value) => + updater.set(ordinal, value.asInstanceOf[GenericFixed].bytes().clone()) + + case (BYTES, BinaryType) => (updater, ordinal, value) => + val bytes = value match { + case b: ByteBuffer => + val bytes = new Array[Byte](b.remaining) + b.get(bytes) + bytes + case b: Array[Byte] => b + case other => throw new RuntimeException(s"$other is not a valid avro binary.") + + } + updater.set(ordinal, bytes) + + case (RECORD, st: StructType) => + val writeRecord = getRecordWriter(avroType, st, path) + (updater, ordinal, value) => + val row = new SpecificInternalRow(st) + writeRecord(new RowUpdater(row), value.asInstanceOf[GenericRecord]) + updater.set(ordinal, row) + + case (ARRAY, ArrayType(elementType, containsNull)) => + val elementWriter = newWriter(avroType.getElementType, elementType, path) + (updater, ordinal, value) => + val array = value.asInstanceOf[GenericData.Array[Any]] + val len = array.size() + val result = createArrayData(elementType, len) + val elementUpdater = new ArrayDataUpdater(result) + + var i = 0 + while (i < len) { + val element = array.get(i) + if (element == null) { + if (!containsNull) { + throw new RuntimeException(s"Array value at path ${path.mkString(".")} is not " + + "allowed to be null") + } else { + elementUpdater.setNullAt(i) + } + } else { + elementWriter(elementUpdater, i, element) + } + i += 1 + } + + updater.set(ordinal, result) + + case (MAP, MapType(keyType, valueType, valueContainsNull)) if keyType == StringType => + val keyWriter = newWriter(SchemaBuilder.builder().stringType(), StringType, path) + val valueWriter = newWriter(avroType.getValueType, valueType, path) + (updater, ordinal, value) => + val map = value.asInstanceOf[java.util.Map[AnyRef, AnyRef]] + val keyArray = createArrayData(keyType, map.size()) + val keyUpdater = new ArrayDataUpdater(keyArray) + val valueArray = createArrayData(valueType, map.size()) + val valueUpdater = new ArrayDataUpdater(valueArray) + val iter = map.entrySet().iterator() + var i = 0 + while (iter.hasNext) { + val entry = iter.next() + assert(entry.getKey != null) + keyWriter(keyUpdater, i, entry.getKey) + if (entry.getValue == null) { + if (!valueContainsNull) { + throw new RuntimeException(s"Map value at path ${path.mkString(".")} is not " + + "allowed to be null") + } else { + valueUpdater.setNullAt(i) + } + } else { + valueWriter(valueUpdater, i, entry.getValue) + } + i += 1 + } + + updater.set(ordinal, new ArrayBasedMapData(keyArray, valueArray)) + + case (UNION, _) => + val allTypes = avroType.getTypes.asScala + val nonNullTypes = allTypes.filter(_.getType != NULL) + if (nonNullTypes.nonEmpty) { + if (nonNullTypes.length == 1) { + newWriter(nonNullTypes.head, catalystType, path) + } else { + nonNullTypes.map(_.getType) match { + case Seq(a, b) if Set(a, b) == Set(INT, LONG) && catalystType == LongType => + (updater, ordinal, value) => value match { + case null => updater.setNullAt(ordinal) + case l: java.lang.Long => updater.setLong(ordinal, l) + case i: java.lang.Integer => updater.setLong(ordinal, i.longValue()) + } + + case Seq(a, b) if Set(a, b) == Set(FLOAT, DOUBLE) && catalystType == DoubleType => + (updater, ordinal, value) => value match { + case null => updater.setNullAt(ordinal) + case d: java.lang.Double => updater.setDouble(ordinal, d) + case f: java.lang.Float => updater.setDouble(ordinal, f.doubleValue()) + } + + case _ => + catalystType match { + case st: StructType if st.length == nonNullTypes.size => + val fieldWriters = nonNullTypes.zip(st.fields).map { + case (schema, field) => newWriter(schema, field.dataType, path :+ field.name) + }.toArray + (updater, ordinal, value) => { + val row = new SpecificInternalRow(st) + val fieldUpdater = new RowUpdater(row) + val i = GenericData.get().resolveUnion(avroType, value) + fieldWriters(i)(fieldUpdater, i, value) + updater.set(ordinal, row) + } + + case _ => + throw new IncompatibleSchemaException( + s"Cannot convert Avro to catalyst because schema at path " + + s"${path.mkString(".")} is not compatible " + + s"(avroType = $avroType, sqlType = $catalystType).\n" + + s"Source Avro schema: $rootAvroType.\n" + + s"Target Catalyst type: $rootCatalystType") + } + } + } + } else { + (updater, ordinal, value) => updater.setNullAt(ordinal) + } + + case _ => + throw new IncompatibleSchemaException( + s"Cannot convert Avro to catalyst because schema at path ${path.mkString(".")} " + + s"is not compatible (avroType = $avroType, sqlType = $catalystType).\n" + + s"Source Avro schema: $rootAvroType.\n" + + s"Target Catalyst type: $rootCatalystType") + } + + private def getRecordWriter( + avroType: Schema, + sqlType: StructType, + path: List[String]): (CatalystDataUpdater, GenericRecord) => Unit = { + val validFieldIndexes = ArrayBuffer.empty[Int] + val fieldWriters = ArrayBuffer.empty[(CatalystDataUpdater, Any) => Unit] + + val length = sqlType.length + var i = 0 + while (i < length) { + val sqlField = sqlType.fields(i) + val avroField = avroType.getField(sqlField.name) + if (avroField != null) { + validFieldIndexes += avroField.pos() + + val baseWriter = newWriter(avroField.schema(), sqlField.dataType, path :+ sqlField.name) + val ordinal = i + val fieldWriter = (fieldUpdater: CatalystDataUpdater, value: Any) => { + if (value == null) { + fieldUpdater.setNullAt(ordinal) + } else { + baseWriter(fieldUpdater, ordinal, value) + } + } + fieldWriters += fieldWriter + } else if (!sqlField.nullable) { + throw new IncompatibleSchemaException( + s""" + |Cannot find non-nullable field ${path.mkString(".")}.${sqlField.name} in Avro schema. + |Source Avro schema: $rootAvroType. + |Target Catalyst type: $rootCatalystType. + """.stripMargin) + } + i += 1 + } + + (fieldUpdater, record) => { + var i = 0 + while (i < validFieldIndexes.length) { + fieldWriters(i)(fieldUpdater, record.get(validFieldIndexes(i))) + i += 1 + } + } + } + + private def createArrayData(elementType: DataType, length: Int): ArrayData = elementType match { + case BooleanType => UnsafeArrayData.fromPrimitiveArray(new Array[Boolean](length)) + case ByteType => UnsafeArrayData.fromPrimitiveArray(new Array[Byte](length)) + case ShortType => UnsafeArrayData.fromPrimitiveArray(new Array[Short](length)) + case IntegerType => UnsafeArrayData.fromPrimitiveArray(new Array[Int](length)) + case LongType => UnsafeArrayData.fromPrimitiveArray(new Array[Long](length)) + case FloatType => UnsafeArrayData.fromPrimitiveArray(new Array[Float](length)) + case DoubleType => UnsafeArrayData.fromPrimitiveArray(new Array[Double](length)) + case _ => new GenericArrayData(new Array[Any](length)) + } + + /** + * A base interface for updating values inside catalyst data structure like `InternalRow` and + * `ArrayData`. + */ + sealed trait CatalystDataUpdater { + def set(ordinal: Int, value: Any): Unit + + def setNullAt(ordinal: Int): Unit = set(ordinal, null) + def setBoolean(ordinal: Int, value: Boolean): Unit = set(ordinal, value) + def setByte(ordinal: Int, value: Byte): Unit = set(ordinal, value) + def setShort(ordinal: Int, value: Short): Unit = set(ordinal, value) + def setInt(ordinal: Int, value: Int): Unit = set(ordinal, value) + def setLong(ordinal: Int, value: Long): Unit = set(ordinal, value) + def setDouble(ordinal: Int, value: Double): Unit = set(ordinal, value) + def setFloat(ordinal: Int, value: Float): Unit = set(ordinal, value) + } + + final class RowUpdater(row: InternalRow) extends CatalystDataUpdater { + override def set(ordinal: Int, value: Any): Unit = row.update(ordinal, value) + + override def setNullAt(ordinal: Int): Unit = row.setNullAt(ordinal) + override def setBoolean(ordinal: Int, value: Boolean): Unit = row.setBoolean(ordinal, value) + override def setByte(ordinal: Int, value: Byte): Unit = row.setByte(ordinal, value) + override def setShort(ordinal: Int, value: Short): Unit = row.setShort(ordinal, value) + override def setInt(ordinal: Int, value: Int): Unit = row.setInt(ordinal, value) + override def setLong(ordinal: Int, value: Long): Unit = row.setLong(ordinal, value) + override def setDouble(ordinal: Int, value: Double): Unit = row.setDouble(ordinal, value) + override def setFloat(ordinal: Int, value: Float): Unit = row.setFloat(ordinal, value) + } + + final class ArrayDataUpdater(array: ArrayData) extends CatalystDataUpdater { + override def set(ordinal: Int, value: Any): Unit = array.update(ordinal, value) + + override def setNullAt(ordinal: Int): Unit = array.setNullAt(ordinal) + override def setBoolean(ordinal: Int, value: Boolean): Unit = array.setBoolean(ordinal, value) + override def setByte(ordinal: Int, value: Byte): Unit = array.setByte(ordinal, value) + override def setShort(ordinal: Int, value: Short): Unit = array.setShort(ordinal, value) + override def setInt(ordinal: Int, value: Int): Unit = array.setInt(ordinal, value) + override def setLong(ordinal: Int, value: Long): Unit = array.setLong(ordinal, value) + override def setDouble(ordinal: Int, value: Double): Unit = array.setDouble(ordinal, value) + override def setFloat(ordinal: Int, value: Float): Unit = array.setFloat(ordinal, value) + } +} diff --git a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroFileFormat.scala b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroFileFormat.scala new file mode 100755 index 0000000000000..9eb206457809c --- /dev/null +++ b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroFileFormat.scala @@ -0,0 +1,285 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.avro + +import java.io._ +import java.net.URI +import java.util.zip.Deflater + +import scala.util.control.NonFatal + +import com.esotericsoftware.kryo.{Kryo, KryoSerializable} +import com.esotericsoftware.kryo.io.{Input, Output} +import org.apache.avro.Schema +import org.apache.avro.file.{DataFileConstants, DataFileReader} +import org.apache.avro.generic.{GenericDatumReader, GenericRecord} +import org.apache.avro.mapred.{AvroOutputFormat, FsInput} +import org.apache.avro.mapreduce.AvroJob +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.{FileStatus, Path} +import org.apache.hadoop.mapreduce.Job +import org.slf4j.LoggerFactory + +import org.apache.spark.TaskContext +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.execution.datasources.{FileFormat, OutputWriterFactory, PartitionedFile} +import org.apache.spark.sql.sources.{DataSourceRegister, Filter} +import org.apache.spark.sql.types.StructType + +private[avro] class AvroFileFormat extends FileFormat with DataSourceRegister { + private val log = LoggerFactory.getLogger(getClass) + + override def equals(other: Any): Boolean = other match { + case _: AvroFileFormat => true + case _ => false + } + + // Dummy hashCode() to appease ScalaStyle. + override def hashCode(): Int = super.hashCode() + + override def inferSchema( + spark: SparkSession, + options: Map[String, String], + files: Seq[FileStatus]): Option[StructType] = { + val conf = spark.sparkContext.hadoopConfiguration + + // Schema evolution is not supported yet. Here we only pick a single random sample file to + // figure out the schema of the whole dataset. + val sampleFile = + if (AvroFileFormat.ignoreFilesWithoutExtensions(conf)) { + files.find(_.getPath.getName.endsWith(".avro")).getOrElse { + throw new FileNotFoundException( + "No Avro files found. Hadoop option \"avro.mapred.ignore.inputs.without.extension\" " + + " is set to true. Do all input files have \".avro\" extension?" + ) + } + } else { + files.headOption.getOrElse { + throw new FileNotFoundException("No Avro files found.") + } + } + + // User can specify an optional avro json schema. + val avroSchema = options.get(AvroFileFormat.AvroSchema) + .map(new Schema.Parser().parse) + .getOrElse { + val in = new FsInput(sampleFile.getPath, conf) + try { + val reader = DataFileReader.openReader(in, new GenericDatumReader[GenericRecord]()) + try { + reader.getSchema + } finally { + reader.close() + } + } finally { + in.close() + } + } + + SchemaConverters.toSqlType(avroSchema).dataType match { + case t: StructType => Some(t) + case _ => throw new RuntimeException( + s"""Avro schema cannot be converted to a Spark SQL StructType: + | + |${avroSchema.toString(true)} + |""".stripMargin) + } + } + + override def shortName(): String = "avro" + + override def isSplitable( + sparkSession: SparkSession, + options: Map[String, String], + path: Path): Boolean = true + + override def prepareWrite( + spark: SparkSession, + job: Job, + options: Map[String, String], + dataSchema: StructType): OutputWriterFactory = { + val recordName = options.getOrElse("recordName", "topLevelRecord") + val recordNamespace = options.getOrElse("recordNamespace", "") + val outputAvroSchema = SchemaConverters.toAvroType( + dataSchema, nullable = false, recordName, recordNamespace) + + AvroJob.setOutputKeySchema(job, outputAvroSchema) + val AVRO_COMPRESSION_CODEC = "spark.sql.avro.compression.codec" + val AVRO_DEFLATE_LEVEL = "spark.sql.avro.deflate.level" + val COMPRESS_KEY = "mapred.output.compress" + + spark.conf.get(AVRO_COMPRESSION_CODEC, "snappy") match { + case "uncompressed" => + log.info("writing uncompressed Avro records") + job.getConfiguration.setBoolean(COMPRESS_KEY, false) + + case "snappy" => + log.info("compressing Avro output using Snappy") + job.getConfiguration.setBoolean(COMPRESS_KEY, true) + job.getConfiguration.set(AvroJob.CONF_OUTPUT_CODEC, DataFileConstants.SNAPPY_CODEC) + + case "deflate" => + val deflateLevel = spark.conf.get( + AVRO_DEFLATE_LEVEL, Deflater.DEFAULT_COMPRESSION.toString).toInt + log.info(s"compressing Avro output using deflate (level=$deflateLevel)") + job.getConfiguration.setBoolean(COMPRESS_KEY, true) + job.getConfiguration.set(AvroJob.CONF_OUTPUT_CODEC, DataFileConstants.DEFLATE_CODEC) + job.getConfiguration.setInt(AvroOutputFormat.DEFLATE_LEVEL_KEY, deflateLevel) + + case unknown: String => + log.error(s"unsupported compression codec $unknown") + } + + new AvroOutputWriterFactory(dataSchema, new SerializableSchema(outputAvroSchema)) + } + + override def buildReader( + spark: SparkSession, + dataSchema: StructType, + partitionSchema: StructType, + requiredSchema: StructType, + filters: Seq[Filter], + options: Map[String, String], + hadoopConf: Configuration): (PartitionedFile) => Iterator[InternalRow] = { + + val broadcastedConf = + spark.sparkContext.broadcast(new AvroFileFormat.SerializableConfiguration(hadoopConf)) + + (file: PartitionedFile) => { + val log = LoggerFactory.getLogger(classOf[AvroFileFormat]) + val conf = broadcastedConf.value.value + val userProvidedSchema = options.get(AvroFileFormat.AvroSchema).map(new Schema.Parser().parse) + + // TODO Removes this check once `FileFormat` gets a general file filtering interface method. + // Doing input file filtering is improper because we may generate empty tasks that process no + // input files but stress the scheduler. We should probably add a more general input file + // filtering mechanism for `FileFormat` data sources. See SPARK-16317. + if (AvroFileFormat.ignoreFilesWithoutExtensions(conf) && !file.filePath.endsWith(".avro")) { + Iterator.empty + } else { + val reader = { + val in = new FsInput(new Path(new URI(file.filePath)), conf) + try { + val datumReader = userProvidedSchema match { + case Some(userSchema) => new GenericDatumReader[GenericRecord](userSchema) + case _ => new GenericDatumReader[GenericRecord]() + } + DataFileReader.openReader(in, datumReader) + } catch { + case NonFatal(e) => + log.error("Exception while opening DataFileReader", e) + in.close() + throw e + } + } + + // Ensure that the reader is closed even if the task fails or doesn't consume the entire + // iterator of records. + Option(TaskContext.get()).foreach { taskContext => + taskContext.addTaskCompletionListener { _ => + reader.close() + } + } + + reader.sync(file.start) + val stop = file.start + file.length + + val deserializer = + new AvroDeserializer(userProvidedSchema.getOrElse(reader.getSchema), requiredSchema) + + new Iterator[InternalRow] { + private[this] var completed = false + + override def hasNext: Boolean = { + if (completed) { + false + } else { + val r = reader.hasNext && !reader.pastSync(stop) + if (!r) { + reader.close() + completed = true + } + r + } + } + + override def next(): InternalRow = { + if (!hasNext) { + throw new NoSuchElementException("next on empty iterator") + } + val record = reader.next() + deserializer.deserialize(record).asInstanceOf[InternalRow] + } + } + } + } + } +} + +private[avro] object AvroFileFormat { + val IgnoreFilesWithoutExtensionProperty = "avro.mapred.ignore.inputs.without.extension" + + val AvroSchema = "avroSchema" + + class SerializableConfiguration(@transient var value: Configuration) + extends Serializable with KryoSerializable { + @transient private[avro] lazy val log = LoggerFactory.getLogger(getClass) + + private def writeObject(out: ObjectOutputStream): Unit = tryOrIOException { + out.defaultWriteObject() + value.write(out) + } + + private def readObject(in: ObjectInputStream): Unit = tryOrIOException { + value = new Configuration(false) + value.readFields(in) + } + + private def tryOrIOException[T](block: => T): T = { + try { + block + } catch { + case e: IOException => + log.error("Exception encountered", e) + throw e + case NonFatal(e) => + log.error("Exception encountered", e) + throw new IOException(e) + } + } + + def write(kryo: Kryo, out: Output): Unit = { + val dos = new DataOutputStream(out) + value.write(dos) + dos.flush() + } + + def read(kryo: Kryo, in: Input): Unit = { + value = new Configuration(false) + value.readFields(new DataInputStream(in)) + } + } + + def ignoreFilesWithoutExtensions(conf: Configuration): Boolean = { + // Files without .avro extensions are not ignored by default + val defaultValue = false + + conf.getBoolean(AvroFileFormat.IgnoreFilesWithoutExtensionProperty, defaultValue) + } +} diff --git a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroOutputWriter.scala b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroOutputWriter.scala new file mode 100644 index 0000000000000..06507115f5ed8 --- /dev/null +++ b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroOutputWriter.scala @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.avro + +import java.io.{IOException, OutputStream} + +import org.apache.avro.Schema +import org.apache.avro.generic.GenericRecord +import org.apache.avro.mapred.AvroKey +import org.apache.avro.mapreduce.AvroKeyOutputFormat +import org.apache.hadoop.fs.Path +import org.apache.hadoop.io.NullWritable +import org.apache.hadoop.mapreduce.{RecordWriter, TaskAttemptContext} + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.execution.datasources.OutputWriter +import org.apache.spark.sql.types._ + +// NOTE: This class is instantiated and used on executor side only, no need to be serializable. +private[avro] class AvroOutputWriter( + path: String, + context: TaskAttemptContext, + schema: StructType, + avroSchema: Schema) extends OutputWriter { + + // The input rows will never be null. + private lazy val serializer = new AvroSerializer(schema, avroSchema, nullable = false) + + /** + * Overrides the couple of methods responsible for generating the output streams / files so + * that the data can be correctly partitioned + */ + private val recordWriter: RecordWriter[AvroKey[GenericRecord], NullWritable] = + new AvroKeyOutputFormat[GenericRecord]() { + + override def getDefaultWorkFile(context: TaskAttemptContext, extension: String): Path = { + new Path(path) + } + + @throws(classOf[IOException]) + override def getAvroFileOutputStream(c: TaskAttemptContext): OutputStream = { + val path = getDefaultWorkFile(context, ".avro") + path.getFileSystem(context.getConfiguration).create(path) + } + + }.getRecordWriter(context) + + override def write(row: InternalRow): Unit = { + val key = new AvroKey(serializer.serialize(row).asInstanceOf[GenericRecord]) + recordWriter.write(key, NullWritable.get()) + } + + override def close(): Unit = recordWriter.close(context) +} diff --git a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroOutputWriterFactory.scala b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroOutputWriterFactory.scala new file mode 100644 index 0000000000000..18a6d93951408 --- /dev/null +++ b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroOutputWriterFactory.scala @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.avro + +import org.apache.hadoop.mapreduce.TaskAttemptContext + +import org.apache.spark.sql.execution.datasources.{OutputWriter, OutputWriterFactory} +import org.apache.spark.sql.types.StructType + +private[avro] class AvroOutputWriterFactory( + schema: StructType, + avroSchema: SerializableSchema) extends OutputWriterFactory { + + override def getFileExtension(context: TaskAttemptContext): String = ".avro" + + override def newInstance( + path: String, + dataSchema: StructType, + context: TaskAttemptContext): OutputWriter = { + new AvroOutputWriter(path, context, schema, avroSchema.value) + } +} diff --git a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroSerializer.scala b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroSerializer.scala new file mode 100644 index 0000000000000..2b4c5813a535b --- /dev/null +++ b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroSerializer.scala @@ -0,0 +1,180 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.avro + +import java.nio.ByteBuffer + +import scala.collection.JavaConverters._ + +import org.apache.avro.Schema +import org.apache.avro.Schema.Type.NULL +import org.apache.avro.generic.GenericData.Record +import org.apache.avro.util.Utf8 + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{SpecializedGetters, SpecificInternalRow} +import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.types._ + +/** + * A serializer to serialize data in catalyst format to data in avro format. + */ +class AvroSerializer(rootCatalystType: DataType, rootAvroType: Schema, nullable: Boolean) { + + def serialize(catalystData: Any): Any = { + converter.apply(catalystData) + } + + private val converter: Any => Any = { + val actualAvroType = resolveNullableType(rootAvroType, nullable) + val baseConverter = rootCatalystType match { + case st: StructType => + newStructConverter(st, actualAvroType).asInstanceOf[Any => Any] + case _ => + val tmpRow = new SpecificInternalRow(Seq(rootCatalystType)) + val converter = newConverter(rootCatalystType, actualAvroType) + (data: Any) => + tmpRow.update(0, data) + converter.apply(tmpRow, 0) + } + if (nullable) { + (data: Any) => + if (data == null) { + null + } else { + baseConverter.apply(data) + } + } else { + baseConverter + } + } + + private type Converter = (SpecializedGetters, Int) => Any + + private def newConverter(catalystType: DataType, avroType: Schema): Converter = { + catalystType match { + case NullType => + (getter, ordinal) => null + case BooleanType => + (getter, ordinal) => getter.getBoolean(ordinal) + case ByteType => + (getter, ordinal) => getter.getByte(ordinal).toInt + case ShortType => + (getter, ordinal) => getter.getShort(ordinal).toInt + case IntegerType => + (getter, ordinal) => getter.getInt(ordinal) + case LongType => + (getter, ordinal) => getter.getLong(ordinal) + case FloatType => + (getter, ordinal) => getter.getFloat(ordinal) + case DoubleType => + (getter, ordinal) => getter.getDouble(ordinal) + case d: DecimalType => + (getter, ordinal) => getter.getDecimal(ordinal, d.precision, d.scale).toString + case StringType => + (getter, ordinal) => new Utf8(getter.getUTF8String(ordinal).getBytes) + case BinaryType => + (getter, ordinal) => ByteBuffer.wrap(getter.getBinary(ordinal)) + case DateType => + (getter, ordinal) => getter.getInt(ordinal) * DateTimeUtils.MILLIS_PER_DAY + case TimestampType => + (getter, ordinal) => getter.getLong(ordinal) / 1000 + + case ArrayType(et, containsNull) => + val elementConverter = newConverter( + et, resolveNullableType(avroType.getElementType, containsNull)) + (getter, ordinal) => { + val arrayData = getter.getArray(ordinal) + val result = new java.util.ArrayList[Any] + var i = 0 + while (i < arrayData.numElements()) { + if (arrayData.isNullAt(i)) { + result.add(null) + } else { + result.add(elementConverter(arrayData, i)) + } + i += 1 + } + result + } + + case st: StructType => + val structConverter = newStructConverter(st, avroType) + val numFields = st.length + (getter, ordinal) => structConverter(getter.getStruct(ordinal, numFields)) + + case MapType(kt, vt, valueContainsNull) if kt == StringType => + val valueConverter = newConverter( + vt, resolveNullableType(avroType.getValueType, valueContainsNull)) + (getter, ordinal) => + val mapData = getter.getMap(ordinal) + val result = new java.util.HashMap[String, Any](mapData.numElements()) + val keyArray = mapData.keyArray() + val valueArray = mapData.valueArray() + var i = 0 + while (i < mapData.numElements()) { + val key = keyArray.getUTF8String(i).toString + if (valueArray.isNullAt(i)) { + result.put(key, null) + } else { + result.put(key, valueConverter(valueArray, i)) + } + i += 1 + } + result + + case other => + throw new IncompatibleSchemaException(s"Unexpected type: $other") + } + } + + private def newStructConverter( + catalystStruct: StructType, avroStruct: Schema): InternalRow => Record = { + val avroFields = avroStruct.getFields + assert(avroFields.size() == catalystStruct.length) + val fieldConverters = catalystStruct.zip(avroFields.asScala).map { + case (f1, f2) => newConverter(f1.dataType, resolveNullableType(f2.schema(), f1.nullable)) + } + val numFields = catalystStruct.length + (row: InternalRow) => + val result = new Record(avroStruct) + var i = 0 + while (i < numFields) { + if (row.isNullAt(i)) { + result.put(i, null) + } else { + result.put(i, fieldConverters(i).apply(row, i)) + } + i += 1 + } + result + } + + private def resolveNullableType(avroType: Schema, nullable: Boolean): Schema = { + if (nullable) { + // avro uses union to represent nullable type. + val fields = avroType.getTypes.asScala + assert(fields.length == 2) + val actualType = fields.filter(_.getType != NULL) + assert(actualType.length == 1) + actualType.head + } else { + avroType + } + } +} diff --git a/external/avro/src/main/scala/org/apache/spark/sql/avro/SchemaConverters.scala b/external/avro/src/main/scala/org/apache/spark/sql/avro/SchemaConverters.scala new file mode 100644 index 0000000000000..87fae63aeff2b --- /dev/null +++ b/external/avro/src/main/scala/org/apache/spark/sql/avro/SchemaConverters.scala @@ -0,0 +1,141 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.avro + +import scala.collection.JavaConverters._ + +import org.apache.avro.{Schema, SchemaBuilder} +import org.apache.avro.Schema.Type._ + +import org.apache.spark.sql.types._ + +/** + * This object contains method that are used to convert sparkSQL schemas to avro schemas and vice + * versa. + */ +object SchemaConverters { + case class SchemaType(dataType: DataType, nullable: Boolean) + + /** + * This function takes an avro schema and returns a sql schema. + */ + def toSqlType(avroSchema: Schema): SchemaType = { + avroSchema.getType match { + case INT => SchemaType(IntegerType, nullable = false) + case STRING => SchemaType(StringType, nullable = false) + case BOOLEAN => SchemaType(BooleanType, nullable = false) + case BYTES => SchemaType(BinaryType, nullable = false) + case DOUBLE => SchemaType(DoubleType, nullable = false) + case FLOAT => SchemaType(FloatType, nullable = false) + case LONG => SchemaType(LongType, nullable = false) + case FIXED => SchemaType(BinaryType, nullable = false) + case ENUM => SchemaType(StringType, nullable = false) + + case RECORD => + val fields = avroSchema.getFields.asScala.map { f => + val schemaType = toSqlType(f.schema()) + StructField(f.name, schemaType.dataType, schemaType.nullable) + } + + SchemaType(StructType(fields), nullable = false) + + case ARRAY => + val schemaType = toSqlType(avroSchema.getElementType) + SchemaType( + ArrayType(schemaType.dataType, containsNull = schemaType.nullable), + nullable = false) + + case MAP => + val schemaType = toSqlType(avroSchema.getValueType) + SchemaType( + MapType(StringType, schemaType.dataType, valueContainsNull = schemaType.nullable), + nullable = false) + + case UNION => + if (avroSchema.getTypes.asScala.exists(_.getType == NULL)) { + // In case of a union with null, eliminate it and make a recursive call + val remainingUnionTypes = avroSchema.getTypes.asScala.filterNot(_.getType == NULL) + if (remainingUnionTypes.size == 1) { + toSqlType(remainingUnionTypes.head).copy(nullable = true) + } else { + toSqlType(Schema.createUnion(remainingUnionTypes.asJava)).copy(nullable = true) + } + } else avroSchema.getTypes.asScala.map(_.getType) match { + case Seq(t1) => + toSqlType(avroSchema.getTypes.get(0)) + case Seq(t1, t2) if Set(t1, t2) == Set(INT, LONG) => + SchemaType(LongType, nullable = false) + case Seq(t1, t2) if Set(t1, t2) == Set(FLOAT, DOUBLE) => + SchemaType(DoubleType, nullable = false) + case _ => + // Convert complex unions to struct types where field names are member0, member1, etc. + // This is consistent with the behavior when converting between Avro and Parquet. + val fields = avroSchema.getTypes.asScala.zipWithIndex.map { + case (s, i) => + val schemaType = toSqlType(s) + // All fields are nullable because only one of them is set at a time + StructField(s"member$i", schemaType.dataType, nullable = true) + } + + SchemaType(StructType(fields), nullable = false) + } + + case other => throw new IncompatibleSchemaException(s"Unsupported type $other") + } + } + + def toAvroType( + catalystType: DataType, + nullable: Boolean = false, + recordName: String = "topLevelRecord", + prevNameSpace: String = ""): Schema = { + val builder = if (nullable) { + SchemaBuilder.builder().nullable() + } else { + SchemaBuilder.builder() + } + catalystType match { + case BooleanType => builder.booleanType() + case ByteType | ShortType | IntegerType => builder.intType() + case LongType => builder.longType() + case DateType => builder.longType() + case TimestampType => builder.longType() + case FloatType => builder.floatType() + case DoubleType => builder.doubleType() + case _: DecimalType | StringType => builder.stringType() + case BinaryType => builder.bytesType() + case ArrayType(et, containsNull) => + builder.array().items(toAvroType(et, containsNull, recordName, prevNameSpace)) + case MapType(StringType, vt, valueContainsNull) => + builder.map().values(toAvroType(vt, valueContainsNull, recordName, prevNameSpace)) + case st: StructType => + val nameSpace = s"$prevNameSpace.$recordName" + val fieldsAssembler = builder.record(recordName).namespace(nameSpace).fields() + st.foreach { f => + val fieldAvroType = toAvroType(f.dataType, f.nullable, f.name, nameSpace) + fieldsAssembler.name(f.name).`type`(fieldAvroType).noDefault() + } + fieldsAssembler.endRecord() + + // This should never happen. + case other => throw new IncompatibleSchemaException(s"Unexpected type $other.") + } + } +} + +class IncompatibleSchemaException(msg: String, ex: Throwable = null) extends Exception(msg, ex) diff --git a/external/avro/src/main/scala/org/apache/spark/sql/avro/SerializableSchema.scala b/external/avro/src/main/scala/org/apache/spark/sql/avro/SerializableSchema.scala new file mode 100644 index 0000000000000..ec0ddc778c8f6 --- /dev/null +++ b/external/avro/src/main/scala/org/apache/spark/sql/avro/SerializableSchema.scala @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.avro + +import java.io._ + +import scala.util.control.NonFatal + +import com.esotericsoftware.kryo.{Kryo, KryoSerializable} +import com.esotericsoftware.kryo.io.{Input, Output} +import org.apache.avro.Schema +import org.slf4j.LoggerFactory + +class SerializableSchema(@transient var value: Schema) + extends Serializable with KryoSerializable { + + @transient private[avro] lazy val log = LoggerFactory.getLogger(getClass) + + private def writeObject(out: ObjectOutputStream): Unit = tryOrIOException { + out.defaultWriteObject() + out.writeUTF(value.toString()) + out.flush() + } + + private def readObject(in: ObjectInputStream): Unit = tryOrIOException { + val json = in.readUTF() + value = new Schema.Parser().parse(json) + } + + private def tryOrIOException[T](block: => T): T = { + try { + block + } catch { + case e: IOException => + log.error("Exception encountered", e) + throw e + case NonFatal(e) => + log.error("Exception encountered", e) + throw new IOException(e) + } + } + + def write(kryo: Kryo, out: Output): Unit = { + val dos = new DataOutputStream(out) + dos.writeUTF(value.toString()) + dos.flush() + } + + def read(kryo: Kryo, in: Input): Unit = { + val dis = new DataInputStream(in) + val json = dis.readUTF() + value = new Schema.Parser().parse(json) + } +} diff --git a/external/avro/src/main/scala/org/apache/spark/sql/avro/package.scala b/external/avro/src/main/scala/org/apache/spark/sql/avro/package.scala new file mode 100755 index 0000000000000..b3c8a669cf820 --- /dev/null +++ b/external/avro/src/main/scala/org/apache/spark/sql/avro/package.scala @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +package object avro { + /** + * Adds a method, `avro`, to DataFrameWriter that allows you to write avro files using + * the DataFileWriter + */ + implicit class AvroDataFrameWriter[T](writer: DataFrameWriter[T]) { + def avro: String => Unit = writer.format("avro").save + } + + /** + * Adds a method, `avro`, to DataFrameReader that allows you to read avro files using + * the DataFileReader + */ + implicit class AvroDataFrameReader(reader: DataFrameReader) { + def avro: String => DataFrame = reader.format("avro").load + + @scala.annotation.varargs + def avro(sources: String*): DataFrame = reader.format("avro").load(sources: _*) + } +} diff --git a/external/avro/src/test/resources/episodes.avro b/external/avro/src/test/resources/episodes.avro new file mode 100644 index 0000000000000..58a028ce19e6a Binary files /dev/null and b/external/avro/src/test/resources/episodes.avro differ diff --git a/external/avro/src/test/resources/log4j.properties b/external/avro/src/test/resources/log4j.properties new file mode 100644 index 0000000000000..c18a724007c27 --- /dev/null +++ b/external/avro/src/test/resources/log4j.properties @@ -0,0 +1,49 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Set everything to be logged to the file core/target/unit-tests.log +log4j.rootLogger=DEBUG, CA, FA + +#Console Appender +log4j.appender.CA=org.apache.log4j.ConsoleAppender +log4j.appender.CA.layout=org.apache.log4j.PatternLayout +log4j.appender.CA.layout.ConversionPattern=%d{HH:mm:ss.SSS} %p %c: %m%n +log4j.appender.CA.Threshold = WARN + + +#File Appender +log4j.appender.FA=org.apache.log4j.FileAppender +log4j.appender.FA.append=false +log4j.appender.FA.file=target/unit-tests.log +log4j.appender.FA.layout=org.apache.log4j.PatternLayout +log4j.appender.FA.layout.ConversionPattern=%d{HH:mm:ss.SSS} %t %p %c{1}: %m%n + +# Set the logger level of File Appender to WARN +log4j.appender.FA.Threshold = INFO + +# Some packages are noisy for no good reason. +log4j.additivity.parquet.hadoop.ParquetRecordReader=false +log4j.logger.parquet.hadoop.ParquetRecordReader=OFF + +log4j.additivity.org.apache.hadoop.hive.serde2.lazy.LazyStruct=false +log4j.logger.org.apache.hadoop.hive.serde2.lazy.LazyStruct=OFF + +log4j.additivity.org.apache.hadoop.hive.metastore.RetryingHMSHandler=false +log4j.logger.org.apache.hadoop.hive.metastore.RetryingHMSHandler=OFF + +log4j.additivity.hive.ql.metadata.Hive=false +log4j.logger.hive.ql.metadata.Hive=OFF \ No newline at end of file diff --git a/external/avro/src/test/resources/test-random-partitioned/part-r-00000.avro b/external/avro/src/test/resources/test-random-partitioned/part-r-00000.avro new file mode 100755 index 0000000000000..fece892444979 Binary files /dev/null and b/external/avro/src/test/resources/test-random-partitioned/part-r-00000.avro differ diff --git a/external/avro/src/test/resources/test-random-partitioned/part-r-00001.avro b/external/avro/src/test/resources/test-random-partitioned/part-r-00001.avro new file mode 100755 index 0000000000000..1ca623a07dcf3 Binary files /dev/null and b/external/avro/src/test/resources/test-random-partitioned/part-r-00001.avro differ diff --git a/external/avro/src/test/resources/test-random-partitioned/part-r-00002.avro b/external/avro/src/test/resources/test-random-partitioned/part-r-00002.avro new file mode 100755 index 0000000000000..a12e9459e7461 Binary files /dev/null and b/external/avro/src/test/resources/test-random-partitioned/part-r-00002.avro differ diff --git a/external/avro/src/test/resources/test-random-partitioned/part-r-00003.avro b/external/avro/src/test/resources/test-random-partitioned/part-r-00003.avro new file mode 100755 index 0000000000000..60c095691d5d5 Binary files /dev/null and b/external/avro/src/test/resources/test-random-partitioned/part-r-00003.avro differ diff --git a/external/avro/src/test/resources/test-random-partitioned/part-r-00004.avro b/external/avro/src/test/resources/test-random-partitioned/part-r-00004.avro new file mode 100755 index 0000000000000..af56dfc8083dc Binary files /dev/null and b/external/avro/src/test/resources/test-random-partitioned/part-r-00004.avro differ diff --git a/external/avro/src/test/resources/test-random-partitioned/part-r-00005.avro b/external/avro/src/test/resources/test-random-partitioned/part-r-00005.avro new file mode 100755 index 0000000000000..87d78447526f9 Binary files /dev/null and b/external/avro/src/test/resources/test-random-partitioned/part-r-00005.avro differ diff --git a/external/avro/src/test/resources/test-random-partitioned/part-r-00006.avro b/external/avro/src/test/resources/test-random-partitioned/part-r-00006.avro new file mode 100755 index 0000000000000..c326fc434bf18 Binary files /dev/null and b/external/avro/src/test/resources/test-random-partitioned/part-r-00006.avro differ diff --git a/external/avro/src/test/resources/test-random-partitioned/part-r-00007.avro b/external/avro/src/test/resources/test-random-partitioned/part-r-00007.avro new file mode 100755 index 0000000000000..279f36c317eb8 Binary files /dev/null and b/external/avro/src/test/resources/test-random-partitioned/part-r-00007.avro differ diff --git a/external/avro/src/test/resources/test-random-partitioned/part-r-00008.avro b/external/avro/src/test/resources/test-random-partitioned/part-r-00008.avro new file mode 100755 index 0000000000000..8d70f5d1274d4 Binary files /dev/null and b/external/avro/src/test/resources/test-random-partitioned/part-r-00008.avro differ diff --git a/external/avro/src/test/resources/test-random-partitioned/part-r-00009.avro b/external/avro/src/test/resources/test-random-partitioned/part-r-00009.avro new file mode 100755 index 0000000000000..6839d7217e492 Binary files /dev/null and b/external/avro/src/test/resources/test-random-partitioned/part-r-00009.avro differ diff --git a/external/avro/src/test/resources/test-random-partitioned/part-r-00010.avro b/external/avro/src/test/resources/test-random-partitioned/part-r-00010.avro new file mode 100755 index 0000000000000..aedc7f7e0e61c Binary files /dev/null and b/external/avro/src/test/resources/test-random-partitioned/part-r-00010.avro differ diff --git a/external/avro/src/test/resources/test.avro b/external/avro/src/test/resources/test.avro new file mode 100644 index 0000000000000..6425e2107e304 Binary files /dev/null and b/external/avro/src/test/resources/test.avro differ diff --git a/external/avro/src/test/resources/test.avsc b/external/avro/src/test/resources/test.avsc new file mode 100644 index 0000000000000..d7119a01f6aa0 --- /dev/null +++ b/external/avro/src/test/resources/test.avsc @@ -0,0 +1,53 @@ +{ + "type" : "record", + "name" : "test_schema", + "fields" : [{ + "name" : "string", + "type" : "string", + "doc" : "Meaningless string of characters" + }, { + "name" : "simple_map", + "type" : {"type": "map", "values": "int"} + }, { + "name" : "complex_map", + "type" : {"type": "map", "values": {"type": "map", "values": "string"}} + }, { + "name" : "union_string_null", + "type" : ["null", "string"] + }, { + "name" : "union_int_long_null", + "type" : ["int", "long", "null"] + }, { + "name" : "union_float_double", + "type" : ["float", "double"] + }, { + "name": "fixed3", + "type": {"type": "fixed", "size": 3, "name": "fixed3"} + }, { + "name": "fixed2", + "type": {"type": "fixed", "size": 2, "name": "fixed2"} + }, { + "name": "enum", + "type": { "type": "enum", + "name": "Suit", + "symbols" : ["SPADES", "HEARTS", "DIAMONDS", "CLUBS"] + } + }, { + "name": "record", + "type": { + "type": "record", + "name": "record", + "aliases": ["RecordAlias"], + "fields" : [{ + "name": "value_field", + "type": "string" + }] + } + }, { + "name": "array_of_boolean", + "type": {"type": "array", "items": "boolean"} + }, { + "name": "bytes", + "type": "bytes" + }] +} diff --git a/external/avro/src/test/resources/test.json b/external/avro/src/test/resources/test.json new file mode 100644 index 0000000000000..780189a92b378 --- /dev/null +++ b/external/avro/src/test/resources/test.json @@ -0,0 +1,42 @@ +{ + "string": "OMG SPARK IS AWESOME", + "simple_map": {"abc": 1, "bcd": 7}, + "complex_map": {"key": {"a": "b", "c": "d"}}, + "union_string_null": {"string": "abc"}, + "union_int_long_null": {"int": 1}, + "union_float_double": {"float": 3.1415926535}, + "fixed3":"\u0002\u0003\u0004", + "fixed2":"\u0011\u0012", + "enum": "SPADES", + "record": {"value_field": "Two things are infinite: the universe and human stupidity; and I'm not sure about universe."}, + "array_of_boolean": [true, false, false], + "bytes": "\u0041\u0042\u0043" +} +{ + "string": "Terran is IMBA!", + "simple_map": {"mmm": 0, "qqq": 66}, + "complex_map": {"key": {"1": "2", "3": "4"}}, + "union_string_null": {"string": "123"}, + "union_int_long_null": {"long": 66}, + "union_float_double": {"double": 6.6666666666666}, + "fixed3":"\u0007\u0007\u0007", + "fixed2":"\u0001\u0002", + "enum": "CLUBS", + "record": {"value_field": "Life did not intend to make us perfect. Whoever is perfect belongs in a museum."}, + "array_of_boolean": [], + "bytes": "" +} +{ + "string": "The cake is a LIE!", + "simple_map": {}, + "complex_map": {"key": {}}, + "union_string_null": {"null": null}, + "union_int_long_null": {"null": null}, + "union_float_double": {"double": 0}, + "fixed3":"\u0011\u0022\u0009", + "fixed2":"\u0010\u0090", + "enum": "DIAMONDS", + "record": {"value_field": "TEST_STR123"}, + "array_of_boolean": [false], + "bytes": "\u0053" +} diff --git a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala new file mode 100644 index 0000000000000..446b42124ceca --- /dev/null +++ b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala @@ -0,0 +1,839 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.avro + +import java.io._ +import java.net.URL +import java.nio.file.{Files, Path, Paths} +import java.sql.{Date, Timestamp} +import java.util.{TimeZone, UUID} + +import scala.collection.JavaConverters._ + +import org.apache.avro.Schema +import org.apache.avro.Schema.{Field, Type} +import org.apache.avro.file.DataFileWriter +import org.apache.avro.generic.{GenericData, GenericDatumWriter, GenericRecord} +import org.apache.avro.generic.GenericData.{EnumSymbol, Fixed} +import org.apache.commons.io.FileUtils + +import org.apache.spark.sql._ +import org.apache.spark.sql.test.{SharedSQLContext, SQLTestUtils} +import org.apache.spark.sql.types._ + +class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { + val episodesAvro = testFile("episodes.avro") + val testAvro = testFile("test.avro") + + override protected def beforeAll(): Unit = { + super.beforeAll() + spark.conf.set("spark.sql.files.maxPartitionBytes", 1024) + } + + def checkReloadMatchesSaved(originalFile: String, newFile: String): Unit = { + val originalEntries = spark.read.avro(testAvro).collect() + val newEntries = spark.read.avro(newFile) + checkAnswer(newEntries, originalEntries) + } + + test("reading from multiple paths") { + val df = spark.read.avro(episodesAvro, episodesAvro) + assert(df.count == 16) + } + + test("reading and writing partitioned data") { + val df = spark.read.avro(episodesAvro) + val fields = List("title", "air_date", "doctor") + for (field <- fields) { + withTempPath { dir => + val outputDir = s"$dir/${UUID.randomUUID}" + df.write.partitionBy(field).avro(outputDir) + val input = spark.read.avro(outputDir) + // makes sure that no fields got dropped. + // We convert Rows to Seqs in order to work around SPARK-10325 + assert(input.select(field).collect().map(_.toSeq).toSet === + df.select(field).collect().map(_.toSeq).toSet) + } + } + } + + test("request no fields") { + val df = spark.read.avro(episodesAvro) + df.createOrReplaceTempView("avro_table") + assert(spark.sql("select count(*) from avro_table").collect().head === Row(8)) + } + + test("convert formats") { + withTempPath { dir => + val df = spark.read.avro(episodesAvro) + df.write.parquet(dir.getCanonicalPath) + assert(spark.read.parquet(dir.getCanonicalPath).count() === df.count) + } + } + + test("rearrange internal schema") { + withTempPath { dir => + val df = spark.read.avro(episodesAvro) + df.select("doctor", "title").write.avro(dir.getCanonicalPath) + } + } + + test("test NULL avro type") { + withTempPath { dir => + val fields = + Seq(new Field("null", Schema.create(Type.NULL), "doc", null)).asJava + val schema = Schema.createRecord("name", "docs", "namespace", false) + schema.setFields(fields) + val datumWriter = new GenericDatumWriter[GenericRecord](schema) + val dataFileWriter = new DataFileWriter[GenericRecord](datumWriter) + dataFileWriter.create(schema, new File(s"$dir.avro")) + val avroRec = new GenericData.Record(schema) + avroRec.put("null", null) + dataFileWriter.append(avroRec) + dataFileWriter.flush() + dataFileWriter.close() + + intercept[IncompatibleSchemaException] { + spark.read.avro(s"$dir.avro") + } + } + } + + test("union(int, long) is read as long") { + withTempPath { dir => + val avroSchema: Schema = { + val union = + Schema.createUnion(List(Schema.create(Type.INT), Schema.create(Type.LONG)).asJava) + val fields = Seq(new Field("field1", union, "doc", null)).asJava + val schema = Schema.createRecord("name", "docs", "namespace", false) + schema.setFields(fields) + schema + } + + val datumWriter = new GenericDatumWriter[GenericRecord](avroSchema) + val dataFileWriter = new DataFileWriter[GenericRecord](datumWriter) + dataFileWriter.create(avroSchema, new File(s"$dir.avro")) + val rec1 = new GenericData.Record(avroSchema) + rec1.put("field1", 1.toLong) + dataFileWriter.append(rec1) + val rec2 = new GenericData.Record(avroSchema) + rec2.put("field1", 2) + dataFileWriter.append(rec2) + dataFileWriter.flush() + dataFileWriter.close() + val df = spark.read.avro(s"$dir.avro") + assert(df.schema.fields === Seq(StructField("field1", LongType, nullable = true))) + assert(df.collect().toSet == Set(Row(1L), Row(2L))) + } + } + + test("union(float, double) is read as double") { + withTempPath { dir => + val avroSchema: Schema = { + val union = + Schema.createUnion(List(Schema.create(Type.FLOAT), Schema.create(Type.DOUBLE)).asJava) + val fields = Seq(new Field("field1", union, "doc", null)).asJava + val schema = Schema.createRecord("name", "docs", "namespace", false) + schema.setFields(fields) + schema + } + + val datumWriter = new GenericDatumWriter[GenericRecord](avroSchema) + val dataFileWriter = new DataFileWriter[GenericRecord](datumWriter) + dataFileWriter.create(avroSchema, new File(s"$dir.avro")) + val rec1 = new GenericData.Record(avroSchema) + rec1.put("field1", 1.toFloat) + dataFileWriter.append(rec1) + val rec2 = new GenericData.Record(avroSchema) + rec2.put("field1", 2.toDouble) + dataFileWriter.append(rec2) + dataFileWriter.flush() + dataFileWriter.close() + val df = spark.read.avro(s"$dir.avro") + assert(df.schema.fields === Seq(StructField("field1", DoubleType, nullable = true))) + assert(df.collect().toSet == Set(Row(1.toDouble), Row(2.toDouble))) + } + } + + test("union(float, double, null) is read as nullable double") { + withTempPath { dir => + val avroSchema: Schema = { + val union = Schema.createUnion( + List(Schema.create(Type.FLOAT), + Schema.create(Type.DOUBLE), + Schema.create(Type.NULL) + ).asJava + ) + val fields = Seq(new Field("field1", union, "doc", null)).asJava + val schema = Schema.createRecord("name", "docs", "namespace", false) + schema.setFields(fields) + schema + } + + val datumWriter = new GenericDatumWriter[GenericRecord](avroSchema) + val dataFileWriter = new DataFileWriter[GenericRecord](datumWriter) + dataFileWriter.create(avroSchema, new File(s"$dir.avro")) + val rec1 = new GenericData.Record(avroSchema) + rec1.put("field1", 1.toFloat) + dataFileWriter.append(rec1) + val rec2 = new GenericData.Record(avroSchema) + rec2.put("field1", null) + dataFileWriter.append(rec2) + dataFileWriter.flush() + dataFileWriter.close() + val df = spark.read.avro(s"$dir.avro") + assert(df.schema.fields === Seq(StructField("field1", DoubleType, nullable = true))) + assert(df.collect().toSet == Set(Row(1.toDouble), Row(null))) + } + } + + test("Union of a single type") { + withTempPath { dir => + val UnionOfOne = Schema.createUnion(List(Schema.create(Type.INT)).asJava) + val fields = Seq(new Field("field1", UnionOfOne, "doc", null)).asJava + val schema = Schema.createRecord("name", "docs", "namespace", false) + schema.setFields(fields) + + val datumWriter = new GenericDatumWriter[GenericRecord](schema) + val dataFileWriter = new DataFileWriter[GenericRecord](datumWriter) + dataFileWriter.create(schema, new File(s"$dir.avro")) + val avroRec = new GenericData.Record(schema) + + avroRec.put("field1", 8) + + dataFileWriter.append(avroRec) + dataFileWriter.flush() + dataFileWriter.close() + + val df = spark.read.avro(s"$dir.avro") + assert(df.first() == Row(8)) + } + } + + test("Complex Union Type") { + withTempPath { dir => + val fixedSchema = Schema.createFixed("fixed_name", "doc", "namespace", 4) + val enumSchema = Schema.createEnum("enum_name", "doc", "namespace", List("e1", "e2").asJava) + val complexUnionType = Schema.createUnion( + List(Schema.create(Type.INT), Schema.create(Type.STRING), fixedSchema, enumSchema).asJava) + val fields = Seq( + new Field("field1", complexUnionType, "doc", null), + new Field("field2", complexUnionType, "doc", null), + new Field("field3", complexUnionType, "doc", null), + new Field("field4", complexUnionType, "doc", null) + ).asJava + val schema = Schema.createRecord("name", "docs", "namespace", false) + schema.setFields(fields) + val datumWriter = new GenericDatumWriter[GenericRecord](schema) + val dataFileWriter = new DataFileWriter[GenericRecord](datumWriter) + dataFileWriter.create(schema, new File(s"$dir.avro")) + val avroRec = new GenericData.Record(schema) + val field1 = 1234 + val field2 = "Hope that was not load bearing" + val field3 = Array[Byte](1, 2, 3, 4) + val field4 = "e2" + avroRec.put("field1", field1) + avroRec.put("field2", field2) + avroRec.put("field3", new Fixed(fixedSchema, field3)) + avroRec.put("field4", new EnumSymbol(enumSchema, field4)) + dataFileWriter.append(avroRec) + dataFileWriter.flush() + dataFileWriter.close() + + val df = spark.sqlContext.read.avro(s"$dir.avro") + assertResult(field1)(df.selectExpr("field1.member0").first().get(0)) + assertResult(field2)(df.selectExpr("field2.member1").first().get(0)) + assertResult(field3)(df.selectExpr("field3.member2").first().get(0)) + assertResult(field4)(df.selectExpr("field4.member3").first().get(0)) + } + } + + test("Lots of nulls") { + withTempPath { dir => + val schema = StructType(Seq( + StructField("binary", BinaryType, true), + StructField("timestamp", TimestampType, true), + StructField("array", ArrayType(ShortType), true), + StructField("map", MapType(StringType, StringType), true), + StructField("struct", StructType(Seq(StructField("int", IntegerType, true)))))) + val rdd = spark.sparkContext.parallelize(Seq[Row]( + Row(null, new Timestamp(1), Array[Short](1, 2, 3), null, null), + Row(null, null, null, null, null), + Row(null, null, null, null, null), + Row(null, null, null, null, null))) + val df = spark.createDataFrame(rdd, schema) + df.write.avro(dir.toString) + assert(spark.read.avro(dir.toString).count == rdd.count) + } + } + + test("Struct field type") { + withTempPath { dir => + val schema = StructType(Seq( + StructField("float", FloatType, true), + StructField("short", ShortType, true), + StructField("byte", ByteType, true), + StructField("boolean", BooleanType, true) + )) + val rdd = spark.sparkContext.parallelize(Seq( + Row(1f, 1.toShort, 1.toByte, true), + Row(2f, 2.toShort, 2.toByte, true), + Row(3f, 3.toShort, 3.toByte, true) + )) + val df = spark.createDataFrame(rdd, schema) + df.write.avro(dir.toString) + assert(spark.read.avro(dir.toString).count == rdd.count) + } + } + + test("Date field type") { + withTempPath { dir => + val schema = StructType(Seq( + StructField("float", FloatType, true), + StructField("date", DateType, true) + )) + TimeZone.setDefault(TimeZone.getTimeZone("UTC")) + val rdd = spark.sparkContext.parallelize(Seq( + Row(1f, null), + Row(2f, new Date(1451948400000L)), + Row(3f, new Date(1460066400500L)) + )) + val df = spark.createDataFrame(rdd, schema) + df.write.avro(dir.toString) + assert(spark.read.avro(dir.toString).count == rdd.count) + assert(spark.read.avro(dir.toString).select("date").collect().map(_(0)).toSet == + Array(null, 1451865600000L, 1459987200000L).toSet) + } + } + + test("Array data types") { + withTempPath { dir => + val testSchema = StructType(Seq( + StructField("byte_array", ArrayType(ByteType), true), + StructField("short_array", ArrayType(ShortType), true), + StructField("float_array", ArrayType(FloatType), true), + StructField("bool_array", ArrayType(BooleanType), true), + StructField("long_array", ArrayType(LongType), true), + StructField("double_array", ArrayType(DoubleType), true), + StructField("decimal_array", ArrayType(DecimalType(10, 0)), true), + StructField("bin_array", ArrayType(BinaryType), true), + StructField("timestamp_array", ArrayType(TimestampType), true), + StructField("array_array", ArrayType(ArrayType(StringType), true), true), + StructField("struct_array", ArrayType( + StructType(Seq(StructField("name", StringType, true))))))) + + val arrayOfByte = new Array[Byte](4) + for (i <- arrayOfByte.indices) { + arrayOfByte(i) = i.toByte + } + + val rdd = spark.sparkContext.parallelize(Seq( + Row(arrayOfByte, Array[Short](1, 2, 3, 4), Array[Float](1f, 2f, 3f, 4f), + Array[Boolean](true, false, true, false), Array[Long](1L, 2L), Array[Double](1.0, 2.0), + Array[BigDecimal](BigDecimal.valueOf(3)), Array[Array[Byte]](arrayOfByte, arrayOfByte), + Array[Timestamp](new Timestamp(0)), + Array[Array[String]](Array[String]("CSH, tearing down the walls that divide us", "-jd")), + Array[Row](Row("Bobby G. can't swim"))))) + val df = spark.createDataFrame(rdd, testSchema) + df.write.avro(dir.toString) + assert(spark.read.avro(dir.toString).count == rdd.count) + } + } + + test("write with compression") { + withTempPath { dir => + val AVRO_COMPRESSION_CODEC = "spark.sql.avro.compression.codec" + val AVRO_DEFLATE_LEVEL = "spark.sql.avro.deflate.level" + val uncompressDir = s"$dir/uncompress" + val deflateDir = s"$dir/deflate" + val snappyDir = s"$dir/snappy" + + val df = spark.read.avro(testAvro) + spark.conf.set(AVRO_COMPRESSION_CODEC, "uncompressed") + df.write.avro(uncompressDir) + spark.conf.set(AVRO_COMPRESSION_CODEC, "deflate") + spark.conf.set(AVRO_DEFLATE_LEVEL, "9") + df.write.avro(deflateDir) + spark.conf.set(AVRO_COMPRESSION_CODEC, "snappy") + df.write.avro(snappyDir) + + val uncompressSize = FileUtils.sizeOfDirectory(new File(uncompressDir)) + val deflateSize = FileUtils.sizeOfDirectory(new File(deflateDir)) + val snappySize = FileUtils.sizeOfDirectory(new File(snappyDir)) + + assert(uncompressSize > deflateSize) + assert(snappySize > deflateSize) + } + } + + test("dsl test") { + val results = spark.read.avro(episodesAvro).select("title").collect() + assert(results.length === 8) + } + + test("support of various data types") { + // This test uses data from test.avro. You can see the data and the schema of this file in + // test.json and test.avsc + val all = spark.read.avro(testAvro).collect() + assert(all.length == 3) + + val str = spark.read.avro(testAvro).select("string").collect() + assert(str.map(_(0)).toSet.contains("Terran is IMBA!")) + + val simple_map = spark.read.avro(testAvro).select("simple_map").collect() + assert(simple_map(0)(0).getClass.toString.contains("Map")) + assert(simple_map.map(_(0).asInstanceOf[Map[String, Some[Int]]].size).toSet == Set(2, 0)) + + val union0 = spark.read.avro(testAvro).select("union_string_null").collect() + assert(union0.map(_(0)).toSet == Set("abc", "123", null)) + + val union1 = spark.read.avro(testAvro).select("union_int_long_null").collect() + assert(union1.map(_(0)).toSet == Set(66, 1, null)) + + val union2 = spark.read.avro(testAvro).select("union_float_double").collect() + assert( + union2 + .map(x => new java.lang.Double(x(0).toString)) + .exists(p => Math.abs(p - Math.PI) < 0.001)) + + val fixed = spark.read.avro(testAvro).select("fixed3").collect() + assert(fixed.map(_(0).asInstanceOf[Array[Byte]]).exists(p => p(1) == 3)) + + val enum = spark.read.avro(testAvro).select("enum").collect() + assert(enum.map(_(0)).toSet == Set("SPADES", "CLUBS", "DIAMONDS")) + + val record = spark.read.avro(testAvro).select("record").collect() + assert(record(0)(0).getClass.toString.contains("Row")) + assert(record.map(_(0).asInstanceOf[Row](0)).contains("TEST_STR123")) + + val array_of_boolean = spark.read.avro(testAvro).select("array_of_boolean").collect() + assert(array_of_boolean.map(_(0).asInstanceOf[Seq[Boolean]].size).toSet == Set(3, 1, 0)) + + val bytes = spark.read.avro(testAvro).select("bytes").collect() + assert(bytes.map(_(0).asInstanceOf[Array[Byte]].length).toSet == Set(3, 1, 0)) + } + + test("sql test") { + spark.sql( + s""" + |CREATE TEMPORARY VIEW avroTable + |USING avro + |OPTIONS (path "${episodesAvro}") + """.stripMargin.replaceAll("\n", " ")) + + assert(spark.sql("SELECT * FROM avroTable").collect().length === 8) + } + + test("conversion to avro and back") { + // Note that test.avro includes a variety of types, some of which are nullable. We expect to + // get the same values back. + withTempPath { dir => + val avroDir = s"$dir/avro" + spark.read.avro(testAvro).write.avro(avroDir) + checkReloadMatchesSaved(testAvro, avroDir) + } + } + + test("conversion to avro and back with namespace") { + // Note that test.avro includes a variety of types, some of which are nullable. We expect to + // get the same values back. + withTempPath { tempDir => + val name = "AvroTest" + val namespace = "com.databricks.spark.avro" + val parameters = Map("recordName" -> name, "recordNamespace" -> namespace) + + val avroDir = tempDir + "/namedAvro" + spark.read.avro(testAvro).write.options(parameters).avro(avroDir) + checkReloadMatchesSaved(testAvro, avroDir) + + // Look at raw file and make sure has namespace info + val rawSaved = spark.sparkContext.textFile(avroDir) + val schema = rawSaved.collect().mkString("") + assert(schema.contains(name)) + assert(schema.contains(namespace)) + } + } + + test("converting some specific sparkSQL types to avro") { + withTempPath { tempDir => + val testSchema = StructType(Seq( + StructField("Name", StringType, false), + StructField("Length", IntegerType, true), + StructField("Time", TimestampType, false), + StructField("Decimal", DecimalType(10, 2), true), + StructField("Binary", BinaryType, false))) + + val arrayOfByte = new Array[Byte](4) + for (i <- arrayOfByte.indices) { + arrayOfByte(i) = i.toByte + } + val cityRDD = spark.sparkContext.parallelize(Seq( + Row("San Francisco", 12, new Timestamp(666), null, arrayOfByte), + Row("Palo Alto", null, new Timestamp(777), null, arrayOfByte), + Row("Munich", 8, new Timestamp(42), Decimal(3.14), arrayOfByte))) + val cityDataFrame = spark.createDataFrame(cityRDD, testSchema) + + val avroDir = tempDir + "/avro" + cityDataFrame.write.avro(avroDir) + assert(spark.read.avro(avroDir).collect().length == 3) + + // TimesStamps are converted to longs + val times = spark.read.avro(avroDir).select("Time").collect() + assert(times.map(_(0)).toSet == Set(666, 777, 42)) + + // DecimalType should be converted to string + val decimals = spark.read.avro(avroDir).select("Decimal").collect() + assert(decimals.map(_(0)).contains("3.14")) + + // There should be a null entry + val length = spark.read.avro(avroDir).select("Length").collect() + assert(length.map(_(0)).contains(null)) + + val binary = spark.read.avro(avroDir).select("Binary").collect() + for (i <- arrayOfByte.indices) { + assert(binary(1)(0).asInstanceOf[Array[Byte]](i) == arrayOfByte(i)) + } + } + } + + test("correctly read long as date/timestamp type") { + withTempPath { tempDir => + val sparkSession = spark + import sparkSession.implicits._ + + val currentTime = new Timestamp(System.currentTimeMillis()) + val currentDate = new Date(System.currentTimeMillis()) + val schema = StructType(Seq( + StructField("_1", DateType, false), StructField("_2", TimestampType, false))) + val writeDs = Seq((currentDate, currentTime)).toDS + + val avroDir = tempDir + "/avro" + writeDs.write.avro(avroDir) + assert(spark.read.avro(avroDir).collect().length == 1) + + val readDs = spark.read.schema(schema).avro(avroDir).as[(Date, Timestamp)] + + assert(readDs.collect().sameElements(writeDs.collect())) + } + } + + test("support of globbed paths") { + val resourceDir = testFile(".") + val e1 = spark.read.avro(resourceDir + "../*/episodes.avro").collect() + assert(e1.length == 8) + + val e2 = spark.read.avro(resourceDir + "../../*/*/episodes.avro").collect() + assert(e2.length == 8) + } + + test("does not coerce null date/timestamp value to 0 epoch.") { + withTempPath { tempDir => + val sparkSession = spark + import sparkSession.implicits._ + + val nullTime: Timestamp = null + val nullDate: Date = null + val schema = StructType(Seq( + StructField("_1", DateType, nullable = true), + StructField("_2", TimestampType, nullable = true)) + ) + val writeDs = Seq((nullDate, nullTime)).toDS + + val avroDir = tempDir + "/avro" + writeDs.write.avro(avroDir) + val readValues = spark.read.schema(schema).avro(avroDir).as[(Date, Timestamp)].collect + + assert(readValues.size == 1) + assert(readValues.head == ((nullDate, nullTime))) + } + } + + test("support user provided avro schema") { + val avroSchema = + """ + |{ + | "type" : "record", + | "name" : "test_schema", + | "fields" : [{ + | "name" : "string", + | "type" : "string", + | "doc" : "Meaningless string of characters" + | }] + |} + """.stripMargin + val result = spark + .read + .option(AvroFileFormat.AvroSchema, avroSchema) + .avro(testAvro) + .collect() + val expected = spark.read.avro(testAvro).select("string").collect() + assert(result.sameElements(expected)) + } + + test("support user provided avro schema with defaults for missing fields") { + val avroSchema = + """ + |{ + | "type" : "record", + | "name" : "test_schema", + | "fields" : [{ + | "name" : "missingField", + | "type" : "string", + | "default" : "foo" + | }] + |} + """.stripMargin + val result = spark.read.option(AvroFileFormat.AvroSchema, avroSchema) + .avro(testAvro).select("missingField").first + assert(result === Row("foo")) + } + + test("reading from invalid path throws exception") { + + // Directory given has no avro files + intercept[AnalysisException] { + withTempPath(dir => spark.read.avro(dir.getCanonicalPath)) + } + + intercept[AnalysisException] { + spark.read.avro("very/invalid/path/123.avro") + } + + // In case of globbed path that can't be matched to anything, another exception is thrown (and + // exception message is helpful) + intercept[AnalysisException] { + spark.read.avro("*/*/*/*/*/*/*/something.avro") + } + + intercept[FileNotFoundException] { + withTempPath { dir => + FileUtils.touch(new File(dir, "test")) + val hadoopConf = spark.sqlContext.sparkContext.hadoopConfiguration + try { + hadoopConf.set(AvroFileFormat.IgnoreFilesWithoutExtensionProperty, "true") + spark.read.avro(dir.toString) + } finally { + hadoopConf.unset(AvroFileFormat.IgnoreFilesWithoutExtensionProperty) } + } + } + + } + + test("SQL test insert overwrite") { + withTempPath { tempDir => + val tempEmptyDir = s"$tempDir/sqlOverwrite" + // Create a temp directory for table that will be overwritten + new File(tempEmptyDir).mkdirs() + spark.sql( + s""" + |CREATE TEMPORARY VIEW episodes + |USING avro + |OPTIONS (path "${episodesAvro}") + """.stripMargin.replaceAll("\n", " ")) + spark.sql( + s""" + |CREATE TEMPORARY VIEW episodesEmpty + |(name string, air_date string, doctor int) + |USING avro + |OPTIONS (path "$tempEmptyDir") + """.stripMargin.replaceAll("\n", " ")) + + assert(spark.sql("SELECT * FROM episodes").collect().length === 8) + assert(spark.sql("SELECT * FROM episodesEmpty").collect().isEmpty) + + spark.sql( + s""" + |INSERT OVERWRITE TABLE episodesEmpty + |SELECT * FROM episodes + """.stripMargin.replaceAll("\n", " ")) + assert(spark.sql("SELECT * FROM episodesEmpty").collect().length == 8) + } + } + + test("test save and load") { + // Test if load works as expected + withTempPath { tempDir => + val df = spark.read.avro(episodesAvro) + assert(df.count == 8) + + val tempSaveDir = s"$tempDir/save/" + + df.write.avro(tempSaveDir) + val newDf = spark.read.avro(tempSaveDir) + assert(newDf.count == 8) + } + } + + test("test load with non-Avro file") { + // Test if load works as expected + withTempPath { tempDir => + val df = spark.read.avro(episodesAvro) + assert(df.count == 8) + + val tempSaveDir = s"$tempDir/save/" + df.write.avro(tempSaveDir) + + Files.createFile(new File(tempSaveDir, "non-avro").toPath) + + val hadoopConf = spark.sqlContext.sparkContext.hadoopConfiguration + val count = try { + hadoopConf.set(AvroFileFormat.IgnoreFilesWithoutExtensionProperty, "true") + val newDf = spark + .read + .avro(tempSaveDir) + newDf.count() + } finally { + hadoopConf.unset(AvroFileFormat.IgnoreFilesWithoutExtensionProperty) + } + + assert(count == 8) + } + } + + test("read avro with user defined schema: read partial columns") { + val partialColumns = StructType(Seq( + StructField("string", StringType, false), + StructField("simple_map", MapType(StringType, IntegerType), false), + StructField("complex_map", MapType(StringType, MapType(StringType, StringType)), false), + StructField("union_string_null", StringType, true), + StructField("union_int_long_null", LongType, true), + StructField("fixed3", BinaryType, true), + StructField("fixed2", BinaryType, true), + StructField("enum", StringType, false), + StructField("record", StructType(Seq(StructField("value_field", StringType, false))), false), + StructField("array_of_boolean", ArrayType(BooleanType), false), + StructField("bytes", BinaryType, true))) + val withSchema = spark.read.schema(partialColumns).avro(testAvro).collect() + val withOutSchema = spark + .read + .avro(testAvro) + .select("string", "simple_map", "complex_map", "union_string_null", "union_int_long_null", + "fixed3", "fixed2", "enum", "record", "array_of_boolean", "bytes") + .collect() + assert(withSchema.sameElements(withOutSchema)) + } + + test("read avro with user defined schema: read non-exist columns") { + val schema = + StructType( + Seq( + StructField("non_exist_string", StringType, true), + StructField( + "record", + StructType(Seq( + StructField("non_exist_field", StringType, false), + StructField("non_exist_field2", StringType, false))), + false))) + val withEmptyColumn = spark.read.schema(schema).avro(testAvro).collect() + + assert(withEmptyColumn.forall(_ == Row(null: String, Row(null: String, null: String)))) + } + + test("read avro file partitioned") { + withTempPath { dir => + val sparkSession = spark + import sparkSession.implicits._ + val df = (0 to 1024 * 3).toDS.map(i => s"record${i}").toDF("records") + val outputDir = s"$dir/${UUID.randomUUID}" + df.write.avro(outputDir) + val input = spark.read.avro(outputDir) + assert(input.collect.toSet.size === 1024 * 3 + 1) + assert(input.rdd.partitions.size > 2) + } + } + + case class NestedBottom(id: Int, data: String) + + case class NestedMiddle(id: Int, data: NestedBottom) + + case class NestedTop(id: Int, data: NestedMiddle) + + test("saving avro that has nested records with the same name") { + withTempPath { tempDir => + // Save avro file on output folder path + val writeDf = spark.createDataFrame(List(NestedTop(1, NestedMiddle(2, NestedBottom(3, "1"))))) + val outputFolder = s"$tempDir/duplicate_names/" + writeDf.write.avro(outputFolder) + // Read avro file saved on the last step + val readDf = spark.read.avro(outputFolder) + // Check if the written DataFrame is equals than read DataFrame + assert(readDf.collect().sameElements(writeDf.collect())) + } + } + + case class NestedMiddleArray(id: Int, data: Array[NestedBottom]) + + case class NestedTopArray(id: Int, data: NestedMiddleArray) + + test("saving avro that has nested records with the same name inside an array") { + withTempPath { tempDir => + // Save avro file on output folder path + val writeDf = spark.createDataFrame( + List(NestedTopArray(1, NestedMiddleArray(2, Array( + NestedBottom(3, "1"), NestedBottom(4, "2") + )))) + ) + val outputFolder = s"$tempDir/duplicate_names_array/" + writeDf.write.avro(outputFolder) + // Read avro file saved on the last step + val readDf = spark.read.avro(outputFolder) + // Check if the written DataFrame is equals than read DataFrame + assert(readDf.collect().sameElements(writeDf.collect())) + } + } + + case class NestedMiddleMap(id: Int, data: Map[String, NestedBottom]) + + case class NestedTopMap(id: Int, data: NestedMiddleMap) + + test("saving avro that has nested records with the same name inside a map") { + withTempPath { tempDir => + // Save avro file on output folder path + val writeDf = spark.createDataFrame( + List(NestedTopMap(1, NestedMiddleMap(2, Map( + "1" -> NestedBottom(3, "1"), "2" -> NestedBottom(4, "2") + )))) + ) + val outputFolder = s"$tempDir/duplicate_names_map/" + writeDf.write.avro(outputFolder) + // Read avro file saved on the last step + val readDf = spark.read.avro(outputFolder) + // Check if the written DataFrame is equals than read DataFrame + assert(readDf.collect().sameElements(writeDf.collect())) + } + } + + test("SPARK-24805: do not ignore files without .avro extension by default") { + withTempDir { dir => + Files.copy( + Paths.get(new URL(episodesAvro).toURI), + Paths.get(dir.getCanonicalPath, "episodes")) + + val fileWithoutExtension = s"${dir.getCanonicalPath}/episodes" + val df1 = spark.read.avro(fileWithoutExtension) + assert(df1.count == 8) + + val schema = new StructType() + .add("title", StringType) + .add("air_date", StringType) + .add("doctor", IntegerType) + val df2 = spark.read.schema(schema).avro(fileWithoutExtension) + assert(df2.count == 8) + } + } +} diff --git a/external/avro/src/test/scala/org/apache/spark/sql/avro/SerializableConfigurationSuite.scala b/external/avro/src/test/scala/org/apache/spark/sql/avro/SerializableConfigurationSuite.scala new file mode 100755 index 0000000000000..a0f88515ed9d4 --- /dev/null +++ b/external/avro/src/test/scala/org/apache/spark/sql/avro/SerializableConfigurationSuite.scala @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.avro + +import org.apache.hadoop.conf.Configuration + +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.serializer.{JavaSerializer, KryoSerializer, SerializerInstance} + +class SerializableConfigurationSuite extends SparkFunSuite { + + private def testSerialization(serializer: SerializerInstance): Unit = { + import AvroFileFormat.SerializableConfiguration + val conf = new SerializableConfiguration(new Configuration()) + + val serialized = serializer.serialize(conf) + + serializer.deserialize[Any](serialized) match { + case c: SerializableConfiguration => + assert(c.log != null, "log was null") + assert(c.value != null, "value was null") + case other => fail( + s"Expecting ${classOf[SerializableConfiguration]}, but got ${other.getClass}.") + } + } + + test("serialization with JavaSerializer") { + testSerialization(new JavaSerializer(new SparkConf()).newInstance()) + } + + test("serialization with KryoSerializer") { + testSerialization(new KryoSerializer(new SparkConf()).newInstance()) + } + +} diff --git a/external/avro/src/test/scala/org/apache/spark/sql/avro/SerializableSchemaSuite.scala b/external/avro/src/test/scala/org/apache/spark/sql/avro/SerializableSchemaSuite.scala new file mode 100644 index 0000000000000..510bcbdd31929 --- /dev/null +++ b/external/avro/src/test/scala/org/apache/spark/sql/avro/SerializableSchemaSuite.scala @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.avro + +import org.apache.avro.Schema + +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.serializer.{JavaSerializer, KryoSerializer, SerializerInstance} + +class SerializableSchemaSuite extends SparkFunSuite { + + private def testSerialization(serializer: SerializerInstance): Unit = { + val avroTypeJson = + s""" + |{ + | "type": "string", + | "name": "my_string" + |} + """.stripMargin + val avroSchema = new Schema.Parser().parse(avroTypeJson) + val serializableSchema = new SerializableSchema(avroSchema) + val serialized = serializer.serialize(serializableSchema) + + serializer.deserialize[Any](serialized) match { + case c: SerializableSchema => + assert(c.log != null, "log was null") + assert(c.value != null, "value was null") + assert(c.value == avroSchema) + case other => fail( + s"Expecting ${classOf[SerializableSchema]}, but got ${other.getClass}.") + } + } + + test("serialization with JavaSerializer") { + testSerialization(new JavaSerializer(new SparkConf()).newInstance()) + } + + test("serialization with KryoSerializer") { + testSerialization(new KryoSerializer(new SparkConf()).newInstance()) + } +} diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaStreamWriter.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaStreamWriter.scala index ae5b5c52d514e..32923dc9f5a6b 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaStreamWriter.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaStreamWriter.scala @@ -67,7 +67,7 @@ case class KafkaStreamWriterFactory( override def createDataWriter( partitionId: Int, - attemptNumber: Int, + taskId: Long, epochId: Long): DataWriter[InternalRow] = { new KafkaStreamDataWriter(topic, producerParams, schema.toAttributes) } diff --git a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/DirectKafkaInputDStream.scala b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/DirectKafkaInputDStream.scala index c3221481556f5..0246006acf0bd 100644 --- a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/DirectKafkaInputDStream.scala +++ b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/DirectKafkaInputDStream.scala @@ -166,6 +166,8 @@ private[spark] class DirectKafkaInputDStream[K, V]( * which would throw off consumer position. Fix position if this happens. */ private def paranoidPoll(c: Consumer[K, V]): Unit = { + // don't actually want to consume any messages, so pause all partitions + c.pause(c.assignment()) val msgs = c.poll(0) if (!msgs.isEmpty) { // position should be minimum offset per topicpartition @@ -204,8 +206,6 @@ private[spark] class DirectKafkaInputDStream[K, V]( // position for new partitions determined by auto.offset.reset if no commit currentOffsets = currentOffsets ++ newPartitions.map(tp => tp -> c.position(tp)).toMap - // don't want to consume messages, so pause - c.pause(newPartitions.asJava) // find latest available offsets c.seekToEnd(currentOffsets.keySet.asJava) parts.map(tp => tp -> c.position(tp)).toMap @@ -262,9 +262,6 @@ private[spark] class DirectKafkaInputDStream[K, V]( tp -> c.position(tp) }.toMap } - - // don't actually want to consume any messages, so pause all partitions - c.pause(currentOffsets.keySet.asJava) } override def stop(): Unit = this.synchronized { diff --git a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala index fa0de6298a5f1..69c52365b1bf8 100644 --- a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala +++ b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala @@ -160,7 +160,6 @@ private[kinesis] class KinesisReceiver[T]( cloudWatchCreds.map(_.provider).getOrElse(kinesisProvider), workerId) .withKinesisEndpoint(endpointUrl) - .withInitialPositionInStream(initialPosition.getPosition) .withTaskBackoffTimeMillis(500) .withRegionName(regionName) @@ -169,7 +168,8 @@ private[kinesis] class KinesisReceiver[T]( initialPosition match { case ts: AtTimestamp => baseClientLibConfiguration.withTimestampAtInitialPositionInStream(ts.getTimestamp) - case _ => baseClientLibConfiguration + case _ => + baseClientLibConfiguration.withInitialPositionInStream(initialPosition.getPosition) } } diff --git a/graphx/pom.xml b/graphx/pom.xml index fbe77fcb958d5..0f5dc548600b2 100644 --- a/graphx/pom.xml +++ b/graphx/pom.xml @@ -53,7 +53,7 @@
org.apache.xbean - xbean-asm5-shaded + xbean-asm6-shaded com.google.guava diff --git a/graphx/src/main/scala/org/apache/spark/graphx/util/BytecodeUtils.scala b/graphx/src/main/scala/org/apache/spark/graphx/util/BytecodeUtils.scala index d76e84ed8c9ed..a559685b1633c 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/util/BytecodeUtils.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/util/BytecodeUtils.scala @@ -22,8 +22,8 @@ import java.io.{ByteArrayInputStream, ByteArrayOutputStream} import scala.collection.mutable.HashSet import scala.language.existentials -import org.apache.xbean.asm5.{ClassReader, ClassVisitor, MethodVisitor} -import org.apache.xbean.asm5.Opcodes._ +import org.apache.xbean.asm6.{ClassReader, ClassVisitor, MethodVisitor} +import org.apache.xbean.asm6.Opcodes._ import org.apache.spark.util.Utils diff --git a/launcher/src/main/java/org/apache/spark/launcher/Main.java b/launcher/src/main/java/org/apache/spark/launcher/Main.java index 1e34bb8c73279..d967aa39a4827 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/Main.java +++ b/launcher/src/main/java/org/apache/spark/launcher/Main.java @@ -17,6 +17,7 @@ package org.apache.spark.launcher; +import java.io.IOException; import java.util.ArrayList; import java.util.Arrays; import java.util.HashMap; @@ -54,10 +55,12 @@ public static void main(String[] argsArray) throws Exception { String className = args.remove(0); boolean printLaunchCommand = !isEmpty(System.getenv("SPARK_PRINT_LAUNCH_COMMAND")); - AbstractCommandBuilder builder; + Map env = new HashMap<>(); + List cmd; if (className.equals("org.apache.spark.deploy.SparkSubmit")) { try { - builder = new SparkSubmitCommandBuilder(args); + AbstractCommandBuilder builder = new SparkSubmitCommandBuilder(args); + cmd = buildCommand(builder, env, printLaunchCommand); } catch (IllegalArgumentException e) { printLaunchCommand = false; System.err.println("Error: " + e.getMessage()); @@ -76,17 +79,12 @@ public static void main(String[] argsArray) throws Exception { help.add(parser.className); } help.add(parser.USAGE_ERROR); - builder = new SparkSubmitCommandBuilder(help); + AbstractCommandBuilder builder = new SparkSubmitCommandBuilder(help); + cmd = buildCommand(builder, env, printLaunchCommand); } } else { - builder = new SparkClassCommandBuilder(className, args); - } - - Map env = new HashMap<>(); - List cmd = builder.buildCommand(env); - if (printLaunchCommand) { - System.err.println("Spark Command: " + join(" ", cmd)); - System.err.println("========================================"); + AbstractCommandBuilder builder = new SparkClassCommandBuilder(className, args); + cmd = buildCommand(builder, env, printLaunchCommand); } if (isWindows()) { @@ -101,6 +99,22 @@ public static void main(String[] argsArray) throws Exception { } } + /** + * Prepare spark commands with the appropriate command builder. + * If printLaunchCommand is set then the commands will be printed to the stderr. + */ + private static List buildCommand( + AbstractCommandBuilder builder, + Map env, + boolean printLaunchCommand) throws IOException, IllegalArgumentException { + List cmd = builder.buildCommand(env); + if (printLaunchCommand) { + System.err.println("Spark Command: " + join(" ", cmd)); + System.err.println("========================================"); + } + return cmd; + } + /** * Prepare a command line for execution from a Windows batch script. * diff --git a/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java b/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java index 5cb6457bf5c21..cc65f78b45c30 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java +++ b/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java @@ -90,7 +90,8 @@ class SparkSubmitCommandBuilder extends AbstractCommandBuilder { final List userArgs; private final List parsedArgs; - private final boolean requiresAppResource; + // Special command means no appResource and no mainClass required + private final boolean isSpecialCommand; private final boolean isExample; /** @@ -105,7 +106,7 @@ class SparkSubmitCommandBuilder extends AbstractCommandBuilder { * spark-submit argument list to be modified after creation. */ SparkSubmitCommandBuilder() { - this.requiresAppResource = true; + this.isSpecialCommand = false; this.isExample = false; this.parsedArgs = new ArrayList<>(); this.userArgs = new ArrayList<>(); @@ -138,25 +139,26 @@ class SparkSubmitCommandBuilder extends AbstractCommandBuilder { case RUN_EXAMPLE: isExample = true; + appResource = SparkLauncher.NO_RESOURCE; submitArgs = args.subList(1, args.size()); } this.isExample = isExample; OptionParser parser = new OptionParser(true); parser.parse(submitArgs); - this.requiresAppResource = parser.requiresAppResource; + this.isSpecialCommand = parser.isSpecialCommand; } else { this.isExample = isExample; - this.requiresAppResource = false; + this.isSpecialCommand = true; } } @Override public List buildCommand(Map env) throws IOException, IllegalArgumentException { - if (PYSPARK_SHELL.equals(appResource) && requiresAppResource) { + if (PYSPARK_SHELL.equals(appResource) && !isSpecialCommand) { return buildPySparkShellCommand(env); - } else if (SPARKR_SHELL.equals(appResource) && requiresAppResource) { + } else if (SPARKR_SHELL.equals(appResource) && !isSpecialCommand) { return buildSparkRCommand(env); } else { return buildSparkSubmitCommand(env); @@ -166,18 +168,18 @@ public List buildCommand(Map env) List buildSparkSubmitArgs() { List args = new ArrayList<>(); OptionParser parser = new OptionParser(false); - final boolean requiresAppResource; + final boolean isSpecialCommand; // If the user args array is not empty, we need to parse it to detect exactly what // the user is trying to run, so that checks below are correct. if (!userArgs.isEmpty()) { parser.parse(userArgs); - requiresAppResource = parser.requiresAppResource; + isSpecialCommand = parser.isSpecialCommand; } else { - requiresAppResource = this.requiresAppResource; + isSpecialCommand = this.isSpecialCommand; } - if (!allowsMixedArguments && requiresAppResource) { + if (!allowsMixedArguments && !isSpecialCommand) { checkArgument(appResource != null, "Missing application resource."); } @@ -229,7 +231,7 @@ List buildSparkSubmitArgs() { args.add(join(",", pyFiles)); } - if (isExample) { + if (isExample && !isSpecialCommand) { checkArgument(mainClass != null, "Missing example class name."); } @@ -421,7 +423,7 @@ private List findExamplesJars() { private class OptionParser extends SparkSubmitOptionParser { - boolean requiresAppResource = true; + boolean isSpecialCommand = false; private final boolean errorOnUnknownArgs; OptionParser(boolean errorOnUnknownArgs) { @@ -470,17 +472,14 @@ protected boolean handle(String opt, String value) { break; case KILL_SUBMISSION: case STATUS: - requiresAppResource = false; + isSpecialCommand = true; parsedArgs.add(opt); parsedArgs.add(value); break; case HELP: case USAGE_ERROR: - requiresAppResource = false; - parsedArgs.add(opt); - break; case VERSION: - requiresAppResource = false; + isSpecialCommand = true; parsedArgs.add(opt); break; default: diff --git a/launcher/src/test/java/org/apache/spark/launcher/SparkSubmitCommandBuilderSuite.java b/launcher/src/test/java/org/apache/spark/launcher/SparkSubmitCommandBuilderSuite.java index 2e050f8413074..b343094b2e7b8 100644 --- a/launcher/src/test/java/org/apache/spark/launcher/SparkSubmitCommandBuilderSuite.java +++ b/launcher/src/test/java/org/apache/spark/launcher/SparkSubmitCommandBuilderSuite.java @@ -18,6 +18,7 @@ package org.apache.spark.launcher; import java.io.File; +import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; import java.util.HashMap; @@ -27,7 +28,10 @@ import org.junit.AfterClass; import org.junit.BeforeClass; +import org.junit.Rule; import org.junit.Test; +import org.junit.rules.ExpectedException; + import static org.junit.Assert.*; public class SparkSubmitCommandBuilderSuite extends BaseSuite { @@ -35,6 +39,9 @@ public class SparkSubmitCommandBuilderSuite extends BaseSuite { private static File dummyPropsFile; private static SparkSubmitOptionParser parser; + @Rule + public ExpectedException expectedException = ExpectedException.none(); + @BeforeClass public static void setUp() throws Exception { dummyPropsFile = File.createTempFile("spark", "properties"); @@ -74,8 +81,11 @@ public void testCliHelpAndNoArg() throws Exception { @Test public void testCliKillAndStatus() throws Exception { - testCLIOpts(parser.STATUS); - testCLIOpts(parser.KILL_SUBMISSION); + List params = Arrays.asList("driver-20160531171222-0000"); + testCLIOpts(null, parser.STATUS, params); + testCLIOpts(null, parser.KILL_SUBMISSION, params); + testCLIOpts(SparkSubmitCommandBuilder.RUN_EXAMPLE, parser.STATUS, params); + testCLIOpts(SparkSubmitCommandBuilder.RUN_EXAMPLE, parser.KILL_SUBMISSION, params); } @Test @@ -190,6 +200,33 @@ public void testSparkRShell() throws Exception { env.get("SPARKR_SUBMIT_ARGS")); } + @Test(expected = IllegalArgumentException.class) + public void testExamplesRunnerNoArg() throws Exception { + List sparkSubmitArgs = Arrays.asList(SparkSubmitCommandBuilder.RUN_EXAMPLE); + Map env = new HashMap<>(); + buildCommand(sparkSubmitArgs, env); + } + + @Test + public void testExamplesRunnerNoMainClass() throws Exception { + testCLIOpts(SparkSubmitCommandBuilder.RUN_EXAMPLE, parser.HELP, null); + testCLIOpts(SparkSubmitCommandBuilder.RUN_EXAMPLE, parser.USAGE_ERROR, null); + testCLIOpts(SparkSubmitCommandBuilder.RUN_EXAMPLE, parser.VERSION, null); + } + + @Test + public void testExamplesRunnerWithMasterNoMainClass() throws Exception { + expectedException.expect(IllegalArgumentException.class); + expectedException.expectMessage("Missing example class name."); + + List sparkSubmitArgs = Arrays.asList( + SparkSubmitCommandBuilder.RUN_EXAMPLE, + parser.MASTER + "=foo" + ); + Map env = new HashMap<>(); + buildCommand(sparkSubmitArgs, env); + } + @Test public void testExamplesRunner() throws Exception { List sparkSubmitArgs = Arrays.asList( @@ -344,10 +381,17 @@ private List buildCommand(List args, Map env) th return newCommandBuilder(args).buildCommand(env); } - private void testCLIOpts(String opt) throws Exception { - List helpArgs = Arrays.asList(opt, "driver-20160531171222-0000"); + private void testCLIOpts(String appResource, String opt, List params) throws Exception { + List args = new ArrayList<>(); + if (appResource != null) { + args.add(appResource); + } + args.add(opt); + if (params != null) { + args.addAll(params); + } Map env = new HashMap<>(); - List cmd = buildCommand(helpArgs, env); + List cmd = buildCommand(args, env); assertTrue(opt + " should be contained in the final cmd.", cmd.contains(opt)); } diff --git a/licenses/LICENSE-scopt.txt b/licenses-binary/LICENSE-AnchorJS.txt similarity index 100% rename from licenses/LICENSE-scopt.txt rename to licenses-binary/LICENSE-AnchorJS.txt diff --git a/licenses-binary/LICENSE-CC0.txt b/licenses-binary/LICENSE-CC0.txt new file mode 100644 index 0000000000000..1625c17936079 --- /dev/null +++ b/licenses-binary/LICENSE-CC0.txt @@ -0,0 +1,121 @@ +Creative Commons Legal Code + +CC0 1.0 Universal + + CREATIVE COMMONS CORPORATION IS NOT A LAW FIRM AND DOES NOT PROVIDE + LEGAL SERVICES. DISTRIBUTION OF THIS DOCUMENT DOES NOT CREATE AN + ATTORNEY-CLIENT RELATIONSHIP. CREATIVE COMMONS PROVIDES THIS + INFORMATION ON AN "AS-IS" BASIS. CREATIVE COMMONS MAKES NO WARRANTIES + REGARDING THE USE OF THIS DOCUMENT OR THE INFORMATION OR WORKS + PROVIDED HEREUNDER, AND DISCLAIMS LIABILITY FOR DAMAGES RESULTING FROM + THE USE OF THIS DOCUMENT OR THE INFORMATION OR WORKS PROVIDED + HEREUNDER. + +Statement of Purpose + +The laws of most jurisdictions throughout the world automatically confer +exclusive Copyright and Related Rights (defined below) upon the creator +and subsequent owner(s) (each and all, an "owner") of an original work of +authorship and/or a database (each, a "Work"). + +Certain owners wish to permanently relinquish those rights to a Work for +the purpose of contributing to a commons of creative, cultural and +scientific works ("Commons") that the public can reliably and without fear +of later claims of infringement build upon, modify, incorporate in other +works, reuse and redistribute as freely as possible in any form whatsoever +and for any purposes, including without limitation commercial purposes. +These owners may contribute to the Commons to promote the ideal of a free +culture and the further production of creative, cultural and scientific +works, or to gain reputation or greater distribution for their Work in +part through the use and efforts of others. + +For these and/or other purposes and motivations, and without any +expectation of additional consideration or compensation, the person +associating CC0 with a Work (the "Affirmer"), to the extent that he or she +is an owner of Copyright and Related Rights in the Work, voluntarily +elects to apply CC0 to the Work and publicly distribute the Work under its +terms, with knowledge of his or her Copyright and Related Rights in the +Work and the meaning and intended legal effect of CC0 on those rights. + +1. Copyright and Related Rights. A Work made available under CC0 may be +protected by copyright and related or neighboring rights ("Copyright and +Related Rights"). Copyright and Related Rights include, but are not +limited to, the following: + + i. the right to reproduce, adapt, distribute, perform, display, + communicate, and translate a Work; + ii. moral rights retained by the original author(s) and/or performer(s); +iii. publicity and privacy rights pertaining to a person's image or + likeness depicted in a Work; + iv. rights protecting against unfair competition in regards to a Work, + subject to the limitations in paragraph 4(a), below; + v. rights protecting the extraction, dissemination, use and reuse of data + in a Work; + vi. database rights (such as those arising under Directive 96/9/EC of the + European Parliament and of the Council of 11 March 1996 on the legal + protection of databases, and under any national implementation + thereof, including any amended or successor version of such + directive); and +vii. other similar, equivalent or corresponding rights throughout the + world based on applicable law or treaty, and any national + implementations thereof. + +2. Waiver. To the greatest extent permitted by, but not in contravention +of, applicable law, Affirmer hereby overtly, fully, permanently, +irrevocably and unconditionally waives, abandons, and surrenders all of +Affirmer's Copyright and Related Rights and associated claims and causes +of action, whether now known or unknown (including existing as well as +future claims and causes of action), in the Work (i) in all territories +worldwide, (ii) for the maximum duration provided by applicable law or +treaty (including future time extensions), (iii) in any current or future +medium and for any number of copies, and (iv) for any purpose whatsoever, +including without limitation commercial, advertising or promotional +purposes (the "Waiver"). Affirmer makes the Waiver for the benefit of each +member of the public at large and to the detriment of Affirmer's heirs and +successors, fully intending that such Waiver shall not be subject to +revocation, rescission, cancellation, termination, or any other legal or +equitable action to disrupt the quiet enjoyment of the Work by the public +as contemplated by Affirmer's express Statement of Purpose. + +3. Public License Fallback. Should any part of the Waiver for any reason +be judged legally invalid or ineffective under applicable law, then the +Waiver shall be preserved to the maximum extent permitted taking into +account Affirmer's express Statement of Purpose. In addition, to the +extent the Waiver is so judged Affirmer hereby grants to each affected +person a royalty-free, non transferable, non sublicensable, non exclusive, +irrevocable and unconditional license to exercise Affirmer's Copyright and +Related Rights in the Work (i) in all territories worldwide, (ii) for the +maximum duration provided by applicable law or treaty (including future +time extensions), (iii) in any current or future medium and for any number +of copies, and (iv) for any purpose whatsoever, including without +limitation commercial, advertising or promotional purposes (the +"License"). The License shall be deemed effective as of the date CC0 was +applied by Affirmer to the Work. Should any part of the License for any +reason be judged legally invalid or ineffective under applicable law, such +partial invalidity or ineffectiveness shall not invalidate the remainder +of the License, and in such case Affirmer hereby affirms that he or she +will not (i) exercise any of his or her remaining Copyright and Related +Rights in the Work or (ii) assert any associated claims and causes of +action with respect to the Work, in either case contrary to Affirmer's +express Statement of Purpose. + +4. Limitations and Disclaimers. + + a. No trademark or patent rights held by Affirmer are waived, abandoned, + surrendered, licensed or otherwise affected by this document. + b. Affirmer offers the Work as-is and makes no representations or + warranties of any kind concerning the Work, express, implied, + statutory or otherwise, including without limitation warranties of + title, merchantability, fitness for a particular purpose, non + infringement, or the absence of latent or other defects, accuracy, or + the present or absence of errors, whether or not discoverable, all to + the greatest extent permissible under applicable law. + c. Affirmer disclaims responsibility for clearing rights of other persons + that may apply to the Work or any use thereof, including without + limitation any person's Copyright and Related Rights in the Work. + Further, Affirmer disclaims responsibility for obtaining any necessary + consents, permissions or other rights required for any use of the + Work. + d. Affirmer understands and acknowledges that Creative Commons is not a + party to this document and has no duty or obligation with respect to + this CC0 or use of the Work. \ No newline at end of file diff --git a/licenses/LICENSE-antlr.txt b/licenses-binary/LICENSE-antlr.txt similarity index 100% rename from licenses/LICENSE-antlr.txt rename to licenses-binary/LICENSE-antlr.txt diff --git a/licenses-binary/LICENSE-arpack.txt b/licenses-binary/LICENSE-arpack.txt new file mode 100644 index 0000000000000..a3ad80087bb63 --- /dev/null +++ b/licenses-binary/LICENSE-arpack.txt @@ -0,0 +1,8 @@ +Copyright © 2018 The University of Tennessee. All rights reserved. + +Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: +· Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. +· Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer listed in this license in the documentation and/or other materials provided with the distribution. +· Neither the name of the copyright holders nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. + +This software is provided by the copyright holders and contributors "as is" and any express or implied warranties, including, but not limited to, the implied warranties of merchantability and fitness for a particular purpose are disclaimed. in no event shall the copyright owner or contributors be liable for any direct, indirect, incidental, special, exemplary, or consequential damages (including, but not limited to, procurement of substitute goods or services; loss of use, data, or profits; or business interruption) however caused and on any theory of liability, whether in contract, strict liability, or tort (including negligence or otherwise) arising in any way out of the use of this software, even if advised of the possibility of such damage. \ No newline at end of file diff --git a/licenses-binary/LICENSE-automaton.txt b/licenses-binary/LICENSE-automaton.txt new file mode 100644 index 0000000000000..2fc6e8c3432f0 --- /dev/null +++ b/licenses-binary/LICENSE-automaton.txt @@ -0,0 +1,24 @@ +Copyright (c) 2001-2017 Anders Moeller +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions +are met: +1. Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. +2. Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. +3. The name of the author may not be used to endorse or promote products + derived from this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR +IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES +OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. +IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT, +INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT +NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF +THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. \ No newline at end of file diff --git a/licenses-binary/LICENSE-bootstrap.txt b/licenses-binary/LICENSE-bootstrap.txt new file mode 100644 index 0000000000000..6c711832fbc85 --- /dev/null +++ b/licenses-binary/LICENSE-bootstrap.txt @@ -0,0 +1,13 @@ +Copyright 2013 Twitter, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. diff --git a/licenses-binary/LICENSE-bouncycastle-bcprov.txt b/licenses-binary/LICENSE-bouncycastle-bcprov.txt new file mode 100644 index 0000000000000..c445a93a06dd4 --- /dev/null +++ b/licenses-binary/LICENSE-bouncycastle-bcprov.txt @@ -0,0 +1,7 @@ +Copyright (c) 2000 - 2018 The Legion of the Bouncy Castle Inc. (https://www.bouncycastle.org) + +Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. \ No newline at end of file diff --git a/licenses-binary/LICENSE-cloudpickle.txt b/licenses-binary/LICENSE-cloudpickle.txt new file mode 100644 index 0000000000000..b1e20fa1eda88 --- /dev/null +++ b/licenses-binary/LICENSE-cloudpickle.txt @@ -0,0 +1,28 @@ +Copyright (c) 2012, Regents of the University of California. +Copyright (c) 2009 `PiCloud, Inc. `_. +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions +are met: + * Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + * Neither the name of the University of California, Berkeley nor the + names of its contributors may be used to endorse or promote + products derived from this software without specific prior written + permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED +TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF +LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING +NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. \ No newline at end of file diff --git a/licenses-binary/LICENSE-d3.min.js.txt b/licenses-binary/LICENSE-d3.min.js.txt new file mode 100644 index 0000000000000..c71e3f254c068 --- /dev/null +++ b/licenses-binary/LICENSE-d3.min.js.txt @@ -0,0 +1,26 @@ +Copyright (c) 2010-2015, Michael Bostock +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +* Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + +* Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +* The name Michael Bostock may not be used to endorse or promote products + derived from this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL MICHAEL BOSTOCK BE LIABLE FOR ANY DIRECT, +INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, +BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING +NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, +EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. \ No newline at end of file diff --git a/licenses/LICENSE-Mockito.txt b/licenses-binary/LICENSE-dagre-d3.txt similarity index 94% rename from licenses/LICENSE-Mockito.txt rename to licenses-binary/LICENSE-dagre-d3.txt index e0840a446caf5..4864fe05e9803 100644 --- a/licenses/LICENSE-Mockito.txt +++ b/licenses-binary/LICENSE-dagre-d3.txt @@ -1,6 +1,4 @@ -The MIT License - -Copyright (c) 2007 Mockito contributors +Copyright (c) 2013 Chris Pettitt Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal diff --git a/licenses-binary/LICENSE-datatables.txt b/licenses-binary/LICENSE-datatables.txt new file mode 100644 index 0000000000000..bb7708b5b5a49 --- /dev/null +++ b/licenses-binary/LICENSE-datatables.txt @@ -0,0 +1,7 @@ +Copyright (C) 2008-2018, SpryMedia Ltd. + +Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. \ No newline at end of file diff --git a/licenses/LICENSE-f2j.txt b/licenses-binary/LICENSE-f2j.txt similarity index 100% rename from licenses/LICENSE-f2j.txt rename to licenses-binary/LICENSE-f2j.txt diff --git a/licenses-binary/LICENSE-graphlib-dot.txt b/licenses-binary/LICENSE-graphlib-dot.txt new file mode 100644 index 0000000000000..4864fe05e9803 --- /dev/null +++ b/licenses-binary/LICENSE-graphlib-dot.txt @@ -0,0 +1,19 @@ +Copyright (c) 2013 Chris Pettitt + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. \ No newline at end of file diff --git a/licenses-binary/LICENSE-heapq.txt b/licenses-binary/LICENSE-heapq.txt new file mode 100644 index 0000000000000..0c4c4b954bea4 --- /dev/null +++ b/licenses-binary/LICENSE-heapq.txt @@ -0,0 +1,280 @@ + +# A. HISTORY OF THE SOFTWARE +# ========================== +# +# Python was created in the early 1990s by Guido van Rossum at Stichting +# Mathematisch Centrum (CWI, see http://www.cwi.nl) in the Netherlands +# as a successor of a language called ABC. Guido remains Python's +# principal author, although it includes many contributions from others. +# +# In 1995, Guido continued his work on Python at the Corporation for +# National Research Initiatives (CNRI, see http://www.cnri.reston.va.us) +# in Reston, Virginia where he released several versions of the +# software. +# +# In May 2000, Guido and the Python core development team moved to +# BeOpen.com to form the BeOpen PythonLabs team. In October of the same +# year, the PythonLabs team moved to Digital Creations (now Zope +# Corporation, see http://www.zope.com). In 2001, the Python Software +# Foundation (PSF, see http://www.python.org/psf/) was formed, a +# non-profit organization created specifically to own Python-related +# Intellectual Property. Zope Corporation is a sponsoring member of +# the PSF. +# +# All Python releases are Open Source (see http://www.opensource.org for +# the Open Source Definition). Historically, most, but not all, Python +# releases have also been GPL-compatible; the table below summarizes +# the various releases. +# +# Release Derived Year Owner GPL- +# from compatible? (1) +# +# 0.9.0 thru 1.2 1991-1995 CWI yes +# 1.3 thru 1.5.2 1.2 1995-1999 CNRI yes +# 1.6 1.5.2 2000 CNRI no +# 2.0 1.6 2000 BeOpen.com no +# 1.6.1 1.6 2001 CNRI yes (2) +# 2.1 2.0+1.6.1 2001 PSF no +# 2.0.1 2.0+1.6.1 2001 PSF yes +# 2.1.1 2.1+2.0.1 2001 PSF yes +# 2.2 2.1.1 2001 PSF yes +# 2.1.2 2.1.1 2002 PSF yes +# 2.1.3 2.1.2 2002 PSF yes +# 2.2.1 2.2 2002 PSF yes +# 2.2.2 2.2.1 2002 PSF yes +# 2.2.3 2.2.2 2003 PSF yes +# 2.3 2.2.2 2002-2003 PSF yes +# 2.3.1 2.3 2002-2003 PSF yes +# 2.3.2 2.3.1 2002-2003 PSF yes +# 2.3.3 2.3.2 2002-2003 PSF yes +# 2.3.4 2.3.3 2004 PSF yes +# 2.3.5 2.3.4 2005 PSF yes +# 2.4 2.3 2004 PSF yes +# 2.4.1 2.4 2005 PSF yes +# 2.4.2 2.4.1 2005 PSF yes +# 2.4.3 2.4.2 2006 PSF yes +# 2.4.4 2.4.3 2006 PSF yes +# 2.5 2.4 2006 PSF yes +# 2.5.1 2.5 2007 PSF yes +# 2.5.2 2.5.1 2008 PSF yes +# 2.5.3 2.5.2 2008 PSF yes +# 2.6 2.5 2008 PSF yes +# 2.6.1 2.6 2008 PSF yes +# 2.6.2 2.6.1 2009 PSF yes +# 2.6.3 2.6.2 2009 PSF yes +# 2.6.4 2.6.3 2009 PSF yes +# 2.6.5 2.6.4 2010 PSF yes +# 2.7 2.6 2010 PSF yes +# +# Footnotes: +# +# (1) GPL-compatible doesn't mean that we're distributing Python under +# the GPL. All Python licenses, unlike the GPL, let you distribute +# a modified version without making your changes open source. The +# GPL-compatible licenses make it possible to combine Python with +# other software that is released under the GPL; the others don't. +# +# (2) According to Richard Stallman, 1.6.1 is not GPL-compatible, +# because its license has a choice of law clause. According to +# CNRI, however, Stallman's lawyer has told CNRI's lawyer that 1.6.1 +# is "not incompatible" with the GPL. +# +# Thanks to the many outside volunteers who have worked under Guido's +# direction to make these releases possible. +# +# +# B. TERMS AND CONDITIONS FOR ACCESSING OR OTHERWISE USING PYTHON +# =============================================================== +# +# PYTHON SOFTWARE FOUNDATION LICENSE VERSION 2 +# -------------------------------------------- +# +# 1. This LICENSE AGREEMENT is between the Python Software Foundation +# ("PSF"), and the Individual or Organization ("Licensee") accessing and +# otherwise using this software ("Python") in source or binary form and +# its associated documentation. +# +# 2. Subject to the terms and conditions of this License Agreement, PSF hereby +# grants Licensee a nonexclusive, royalty-free, world-wide license to reproduce, +# analyze, test, perform and/or display publicly, prepare derivative works, +# distribute, and otherwise use Python alone or in any derivative version, +# provided, however, that PSF's License Agreement and PSF's notice of copyright, +# i.e., "Copyright (c) 2001, 2002, 2003, 2004, 2005, 2006, 2007, 2008, 2009, 2010, +# 2011, 2012, 2013 Python Software Foundation; All Rights Reserved" are retained +# in Python alone or in any derivative version prepared by Licensee. +# +# 3. In the event Licensee prepares a derivative work that is based on +# or incorporates Python or any part thereof, and wants to make +# the derivative work available to others as provided herein, then +# Licensee hereby agrees to include in any such work a brief summary of +# the changes made to Python. +# +# 4. PSF is making Python available to Licensee on an "AS IS" +# basis. PSF MAKES NO REPRESENTATIONS OR WARRANTIES, EXPRESS OR +# IMPLIED. BY WAY OF EXAMPLE, BUT NOT LIMITATION, PSF MAKES NO AND +# DISCLAIMS ANY REPRESENTATION OR WARRANTY OF MERCHANTABILITY OR FITNESS +# FOR ANY PARTICULAR PURPOSE OR THAT THE USE OF PYTHON WILL NOT +# INFRINGE ANY THIRD PARTY RIGHTS. +# +# 5. PSF SHALL NOT BE LIABLE TO LICENSEE OR ANY OTHER USERS OF PYTHON +# FOR ANY INCIDENTAL, SPECIAL, OR CONSEQUENTIAL DAMAGES OR LOSS AS +# A RESULT OF MODIFYING, DISTRIBUTING, OR OTHERWISE USING PYTHON, +# OR ANY DERIVATIVE THEREOF, EVEN IF ADVISED OF THE POSSIBILITY THEREOF. +# +# 6. This License Agreement will automatically terminate upon a material +# breach of its terms and conditions. +# +# 7. Nothing in this License Agreement shall be deemed to create any +# relationship of agency, partnership, or joint venture between PSF and +# Licensee. This License Agreement does not grant permission to use PSF +# trademarks or trade name in a trademark sense to endorse or promote +# products or services of Licensee, or any third party. +# +# 8. By copying, installing or otherwise using Python, Licensee +# agrees to be bound by the terms and conditions of this License +# Agreement. +# +# +# BEOPEN.COM LICENSE AGREEMENT FOR PYTHON 2.0 +# ------------------------------------------- +# +# BEOPEN PYTHON OPEN SOURCE LICENSE AGREEMENT VERSION 1 +# +# 1. This LICENSE AGREEMENT is between BeOpen.com ("BeOpen"), having an +# office at 160 Saratoga Avenue, Santa Clara, CA 95051, and the +# Individual or Organization ("Licensee") accessing and otherwise using +# this software in source or binary form and its associated +# documentation ("the Software"). +# +# 2. Subject to the terms and conditions of this BeOpen Python License +# Agreement, BeOpen hereby grants Licensee a non-exclusive, +# royalty-free, world-wide license to reproduce, analyze, test, perform +# and/or display publicly, prepare derivative works, distribute, and +# otherwise use the Software alone or in any derivative version, +# provided, however, that the BeOpen Python License is retained in the +# Software, alone or in any derivative version prepared by Licensee. +# +# 3. BeOpen is making the Software available to Licensee on an "AS IS" +# basis. BEOPEN MAKES NO REPRESENTATIONS OR WARRANTIES, EXPRESS OR +# IMPLIED. BY WAY OF EXAMPLE, BUT NOT LIMITATION, BEOPEN MAKES NO AND +# DISCLAIMS ANY REPRESENTATION OR WARRANTY OF MERCHANTABILITY OR FITNESS +# FOR ANY PARTICULAR PURPOSE OR THAT THE USE OF THE SOFTWARE WILL NOT +# INFRINGE ANY THIRD PARTY RIGHTS. +# +# 4. BEOPEN SHALL NOT BE LIABLE TO LICENSEE OR ANY OTHER USERS OF THE +# SOFTWARE FOR ANY INCIDENTAL, SPECIAL, OR CONSEQUENTIAL DAMAGES OR LOSS +# AS A RESULT OF USING, MODIFYING OR DISTRIBUTING THE SOFTWARE, OR ANY +# DERIVATIVE THEREOF, EVEN IF ADVISED OF THE POSSIBILITY THEREOF. +# +# 5. This License Agreement will automatically terminate upon a material +# breach of its terms and conditions. +# +# 6. This License Agreement shall be governed by and interpreted in all +# respects by the law of the State of California, excluding conflict of +# law provisions. Nothing in this License Agreement shall be deemed to +# create any relationship of agency, partnership, or joint venture +# between BeOpen and Licensee. This License Agreement does not grant +# permission to use BeOpen trademarks or trade names in a trademark +# sense to endorse or promote products or services of Licensee, or any +# third party. As an exception, the "BeOpen Python" logos available at +# http://www.pythonlabs.com/logos.html may be used according to the +# permissions granted on that web page. +# +# 7. By copying, installing or otherwise using the software, Licensee +# agrees to be bound by the terms and conditions of this License +# Agreement. +# +# +# CNRI LICENSE AGREEMENT FOR PYTHON 1.6.1 +# --------------------------------------- +# +# 1. This LICENSE AGREEMENT is between the Corporation for National +# Research Initiatives, having an office at 1895 Preston White Drive, +# Reston, VA 20191 ("CNRI"), and the Individual or Organization +# ("Licensee") accessing and otherwise using Python 1.6.1 software in +# source or binary form and its associated documentation. +# +# 2. Subject to the terms and conditions of this License Agreement, CNRI +# hereby grants Licensee a nonexclusive, royalty-free, world-wide +# license to reproduce, analyze, test, perform and/or display publicly, +# prepare derivative works, distribute, and otherwise use Python 1.6.1 +# alone or in any derivative version, provided, however, that CNRI's +# License Agreement and CNRI's notice of copyright, i.e., "Copyright (c) +# 1995-2001 Corporation for National Research Initiatives; All Rights +# Reserved" are retained in Python 1.6.1 alone or in any derivative +# version prepared by Licensee. Alternately, in lieu of CNRI's License +# Agreement, Licensee may substitute the following text (omitting the +# quotes): "Python 1.6.1 is made available subject to the terms and +# conditions in CNRI's License Agreement. This Agreement together with +# Python 1.6.1 may be located on the Internet using the following +# unique, persistent identifier (known as a handle): 1895.22/1013. This +# Agreement may also be obtained from a proxy server on the Internet +# using the following URL: http://hdl.handle.net/1895.22/1013". +# +# 3. In the event Licensee prepares a derivative work that is based on +# or incorporates Python 1.6.1 or any part thereof, and wants to make +# the derivative work available to others as provided herein, then +# Licensee hereby agrees to include in any such work a brief summary of +# the changes made to Python 1.6.1. +# +# 4. CNRI is making Python 1.6.1 available to Licensee on an "AS IS" +# basis. CNRI MAKES NO REPRESENTATIONS OR WARRANTIES, EXPRESS OR +# IMPLIED. BY WAY OF EXAMPLE, BUT NOT LIMITATION, CNRI MAKES NO AND +# DISCLAIMS ANY REPRESENTATION OR WARRANTY OF MERCHANTABILITY OR FITNESS +# FOR ANY PARTICULAR PURPOSE OR THAT THE USE OF PYTHON 1.6.1 WILL NOT +# INFRINGE ANY THIRD PARTY RIGHTS. +# +# 5. CNRI SHALL NOT BE LIABLE TO LICENSEE OR ANY OTHER USERS OF PYTHON +# 1.6.1 FOR ANY INCIDENTAL, SPECIAL, OR CONSEQUENTIAL DAMAGES OR LOSS AS +# A RESULT OF MODIFYING, DISTRIBUTING, OR OTHERWISE USING PYTHON 1.6.1, +# OR ANY DERIVATIVE THEREOF, EVEN IF ADVISED OF THE POSSIBILITY THEREOF. +# +# 6. This License Agreement will automatically terminate upon a material +# breach of its terms and conditions. +# +# 7. This License Agreement shall be governed by the federal +# intellectual property law of the United States, including without +# limitation the federal copyright law, and, to the extent such +# U.S. federal law does not apply, by the law of the Commonwealth of +# Virginia, excluding Virginia's conflict of law provisions. +# Notwithstanding the foregoing, with regard to derivative works based +# on Python 1.6.1 that incorporate non-separable material that was +# previously distributed under the GNU General Public License (GPL), the +# law of the Commonwealth of Virginia shall govern this License +# Agreement only as to issues arising under or with respect to +# Paragraphs 4, 5, and 7 of this License Agreement. Nothing in this +# License Agreement shall be deemed to create any relationship of +# agency, partnership, or joint venture between CNRI and Licensee. This +# License Agreement does not grant permission to use CNRI trademarks or +# trade name in a trademark sense to endorse or promote products or +# services of Licensee, or any third party. +# +# 8. By clicking on the "ACCEPT" button where indicated, or by copying, +# installing or otherwise using Python 1.6.1, Licensee agrees to be +# bound by the terms and conditions of this License Agreement. +# +# ACCEPT +# +# +# CWI LICENSE AGREEMENT FOR PYTHON 0.9.0 THROUGH 1.2 +# -------------------------------------------------- +# +# Copyright (c) 1991 - 1995, Stichting Mathematisch Centrum Amsterdam, +# The Netherlands. All rights reserved. +# +# Permission to use, copy, modify, and distribute this software and its +# documentation for any purpose and without fee is hereby granted, +# provided that the above copyright notice appear in all copies and that +# both that copyright notice and this permission notice appear in +# supporting documentation, and that the name of Stichting Mathematisch +# Centrum or CWI not be used in advertising or publicity pertaining to +# distribution of the software without specific, written prior +# permission. +# +# STICHTING MATHEMATISCH CENTRUM DISCLAIMS ALL WARRANTIES WITH REGARD TO +# THIS SOFTWARE, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY AND +# FITNESS, IN NO EVENT SHALL STICHTING MATHEMATISCH CENTRUM BE LIABLE +# FOR ANY SPECIAL, INDIRECT OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES +# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN +# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT +# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. \ No newline at end of file diff --git a/licenses-binary/LICENSE-janino.txt b/licenses-binary/LICENSE-janino.txt new file mode 100644 index 0000000000000..d1e1f237c4641 --- /dev/null +++ b/licenses-binary/LICENSE-janino.txt @@ -0,0 +1,31 @@ +Janino - An embedded Java[TM] compiler + +Copyright (c) 2001-2016, Arno Unkrig +Copyright (c) 2015-2016 TIBCO Software Inc. +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions +are met: + + 1. Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + 2. Redistributions in binary form must reproduce the above + copyright notice, this list of conditions and the following + disclaimer in the documentation and/or other materials + provided with the distribution. + 3. Neither the name of JANINO nor the names of its contributors + may be used to endorse or promote products derived from this + software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDERS OR CONTRIBUTORS BE +LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER +IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR +OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN +IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. \ No newline at end of file diff --git a/licenses-binary/LICENSE-javassist.html b/licenses-binary/LICENSE-javassist.html new file mode 100644 index 0000000000000..5abd563a0c4d9 --- /dev/null +++ b/licenses-binary/LICENSE-javassist.html @@ -0,0 +1,373 @@ + + + Javassist License + + + + +
MOZILLA PUBLIC LICENSE
Version + 1.1 +

+


+
+

1. Definitions. +

    1.0.1. "Commercial Use" means distribution or otherwise making the + Covered Code available to a third party. +

    1.1. ''Contributor'' means each entity that creates or contributes + to the creation of Modifications. +

    1.2. ''Contributor Version'' means the combination of the Original + Code, prior Modifications used by a Contributor, and the Modifications made by + that particular Contributor. +

    1.3. ''Covered Code'' means the Original Code or Modifications or + the combination of the Original Code and Modifications, in each case including + portions thereof. +

    1.4. ''Electronic Distribution Mechanism'' means a mechanism + generally accepted in the software development community for the electronic + transfer of data. +

    1.5. ''Executable'' means Covered Code in any form other than Source + Code. +

    1.6. ''Initial Developer'' means the individual or entity identified + as the Initial Developer in the Source Code notice required by Exhibit + A. +

    1.7. ''Larger Work'' means a work which combines Covered Code or + portions thereof with code not governed by the terms of this License. +

    1.8. ''License'' means this document. +

    1.8.1. "Licensable" means having the right to grant, to the maximum + extent possible, whether at the time of the initial grant or subsequently + acquired, any and all of the rights conveyed herein. +

    1.9. ''Modifications'' means any addition to or deletion from the + substance or structure of either the Original Code or any previous + Modifications. When Covered Code is released as a series of files, a + Modification is: +

      A. Any addition to or deletion from the contents of a file + containing Original Code or previous Modifications. +

      B. Any new file that contains any part of the Original Code or + previous Modifications.
       

    1.10. ''Original Code'' +means Source Code of computer software code which is described in the Source +Code notice required by Exhibit A as Original Code, and which, at the +time of its release under this License is not already Covered Code governed by +this License. +

    1.10.1. "Patent Claims" means any patent claim(s), now owned or + hereafter acquired, including without limitation,  method, process, and + apparatus claims, in any patent Licensable by grantor. +

    1.11. ''Source Code'' means the preferred form of the Covered Code + for making modifications to it, including all modules it contains, plus any + associated interface definition files, scripts used to control compilation and + installation of an Executable, or source code differential comparisons against + either the Original Code or another well known, available Covered Code of the + Contributor's choice. The Source Code can be in a compressed or archival form, + provided the appropriate decompression or de-archiving software is widely + available for no charge. +

    1.12. "You'' (or "Your")  means an individual or a legal entity + exercising rights under, and complying with all of the terms of, this License + or a future version of this License issued under Section 6.1. For legal + entities, "You'' includes any entity which controls, is controlled by, or is + under common control with You. For purposes of this definition, "control'' + means (a) the power, direct or indirect, to cause the direction or management + of such entity, whether by contract or otherwise, or (b) ownership of more + than fifty percent (50%) of the outstanding shares or beneficial ownership of + such entity.

2. Source Code License. +
    2.1. The Initial Developer Grant.
    The Initial Developer hereby + grants You a world-wide, royalty-free, non-exclusive license, subject to third + party intellectual property claims: +
      (a)  under intellectual property rights (other than + patent or trademark) Licensable by Initial Developer to use, reproduce, + modify, display, perform, sublicense and distribute the Original Code (or + portions thereof) with or without Modifications, and/or as part of a Larger + Work; and +

      (b) under Patents Claims infringed by the making, using or selling + of Original Code, to make, have made, use, practice, sell, and offer for + sale, and/or otherwise dispose of the Original Code (or portions thereof). +

        +
        (c) the licenses granted in this Section 2.1(a) and (b) + are effective on the date Initial Developer first distributes Original Code + under the terms of this License. +

        (d) Notwithstanding Section 2.1(b) above, no patent license is + granted: 1) for code that You delete from the Original Code; 2) separate + from the Original Code;  or 3) for infringements caused by: i) the + modification of the Original Code or ii) the combination of the Original + Code with other software or devices.
         

      2.2. Contributor + Grant.
      Subject to third party intellectual property claims, each + Contributor hereby grants You a world-wide, royalty-free, non-exclusive + license +

        (a)  under intellectual property rights (other + than patent or trademark) Licensable by Contributor, to use, reproduce, + modify, display, perform, sublicense and distribute the Modifications + created by such Contributor (or portions thereof) either on an unmodified + basis, with other Modifications, as Covered Code and/or as part of a Larger + Work; and +

        (b) under Patent Claims infringed by the making, using, or selling + of  Modifications made by that Contributor either alone and/or in combination with its Contributor Version (or portions of such + combination), to make, use, sell, offer for sale, have made, and/or + otherwise dispose of: 1) Modifications made by that Contributor (or portions + thereof); and 2) the combination of  Modifications made by that + Contributor with its Contributor Version (or portions of such + combination). +

        (c) the licenses granted in Sections 2.2(a) and 2.2(b) are + effective on the date Contributor first makes Commercial Use of the Covered + Code. +

        (d)    Notwithstanding Section 2.2(b) above, no + patent license is granted: 1) for any code that Contributor has deleted from + the Contributor Version; 2)  separate from the Contributor + Version;  3)  for infringements caused by: i) third party + modifications of Contributor Version or ii)  the combination of + Modifications made by that Contributor with other software  (except as + part of the Contributor Version) or other devices; or 4) under Patent Claims + infringed by Covered Code in the absence of Modifications made by that + Contributor.

    +


    3. Distribution Obligations. +

      3.1. Application of License.
      The Modifications which You create + or to which You contribute are governed by the terms of this License, + including without limitation Section 2.2. The Source Code version of + Covered Code may be distributed only under the terms of this License or a + future version of this License released under Section 6.1, and You must + include a copy of this License with every copy of the Source Code You + distribute. You may not offer or impose any terms on any Source Code version + that alters or restricts the applicable version of this License or the + recipients' rights hereunder. However, You may include an additional document + offering the additional rights described in Section 3.5. +

      3.2. Availability of Source Code.
      Any Modification which You + create or to which You contribute must be made available in Source Code form + under the terms of this License either on the same media as an Executable + version or via an accepted Electronic Distribution Mechanism to anyone to whom + you made an Executable version available; and if made available via Electronic + Distribution Mechanism, must remain available for at least twelve (12) months + after the date it initially became available, or at least six (6) months after + a subsequent version of that particular Modification has been made available + to such recipients. You are responsible for ensuring that the Source Code + version remains available even if the Electronic Distribution Mechanism is + maintained by a third party. +

      3.3. Description of Modifications.
      You must cause all Covered + Code to which You contribute to contain a file documenting the changes You + made to create that Covered Code and the date of any change. You must include + a prominent statement that the Modification is derived, directly or + indirectly, from Original Code provided by the Initial Developer and including + the name of the Initial Developer in (a) the Source Code, and (b) in any + notice in an Executable version or related documentation in which You describe + the origin or ownership of the Covered Code. +

      3.4. Intellectual Property Matters +

        (a) Third Party Claims.
        If Contributor has knowledge that a + license under a third party's intellectual property rights is required to + exercise the rights granted by such Contributor under Sections 2.1 or 2.2, + Contributor must include a text file with the Source Code distribution + titled "LEGAL'' which describes the claim and the party making the claim in + sufficient detail that a recipient will know whom to contact. If Contributor + obtains such knowledge after the Modification is made available as described + in Section 3.2, Contributor shall promptly modify the LEGAL file in all + copies Contributor makes available thereafter and shall take other steps + (such as notifying appropriate mailing lists or newsgroups) reasonably + calculated to inform those who received the Covered Code that new knowledge + has been obtained. +

        (b) Contributor APIs.
        If Contributor's Modifications include + an application programming interface and Contributor has knowledge of patent + licenses which are reasonably necessary to implement that API, Contributor + must also include this information in the LEGAL file. +
         

                +(c)    Representations. +
        Contributor represents that, except as disclosed pursuant to Section + 3.4(a) above, Contributor believes that Contributor's Modifications are + Contributor's original creation(s) and/or Contributor has sufficient rights + to grant the rights conveyed by this License.
      +


      3.5. Required Notices.
      You must duplicate the notice in + Exhibit A in each file of the Source Code.  If it is not possible + to put such notice in a particular Source Code file due to its structure, then + You must include such notice in a location (such as a relevant directory) + where a user would be likely to look for such a notice.  If You created + one or more Modification(s) You may add your name as a Contributor to the + notice described in Exhibit A.  You must also duplicate this + License in any documentation for the Source Code where You describe + recipients' rights or ownership rights relating to Covered Code.  You may + choose to offer, and to charge a fee for, warranty, support, indemnity or + liability obligations to one or more recipients of Covered Code. However, You + may do so only on Your own behalf, and not on behalf of the Initial Developer + or any Contributor. You must make it absolutely clear than any such warranty, + support, indemnity or liability obligation is offered by You alone, and You + hereby agree to indemnify the Initial Developer and every Contributor for any + liability incurred by the Initial Developer or such Contributor as a result of + warranty, support, indemnity or liability terms You offer. +

      3.6. Distribution of Executable Versions.
      You may distribute + Covered Code in Executable form only if the requirements of Section + 3.1-3.5 have been met for that Covered Code, and if You include a + notice stating that the Source Code version of the Covered Code is available + under the terms of this License, including a description of how and where You + have fulfilled the obligations of Section 3.2. The notice must be + conspicuously included in any notice in an Executable version, related + documentation or collateral in which You describe recipients' rights relating + to the Covered Code. You may distribute the Executable version of Covered Code + or ownership rights under a license of Your choice, which may contain terms + different from this License, provided that You are in compliance with the + terms of this License and that the license for the Executable version does not + attempt to limit or alter the recipient's rights in the Source Code version + from the rights set forth in this License. If You distribute the Executable + version under a different license You must make it absolutely clear that any + terms which differ from this License are offered by You alone, not by the + Initial Developer or any Contributor. You hereby agree to indemnify the + Initial Developer and every Contributor for any liability incurred by the + Initial Developer or such Contributor as a result of any such terms You offer. + +

      3.7. Larger Works.
      You may create a Larger Work by combining + Covered Code with other code not governed by the terms of this License and + distribute the Larger Work as a single product. In such a case, You must make + sure the requirements of this License are fulfilled for the Covered + Code.

    4. Inability to Comply Due to Statute or Regulation. +
      If it is impossible for You to comply with any of the terms of this + License with respect to some or all of the Covered Code due to statute, + judicial order, or regulation then You must: (a) comply with the terms of this + License to the maximum extent possible; and (b) describe the limitations and + the code they affect. Such description must be included in the LEGAL file + described in Section 3.4 and must be included with all distributions of + the Source Code. Except to the extent prohibited by statute or regulation, + such description must be sufficiently detailed for a recipient of ordinary + skill to be able to understand it.
    5. Application of this License. +
      This License applies to code to which the Initial Developer has attached + the notice in Exhibit A and to related Covered Code.
    6. Versions + of the License. +
      6.1. New Versions.
      Netscape Communications Corporation + (''Netscape'') may publish revised and/or new versions of the License from + time to time. Each version will be given a distinguishing version number. +

      6.2. Effect of New Versions.
      Once Covered Code has been + published under a particular version of the License, You may always continue + to use it under the terms of that version. You may also choose to use such + Covered Code under the terms of any subsequent version of the License + published by Netscape. No one other than Netscape has the right to modify the + terms applicable to Covered Code created under this License. +

      6.3. Derivative Works.
      If You create or use a modified version + of this License (which you may only do in order to apply it to code which is + not already Covered Code governed by this License), You must (a) rename Your + license so that the phrases ''Mozilla'', ''MOZILLAPL'', ''MOZPL'', + ''Netscape'', "MPL", ''NPL'' or any confusingly similar phrase do not appear + in your license (except to note that your license differs from this License) + and (b) otherwise make it clear that Your version of the license contains + terms which differ from the Mozilla Public License and Netscape Public + License. (Filling in the name of the Initial Developer, Original Code or + Contributor in the notice described in Exhibit A shall not of + themselves be deemed to be modifications of this License.)

    7. + DISCLAIMER OF WARRANTY. +
      COVERED CODE IS PROVIDED UNDER THIS LICENSE ON AN "AS IS'' BASIS, WITHOUT + WARRANTY OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, WITHOUT + LIMITATION, WARRANTIES THAT THE COVERED CODE IS FREE OF DEFECTS, MERCHANTABLE, + FIT FOR A PARTICULAR PURPOSE OR NON-INFRINGING. THE ENTIRE RISK AS TO THE + QUALITY AND PERFORMANCE OF THE COVERED CODE IS WITH YOU. SHOULD ANY COVERED + CODE PROVE DEFECTIVE IN ANY RESPECT, YOU (NOT THE INITIAL DEVELOPER OR ANY + OTHER CONTRIBUTOR) ASSUME THE COST OF ANY NECESSARY SERVICING, REPAIR OR + CORRECTION. THIS DISCLAIMER OF WARRANTY CONSTITUTES AN ESSENTIAL PART OF THIS + LICENSE. NO USE OF ANY COVERED CODE IS AUTHORIZED HEREUNDER EXCEPT UNDER THIS + DISCLAIMER.
    8. TERMINATION. +
      8.1.  This License and the rights granted hereunder will + terminate automatically if You fail to comply with terms herein and fail to + cure such breach within 30 days of becoming aware of the breach. All + sublicenses to the Covered Code which are properly granted shall survive any + termination of this License. Provisions which, by their nature, must remain in + effect beyond the termination of this License shall survive. +

      8.2.  If You initiate litigation by asserting a patent + infringement claim (excluding declatory judgment actions) against Initial + Developer or a Contributor (the Initial Developer or Contributor against whom + You file such action is referred to as "Participant")  alleging that: +

      (a)  such Participant's Contributor Version directly or + indirectly infringes any patent, then any and all rights granted by such + Participant to You under Sections 2.1 and/or 2.2 of this License shall, upon + 60 days notice from Participant terminate prospectively, unless if within 60 + days after receipt of notice You either: (i)  agree in writing to pay + Participant a mutually agreeable reasonable royalty for Your past and future + use of Modifications made by such Participant, or (ii) withdraw Your + litigation claim with respect to the Contributor Version against such + Participant.  If within 60 days of notice, a reasonable royalty and + payment arrangement are not mutually agreed upon in writing by the parties or + the litigation claim is not withdrawn, the rights granted by Participant to + You under Sections 2.1 and/or 2.2 automatically terminate at the expiration of + the 60 day notice period specified above. +

      (b)  any software, hardware, or device, other than such + Participant's Contributor Version, directly or indirectly infringes any + patent, then any rights granted to You by such Participant under Sections + 2.1(b) and 2.2(b) are revoked effective as of the date You first made, used, + sold, distributed, or had made, Modifications made by that Participant. +

      8.3.  If You assert a patent infringement claim against + Participant alleging that such Participant's Contributor Version directly or + indirectly infringes any patent where such claim is resolved (such as by + license or settlement) prior to the initiation of patent infringement + litigation, then the reasonable value of the licenses granted by such + Participant under Sections 2.1 or 2.2 shall be taken into account in + determining the amount or value of any payment or license. +

      8.4.  In the event of termination under Sections 8.1 or 8.2 + above,  all end user license agreements (excluding distributors and + resellers) which have been validly granted by You or any distributor hereunder + prior to termination shall survive termination.

    9. LIMITATION OF + LIABILITY. +
      UNDER NO CIRCUMSTANCES AND UNDER NO LEGAL THEORY, WHETHER TORT (INCLUDING + NEGLIGENCE), CONTRACT, OR OTHERWISE, SHALL YOU, THE INITIAL DEVELOPER, ANY + OTHER CONTRIBUTOR, OR ANY DISTRIBUTOR OF COVERED CODE, OR ANY SUPPLIER OF ANY + OF SUCH PARTIES, BE LIABLE TO ANY PERSON FOR ANY INDIRECT, SPECIAL, + INCIDENTAL, OR CONSEQUENTIAL DAMAGES OF ANY CHARACTER INCLUDING, WITHOUT + LIMITATION, DAMAGES FOR LOSS OF GOODWILL, WORK STOPPAGE, COMPUTER FAILURE OR + MALFUNCTION, OR ANY AND ALL OTHER COMMERCIAL DAMAGES OR LOSSES, EVEN IF SUCH + PARTY SHALL HAVE BEEN INFORMED OF THE POSSIBILITY OF SUCH DAMAGES. THIS + LIMITATION OF LIABILITY SHALL NOT APPLY TO LIABILITY FOR DEATH OR PERSONAL + INJURY RESULTING FROM SUCH PARTY'S NEGLIGENCE TO THE EXTENT APPLICABLE LAW + PROHIBITS SUCH LIMITATION. SOME JURISDICTIONS DO NOT ALLOW THE EXCLUSION OR + LIMITATION OF INCIDENTAL OR CONSEQUENTIAL DAMAGES, SO THIS EXCLUSION AND + LIMITATION MAY NOT APPLY TO YOU.
    10. U.S. GOVERNMENT END USERS. +
      The Covered Code is a ''commercial item,'' as that term is defined in 48 + C.F.R. 2.101 (Oct. 1995), consisting of ''commercial computer software'' and + ''commercial computer software documentation,'' as such terms are used in 48 + C.F.R. 12.212 (Sept. 1995). Consistent with 48 C.F.R. 12.212 and 48 C.F.R. + 227.7202-1 through 227.7202-4 (June 1995), all U.S. Government End Users + acquire Covered Code with only those rights set forth herein.
    11. + MISCELLANEOUS. +
      This License represents the complete agreement concerning subject matter + hereof. If any provision of this License is held to be unenforceable, such + provision shall be reformed only to the extent necessary to make it + enforceable. This License shall be governed by California law provisions + (except to the extent applicable law, if any, provides otherwise), excluding + its conflict-of-law provisions. With respect to disputes in which at least one + party is a citizen of, or an entity chartered or registered to do business in + the United States of America, any litigation relating to this License shall be + subject to the jurisdiction of the Federal Courts of the Northern District of + California, with venue lying in Santa Clara County, California, with the + losing party responsible for costs, including without limitation, court costs + and reasonable attorneys' fees and expenses. The application of the United + Nations Convention on Contracts for the International Sale of Goods is + expressly excluded. Any law or regulation which provides that the language of + a contract shall be construed against the drafter shall not apply to this + License.
    12. RESPONSIBILITY FOR CLAIMS. +
      As between Initial Developer and the Contributors, each party is + responsible for claims and damages arising, directly or indirectly, out of its + utilization of rights under this License and You agree to work with Initial + Developer and Contributors to distribute such responsibility on an equitable + basis. Nothing herein is intended or shall be deemed to constitute any + admission of liability.
    13. MULTIPLE-LICENSED CODE. +
      Initial Developer may designate portions of the Covered Code as + "Multiple-Licensed".  "Multiple-Licensed" means that the Initial + Developer permits you to utilize portions of the Covered Code under Your + choice of the MPL or the alternative licenses, if any, specified by the + Initial Developer in the file described in Exhibit A.
    +


    EXHIBIT A -Mozilla Public License. +

      The contents of this file are subject to the Mozilla Public License + Version 1.1 (the "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at +
      http://www.mozilla.org/MPL/ +

      Software distributed under the License is distributed on an "AS IS" basis, + WITHOUT WARRANTY OF
      ANY KIND, either express or implied. See the License + for the specific language governing rights and
      limitations under the + License. +

      The Original Code is Javassist. +

      The Initial Developer of the Original Code is Shigeru Chiba. + Portions created by the Initial Developer are
        + Copyright (C) 1999- Shigeru Chiba. All Rights Reserved. +

      Contributor(s): __Bill Burke, Jason T. Greene______________. + +

      Alternatively, the contents of this software may be used under the + terms of the GNU Lesser General Public License Version 2.1 or later + (the "LGPL"), or the Apache License Version 2.0 (the "AL"), + in which case the provisions of the LGPL or the AL are applicable + instead of those above. If you wish to allow use of your version of + this software only under the terms of either the LGPL or the AL, and not to allow others to + use your version of this software under the terms of the MPL, indicate + your decision by deleting the provisions above and replace them with + the notice and other provisions required by the LGPL or the AL. If you do not + delete the provisions above, a recipient may use your version of this + software under the terms of any one of the MPL, the LGPL or the AL. + +

    + + \ No newline at end of file diff --git a/licenses/LICENSE-javolution.txt b/licenses-binary/LICENSE-javolution.txt similarity index 100% rename from licenses/LICENSE-javolution.txt rename to licenses-binary/LICENSE-javolution.txt diff --git a/licenses/LICENSE-jline.txt b/licenses-binary/LICENSE-jline.txt similarity index 100% rename from licenses/LICENSE-jline.txt rename to licenses-binary/LICENSE-jline.txt diff --git a/licenses/LICENSE-junit-interface.txt b/licenses-binary/LICENSE-jodd.txt similarity index 69% rename from licenses/LICENSE-junit-interface.txt rename to licenses-binary/LICENSE-jodd.txt index e835350c4e2a4..cc6b458adb386 100644 --- a/licenses/LICENSE-junit-interface.txt +++ b/licenses-binary/LICENSE-jodd.txt @@ -1,15 +1,15 @@ -Copyright (c) 2009-2012, Stefan Zeiger +Copyright (c) 2003-present, Jodd Team (https://jodd.org) All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - * Redistributions of source code must retain the above copyright notice, - this list of conditions and the following disclaimer. +1. Redistributions of source code must retain the above copyright notice, +this list of conditions and the following disclaimer. - * Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. +2. Redistributions in binary form must reproduce the above copyright +notice, this list of conditions and the following disclaimer in the +documentation and/or other materials provided with the distribution. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE diff --git a/licenses/LICENSE-DPark.txt b/licenses-binary/LICENSE-join.txt similarity index 100% rename from licenses/LICENSE-DPark.txt rename to licenses-binary/LICENSE-join.txt diff --git a/licenses-binary/LICENSE-jquery.txt b/licenses-binary/LICENSE-jquery.txt new file mode 100644 index 0000000000000..45930542204fb --- /dev/null +++ b/licenses-binary/LICENSE-jquery.txt @@ -0,0 +1,20 @@ +Copyright JS Foundation and other contributors, https://js.foundation/ + +Permission is hereby granted, free of charge, to any person obtaining +a copy of this software and associated documentation files (the +"Software"), to deal in the Software without restriction, including +without limitation the rights to use, copy, modify, merge, publish, +distribute, sublicense, and/or sell copies of the Software, and to +permit persons to whom the Software is furnished to do so, subject to +the following conditions: + +The above copyright notice and this permission notice shall be +included in all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND +NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE +LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION +OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION +WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. \ No newline at end of file diff --git a/licenses-binary/LICENSE-json-formatter.txt b/licenses-binary/LICENSE-json-formatter.txt new file mode 100644 index 0000000000000..5193348fce126 --- /dev/null +++ b/licenses-binary/LICENSE-json-formatter.txt @@ -0,0 +1,6 @@ +Copyright 2014 Mohsen Azimi + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. \ No newline at end of file diff --git a/licenses-binary/LICENSE-jtransforms.html b/licenses-binary/LICENSE-jtransforms.html new file mode 100644 index 0000000000000..351c17412357b --- /dev/null +++ b/licenses-binary/LICENSE-jtransforms.html @@ -0,0 +1,388 @@ + + +Mozilla Public License version 1.1 + + + + +

    Mozilla Public License Version 1.1

    +

    1. Definitions.

    +
    +
    1.0.1. "Commercial Use" +
    means distribution or otherwise making the Covered Code available to a third party. +
    1.1. "Contributor" +
    means each entity that creates or contributes to the creation of Modifications. +
    1.2. "Contributor Version" +
    means the combination of the Original Code, prior Modifications used by a Contributor, + and the Modifications made by that particular Contributor. +
    1.3. "Covered Code" +
    means the Original Code or Modifications or the combination of the Original Code and + Modifications, in each case including portions thereof. +
    1.4. "Electronic Distribution Mechanism" +
    means a mechanism generally accepted in the software development community for the + electronic transfer of data. +
    1.5. "Executable" +
    means Covered Code in any form other than Source Code. +
    1.6. "Initial Developer" +
    means the individual or entity identified as the Initial Developer in the Source Code + notice required by Exhibit A. +
    1.7. "Larger Work" +
    means a work which combines Covered Code or portions thereof with code not governed + by the terms of this License. +
    1.8. "License" +
    means this document. +
    1.8.1. "Licensable" +
    means having the right to grant, to the maximum extent possible, whether at the + time of the initial grant or subsequently acquired, any and all of the rights + conveyed herein. +
    1.9. "Modifications" +
    +

    means any addition to or deletion from the substance or structure of either the + Original Code or any previous Modifications. When Covered Code is released as a + series of files, a Modification is: +

      +
    1. Any addition to or deletion from the contents of a file + containing Original Code or previous Modifications. +
    2. Any new file that contains any part of the Original Code or + previous Modifications. +
    +
    1.10. "Original Code" +
    means Source Code of computer software code which is described in the Source Code + notice required by Exhibit A as Original Code, and which, + at the time of its release under this License is not already Covered Code governed + by this License. +
    1.10.1. "Patent Claims" +
    means any patent claim(s), now owned or hereafter acquired, including without + limitation, method, process, and apparatus claims, in any patent Licensable by + grantor. +
    1.11. "Source Code" +
    means the preferred form of the Covered Code for making modifications to it, + including all modules it contains, plus any associated interface definition files, + scripts used to control compilation and installation of an Executable, or source + code differential comparisons against either the Original Code or another well known, + available Covered Code of the Contributor's choice. The Source Code can be in a + compressed or archival form, provided the appropriate decompression or de-archiving + software is widely available for no charge. +
    1.12. "You" (or "Your") +
    means an individual or a legal entity exercising rights under, and complying with + all of the terms of, this License or a future version of this License issued under + Section 6.1. For legal entities, "You" includes any entity + which controls, is controlled by, or is under common control with You. For purposes of + this definition, "control" means (a) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or otherwise, or (b) + ownership of more than fifty percent (50%) of the outstanding shares or beneficial + ownership of such entity. +
    +

    2. Source Code License.

    +

    2.1. The Initial Developer Grant.

    +

    The Initial Developer hereby grants You a world-wide, royalty-free, non-exclusive + license, subject to third party intellectual property claims: +

      +
    1. under intellectual property rights (other than patent or + trademark) Licensable by Initial Developer to use, reproduce, modify, display, perform, + sublicense and distribute the Original Code (or portions thereof) with or without + Modifications, and/or as part of a Larger Work; and +
    2. under Patents Claims infringed by the making, using or selling + of Original Code, to make, have made, use, practice, sell, and offer for sale, and/or + otherwise dispose of the Original Code (or portions thereof). +
    3. the licenses granted in this Section 2.1 + (a) and (b) are effective on + the date Initial Developer first distributes Original Code under the terms of this + License. +
    4. Notwithstanding Section 2.1 (b) + above, no patent license is granted: 1) for code that You delete from the Original Code; + 2) separate from the Original Code; or 3) for infringements caused by: i) the + modification of the Original Code or ii) the combination of the Original Code with other + software or devices. +
    +

    2.2. Contributor Grant.

    +

    Subject to third party intellectual property claims, each Contributor hereby grants You + a world-wide, royalty-free, non-exclusive license +

      +
    1. under intellectual property rights (other than patent or trademark) + Licensable by Contributor, to use, reproduce, modify, display, perform, sublicense and + distribute the Modifications created by such Contributor (or portions thereof) either on + an unmodified basis, with other Modifications, as Covered Code and/or as part of a Larger + Work; and +
    2. under Patent Claims infringed by the making, using, or selling of + Modifications made by that Contributor either alone and/or in combination with its + Contributor Version (or portions of such combination), to make, use, sell, offer for + sale, have made, and/or otherwise dispose of: 1) Modifications made by that Contributor + (or portions thereof); and 2) the combination of Modifications made by that Contributor + with its Contributor Version (or portions of such combination). +
    3. the licenses granted in Sections 2.2 + (a) and 2.2 (b) are effective + on the date Contributor first makes Commercial Use of the Covered Code. +
    4. Notwithstanding Section 2.2 (b) + above, no patent license is granted: 1) for any code that Contributor has deleted from + the Contributor Version; 2) separate from the Contributor Version; 3) for infringements + caused by: i) third party modifications of Contributor Version or ii) the combination of + Modifications made by that Contributor with other software (except as part of the + Contributor Version) or other devices; or 4) under Patent Claims infringed by Covered Code + in the absence of Modifications made by that Contributor. +
    +

    3. Distribution Obligations.

    +

    3.1. Application of License.

    +

    The Modifications which You create or to which You contribute are governed by the terms + of this License, including without limitation Section 2.2. The + Source Code version of Covered Code may be distributed only under the terms of this License + or a future version of this License released under Section 6.1, + and You must include a copy of this License with every copy of the Source Code You + distribute. You may not offer or impose any terms on any Source Code version that alters or + restricts the applicable version of this License or the recipients' rights hereunder. + However, You may include an additional document offering the additional rights described in + Section 3.5. +

    3.2. Availability of Source Code.

    +

    Any Modification which You create or to which You contribute must be made available in + Source Code form under the terms of this License either on the same media as an Executable + version or via an accepted Electronic Distribution Mechanism to anyone to whom you made an + Executable version available; and if made available via Electronic Distribution Mechanism, + must remain available for at least twelve (12) months after the date it initially became + available, or at least six (6) months after a subsequent version of that particular + Modification has been made available to such recipients. You are responsible for ensuring + that the Source Code version remains available even if the Electronic Distribution + Mechanism is maintained by a third party. +

    3.3. Description of Modifications.

    +

    You must cause all Covered Code to which You contribute to contain a file documenting the + changes You made to create that Covered Code and the date of any change. You must include a + prominent statement that the Modification is derived, directly or indirectly, from Original + Code provided by the Initial Developer and including the name of the Initial Developer in + (a) the Source Code, and (b) in any notice in an Executable version or related documentation + in which You describe the origin or ownership of the Covered Code. +

    3.4. Intellectual Property Matters

    +

    (a) Third Party Claims

    +

    If Contributor has knowledge that a license under a third party's intellectual property + rights is required to exercise the rights granted by such Contributor under Sections + 2.1 or 2.2, Contributor must include a + text file with the Source Code distribution titled "LEGAL" which describes the claim and the + party making the claim in sufficient detail that a recipient will know whom to contact. If + Contributor obtains such knowledge after the Modification is made available as described in + Section 3.2, Contributor shall promptly modify the LEGAL file in + all copies Contributor makes available thereafter and shall take other steps (such as + notifying appropriate mailing lists or newsgroups) reasonably calculated to inform those who + received the Covered Code that new knowledge has been obtained. +

    (b) Contributor APIs

    +

    If Contributor's Modifications include an application programming interface and Contributor + has knowledge of patent licenses which are reasonably necessary to implement that + API, Contributor must also include this information in the + legal file. +

    (c) Representations.

    +

    Contributor represents that, except as disclosed pursuant to Section 3.4 + (a) above, Contributor believes that Contributor's Modifications + are Contributor's original creation(s) and/or Contributor has sufficient rights to grant the + rights conveyed by this License. +

    3.5. Required Notices.

    +

    You must duplicate the notice in Exhibit A in each file of the + Source Code. If it is not possible to put such notice in a particular Source Code file due to + its structure, then You must include such notice in a location (such as a relevant directory) + where a user would be likely to look for such a notice. If You created one or more + Modification(s) You may add your name as a Contributor to the notice described in + Exhibit A. You must also duplicate this License in any documentation + for the Source Code where You describe recipients' rights or ownership rights relating to + Covered Code. You may choose to offer, and to charge a fee for, warranty, support, indemnity + or liability obligations to one or more recipients of Covered Code. However, You may do so + only on Your own behalf, and not on behalf of the Initial Developer or any Contributor. You + must make it absolutely clear than any such warranty, support, indemnity or liability + obligation is offered by You alone, and You hereby agree to indemnify the Initial Developer + and every Contributor for any liability incurred by the Initial Developer or such Contributor + as a result of warranty, support, indemnity or liability terms You offer. +

    3.6. Distribution of Executable Versions.

    +

    You may distribute Covered Code in Executable form only if the requirements of Sections + 3.1, 3.2, + 3.3, 3.4 and + 3.5 have been met for that Covered Code, and if You include a + notice stating that the Source Code version of the Covered Code is available under the terms + of this License, including a description of how and where You have fulfilled the obligations + of Section 3.2. The notice must be conspicuously included in any + notice in an Executable version, related documentation or collateral in which You describe + recipients' rights relating to the Covered Code. You may distribute the Executable version of + Covered Code or ownership rights under a license of Your choice, which may contain terms + different from this License, provided that You are in compliance with the terms of this + License and that the license for the Executable version does not attempt to limit or alter the + recipient's rights in the Source Code version from the rights set forth in this License. If + You distribute the Executable version under a different license You must make it absolutely + clear that any terms which differ from this License are offered by You alone, not by the + Initial Developer or any Contributor. You hereby agree to indemnify the Initial Developer and + every Contributor for any liability incurred by the Initial Developer or such Contributor as + a result of any such terms You offer. +

    3.7. Larger Works.

    +

    You may create a Larger Work by combining Covered Code with other code not governed by the + terms of this License and distribute the Larger Work as a single product. In such a case, + You must make sure the requirements of this License are fulfilled for the Covered Code. +

    4. Inability to Comply Due to Statute or Regulation.

    +

    If it is impossible for You to comply with any of the terms of this License with respect to + some or all of the Covered Code due to statute, judicial order, or regulation then You must: + (a) comply with the terms of this License to the maximum extent possible; and (b) describe + the limitations and the code they affect. Such description must be included in the + legal file described in Section + 3.4 and must be included with all distributions of the Source Code. + Except to the extent prohibited by statute or regulation, such description must be + sufficiently detailed for a recipient of ordinary skill to be able to understand it. +

    5. Application of this License.

    +

    This License applies to code to which the Initial Developer has attached the notice in + Exhibit A and to related Covered Code. +

    6. Versions of the License.

    +

    6.1. New Versions

    +

    Netscape Communications Corporation ("Netscape") may publish revised and/or new versions + of the License from time to time. Each version will be given a distinguishing version number. +

    6.2. Effect of New Versions

    +

    Once Covered Code has been published under a particular version of the License, You may + always continue to use it under the terms of that version. You may also choose to use such + Covered Code under the terms of any subsequent version of the License published by Netscape. + No one other than Netscape has the right to modify the terms applicable to Covered Code + created under this License. +

    6.3. Derivative Works

    +

    If You create or use a modified version of this License (which you may only do in order to + apply it to code which is not already Covered Code governed by this License), You must (a) + rename Your license so that the phrases "Mozilla", "MOZILLAPL", "MOZPL", "Netscape", "MPL", + "NPL" or any confusingly similar phrase do not appear in your license (except to note that + your license differs from this License) and (b) otherwise make it clear that Your version of + the license contains terms which differ from the Mozilla Public License and Netscape Public + License. (Filling in the name of the Initial Developer, Original Code or Contributor in the + notice described in Exhibit A shall not of themselves be deemed to + be modifications of this License.) +

    7. Disclaimer of warranty

    +

    Covered code is provided under this license on an "as is" + basis, without warranty of any kind, either expressed or implied, including, without + limitation, warranties that the covered code is free of defects, merchantable, fit for a + particular purpose or non-infringing. The entire risk as to the quality and performance of + the covered code is with you. Should any covered code prove defective in any respect, you + (not the initial developer or any other contributor) assume the cost of any necessary + servicing, repair or correction. This disclaimer of warranty constitutes an essential part + of this license. No use of any covered code is authorized hereunder except under this + disclaimer. +

    8. Termination

    +

    8.1. This License and the rights granted hereunder will terminate + automatically if You fail to comply with terms herein and fail to cure such breach + within 30 days of becoming aware of the breach. All sublicenses to the Covered Code which + are properly granted shall survive any termination of this License. Provisions which, by + their nature, must remain in effect beyond the termination of this License shall survive. +

    8.2. If You initiate litigation by asserting a patent infringement + claim (excluding declatory judgment actions) against Initial Developer or a Contributor + (the Initial Developer or Contributor against whom You file such action is referred to + as "Participant") alleging that: +

      +
    1. such Participant's Contributor Version directly or indirectly + infringes any patent, then any and all rights granted by such Participant to You under + Sections 2.1 and/or 2.2 of this + License shall, upon 60 days notice from Participant terminate prospectively, unless if + within 60 days after receipt of notice You either: (i) agree in writing to pay + Participant a mutually agreeable reasonable royalty for Your past and future use of + Modifications made by such Participant, or (ii) withdraw Your litigation claim with + respect to the Contributor Version against such Participant. If within 60 days of + notice, a reasonable royalty and payment arrangement are not mutually agreed upon in + writing by the parties or the litigation claim is not withdrawn, the rights granted by + Participant to You under Sections 2.1 and/or + 2.2 automatically terminate at the expiration of the 60 day + notice period specified above. +
    2. any software, hardware, or device, other than such Participant's + Contributor Version, directly or indirectly infringes any patent, then any rights + granted to You by such Participant under Sections 2.1(b) + and 2.2(b) are revoked effective as of the date You first + made, used, sold, distributed, or had made, Modifications made by that Participant. +
    +

    8.3. If You assert a patent infringement claim against Participant + alleging that such Participant's Contributor Version directly or indirectly infringes + any patent where such claim is resolved (such as by license or settlement) prior to the + initiation of patent infringement litigation, then the reasonable value of the licenses + granted by such Participant under Sections 2.1 or + 2.2 shall be taken into account in determining the amount or + value of any payment or license. +

    8.4. In the event of termination under Sections + 8.1 or 8.2 above, all end user + license agreements (excluding distributors and resellers) which have been validly + granted by You or any distributor hereunder prior to termination shall survive + termination. +

    9. Limitation of liability

    +

    Under no circumstances and under no legal theory, whether + tort (including negligence), contract, or otherwise, shall you, the initial developer, + any other contributor, or any distributor of covered code, or any supplier of any of + such parties, be liable to any person for any indirect, special, incidental, or + consequential damages of any character including, without limitation, damages for loss + of goodwill, work stoppage, computer failure or malfunction, or any and all other + commercial damages or losses, even if such party shall have been informed of the + possibility of such damages. This limitation of liability shall not apply to liability + for death or personal injury resulting from such party's negligence to the extent + applicable law prohibits such limitation. Some jurisdictions do not allow the exclusion + or limitation of incidental or consequential damages, so this exclusion and limitation + may not apply to you. +

    10. U.S. government end users

    +

    The Covered Code is a "commercial item," as that term is defined in 48 + C.F.R. 2.101 (Oct. 1995), consisting of + "commercial computer software" and "commercial computer software documentation," as such + terms are used in 48 C.F.R. 12.212 (Sept. + 1995). Consistent with 48 C.F.R. 12.212 and 48 C.F.R. + 227.7202-1 through 227.7202-4 (June 1995), all U.S. Government End Users + acquire Covered Code with only those rights set forth herein. +

    11. Miscellaneous

    +

    This License represents the complete agreement concerning subject matter hereof. If + any provision of this License is held to be unenforceable, such provision shall be + reformed only to the extent necessary to make it enforceable. This License shall be + governed by California law provisions (except to the extent applicable law, if any, + provides otherwise), excluding its conflict-of-law provisions. With respect to + disputes in which at least one party is a citizen of, or an entity chartered or + registered to do business in the United States of America, any litigation relating to + this License shall be subject to the jurisdiction of the Federal Courts of the + Northern District of California, with venue lying in Santa Clara County, California, + with the losing party responsible for costs, including without limitation, court + costs and reasonable attorneys' fees and expenses. The application of the United + Nations Convention on Contracts for the International Sale of Goods is expressly + excluded. Any law or regulation which provides that the language of a contract + shall be construed against the drafter shall not apply to this License. +

    12. Responsibility for claims

    +

    As between Initial Developer and the Contributors, each party is responsible for + claims and damages arising, directly or indirectly, out of its utilization of rights + under this License and You agree to work with Initial Developer and Contributors to + distribute such responsibility on an equitable basis. Nothing herein is intended or + shall be deemed to constitute any admission of liability. +

    13. Multiple-licensed code

    +

    Initial Developer may designate portions of the Covered Code as + "Multiple-Licensed". "Multiple-Licensed" means that the Initial Developer permits + you to utilize portions of the Covered Code under Your choice of the MPL + or the alternative licenses, if any, specified by the Initial Developer in the file + described in Exhibit A. +

    Exhibit A - Mozilla Public License.

    +
    "The contents of this file are subject to the Mozilla Public License
    +Version 1.1 (the "License"); you may not use this file except in
    +compliance with the License. You may obtain a copy of the License at
    +http://www.mozilla.org/MPL/
    +
    +Software distributed under the License is distributed on an "AS IS"
    +basis, WITHOUT WARRANTY OF ANY KIND, either express or implied. See the
    +License for the specific language governing rights and limitations
    +under the License.
    +
    +The Original Code is JTransforms.
    +
    +The Initial Developer of the Original Code is
    +Piotr Wendykier, Emory University.
    +Portions created by the Initial Developer are Copyright (C) 2007-2009
    +the Initial Developer. All Rights Reserved.
    +
    +Alternatively, the contents of this file may be used under the terms of
    +either the GNU General Public License Version 2 or later (the "GPL"), or
    +the GNU Lesser General Public License Version 2.1 or later (the "LGPL"),
    +in which case the provisions of the GPL or the LGPL are applicable instead
    +of those above. If you wish to allow use of your version of this file only
    +under the terms of either the GPL or the LGPL, and not to allow others to
    +use your version of this file under the terms of the MPL, indicate your
    +decision by deleting the provisions above and replace them with the notice
    +and other provisions required by the GPL or the LGPL. If you do not delete
    +the provisions above, a recipient may use your version of this file under
    +the terms of any one of the MPL, the GPL or the LGPL.
    +

    NOTE: The text of this Exhibit A may differ slightly from the text of + the notices in the Source Code files of the Original Code. You should + use the text of this Exhibit A rather than the text found in the + Original Code Source Code for Your Modifications. + +

    \ No newline at end of file diff --git a/licenses/LICENSE-kryo.txt b/licenses-binary/LICENSE-kryo.txt similarity index 100% rename from licenses/LICENSE-kryo.txt rename to licenses-binary/LICENSE-kryo.txt diff --git a/licenses-binary/LICENSE-leveldbjni.txt b/licenses-binary/LICENSE-leveldbjni.txt new file mode 100644 index 0000000000000..b4dabb9174c6d --- /dev/null +++ b/licenses-binary/LICENSE-leveldbjni.txt @@ -0,0 +1,27 @@ +Copyright (c) 2011 FuseSource Corp. All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: + + * Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above +copyright notice, this list of conditions and the following disclaimer +in the documentation and/or other materials provided with the +distribution. + * Neither the name of FuseSource Corp. nor the names of its +contributors may be used to endorse or promote products derived from +this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. \ No newline at end of file diff --git a/licenses-binary/LICENSE-machinist.txt b/licenses-binary/LICENSE-machinist.txt new file mode 100644 index 0000000000000..68cc3a3e3a9c4 --- /dev/null +++ b/licenses-binary/LICENSE-machinist.txt @@ -0,0 +1,19 @@ +Copyright (c) 2011-2014 Erik Osheim, Tom Switzer + +Permission is hereby granted, free of charge, to any person obtaining a copy of +this software and associated documentation files (the "Software"), to deal in +the Software without restriction, including without limitation the rights to +use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies +of the Software, and to permit persons to whom the Software is furnished to do +so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. \ No newline at end of file diff --git a/licenses-binary/LICENSE-matchMedia-polyfill.txt b/licenses-binary/LICENSE-matchMedia-polyfill.txt new file mode 100644 index 0000000000000..2fd0bc2b37448 --- /dev/null +++ b/licenses-binary/LICENSE-matchMedia-polyfill.txt @@ -0,0 +1 @@ +matchMedia() polyfill - Test a CSS media type/query in JS. Authors & copyright (c) 2012: Scott Jehl, Paul Irish, Nicholas Zakas. Dual MIT/BSD license \ No newline at end of file diff --git a/licenses/LICENSE-minlog.txt b/licenses-binary/LICENSE-minlog.txt similarity index 100% rename from licenses/LICENSE-minlog.txt rename to licenses-binary/LICENSE-minlog.txt diff --git a/licenses-binary/LICENSE-modernizr.txt b/licenses-binary/LICENSE-modernizr.txt new file mode 100644 index 0000000000000..2bf24b9b9f848 --- /dev/null +++ b/licenses-binary/LICENSE-modernizr.txt @@ -0,0 +1,21 @@ +The MIT License (MIT) + +Copyright (c) + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. \ No newline at end of file diff --git a/licenses/LICENSE-netlib.txt b/licenses-binary/LICENSE-netlib.txt similarity index 100% rename from licenses/LICENSE-netlib.txt rename to licenses-binary/LICENSE-netlib.txt diff --git a/licenses/LICENSE-paranamer.txt b/licenses-binary/LICENSE-paranamer.txt similarity index 100% rename from licenses/LICENSE-paranamer.txt rename to licenses-binary/LICENSE-paranamer.txt diff --git a/licenses/LICENSE-jpmml-model.txt b/licenses-binary/LICENSE-pmml-model.txt similarity index 100% rename from licenses/LICENSE-jpmml-model.txt rename to licenses-binary/LICENSE-pmml-model.txt diff --git a/licenses/LICENSE-protobuf.txt b/licenses-binary/LICENSE-protobuf.txt similarity index 100% rename from licenses/LICENSE-protobuf.txt rename to licenses-binary/LICENSE-protobuf.txt diff --git a/licenses-binary/LICENSE-py4j.txt b/licenses-binary/LICENSE-py4j.txt new file mode 100644 index 0000000000000..70af3e69ed67a --- /dev/null +++ b/licenses-binary/LICENSE-py4j.txt @@ -0,0 +1,27 @@ +Copyright (c) 2009-2011, Barthelemy Dagenais All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +- Redistributions of source code must retain the above copyright notice, this +list of conditions and the following disclaimer. + +- Redistributions in binary form must reproduce the above copyright notice, +this list of conditions and the following disclaimer in the documentation +and/or other materials provided with the distribution. + +- The name of the author may not be used to endorse or promote products +derived from this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +POSSIBILITY OF SUCH DAMAGE. + diff --git a/licenses/LICENSE-pyrolite.txt b/licenses-binary/LICENSE-pyrolite.txt similarity index 100% rename from licenses/LICENSE-pyrolite.txt rename to licenses-binary/LICENSE-pyrolite.txt diff --git a/licenses/LICENSE-reflectasm.txt b/licenses-binary/LICENSE-reflectasm.txt similarity index 100% rename from licenses/LICENSE-reflectasm.txt rename to licenses-binary/LICENSE-reflectasm.txt diff --git a/licenses-binary/LICENSE-respond.txt b/licenses-binary/LICENSE-respond.txt new file mode 100644 index 0000000000000..dea4ff9e5b2ea --- /dev/null +++ b/licenses-binary/LICENSE-respond.txt @@ -0,0 +1,22 @@ +Copyright (c) 2012 Scott Jehl + +Permission is hereby granted, free of charge, to any person +obtaining a copy of this software and associated documentation +files (the "Software"), to deal in the Software without +restriction, including without limitation the rights to use, +copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the +Software is furnished to do so, subject to the following +conditions: + +The above copyright notice and this permission notice shall be +included in all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES +OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND +NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT +HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, +WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR +OTHER DEALINGS IN THE SOFTWARE. \ No newline at end of file diff --git a/licenses-binary/LICENSE-sbt-launch-lib.txt b/licenses-binary/LICENSE-sbt-launch-lib.txt new file mode 100644 index 0000000000000..3b9156baaab78 --- /dev/null +++ b/licenses-binary/LICENSE-sbt-launch-lib.txt @@ -0,0 +1,26 @@ +// Generated from http://www.opensource.org/licenses/bsd-license.php +Copyright (c) 2011, Paul Phillips. +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + + * Redistributions of source code must retain the above copyright notice, + this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + * Neither the name of the author nor the names of its contributors may be + used to endorse or promote products derived from this software without + specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING +NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, +EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. \ No newline at end of file diff --git a/licenses/LICENSE-scala.txt b/licenses-binary/LICENSE-scala.txt similarity index 100% rename from licenses/LICENSE-scala.txt rename to licenses-binary/LICENSE-scala.txt diff --git a/licenses-binary/LICENSE-scopt.txt b/licenses-binary/LICENSE-scopt.txt new file mode 100644 index 0000000000000..e92e9b592fba0 --- /dev/null +++ b/licenses-binary/LICENSE-scopt.txt @@ -0,0 +1,9 @@ +This project is licensed under the MIT license. + +Copyright (c) scopt contributors + +Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. \ No newline at end of file diff --git a/licenses/LICENSE-slf4j.txt b/licenses-binary/LICENSE-slf4j.txt similarity index 100% rename from licenses/LICENSE-slf4j.txt rename to licenses-binary/LICENSE-slf4j.txt diff --git a/licenses-binary/LICENSE-sorttable.js.txt b/licenses-binary/LICENSE-sorttable.js.txt new file mode 100644 index 0000000000000..b31a5b206bf40 --- /dev/null +++ b/licenses-binary/LICENSE-sorttable.js.txt @@ -0,0 +1,16 @@ +Copyright (c) 1997-2007 Stuart Langridge + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. diff --git a/licenses/LICENSE-spire.txt b/licenses-binary/LICENSE-spire.txt similarity index 100% rename from licenses/LICENSE-spire.txt rename to licenses-binary/LICENSE-spire.txt diff --git a/licenses-binary/LICENSE-vis.txt b/licenses-binary/LICENSE-vis.txt new file mode 100644 index 0000000000000..18b7323059a41 --- /dev/null +++ b/licenses-binary/LICENSE-vis.txt @@ -0,0 +1,22 @@ +vis.js +https://github.com/almende/vis + +A dynamic, browser-based visualization library. + +@version 4.16.1 +@date 2016-04-18 + +@license +Copyright (C) 2011-2016 Almende B.V, http://almende.com + +Vis.js is dual licensed under both + +* The Apache 2.0 License + http://www.apache.org/licenses/LICENSE-2.0 + +and + +* The MIT License + http://opensource.org/licenses/MIT + +Vis.js may be distributed under either license. \ No newline at end of file diff --git a/licenses/LICENSE-xmlenc.txt b/licenses-binary/LICENSE-xmlenc.txt similarity index 100% rename from licenses/LICENSE-xmlenc.txt rename to licenses-binary/LICENSE-xmlenc.txt diff --git a/licenses/LICENSE-zstd-jni.txt b/licenses-binary/LICENSE-zstd-jni.txt similarity index 100% rename from licenses/LICENSE-zstd-jni.txt rename to licenses-binary/LICENSE-zstd-jni.txt diff --git a/licenses/LICENSE-zstd.txt b/licenses-binary/LICENSE-zstd.txt similarity index 100% rename from licenses/LICENSE-zstd.txt rename to licenses-binary/LICENSE-zstd.txt diff --git a/licenses/LICENSE-CC0.txt b/licenses/LICENSE-CC0.txt new file mode 100644 index 0000000000000..1625c17936079 --- /dev/null +++ b/licenses/LICENSE-CC0.txt @@ -0,0 +1,121 @@ +Creative Commons Legal Code + +CC0 1.0 Universal + + CREATIVE COMMONS CORPORATION IS NOT A LAW FIRM AND DOES NOT PROVIDE + LEGAL SERVICES. DISTRIBUTION OF THIS DOCUMENT DOES NOT CREATE AN + ATTORNEY-CLIENT RELATIONSHIP. CREATIVE COMMONS PROVIDES THIS + INFORMATION ON AN "AS-IS" BASIS. CREATIVE COMMONS MAKES NO WARRANTIES + REGARDING THE USE OF THIS DOCUMENT OR THE INFORMATION OR WORKS + PROVIDED HEREUNDER, AND DISCLAIMS LIABILITY FOR DAMAGES RESULTING FROM + THE USE OF THIS DOCUMENT OR THE INFORMATION OR WORKS PROVIDED + HEREUNDER. + +Statement of Purpose + +The laws of most jurisdictions throughout the world automatically confer +exclusive Copyright and Related Rights (defined below) upon the creator +and subsequent owner(s) (each and all, an "owner") of an original work of +authorship and/or a database (each, a "Work"). + +Certain owners wish to permanently relinquish those rights to a Work for +the purpose of contributing to a commons of creative, cultural and +scientific works ("Commons") that the public can reliably and without fear +of later claims of infringement build upon, modify, incorporate in other +works, reuse and redistribute as freely as possible in any form whatsoever +and for any purposes, including without limitation commercial purposes. +These owners may contribute to the Commons to promote the ideal of a free +culture and the further production of creative, cultural and scientific +works, or to gain reputation or greater distribution for their Work in +part through the use and efforts of others. + +For these and/or other purposes and motivations, and without any +expectation of additional consideration or compensation, the person +associating CC0 with a Work (the "Affirmer"), to the extent that he or she +is an owner of Copyright and Related Rights in the Work, voluntarily +elects to apply CC0 to the Work and publicly distribute the Work under its +terms, with knowledge of his or her Copyright and Related Rights in the +Work and the meaning and intended legal effect of CC0 on those rights. + +1. Copyright and Related Rights. A Work made available under CC0 may be +protected by copyright and related or neighboring rights ("Copyright and +Related Rights"). Copyright and Related Rights include, but are not +limited to, the following: + + i. the right to reproduce, adapt, distribute, perform, display, + communicate, and translate a Work; + ii. moral rights retained by the original author(s) and/or performer(s); +iii. publicity and privacy rights pertaining to a person's image or + likeness depicted in a Work; + iv. rights protecting against unfair competition in regards to a Work, + subject to the limitations in paragraph 4(a), below; + v. rights protecting the extraction, dissemination, use and reuse of data + in a Work; + vi. database rights (such as those arising under Directive 96/9/EC of the + European Parliament and of the Council of 11 March 1996 on the legal + protection of databases, and under any national implementation + thereof, including any amended or successor version of such + directive); and +vii. other similar, equivalent or corresponding rights throughout the + world based on applicable law or treaty, and any national + implementations thereof. + +2. Waiver. To the greatest extent permitted by, but not in contravention +of, applicable law, Affirmer hereby overtly, fully, permanently, +irrevocably and unconditionally waives, abandons, and surrenders all of +Affirmer's Copyright and Related Rights and associated claims and causes +of action, whether now known or unknown (including existing as well as +future claims and causes of action), in the Work (i) in all territories +worldwide, (ii) for the maximum duration provided by applicable law or +treaty (including future time extensions), (iii) in any current or future +medium and for any number of copies, and (iv) for any purpose whatsoever, +including without limitation commercial, advertising or promotional +purposes (the "Waiver"). Affirmer makes the Waiver for the benefit of each +member of the public at large and to the detriment of Affirmer's heirs and +successors, fully intending that such Waiver shall not be subject to +revocation, rescission, cancellation, termination, or any other legal or +equitable action to disrupt the quiet enjoyment of the Work by the public +as contemplated by Affirmer's express Statement of Purpose. + +3. Public License Fallback. Should any part of the Waiver for any reason +be judged legally invalid or ineffective under applicable law, then the +Waiver shall be preserved to the maximum extent permitted taking into +account Affirmer's express Statement of Purpose. In addition, to the +extent the Waiver is so judged Affirmer hereby grants to each affected +person a royalty-free, non transferable, non sublicensable, non exclusive, +irrevocable and unconditional license to exercise Affirmer's Copyright and +Related Rights in the Work (i) in all territories worldwide, (ii) for the +maximum duration provided by applicable law or treaty (including future +time extensions), (iii) in any current or future medium and for any number +of copies, and (iv) for any purpose whatsoever, including without +limitation commercial, advertising or promotional purposes (the +"License"). The License shall be deemed effective as of the date CC0 was +applied by Affirmer to the Work. Should any part of the License for any +reason be judged legally invalid or ineffective under applicable law, such +partial invalidity or ineffectiveness shall not invalidate the remainder +of the License, and in such case Affirmer hereby affirms that he or she +will not (i) exercise any of his or her remaining Copyright and Related +Rights in the Work or (ii) assert any associated claims and causes of +action with respect to the Work, in either case contrary to Affirmer's +express Statement of Purpose. + +4. Limitations and Disclaimers. + + a. No trademark or patent rights held by Affirmer are waived, abandoned, + surrendered, licensed or otherwise affected by this document. + b. Affirmer offers the Work as-is and makes no representations or + warranties of any kind concerning the Work, express, implied, + statutory or otherwise, including without limitation warranties of + title, merchantability, fitness for a particular purpose, non + infringement, or the absence of latent or other defects, accuracy, or + the present or absence of errors, whether or not discoverable, all to + the greatest extent permissible under applicable law. + c. Affirmer disclaims responsibility for clearing rights of other persons + that may apply to the Work or any use thereof, including without + limitation any person's Copyright and Related Rights in the Work. + Further, Affirmer disclaims responsibility for obtaining any necessary + consents, permissions or other rights required for any use of the + Work. + d. Affirmer understands and acknowledges that Creative Commons is not a + party to this document and has no duty or obligation with respect to + this CC0 or use of the Work. \ No newline at end of file diff --git a/licenses/LICENSE-SnapTree.txt b/licenses/LICENSE-SnapTree.txt deleted file mode 100644 index a538825d89ec5..0000000000000 --- a/licenses/LICENSE-SnapTree.txt +++ /dev/null @@ -1,35 +0,0 @@ -SNAPTREE LICENSE - -Copyright (c) 2009-2012 Stanford University, unless otherwise specified. -All rights reserved. - -This software was developed by the Pervasive Parallelism Laboratory of -Stanford University, California, USA. - -Permission to use, copy, modify, and distribute this software in source -or binary form for any purpose with or without fee is hereby granted, -provided that the following conditions are met: - - 1. Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - 2. Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - 3. Neither the name of Stanford University nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. - - -THIS SOFTWARE IS PROVIDED BY THE REGENTS AND CONTRIBUTORS ``AS IS'' AND -ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE -ARE DISCLAIMED. IN NO EVENT SHALL THE REGENTS OR CONTRIBUTORS BE LIABLE -FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT -LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY -OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF -SUCH DAMAGE. diff --git a/licenses/LICENSE-bootstrap.txt b/licenses/LICENSE-bootstrap.txt new file mode 100644 index 0000000000000..6c711832fbc85 --- /dev/null +++ b/licenses/LICENSE-bootstrap.txt @@ -0,0 +1,13 @@ +Copyright 2013 Twitter, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. diff --git a/licenses/LICENSE-boto.txt b/licenses/LICENSE-boto.txt deleted file mode 100644 index 7bba0cd9e10a4..0000000000000 --- a/licenses/LICENSE-boto.txt +++ /dev/null @@ -1,20 +0,0 @@ -Copyright (c) 2006-2008 Mitch Garnaat http://garnaat.org/ - -Permission is hereby granted, free of charge, to any person obtaining a -copy of this software and associated documentation files (the -"Software"), to deal in the Software without restriction, including -without limitation the rights to use, copy, modify, merge, publish, dis- -tribute, sublicense, and/or sell copies of the Software, and to permit -persons to whom the Software is furnished to do so, subject to the fol- -lowing conditions: - -The above copyright notice and this permission notice shall be included -in all copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS -OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABIL- -ITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT -SHALL THE AUTHOR BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, -WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS -IN THE SOFTWARE. \ No newline at end of file diff --git a/licenses/LICENSE-datatables.txt b/licenses/LICENSE-datatables.txt new file mode 100644 index 0000000000000..bb7708b5b5a49 --- /dev/null +++ b/licenses/LICENSE-datatables.txt @@ -0,0 +1,7 @@ +Copyright (C) 2008-2018, SpryMedia Ltd. + +Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. \ No newline at end of file diff --git a/licenses/LICENSE-graphlib-dot.txt b/licenses/LICENSE-graphlib-dot.txt index c9e18cd562423..4864fe05e9803 100644 --- a/licenses/LICENSE-graphlib-dot.txt +++ b/licenses/LICENSE-graphlib-dot.txt @@ -1,4 +1,4 @@ -Copyright (c) 2012-2013 Chris Pettitt +Copyright (c) 2013 Chris Pettitt Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal diff --git a/licenses/LICENSE-jbcrypt.txt b/licenses/LICENSE-jbcrypt.txt deleted file mode 100644 index d332534c06356..0000000000000 --- a/licenses/LICENSE-jbcrypt.txt +++ /dev/null @@ -1,17 +0,0 @@ -jBCrypt is subject to the following license: - -/* - * Copyright (c) 2006 Damien Miller - * - * Permission to use, copy, modify, and distribute this software for any - * purpose with or without fee is hereby granted, provided that the above - * copyright notice and this permission notice appear in all copies. - * - * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES - * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF - * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR - * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES - * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN - * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF - * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. - */ diff --git a/licenses/LICENSE-join.txt b/licenses/LICENSE-join.txt new file mode 100644 index 0000000000000..1d916090e4ea0 --- /dev/null +++ b/licenses/LICENSE-join.txt @@ -0,0 +1,30 @@ +Copyright (c) 2011, Douban Inc. +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: + + * Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. + + * Redistributions in binary form must reproduce the above +copyright notice, this list of conditions and the following disclaimer +in the documentation and/or other materials provided with the +distribution. + + * Neither the name of the Douban Inc. nor the names of its +contributors may be used to endorse or promote products derived from +this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. \ No newline at end of file diff --git a/licenses/LICENSE-jquery.txt b/licenses/LICENSE-jquery.txt index e1dd696d3b6cc..45930542204fb 100644 --- a/licenses/LICENSE-jquery.txt +++ b/licenses/LICENSE-jquery.txt @@ -1,9 +1,20 @@ -The MIT License (MIT) +Copyright JS Foundation and other contributors, https://js.foundation/ -Copyright (c) +Permission is hereby granted, free of charge, to any person obtaining +a copy of this software and associated documentation files (the +"Software"), to deal in the Software without restriction, including +without limitation the rights to use, copy, modify, merge, publish, +distribute, sublicense, and/or sell copies of the Software, and to +permit persons to whom the Software is furnished to do so, subject to +the following conditions: -Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: +The above copyright notice and this permission notice shall be +included in all copies or substantial portions of the Software. -The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. \ No newline at end of file +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND +NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE +LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION +OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION +WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. \ No newline at end of file diff --git a/licenses/LICENSE-json-formatter.txt b/licenses/LICENSE-json-formatter.txt new file mode 100644 index 0000000000000..5193348fce126 --- /dev/null +++ b/licenses/LICENSE-json-formatter.txt @@ -0,0 +1,6 @@ +Copyright 2014 Mohsen Azimi + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. \ No newline at end of file diff --git a/licenses/LICENSE-matchMedia-polyfill.txt b/licenses/LICENSE-matchMedia-polyfill.txt new file mode 100644 index 0000000000000..2fd0bc2b37448 --- /dev/null +++ b/licenses/LICENSE-matchMedia-polyfill.txt @@ -0,0 +1 @@ +matchMedia() polyfill - Test a CSS media type/query in JS. Authors & copyright (c) 2012: Scott Jehl, Paul Irish, Nicholas Zakas. Dual MIT/BSD license \ No newline at end of file diff --git a/licenses/LICENSE-postgresql.txt b/licenses/LICENSE-postgresql.txt deleted file mode 100644 index 515bf9af4d432..0000000000000 --- a/licenses/LICENSE-postgresql.txt +++ /dev/null @@ -1,24 +0,0 @@ -PostgreSQL Database Management System -(formerly known as Postgres, then as Postgres95) - -Portions Copyright (c) 1996-2010, PostgreSQL Global Development Group - -Portions Copyright (c) 1994, The Regents of the University of California - -Permission to use, copy, modify, and distribute this software and its -documentation for any purpose, without fee, and without a written agreement -is hereby granted, provided that the above copyright notice and this -paragraph and the following two paragraphs appear in all copies. - -IN NO EVENT SHALL THE UNIVERSITY OF CALIFORNIA BE LIABLE TO ANY PARTY FOR -DIRECT, INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, INCLUDING -LOST PROFITS, ARISING OUT OF THE USE OF THIS SOFTWARE AND ITS -DOCUMENTATION, EVEN IF THE UNIVERSITY OF CALIFORNIA HAS BEEN ADVISED OF THE -POSSIBILITY OF SUCH DAMAGE. - -THE UNIVERSITY OF CALIFORNIA SPECIFICALLY DISCLAIMS ANY WARRANTIES, -INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY -AND FITNESS FOR A PARTICULAR PURPOSE. THE SOFTWARE PROVIDED HEREUNDER IS -ON AN "AS IS" BASIS, AND THE UNIVERSITY OF CALIFORNIA HAS NO OBLIGATIONS TO -PROVIDE MAINTENANCE, SUPPORT, UPDATES, ENHANCEMENTS, OR MODIFICATIONS. - diff --git a/licenses/LICENSE-respond.txt b/licenses/LICENSE-respond.txt new file mode 100644 index 0000000000000..dea4ff9e5b2ea --- /dev/null +++ b/licenses/LICENSE-respond.txt @@ -0,0 +1,22 @@ +Copyright (c) 2012 Scott Jehl + +Permission is hereby granted, free of charge, to any person +obtaining a copy of this software and associated documentation +files (the "Software"), to deal in the Software without +restriction, including without limitation the rights to use, +copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the +Software is furnished to do so, subject to the following +conditions: + +The above copyright notice and this permission notice shall be +included in all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES +OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND +NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT +HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, +WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR +OTHER DEALINGS IN THE SOFTWARE. \ No newline at end of file diff --git a/licenses/LICENSE-scalacheck.txt b/licenses/LICENSE-scalacheck.txt deleted file mode 100644 index cb8f97842f4c4..0000000000000 --- a/licenses/LICENSE-scalacheck.txt +++ /dev/null @@ -1,32 +0,0 @@ -ScalaCheck LICENSE - -Copyright (c) 2007-2015, Rickard Nilsson -All rights reserved. - -Permission to use, copy, modify, and distribute this software in source -or binary form for any purpose with or without fee is hereby granted, -provided that the following conditions are met: - - 1. Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - 2. Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - 3. Neither the name of the author nor the names of its contributors - may be used to endorse or promote products derived from this - software without specific prior written permission. - - -THIS SOFTWARE IS PROVIDED BY THE REGENTS AND CONTRIBUTORS ``AS IS'' AND -ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE -ARE DISCLAIMED. IN NO EVENT SHALL THE REGENTS OR CONTRIBUTORS BE LIABLE -FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT -LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY -OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF -SUCH DAMAGE. \ No newline at end of file diff --git a/licenses/LICENSE-vis.txt b/licenses/LICENSE-vis.txt new file mode 100644 index 0000000000000..18b7323059a41 --- /dev/null +++ b/licenses/LICENSE-vis.txt @@ -0,0 +1,22 @@ +vis.js +https://github.com/almende/vis + +A dynamic, browser-based visualization library. + +@version 4.16.1 +@date 2016-04-18 + +@license +Copyright (C) 2011-2016 Almende B.V, http://almende.com + +Vis.js is dual licensed under both + +* The Apache 2.0 License + http://www.apache.org/licenses/LICENSE-2.0 + +and + +* The MIT License + http://opensource.org/licenses/MIT + +Vis.js may be distributed under either license. \ No newline at end of file diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index 06ca37bc75146..92e342ed4a464 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -1202,6 +1202,11 @@ class LogisticRegressionModel private[spark] ( */ @Since("1.6.0") override def write: MLWriter = new LogisticRegressionModel.LogisticRegressionModelWriter(this) + + override def toString: String = { + s"LogisticRegressionModel: " + + s"uid = ${super.toString}, numClasses = $numClasses, numFeatures = $numFeatures" + } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala index 9c9614509c64f..de564471c2b4d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala @@ -274,7 +274,7 @@ class BisectingKMeans @Since("2.0.0") ( val parentModel = bkm.run(rdd) val model = copyValues(new BisectingKMeansModel(uid, parentModel).setParent(this)) val summary = new BisectingKMeansSummary( - model.transform(dataset), $(predictionCol), $(featuresCol), $(k)) + model.transform(dataset), $(predictionCol), $(featuresCol), $(k), $(maxIter)) model.setSummary(Some(summary)) instr.logNamedValue("clusterSizes", summary.clusterSizes) instr.logSuccess(model) @@ -304,6 +304,7 @@ object BisectingKMeans extends DefaultParamsReadable[BisectingKMeans] { * @param predictionCol Name for column of predicted clusters in `predictions`. * @param featuresCol Name for column of features in `predictions`. * @param k Number of clusters. + * @param numIter Number of iterations. */ @Since("2.1.0") @Experimental @@ -311,4 +312,5 @@ class BisectingKMeansSummary private[clustering] ( predictions: DataFrame, predictionCol: String, featuresCol: String, - k: Int) extends ClusteringSummary(predictions, predictionCol, featuresCol, k) + k: Int, + numIter: Int) extends ClusteringSummary(predictions, predictionCol, featuresCol, k, numIter) diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/ClusteringSummary.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/ClusteringSummary.scala index 44e832b058b62..7da4c43a1abf3 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/ClusteringSummary.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/ClusteringSummary.scala @@ -17,7 +17,7 @@ package org.apache.spark.ml.clustering -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.sql.{DataFrame, Row} /** @@ -28,13 +28,15 @@ import org.apache.spark.sql.{DataFrame, Row} * @param predictionCol Name for column of predicted clusters in `predictions`. * @param featuresCol Name for column of features in `predictions`. * @param k Number of clusters. + * @param numIter Number of iterations. */ @Experimental class ClusteringSummary private[clustering] ( @transient val predictions: DataFrame, val predictionCol: String, val featuresCol: String, - val k: Int) extends Serializable { + val k: Int, + @Since("2.4.0") val numIter: Int) extends Serializable { /** * Cluster centers of the transformed data. diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala index 64ecc1ebda589..f0707b380c673 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala @@ -341,7 +341,7 @@ class GaussianMixture @Since("2.0.0") ( val sc = dataset.sparkSession.sparkContext val numClusters = $(k) - val instances: RDD[Vector] = dataset + val instances = dataset .select(DatasetUtils.columnToVector(dataset, getFeaturesCol)).rdd.map { case Row(features: Vector) => features }.cache() @@ -416,6 +416,7 @@ class GaussianMixture @Since("2.0.0") ( iter += 1 } + instances.unpersist(false) val gaussianDists = gaussians.map { case (mean, covVec) => val cov = GaussianMixture.unpackUpperTriangularMatrix(numFeatures, covVec.values) new MultivariateGaussian(mean, cov) @@ -423,7 +424,7 @@ class GaussianMixture @Since("2.0.0") ( val model = copyValues(new GaussianMixtureModel(uid, weights, gaussianDists)).setParent(this) val summary = new GaussianMixtureSummary(model.transform(dataset), - $(predictionCol), $(probabilityCol), $(featuresCol), $(k), logLikelihood) + $(predictionCol), $(probabilityCol), $(featuresCol), $(k), logLikelihood, iter) model.setSummary(Some(summary)) instr.logNamedValue("logLikelihood", logLikelihood) instr.logNamedValue("clusterSizes", summary.clusterSizes) @@ -687,6 +688,7 @@ private class ExpectationAggregator( * @param featuresCol Name for column of features in `predictions`. * @param k Number of clusters. * @param logLikelihood Total log-likelihood for this model on the given data. + * @param numIter Number of iterations. */ @Since("2.0.0") @Experimental @@ -696,8 +698,9 @@ class GaussianMixtureSummary private[clustering] ( @Since("2.0.0") val probabilityCol: String, featuresCol: String, k: Int, - @Since("2.2.0") val logLikelihood: Double) - extends ClusteringSummary(predictions, predictionCol, featuresCol, k) { + @Since("2.2.0") val logLikelihood: Double, + numIter: Int) + extends ClusteringSummary(predictions, predictionCol, featuresCol, k, numIter) { /** * Probability of each cluster. diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala index 1704412741d49..f40037a8d9aa9 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala @@ -356,7 +356,7 @@ class KMeans @Since("1.5.0") ( val parentModel = algo.run(instances, Option(instr)) val model = copyValues(new KMeansModel(uid, parentModel).setParent(this)) val summary = new KMeansSummary( - model.transform(dataset), $(predictionCol), $(featuresCol), $(k)) + model.transform(dataset), $(predictionCol), $(featuresCol), $(k), parentModel.numIter) model.setSummary(Some(summary)) instr.logNamedValue("clusterSizes", summary.clusterSizes) @@ -388,6 +388,7 @@ object KMeans extends DefaultParamsReadable[KMeans] { * @param predictionCol Name for column of predicted clusters in `predictions`. * @param featuresCol Name for column of features in `predictions`. * @param k Number of clusters. + * @param numIter Number of iterations. */ @Since("2.0.0") @Experimental @@ -395,4 +396,5 @@ class KMeansSummary private[clustering] ( predictions: DataFrame, predictionCol: String, featuresCol: String, - k: Int) extends ClusteringSummary(predictions, predictionCol, featuresCol, k) + k: Int, + numIter: Int) extends ClusteringSummary(predictions, predictionCol, featuresCol, k, numIter) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/MinHashLSH.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/MinHashLSH.scala index a67a3b0abbc1f..a043033e96724 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/MinHashLSH.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/MinHashLSH.scala @@ -66,7 +66,7 @@ class MinHashLSHModel private[ml]( val elemsList = elems.toSparse.indices.toList val hashValues = randCoefficients.map { case (a, b) => elemsList.map { elem: Int => - ((1 + elem) * a + b) % MinHashLSH.HASH_PRIME + ((1L + elem) * a + b) % MinHashLSH.HASH_PRIME }.min.toDouble } // TODO: Output vectors of dimension numHashFunctions in SPARK-18450 diff --git a/mllib/src/main/scala/org/apache/spark/ml/image/HadoopUtils.scala b/mllib/src/main/scala/org/apache/spark/ml/image/HadoopUtils.scala index 8c975a2fba8ca..f1579ec5844a4 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/image/HadoopUtils.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/image/HadoopUtils.scala @@ -42,9 +42,11 @@ private object RecursiveFlag { val old = Option(hadoopConf.get(flagName)) hadoopConf.set(flagName, value.toString) try f finally { - old match { - case Some(v) => hadoopConf.set(flagName, v) - case None => hadoopConf.unset(flagName) + // avoid false positive of DLS_DEAD_LOCAL_STORE_IN_RETURN by SpotBugs + if (old.isDefined) { + hadoopConf.set(flagName, old.get) + } else { + hadoopConf.unset(flagName) } } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala index b5b1be3490497..37ae8b1a6171a 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala @@ -348,7 +348,7 @@ class KMeans private ( logInfo(s"The cost is $cost.") - new KMeansModel(centers.map(_.vector), distanceMeasure) + new KMeansModel(centers.map(_.vector), distanceMeasure, iteration) } /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala index a78c21e838e44..e3a88b42fbf73 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala @@ -36,8 +36,9 @@ import org.apache.spark.sql.{Row, SparkSession} * A clustering model for K-means. Each point belongs to the cluster with the closest center. */ @Since("0.8.0") -class KMeansModel @Since("2.4.0") (@Since("1.0.0") val clusterCenters: Array[Vector], - @Since("2.4.0") val distanceMeasure: String) +class KMeansModel (@Since("1.0.0") val clusterCenters: Array[Vector], + @Since("2.4.0") val distanceMeasure: String, + private[spark] val numIter: Int) extends Saveable with Serializable with PMMLExportable { private val distanceMeasureInstance: DistanceMeasure = @@ -46,6 +47,10 @@ class KMeansModel @Since("2.4.0") (@Since("1.0.0") val clusterCenters: Array[Vec private val clusterCentersWithNorm = if (clusterCenters == null) null else clusterCenters.map(new VectorWithNorm(_)) + @Since("2.4.0") + private[spark] def this(clusterCenters: Array[Vector], distanceMeasure: String) = + this(clusterCenters: Array[Vector], distanceMeasure, -1) + @Since("1.1.0") def this(clusterCenters: Array[Vector]) = this(clusterCenters: Array[Vector], DistanceMeasure.EUCLIDEAN) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala index ac709ad72f0c0..7b49d4d0812f9 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala @@ -78,8 +78,13 @@ class MatrixFactorizationModel @Since("0.8.0") ( /** Predict the rating of one user for one product. */ @Since("0.8.0") def predict(user: Int, product: Int): Double = { - val userVector = userFeatures.lookup(user).head - val productVector = productFeatures.lookup(product).head + val userFeatureSeq = userFeatures.lookup(user) + require(userFeatureSeq.nonEmpty, s"userId: $user not found in the model") + val productFeatureSeq = productFeatures.lookup(product) + require(productFeatureSeq.nonEmpty, s"productId: $product not found in the model") + + val userVector = userFeatureSeq.head + val productVector = productFeatureSeq.head blas.ddot(rank, userVector, 1, productVector, 1) } @@ -164,9 +169,12 @@ class MatrixFactorizationModel @Since("0.8.0") ( * recommended the product is. */ @Since("1.1.0") - def recommendProducts(user: Int, num: Int): Array[Rating] = - MatrixFactorizationModel.recommend(userFeatures.lookup(user).head, productFeatures, num) + def recommendProducts(user: Int, num: Int): Array[Rating] = { + val userFeatureSeq = userFeatures.lookup(user) + require(userFeatureSeq.nonEmpty, s"userId: $user not found in the model") + MatrixFactorizationModel.recommend(userFeatureSeq.head, productFeatures, num) .map(t => Rating(user, t._1, t._2)) + } /** * Recommends users to a product. That is, this returns users who are most likely to be @@ -181,9 +189,12 @@ class MatrixFactorizationModel @Since("0.8.0") ( * recommended the user is. */ @Since("1.1.0") - def recommendUsers(product: Int, num: Int): Array[Rating] = - MatrixFactorizationModel.recommend(productFeatures.lookup(product).head, userFeatures, num) + def recommendUsers(product: Int, num: Int): Array[Rating] = { + val productFeatureSeq = productFeatures.lookup(product) + require(productFeatureSeq.nonEmpty, s"productId: $product not found in the model") + MatrixFactorizationModel.recommend(productFeatureSeq.head, userFeatures, num) .map(t => Rating(t._1, product, t._2)) + } protected override val formatVersion: String = "1.0" diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala index 36b7e51f93d01..75c2aeb146786 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala @@ -2751,6 +2751,12 @@ class LogisticRegressionSuite extends MLTest with DefaultReadWriteTest { assert(model.getFamily === family) } } + + test("toString") { + val model = new LogisticRegressionModel("logReg", Vectors.dense(0.1, 0.2, 0.3), 0.0) + val expected = "LogisticRegressionModel: uid = logReg, numClasses = 2, numFeatures = 3" + assert(model.toString === expected) + } } object LogisticRegressionSuite { diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala index 81842afbddbbb..1b7780e171e77 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala @@ -133,6 +133,7 @@ class BisectingKMeansSuite extends MLTest with DefaultReadWriteTest { assert(clusterSizes.length === k) assert(clusterSizes.sum === numRows) assert(clusterSizes.forall(_ >= 0)) + assert(summary.numIter == 20) model.setSummary(None) assert(!model.hasSummary) diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala index 0b91f502f615b..13bed9dbe3e89 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala @@ -145,6 +145,7 @@ class GaussianMixtureSuite extends MLTest with DefaultReadWriteTest { assert(clusterSizes.length === k) assert(clusterSizes.sum === numRows) assert(clusterSizes.forall(_ >= 0)) + assert(summary.numIter == 2) model.setSummary(None) assert(!model.hasSummary) diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala index 2569e7a432ca4..829c90fe34e94 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala @@ -135,6 +135,7 @@ class KMeansSuite extends MLTest with DefaultReadWriteTest with PMMLReadWriteTes assert(clusterSizes.length === k) assert(clusterSizes.sum === numRows) assert(clusterSizes.forall(_ >= 0)) + assert(summary.numIter == 1) model.setSummary(None) assert(!model.hasSummary) diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/PowerIterationClusteringSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/PowerIterationClusteringSuite.scala index b7072728d48f0..55b460f1a4524 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/PowerIterationClusteringSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/PowerIterationClusteringSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.ml.clustering +import scala.collection.mutable + import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.ml.util.DefaultReadWriteTest import org.apache.spark.mllib.util.MLlibTestSparkContext @@ -76,12 +78,15 @@ class PowerIterationClusteringSuite extends SparkFunSuite .setMaxIter(40) .setWeightCol("weight") .assignClusters(data) - val localAssignments = assignments - .select('id, 'cluster) - .as[(Long, Int)].collect().toSet - val expectedResult = (0 until n1).map(x => (x, 1)).toSet ++ - (n1 until n).map(x => (x, 0)).toSet - assert(localAssignments === expectedResult) + .select("id", "cluster") + .as[(Long, Int)] + .collect() + + val predictions = Array.fill(2)(mutable.Set.empty[Long]) + assignments.foreach { + case (id, cluster) => predictions(cluster) += id + } + assert(predictions.toSet === Set((0 until n1).toSet, (n1 until n).toSet)) val assignments2 = new PowerIterationClustering() .setK(2) @@ -89,10 +94,15 @@ class PowerIterationClusteringSuite extends SparkFunSuite .setInitMode("degree") .setWeightCol("weight") .assignClusters(data) - val localAssignments2 = assignments2 - .select('id, 'cluster) - .as[(Long, Int)].collect().toSet - assert(localAssignments2 === expectedResult) + .select("id", "cluster") + .as[(Long, Int)] + .collect() + + val predictions2 = Array.fill(2)(mutable.Set.empty[Long]) + assignments2.foreach { + case (id, cluster) => predictions2(cluster) += id + } + assert(predictions2.toSet === Set((0 until n1).toSet, (n1 until n).toSet)) } test("supported input types") { diff --git a/mllib/src/test/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModelSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModelSuite.scala index 2c8ed057a516a..5ed9d077afe78 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModelSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModelSuite.scala @@ -72,6 +72,27 @@ class MatrixFactorizationModelSuite extends SparkFunSuite with MLlibTestSparkCon } } + test("invalid user and product") { + val model = new MatrixFactorizationModel(rank, userFeatures, prodFeatures) + + intercept[IllegalArgumentException] { + // invalid user + model.predict(5, 2) + } + intercept[IllegalArgumentException] { + // invalid product + model.predict(0, 5) + } + intercept[IllegalArgumentException] { + // invalid user + model.recommendProducts(5, 2) + } + intercept[IllegalArgumentException] { + // invalid product + model.recommendUsers(5, 2) + } + } + test("batch predict API recommendProductsForUsers") { val model = new MatrixFactorizationModel(rank, userFeatures, prodFeatures) val topK = 10 diff --git a/pom.xml b/pom.xml index 23bbd3b09734e..039292337eaa0 100644 --- a/pom.xml +++ b/pom.xml @@ -104,6 +104,7 @@ external/kafka-0-10 external/kafka-0-10-assembly external/kafka-0-10-sql + external/avro @@ -155,7 +156,7 @@ 3.4.1 3.2.2 - 2.11.8 + 2.11.12 2.11 1.9.13 2.6.7 @@ -313,13 +314,13 @@ chill-java ${chill.version}
    - org.apache.xbean - xbean-asm5-shaded - 4.4 + xbean-asm6-shaded + 4.8 jline jline - 2.12.1 + 2.14.3 org.scalatest @@ -760,6 +761,12 @@ 1.10.19 test + + org.jmock + jmock-junit4 + test + 2.8.4 + org.scalacheck scalacheck_${scala.binary.version} @@ -2116,7 +2123,7 @@ org.apache.maven.plugins maven-surefire-plugin - 2.20.1 + 2.22.0 @@ -2600,6 +2607,28 @@ + + com.github.spotbugs + spotbugs-maven-plugin + 3.1.3 + + ${basedir}/target/scala-${scala.binary.version}/classes + ${basedir}/target/scala-${scala.binary.version}/test-classes + Max + Low + true + FindPuzzlers + false + + + + + check + + compile + + + @@ -2749,7 +2778,7 @@ scala-2.12 - 2.12.4 + 2.12.6 2.12 diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 4f6d5ff898681..8f96bb0f33849 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -36,6 +36,14 @@ object MimaExcludes { // Exclude rules for 2.4.x lazy val v24excludes = v23excludes ++ Seq( + // [SPARK-23528] Add numIter to ClusteringSummary + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.clustering.ClusteringSummary.this"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.clustering.KMeansSummary.this"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.clustering.BisectingKMeansSummary.this"), + + // [SPARK-6237][NETWORK] Network-layer changes to allow stream upload + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.network.netty.NettyBlockRpcServer.receive"), + // [SPARK-20087][CORE] Attach accumulators / metrics to 'TaskKilled' end reason ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.TaskKilled.apply"), ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.TaskKilled.copy"), diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index b606f9355e03b..247b6fee394bc 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -40,8 +40,8 @@ object BuildCommons { private val buildLocation = file(".").getAbsoluteFile.getParentFile - val sqlProjects@Seq(catalyst, sql, hive, hiveThriftServer, sqlKafka010) = Seq( - "catalyst", "sql", "hive", "hive-thriftserver", "sql-kafka-0-10" + val sqlProjects@Seq(catalyst, sql, hive, hiveThriftServer, sqlKafka010, avro) = Seq( + "catalyst", "sql", "hive", "hive-thriftserver", "sql-kafka-0-10", "avro" ).map(ProjectRef(buildLocation, _)) val streamingProjects@Seq(streaming, streamingKafka010) = @@ -326,7 +326,7 @@ object SparkBuild extends PomBuild { val mimaProjects = allProjects.filterNot { x => Seq( spark, hive, hiveThriftServer, catalyst, repl, networkCommon, networkShuffle, networkYarn, - unsafe, tags, sqlKafka010, kvstore + unsafe, tags, sqlKafka010, kvstore, avro ).contains(x) } @@ -464,7 +464,8 @@ object DockerIntegrationTests { */ object DependencyOverrides { lazy val settings = Seq( - dependencyOverrides += "com.google.guava" % "guava" % "14.0.1") + dependencyOverrides += "com.google.guava" % "guava" % "14.0.1", + dependencyOverrides += "jline" % "jline" % "2.14.3") } /** @@ -687,9 +688,11 @@ object Unidoc { publish := {}, unidocProjectFilter in(ScalaUnidoc, unidoc) := - inAnyProject -- inProjects(OldDeps.project, repl, examples, tools, streamingFlumeSink, kubernetes, yarn, tags, streamingKafka010, sqlKafka010), + inAnyProject -- inProjects(OldDeps.project, repl, examples, tools, streamingFlumeSink, kubernetes, + yarn, tags, streamingKafka010, sqlKafka010, avro), unidocProjectFilter in(JavaUnidoc, unidoc) := - inAnyProject -- inProjects(OldDeps.project, repl, examples, tools, streamingFlumeSink, kubernetes, yarn, tags, streamingKafka010, sqlKafka010), + inAnyProject -- inProjects(OldDeps.project, repl, examples, tools, streamingFlumeSink, kubernetes, + yarn, tags, streamingKafka010, sqlKafka010, avro), unidocAllClasspaths in (ScalaUnidoc, unidoc) := { ignoreClasspaths((unidocAllClasspaths in (ScalaUnidoc, unidoc)).value) diff --git a/python/docs/Makefile b/python/docs/Makefile index b8e079483c90c..1ed1f33af2326 100644 --- a/python/docs/Makefile +++ b/python/docs/Makefile @@ -1,19 +1,44 @@ # Makefile for Sphinx documentation # +ifndef SPHINXBUILD +ifndef SPHINXPYTHON +SPHINXBUILD = sphinx-build +endif +endif + +ifdef SPHINXBUILD +# User-friendly check for sphinx-build. +ifeq ($(shell which $(SPHINXBUILD) >/dev/null 2>&1; echo $$?), 1) +$(error The '$(SPHINXBUILD)' command was not found. Make sure you have Sphinx installed, then set the SPHINXBUILD environment variable to point to the full path of the '$(SPHINXBUILD)' executable. Alternatively you can add the directory with the executable to your PATH. If you don't have Sphinx installed, grab it from http://sphinx-doc.org/) +endif +else +# Note that there is an issue with Python version and Sphinx in PySpark documentation generation. +# Please remove this check below when this issue is fixed. See SPARK-24530 for more details. +PYTHON_VERSION_CHECK = $(shell $(SPHINXPYTHON) -c 'import sys; print(sys.version_info < (3, 0, 0))') +ifeq ($(PYTHON_VERSION_CHECK), True) +$(error Note that Python 3 is required to generate PySpark documentation correctly for now. Current Python executable was less than Python 3. See SPARK-24530. To force Sphinx to use a specific Python executable, please set SPHINXPYTHON to point to the Python 3 executable.) +endif +# Check if Sphinx is installed. +ifeq ($(shell $(SPHINXPYTHON) -c 'import sphinx' >/dev/null 2>&1; echo $$?), 1) +$(error Python executable '$(SPHINXPYTHON)' did not have Sphinx installed. Make sure you have Sphinx installed, then set the SPHINXPYTHON environment variable to point to the Python executable having Sphinx installed. If you don't have Sphinx installed, grab it from http://sphinx-doc.org/) +endif +# Use 'SPHINXPYTHON -msphinx' instead of 'sphinx-build'. See https://github.com/sphinx-doc/sphinx/pull/3523 for more details. +SPHINXBUILD = $(SPHINXPYTHON) -msphinx +endif + # You can set these variables from the command line. SPHINXOPTS ?= -SPHINXBUILD ?= sphinx-build PAPER ?= BUILDDIR ?= _build +# You can set SPHINXBUILD to specify Sphinx build executable or SPHINXPYTHON to specify the Python executable used in Sphinx. +# They follow: +# 1. if SPHINXPYTHON is set, use Python. If SPHINXBUILD is set, use sphinx-build. +# 2. If both are set, SPHINXBUILD has a higher priority over SPHINXPYTHON +# 3. By default, SPHINXBUILD is used as 'sphinx-build'. export PYTHONPATH=$(realpath ..):$(realpath ../lib/py4j-0.10.7-src.zip) -# User-friendly check for sphinx-build -ifeq ($(shell which $(SPHINXBUILD) >/dev/null 2>&1; echo $$?), 1) -$(error The '$(SPHINXBUILD)' command was not found. Make sure you have Sphinx installed, then set the SPHINXBUILD environment variable to point to the full path of the '$(SPHINXBUILD)' executable. Alternatively you can add the directory with the executable to your PATH. If you don't have Sphinx installed, grab it from http://sphinx-doc.org/) -endif - # Internal variables. PAPEROPT_a4 = -D latex_paper_size=a4 PAPEROPT_letter = -D latex_paper_size=letter diff --git a/python/pyspark/context.py b/python/pyspark/context.py index ede3b6af0a8cf..2cb3117184334 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -847,6 +847,8 @@ def addFile(self, path, recursive=False): A directory can be given if the recursive option is set to True. Currently directories are only supported for Hadoop-supported filesystems. + .. note:: A path can be added only once. Subsequent additions of the same path are ignored. + >>> from pyspark import SparkFiles >>> path = os.path.join(tempdir, "test.txt") >>> with open(path, "w") as testFile: @@ -867,6 +869,8 @@ def addPyFile(self, path): SparkContext in the future. The C{path} passed can be either a local file, a file in HDFS (or other Hadoop-supported filesystems), or an HTTP, HTTPS or FTP URI. + + .. note:: A path can be added only once. Subsequent additions of the same path are ignored. """ self.addFile(path) (dirname, filename) = os.path.split(path) # dirname may be directory or HDFS/S3 prefix diff --git a/python/pyspark/java_gateway.py b/python/pyspark/java_gateway.py index 0afbe9dc6aa3e..fa2d5e8db716a 100644 --- a/python/pyspark/java_gateway.py +++ b/python/pyspark/java_gateway.py @@ -31,7 +31,7 @@ if sys.version >= '3': xrange = range -from py4j.java_gateway import java_import, JavaGateway, GatewayParameters +from py4j.java_gateway import java_import, JavaGateway, JavaObject, GatewayParameters from pyspark.find_spark_home import _find_spark_home from pyspark.serializers import read_int, write_with_length, UTF8Deserializer @@ -145,3 +145,26 @@ def do_server_auth(conn, auth_secret): if reply != "ok": conn.close() raise Exception("Unexpected reply from iterator server.") + + +def ensure_callback_server_started(gw): + """ + Start callback server if not already started. The callback server is needed if the Java + driver process needs to callback into the Python driver process to execute Python code. + """ + + # getattr will fallback to JVM, so we cannot test by hasattr() + if "_callback_server" not in gw.__dict__ or gw._callback_server is None: + gw.callback_server_parameters.eager_load = True + gw.callback_server_parameters.daemonize = True + gw.callback_server_parameters.daemonize_connections = True + gw.callback_server_parameters.port = 0 + gw.start_callback_server(gw.callback_server_parameters) + cbport = gw._callback_server.server_socket.getsockname()[1] + gw._callback_server.port = cbport + # gateway with real port + gw._python_proxy_port = gw._callback_server.port + # get the GatewayServer object in JVM by ID + jgws = JavaObject("GATEWAY_SERVER", gw._gateway_client) + # update the port of CallbackClient with real port + jgws.resetCallbackClient(jgws.getCallbackClient().getAddress(), gw._python_proxy_port) diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index 1754c48937a62..d5963f4f7042c 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -239,6 +239,8 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti True >>> blorModel.intercept == model2.intercept True + >>> model2 + LogisticRegressionModel: uid = ..., numClasses = 2, numFeatures = 2 .. versionadded:: 1.3.0 """ @@ -562,6 +564,9 @@ def evaluate(self, dataset): java_blr_summary = self._call_java("evaluate", dataset) return BinaryLogisticRegressionSummary(java_blr_summary) + def __repr__(self): + return self._call_java("toString") + class LogisticRegressionSummary(JavaWrapper): """ diff --git a/python/pyspark/ml/clustering.py b/python/pyspark/ml/clustering.py index 4aa1cf84b5824..2f0660040dc7c 100644 --- a/python/pyspark/ml/clustering.py +++ b/python/pyspark/ml/clustering.py @@ -349,8 +349,8 @@ def summary(self): @inherit_doc -class KMeans(JavaEstimator, HasFeaturesCol, HasPredictionCol, HasMaxIter, HasTol, HasSeed, - JavaMLWritable, JavaMLReadable): +class KMeans(JavaEstimator, HasDistanceMeasure, HasFeaturesCol, HasPredictionCol, HasMaxIter, + HasTol, HasSeed, JavaMLWritable, JavaMLReadable): """ K-means clustering with a k-means++ like initialization mode (the k-means|| algorithm by Bahmani et al). @@ -406,9 +406,6 @@ class KMeans(JavaEstimator, HasFeaturesCol, HasPredictionCol, HasMaxIter, HasTol typeConverter=TypeConverters.toString) initSteps = Param(Params._dummy(), "initSteps", "The number of steps for k-means|| " + "initialization mode. Must be > 0.", typeConverter=TypeConverters.toInt) - distanceMeasure = Param(Params._dummy(), "distanceMeasure", "The distance measure. " + - "Supported options: 'euclidean' and 'cosine'.", - typeConverter=TypeConverters.toString) @keyword_only def __init__(self, featuresCol="features", predictionCol="prediction", k=2, @@ -544,8 +541,8 @@ def summary(self): @inherit_doc -class BisectingKMeans(JavaEstimator, HasFeaturesCol, HasPredictionCol, HasMaxIter, HasSeed, - JavaMLWritable, JavaMLReadable): +class BisectingKMeans(JavaEstimator, HasDistanceMeasure, HasFeaturesCol, HasPredictionCol, + HasMaxIter, HasSeed, JavaMLWritable, JavaMLReadable): """ A bisecting k-means algorithm based on the paper "A comparison of document clustering techniques" by Steinbach, Karypis, and Kumar, with modification to fit Spark. @@ -585,6 +582,8 @@ class BisectingKMeans(JavaEstimator, HasFeaturesCol, HasPredictionCol, HasMaxIte >>> bkm2 = BisectingKMeans.load(bkm_path) >>> bkm2.getK() 2 + >>> bkm2.getDistanceMeasure() + 'euclidean' >>> model_path = temp_path + "/bkm_model" >>> model.save(model_path) >>> model2 = BisectingKMeansModel.load(model_path) @@ -607,10 +606,10 @@ class BisectingKMeans(JavaEstimator, HasFeaturesCol, HasPredictionCol, HasMaxIte @keyword_only def __init__(self, featuresCol="features", predictionCol="prediction", maxIter=20, - seed=None, k=4, minDivisibleClusterSize=1.0): + seed=None, k=4, minDivisibleClusterSize=1.0, distanceMeasure="euclidean"): """ __init__(self, featuresCol="features", predictionCol="prediction", maxIter=20, \ - seed=None, k=4, minDivisibleClusterSize=1.0) + seed=None, k=4, minDivisibleClusterSize=1.0, distanceMeasure="euclidean") """ super(BisectingKMeans, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.clustering.BisectingKMeans", @@ -622,10 +621,10 @@ def __init__(self, featuresCol="features", predictionCol="prediction", maxIter=2 @keyword_only @since("2.0.0") def setParams(self, featuresCol="features", predictionCol="prediction", maxIter=20, - seed=None, k=4, minDivisibleClusterSize=1.0): + seed=None, k=4, minDivisibleClusterSize=1.0, distanceMeasure="euclidean"): """ setParams(self, featuresCol="features", predictionCol="prediction", maxIter=20, \ - seed=None, k=4, minDivisibleClusterSize=1.0) + seed=None, k=4, minDivisibleClusterSize=1.0, distanceMeasure="euclidean") Sets params for BisectingKMeans. """ kwargs = self._input_kwargs @@ -659,6 +658,20 @@ def getMinDivisibleClusterSize(self): """ return self.getOrDefault(self.minDivisibleClusterSize) + @since("2.4.0") + def setDistanceMeasure(self, value): + """ + Sets the value of :py:attr:`distanceMeasure`. + """ + return self._set(distanceMeasure=value) + + @since("2.4.0") + def getDistanceMeasure(self): + """ + Gets the value of `distanceMeasure` or its default value. + """ + return self.getOrDefault(self.distanceMeasure) + def _create_model(self, java_model): return BisectingKMeansModel(java_model) @@ -1332,8 +1345,14 @@ def assignClusters(self, dataset): if __name__ == "__main__": import doctest + import numpy import pyspark.ml.clustering from pyspark.sql import SparkSession + try: + # Numpy 1.14+ changed it's string format. + numpy.set_printoptions(legacy='1.13') + except TypeError: + pass globs = pyspark.ml.clustering.__dict__.copy() # The small batch size here ensures that we see multiple batches, # even in these small test examples: diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index 14800d4d9327a..ddba7389145e3 100755 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -1294,14 +1294,14 @@ class MinHashLSH(JavaEstimator, LSHParams, HasInputCol, HasOutputCol, HasSeed, >>> mh = MinHashLSH(inputCol="features", outputCol="hashes", seed=12345) >>> model = mh.fit(df) >>> model.transform(df).head() - Row(id=0, features=SparseVector(6, {0: 1.0, 1: 1.0, 2: 1.0}), hashes=[DenseVector([-1638925... + Row(id=0, features=SparseVector(6, {0: 1.0, 1: 1.0, 2: 1.0}), hashes=[DenseVector([6179668... >>> data2 = [(3, Vectors.sparse(6, [1, 3, 5], [1.0, 1.0, 1.0]),), ... (4, Vectors.sparse(6, [2, 3, 5], [1.0, 1.0, 1.0]),), ... (5, Vectors.sparse(6, [1, 2, 4], [1.0, 1.0, 1.0]),)] >>> df2 = spark.createDataFrame(data2, ["id", "features"]) >>> key = Vectors.sparse(6, [1, 2], [1.0, 1.0]) >>> model.approxNearestNeighbors(df2, key, 1).collect() - [Row(id=5, features=SparseVector(6, {1: 1.0, 2: 1.0, 4: 1.0}), hashes=[DenseVector([-163892... + [Row(id=5, features=SparseVector(6, {1: 1.0, 2: 1.0, 4: 1.0}), hashes=[DenseVector([6179668... >>> model.approxSimilarityJoin(df, df2, 0.6, distCol="JaccardDistance").select( ... col("datasetA.id").alias("idA"), ... col("datasetB.id").alias("idB"), @@ -1309,8 +1309,8 @@ class MinHashLSH(JavaEstimator, LSHParams, HasInputCol, HasOutputCol, HasSeed, +---+---+---------------+ |idA|idB|JaccardDistance| +---+---+---------------+ - | 1| 4| 0.5| | 0| 5| 0.5| + | 1| 4| 0.5| +---+---+---------------+ ... >>> mhPath = temp_path + "/mh" diff --git a/python/pyspark/ml/linalg/__init__.py b/python/pyspark/ml/linalg/__init__.py index 6a611a2b5b59d..2548fd0f50b33 100644 --- a/python/pyspark/ml/linalg/__init__.py +++ b/python/pyspark/ml/linalg/__init__.py @@ -1156,6 +1156,11 @@ def sparse(numRows, numCols, colPtrs, rowIndices, values): def _test(): import doctest + try: + # Numpy 1.14+ changed it's string format. + np.set_printoptions(legacy='1.13') + except TypeError: + pass (failure_count, test_count) = doctest.testmod(optionflags=doctest.ELLIPSIS) if failure_count: sys.exit(-1) diff --git a/python/pyspark/ml/param/_shared_params_code_gen.py b/python/pyspark/ml/param/_shared_params_code_gen.py index 6e9e0a34cdfde..e45ba840b412b 100644 --- a/python/pyspark/ml/param/_shared_params_code_gen.py +++ b/python/pyspark/ml/param/_shared_params_code_gen.py @@ -162,7 +162,9 @@ def get$Name(self): "fitting. If set to true, then all sub-models will be available. Warning: For large " + "models, collecting all sub-models can cause OOMs on the Spark driver.", "False", "TypeConverters.toBoolean"), - ("loss", "the loss function to be optimized.", None, "TypeConverters.toString")] + ("loss", "the loss function to be optimized.", None, "TypeConverters.toString"), + ("distanceMeasure", "the distance measure. Supported options: 'euclidean' and 'cosine'.", + "'euclidean'", "TypeConverters.toString")] code = [] for name, doc, defaultValueStr, typeConverter in shared: diff --git a/python/pyspark/ml/param/shared.py b/python/pyspark/ml/param/shared.py index 08408ee8fbfcc..618f5bf0a8103 100644 --- a/python/pyspark/ml/param/shared.py +++ b/python/pyspark/ml/param/shared.py @@ -790,3 +790,27 @@ def getCacheNodeIds(self): """ return self.getOrDefault(self.cacheNodeIds) + +class HasDistanceMeasure(Params): + """ + Mixin for param distanceMeasure: the distance measure. Supported options: 'euclidean' and 'cosine'. + """ + + distanceMeasure = Param(Params._dummy(), "distanceMeasure", "the distance measure. Supported options: 'euclidean' and 'cosine'.", typeConverter=TypeConverters.toString) + + def __init__(self): + super(HasDistanceMeasure, self).__init__() + self._setDefault(distanceMeasure='euclidean') + + def setDistanceMeasure(self, value): + """ + Sets the value of :py:attr:`distanceMeasure`. + """ + return self._set(distanceMeasure=value) + + def getDistanceMeasure(self): + """ + Gets the value of distanceMeasure or its default value. + """ + return self.getOrDefault(self.distanceMeasure) + diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py index dba0e57b01a0b..83f0edb397271 100644 --- a/python/pyspark/ml/regression.py +++ b/python/pyspark/ml/regression.py @@ -95,6 +95,7 @@ class LinearRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPrediction True >>> model.numFeatures 1 + >>> model.write().format("pmml").save(model_path + "_2") .. versionadded:: 1.4.0 """ @@ -161,7 +162,7 @@ def getEpsilon(self): return self.getOrDefault(self.epsilon) -class LinearRegressionModel(JavaModel, JavaPredictionModel, JavaMLWritable, JavaMLReadable): +class LinearRegressionModel(JavaModel, JavaPredictionModel, GeneralJavaMLWritable, JavaMLReadable): """ Model fitted by :class:`LinearRegression`. diff --git a/python/pyspark/ml/stat.py b/python/pyspark/ml/stat.py index a06ab31a7a56a..370154fc6d62a 100644 --- a/python/pyspark/ml/stat.py +++ b/python/pyspark/ml/stat.py @@ -388,8 +388,14 @@ def summary(self, featuresCol, weightCol=None): if __name__ == "__main__": import doctest + import numpy import pyspark.ml.stat from pyspark.sql import SparkSession + try: + # Numpy 1.14+ changed it's string format. + numpy.set_printoptions(legacy='1.13') + except TypeError: + pass globs = pyspark.ml.stat.__dict__.copy() # The small batch size here ensures that we see multiple batches, diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index ebd36cbb5f7a7..bc782138292bf 100755 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -1362,6 +1362,23 @@ def test_linear_regression(self): except OSError: pass + def test_linear_regression_pmml_basic(self): + # Most of the validation is done in the Scala side, here we just check + # that we output text rather than parquet (e.g. that the format flag + # was respected). + df = self.spark.createDataFrame([(1.0, 2.0, Vectors.dense(1.0)), + (0.0, 2.0, Vectors.sparse(1, [], []))], + ["label", "weight", "features"]) + lr = LinearRegression(maxIter=1) + model = lr.fit(df) + path = tempfile.mkdtemp() + lr_path = path + "/lr-pmml" + model.write().format("pmml").save(lr_path) + pmml_text_list = self.sc.textFile(lr_path).collect() + pmml_text = "\n".join(pmml_text_list) + self.assertIn("Apache Spark", pmml_text) + self.assertIn("PMML", pmml_text) + def test_logistic_regression(self): lr = LogisticRegression(maxIter=1) path = tempfile.mkdtemp() diff --git a/python/pyspark/ml/util.py b/python/pyspark/ml/util.py index 9fa85664939b8..e846834761e49 100644 --- a/python/pyspark/ml/util.py +++ b/python/pyspark/ml/util.py @@ -63,7 +63,7 @@ def _randomUID(cls): Generate a unique unicode id for the object. The default implementation concatenates the class name, "_", and 12 random hex chars. """ - return unicode(cls.__name__ + "_" + uuid.uuid4().hex[12:]) + return unicode(cls.__name__ + "_" + uuid.uuid4().hex[-12:]) @inherit_doc @@ -148,6 +148,23 @@ def overwrite(self): return self +@inherit_doc +class GeneralMLWriter(MLWriter): + """ + Utility class that can save ML instances in different formats. + + .. versionadded:: 2.4.0 + """ + + def format(self, source): + """ + Specifies the format of ML export (e.g. "pmml", "internal", or the fully qualified class + name for export). + """ + self.source = source + return self + + @inherit_doc class JavaMLWriter(MLWriter): """ @@ -192,6 +209,24 @@ def session(self, sparkSession): return self +@inherit_doc +class GeneralJavaMLWriter(JavaMLWriter): + """ + (Private) Specialization of :py:class:`GeneralMLWriter` for :py:class:`JavaParams` types + """ + + def __init__(self, instance): + super(GeneralJavaMLWriter, self).__init__(instance) + + def format(self, source): + """ + Specifies the format of ML export (e.g. "pmml", "internal", or the fully qualified class + name for export). + """ + self._jwrite.format(source) + return self + + @inherit_doc class MLWritable(object): """ @@ -220,6 +255,17 @@ def write(self): return JavaMLWriter(self) +@inherit_doc +class GeneralJavaMLWritable(JavaMLWritable): + """ + (Private) Mixin for ML instances that provide :py:class:`GeneralJavaMLWriter`. + """ + + def write(self): + """Returns an GeneralMLWriter instance for this ML instance.""" + return GeneralJavaMLWriter(self) + + @inherit_doc class MLReader(BaseReadWrite): """ diff --git a/python/pyspark/mllib/classification.py b/python/pyspark/mllib/classification.py index bb281981fd56b..e00ed95ef0701 100644 --- a/python/pyspark/mllib/classification.py +++ b/python/pyspark/mllib/classification.py @@ -258,6 +258,9 @@ def load(cls, sc, path): model.setThreshold(threshold) return model + def __repr__(self): + return self._call_java("toString") + class LogisticRegressionWithSGD(object): """ diff --git a/python/pyspark/mllib/clustering.py b/python/pyspark/mllib/clustering.py index 0cbabab13a896..b09469b9f5c2d 100644 --- a/python/pyspark/mllib/clustering.py +++ b/python/pyspark/mllib/clustering.py @@ -1042,7 +1042,13 @@ def train(cls, rdd, k=10, maxIterations=20, docConcentration=-1.0, def _test(): import doctest + import numpy import pyspark.mllib.clustering + try: + # Numpy 1.14+ changed it's string format. + numpy.set_printoptions(legacy='1.13') + except TypeError: + pass globs = pyspark.mllib.clustering.__dict__.copy() globs['sc'] = SparkContext('local[4]', 'PythonTest', batchSize=2) (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) diff --git a/python/pyspark/mllib/evaluation.py b/python/pyspark/mllib/evaluation.py index 36cb03369b8c0..6c65da58e4e2b 100644 --- a/python/pyspark/mllib/evaluation.py +++ b/python/pyspark/mllib/evaluation.py @@ -532,8 +532,14 @@ def accuracy(self): def _test(): import doctest + import numpy from pyspark.sql import SparkSession import pyspark.mllib.evaluation + try: + # Numpy 1.14+ changed it's string format. + numpy.set_printoptions(legacy='1.13') + except TypeError: + pass globs = pyspark.mllib.evaluation.__dict__.copy() spark = SparkSession.builder\ .master("local[4]")\ diff --git a/python/pyspark/mllib/linalg/__init__.py b/python/pyspark/mllib/linalg/__init__.py index 60d96d8d5ceb8..4afd6666400b0 100644 --- a/python/pyspark/mllib/linalg/__init__.py +++ b/python/pyspark/mllib/linalg/__init__.py @@ -1368,6 +1368,12 @@ def R(self): def _test(): import doctest + import numpy + try: + # Numpy 1.14+ changed it's string format. + numpy.set_printoptions(legacy='1.13') + except TypeError: + pass (failure_count, test_count) = doctest.testmod(optionflags=doctest.ELLIPSIS) if failure_count: sys.exit(-1) diff --git a/python/pyspark/mllib/linalg/distributed.py b/python/pyspark/mllib/linalg/distributed.py index bba88542167ad..7e8b15056cabe 100644 --- a/python/pyspark/mllib/linalg/distributed.py +++ b/python/pyspark/mllib/linalg/distributed.py @@ -1364,9 +1364,15 @@ def toCoordinateMatrix(self): def _test(): import doctest + import numpy from pyspark.sql import SparkSession from pyspark.mllib.linalg import Matrices import pyspark.mllib.linalg.distributed + try: + # Numpy 1.14+ changed it's string format. + numpy.set_printoptions(legacy='1.13') + except TypeError: + pass globs = pyspark.mllib.linalg.distributed.__dict__.copy() spark = SparkSession.builder\ .master("local[2]")\ diff --git a/python/pyspark/mllib/stat/_statistics.py b/python/pyspark/mllib/stat/_statistics.py index 3c75b132ecad2..937bb154c2356 100644 --- a/python/pyspark/mllib/stat/_statistics.py +++ b/python/pyspark/mllib/stat/_statistics.py @@ -303,7 +303,13 @@ def kolmogorovSmirnovTest(data, distName="norm", *params): def _test(): import doctest + import numpy from pyspark.sql import SparkSession + try: + # Numpy 1.14+ changed it's string format. + numpy.set_printoptions(legacy='1.13') + except TypeError: + pass globs = globals().copy() spark = SparkSession.builder\ .master("local[4]")\ diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 7e7e5822a6b20..951851804b1d8 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -1370,7 +1370,10 @@ def takeUpToNumLeft(iterator): iterator = iter(iterator) taken = 0 while taken < left: - yield next(iterator) + try: + yield next(iterator) + except StopIteration: + return taken += 1 p = range(partsScanned, min(partsScanned + numPartsToTry, totalParts)) diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index 15753f77bd903..4c16b5fc26f3d 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -33,8 +33,9 @@ [0, 2, 4, 6, 8, 10, 12, 14, 16, 18] >>> sc.stop() -PySpark serialize objects in batches; By default, the batch size is chosen based -on the size of objects, also configurable by SparkContext's C{batchSize} parameter: +PySpark serializes objects in batches; by default, the batch size is chosen based +on the size of objects and is also configurable by SparkContext's C{batchSize} +parameter: >>> sc = SparkContext('local', 'test', batchSize=2) >>> rdd = sc.parallelize(range(16), 4).map(lambda x: x) @@ -100,7 +101,7 @@ def load_stream(self, stream): def _load_stream_without_unbatching(self, stream): """ Return an iterator of deserialized batches (iterable) of objects from the input stream. - if the serializer does not operate on batches the default implementation returns an + If the serializer does not operate on batches the default implementation returns an iterator of single element lists. """ return map(lambda x: [x], self.load_stream(stream)) @@ -461,7 +462,7 @@ def dumps(self, obj): return obj -# Hook namedtuple, make it picklable +# Hack namedtuple, make it picklable __cls = {} @@ -525,15 +526,15 @@ def namedtuple(*args, **kwargs): cls = _old_namedtuple(*args, **kwargs) return _hack_namedtuple(cls) - # replace namedtuple with new one + # replace namedtuple with the new one collections.namedtuple.__globals__["_old_namedtuple_kwdefaults"] = _old_namedtuple_kwdefaults collections.namedtuple.__globals__["_old_namedtuple"] = _old_namedtuple collections.namedtuple.__globals__["_hack_namedtuple"] = _hack_namedtuple collections.namedtuple.__code__ = namedtuple.__code__ collections.namedtuple.__hijack = 1 - # hack the cls already generated by namedtuple - # those created in other module can be pickled as normal, + # hack the cls already generated by namedtuple. + # Those created in other modules can be pickled as normal, # so only hack those in __main__ module for n, o in sys.modules["__main__"].__dict__.items(): if (type(o) is type and o.__base__ is tuple @@ -627,7 +628,7 @@ def loads(self, obj): elif _type == b'P': return pickle.loads(obj[1:]) else: - raise ValueError("invalid sevialization type: %s" % _type) + raise ValueError("invalid serialization type: %s" % _type) class CompressedSerializer(FramedSerializer): diff --git a/python/pyspark/sql/conf.py b/python/pyspark/sql/conf.py index db49040e17b63..f80bf598c2211 100644 --- a/python/pyspark/sql/conf.py +++ b/python/pyspark/sql/conf.py @@ -63,6 +63,14 @@ def _checkType(self, obj, identifier): raise TypeError("expected %s '%s' to be a string (was '%s')" % (identifier, obj, type(obj).__name__)) + @ignore_unicode_prefix + @since(2.4) + def isModifiable(self, key): + """Indicates whether the configuration property with the given key + is modifiable in the current session. + """ + return self._jconf.isModifiable(key) + def _test(): import os diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py index e9ec7ba866761..9c094dd9a9033 100644 --- a/python/pyspark/sql/context.py +++ b/python/pyspark/sql/context.py @@ -93,6 +93,11 @@ def _ssql_ctx(self): """ return self._jsqlContext + @property + def _conf(self): + """Accessor for the JVM SQL-specific configurations""" + return self.sparkSession._jsparkSession.sessionState().conf() + @classmethod @since(1.6) def getOrCreate(cls, sc): diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 1e6a1acebb5ca..c40aea9bcef0a 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -354,32 +354,12 @@ def show(self, n=20, truncate=True, vertical=False): else: print(self._jdf.showString(n, int(truncate), vertical)) - @property - def _eager_eval(self): - """Returns true if the eager evaluation enabled. - """ - return self.sql_ctx.getConf( - "spark.sql.repl.eagerEval.enabled", "false").lower() == "true" - - @property - def _max_num_rows(self): - """Returns the max row number for eager evaluation. - """ - return int(self.sql_ctx.getConf( - "spark.sql.repl.eagerEval.maxNumRows", "20")) - - @property - def _truncate(self): - """Returns the truncate length for eager evaluation. - """ - return int(self.sql_ctx.getConf( - "spark.sql.repl.eagerEval.truncate", "20")) - def __repr__(self): - if not self._support_repr_html and self._eager_eval: + if not self._support_repr_html and self.sql_ctx._conf.isReplEagerEvalEnabled(): vertical = False return self._jdf.showString( - self._max_num_rows, self._truncate, vertical) + self.sql_ctx._conf.replEagerEvalMaxNumRows(), + self.sql_ctx._conf.replEagerEvalTruncate(), vertical) else: return "DataFrame[%s]" % (", ".join("%s: %s" % c for c in self.dtypes)) @@ -391,11 +371,10 @@ def _repr_html_(self): import cgi if not self._support_repr_html: self._support_repr_html = True - if self._eager_eval: - max_num_rows = max(self._max_num_rows, 0) - vertical = False + if self.sql_ctx._conf.isReplEagerEvalEnabled(): + max_num_rows = max(self.sql_ctx._conf.replEagerEvalMaxNumRows(), 0) sock_info = self._jdf.getRowsToPython( - max_num_rows, self._truncate, vertical) + max_num_rows, self.sql_ctx._conf.replEagerEvalTruncate()) rows = list(_load_from_socket(sock_info, BatchedSerializer(PickleSerializer()))) head = rows[0] row_data = rows[1:] @@ -2050,13 +2029,12 @@ def toPandas(self): import pandas as pd - if self.sql_ctx.getConf("spark.sql.execution.pandas.respectSessionTimeZone").lower() \ - == "true": - timezone = self.sql_ctx.getConf("spark.sql.session.timeZone") + if self.sql_ctx._conf.pandasRespectSessionTimeZone(): + timezone = self.sql_ctx._conf.sessionLocalTimeZone() else: timezone = None - if self.sql_ctx.getConf("spark.sql.execution.arrow.enabled", "false").lower() == "true": + if self.sql_ctx._conf.arrowEnabled(): use_arrow = True try: from pyspark.sql.types import to_arrow_schema @@ -2066,8 +2044,7 @@ def toPandas(self): to_arrow_schema(self.schema) except Exception as e: - if self.sql_ctx.getConf("spark.sql.execution.arrow.fallback.enabled", "true") \ - .lower() == "true": + if self.sql_ctx._conf.arrowFallbackEnabled(): msg = ( "toPandas attempted Arrow optimization because " "'spark.sql.execution.arrow.enabled' is set to true; however, " diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index a5e3384e802b8..5ef73987a66a6 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -1285,11 +1285,21 @@ def from_utc_timestamp(timestamp, tz): that time as a timestamp in the given time zone. For example, 'GMT+1' would yield '2017-07-14 03:40:00.0'. - >>> df = spark.createDataFrame([('1997-02-28 10:30:00',)], ['t']) - >>> df.select(from_utc_timestamp(df.t, "PST").alias('local_time')).collect() + :param timestamp: the column that contains timestamps + :param tz: a string that has the ID of timezone, e.g. "GMT", "America/Los_Angeles", etc + + .. versionchanged:: 2.4 + `tz` can take a :class:`Column` containing timezone ID strings. + + >>> df = spark.createDataFrame([('1997-02-28 10:30:00', 'JST')], ['ts', 'tz']) + >>> df.select(from_utc_timestamp(df.ts, "PST").alias('local_time')).collect() [Row(local_time=datetime.datetime(1997, 2, 28, 2, 30))] + >>> df.select(from_utc_timestamp(df.ts, df.tz).alias('local_time')).collect() + [Row(local_time=datetime.datetime(1997, 2, 28, 19, 30))] """ sc = SparkContext._active_spark_context + if isinstance(tz, Column): + tz = _to_java_column(tz) return Column(sc._jvm.functions.from_utc_timestamp(_to_java_column(timestamp), tz)) @@ -1300,11 +1310,21 @@ def to_utc_timestamp(timestamp, tz): zone, and renders that time as a timestamp in UTC. For example, 'GMT+1' would yield '2017-07-14 01:40:00.0'. - >>> df = spark.createDataFrame([('1997-02-28 10:30:00',)], ['ts']) + :param timestamp: the column that contains timestamps + :param tz: a string that has the ID of timezone, e.g. "GMT", "America/Los_Angeles", etc + + .. versionchanged:: 2.4 + `tz` can take a :class:`Column` containing timezone ID strings. + + >>> df = spark.createDataFrame([('1997-02-28 10:30:00', 'JST')], ['ts', 'tz']) >>> df.select(to_utc_timestamp(df.ts, "PST").alias('utc_time')).collect() [Row(utc_time=datetime.datetime(1997, 2, 28, 18, 30))] + >>> df.select(to_utc_timestamp(df.ts, df.tz).alias('utc_time')).collect() + [Row(utc_time=datetime.datetime(1997, 2, 28, 1, 30))] """ sc = SparkContext._active_spark_context + if isinstance(tz, Column): + tz = _to_java_column(tz) return Column(sc._jvm.functions.to_utc_timestamp(_to_java_column(timestamp), tz)) @@ -1999,6 +2019,39 @@ def array_remove(col, element): return Column(sc._jvm.functions.array_remove(_to_java_column(col), element)) +@since(2.4) +def array_distinct(col): + """ + Collection function: removes duplicate values from the array. + :param col: name of column or expression + + >>> df = spark.createDataFrame([([1, 2, 3, 2],), ([4, 5, 5, 4],)], ['data']) + >>> df.select(array_distinct(df.data)).collect() + [Row(array_distinct(data)=[1, 2, 3]), Row(array_distinct(data)=[4, 5])] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.array_distinct(_to_java_column(col))) + + +@ignore_unicode_prefix +@since(2.4) +def array_union(col1, col2): + """ + Collection function: returns an array of the elements in the union of col1 and col2, + without duplicates. + + :param col1: name of column containing array + :param col2: name of column containing array + + >>> from pyspark.sql import Row + >>> df = spark.createDataFrame([Row(c1=["b", "a", "c"], c2=["c", "d", "a", "f"])]) + >>> df.select(array_union(df.c1, df.c2)).collect() + [Row(array_union(c1, c2)=[u'b', u'a', u'c', u'd', u'f'])] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.array_union(_to_java_column(col1), _to_java_column(col2))) + + @since(1.4) def explode(col): """Returns a new row for each element in the given array or map. @@ -2168,19 +2221,23 @@ def from_json(col, schema, options={}): [Row(json=Row(a=1))] >>> df.select(from_json(df.value, "a INT").alias("json")).collect() [Row(json=Row(a=1))] - >>> schema = MapType(StringType(), IntegerType()) - >>> df.select(from_json(df.value, schema).alias("json")).collect() + >>> df.select(from_json(df.value, "MAP").alias("json")).collect() [Row(json={u'a': 1})] >>> data = [(1, '''[{"a": 1}]''')] >>> schema = ArrayType(StructType([StructField("a", IntegerType())])) >>> df = spark.createDataFrame(data, ("key", "value")) >>> df.select(from_json(df.value, schema).alias("json")).collect() [Row(json=[Row(a=1)])] + >>> schema = schema_of_json(lit('''{"a": 0}''')) + >>> df.select(from_json(df.value, schema).alias("json")).collect() + [Row(json=Row(a=1))] """ sc = SparkContext._active_spark_context if isinstance(schema, DataType): schema = schema.json() + elif isinstance(schema, Column): + schema = _to_java_column(schema) jc = sc._jvm.functions.from_json(_to_java_column(col), schema, options) return Column(jc) @@ -2222,6 +2279,28 @@ def to_json(col, options={}): return Column(jc) +@ignore_unicode_prefix +@since(2.4) +def schema_of_json(col): + """ + Parses a column containing a JSON string and infers its schema in DDL format. + + :param col: string column in json format + + >>> from pyspark.sql.types import * + >>> data = [(1, '{"a": 1}')] + >>> df = spark.createDataFrame(data, ("key", "value")) + >>> df.select(schema_of_json(df.value).alias("json")).collect() + [Row(json=u'struct')] + >>> df.select(schema_of_json(lit('{"a": 0}')).alias("json")).collect() + [Row(json=u'struct')] + """ + + sc = SparkContext._active_spark_context + jc = sc._jvm.functions.schema_of_json(_to_java_column(col)) + return Column(jc) + + @since(1.5) def size(col): """ @@ -2399,6 +2478,26 @@ def map_entries(col): return Column(sc._jvm.functions.map_entries(_to_java_column(col))) +@since(2.4) +def map_from_entries(col): + """ + Collection function: Returns a map created from the given array of entries. + + :param col: name of column or expression + + >>> from pyspark.sql.functions import map_from_entries + >>> df = spark.sql("SELECT array(struct(1, 'a'), struct(2, 'b')) as data") + >>> df.select(map_from_entries("data").alias("map")).show() + +----------------+ + | map| + +----------------+ + |[1 -> a, 2 -> b]| + +----------------+ + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.map_from_entries(_to_java_column(col))) + + @ignore_unicode_prefix @since(2.4) def array_repeat(col, count): @@ -2430,6 +2529,28 @@ def arrays_zip(*cols): return Column(sc._jvm.functions.arrays_zip(_to_seq(sc, cols, _to_java_column))) +@since(2.4) +def map_concat(*cols): + """Returns the union of all the given maps. + + :param cols: list of column names (string) or list of :class:`Column` expressions + + >>> from pyspark.sql.functions import map_concat + >>> df = spark.sql("SELECT map(1, 'a', 2, 'b') as map1, map(3, 'c', 1, 'd') as map2") + >>> df.select(map_concat("map1", "map2").alias("map3")).show(truncate=False) + +--------------------------------+ + |map3 | + +--------------------------------+ + |[1 -> a, 2 -> b, 3 -> c, 1 -> d]| + +--------------------------------+ + """ + sc = SparkContext._active_spark_context + if len(cols) == 1 and isinstance(cols[0], (list, set)): + cols = cols[0] + jc = sc._jvm.functions.map_concat(_to_seq(sc, cols, _to_java_column)) + return Column(jc) + + # ---------------------------- User Defined Function ---------------------------------- class PandasUDFType(object): @@ -2551,9 +2672,10 @@ def pandas_udf(f=None, returnType=None, functionType=None): A grouped map UDF defines transformation: A `pandas.DataFrame` -> A `pandas.DataFrame` The returnType should be a :class:`StructType` describing the schema of the returned - `pandas.DataFrame`. - The length of the returned `pandas.DataFrame` can be arbitrary and the columns must be - indexed so that their position matches the corresponding field in the schema. + `pandas.DataFrame`. The column labels of the returned `pandas.DataFrame` must either match + the field names in the defined returnType schema if specified as strings, or match the + field data types by position if not strings, e.g. integer indices. + The length of the returned `pandas.DataFrame` can be arbitrary. Grouped map UDFs are used with :meth:`pyspark.sql.GroupedData.apply`. diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index a0e20d39c20da..3efe2adb6e2a4 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -177,7 +177,7 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, allowNumericLeadingZero=None, allowBackslashEscapingAnyCharacter=None, mode=None, columnNameOfCorruptRecord=None, dateFormat=None, timestampFormat=None, multiLine=None, allowUnquotedControlChars=None, lineSep=None, samplingRatio=None, - encoding=None): + dropFieldIfAllNull=None, encoding=None): """ Loads JSON files and returns the results as a :class:`DataFrame`. @@ -246,6 +246,9 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, set, it covers all ``\\r``, ``\\r\\n`` and ``\\n``. :param samplingRatio: defines fraction of input JSON objects used for schema inferring. If None is set, it uses the default value, ``1.0``. + :param dropFieldIfAllNull: whether to ignore column of all null values or empty + array/struct during schema inference. If None is set, it + uses the default value, ``false``. >>> df1 = spark.read.json('python/test_support/sql/people.json') >>> df1.dtypes diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py index e880dd1ca6d1a..f1ad6b1212ed9 100644 --- a/python/pyspark/sql/session.py +++ b/python/pyspark/sql/session.py @@ -567,14 +567,7 @@ def _create_shell_session(): .getOrCreate() else: return SparkSession.builder.getOrCreate() - except py4j.protocol.Py4JError: - if conf.get('spark.sql.catalogImplementation', '').lower() == 'hive': - warnings.warn("Fall back to non-hive support because failing to access HiveConf, " - "please make sure you build spark with hive") - - try: - return SparkSession.builder.getOrCreate() - except TypeError: + except (py4j.protocol.Py4JError, TypeError): if conf.get('spark.sql.catalogImplementation', '').lower() == 'hive': warnings.warn("Fall back to non-hive support because failing to access HiveConf, " "please make sure you build spark with hive") diff --git a/python/pyspark/sql/streaming.py b/python/pyspark/sql/streaming.py index fae50b3d5d532..8c1fd4af674d7 100644 --- a/python/pyspark/sql/streaming.py +++ b/python/pyspark/sql/streaming.py @@ -24,12 +24,14 @@ else: intlike = (int, long) +from py4j.java_gateway import java_import + from pyspark import since, keyword_only from pyspark.rdd import ignore_unicode_prefix from pyspark.sql.column import _to_seq from pyspark.sql.readwriter import OptionUtils, to_str from pyspark.sql.types import * -from pyspark.sql.utils import StreamingQueryException +from pyspark.sql.utils import ForeachBatchFunction, StreamingQueryException __all__ = ["StreamingQuery", "StreamingQueryManager", "DataStreamReader", "DataStreamWriter"] @@ -854,6 +856,197 @@ def trigger(self, processingTime=None, once=None, continuous=None): self._jwrite = self._jwrite.trigger(jTrigger) return self + @since(2.4) + def foreach(self, f): + """ + Sets the output of the streaming query to be processed using the provided writer ``f``. + This is often used to write the output of a streaming query to arbitrary storage systems. + The processing logic can be specified in two ways. + + #. A **function** that takes a row as input. + This is a simple way to express your processing logic. Note that this does + not allow you to deduplicate generated data when failures cause reprocessing of + some input data. That would require you to specify the processing logic in the next + way. + + #. An **object** with a ``process`` method and optional ``open`` and ``close`` methods. + The object can have the following methods. + + * ``open(partition_id, epoch_id)``: *Optional* method that initializes the processing + (for example, open a connection, start a transaction, etc). Additionally, you can + use the `partition_id` and `epoch_id` to deduplicate regenerated data + (discussed later). + + * ``process(row)``: *Non-optional* method that processes each :class:`Row`. + + * ``close(error)``: *Optional* method that finalizes and cleans up (for example, + close connection, commit transaction, etc.) after all rows have been processed. + + The object will be used by Spark in the following way. + + * A single copy of this object is responsible of all the data generated by a + single task in a query. In other words, one instance is responsible for + processing one partition of the data generated in a distributed manner. + + * This object must be serializable because each task will get a fresh + serialized-deserialized copy of the provided object. Hence, it is strongly + recommended that any initialization for writing data (e.g. opening a + connection or starting a transaction) is done after the `open(...)` + method has been called, which signifies that the task is ready to generate data. + + * The lifecycle of the methods are as follows. + + For each partition with ``partition_id``: + + ... For each batch/epoch of streaming data with ``epoch_id``: + + ....... Method ``open(partitionId, epochId)`` is called. + + ....... If ``open(...)`` returns true, for each row in the partition and + batch/epoch, method ``process(row)`` is called. + + ....... Method ``close(errorOrNull)`` is called with error (if any) seen while + processing rows. + + Important points to note: + + * The `partitionId` and `epochId` can be used to deduplicate generated data when + failures cause reprocessing of some input data. This depends on the execution + mode of the query. If the streaming query is being executed in the micro-batch + mode, then every partition represented by a unique tuple (partition_id, epoch_id) + is guaranteed to have the same data. Hence, (partition_id, epoch_id) can be used + to deduplicate and/or transactionally commit data and achieve exactly-once + guarantees. However, if the streaming query is being executed in the continuous + mode, then this guarantee does not hold and therefore should not be used for + deduplication. + + * The ``close()`` method (if exists) will be called if `open()` method exists and + returns successfully (irrespective of the return value), except if the Python + crashes in the middle. + + .. note:: Evolving. + + >>> # Print every row using a function + >>> def print_row(row): + ... print(row) + ... + >>> writer = sdf.writeStream.foreach(print_row) + >>> # Print every row using a object with process() method + >>> class RowPrinter: + ... def open(self, partition_id, epoch_id): + ... print("Opened %d, %d" % (partition_id, epoch_id)) + ... return True + ... def process(self, row): + ... print(row) + ... def close(self, error): + ... print("Closed with error: %s" % str(error)) + ... + >>> writer = sdf.writeStream.foreach(RowPrinter()) + """ + + from pyspark.rdd import _wrap_function + from pyspark.serializers import PickleSerializer, AutoBatchedSerializer + from pyspark.taskcontext import TaskContext + + if callable(f): + # The provided object is a callable function that is supposed to be called on each row. + # Construct a function that takes an iterator and calls the provided function on each + # row. + def func_without_process(_, iterator): + for x in iterator: + f(x) + return iter([]) + + func = func_without_process + + else: + # The provided object is not a callable function. Then it is expected to have a + # 'process(row)' method, and optional 'open(partition_id, epoch_id)' and + # 'close(error)' methods. + + if not hasattr(f, 'process'): + raise Exception("Provided object does not have a 'process' method") + + if not callable(getattr(f, 'process')): + raise Exception("Attribute 'process' in provided object is not callable") + + def doesMethodExist(method_name): + exists = hasattr(f, method_name) + if exists and not callable(getattr(f, method_name)): + raise Exception( + "Attribute '%s' in provided object is not callable" % method_name) + return exists + + open_exists = doesMethodExist('open') + close_exists = doesMethodExist('close') + + def func_with_open_process_close(partition_id, iterator): + epoch_id = TaskContext.get().getLocalProperty('streaming.sql.batchId') + if epoch_id: + epoch_id = int(epoch_id) + else: + raise Exception("Could not get batch id from TaskContext") + + # Check if the data should be processed + should_process = True + if open_exists: + should_process = f.open(partition_id, epoch_id) + + error = None + + try: + if should_process: + for x in iterator: + f.process(x) + except Exception as ex: + error = ex + finally: + if close_exists: + f.close(error) + if error: + raise error + + return iter([]) + + func = func_with_open_process_close + + serializer = AutoBatchedSerializer(PickleSerializer()) + wrapped_func = _wrap_function(self._spark._sc, func, serializer, serializer) + jForeachWriter = \ + self._spark._sc._jvm.org.apache.spark.sql.execution.python.PythonForeachWriter( + wrapped_func, self._df._jdf.schema()) + self._jwrite.foreach(jForeachWriter) + return self + + @since(2.4) + def foreachBatch(self, func): + """ + Sets the output of the streaming query to be processed using the provided + function. This is supported only the in the micro-batch execution modes (that is, when the + trigger is not continuous). In every micro-batch, the provided function will be called in + every micro-batch with (i) the output rows as a DataFrame and (ii) the batch identifier. + The batchId can be used deduplicate and transactionally write the output + (that is, the provided Dataset) to external systems. The output DataFrame is guaranteed + to exactly same for the same batchId (assuming all operations are deterministic in the + query). + + .. note:: Evolving. + + >>> def func(batch_df, batch_id): + ... batch_df.collect() + ... + >>> writer = sdf.writeStream.foreach(func) + """ + + from pyspark.java_gateway import ensure_callback_server_started + gw = self._spark._sc._gateway + java_import(gw.jvm, "org.apache.spark.sql.execution.streaming.sources.*") + + wrapped_func = ForeachBatchFunction(self._spark, func) + gw.jvm.PythonForeachBatchHelper.callForeachBatch(self._jwrite, wrapped_func) + ensure_callback_server_started(gw) + return self + @ignore_unicode_prefix @since(2.0) def start(self, path=None, format=None, outputMode=None, partitionBy=None, queryName=None, diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 2d7a4f62d4ee8..565654e7f03bb 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -1869,6 +1869,299 @@ def test_query_manager_await_termination(self): q.stop() shutil.rmtree(tmpPath) + class ForeachWriterTester: + + def __init__(self, spark): + self.spark = spark + + def write_open_event(self, partitionId, epochId): + self._write_event( + self.open_events_dir, + {'partition': partitionId, 'epoch': epochId}) + + def write_process_event(self, row): + self._write_event(self.process_events_dir, {'value': 'text'}) + + def write_close_event(self, error): + self._write_event(self.close_events_dir, {'error': str(error)}) + + def write_input_file(self): + self._write_event(self.input_dir, "text") + + def open_events(self): + return self._read_events(self.open_events_dir, 'partition INT, epoch INT') + + def process_events(self): + return self._read_events(self.process_events_dir, 'value STRING') + + def close_events(self): + return self._read_events(self.close_events_dir, 'error STRING') + + def run_streaming_query_on_writer(self, writer, num_files): + self._reset() + try: + sdf = self.spark.readStream.format('text').load(self.input_dir) + sq = sdf.writeStream.foreach(writer).start() + for i in range(num_files): + self.write_input_file() + sq.processAllAvailable() + finally: + self.stop_all() + + def assert_invalid_writer(self, writer, msg=None): + self._reset() + try: + sdf = self.spark.readStream.format('text').load(self.input_dir) + sq = sdf.writeStream.foreach(writer).start() + self.write_input_file() + sq.processAllAvailable() + self.fail("invalid writer %s did not fail the query" % str(writer)) # not expected + except Exception as e: + if msg: + assert msg in str(e), "%s not in %s" % (msg, str(e)) + + finally: + self.stop_all() + + def stop_all(self): + for q in self.spark._wrapped.streams.active: + q.stop() + + def _reset(self): + self.input_dir = tempfile.mkdtemp() + self.open_events_dir = tempfile.mkdtemp() + self.process_events_dir = tempfile.mkdtemp() + self.close_events_dir = tempfile.mkdtemp() + + def _read_events(self, dir, json): + rows = self.spark.read.schema(json).json(dir).collect() + dicts = [row.asDict() for row in rows] + return dicts + + def _write_event(self, dir, event): + import uuid + with open(os.path.join(dir, str(uuid.uuid4())), 'w') as f: + f.write("%s\n" % str(event)) + + def __getstate__(self): + return (self.open_events_dir, self.process_events_dir, self.close_events_dir) + + def __setstate__(self, state): + self.open_events_dir, self.process_events_dir, self.close_events_dir = state + + def test_streaming_foreach_with_simple_function(self): + tester = self.ForeachWriterTester(self.spark) + + def foreach_func(row): + tester.write_process_event(row) + + tester.run_streaming_query_on_writer(foreach_func, 2) + self.assertEqual(len(tester.process_events()), 2) + + def test_streaming_foreach_with_basic_open_process_close(self): + tester = self.ForeachWriterTester(self.spark) + + class ForeachWriter: + def open(self, partitionId, epochId): + tester.write_open_event(partitionId, epochId) + return True + + def process(self, row): + tester.write_process_event(row) + + def close(self, error): + tester.write_close_event(error) + + tester.run_streaming_query_on_writer(ForeachWriter(), 2) + + open_events = tester.open_events() + self.assertEqual(len(open_events), 2) + self.assertSetEqual(set([e['epoch'] for e in open_events]), {0, 1}) + + self.assertEqual(len(tester.process_events()), 2) + + close_events = tester.close_events() + self.assertEqual(len(close_events), 2) + self.assertSetEqual(set([e['error'] for e in close_events]), {'None'}) + + def test_streaming_foreach_with_open_returning_false(self): + tester = self.ForeachWriterTester(self.spark) + + class ForeachWriter: + def open(self, partition_id, epoch_id): + tester.write_open_event(partition_id, epoch_id) + return False + + def process(self, row): + tester.write_process_event(row) + + def close(self, error): + tester.write_close_event(error) + + tester.run_streaming_query_on_writer(ForeachWriter(), 2) + + self.assertEqual(len(tester.open_events()), 2) + + self.assertEqual(len(tester.process_events()), 0) # no row was processed + + close_events = tester.close_events() + self.assertEqual(len(close_events), 2) + self.assertSetEqual(set([e['error'] for e in close_events]), {'None'}) + + def test_streaming_foreach_without_open_method(self): + tester = self.ForeachWriterTester(self.spark) + + class ForeachWriter: + def process(self, row): + tester.write_process_event(row) + + def close(self, error): + tester.write_close_event(error) + + tester.run_streaming_query_on_writer(ForeachWriter(), 2) + self.assertEqual(len(tester.open_events()), 0) # no open events + self.assertEqual(len(tester.process_events()), 2) + self.assertEqual(len(tester.close_events()), 2) + + def test_streaming_foreach_without_close_method(self): + tester = self.ForeachWriterTester(self.spark) + + class ForeachWriter: + def open(self, partition_id, epoch_id): + tester.write_open_event(partition_id, epoch_id) + return True + + def process(self, row): + tester.write_process_event(row) + + tester.run_streaming_query_on_writer(ForeachWriter(), 2) + self.assertEqual(len(tester.open_events()), 2) # no open events + self.assertEqual(len(tester.process_events()), 2) + self.assertEqual(len(tester.close_events()), 0) + + def test_streaming_foreach_without_open_and_close_methods(self): + tester = self.ForeachWriterTester(self.spark) + + class ForeachWriter: + def process(self, row): + tester.write_process_event(row) + + tester.run_streaming_query_on_writer(ForeachWriter(), 2) + self.assertEqual(len(tester.open_events()), 0) # no open events + self.assertEqual(len(tester.process_events()), 2) + self.assertEqual(len(tester.close_events()), 0) + + def test_streaming_foreach_with_process_throwing_error(self): + from pyspark.sql.utils import StreamingQueryException + + tester = self.ForeachWriterTester(self.spark) + + class ForeachWriter: + def process(self, row): + raise Exception("test error") + + def close(self, error): + tester.write_close_event(error) + + try: + tester.run_streaming_query_on_writer(ForeachWriter(), 1) + self.fail("bad writer did not fail the query") # this is not expected + except StreamingQueryException as e: + # TODO: Verify whether original error message is inside the exception + pass + + self.assertEqual(len(tester.process_events()), 0) # no row was processed + close_events = tester.close_events() + self.assertEqual(len(close_events), 1) + # TODO: Verify whether original error message is inside the exception + + def test_streaming_foreach_with_invalid_writers(self): + + tester = self.ForeachWriterTester(self.spark) + + def func_with_iterator_input(iter): + for x in iter: + print(x) + + tester.assert_invalid_writer(func_with_iterator_input) + + class WriterWithoutProcess: + def open(self, partition): + pass + + tester.assert_invalid_writer(WriterWithoutProcess(), "does not have a 'process'") + + class WriterWithNonCallableProcess(): + process = True + + tester.assert_invalid_writer(WriterWithNonCallableProcess(), + "'process' in provided object is not callable") + + class WriterWithNoParamProcess(): + def process(self): + pass + + tester.assert_invalid_writer(WriterWithNoParamProcess()) + + # Abstract class for tests below + class WithProcess(): + def process(self, row): + pass + + class WriterWithNonCallableOpen(WithProcess): + open = True + + tester.assert_invalid_writer(WriterWithNonCallableOpen(), + "'open' in provided object is not callable") + + class WriterWithNoParamOpen(WithProcess): + def open(self): + pass + + tester.assert_invalid_writer(WriterWithNoParamOpen()) + + class WriterWithNonCallableClose(WithProcess): + close = True + + tester.assert_invalid_writer(WriterWithNonCallableClose(), + "'close' in provided object is not callable") + + def test_streaming_foreachBatch(self): + q = None + collected = dict() + + def collectBatch(batch_df, batch_id): + collected[batch_id] = batch_df.collect() + + try: + df = self.spark.readStream.format('text').load('python/test_support/sql/streaming') + q = df.writeStream.foreachBatch(collectBatch).start() + q.processAllAvailable() + self.assertTrue(0 in collected) + self.assertTrue(len(collected[0]), 2) + finally: + if q: + q.stop() + + def test_streaming_foreachBatch_propagates_python_errors(self): + from pyspark.sql.utils import StreamingQueryException + + q = None + + def collectBatch(df, id): + raise Exception("this should fail the query") + + try: + df = self.spark.readStream.format('text').load('python/test_support/sql/streaming') + q = df.writeStream.foreachBatch(collectBatch).start() + q.processAllAvailable() + self.fail("Expected a failure") + except StreamingQueryException as e: + self.assertTrue("this should fail" in str(e)) + finally: + if q: + q.stop() + def test_help_command(self): # Regression test for SPARK-5464 rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}']) @@ -3058,11 +3351,41 @@ def test_checking_csv_header(self): finally: shutil.rmtree(path) - def test_repr_html(self): + def test_repr_behaviors(self): import re pattern = re.compile(r'^ *\|', re.MULTILINE) df = self.spark.createDataFrame([(1, "1"), (22222, "22222")], ("key", "value")) - self.assertEquals(None, df._repr_html_()) + + # test when eager evaluation is enabled and _repr_html_ will not be called + with self.sql_conf({"spark.sql.repl.eagerEval.enabled": True}): + expected1 = """+-----+-----+ + || key|value| + |+-----+-----+ + || 1| 1| + ||22222|22222| + |+-----+-----+ + |""" + self.assertEquals(re.sub(pattern, '', expected1), df.__repr__()) + with self.sql_conf({"spark.sql.repl.eagerEval.truncate": 3}): + expected2 = """+---+-----+ + ||key|value| + |+---+-----+ + || 1| 1| + ||222| 222| + |+---+-----+ + |""" + self.assertEquals(re.sub(pattern, '', expected2), df.__repr__()) + with self.sql_conf({"spark.sql.repl.eagerEval.maxNumRows": 1}): + expected3 = """+---+-----+ + ||key|value| + |+---+-----+ + || 1| 1| + |+---+-----+ + |only showing top 1 row + |""" + self.assertEquals(re.sub(pattern, '', expected3), df.__repr__()) + + # test when eager evaluation is enabled and _repr_html_ will be called with self.sql_conf({"spark.sql.repl.eagerEval.enabled": True}): expected1 = """ | @@ -3088,6 +3411,18 @@ def test_repr_html(self): |""" self.assertEquals(re.sub(pattern, '', expected3), df._repr_html_()) + # test when eager evaluation is disabled and _repr_html_ will be called + with self.sql_conf({"spark.sql.repl.eagerEval.enabled": False}): + expected = "DataFrame[key: bigint, value: string]" + self.assertEquals(None, df._repr_html_()) + self.assertEquals(expected, df.__repr__()) + with self.sql_conf({"spark.sql.repl.eagerEval.truncate": 3}): + self.assertEquals(None, df._repr_html_()) + self.assertEquals(expected, df.__repr__()) + with self.sql_conf({"spark.sql.repl.eagerEval.maxNumRows": 1}): + self.assertEquals(None, df._repr_html_()) + self.assertEquals(expected, df.__repr__()) + class HiveSparkSubmitTests(SparkSubmitTests): @@ -4449,7 +4784,6 @@ def test_vectorized_udf_chained(self): def test_vectorized_udf_wrong_return_type(self): from pyspark.sql.functions import pandas_udf, col - df = self.spark.range(10) with QuietTest(self.sc): with self.assertRaisesRegexp( NotImplementedError, @@ -5034,6 +5368,125 @@ def foo3(key, pdf): expected4 = udf3.func((), pdf) self.assertPandasEqual(expected4, result4) + def test_column_order(self): + from collections import OrderedDict + import pandas as pd + from pyspark.sql.functions import pandas_udf, PandasUDFType + + # Helper function to set column names from a list + def rename_pdf(pdf, names): + pdf.rename(columns={old: new for old, new in + zip(pd_result.columns, names)}, inplace=True) + + df = self.data + grouped_df = df.groupby('id') + grouped_pdf = df.toPandas().groupby('id') + + # Function returns a pdf with required column names, but order could be arbitrary using dict + def change_col_order(pdf): + # Constructing a DataFrame from a dict should result in the same order, + # but use from_items to ensure the pdf column order is different than schema + return pd.DataFrame.from_items([ + ('id', pdf.id), + ('u', pdf.v * 2), + ('v', pdf.v)]) + + ordered_udf = pandas_udf( + change_col_order, + 'id long, v int, u int', + PandasUDFType.GROUPED_MAP + ) + + # The UDF result should assign columns by name from the pdf + result = grouped_df.apply(ordered_udf).sort('id', 'v')\ + .select('id', 'u', 'v').toPandas() + pd_result = grouped_pdf.apply(change_col_order) + expected = pd_result.sort_values(['id', 'v']).reset_index(drop=True) + self.assertPandasEqual(expected, result) + + # Function returns a pdf with positional columns, indexed by range + def range_col_order(pdf): + # Create a DataFrame with positional columns, fix types to long + return pd.DataFrame(list(zip(pdf.id, pdf.v * 3, pdf.v)), dtype='int64') + + range_udf = pandas_udf( + range_col_order, + 'id long, u long, v long', + PandasUDFType.GROUPED_MAP + ) + + # The UDF result uses positional columns from the pdf + result = grouped_df.apply(range_udf).sort('id', 'v') \ + .select('id', 'u', 'v').toPandas() + pd_result = grouped_pdf.apply(range_col_order) + rename_pdf(pd_result, ['id', 'u', 'v']) + expected = pd_result.sort_values(['id', 'v']).reset_index(drop=True) + self.assertPandasEqual(expected, result) + + # Function returns a pdf with columns indexed with integers + def int_index(pdf): + return pd.DataFrame(OrderedDict([(0, pdf.id), (1, pdf.v * 4), (2, pdf.v)])) + + int_index_udf = pandas_udf( + int_index, + 'id long, u int, v int', + PandasUDFType.GROUPED_MAP + ) + + # The UDF result should assign columns by position of integer index + result = grouped_df.apply(int_index_udf).sort('id', 'v') \ + .select('id', 'u', 'v').toPandas() + pd_result = grouped_pdf.apply(int_index) + rename_pdf(pd_result, ['id', 'u', 'v']) + expected = pd_result.sort_values(['id', 'v']).reset_index(drop=True) + self.assertPandasEqual(expected, result) + + @pandas_udf('id long, v int', PandasUDFType.GROUPED_MAP) + def column_name_typo(pdf): + return pd.DataFrame({'iid': pdf.id, 'v': pdf.v}) + + @pandas_udf('id long, v int', PandasUDFType.GROUPED_MAP) + def invalid_positional_types(pdf): + return pd.DataFrame([(u'a', 1.2)]) + + with QuietTest(self.sc): + with self.assertRaisesRegexp(Exception, "KeyError: 'id'"): + grouped_df.apply(column_name_typo).collect() + with self.assertRaisesRegexp(Exception, "No cast implemented"): + grouped_df.apply(invalid_positional_types).collect() + + def test_positional_assignment_conf(self): + import pandas as pd + from pyspark.sql.functions import pandas_udf, PandasUDFType + + with self.sql_conf({"spark.sql.execution.pandas.groupedMap.assignColumnsByPosition": True}): + + @pandas_udf("a string, b float", PandasUDFType.GROUPED_MAP) + def foo(_): + return pd.DataFrame([('hi', 1)], columns=['x', 'y']) + + df = self.data + result = df.groupBy('id').apply(foo).select('a', 'b').collect() + for r in result: + self.assertEqual(r.a, 'hi') + self.assertEqual(r.b, 1) + + def test_self_join_with_pandas(self): + import pyspark.sql.functions as F + + @F.pandas_udf('key long, col string', F.PandasUDFType.GROUPED_MAP) + def dummy_pandas_udf(df): + return df[['key', 'col']] + + df = self.spark.createDataFrame([Row(key=1, col='A'), Row(key=1, col='B'), + Row(key=2, col='C')]) + df_with_pandas = df.groupBy('key').apply(dummy_pandas_udf) + + # this was throwing an AnalysisException before SPARK-24208 + res = df_with_pandas.alias('temp0').join(df_with_pandas.alias('temp1'), + F.col('temp0.key') == F.col('temp1.key')) + self.assertEquals(res.count(), 5) + @unittest.skipIf( not _have_pandas or not _have_pyarrow, diff --git a/python/pyspark/sql/utils.py b/python/pyspark/sql/utils.py index 45363f089a73d..bb9ce02c4b60f 100644 --- a/python/pyspark/sql/utils.py +++ b/python/pyspark/sql/utils.py @@ -150,3 +150,26 @@ def require_minimum_pyarrow_version(): if LooseVersion(pyarrow.__version__) < LooseVersion(minimum_pyarrow_version): raise ImportError("PyArrow >= %s must be installed; however, " "your version was %s." % (minimum_pyarrow_version, pyarrow.__version__)) + + +class ForeachBatchFunction(object): + """ + This is the Python implementation of Java interface 'ForeachBatchFunction'. This wraps + the user-defined 'foreachBatch' function such that it can be called from the JVM when + the query is active. + """ + + def __init__(self, sql_ctx, func): + self.sql_ctx = sql_ctx + self.func = func + + def call(self, jdf, batch_id): + from pyspark.sql.dataframe import DataFrame + try: + self.func(DataFrame(jdf, self.sql_ctx), batch_id) + except Exception as e: + self.error = e + raise e + + class Java: + implements = ['org.apache.spark.sql.execution.streaming.sources.PythonForeachBatchFunction'] diff --git a/python/pyspark/streaming/context.py b/python/pyspark/streaming/context.py index dd924ef89868e..a4515828d180c 100644 --- a/python/pyspark/streaming/context.py +++ b/python/pyspark/streaming/context.py @@ -79,22 +79,8 @@ def _ensure_initialized(cls): java_import(gw.jvm, "org.apache.spark.streaming.api.java.*") java_import(gw.jvm, "org.apache.spark.streaming.api.python.*") - # start callback server - # getattr will fallback to JVM, so we cannot test by hasattr() - if "_callback_server" not in gw.__dict__ or gw._callback_server is None: - gw.callback_server_parameters.eager_load = True - gw.callback_server_parameters.daemonize = True - gw.callback_server_parameters.daemonize_connections = True - gw.callback_server_parameters.port = 0 - gw.start_callback_server(gw.callback_server_parameters) - cbport = gw._callback_server.server_socket.getsockname()[1] - gw._callback_server.port = cbport - # gateway with real port - gw._python_proxy_port = gw._callback_server.port - # get the GatewayServer object in JVM by ID - jgws = JavaObject("GATEWAY_SERVER", gw._gateway_client) - # update the port of CallbackClient with real port - jgws.resetCallbackClient(jgws.getCallbackClient().getAddress(), gw._python_proxy_port) + from pyspark.java_gateway import ensure_callback_server_started + ensure_callback_server_started(gw) # register serializer for TransformFunction # it happens before creating SparkContext when loading from checkpointing diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 18b2f251dc9fd..a4c5fb1db8b37 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -581,9 +581,9 @@ def test_get_local_property(self): self.sc.setLocalProperty(key, value) try: rdd = self.sc.parallelize(range(1), 1) - prop1 = rdd.map(lambda x: TaskContext.get().getLocalProperty(key)).collect()[0] + prop1 = rdd.map(lambda _: TaskContext.get().getLocalProperty(key)).collect()[0] self.assertEqual(prop1, value) - prop2 = rdd.map(lambda x: TaskContext.get().getLocalProperty("otherkey")).collect()[0] + prop2 = rdd.map(lambda _: TaskContext.get().getLocalProperty("otherkey")).collect()[0] self.assertTrue(prop2 is None) finally: self.sc.setLocalProperty(key, None) diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 38fe2ef06eac5..eaaae2b14e107 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -38,6 +38,9 @@ from pyspark.util import _get_argspec, fail_on_stopiteration from pyspark import shuffle +if sys.version >= '3': + basestring = str + pickleSer = PickleSerializer() utf8_deserializer = UTF8Deserializer() @@ -92,7 +95,10 @@ def verify_result_length(*a): return lambda *a: (verify_result_length(*a), arrow_return_type) -def wrap_grouped_map_pandas_udf(f, return_type, argspec): +def wrap_grouped_map_pandas_udf(f, return_type, argspec, runner_conf): + assign_cols_by_pos = runner_conf.get( + "spark.sql.execution.pandas.groupedMap.assignColumnsByPosition", False) + def wrapped(key_series, value_series): import pandas as pd @@ -110,9 +116,13 @@ def wrapped(key_series, value_series): "Number of columns of the returned pandas.DataFrame " "doesn't match specified schema. " "Expected: {} Actual: {}".format(len(return_type), len(result.columns))) - arrow_return_types = (to_arrow_type(field.dataType) for field in return_type) - return [(result[result.columns[i]], arrow_type) - for i, arrow_type in enumerate(arrow_return_types)] + + # Assign result columns by schema name if user labeled with strings, else use position + if not assign_cols_by_pos and any(isinstance(name, basestring) for name in result.columns): + return [(result[field.name], to_arrow_type(field.dataType)) for field in return_type] + else: + return [(result[result.columns[i]], to_arrow_type(field.dataType)) + for i, field in enumerate(return_type)] return wrapped @@ -143,7 +153,7 @@ def wrapped(*series): return lambda *a: (wrapped(*a), arrow_return_type) -def read_single_udf(pickleSer, infile, eval_type): +def read_single_udf(pickleSer, infile, eval_type, runner_conf): num_arg = read_int(infile) arg_offsets = [read_int(infile) for i in range(num_arg)] row_func = None @@ -163,7 +173,7 @@ def read_single_udf(pickleSer, infile, eval_type): return arg_offsets, wrap_scalar_pandas_udf(func, return_type) elif eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF: argspec = _get_argspec(row_func) # signature was lost when wrapping it - return arg_offsets, wrap_grouped_map_pandas_udf(func, return_type, argspec) + return arg_offsets, wrap_grouped_map_pandas_udf(func, return_type, argspec, runner_conf) elif eval_type == PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF: return arg_offsets, wrap_grouped_agg_pandas_udf(func, return_type) elif eval_type == PythonEvalType.SQL_WINDOW_AGG_PANDAS_UDF: @@ -175,6 +185,26 @@ def read_single_udf(pickleSer, infile, eval_type): def read_udfs(pickleSer, infile, eval_type): + runner_conf = {} + + if eval_type in (PythonEvalType.SQL_SCALAR_PANDAS_UDF, + PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF, + PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF, + PythonEvalType.SQL_WINDOW_AGG_PANDAS_UDF): + + # Load conf used for pandas_udf evaluation + num_conf = read_int(infile) + for i in range(num_conf): + k = utf8_deserializer.loads(infile) + v = utf8_deserializer.loads(infile) + runner_conf[k] = v + + # NOTE: if timezone is set here, that implies respectSessionTimeZone is True + timezone = runner_conf.get("spark.sql.session.timeZone", None) + ser = ArrowStreamPandasSerializer(timezone) + else: + ser = BatchedSerializer(PickleSerializer(), 100) + num_udfs = read_int(infile) udfs = {} call_udf = [] @@ -189,7 +219,7 @@ def read_udfs(pickleSer, infile, eval_type): # See FlatMapGroupsInPandasExec for how arg_offsets are used to # distinguish between grouping attributes and data attributes - arg_offsets, udf = read_single_udf(pickleSer, infile, eval_type) + arg_offsets, udf = read_single_udf(pickleSer, infile, eval_type, runner_conf) udfs['f'] = udf split_offset = arg_offsets[0] + 1 arg0 = ["a[%d]" % o for o in arg_offsets[1: split_offset]] @@ -201,7 +231,7 @@ def read_udfs(pickleSer, infile, eval_type): # In the special case of a single UDF this will return a single result rather # than a tuple of results; this is the format that the JVM side expects. for i in range(num_udfs): - arg_offsets, udf = read_single_udf(pickleSer, infile, eval_type) + arg_offsets, udf = read_single_udf(pickleSer, infile, eval_type, runner_conf) udfs['f%d' % i] = udf args = ["a[%d]" % o for o in arg_offsets] call_udf.append("f%d(%s)" % (i, ", ".join(args))) @@ -210,15 +240,6 @@ def read_udfs(pickleSer, infile, eval_type): mapper = eval(mapper_str, udfs) func = lambda _, it: map(mapper, it) - if eval_type in (PythonEvalType.SQL_SCALAR_PANDAS_UDF, - PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF, - PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF, - PythonEvalType.SQL_WINDOW_AGG_PANDAS_UDF): - timezone = utf8_deserializer.loads(infile) - ser = ArrowStreamPandasSerializer(timezone) - else: - ser = BatchedSerializer(PickleSerializer(), 100) - # profiling is not supported for UDF return func, None, ser, ser diff --git a/python/setup.py b/python/setup.py index d309e0564530a..45eb74eb87ce7 100644 --- a/python/setup.py +++ b/python/setup.py @@ -219,6 +219,7 @@ def _supports_symlinks(): 'Programming Language :: Python :: 3.4', 'Programming Language :: Python :: 3.5', 'Programming Language :: Python :: 3.6', + 'Programming Language :: Python :: 3.7', 'Programming Language :: Python :: Implementation :: CPython', 'Programming Language :: Python :: Implementation :: PyPy'] ) diff --git a/repl/pom.xml b/repl/pom.xml index 6f4a863c48bc7..861bbd7c49654 100644 --- a/repl/pom.xml +++ b/repl/pom.xml @@ -102,7 +102,7 @@ org.apache.xbean - xbean-asm5-shaded + xbean-asm6-shaded @@ -166,7 +166,7 @@ - + scala-2.12 diff --git a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkILoop.scala b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkILoop.scala index e69441a475e9a..a44051b351e19 100644 --- a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkILoop.scala +++ b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkILoop.scala @@ -36,7 +36,7 @@ class SparkILoop(in0: Option[BufferedReader], out: JPrintWriter) def this() = this(None, new JPrintWriter(Console.out, true)) override def createInterpreter(): Unit = { - intp = new SparkILoopInterpreter(settings, out) + intp = new SparkILoopInterpreter(settings, out, initializeSpark) } val initializationCommands: Seq[String] = Seq( @@ -73,11 +73,15 @@ class SparkILoop(in0: Option[BufferedReader], out: JPrintWriter) "import org.apache.spark.sql.functions._" ) - def initializeSpark() { - intp.beQuietDuring { - savingReplayStack { // remove the commands from session history. - initializationCommands.foreach(processLine) + def initializeSpark(): Unit = { + if (!intp.reporter.hasErrors) { + // `savingReplayStack` removes the commands from session history. + savingReplayStack { + initializationCommands.foreach(intp quietRun _) } + } else { + throw new RuntimeException(s"Scala $versionString interpreter encountered " + + "errors during initialization") } } @@ -101,16 +105,6 @@ class SparkILoop(in0: Option[BufferedReader], out: JPrintWriter) /** Available commands */ override def commands: List[LoopCommand] = standardCommands - /** - * We override `loadFiles` because we need to initialize Spark *before* the REPL - * sees any files, so that the Spark context is visible in those files. This is a bit of a - * hack, but there isn't another hook available to us at this point. - */ - override def loadFiles(settings: Settings): Unit = { - initializeSpark() - super.loadFiles(settings) - } - override def resetCommand(line: String): Unit = { super.resetCommand(line) initializeSpark() diff --git a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkILoopInterpreter.scala b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkILoopInterpreter.scala index e736607a9a6b9..4e63816402a10 100644 --- a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkILoopInterpreter.scala +++ b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkILoopInterpreter.scala @@ -21,8 +21,22 @@ import scala.collection.mutable import scala.tools.nsc.Settings import scala.tools.nsc.interpreter._ -class SparkILoopInterpreter(settings: Settings, out: JPrintWriter) extends IMain(settings, out) { - self => +class SparkILoopInterpreter(settings: Settings, out: JPrintWriter, initializeSpark: () => Unit) + extends IMain(settings, out) { self => + + /** + * We override `initializeSynchronous` to initialize Spark *after* `intp` is properly initialized + * and *before* the REPL sees any files in the private `loadInitFiles` functions, so that + * the Spark context is visible in those files. + * + * This is a bit of a hack, but there isn't another hook available to us at this point. + * + * See the discussion in Scala community https://github.com/scala/bug/issues/10913 for detail. + */ + override def initializeSynchronous(): Unit = { + super.initializeSynchronous() + initializeSpark() + } override lazy val memberHandlers = new { val intp: self.type = self diff --git a/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala b/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala index 4dc399827ffed..42298b06a2c86 100644 --- a/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala +++ b/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala @@ -22,8 +22,8 @@ import java.net.{URI, URL, URLEncoder} import java.nio.channels.Channels import org.apache.hadoop.fs.{FileSystem, Path} -import org.apache.xbean.asm5._ -import org.apache.xbean.asm5.Opcodes._ +import org.apache.xbean.asm6._ +import org.apache.xbean.asm6.Opcodes._ import org.apache.spark.{SparkConf, SparkEnv} import org.apache.spark.deploy.SparkHadoopUtil diff --git a/resource-managers/kubernetes/core/pom.xml b/resource-managers/kubernetes/core/pom.xml index a62f271273465..920f0f6ebf2c8 100644 --- a/resource-managers/kubernetes/core/pom.xml +++ b/resource-managers/kubernetes/core/pom.xml @@ -47,6 +47,12 @@ test + + org.apache.spark + spark-tags_${scala.binary.version} + test-jar + + io.fabric8 kubernetes-client @@ -77,6 +83,12 @@ + + com.squareup.okhttp3 + okhttp + 3.8.1 + + org.mockito mockito-core @@ -84,9 +96,9 @@ - com.squareup.okhttp3 - okhttp - 3.8.1 + org.jmock + jmock-junit4 + test diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala index 590deaa72e7ee..f9a77e71ad618 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala @@ -176,6 +176,24 @@ private[spark] object Config extends Logging { .checkValue(interval => interval > 0, s"Logging interval must be a positive time value.") .createWithDefaultString("1s") + val KUBERNETES_EXECUTOR_API_POLLING_INTERVAL = + ConfigBuilder("spark.kubernetes.executor.apiPollingInterval") + .doc("Interval between polls against the Kubernetes API server to inspect the " + + "state of executors.") + .timeConf(TimeUnit.MILLISECONDS) + .checkValue(interval => interval > 0, s"API server polling interval must be a" + + " positive time value.") + .createWithDefaultString("30s") + + val KUBERNETES_EXECUTOR_EVENT_PROCESSING_INTERVAL = + ConfigBuilder("spark.kubernetes.executor.eventProcessingInterval") + .doc("Interval between successive inspection of executor events sent from the" + + " Kubernetes API.") + .timeConf(TimeUnit.MILLISECONDS) + .checkValue(interval => interval > 0, s"Event processing interval must be a positive" + + " time value.") + .createWithDefaultString("1s") + val MEMORY_OVERHEAD_FACTOR = ConfigBuilder("spark.kubernetes.memoryOverheadFactor") .doc("This sets the Memory Overhead Factor that will allocate memory to non-JVM jobs " + @@ -193,7 +211,6 @@ private[spark] object Config extends Logging { "Ensure that major Python version is either Python2 or Python3") .createWithDefault("2") - val KUBERNETES_AUTH_SUBMISSION_CONF_PREFIX = "spark.kubernetes.authenticate.submission" @@ -203,11 +220,23 @@ private[spark] object Config extends Logging { val KUBERNETES_DRIVER_ANNOTATION_PREFIX = "spark.kubernetes.driver.annotation." val KUBERNETES_DRIVER_SECRETS_PREFIX = "spark.kubernetes.driver.secrets." val KUBERNETES_DRIVER_SECRET_KEY_REF_PREFIX = "spark.kubernetes.driver.secretKeyRef." + val KUBERNETES_DRIVER_VOLUMES_PREFIX = "spark.kubernetes.driver.volumes." val KUBERNETES_EXECUTOR_LABEL_PREFIX = "spark.kubernetes.executor.label." val KUBERNETES_EXECUTOR_ANNOTATION_PREFIX = "spark.kubernetes.executor.annotation." val KUBERNETES_EXECUTOR_SECRETS_PREFIX = "spark.kubernetes.executor.secrets." val KUBERNETES_EXECUTOR_SECRET_KEY_REF_PREFIX = "spark.kubernetes.executor.secretKeyRef." + val KUBERNETES_EXECUTOR_VOLUMES_PREFIX = "spark.kubernetes.executor.volumes." + + val KUBERNETES_VOLUMES_HOSTPATH_TYPE = "hostPath" + val KUBERNETES_VOLUMES_PVC_TYPE = "persistentVolumeClaim" + val KUBERNETES_VOLUMES_EMPTYDIR_TYPE = "emptyDir" + val KUBERNETES_VOLUMES_MOUNT_PATH_KEY = "mount.path" + val KUBERNETES_VOLUMES_MOUNT_READONLY_KEY = "mount.readOnly" + val KUBERNETES_VOLUMES_OPTIONS_PATH_KEY = "options.path" + val KUBERNETES_VOLUMES_OPTIONS_CLAIM_NAME_KEY = "options.claimName" + val KUBERNETES_VOLUMES_OPTIONS_MEDIUM_KEY = "options.medium" + val KUBERNETES_VOLUMES_OPTIONS_SIZE_LIMIT_KEY = "options.sizeLimit" val KUBERNETES_DRIVER_ENV_PREFIX = "spark.kubernetes.driverEnv." } diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Constants.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Constants.scala index 69bd03d1eda6f..5ecdd3a04d77b 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Constants.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Constants.scala @@ -25,9 +25,6 @@ private[spark] object Constants { val SPARK_POD_DRIVER_ROLE = "driver" val SPARK_POD_EXECUTOR_ROLE = "executor" - // Annotations - val SPARK_APP_NAME_ANNOTATION = "spark-app-name" - // Credentials secrets val DRIVER_CREDENTIALS_SECRETS_BASE_DIR = "/mnt/secrets/spark-kubernetes-credentials" @@ -50,17 +47,14 @@ private[spark] object Constants { val DEFAULT_BLOCKMANAGER_PORT = 7079 val DRIVER_PORT_NAME = "driver-rpc-port" val BLOCK_MANAGER_PORT_NAME = "blockmanager" - val EXECUTOR_PORT_NAME = "executor" // Environment Variables - val ENV_EXECUTOR_PORT = "SPARK_EXECUTOR_PORT" val ENV_DRIVER_URL = "SPARK_DRIVER_URL" val ENV_EXECUTOR_CORES = "SPARK_EXECUTOR_CORES" val ENV_EXECUTOR_MEMORY = "SPARK_EXECUTOR_MEMORY" val ENV_APPLICATION_ID = "SPARK_APPLICATION_ID" val ENV_EXECUTOR_ID = "SPARK_EXECUTOR_ID" val ENV_EXECUTOR_POD_IP = "SPARK_EXECUTOR_POD_IP" - val ENV_MOUNTED_CLASSPATH = "SPARK_MOUNTED_CLASSPATH" val ENV_JAVA_OPT_PREFIX = "SPARK_JAVA_OPT_" val ENV_CLASSPATH = "SPARK_CLASSPATH" val ENV_DRIVER_BIND_ADDRESS = "SPARK_DRIVER_BIND_ADDRESS" diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesConf.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesConf.scala index b0ccaa36b01ed..51d205fdb68d1 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesConf.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesConf.scala @@ -59,6 +59,7 @@ private[spark] case class KubernetesConf[T <: KubernetesRoleSpecificConf]( roleSecretNamesToMountPaths: Map[String, String], roleSecretEnvNamesToKeyRefs: Map[String, String], roleEnvs: Map[String, String], + roleVolumes: Iterable[KubernetesVolumeSpec[_ <: KubernetesVolumeSpecificConf]], sparkFiles: Seq[String]) { def namespace(): String = sparkConf.get(KUBERNETES_NAMESPACE) @@ -155,6 +156,12 @@ private[spark] object KubernetesConf { sparkConf, KUBERNETES_DRIVER_SECRET_KEY_REF_PREFIX) val driverEnvs = KubernetesUtils.parsePrefixedKeyValuePairs( sparkConf, KUBERNETES_DRIVER_ENV_PREFIX) + val driverVolumes = KubernetesVolumeUtils.parseVolumesWithPrefix( + sparkConf, KUBERNETES_DRIVER_VOLUMES_PREFIX).map(_.get) + // Also parse executor volumes in order to verify configuration + // before the driver pod is created + KubernetesVolumeUtils.parseVolumesWithPrefix( + sparkConf, KUBERNETES_EXECUTOR_VOLUMES_PREFIX).map(_.get) val sparkFiles = sparkConf .getOption("spark.files") @@ -171,6 +178,7 @@ private[spark] object KubernetesConf { driverSecretNamesToMountPaths, driverSecretEnvNamesToKeyRefs, driverEnvs, + driverVolumes, sparkFiles) } @@ -203,6 +211,8 @@ private[spark] object KubernetesConf { val executorEnvSecrets = KubernetesUtils.parsePrefixedKeyValuePairs( sparkConf, KUBERNETES_EXECUTOR_SECRET_KEY_REF_PREFIX) val executorEnv = sparkConf.getExecutorEnv.toMap + val executorVolumes = KubernetesVolumeUtils.parseVolumesWithPrefix( + sparkConf, KUBERNETES_EXECUTOR_VOLUMES_PREFIX).map(_.get) KubernetesConf( sparkConf.clone(), @@ -214,6 +224,7 @@ private[spark] object KubernetesConf { executorMountSecrets, executorEnvSecrets, executorEnv, + executorVolumes, Seq.empty[String]) } } diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesUtils.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesUtils.scala index 593fb531a004d..66fff267545dc 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesUtils.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesUtils.scala @@ -16,8 +16,6 @@ */ package org.apache.spark.deploy.k8s -import io.fabric8.kubernetes.api.model.LocalObjectReference - import org.apache.spark.SparkConf import org.apache.spark.util.Utils diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesVolumeSpec.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesVolumeSpec.scala new file mode 100644 index 0000000000000..b1762d1efe2ea --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesVolumeSpec.scala @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s + +private[spark] sealed trait KubernetesVolumeSpecificConf + +private[spark] case class KubernetesHostPathVolumeConf( + hostPath: String) + extends KubernetesVolumeSpecificConf + +private[spark] case class KubernetesPVCVolumeConf( + claimName: String) + extends KubernetesVolumeSpecificConf + +private[spark] case class KubernetesEmptyDirVolumeConf( + medium: Option[String], + sizeLimit: Option[String]) + extends KubernetesVolumeSpecificConf + +private[spark] case class KubernetesVolumeSpec[T <: KubernetesVolumeSpecificConf]( + volumeName: String, + mountPath: String, + mountReadOnly: Boolean, + volumeConf: T) diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesVolumeUtils.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesVolumeUtils.scala new file mode 100644 index 0000000000000..713df5fffc3a2 --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesVolumeUtils.scala @@ -0,0 +1,110 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s + +import java.util.NoSuchElementException + +import scala.util.{Failure, Success, Try} + +import org.apache.spark.SparkConf +import org.apache.spark.deploy.k8s.Config._ + +private[spark] object KubernetesVolumeUtils { + /** + * Extract Spark volume configuration properties with a given name prefix. + * + * @param sparkConf Spark configuration + * @param prefix the given property name prefix + * @return a Map storing with volume name as key and spec as value + */ + def parseVolumesWithPrefix( + sparkConf: SparkConf, + prefix: String): Iterable[Try[KubernetesVolumeSpec[_ <: KubernetesVolumeSpecificConf]]] = { + val properties = sparkConf.getAllWithPrefix(prefix).toMap + + getVolumeTypesAndNames(properties).map { case (volumeType, volumeName) => + val pathKey = s"$volumeType.$volumeName.$KUBERNETES_VOLUMES_MOUNT_PATH_KEY" + val readOnlyKey = s"$volumeType.$volumeName.$KUBERNETES_VOLUMES_MOUNT_READONLY_KEY" + + for { + path <- properties.getTry(pathKey) + volumeConf <- parseVolumeSpecificConf(properties, volumeType, volumeName) + } yield KubernetesVolumeSpec( + volumeName = volumeName, + mountPath = path, + mountReadOnly = properties.get(readOnlyKey).exists(_.toBoolean), + volumeConf = volumeConf + ) + } + } + + /** + * Get unique pairs of volumeType and volumeName, + * assuming options are formatted in this way: + * `volumeType`.`volumeName`.`property` = `value` + * @param properties flat mapping of property names to values + * @return Set[(volumeType, volumeName)] + */ + private def getVolumeTypesAndNames( + properties: Map[String, String] + ): Set[(String, String)] = { + properties.keys.flatMap { k => + k.split('.').toList match { + case tpe :: name :: _ => Some((tpe, name)) + case _ => None + } + }.toSet + } + + private def parseVolumeSpecificConf( + options: Map[String, String], + volumeType: String, + volumeName: String): Try[KubernetesVolumeSpecificConf] = { + volumeType match { + case KUBERNETES_VOLUMES_HOSTPATH_TYPE => + val pathKey = s"$volumeType.$volumeName.$KUBERNETES_VOLUMES_OPTIONS_PATH_KEY" + for { + path <- options.getTry(pathKey) + } yield KubernetesHostPathVolumeConf(path) + + case KUBERNETES_VOLUMES_PVC_TYPE => + val claimNameKey = s"$volumeType.$volumeName.$KUBERNETES_VOLUMES_OPTIONS_CLAIM_NAME_KEY" + for { + claimName <- options.getTry(claimNameKey) + } yield KubernetesPVCVolumeConf(claimName) + + case KUBERNETES_VOLUMES_EMPTYDIR_TYPE => + val mediumKey = s"$volumeType.$volumeName.$KUBERNETES_VOLUMES_OPTIONS_MEDIUM_KEY" + val sizeLimitKey = s"$volumeType.$volumeName.$KUBERNETES_VOLUMES_OPTIONS_SIZE_LIMIT_KEY" + Success(KubernetesEmptyDirVolumeConf(options.get(mediumKey), options.get(sizeLimitKey))) + + case _ => + Failure(new RuntimeException(s"Kubernetes Volume type `$volumeType` is not supported")) + } + } + + /** + * Convenience wrapper to accumulate key lookup errors + */ + implicit private class MapOps[A, B](m: Map[A, B]) { + def getTry(key: A): Try[B] = { + m + .get(key) + .fold[Try[B]](Failure(new NoSuchElementException(key.toString)))(Success(_)) + } + } +} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicDriverFeatureStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicDriverFeatureStep.scala index 143dc8a12304e..7e67b51de6e04 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicDriverFeatureStep.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicDriverFeatureStep.scala @@ -19,10 +19,10 @@ package org.apache.spark.deploy.k8s.features import scala.collection.JavaConverters._ import scala.collection.mutable -import io.fabric8.kubernetes.api.model.{ContainerBuilder, EnvVarBuilder, EnvVarSourceBuilder, HasMetadata, PodBuilder, QuantityBuilder} +import io.fabric8.kubernetes.api.model._ import org.apache.spark.SparkException -import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesDriverSpecificConf, KubernetesUtils, SparkPod} +import org.apache.spark.deploy.k8s._ import org.apache.spark.deploy.k8s.Config._ import org.apache.spark.deploy.k8s.Constants._ import org.apache.spark.deploy.k8s.submit._ @@ -103,6 +103,7 @@ private[spark] class BasicDriverFeatureStep( .addToImagePullSecrets(conf.imagePullSecrets(): _*) .endSpec() .build() + SparkPod(driverPod, driverContainer) } diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStep.scala index 91c54a9776982..abaeff0313a79 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStep.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStep.scala @@ -18,10 +18,10 @@ package org.apache.spark.deploy.k8s.features import scala.collection.JavaConverters._ -import io.fabric8.kubernetes.api.model.{ContainerBuilder, ContainerPortBuilder, EnvVar, EnvVarBuilder, EnvVarSourceBuilder, HasMetadata, PodBuilder, QuantityBuilder} +import io.fabric8.kubernetes.api.model._ import org.apache.spark.SparkException -import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesExecutorSpecificConf, SparkPod} +import org.apache.spark.deploy.k8s._ import org.apache.spark.deploy.k8s.Config._ import org.apache.spark.deploy.k8s.Constants._ import org.apache.spark.internal.config.{EXECUTOR_CLASS_PATH, EXECUTOR_JAVA_OPTIONS, EXECUTOR_MEMORY, EXECUTOR_MEMORY_OVERHEAD} @@ -173,6 +173,7 @@ private[spark] class BasicExecutorFeatureStep( .addToImagePullSecrets(kubernetesConf.imagePullSecrets(): _*) .endSpec() .build() + SparkPod(executorPod, containerWithLimitCores) } diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/MountVolumesFeatureStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/MountVolumesFeatureStep.scala new file mode 100644 index 0000000000000..bb0e2b3128efd --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/MountVolumesFeatureStep.scala @@ -0,0 +1,79 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.features + +import io.fabric8.kubernetes.api.model._ + +import org.apache.spark.deploy.k8s._ + +private[spark] class MountVolumesFeatureStep( + kubernetesConf: KubernetesConf[_ <: KubernetesRoleSpecificConf]) + extends KubernetesFeatureConfigStep { + + override def configurePod(pod: SparkPod): SparkPod = { + val (volumeMounts, volumes) = constructVolumes(kubernetesConf.roleVolumes).unzip + + val podWithVolumes = new PodBuilder(pod.pod) + .editSpec() + .addToVolumes(volumes.toSeq: _*) + .endSpec() + .build() + + val containerWithVolumeMounts = new ContainerBuilder(pod.container) + .addToVolumeMounts(volumeMounts.toSeq: _*) + .build() + + SparkPod(podWithVolumes, containerWithVolumeMounts) + } + + override def getAdditionalPodSystemProperties(): Map[String, String] = Map.empty + + override def getAdditionalKubernetesResources(): Seq[HasMetadata] = Seq.empty + + private def constructVolumes( + volumeSpecs: Iterable[KubernetesVolumeSpec[_ <: KubernetesVolumeSpecificConf]] + ): Iterable[(VolumeMount, Volume)] = { + volumeSpecs.map { spec => + val volumeMount = new VolumeMountBuilder() + .withMountPath(spec.mountPath) + .withReadOnly(spec.mountReadOnly) + .withName(spec.volumeName) + .build() + + val volumeBuilder = spec.volumeConf match { + case KubernetesHostPathVolumeConf(hostPath) => + new VolumeBuilder() + .withHostPath(new HostPathVolumeSource(hostPath)) + + case KubernetesPVCVolumeConf(claimName) => + new VolumeBuilder() + .withPersistentVolumeClaim( + new PersistentVolumeClaimVolumeSource(claimName, spec.mountReadOnly)) + + case KubernetesEmptyDirVolumeConf(medium, sizeLimit) => + new VolumeBuilder() + .withEmptyDir( + new EmptyDirVolumeSource(medium.getOrElse(""), + new Quantity(sizeLimit.orNull))) + } + + val volume = volumeBuilder.withName(spec.volumeName).build() + + (volumeMount, volume) + } + } +} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilder.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilder.scala index 5762d8245f778..7208e3d377593 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilder.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilder.scala @@ -17,7 +17,7 @@ package org.apache.spark.deploy.k8s.submit import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesDriverSpec, KubernetesDriverSpecificConf, KubernetesRoleSpecificConf} -import org.apache.spark.deploy.k8s.features.{BasicDriverFeatureStep, DriverKubernetesCredentialsFeatureStep, DriverServiceFeatureStep, EnvSecretsFeatureStep, KubernetesFeatureConfigStep, LocalDirsFeatureStep, MountSecretsFeatureStep} +import org.apache.spark.deploy.k8s.features._ import org.apache.spark.deploy.k8s.features.bindings.{JavaDriverFeatureStep, PythonDriverFeatureStep} private[spark] class KubernetesDriverBuilder( @@ -33,10 +33,13 @@ private[spark] class KubernetesDriverBuilder( new MountSecretsFeatureStep(_), provideEnvSecretsStep: (KubernetesConf[_ <: KubernetesRoleSpecificConf] => EnvSecretsFeatureStep) = - new EnvSecretsFeatureStep(_), - provideLocalDirsStep: (KubernetesConf[_ <: KubernetesRoleSpecificConf] - => LocalDirsFeatureStep) = + new EnvSecretsFeatureStep(_), + provideLocalDirsStep: (KubernetesConf[_ <: KubernetesRoleSpecificConf]) + => LocalDirsFeatureStep = new LocalDirsFeatureStep(_), + provideVolumesStep: (KubernetesConf[_ <: KubernetesRoleSpecificConf] + => MountVolumesFeatureStep) = + new MountVolumesFeatureStep(_), provideJavaStep: ( KubernetesConf[KubernetesDriverSpecificConf] => JavaDriverFeatureStep) = @@ -54,22 +57,25 @@ private[spark] class KubernetesDriverBuilder( provideServiceStep(kubernetesConf), provideLocalDirsStep(kubernetesConf)) - val maybeRoleSecretNamesStep = if (kubernetesConf.roleSecretNamesToMountPaths.nonEmpty) { - Some(provideSecretsStep(kubernetesConf)) } else None - - val maybeProvideSecretsStep = if (kubernetesConf.roleSecretEnvNamesToKeyRefs.nonEmpty) { - Some(provideEnvSecretsStep(kubernetesConf)) } else None + val secretFeature = if (kubernetesConf.roleSecretNamesToMountPaths.nonEmpty) { + Seq(provideSecretsStep(kubernetesConf)) + } else Nil + val envSecretFeature = if (kubernetesConf.roleSecretEnvNamesToKeyRefs.nonEmpty) { + Seq(provideEnvSecretsStep(kubernetesConf)) + } else Nil + val volumesFeature = if (kubernetesConf.roleVolumes.nonEmpty) { + Seq(provideVolumesStep(kubernetesConf)) + } else Nil val bindingsStep = kubernetesConf.roleSpecificConf.mainAppResource.map { case JavaMainAppResource(_) => provideJavaStep(kubernetesConf) case PythonMainAppResource(_) => - providePythonStep(kubernetesConf)}.getOrElse(provideJavaStep(kubernetesConf)) + providePythonStep(kubernetesConf)} + .getOrElse(provideJavaStep(kubernetesConf)) - val allFeatures: Seq[KubernetesFeatureConfigStep] = - (baseFeatures :+ bindingsStep) ++ - maybeRoleSecretNamesStep.toSeq ++ - maybeProvideSecretsStep.toSeq + val allFeatures = (baseFeatures :+ bindingsStep) ++ + secretFeature ++ envSecretFeature ++ volumesFeature var spec = KubernetesDriverSpec.initialSpec(kubernetesConf.sparkConf.getAll.toMap) for (feature <- allFeatures) { diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodStates.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodStates.scala new file mode 100644 index 0000000000000..83daddf714489 --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodStates.scala @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.scheduler.cluster.k8s + +import io.fabric8.kubernetes.api.model.Pod + +sealed trait ExecutorPodState { + def pod: Pod +} + +case class PodRunning(pod: Pod) extends ExecutorPodState + +case class PodPending(pod: Pod) extends ExecutorPodState + +sealed trait FinalPodState extends ExecutorPodState + +case class PodSucceeded(pod: Pod) extends FinalPodState + +case class PodFailed(pod: Pod) extends FinalPodState + +case class PodDeleted(pod: Pod) extends FinalPodState + +case class PodUnknown(pod: Pod) extends ExecutorPodState diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsAllocator.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsAllocator.scala new file mode 100644 index 0000000000000..5a143ad3600fd --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsAllocator.scala @@ -0,0 +1,149 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.scheduler.cluster.k8s + +import java.util.concurrent.atomic.{AtomicInteger, AtomicLong} + +import io.fabric8.kubernetes.api.model.PodBuilder +import io.fabric8.kubernetes.client.KubernetesClient +import scala.collection.mutable + +import org.apache.spark.{SparkConf, SparkException} +import org.apache.spark.deploy.k8s.Config._ +import org.apache.spark.deploy.k8s.Constants._ +import org.apache.spark.deploy.k8s.KubernetesConf +import org.apache.spark.internal.Logging +import org.apache.spark.util.{Clock, Utils} + +private[spark] class ExecutorPodsAllocator( + conf: SparkConf, + executorBuilder: KubernetesExecutorBuilder, + kubernetesClient: KubernetesClient, + snapshotsStore: ExecutorPodsSnapshotsStore, + clock: Clock) extends Logging { + + private val EXECUTOR_ID_COUNTER = new AtomicLong(0L) + + private val totalExpectedExecutors = new AtomicInteger(0) + + private val podAllocationSize = conf.get(KUBERNETES_ALLOCATION_BATCH_SIZE) + + private val podAllocationDelay = conf.get(KUBERNETES_ALLOCATION_BATCH_DELAY) + + private val podCreationTimeout = math.max(podAllocationDelay * 5, 60000) + + private val kubernetesDriverPodName = conf + .get(KUBERNETES_DRIVER_POD_NAME) + .getOrElse(throw new SparkException("Must specify the driver pod name")) + + private val driverPod = kubernetesClient.pods() + .withName(kubernetesDriverPodName) + .get() + + // Executor IDs that have been requested from Kubernetes but have not been detected in any + // snapshot yet. Mapped to the timestamp when they were created. + private val newlyCreatedExecutors = mutable.Map.empty[Long, Long] + + def start(applicationId: String): Unit = { + snapshotsStore.addSubscriber(podAllocationDelay) { + onNewSnapshots(applicationId, _) + } + } + + def setTotalExpectedExecutors(total: Int): Unit = totalExpectedExecutors.set(total) + + private def onNewSnapshots(applicationId: String, snapshots: Seq[ExecutorPodsSnapshot]): Unit = { + newlyCreatedExecutors --= snapshots.flatMap(_.executorPods.keys) + // For all executors we've created against the API but have not seen in a snapshot + // yet - check the current time. If the current time has exceeded some threshold, + // assume that the pod was either never created (the API server never properly + // handled the creation request), or the API server created the pod but we missed + // both the creation and deletion events. In either case, delete the missing pod + // if possible, and mark such a pod to be rescheduled below. + newlyCreatedExecutors.foreach { case (execId, timeCreated) => + val currentTime = clock.getTimeMillis() + if (currentTime - timeCreated > podCreationTimeout) { + logWarning(s"Executor with id $execId was not detected in the Kubernetes" + + s" cluster after $podCreationTimeout milliseconds despite the fact that a" + + " previous allocation attempt tried to create it. The executor may have been" + + " deleted but the application missed the deletion event.") + Utils.tryLogNonFatalError { + kubernetesClient + .pods() + .withLabel(SPARK_EXECUTOR_ID_LABEL, execId.toString) + .delete() + } + newlyCreatedExecutors -= execId + } else { + logDebug(s"Executor with id $execId was not found in the Kubernetes cluster since it" + + s" was created ${currentTime - timeCreated} milliseconds ago.") + } + } + + if (snapshots.nonEmpty) { + // Only need to examine the cluster as of the latest snapshot, the "current" state, to see if + // we need to allocate more executors or not. + val latestSnapshot = snapshots.last + val currentRunningExecutors = latestSnapshot.executorPods.values.count { + case PodRunning(_) => true + case _ => false + } + val currentPendingExecutors = latestSnapshot.executorPods.values.count { + case PodPending(_) => true + case _ => false + } + val currentTotalExpectedExecutors = totalExpectedExecutors.get + logDebug(s"Currently have $currentRunningExecutors running executors and" + + s" $currentPendingExecutors pending executors. $newlyCreatedExecutors executors" + + s" have been requested but are pending appearance in the cluster.") + if (newlyCreatedExecutors.isEmpty + && currentPendingExecutors == 0 + && currentRunningExecutors < currentTotalExpectedExecutors) { + val numExecutorsToAllocate = math.min( + currentTotalExpectedExecutors - currentRunningExecutors, podAllocationSize) + logInfo(s"Going to request $numExecutorsToAllocate executors from Kubernetes.") + for ( _ <- 0 until numExecutorsToAllocate) { + val newExecutorId = EXECUTOR_ID_COUNTER.incrementAndGet() + val executorConf = KubernetesConf.createExecutorConf( + conf, + newExecutorId.toString, + applicationId, + driverPod) + val executorPod = executorBuilder.buildFromFeatures(executorConf) + val podWithAttachedContainer = new PodBuilder(executorPod.pod) + .editOrNewSpec() + .addToContainers(executorPod.container) + .endSpec() + .build() + kubernetesClient.pods().create(podWithAttachedContainer) + newlyCreatedExecutors(newExecutorId) = clock.getTimeMillis() + logDebug(s"Requested executor with id $newExecutorId from Kubernetes.") + } + } else if (currentRunningExecutors >= currentTotalExpectedExecutors) { + // TODO handle edge cases if we end up with more running executors than expected. + logDebug("Current number of running executors is equal to the number of requested" + + " executors. Not scaling up further.") + } else if (newlyCreatedExecutors.nonEmpty || currentPendingExecutors != 0) { + logDebug(s"Still waiting for ${newlyCreatedExecutors.size + currentPendingExecutors}" + + s" executors to begin running before requesting for more executors. # of executors in" + + s" pending status in the cluster: $currentPendingExecutors. # of executors that we have" + + s" created but we have not observed as being present in the cluster yet:" + + s" ${newlyCreatedExecutors.size}.") + } + } + } +} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsLifecycleManager.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsLifecycleManager.scala new file mode 100644 index 0000000000000..b28d93990313e --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsLifecycleManager.scala @@ -0,0 +1,176 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.scheduler.cluster.k8s + +import com.google.common.cache.Cache +import io.fabric8.kubernetes.api.model.Pod +import io.fabric8.kubernetes.client.KubernetesClient +import scala.collection.JavaConverters._ +import scala.collection.mutable + +import org.apache.spark.SparkConf +import org.apache.spark.deploy.k8s.Config._ +import org.apache.spark.internal.Logging +import org.apache.spark.scheduler.ExecutorExited +import org.apache.spark.util.Utils + +private[spark] class ExecutorPodsLifecycleManager( + conf: SparkConf, + executorBuilder: KubernetesExecutorBuilder, + kubernetesClient: KubernetesClient, + snapshotsStore: ExecutorPodsSnapshotsStore, + // Use a best-effort to track which executors have been removed already. It's not generally + // job-breaking if we remove executors more than once but it's ideal if we make an attempt + // to avoid doing so. Expire cache entries so that this data structure doesn't grow beyond + // bounds. + removedExecutorsCache: Cache[java.lang.Long, java.lang.Long]) extends Logging { + + import ExecutorPodsLifecycleManager._ + + private val eventProcessingInterval = conf.get(KUBERNETES_EXECUTOR_EVENT_PROCESSING_INTERVAL) + + def start(schedulerBackend: KubernetesClusterSchedulerBackend): Unit = { + snapshotsStore.addSubscriber(eventProcessingInterval) { + onNewSnapshots(schedulerBackend, _) + } + } + + private def onNewSnapshots( + schedulerBackend: KubernetesClusterSchedulerBackend, + snapshots: Seq[ExecutorPodsSnapshot]): Unit = { + val execIdsRemovedInThisRound = mutable.HashSet.empty[Long] + snapshots.foreach { snapshot => + snapshot.executorPods.foreach { case (execId, state) => + state match { + case deleted@PodDeleted(_) => + logDebug(s"Snapshot reported deleted executor with id $execId," + + s" pod name ${state.pod.getMetadata.getName}") + removeExecutorFromSpark(schedulerBackend, deleted, execId) + execIdsRemovedInThisRound += execId + case failed@PodFailed(_) => + logDebug(s"Snapshot reported failed executor with id $execId," + + s" pod name ${state.pod.getMetadata.getName}") + onFinalNonDeletedState(failed, execId, schedulerBackend, execIdsRemovedInThisRound) + case succeeded@PodSucceeded(_) => + logDebug(s"Snapshot reported succeeded executor with id $execId," + + s" pod name ${state.pod.getMetadata.getName}. Note that succeeded executors are" + + s" unusual unless Spark specifically informed the executor to exit.") + onFinalNonDeletedState(succeeded, execId, schedulerBackend, execIdsRemovedInThisRound) + case _ => + } + } + } + + // Reconcile the case where Spark claims to know about an executor but the corresponding pod + // is missing from the cluster. This would occur if we miss a deletion event and the pod + // transitions immediately from running io absent. We only need to check against the latest + // snapshot for this, and we don't do this for executors in the deleted executors cache or + // that we just removed in this round. + if (snapshots.nonEmpty) { + val latestSnapshot = snapshots.last + (schedulerBackend.getExecutorIds().map(_.toLong).toSet + -- latestSnapshot.executorPods.keySet + -- execIdsRemovedInThisRound).foreach { missingExecutorId => + if (removedExecutorsCache.getIfPresent(missingExecutorId) == null) { + val exitReasonMessage = s"The executor with ID $missingExecutorId was not found in the" + + s" cluster but we didn't get a reason why. Marking the executor as failed. The" + + s" executor may have been deleted but the driver missed the deletion event." + logDebug(exitReasonMessage) + val exitReason = ExecutorExited( + UNKNOWN_EXIT_CODE, + exitCausedByApp = false, + exitReasonMessage) + schedulerBackend.doRemoveExecutor(missingExecutorId.toString, exitReason) + execIdsRemovedInThisRound += missingExecutorId + } + } + } + logDebug(s"Removed executors with ids ${execIdsRemovedInThisRound.mkString(",")}" + + s" from Spark that were either found to be deleted or non-existent in the cluster.") + } + + private def onFinalNonDeletedState( + podState: FinalPodState, + execId: Long, + schedulerBackend: KubernetesClusterSchedulerBackend, + execIdsRemovedInRound: mutable.Set[Long]): Unit = { + removeExecutorFromK8s(podState.pod) + removeExecutorFromSpark(schedulerBackend, podState, execId) + execIdsRemovedInRound += execId + } + + private def removeExecutorFromK8s(updatedPod: Pod): Unit = { + // If deletion failed on a previous try, we can try again if resync informs us the pod + // is still around. + // Delete as best attempt - duplicate deletes will throw an exception but the end state + // of getting rid of the pod is what matters. + Utils.tryLogNonFatalError { + kubernetesClient + .pods() + .withName(updatedPod.getMetadata.getName) + .delete() + } + } + + private def removeExecutorFromSpark( + schedulerBackend: KubernetesClusterSchedulerBackend, + podState: FinalPodState, + execId: Long): Unit = { + if (removedExecutorsCache.getIfPresent(execId) == null) { + removedExecutorsCache.put(execId, execId) + val exitReason = findExitReason(podState, execId) + schedulerBackend.doRemoveExecutor(execId.toString, exitReason) + } + } + + private def findExitReason(podState: FinalPodState, execId: Long): ExecutorExited = { + val exitCode = findExitCode(podState) + val (exitCausedByApp, exitMessage) = podState match { + case PodDeleted(_) => + (false, s"The executor with id $execId was deleted by a user or the framework.") + case _ => + val msg = exitReasonMessage(podState, execId, exitCode) + (true, msg) + } + ExecutorExited(exitCode, exitCausedByApp, exitMessage) + } + + private def exitReasonMessage(podState: FinalPodState, execId: Long, exitCode: Int) = { + val pod = podState.pod + s""" + |The executor with id $execId exited with exit code $exitCode. + |The API gave the following brief reason: ${pod.getStatus.getReason} + |The API gave the following message: ${pod.getStatus.getMessage} + |The API gave the following container statuses: + | + |${pod.getStatus.getContainerStatuses.asScala.map(_.toString).mkString("\n===\n")} + """.stripMargin + } + + private def findExitCode(podState: FinalPodState): Int = { + podState.pod.getStatus.getContainerStatuses.asScala.find { containerStatus => + containerStatus.getState.getTerminated != null + }.map { terminatedContainer => + terminatedContainer.getState.getTerminated.getExitCode.toInt + }.getOrElse(UNKNOWN_EXIT_CODE) + } +} + +private object ExecutorPodsLifecycleManager { + val UNKNOWN_EXIT_CODE = -1 +} + diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsPollingSnapshotSource.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsPollingSnapshotSource.scala new file mode 100644 index 0000000000000..e77e604d00e0f --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsPollingSnapshotSource.scala @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.scheduler.cluster.k8s + +import java.util.concurrent.{Future, ScheduledExecutorService, TimeUnit} + +import io.fabric8.kubernetes.client.KubernetesClient +import scala.collection.JavaConverters._ + +import org.apache.spark.SparkConf +import org.apache.spark.deploy.k8s.Config._ +import org.apache.spark.deploy.k8s.Constants._ +import org.apache.spark.internal.Logging +import org.apache.spark.util.ThreadUtils + +private[spark] class ExecutorPodsPollingSnapshotSource( + conf: SparkConf, + kubernetesClient: KubernetesClient, + snapshotsStore: ExecutorPodsSnapshotsStore, + pollingExecutor: ScheduledExecutorService) extends Logging { + + private val pollingInterval = conf.get(KUBERNETES_EXECUTOR_API_POLLING_INTERVAL) + + private var pollingFuture: Future[_] = _ + + def start(applicationId: String): Unit = { + require(pollingFuture == null, "Cannot start polling more than once.") + logDebug(s"Starting to check for executor pod state every $pollingInterval ms.") + pollingFuture = pollingExecutor.scheduleWithFixedDelay( + new PollRunnable(applicationId), pollingInterval, pollingInterval, TimeUnit.MILLISECONDS) + } + + def stop(): Unit = { + if (pollingFuture != null) { + pollingFuture.cancel(true) + pollingFuture = null + } + ThreadUtils.shutdown(pollingExecutor) + } + + private class PollRunnable(applicationId: String) extends Runnable { + override def run(): Unit = { + logDebug(s"Resynchronizing full executor pod state from Kubernetes.") + snapshotsStore.replaceSnapshot(kubernetesClient + .pods() + .withLabel(SPARK_APP_ID_LABEL, applicationId) + .withLabel(SPARK_ROLE_LABEL, SPARK_POD_EXECUTOR_ROLE) + .list() + .getItems + .asScala) + } + } + +} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsSnapshot.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsSnapshot.scala new file mode 100644 index 0000000000000..26be918043412 --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsSnapshot.scala @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.scheduler.cluster.k8s + +import io.fabric8.kubernetes.api.model.Pod + +import org.apache.spark.deploy.k8s.Constants._ +import org.apache.spark.internal.Logging + +/** + * An immutable view of the current executor pods that are running in the cluster. + */ +private[spark] case class ExecutorPodsSnapshot(executorPods: Map[Long, ExecutorPodState]) { + + import ExecutorPodsSnapshot._ + + def withUpdate(updatedPod: Pod): ExecutorPodsSnapshot = { + val newExecutorPods = executorPods ++ toStatesByExecutorId(Seq(updatedPod)) + new ExecutorPodsSnapshot(newExecutorPods) + } +} + +object ExecutorPodsSnapshot extends Logging { + + def apply(executorPods: Seq[Pod]): ExecutorPodsSnapshot = { + ExecutorPodsSnapshot(toStatesByExecutorId(executorPods)) + } + + def apply(): ExecutorPodsSnapshot = ExecutorPodsSnapshot(Map.empty[Long, ExecutorPodState]) + + private def toStatesByExecutorId(executorPods: Seq[Pod]): Map[Long, ExecutorPodState] = { + executorPods.map { pod => + (pod.getMetadata.getLabels.get(SPARK_EXECUTOR_ID_LABEL).toLong, toState(pod)) + }.toMap + } + + private def toState(pod: Pod): ExecutorPodState = { + if (isDeleted(pod)) { + PodDeleted(pod) + } else { + val phase = pod.getStatus.getPhase.toLowerCase + phase match { + case "pending" => + PodPending(pod) + case "running" => + PodRunning(pod) + case "failed" => + PodFailed(pod) + case "succeeded" => + PodSucceeded(pod) + case _ => + logWarning(s"Received unknown phase $phase for executor pod with name" + + s" ${pod.getMetadata.getName} in namespace ${pod.getMetadata.getNamespace}") + PodUnknown(pod) + } + } + } + + private def isDeleted(pod: Pod): Boolean = pod.getMetadata.getDeletionTimestamp != null +} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsSnapshotsStore.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsSnapshotsStore.scala new file mode 100644 index 0000000000000..dd264332cf9e8 --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsSnapshotsStore.scala @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.scheduler.cluster.k8s + +import io.fabric8.kubernetes.api.model.Pod + +private[spark] trait ExecutorPodsSnapshotsStore { + + def addSubscriber + (processBatchIntervalMillis: Long) + (onNewSnapshots: Seq[ExecutorPodsSnapshot] => Unit) + + def stop(): Unit + + def updatePod(updatedPod: Pod): Unit + + def replaceSnapshot(newSnapshot: Seq[Pod]): Unit +} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsSnapshotsStoreImpl.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsSnapshotsStoreImpl.scala new file mode 100644 index 0000000000000..5583b4617eeb2 --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsSnapshotsStoreImpl.scala @@ -0,0 +1,113 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.scheduler.cluster.k8s + +import java.util.concurrent._ + +import io.fabric8.kubernetes.api.model.Pod +import javax.annotation.concurrent.GuardedBy +import scala.collection.JavaConverters._ +import scala.collection.mutable + +import org.apache.spark.util.{ThreadUtils, Utils} + +/** + * Controls the propagation of the Spark application's executor pods state to subscribers that + * react to that state. + *
    + * Roughly follows a producer-consumer model. Producers report states of executor pods, and these + * states are then published to consumers that can perform any actions in response to these states. + *
    + * Producers push updates in one of two ways. An incremental update sent by updatePod() represents + * a known new state of a single executor pod. A full sync sent by replaceSnapshot() indicates that + * the passed pods are all of the most up to date states of all executor pods for the application. + * The combination of the states of all executor pods for the application is collectively known as + * a snapshot. The store keeps track of the most up to date snapshot, and applies updates to that + * most recent snapshot - either by incrementally updating the snapshot with a single new pod state, + * or by replacing the snapshot entirely on a full sync. + *
    + * Consumers, or subscribers, register that they want to be informed about all snapshots of the + * executor pods. Every time the store replaces its most up to date snapshot from either an + * incremental update or a full sync, the most recent snapshot after the update is posted to the + * subscriber's buffer. Subscribers receive blocks of snapshots produced by the producers in + * time-windowed chunks. Each subscriber can choose to receive their snapshot chunks at different + * time intervals. + */ +private[spark] class ExecutorPodsSnapshotsStoreImpl(subscribersExecutor: ScheduledExecutorService) + extends ExecutorPodsSnapshotsStore { + + private val SNAPSHOT_LOCK = new Object() + + private val subscribers = mutable.Buffer.empty[SnapshotsSubscriber] + private val pollingTasks = mutable.Buffer.empty[Future[_]] + + @GuardedBy("SNAPSHOT_LOCK") + private var currentSnapshot = ExecutorPodsSnapshot() + + override def addSubscriber( + processBatchIntervalMillis: Long) + (onNewSnapshots: Seq[ExecutorPodsSnapshot] => Unit): Unit = { + val newSubscriber = SnapshotsSubscriber( + new LinkedBlockingQueue[ExecutorPodsSnapshot](), onNewSnapshots) + SNAPSHOT_LOCK.synchronized { + newSubscriber.snapshotsBuffer.add(currentSnapshot) + } + subscribers += newSubscriber + pollingTasks += subscribersExecutor.scheduleWithFixedDelay( + toRunnable(() => callSubscriber(newSubscriber)), + 0L, + processBatchIntervalMillis, + TimeUnit.MILLISECONDS) + } + + override def stop(): Unit = { + pollingTasks.foreach(_.cancel(true)) + ThreadUtils.shutdown(subscribersExecutor) + } + + override def updatePod(updatedPod: Pod): Unit = SNAPSHOT_LOCK.synchronized { + currentSnapshot = currentSnapshot.withUpdate(updatedPod) + addCurrentSnapshotToSubscribers() + } + + override def replaceSnapshot(newSnapshot: Seq[Pod]): Unit = SNAPSHOT_LOCK.synchronized { + currentSnapshot = ExecutorPodsSnapshot(newSnapshot) + addCurrentSnapshotToSubscribers() + } + + private def addCurrentSnapshotToSubscribers(): Unit = { + subscribers.foreach { subscriber => + subscriber.snapshotsBuffer.add(currentSnapshot) + } + } + + private def callSubscriber(subscriber: SnapshotsSubscriber): Unit = { + Utils.tryLogNonFatalError { + val currentSnapshots = mutable.Buffer.empty[ExecutorPodsSnapshot].asJava + subscriber.snapshotsBuffer.drainTo(currentSnapshots) + subscriber.onNewSnapshots(currentSnapshots.asScala) + } + } + + private def toRunnable[T](runnable: () => Unit): Runnable = new Runnable { + override def run(): Unit = runnable() + } + + private case class SnapshotsSubscriber( + snapshotsBuffer: BlockingQueue[ExecutorPodsSnapshot], + onNewSnapshots: Seq[ExecutorPodsSnapshot] => Unit) +} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsWatchSnapshotSource.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsWatchSnapshotSource.scala new file mode 100644 index 0000000000000..a6749a644e00c --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsWatchSnapshotSource.scala @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.scheduler.cluster.k8s + +import java.io.Closeable + +import io.fabric8.kubernetes.api.model.Pod +import io.fabric8.kubernetes.client.{KubernetesClient, KubernetesClientException, Watcher} +import io.fabric8.kubernetes.client.Watcher.Action + +import org.apache.spark.deploy.k8s.Constants._ +import org.apache.spark.internal.Logging +import org.apache.spark.util.Utils + +private[spark] class ExecutorPodsWatchSnapshotSource( + snapshotsStore: ExecutorPodsSnapshotsStore, + kubernetesClient: KubernetesClient) extends Logging { + + private var watchConnection: Closeable = _ + + def start(applicationId: String): Unit = { + require(watchConnection == null, "Cannot start the watcher twice.") + logDebug(s"Starting watch for pods with labels $SPARK_APP_ID_LABEL=$applicationId," + + s" $SPARK_ROLE_LABEL=$SPARK_POD_EXECUTOR_ROLE.") + watchConnection = kubernetesClient.pods() + .withLabel(SPARK_APP_ID_LABEL, applicationId) + .withLabel(SPARK_ROLE_LABEL, SPARK_POD_EXECUTOR_ROLE) + .watch(new ExecutorPodsWatcher()) + } + + def stop(): Unit = { + if (watchConnection != null) { + Utils.tryLogNonFatalError { + watchConnection.close() + } + watchConnection = null + } + } + + private class ExecutorPodsWatcher extends Watcher[Pod] { + override def eventReceived(action: Action, pod: Pod): Unit = { + val podName = pod.getMetadata.getName + logDebug(s"Received executor pod update for pod named $podName, action $action") + snapshotsStore.updatePod(pod) + } + + override def onClose(e: KubernetesClientException): Unit = { + logWarning("Kubernetes client has been closed (this is expected if the application is" + + " shutting down.)", e) + } + } + +} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterManager.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterManager.scala index 0ea80dfbc0d97..de2a52bc7a0b8 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterManager.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterManager.scala @@ -17,7 +17,9 @@ package org.apache.spark.scheduler.cluster.k8s import java.io.File +import java.util.concurrent.TimeUnit +import com.google.common.cache.CacheBuilder import io.fabric8.kubernetes.client.Config import org.apache.spark.{SparkContext, SparkException} @@ -26,7 +28,7 @@ import org.apache.spark.deploy.k8s.Config._ import org.apache.spark.deploy.k8s.Constants._ import org.apache.spark.internal.Logging import org.apache.spark.scheduler.{ExternalClusterManager, SchedulerBackend, TaskScheduler, TaskSchedulerImpl} -import org.apache.spark.util.ThreadUtils +import org.apache.spark.util.{SystemClock, ThreadUtils} private[spark] class KubernetesClusterManager extends ExternalClusterManager with Logging { @@ -46,8 +48,6 @@ private[spark] class KubernetesClusterManager extends ExternalClusterManager wit sc: SparkContext, masterURL: String, scheduler: TaskScheduler): SchedulerBackend = { - val executorSecretNamesToMountPaths = KubernetesUtils.parsePrefixedKeyValuePairs( - sc.conf, KUBERNETES_EXECUTOR_SECRETS_PREFIX) val kubernetesClient = SparkKubernetesClientFactory.createKubernetesClient( KUBERNETES_MASTER_INTERNAL_URL, Some(sc.conf.get(KUBERNETES_NAMESPACE)), @@ -56,17 +56,45 @@ private[spark] class KubernetesClusterManager extends ExternalClusterManager wit Some(new File(Config.KUBERNETES_SERVICE_ACCOUNT_TOKEN_PATH)), Some(new File(Config.KUBERNETES_SERVICE_ACCOUNT_CA_CRT_PATH))) - val allocatorExecutor = ThreadUtils - .newDaemonSingleThreadScheduledExecutor("kubernetes-pod-allocator") val requestExecutorsService = ThreadUtils.newDaemonCachedThreadPool( "kubernetes-executor-requests") + + val subscribersExecutor = ThreadUtils + .newDaemonThreadPoolScheduledExecutor( + "kubernetes-executor-snapshots-subscribers", 2) + val snapshotsStore = new ExecutorPodsSnapshotsStoreImpl(subscribersExecutor) + val removedExecutorsCache = CacheBuilder.newBuilder() + .expireAfterWrite(3, TimeUnit.MINUTES) + .build[java.lang.Long, java.lang.Long]() + val executorPodsLifecycleEventHandler = new ExecutorPodsLifecycleManager( + sc.conf, + new KubernetesExecutorBuilder(), + kubernetesClient, + snapshotsStore, + removedExecutorsCache) + + val executorPodsAllocator = new ExecutorPodsAllocator( + sc.conf, new KubernetesExecutorBuilder(), kubernetesClient, snapshotsStore, new SystemClock()) + + val podsWatchEventSource = new ExecutorPodsWatchSnapshotSource( + snapshotsStore, + kubernetesClient) + + val eventsPollingExecutor = ThreadUtils.newDaemonSingleThreadScheduledExecutor( + "kubernetes-executor-pod-polling-sync") + val podsPollingEventSource = new ExecutorPodsPollingSnapshotSource( + sc.conf, kubernetesClient, snapshotsStore, eventsPollingExecutor) + new KubernetesClusterSchedulerBackend( scheduler.asInstanceOf[TaskSchedulerImpl], sc.env.rpcEnv, - new KubernetesExecutorBuilder, kubernetesClient, - allocatorExecutor, - requestExecutorsService) + requestExecutorsService, + snapshotsStore, + executorPodsAllocator, + executorPodsLifecycleEventHandler, + podsWatchEventSource, + podsPollingEventSource) } override def initialize(scheduler: TaskScheduler, backend: SchedulerBackend): Unit = { diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackend.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackend.scala index d86664c81071b..fa6dc2c479bbf 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackend.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackend.scala @@ -16,60 +16,32 @@ */ package org.apache.spark.scheduler.cluster.k8s -import java.io.Closeable -import java.net.InetAddress -import java.util.concurrent.{ConcurrentHashMap, ExecutorService, ScheduledExecutorService, TimeUnit} -import java.util.concurrent.atomic.{AtomicInteger, AtomicLong, AtomicReference} -import javax.annotation.concurrent.GuardedBy +import java.util.concurrent.ExecutorService -import io.fabric8.kubernetes.api.model._ -import io.fabric8.kubernetes.client.{KubernetesClient, KubernetesClientException, Watcher} -import io.fabric8.kubernetes.client.Watcher.Action -import scala.collection.JavaConverters._ -import scala.collection.mutable +import io.fabric8.kubernetes.client.KubernetesClient import scala.concurrent.{ExecutionContext, Future} -import org.apache.spark.SparkException -import org.apache.spark.deploy.k8s.Config._ import org.apache.spark.deploy.k8s.Constants._ -import org.apache.spark.deploy.k8s.KubernetesConf -import org.apache.spark.rpc.{RpcAddress, RpcEndpointAddress, RpcEnv} -import org.apache.spark.scheduler.{ExecutorExited, SlaveLost, TaskSchedulerImpl} +import org.apache.spark.rpc.{RpcAddress, RpcEnv} +import org.apache.spark.scheduler.{ExecutorLossReason, TaskSchedulerImpl} import org.apache.spark.scheduler.cluster.{CoarseGrainedSchedulerBackend, SchedulerBackendUtils} -import org.apache.spark.util.Utils +import org.apache.spark.util.{ThreadUtils, Utils} private[spark] class KubernetesClusterSchedulerBackend( scheduler: TaskSchedulerImpl, rpcEnv: RpcEnv, - executorBuilder: KubernetesExecutorBuilder, kubernetesClient: KubernetesClient, - allocatorExecutor: ScheduledExecutorService, - requestExecutorsService: ExecutorService) + requestExecutorsService: ExecutorService, + snapshotsStore: ExecutorPodsSnapshotsStore, + podAllocator: ExecutorPodsAllocator, + lifecycleEventHandler: ExecutorPodsLifecycleManager, + watchEvents: ExecutorPodsWatchSnapshotSource, + pollEvents: ExecutorPodsPollingSnapshotSource) extends CoarseGrainedSchedulerBackend(scheduler, rpcEnv) { - import KubernetesClusterSchedulerBackend._ - - private val EXECUTOR_ID_COUNTER = new AtomicLong(0L) - private val RUNNING_EXECUTOR_PODS_LOCK = new Object - @GuardedBy("RUNNING_EXECUTOR_PODS_LOCK") - private val runningExecutorsToPods = new mutable.HashMap[String, Pod] - private val executorPodsByIPs = new ConcurrentHashMap[String, Pod]() - private val podsWithKnownExitReasons = new ConcurrentHashMap[String, ExecutorExited]() - private val disconnectedPodsByExecutorIdPendingRemoval = new ConcurrentHashMap[String, Pod]() - - private val kubernetesNamespace = conf.get(KUBERNETES_NAMESPACE) - - private val kubernetesDriverPodName = conf - .get(KUBERNETES_DRIVER_POD_NAME) - .getOrElse(throw new SparkException("Must specify the driver pod name")) private implicit val requestExecutorContext = ExecutionContext.fromExecutorService( requestExecutorsService) - private val driverPod = kubernetesClient.pods() - .inNamespace(kubernetesNamespace) - .withName(kubernetesDriverPodName) - .get() - protected override val minRegisteredRatio = if (conf.getOption("spark.scheduler.minRegisteredResourcesRatio").isEmpty) { 0.8 @@ -77,372 +49,93 @@ private[spark] class KubernetesClusterSchedulerBackend( super.minRegisteredRatio } - private val executorWatchResource = new AtomicReference[Closeable] - private val totalExpectedExecutors = new AtomicInteger(0) - - private val driverUrl = RpcEndpointAddress( - conf.get("spark.driver.host"), - conf.getInt("spark.driver.port", DEFAULT_DRIVER_PORT), - CoarseGrainedSchedulerBackend.ENDPOINT_NAME).toString - private val initialExecutors = SchedulerBackendUtils.getInitialTargetExecutorNumber(conf) - private val podAllocationInterval = conf.get(KUBERNETES_ALLOCATION_BATCH_DELAY) - - private val podAllocationSize = conf.get(KUBERNETES_ALLOCATION_BATCH_SIZE) - - private val executorLostReasonCheckMaxAttempts = conf.get( - KUBERNETES_EXECUTOR_LOST_REASON_CHECK_MAX_ATTEMPTS) - - private val allocatorRunnable = new Runnable { - - // Maintains a map of executor id to count of checks performed to learn the loss reason - // for an executor. - private val executorReasonCheckAttemptCounts = new mutable.HashMap[String, Int] - - override def run(): Unit = { - handleDisconnectedExecutors() - - val executorsToAllocate = mutable.Map[String, Pod]() - val currentTotalRegisteredExecutors = totalRegisteredExecutors.get - val currentTotalExpectedExecutors = totalExpectedExecutors.get - val currentNodeToLocalTaskCount = getNodesWithLocalTaskCounts() - RUNNING_EXECUTOR_PODS_LOCK.synchronized { - if (currentTotalRegisteredExecutors < runningExecutorsToPods.size) { - logDebug("Waiting for pending executors before scaling") - } else if (currentTotalExpectedExecutors <= runningExecutorsToPods.size) { - logDebug("Maximum allowed executor limit reached. Not scaling up further.") - } else { - for (_ <- 0 until math.min( - currentTotalExpectedExecutors - runningExecutorsToPods.size, podAllocationSize)) { - val executorId = EXECUTOR_ID_COUNTER.incrementAndGet().toString - val executorConf = KubernetesConf.createExecutorConf( - conf, - executorId, - applicationId(), - driverPod) - val executorPod = executorBuilder.buildFromFeatures(executorConf) - val podWithAttachedContainer = new PodBuilder(executorPod.pod) - .editOrNewSpec() - .addToContainers(executorPod.container) - .endSpec() - .build() - - executorsToAllocate(executorId) = podWithAttachedContainer - logInfo( - s"Requesting a new executor, total executors is now ${runningExecutorsToPods.size}") - } - } - } - - val allocatedExecutors = executorsToAllocate.mapValues { pod => - Utils.tryLog { - kubernetesClient.pods().create(pod) - } - } - - RUNNING_EXECUTOR_PODS_LOCK.synchronized { - allocatedExecutors.map { - case (executorId, attemptedAllocatedExecutor) => - attemptedAllocatedExecutor.map { successfullyAllocatedExecutor => - runningExecutorsToPods.put(executorId, successfullyAllocatedExecutor) - } - } - } - } - - def handleDisconnectedExecutors(): Unit = { - // For each disconnected executor, synchronize with the loss reasons that may have been found - // by the executor pod watcher. If the loss reason was discovered by the watcher, - // inform the parent class with removeExecutor. - disconnectedPodsByExecutorIdPendingRemoval.asScala.foreach { - case (executorId, executorPod) => - val knownExitReason = Option(podsWithKnownExitReasons.remove( - executorPod.getMetadata.getName)) - knownExitReason.fold { - removeExecutorOrIncrementLossReasonCheckCount(executorId) - } { executorExited => - logWarning(s"Removing executor $executorId with loss reason " + executorExited.message) - removeExecutor(executorId, executorExited) - // We don't delete the pod running the executor that has an exit condition caused by - // the application from the Kubernetes API server. This allows users to debug later on - // through commands such as "kubectl logs " and - // "kubectl describe pod ". Note that exited containers have terminated and - // therefore won't take CPU and memory resources. - // Otherwise, the executor pod is marked to be deleted from the API server. - if (executorExited.exitCausedByApp) { - logInfo(s"Executor $executorId exited because of the application.") - deleteExecutorFromDataStructures(executorId) - } else { - logInfo(s"Executor $executorId failed because of a framework error.") - deleteExecutorFromClusterAndDataStructures(executorId) - } - } - } - } - - def removeExecutorOrIncrementLossReasonCheckCount(executorId: String): Unit = { - val reasonCheckCount = executorReasonCheckAttemptCounts.getOrElse(executorId, 0) - if (reasonCheckCount >= executorLostReasonCheckMaxAttempts) { - removeExecutor(executorId, SlaveLost("Executor lost for unknown reasons.")) - deleteExecutorFromClusterAndDataStructures(executorId) - } else { - executorReasonCheckAttemptCounts.put(executorId, reasonCheckCount + 1) - } - } - - def deleteExecutorFromClusterAndDataStructures(executorId: String): Unit = { - deleteExecutorFromDataStructures(executorId).foreach { pod => - kubernetesClient.pods().delete(pod) - } - } - - def deleteExecutorFromDataStructures(executorId: String): Option[Pod] = { - disconnectedPodsByExecutorIdPendingRemoval.remove(executorId) - executorReasonCheckAttemptCounts -= executorId - podsWithKnownExitReasons.remove(executorId) - RUNNING_EXECUTOR_PODS_LOCK.synchronized { - runningExecutorsToPods.remove(executorId).orElse { - logWarning(s"Unable to remove pod for unknown executor $executorId") - None - } - } - } - } - - override def sufficientResourcesRegistered(): Boolean = { - totalRegisteredExecutors.get() >= initialExecutors * minRegisteredRatio + // Allow removeExecutor to be accessible by ExecutorPodsLifecycleEventHandler + private[k8s] def doRemoveExecutor(executorId: String, reason: ExecutorLossReason): Unit = { + removeExecutor(executorId, reason) } override def start(): Unit = { super.start() - executorWatchResource.set( - kubernetesClient - .pods() - .withLabel(SPARK_APP_ID_LABEL, applicationId()) - .watch(new ExecutorPodsWatcher())) - - allocatorExecutor.scheduleWithFixedDelay( - allocatorRunnable, 0L, podAllocationInterval, TimeUnit.MILLISECONDS) - if (!Utils.isDynamicAllocationEnabled(conf)) { - doRequestTotalExecutors(initialExecutors) + podAllocator.setTotalExpectedExecutors(initialExecutors) } + lifecycleEventHandler.start(this) + podAllocator.start(applicationId()) + watchEvents.start(applicationId()) + pollEvents.start(applicationId()) } override def stop(): Unit = { - // stop allocation of new resources and caches. - allocatorExecutor.shutdown() - allocatorExecutor.awaitTermination(30, TimeUnit.SECONDS) - - // send stop message to executors so they shut down cleanly super.stop() - try { - val resource = executorWatchResource.getAndSet(null) - if (resource != null) { - resource.close() - } - } catch { - case e: Throwable => logWarning("Failed to close the executor pod watcher", e) + Utils.tryLogNonFatalError { + snapshotsStore.stop() } - // then delete the executor pods Utils.tryLogNonFatalError { - deleteExecutorPodsOnStop() - executorPodsByIPs.clear() + watchEvents.stop() } + Utils.tryLogNonFatalError { - logInfo("Closing kubernetes client") - kubernetesClient.close() + pollEvents.stop() } - } - /** - * @return A map of K8s cluster nodes to the number of tasks that could benefit from data - * locality if an executor launches on the cluster node. - */ - private def getNodesWithLocalTaskCounts() : Map[String, Int] = { - val nodeToLocalTaskCount = synchronized { - mutable.Map[String, Int]() ++ hostToLocalTaskCount + Utils.tryLogNonFatalError { + kubernetesClient.pods() + .withLabel(SPARK_APP_ID_LABEL, applicationId()) + .withLabel(SPARK_ROLE_LABEL, SPARK_POD_EXECUTOR_ROLE) + .delete() } - for (pod <- executorPodsByIPs.values().asScala) { - // Remove cluster nodes that are running our executors already. - // TODO: This prefers spreading out executors across nodes. In case users want - // consolidating executors on fewer nodes, introduce a flag. See the spark.deploy.spreadOut - // flag that Spark standalone has: https://spark.apache.org/docs/latest/spark-standalone.html - nodeToLocalTaskCount.remove(pod.getSpec.getNodeName).nonEmpty || - nodeToLocalTaskCount.remove(pod.getStatus.getHostIP).nonEmpty || - nodeToLocalTaskCount.remove( - InetAddress.getByName(pod.getStatus.getHostIP).getCanonicalHostName).nonEmpty + Utils.tryLogNonFatalError { + ThreadUtils.shutdown(requestExecutorsService) + } + + Utils.tryLogNonFatalError { + kubernetesClient.close() } - nodeToLocalTaskCount.toMap[String, Int] } override def doRequestTotalExecutors(requestedTotal: Int): Future[Boolean] = Future[Boolean] { - totalExpectedExecutors.set(requestedTotal) + // TODO when we support dynamic allocation, the pod allocator should be told to process the + // current snapshot in order to decrease/increase the number of executors accordingly. + podAllocator.setTotalExpectedExecutors(requestedTotal) true } - override def doKillExecutors(executorIds: Seq[String]): Future[Boolean] = Future[Boolean] { - val podsToDelete = RUNNING_EXECUTOR_PODS_LOCK.synchronized { - executorIds.flatMap { executorId => - runningExecutorsToPods.remove(executorId) match { - case Some(pod) => - disconnectedPodsByExecutorIdPendingRemoval.put(executorId, pod) - Some(pod) - - case None => - logWarning(s"Unable to remove pod for unknown executor $executorId") - None - } - } - } - - kubernetesClient.pods().delete(podsToDelete: _*) - true + override def sufficientResourcesRegistered(): Boolean = { + totalRegisteredExecutors.get() >= initialExecutors * minRegisteredRatio } - private def deleteExecutorPodsOnStop(): Unit = { - val executorPodsToDelete = RUNNING_EXECUTOR_PODS_LOCK.synchronized { - val runningExecutorPodsCopy = Seq(runningExecutorsToPods.values.toSeq: _*) - runningExecutorsToPods.clear() - runningExecutorPodsCopy - } - kubernetesClient.pods().delete(executorPodsToDelete: _*) + override def getExecutorIds(): Seq[String] = synchronized { + super.getExecutorIds() } - private class ExecutorPodsWatcher extends Watcher[Pod] { - - private val DEFAULT_CONTAINER_FAILURE_EXIT_STATUS = -1 - - override def eventReceived(action: Action, pod: Pod): Unit = { - val podName = pod.getMetadata.getName - val podIP = pod.getStatus.getPodIP - - action match { - case Action.MODIFIED if (pod.getStatus.getPhase == "Running" - && pod.getMetadata.getDeletionTimestamp == null) => - val clusterNodeName = pod.getSpec.getNodeName - logInfo(s"Executor pod $podName ready, launched at $clusterNodeName as IP $podIP.") - executorPodsByIPs.put(podIP, pod) - - case Action.DELETED | Action.ERROR => - val executorId = getExecutorId(pod) - logDebug(s"Executor pod $podName at IP $podIP was at $action.") - if (podIP != null) { - executorPodsByIPs.remove(podIP) - } - - val executorExitReason = if (action == Action.ERROR) { - logWarning(s"Received error event of executor pod $podName. Reason: " + - pod.getStatus.getReason) - executorExitReasonOnError(pod) - } else if (action == Action.DELETED) { - logWarning(s"Received delete event of executor pod $podName. Reason: " + - pod.getStatus.getReason) - executorExitReasonOnDelete(pod) - } else { - throw new IllegalStateException( - s"Unknown action that should only be DELETED or ERROR: $action") - } - podsWithKnownExitReasons.put(pod.getMetadata.getName, executorExitReason) - - if (!disconnectedPodsByExecutorIdPendingRemoval.containsKey(executorId)) { - log.warn(s"Executor with id $executorId was not marked as disconnected, but the " + - s"watch received an event of type $action for this executor. The executor may " + - "have failed to start in the first place and never registered with the driver.") - } - disconnectedPodsByExecutorIdPendingRemoval.put(executorId, pod) - - case _ => logDebug(s"Received event of executor pod $podName: " + action) - } - } - - override def onClose(cause: KubernetesClientException): Unit = { - logDebug("Executor pod watch closed.", cause) - } - - private def getExecutorExitStatus(pod: Pod): Int = { - val containerStatuses = pod.getStatus.getContainerStatuses - if (!containerStatuses.isEmpty) { - // we assume the first container represents the pod status. This assumption may not hold - // true in the future. Revisit this if side-car containers start running inside executor - // pods. - getExecutorExitStatus(containerStatuses.get(0)) - } else DEFAULT_CONTAINER_FAILURE_EXIT_STATUS - } - - private def getExecutorExitStatus(containerStatus: ContainerStatus): Int = { - Option(containerStatus.getState).map { containerState => - Option(containerState.getTerminated).map { containerStateTerminated => - containerStateTerminated.getExitCode.intValue() - }.getOrElse(UNKNOWN_EXIT_CODE) - }.getOrElse(UNKNOWN_EXIT_CODE) - } - - private def isPodAlreadyReleased(pod: Pod): Boolean = { - val executorId = pod.getMetadata.getLabels.get(SPARK_EXECUTOR_ID_LABEL) - RUNNING_EXECUTOR_PODS_LOCK.synchronized { - !runningExecutorsToPods.contains(executorId) - } - } - - private def executorExitReasonOnError(pod: Pod): ExecutorExited = { - val containerExitStatus = getExecutorExitStatus(pod) - // container was probably actively killed by the driver. - if (isPodAlreadyReleased(pod)) { - ExecutorExited(containerExitStatus, exitCausedByApp = false, - s"Container in pod ${pod.getMetadata.getName} exited from explicit termination " + - "request.") - } else { - val containerExitReason = s"Pod ${pod.getMetadata.getName}'s executor container " + - s"exited with exit status code $containerExitStatus." - ExecutorExited(containerExitStatus, exitCausedByApp = true, containerExitReason) - } - } - - private def executorExitReasonOnDelete(pod: Pod): ExecutorExited = { - val exitMessage = if (isPodAlreadyReleased(pod)) { - s"Container in pod ${pod.getMetadata.getName} exited from explicit termination request." - } else { - s"Pod ${pod.getMetadata.getName} deleted or lost." - } - ExecutorExited(getExecutorExitStatus(pod), exitCausedByApp = false, exitMessage) - } - - private def getExecutorId(pod: Pod): String = { - val executorId = pod.getMetadata.getLabels.get(SPARK_EXECUTOR_ID_LABEL) - require(executorId != null, "Unexpected pod metadata; expected all executor pods " + - s"to have label $SPARK_EXECUTOR_ID_LABEL.") - executorId - } + override def doKillExecutors(executorIds: Seq[String]): Future[Boolean] = Future[Boolean] { + kubernetesClient.pods() + .withLabel(SPARK_APP_ID_LABEL, applicationId()) + .withLabel(SPARK_ROLE_LABEL, SPARK_POD_EXECUTOR_ROLE) + .withLabelIn(SPARK_EXECUTOR_ID_LABEL, executorIds: _*) + .delete() + // Don't do anything else - let event handling from the Kubernetes API do the Spark changes } override def createDriverEndpoint(properties: Seq[(String, String)]): DriverEndpoint = { new KubernetesDriverEndpoint(rpcEnv, properties) } - private class KubernetesDriverEndpoint( - rpcEnv: RpcEnv, - sparkProperties: Seq[(String, String)]) + private class KubernetesDriverEndpoint(rpcEnv: RpcEnv, sparkProperties: Seq[(String, String)]) extends DriverEndpoint(rpcEnv, sparkProperties) { override def onDisconnected(rpcAddress: RpcAddress): Unit = { - addressToExecutorId.get(rpcAddress).foreach { executorId => - if (disableExecutor(executorId)) { - RUNNING_EXECUTOR_PODS_LOCK.synchronized { - runningExecutorsToPods.get(executorId).foreach { pod => - disconnectedPodsByExecutorIdPendingRemoval.put(executorId, pod) - } - } - } - } + // Don't do anything besides disabling the executor - allow the Kubernetes API events to + // drive the rest of the lifecycle decisions + // TODO what if we disconnect from a networking issue? Probably want to mark the executor + // to be deleted eventually. + addressToExecutorId.get(rpcAddress).foreach(disableExecutor) } } -} -private object KubernetesClusterSchedulerBackend { - private val UNKNOWN_EXIT_CODE = -1 } diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilder.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilder.scala index 769a0a5a63047..364b6fb367722 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilder.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilder.scala @@ -17,37 +17,41 @@ package org.apache.spark.scheduler.cluster.k8s import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesExecutorSpecificConf, KubernetesRoleSpecificConf, SparkPod} -import org.apache.spark.deploy.k8s.features.{BasicExecutorFeatureStep, EnvSecretsFeatureStep, KubernetesFeatureConfigStep, LocalDirsFeatureStep, MountSecretsFeatureStep} +import org.apache.spark.deploy.k8s.features._ +import org.apache.spark.deploy.k8s.features.{BasicExecutorFeatureStep, EnvSecretsFeatureStep, LocalDirsFeatureStep, MountSecretsFeatureStep} private[spark] class KubernetesExecutorBuilder( - provideBasicStep: (KubernetesConf[KubernetesExecutorSpecificConf]) => BasicExecutorFeatureStep = + provideBasicStep: (KubernetesConf [KubernetesExecutorSpecificConf]) + => BasicExecutorFeatureStep = new BasicExecutorFeatureStep(_), - provideSecretsStep: - (KubernetesConf[_ <: KubernetesRoleSpecificConf]) => MountSecretsFeatureStep = + provideSecretsStep: (KubernetesConf[_ <: KubernetesRoleSpecificConf]) + => MountSecretsFeatureStep = new MountSecretsFeatureStep(_), provideEnvSecretsStep: (KubernetesConf[_ <: KubernetesRoleSpecificConf] => EnvSecretsFeatureStep) = new EnvSecretsFeatureStep(_), provideLocalDirsStep: (KubernetesConf[_ <: KubernetesRoleSpecificConf]) => LocalDirsFeatureStep = - new LocalDirsFeatureStep(_)) { + new LocalDirsFeatureStep(_), + provideVolumesStep: (KubernetesConf[_ <: KubernetesRoleSpecificConf] + => MountVolumesFeatureStep) = + new MountVolumesFeatureStep(_)) { def buildFromFeatures( kubernetesConf: KubernetesConf[KubernetesExecutorSpecificConf]): SparkPod = { - val baseFeatures = Seq( - provideBasicStep(kubernetesConf), - provideLocalDirsStep(kubernetesConf)) - val maybeRoleSecretNamesStep = if (kubernetesConf.roleSecretNamesToMountPaths.nonEmpty) { - Some(provideSecretsStep(kubernetesConf)) } else None + val baseFeatures = Seq(provideBasicStep(kubernetesConf), provideLocalDirsStep(kubernetesConf)) + val secretFeature = if (kubernetesConf.roleSecretNamesToMountPaths.nonEmpty) { + Seq(provideSecretsStep(kubernetesConf)) + } else Nil + val secretEnvFeature = if (kubernetesConf.roleSecretEnvNamesToKeyRefs.nonEmpty) { + Seq(provideEnvSecretsStep(kubernetesConf)) + } else Nil + val volumesFeature = if (kubernetesConf.roleVolumes.nonEmpty) { + Seq(provideVolumesStep(kubernetesConf)) + } else Nil - val maybeProvideSecretsStep = if (kubernetesConf.roleSecretEnvNamesToKeyRefs.nonEmpty) { - Some(provideEnvSecretsStep(kubernetesConf)) } else None - - val allFeatures: Seq[KubernetesFeatureConfigStep] = - baseFeatures ++ - maybeRoleSecretNamesStep.toSeq ++ - maybeProvideSecretsStep.toSeq + val allFeatures = baseFeatures ++ secretFeature ++ secretEnvFeature ++ volumesFeature var executorPod = SparkPod.initialPod() for (feature <- allFeatures) { diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/Fabric8Aliases.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/Fabric8Aliases.scala new file mode 100644 index 0000000000000..527fc6b0d8f87 --- /dev/null +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/Fabric8Aliases.scala @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s + +import io.fabric8.kubernetes.api.model.{DoneablePod, HasMetadata, Pod, PodList} +import io.fabric8.kubernetes.client.{Watch, Watcher} +import io.fabric8.kubernetes.client.dsl.{FilterWatchListDeletable, MixedOperation, NamespaceListVisitFromServerGetDeleteRecreateWaitApplicable, PodResource} + +object Fabric8Aliases { + type PODS = MixedOperation[Pod, PodList, DoneablePod, PodResource[Pod, DoneablePod]] + type LABELED_PODS = FilterWatchListDeletable[ + Pod, PodList, java.lang.Boolean, Watch, Watcher[Pod]] + type SINGLE_POD = PodResource[Pod, DoneablePod] + type RESOURCE_LIST = NamespaceListVisitFromServerGetDeleteRecreateWaitApplicable[ + HasMetadata, Boolean] +} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/KubernetesVolumeUtilsSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/KubernetesVolumeUtilsSuite.scala new file mode 100644 index 0000000000000..d795d159773a8 --- /dev/null +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/KubernetesVolumeUtilsSuite.scala @@ -0,0 +1,106 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s + +import org.apache.spark.{SparkConf, SparkFunSuite} + +class KubernetesVolumeUtilsSuite extends SparkFunSuite { + test("Parses hostPath volumes correctly") { + val sparkConf = new SparkConf(false) + sparkConf.set("test.hostPath.volumeName.mount.path", "/path") + sparkConf.set("test.hostPath.volumeName.mount.readOnly", "true") + sparkConf.set("test.hostPath.volumeName.options.path", "/hostPath") + + val volumeSpec = KubernetesVolumeUtils.parseVolumesWithPrefix(sparkConf, "test.").head.get + assert(volumeSpec.volumeName === "volumeName") + assert(volumeSpec.mountPath === "/path") + assert(volumeSpec.mountReadOnly === true) + assert(volumeSpec.volumeConf.asInstanceOf[KubernetesHostPathVolumeConf] === + KubernetesHostPathVolumeConf("/hostPath")) + } + + test("Parses persistentVolumeClaim volumes correctly") { + val sparkConf = new SparkConf(false) + sparkConf.set("test.persistentVolumeClaim.volumeName.mount.path", "/path") + sparkConf.set("test.persistentVolumeClaim.volumeName.mount.readOnly", "true") + sparkConf.set("test.persistentVolumeClaim.volumeName.options.claimName", "claimeName") + + val volumeSpec = KubernetesVolumeUtils.parseVolumesWithPrefix(sparkConf, "test.").head.get + assert(volumeSpec.volumeName === "volumeName") + assert(volumeSpec.mountPath === "/path") + assert(volumeSpec.mountReadOnly === true) + assert(volumeSpec.volumeConf.asInstanceOf[KubernetesPVCVolumeConf] === + KubernetesPVCVolumeConf("claimeName")) + } + + test("Parses emptyDir volumes correctly") { + val sparkConf = new SparkConf(false) + sparkConf.set("test.emptyDir.volumeName.mount.path", "/path") + sparkConf.set("test.emptyDir.volumeName.mount.readOnly", "true") + sparkConf.set("test.emptyDir.volumeName.options.medium", "medium") + sparkConf.set("test.emptyDir.volumeName.options.sizeLimit", "5G") + + val volumeSpec = KubernetesVolumeUtils.parseVolumesWithPrefix(sparkConf, "test.").head.get + assert(volumeSpec.volumeName === "volumeName") + assert(volumeSpec.mountPath === "/path") + assert(volumeSpec.mountReadOnly === true) + assert(volumeSpec.volumeConf.asInstanceOf[KubernetesEmptyDirVolumeConf] === + KubernetesEmptyDirVolumeConf(Some("medium"), Some("5G"))) + } + + test("Parses emptyDir volume options can be optional") { + val sparkConf = new SparkConf(false) + sparkConf.set("test.emptyDir.volumeName.mount.path", "/path") + sparkConf.set("test.emptyDir.volumeName.mount.readOnly", "true") + + val volumeSpec = KubernetesVolumeUtils.parseVolumesWithPrefix(sparkConf, "test.").head.get + assert(volumeSpec.volumeName === "volumeName") + assert(volumeSpec.mountPath === "/path") + assert(volumeSpec.mountReadOnly === true) + assert(volumeSpec.volumeConf.asInstanceOf[KubernetesEmptyDirVolumeConf] === + KubernetesEmptyDirVolumeConf(None, None)) + } + + test("Defaults optional readOnly to false") { + val sparkConf = new SparkConf(false) + sparkConf.set("test.hostPath.volumeName.mount.path", "/path") + sparkConf.set("test.hostPath.volumeName.options.path", "/hostPath") + + val volumeSpec = KubernetesVolumeUtils.parseVolumesWithPrefix(sparkConf, "test.").head.get + assert(volumeSpec.mountReadOnly === false) + } + + test("Gracefully fails on missing mount key") { + val sparkConf = new SparkConf(false) + sparkConf.set("test.emptyDir.volumeName.mnt.path", "/path") + + val volumeSpec = KubernetesVolumeUtils.parseVolumesWithPrefix(sparkConf, "test.").head + assert(volumeSpec.isFailure === true) + assert(volumeSpec.failed.get.getMessage === "emptyDir.volumeName.mount.path") + } + + test("Gracefully fails on missing option key") { + val sparkConf = new SparkConf(false) + sparkConf.set("test.hostPath.volumeName.mount.path", "/path") + sparkConf.set("test.hostPath.volumeName.mount.readOnly", "true") + sparkConf.set("test.hostPath.volumeName.options.pth", "/hostPath") + + val volumeSpec = KubernetesVolumeUtils.parseVolumesWithPrefix(sparkConf, "test.").head + assert(volumeSpec.isFailure === true) + assert(volumeSpec.failed.get.getMessage === "hostPath.volumeName.options.path") + } +} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicDriverFeatureStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicDriverFeatureStepSuite.scala index 04b909db9d9f3..165f46a07df2f 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicDriverFeatureStepSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicDriverFeatureStepSuite.scala @@ -50,6 +50,12 @@ class BasicDriverFeatureStepSuite extends SparkFunSuite { TEST_IMAGE_PULL_SECRETS.map { secret => new LocalObjectReferenceBuilder().withName(secret).build() } + private val emptyDriverSpecificConf = KubernetesDriverSpecificConf( + None, + APP_NAME, + MAIN_CLASS, + APP_ARGS) + test("Check the pod respects all configurations from the user.") { val sparkConf = new SparkConf() @@ -62,11 +68,7 @@ class BasicDriverFeatureStepSuite extends SparkFunSuite { .set(IMAGE_PULL_SECRETS, TEST_IMAGE_PULL_SECRETS.mkString(",")) val kubernetesConf = KubernetesConf( sparkConf, - KubernetesDriverSpecificConf( - Some(JavaMainAppResource("")), - APP_NAME, - MAIN_CLASS, - APP_ARGS), + emptyDriverSpecificConf, RESOURCE_NAME_PREFIX, APP_ID, DRIVER_LABELS, @@ -74,6 +76,7 @@ class BasicDriverFeatureStepSuite extends SparkFunSuite { Map.empty, Map.empty, DRIVER_ENVS, + Nil, Seq.empty[String]) val featureStep = new BasicDriverFeatureStep(kubernetesConf) @@ -143,6 +146,7 @@ class BasicDriverFeatureStepSuite extends SparkFunSuite { Map.empty, Map.empty, DRIVER_ENVS, + Nil, Seq.empty[String]) val pythonKubernetesConf = KubernetesConf( pythonSparkConf, @@ -158,6 +162,7 @@ class BasicDriverFeatureStepSuite extends SparkFunSuite { Map.empty, Map.empty, DRIVER_ENVS, + Nil, Seq.empty[String]) val javaFeatureStep = new BasicDriverFeatureStep(javaKubernetesConf) val pythonFeatureStep = new BasicDriverFeatureStep(pythonKubernetesConf) @@ -176,11 +181,7 @@ class BasicDriverFeatureStepSuite extends SparkFunSuite { .set(CONTAINER_IMAGE, "spark-driver:latest") val kubernetesConf = KubernetesConf( sparkConf, - KubernetesDriverSpecificConf( - Some(JavaMainAppResource("")), - APP_NAME, - MAIN_CLASS, - APP_ARGS), + emptyDriverSpecificConf, RESOURCE_NAME_PREFIX, APP_ID, DRIVER_LABELS, @@ -188,7 +189,9 @@ class BasicDriverFeatureStepSuite extends SparkFunSuite { Map.empty, Map.empty, DRIVER_ENVS, + Nil, allFiles) + val step = new BasicDriverFeatureStep(kubernetesConf) val additionalProperties = step.getAdditionalPodSystemProperties() val expectedSparkConf = Map( diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStepSuite.scala index f06030aa55c0c..a44fa1f2ffc63 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStepSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStepSuite.scala @@ -89,6 +89,7 @@ class BasicExecutorFeatureStepSuite Map.empty, Map.empty, Map.empty, + Nil, Seq.empty[String])) val executor = step.configurePod(SparkPod.initialPod()) @@ -128,6 +129,7 @@ class BasicExecutorFeatureStepSuite Map.empty, Map.empty, Map.empty, + Nil, Seq.empty[String])) assert(step.configurePod(SparkPod.initialPod()).pod.getSpec.getHostname.length === 63) } @@ -148,6 +150,7 @@ class BasicExecutorFeatureStepSuite Map.empty, Map.empty, Map("qux" -> "quux"), + Nil, Seq.empty[String])) val executor = step.configurePod(SparkPod.initialPod()) diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/DriverKubernetesCredentialsFeatureStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/DriverKubernetesCredentialsFeatureStepSuite.scala index 7cea83591f3e8..7e916b3854404 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/DriverKubernetesCredentialsFeatureStepSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/DriverKubernetesCredentialsFeatureStepSuite.scala @@ -61,6 +61,7 @@ class DriverKubernetesCredentialsFeatureStepSuite extends SparkFunSuite with Bef Map.empty, Map.empty, Map.empty, + Nil, Seq.empty[String]) val kubernetesCredentialsStep = new DriverKubernetesCredentialsFeatureStep(kubernetesConf) assert(kubernetesCredentialsStep.configurePod(BASE_DRIVER_POD) === BASE_DRIVER_POD) @@ -92,6 +93,7 @@ class DriverKubernetesCredentialsFeatureStepSuite extends SparkFunSuite with Bef Map.empty, Map.empty, Map.empty, + Nil, Seq.empty[String]) val kubernetesCredentialsStep = new DriverKubernetesCredentialsFeatureStep(kubernetesConf) @@ -130,6 +132,7 @@ class DriverKubernetesCredentialsFeatureStepSuite extends SparkFunSuite with Bef Map.empty, Map.empty, Map.empty, + Nil, Seq.empty[String]) val kubernetesCredentialsStep = new DriverKubernetesCredentialsFeatureStep(kubernetesConf) val resolvedProperties = kubernetesCredentialsStep.getAdditionalPodSystemProperties() diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/DriverServiceFeatureStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/DriverServiceFeatureStepSuite.scala index 77d38bf19cd10..8b91e93eecd8c 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/DriverServiceFeatureStepSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/DriverServiceFeatureStepSuite.scala @@ -67,6 +67,7 @@ class DriverServiceFeatureStepSuite extends SparkFunSuite with BeforeAndAfter { Map.empty, Map.empty, Map.empty, + Nil, Seq.empty[String])) assert(configurationStep.configurePod(SparkPod.initialPod()) === SparkPod.initialPod()) assert(configurationStep.getAdditionalKubernetesResources().size === 1) @@ -98,6 +99,7 @@ class DriverServiceFeatureStepSuite extends SparkFunSuite with BeforeAndAfter { Map.empty, Map.empty, Map.empty, + Nil, Seq.empty[String])) val expectedServiceName = SHORT_RESOURCE_NAME_PREFIX + DriverServiceFeatureStep.DRIVER_SVC_POSTFIX @@ -119,6 +121,7 @@ class DriverServiceFeatureStepSuite extends SparkFunSuite with BeforeAndAfter { Map.empty, Map.empty, Map.empty, + Nil, Seq.empty[String])) val resolvedService = configurationStep .getAdditionalKubernetesResources() @@ -149,6 +152,7 @@ class DriverServiceFeatureStepSuite extends SparkFunSuite with BeforeAndAfter { Map.empty, Map.empty, Map.empty, + Nil, Seq.empty[String]), clock) val driverService = configurationStep @@ -176,6 +180,7 @@ class DriverServiceFeatureStepSuite extends SparkFunSuite with BeforeAndAfter { Map.empty, Map.empty, Map.empty, + Nil, Seq.empty[String]), clock) fail("The driver bind address should not be allowed.") @@ -201,6 +206,7 @@ class DriverServiceFeatureStepSuite extends SparkFunSuite with BeforeAndAfter { Map.empty, Map.empty, Map.empty, + Nil, Seq.empty[String]), clock) fail("The driver host address should not be allowed.") diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/EnvSecretsFeatureStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/EnvSecretsFeatureStepSuite.scala index af6b35eae484a..1c8d84b76c56b 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/EnvSecretsFeatureStepSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/EnvSecretsFeatureStepSuite.scala @@ -45,6 +45,7 @@ class EnvSecretsFeatureStepSuite extends SparkFunSuite{ Map.empty, envVarsToKeys, Map.empty, + Nil, Seq.empty[String]) val step = new EnvSecretsFeatureStep(kubernetesConf) diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/LocalDirsFeatureStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/LocalDirsFeatureStepSuite.scala index bd6ce4b42fc8e..a339827b819a9 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/LocalDirsFeatureStepSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/LocalDirsFeatureStepSuite.scala @@ -21,7 +21,7 @@ import org.mockito.Mockito import org.scalatest.BeforeAndAfter import org.apache.spark.{SparkConf, SparkFunSuite} -import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesDriverSpecificConf, KubernetesExecutorSpecificConf, KubernetesRoleSpecificConf, SparkPod} +import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesDriverSpecificConf, KubernetesRoleSpecificConf, SparkPod} class LocalDirsFeatureStepSuite extends SparkFunSuite with BeforeAndAfter { private val defaultLocalDir = "/var/data/default-local-dir" @@ -45,6 +45,7 @@ class LocalDirsFeatureStepSuite extends SparkFunSuite with BeforeAndAfter { Map.empty, Map.empty, Map.empty, + Nil, Seq.empty[String]) } diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/MountSecretsFeatureStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/MountSecretsFeatureStepSuite.scala index eff75b8a15daa..2b49b72dfa569 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/MountSecretsFeatureStepSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/MountSecretsFeatureStepSuite.scala @@ -43,6 +43,7 @@ class MountSecretsFeatureStepSuite extends SparkFunSuite { secretNamesToMountPaths, Map.empty, Map.empty, + Nil, Seq.empty[String]) val step = new MountSecretsFeatureStep(kubernetesConf) diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/MountVolumesFeatureStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/MountVolumesFeatureStepSuite.scala new file mode 100644 index 0000000000000..d309aa94ec115 --- /dev/null +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/MountVolumesFeatureStepSuite.scala @@ -0,0 +1,144 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.features + +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.deploy.k8s._ + +class MountVolumesFeatureStepSuite extends SparkFunSuite { + private val sparkConf = new SparkConf(false) + private val emptyKubernetesConf = KubernetesConf( + sparkConf = sparkConf, + roleSpecificConf = KubernetesDriverSpecificConf( + None, + "app-name", + "main", + Seq.empty), + appResourceNamePrefix = "resource", + appId = "app-id", + roleLabels = Map.empty, + roleAnnotations = Map.empty, + roleSecretNamesToMountPaths = Map.empty, + roleSecretEnvNamesToKeyRefs = Map.empty, + roleEnvs = Map.empty, + roleVolumes = Nil, + sparkFiles = Nil) + + test("Mounts hostPath volumes") { + val volumeConf = KubernetesVolumeSpec( + "testVolume", + "/tmp", + false, + KubernetesHostPathVolumeConf("/hostPath/tmp") + ) + val kubernetesConf = emptyKubernetesConf.copy(roleVolumes = volumeConf :: Nil) + val step = new MountVolumesFeatureStep(kubernetesConf) + val configuredPod = step.configurePod(SparkPod.initialPod()) + + assert(configuredPod.pod.getSpec.getVolumes.size() === 1) + assert(configuredPod.pod.getSpec.getVolumes.get(0).getHostPath.getPath === "/hostPath/tmp") + assert(configuredPod.container.getVolumeMounts.size() === 1) + assert(configuredPod.container.getVolumeMounts.get(0).getMountPath === "/tmp") + assert(configuredPod.container.getVolumeMounts.get(0).getName === "testVolume") + assert(configuredPod.container.getVolumeMounts.get(0).getReadOnly === false) + } + + test("Mounts pesistentVolumeClaims") { + val volumeConf = KubernetesVolumeSpec( + "testVolume", + "/tmp", + true, + KubernetesPVCVolumeConf("pvcClaim") + ) + val kubernetesConf = emptyKubernetesConf.copy(roleVolumes = volumeConf :: Nil) + val step = new MountVolumesFeatureStep(kubernetesConf) + val configuredPod = step.configurePod(SparkPod.initialPod()) + + assert(configuredPod.pod.getSpec.getVolumes.size() === 1) + val pvcClaim = configuredPod.pod.getSpec.getVolumes.get(0).getPersistentVolumeClaim + assert(pvcClaim.getClaimName === "pvcClaim") + assert(configuredPod.container.getVolumeMounts.size() === 1) + assert(configuredPod.container.getVolumeMounts.get(0).getMountPath === "/tmp") + assert(configuredPod.container.getVolumeMounts.get(0).getName === "testVolume") + assert(configuredPod.container.getVolumeMounts.get(0).getReadOnly === true) + + } + + test("Mounts emptyDir") { + val volumeConf = KubernetesVolumeSpec( + "testVolume", + "/tmp", + false, + KubernetesEmptyDirVolumeConf(Some("Memory"), Some("6G")) + ) + val kubernetesConf = emptyKubernetesConf.copy(roleVolumes = volumeConf :: Nil) + val step = new MountVolumesFeatureStep(kubernetesConf) + val configuredPod = step.configurePod(SparkPod.initialPod()) + + assert(configuredPod.pod.getSpec.getVolumes.size() === 1) + val emptyDir = configuredPod.pod.getSpec.getVolumes.get(0).getEmptyDir + assert(emptyDir.getMedium === "Memory") + assert(emptyDir.getSizeLimit.getAmount === "6G") + assert(configuredPod.container.getVolumeMounts.size() === 1) + assert(configuredPod.container.getVolumeMounts.get(0).getMountPath === "/tmp") + assert(configuredPod.container.getVolumeMounts.get(0).getName === "testVolume") + assert(configuredPod.container.getVolumeMounts.get(0).getReadOnly === false) + } + + test("Mounts emptyDir with no options") { + val volumeConf = KubernetesVolumeSpec( + "testVolume", + "/tmp", + false, + KubernetesEmptyDirVolumeConf(None, None) + ) + val kubernetesConf = emptyKubernetesConf.copy(roleVolumes = volumeConf :: Nil) + val step = new MountVolumesFeatureStep(kubernetesConf) + val configuredPod = step.configurePod(SparkPod.initialPod()) + + assert(configuredPod.pod.getSpec.getVolumes.size() === 1) + val emptyDir = configuredPod.pod.getSpec.getVolumes.get(0).getEmptyDir + assert(emptyDir.getMedium === "") + assert(emptyDir.getSizeLimit.getAmount === null) + assert(configuredPod.container.getVolumeMounts.size() === 1) + assert(configuredPod.container.getVolumeMounts.get(0).getMountPath === "/tmp") + assert(configuredPod.container.getVolumeMounts.get(0).getName === "testVolume") + assert(configuredPod.container.getVolumeMounts.get(0).getReadOnly === false) + } + + test("Mounts multiple volumes") { + val hpVolumeConf = KubernetesVolumeSpec( + "hpVolume", + "/tmp", + false, + KubernetesHostPathVolumeConf("/hostPath/tmp") + ) + val pvcVolumeConf = KubernetesVolumeSpec( + "checkpointVolume", + "/checkpoints", + true, + KubernetesPVCVolumeConf("pvcClaim") + ) + val volumesConf = hpVolumeConf :: pvcVolumeConf :: Nil + val kubernetesConf = emptyKubernetesConf.copy(roleVolumes = volumesConf) + val step = new MountVolumesFeatureStep(kubernetesConf) + val configuredPod = step.configurePod(SparkPod.initialPod()) + + assert(configuredPod.pod.getSpec.getVolumes.size() === 2) + assert(configuredPod.container.getVolumeMounts.size() === 2) + } +} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/bindings/JavaDriverFeatureStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/bindings/JavaDriverFeatureStepSuite.scala index 0f2bf2fa1d9b5..18874afe6e53a 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/bindings/JavaDriverFeatureStepSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/bindings/JavaDriverFeatureStepSuite.scala @@ -42,6 +42,7 @@ class JavaDriverFeatureStepSuite extends SparkFunSuite { roleSecretNamesToMountPaths = Map.empty, roleSecretEnvNamesToKeyRefs = Map.empty, roleEnvs = Map.empty, + roleVolumes = Nil, sparkFiles = Seq.empty[String]) val step = new JavaDriverFeatureStep(kubernetesConf) diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/bindings/PythonDriverFeatureStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/bindings/PythonDriverFeatureStepSuite.scala index a1f9a5d9e264e..a5dac6869327d 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/bindings/PythonDriverFeatureStepSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/bindings/PythonDriverFeatureStepSuite.scala @@ -52,6 +52,7 @@ class PythonDriverFeatureStepSuite extends SparkFunSuite { roleSecretNamesToMountPaths = Map.empty, roleSecretEnvNamesToKeyRefs = Map.empty, roleEnvs = Map.empty, + roleVolumes = Nil, sparkFiles = Seq.empty[String]) val step = new PythonDriverFeatureStep(kubernetesConf) @@ -88,6 +89,7 @@ class PythonDriverFeatureStepSuite extends SparkFunSuite { roleSecretNamesToMountPaths = Map.empty, roleSecretEnvNamesToKeyRefs = Map.empty, roleEnvs = Map.empty, + roleVolumes = Nil, sparkFiles = Seq.empty[String]) val step = new PythonDriverFeatureStep(kubernetesConf) val driverContainerwithPySpark = step.configurePod(baseDriverPod).container diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/ClientSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/ClientSuite.scala index a8a8218c621ea..4d8e79189ff32 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/ClientSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/ClientSuite.scala @@ -27,6 +27,7 @@ import org.scalatest.mockito.MockitoSugar._ import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesDriverSpec, KubernetesDriverSpecificConf, SparkPod} import org.apache.spark.deploy.k8s.Constants._ +import org.apache.spark.deploy.k8s.Fabric8Aliases._ class ClientSuite extends SparkFunSuite with BeforeAndAfter { @@ -103,15 +104,11 @@ class ClientSuite extends SparkFunSuite with BeforeAndAfter { .build() } - private type ResourceList = NamespaceListVisitFromServerGetDeleteRecreateWaitApplicable[ - HasMetadata, Boolean] - private type Pods = MixedOperation[Pod, PodList, DoneablePod, PodResource[Pod, DoneablePod]] - @Mock private var kubernetesClient: KubernetesClient = _ @Mock - private var podOperations: Pods = _ + private var podOperations: PODS = _ @Mock private var namedPods: PodResource[Pod, DoneablePod] = _ @@ -123,7 +120,7 @@ class ClientSuite extends SparkFunSuite with BeforeAndAfter { private var driverBuilder: KubernetesDriverBuilder = _ @Mock - private var resourceList: ResourceList = _ + private var resourceList: RESOURCE_LIST = _ private var kubernetesConf: KubernetesConf[KubernetesDriverSpecificConf] = _ @@ -144,6 +141,7 @@ class ClientSuite extends SparkFunSuite with BeforeAndAfter { Map.empty, Map.empty, Map.empty, + Nil, Seq.empty[String]) when(driverBuilder.buildFromFeatures(kubernetesConf)).thenReturn(BUILT_KUBERNETES_SPEC) when(kubernetesClient.pods()).thenReturn(podOperations) diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilderSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilderSuite.scala index 4e8c300543430..046e578b94629 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilderSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilderSuite.scala @@ -17,7 +17,8 @@ package org.apache.spark.deploy.k8s.submit import org.apache.spark.{SparkConf, SparkFunSuite} -import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesDriverSpec, KubernetesDriverSpecificConf} +import org.apache.spark.deploy.k8s._ +import org.apache.spark.deploy.k8s.features._ import org.apache.spark.deploy.k8s.features.{BasicDriverFeatureStep, DriverKubernetesCredentialsFeatureStep, DriverServiceFeatureStep, EnvSecretsFeatureStep, KubernetesFeaturesTestUtils, LocalDirsFeatureStep, MountSecretsFeatureStep} import org.apache.spark.deploy.k8s.features.bindings.{JavaDriverFeatureStep, PythonDriverFeatureStep} @@ -31,6 +32,7 @@ class KubernetesDriverBuilderSuite extends SparkFunSuite { private val JAVA_STEP_TYPE = "java-bindings" private val PYSPARK_STEP_TYPE = "pyspark-bindings" private val ENV_SECRETS_STEP_TYPE = "env-secrets" + private val MOUNT_VOLUMES_STEP_TYPE = "mount-volumes" private val basicFeatureStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType( BASIC_STEP_TYPE, classOf[BasicDriverFeatureStep]) @@ -56,6 +58,9 @@ class KubernetesDriverBuilderSuite extends SparkFunSuite { private val envSecretsStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType( ENV_SECRETS_STEP_TYPE, classOf[EnvSecretsFeatureStep]) + private val mountVolumesStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType( + MOUNT_VOLUMES_STEP_TYPE, classOf[MountVolumesFeatureStep]) + private val builderUnderTest: KubernetesDriverBuilder = new KubernetesDriverBuilder( _ => basicFeatureStep, @@ -64,6 +69,7 @@ class KubernetesDriverBuilderSuite extends SparkFunSuite { _ => secretsStep, _ => envSecretsStep, _ => localDirsStep, + _ => mountVolumesStep, _ => javaStep, _ => pythonStep) @@ -82,6 +88,7 @@ class KubernetesDriverBuilderSuite extends SparkFunSuite { Map.empty, Map.empty, Map.empty, + Nil, Seq.empty[String]) validateStepTypesApplied( builderUnderTest.buildFromFeatures(conf), @@ -107,6 +114,7 @@ class KubernetesDriverBuilderSuite extends SparkFunSuite { Map("secret" -> "secretMountPath"), Map("EnvName" -> "SecretName:secretKey"), Map.empty, + Nil, Seq.empty[String]) validateStepTypesApplied( builderUnderTest.buildFromFeatures(conf), @@ -134,6 +142,7 @@ class KubernetesDriverBuilderSuite extends SparkFunSuite { Map.empty, Map.empty, Map.empty, + Nil, Seq.empty[String]) validateStepTypesApplied( builderUnderTest.buildFromFeatures(conf), @@ -159,6 +168,7 @@ class KubernetesDriverBuilderSuite extends SparkFunSuite { Map.empty, Map.empty, Map.empty, + Nil, Seq.empty[String]) validateStepTypesApplied( builderUnderTest.buildFromFeatures(conf), @@ -169,6 +179,39 @@ class KubernetesDriverBuilderSuite extends SparkFunSuite { PYSPARK_STEP_TYPE) } + test("Apply volumes step if mounts are present.") { + val volumeSpec = KubernetesVolumeSpec( + "volume", + "/tmp", + false, + KubernetesHostPathVolumeConf("/path")) + val conf = KubernetesConf( + new SparkConf(false), + KubernetesDriverSpecificConf( + None, + "test-app", + "main", + Seq.empty), + "prefix", + "appId", + Map.empty, + Map.empty, + Map.empty, + Map.empty, + Map.empty, + volumeSpec :: Nil, + Seq.empty[String]) + validateStepTypesApplied( + builderUnderTest.buildFromFeatures(conf), + BASIC_STEP_TYPE, + CREDENTIALS_STEP_TYPE, + SERVICE_STEP_TYPE, + LOCAL_DIRS_STEP_TYPE, + MOUNT_VOLUMES_STEP_TYPE, + JAVA_STEP_TYPE) + } + + private def validateStepTypesApplied(resolvedSpec: KubernetesDriverSpec, stepTypes: String*) : Unit = { assert(resolvedSpec.systemProperties.size === stepTypes.size) diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/DeterministicExecutorPodsSnapshotsStore.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/DeterministicExecutorPodsSnapshotsStore.scala new file mode 100644 index 0000000000000..f7721e6fd6388 --- /dev/null +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/DeterministicExecutorPodsSnapshotsStore.scala @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.scheduler.cluster.k8s + +import io.fabric8.kubernetes.api.model.Pod +import scala.collection.mutable + +class DeterministicExecutorPodsSnapshotsStore extends ExecutorPodsSnapshotsStore { + + private val snapshotsBuffer = mutable.Buffer.empty[ExecutorPodsSnapshot] + private val subscribers = mutable.Buffer.empty[Seq[ExecutorPodsSnapshot] => Unit] + + private var currentSnapshot = ExecutorPodsSnapshot() + + override def addSubscriber + (processBatchIntervalMillis: Long) + (onNewSnapshots: Seq[ExecutorPodsSnapshot] => Unit): Unit = { + subscribers += onNewSnapshots + } + + override def stop(): Unit = {} + + def notifySubscribers(): Unit = { + subscribers.foreach(_(snapshotsBuffer)) + snapshotsBuffer.clear() + } + + override def updatePod(updatedPod: Pod): Unit = { + currentSnapshot = currentSnapshot.withUpdate(updatedPod) + snapshotsBuffer += currentSnapshot + } + + override def replaceSnapshot(newSnapshot: Seq[Pod]): Unit = { + currentSnapshot = ExecutorPodsSnapshot(newSnapshot) + snapshotsBuffer += currentSnapshot + } +} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorLifecycleTestUtils.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorLifecycleTestUtils.scala new file mode 100644 index 0000000000000..c6b667ed85e8c --- /dev/null +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorLifecycleTestUtils.scala @@ -0,0 +1,123 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.scheduler.cluster.k8s + +import io.fabric8.kubernetes.api.model.{ContainerBuilder, Pod, PodBuilder} + +import org.apache.spark.deploy.k8s.Constants._ +import org.apache.spark.deploy.k8s.SparkPod + +object ExecutorLifecycleTestUtils { + + val TEST_SPARK_APP_ID = "spark-app-id" + + def failedExecutorWithoutDeletion(executorId: Long): Pod = { + new PodBuilder(podWithAttachedContainerForId(executorId)) + .editOrNewStatus() + .withPhase("failed") + .addNewContainerStatus() + .withName("spark-executor") + .withImage("k8s-spark") + .withNewState() + .withNewTerminated() + .withMessage("Failed") + .withExitCode(1) + .endTerminated() + .endState() + .endContainerStatus() + .addNewContainerStatus() + .withName("spark-executor-sidecar") + .withImage("k8s-spark-sidecar") + .withNewState() + .withNewTerminated() + .withMessage("Failed") + .withExitCode(1) + .endTerminated() + .endState() + .endContainerStatus() + .withMessage("Executor failed.") + .withReason("Executor failed because of a thrown error.") + .endStatus() + .build() + } + + def pendingExecutor(executorId: Long): Pod = { + new PodBuilder(podWithAttachedContainerForId(executorId)) + .editOrNewStatus() + .withPhase("pending") + .endStatus() + .build() + } + + def runningExecutor(executorId: Long): Pod = { + new PodBuilder(podWithAttachedContainerForId(executorId)) + .editOrNewStatus() + .withPhase("running") + .endStatus() + .build() + } + + def succeededExecutor(executorId: Long): Pod = { + new PodBuilder(podWithAttachedContainerForId(executorId)) + .editOrNewStatus() + .withPhase("succeeded") + .endStatus() + .build() + } + + def deletedExecutor(executorId: Long): Pod = { + new PodBuilder(podWithAttachedContainerForId(executorId)) + .editOrNewMetadata() + .withNewDeletionTimestamp("523012521") + .endMetadata() + .build() + } + + def unknownExecutor(executorId: Long): Pod = { + new PodBuilder(podWithAttachedContainerForId(executorId)) + .editOrNewStatus() + .withPhase("unknown") + .endStatus() + .build() + } + + def podWithAttachedContainerForId(executorId: Long): Pod = { + val sparkPod = executorPodWithId(executorId) + val podWithAttachedContainer = new PodBuilder(sparkPod.pod) + .editOrNewSpec() + .addToContainers(sparkPod.container) + .endSpec() + .build() + podWithAttachedContainer + } + + def executorPodWithId(executorId: Long): SparkPod = { + val pod = new PodBuilder() + .withNewMetadata() + .withName(s"spark-executor-$executorId") + .addToLabels(SPARK_APP_ID_LABEL, TEST_SPARK_APP_ID) + .addToLabels(SPARK_ROLE_LABEL, SPARK_POD_EXECUTOR_ROLE) + .addToLabels(SPARK_EXECUTOR_ID_LABEL, executorId.toString) + .endMetadata() + .build() + val container = new ContainerBuilder() + .withName("spark-executor") + .withImage("k8s-spark") + .build() + SparkPod(pod, container) + } +} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsAllocatorSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsAllocatorSuite.scala new file mode 100644 index 0000000000000..0c19f5946b75f --- /dev/null +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsAllocatorSuite.scala @@ -0,0 +1,179 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.scheduler.cluster.k8s + +import io.fabric8.kubernetes.api.model.{DoneablePod, Pod, PodBuilder} +import io.fabric8.kubernetes.client.KubernetesClient +import io.fabric8.kubernetes.client.dsl.PodResource +import org.mockito.{ArgumentMatcher, Matchers, Mock, MockitoAnnotations} +import org.mockito.Matchers.any +import org.mockito.Mockito.{never, times, verify, when} +import org.mockito.invocation.InvocationOnMock +import org.mockito.stubbing.Answer +import org.scalatest.BeforeAndAfter + +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesExecutorSpecificConf, SparkPod} +import org.apache.spark.deploy.k8s.Config._ +import org.apache.spark.deploy.k8s.Constants._ +import org.apache.spark.deploy.k8s.Fabric8Aliases._ +import org.apache.spark.scheduler.cluster.k8s.ExecutorLifecycleTestUtils._ +import org.apache.spark.util.ManualClock + +class ExecutorPodsAllocatorSuite extends SparkFunSuite with BeforeAndAfter { + + private val driverPodName = "driver" + + private val driverPod = new PodBuilder() + .withNewMetadata() + .withName(driverPodName) + .addToLabels(SPARK_APP_ID_LABEL, TEST_SPARK_APP_ID) + .addToLabels(SPARK_ROLE_LABEL, SPARK_POD_DRIVER_ROLE) + .withUid("driver-pod-uid") + .endMetadata() + .build() + + private val conf = new SparkConf().set(KUBERNETES_DRIVER_POD_NAME, driverPodName) + + private val podAllocationSize = conf.get(KUBERNETES_ALLOCATION_BATCH_SIZE) + private val podAllocationDelay = conf.get(KUBERNETES_ALLOCATION_BATCH_DELAY) + private val podCreationTimeout = math.max(podAllocationDelay * 5, 60000L) + + private var waitForExecutorPodsClock: ManualClock = _ + + @Mock + private var kubernetesClient: KubernetesClient = _ + + @Mock + private var podOperations: PODS = _ + + @Mock + private var labeledPods: LABELED_PODS = _ + + @Mock + private var driverPodOperations: PodResource[Pod, DoneablePod] = _ + + @Mock + private var executorBuilder: KubernetesExecutorBuilder = _ + + private var snapshotsStore: DeterministicExecutorPodsSnapshotsStore = _ + + private var podsAllocatorUnderTest: ExecutorPodsAllocator = _ + + before { + MockitoAnnotations.initMocks(this) + when(kubernetesClient.pods()).thenReturn(podOperations) + when(podOperations.withName(driverPodName)).thenReturn(driverPodOperations) + when(driverPodOperations.get).thenReturn(driverPod) + when(executorBuilder.buildFromFeatures(kubernetesConfWithCorrectFields())) + .thenAnswer(executorPodAnswer()) + snapshotsStore = new DeterministicExecutorPodsSnapshotsStore() + waitForExecutorPodsClock = new ManualClock(0L) + podsAllocatorUnderTest = new ExecutorPodsAllocator( + conf, executorBuilder, kubernetesClient, snapshotsStore, waitForExecutorPodsClock) + podsAllocatorUnderTest.start(TEST_SPARK_APP_ID) + } + + test("Initially request executors in batches. Do not request another batch if the" + + " first has not finished.") { + podsAllocatorUnderTest.setTotalExpectedExecutors(podAllocationSize + 1) + snapshotsStore.replaceSnapshot(Seq.empty[Pod]) + snapshotsStore.notifySubscribers() + for (nextId <- 1 to podAllocationSize) { + verify(podOperations).create(podWithAttachedContainerForId(nextId)) + } + verify(podOperations, never()).create(podWithAttachedContainerForId(podAllocationSize + 1)) + } + + test("Request executors in batches. Allow another batch to be requested if" + + " all pending executors start running.") { + podsAllocatorUnderTest.setTotalExpectedExecutors(podAllocationSize + 1) + snapshotsStore.replaceSnapshot(Seq.empty[Pod]) + snapshotsStore.notifySubscribers() + for (execId <- 1 until podAllocationSize) { + snapshotsStore.updatePod(runningExecutor(execId)) + } + snapshotsStore.notifySubscribers() + verify(podOperations, never()).create(podWithAttachedContainerForId(podAllocationSize + 1)) + snapshotsStore.updatePod(runningExecutor(podAllocationSize)) + snapshotsStore.notifySubscribers() + verify(podOperations).create(podWithAttachedContainerForId(podAllocationSize + 1)) + snapshotsStore.updatePod(runningExecutor(podAllocationSize)) + snapshotsStore.notifySubscribers() + verify(podOperations, times(podAllocationSize + 1)).create(any(classOf[Pod])) + } + + test("When a current batch reaches error states immediately, re-request" + + " them on the next batch.") { + podsAllocatorUnderTest.setTotalExpectedExecutors(podAllocationSize) + snapshotsStore.replaceSnapshot(Seq.empty[Pod]) + snapshotsStore.notifySubscribers() + for (execId <- 1 until podAllocationSize) { + snapshotsStore.updatePod(runningExecutor(execId)) + } + val failedPod = failedExecutorWithoutDeletion(podAllocationSize) + snapshotsStore.updatePod(failedPod) + snapshotsStore.notifySubscribers() + verify(podOperations).create(podWithAttachedContainerForId(podAllocationSize + 1)) + } + + test("When an executor is requested but the API does not report it in a reasonable time, retry" + + " requesting that executor.") { + podsAllocatorUnderTest.setTotalExpectedExecutors(1) + snapshotsStore.replaceSnapshot(Seq.empty[Pod]) + snapshotsStore.notifySubscribers() + snapshotsStore.replaceSnapshot(Seq.empty[Pod]) + waitForExecutorPodsClock.setTime(podCreationTimeout + 1) + when(podOperations.withLabel(SPARK_EXECUTOR_ID_LABEL, "1")).thenReturn(labeledPods) + snapshotsStore.notifySubscribers() + verify(labeledPods).delete() + verify(podOperations).create(podWithAttachedContainerForId(2)) + } + + private def executorPodAnswer(): Answer[SparkPod] = { + new Answer[SparkPod] { + override def answer(invocation: InvocationOnMock): SparkPod = { + val k8sConf = invocation.getArgumentAt( + 0, classOf[KubernetesConf[KubernetesExecutorSpecificConf]]) + executorPodWithId(k8sConf.roleSpecificConf.executorId.toInt) + } + } + } + + private def kubernetesConfWithCorrectFields(): KubernetesConf[KubernetesExecutorSpecificConf] = + Matchers.argThat(new ArgumentMatcher[KubernetesConf[KubernetesExecutorSpecificConf]] { + override def matches(argument: scala.Any): Boolean = { + if (!argument.isInstanceOf[KubernetesConf[KubernetesExecutorSpecificConf]]) { + false + } else { + val k8sConf = argument.asInstanceOf[KubernetesConf[KubernetesExecutorSpecificConf]] + val executorSpecificConf = k8sConf.roleSpecificConf + val expectedK8sConf = KubernetesConf.createExecutorConf( + conf, + executorSpecificConf.executorId, + TEST_SPARK_APP_ID, + driverPod) + k8sConf.sparkConf.getAll.toMap == conf.getAll.toMap && + // Since KubernetesConf.createExecutorConf clones the SparkConf object, force + // deep equality comparison for the SparkConf object and use object equality + // comparison on all other fields. + k8sConf.copy(sparkConf = conf) == expectedK8sConf.copy(sparkConf = conf) + } + } + }) + +} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsLifecycleManagerSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsLifecycleManagerSuite.scala new file mode 100644 index 0000000000000..562ace9f49d4d --- /dev/null +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsLifecycleManagerSuite.scala @@ -0,0 +1,126 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.scheduler.cluster.k8s + +import com.google.common.cache.CacheBuilder +import io.fabric8.kubernetes.api.model.{DoneablePod, Pod} +import io.fabric8.kubernetes.client.KubernetesClient +import io.fabric8.kubernetes.client.dsl.PodResource +import org.mockito.{Mock, MockitoAnnotations} +import org.mockito.Matchers.any +import org.mockito.Mockito.{mock, times, verify, when} +import org.mockito.invocation.InvocationOnMock +import org.mockito.stubbing.Answer +import org.scalatest.BeforeAndAfter +import scala.collection.JavaConverters._ +import scala.collection.mutable + +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.deploy.k8s.Fabric8Aliases._ +import org.apache.spark.scheduler.ExecutorExited +import org.apache.spark.scheduler.cluster.k8s.ExecutorLifecycleTestUtils._ + +class ExecutorPodsLifecycleManagerSuite extends SparkFunSuite with BeforeAndAfter { + + private var namedExecutorPods: mutable.Map[String, PodResource[Pod, DoneablePod]] = _ + + @Mock + private var kubernetesClient: KubernetesClient = _ + + @Mock + private var podOperations: PODS = _ + + @Mock + private var executorBuilder: KubernetesExecutorBuilder = _ + + @Mock + private var schedulerBackend: KubernetesClusterSchedulerBackend = _ + + private var snapshotsStore: DeterministicExecutorPodsSnapshotsStore = _ + private var eventHandlerUnderTest: ExecutorPodsLifecycleManager = _ + + before { + MockitoAnnotations.initMocks(this) + val removedExecutorsCache = CacheBuilder.newBuilder().build[java.lang.Long, java.lang.Long] + snapshotsStore = new DeterministicExecutorPodsSnapshotsStore() + namedExecutorPods = mutable.Map.empty[String, PodResource[Pod, DoneablePod]] + when(schedulerBackend.getExecutorIds()).thenReturn(Seq.empty[String]) + when(kubernetesClient.pods()).thenReturn(podOperations) + when(podOperations.withName(any(classOf[String]))).thenAnswer(namedPodsAnswer()) + eventHandlerUnderTest = new ExecutorPodsLifecycleManager( + new SparkConf(), + executorBuilder, + kubernetesClient, + snapshotsStore, + removedExecutorsCache) + eventHandlerUnderTest.start(schedulerBackend) + } + + test("When an executor reaches error states immediately, remove from the scheduler backend.") { + val failedPod = failedExecutorWithoutDeletion(1) + snapshotsStore.updatePod(failedPod) + snapshotsStore.notifySubscribers() + val msg = exitReasonMessage(1, failedPod) + val expectedLossReason = ExecutorExited(1, exitCausedByApp = true, msg) + verify(schedulerBackend).doRemoveExecutor("1", expectedLossReason) + verify(namedExecutorPods(failedPod.getMetadata.getName)).delete() + } + + test("Don't remove executors twice from Spark but remove from K8s repeatedly.") { + val failedPod = failedExecutorWithoutDeletion(1) + snapshotsStore.updatePod(failedPod) + snapshotsStore.updatePod(failedPod) + snapshotsStore.notifySubscribers() + val msg = exitReasonMessage(1, failedPod) + val expectedLossReason = ExecutorExited(1, exitCausedByApp = true, msg) + verify(schedulerBackend, times(1)).doRemoveExecutor("1", expectedLossReason) + verify(namedExecutorPods(failedPod.getMetadata.getName), times(2)).delete() + } + + test("When the scheduler backend lists executor ids that aren't present in the cluster," + + " remove those executors from Spark.") { + when(schedulerBackend.getExecutorIds()).thenReturn(Seq("1")) + val msg = s"The executor with ID 1 was not found in the cluster but we didn't" + + s" get a reason why. Marking the executor as failed. The executor may have been" + + s" deleted but the driver missed the deletion event." + val expectedLossReason = ExecutorExited(-1, exitCausedByApp = false, msg) + snapshotsStore.replaceSnapshot(Seq.empty[Pod]) + snapshotsStore.notifySubscribers() + verify(schedulerBackend).doRemoveExecutor("1", expectedLossReason) + } + + private def exitReasonMessage(failedExecutorId: Int, failedPod: Pod): String = { + s""" + |The executor with id $failedExecutorId exited with exit code 1. + |The API gave the following brief reason: ${failedPod.getStatus.getReason} + |The API gave the following message: ${failedPod.getStatus.getMessage} + |The API gave the following container statuses: + | + |${failedPod.getStatus.getContainerStatuses.asScala.map(_.toString).mkString("\n===\n")} + """.stripMargin + } + + private def namedPodsAnswer(): Answer[PodResource[Pod, DoneablePod]] = { + new Answer[PodResource[Pod, DoneablePod]] { + override def answer(invocation: InvocationOnMock): PodResource[Pod, DoneablePod] = { + val podName = invocation.getArgumentAt(0, classOf[String]) + namedExecutorPods.getOrElseUpdate( + podName, mock(classOf[PodResource[Pod, DoneablePod]])) + } + } + } +} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsPollingSnapshotSourceSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsPollingSnapshotSourceSuite.scala new file mode 100644 index 0000000000000..1b26d6af296a5 --- /dev/null +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsPollingSnapshotSourceSuite.scala @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.scheduler.cluster.k8s + +import java.util.concurrent.TimeUnit + +import io.fabric8.kubernetes.api.model.PodListBuilder +import io.fabric8.kubernetes.client.KubernetesClient +import org.jmock.lib.concurrent.DeterministicScheduler +import org.mockito.{Mock, MockitoAnnotations} +import org.mockito.Mockito.{verify, when} +import org.scalatest.BeforeAndAfter + +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.deploy.k8s.Config._ +import org.apache.spark.deploy.k8s.Constants._ +import org.apache.spark.deploy.k8s.Fabric8Aliases._ +import org.apache.spark.scheduler.cluster.k8s.ExecutorLifecycleTestUtils._ + +class ExecutorPodsPollingSnapshotSourceSuite extends SparkFunSuite with BeforeAndAfter { + + private val sparkConf = new SparkConf + + private val pollingInterval = sparkConf.get(KUBERNETES_EXECUTOR_API_POLLING_INTERVAL) + + @Mock + private var kubernetesClient: KubernetesClient = _ + + @Mock + private var podOperations: PODS = _ + + @Mock + private var appIdLabeledPods: LABELED_PODS = _ + + @Mock + private var executorRoleLabeledPods: LABELED_PODS = _ + + @Mock + private var eventQueue: ExecutorPodsSnapshotsStore = _ + + private var pollingExecutor: DeterministicScheduler = _ + private var pollingSourceUnderTest: ExecutorPodsPollingSnapshotSource = _ + + before { + MockitoAnnotations.initMocks(this) + pollingExecutor = new DeterministicScheduler() + pollingSourceUnderTest = new ExecutorPodsPollingSnapshotSource( + sparkConf, + kubernetesClient, + eventQueue, + pollingExecutor) + pollingSourceUnderTest.start(TEST_SPARK_APP_ID) + when(kubernetesClient.pods()).thenReturn(podOperations) + when(podOperations.withLabel(SPARK_APP_ID_LABEL, TEST_SPARK_APP_ID)) + .thenReturn(appIdLabeledPods) + when(appIdLabeledPods.withLabel(SPARK_ROLE_LABEL, SPARK_POD_EXECUTOR_ROLE)) + .thenReturn(executorRoleLabeledPods) + } + + test("Items returned by the API should be pushed to the event queue") { + when(executorRoleLabeledPods.list()) + .thenReturn(new PodListBuilder() + .addToItems( + runningExecutor(1), + runningExecutor(2)) + .build()) + pollingExecutor.tick(pollingInterval, TimeUnit.MILLISECONDS) + verify(eventQueue).replaceSnapshot(Seq(runningExecutor(1), runningExecutor(2))) + + } +} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsSnapshotSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsSnapshotSuite.scala new file mode 100644 index 0000000000000..70e19c904eddb --- /dev/null +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsSnapshotSuite.scala @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.scheduler.cluster.k8s + +import org.apache.spark.SparkFunSuite +import org.apache.spark.scheduler.cluster.k8s.ExecutorLifecycleTestUtils._ + +class ExecutorPodsSnapshotSuite extends SparkFunSuite { + + test("States are interpreted correctly from pod metadata.") { + val pods = Seq( + pendingExecutor(0), + runningExecutor(1), + succeededExecutor(2), + failedExecutorWithoutDeletion(3), + deletedExecutor(4), + unknownExecutor(5)) + val snapshot = ExecutorPodsSnapshot(pods) + assert(snapshot.executorPods === + Map( + 0L -> PodPending(pods(0)), + 1L -> PodRunning(pods(1)), + 2L -> PodSucceeded(pods(2)), + 3L -> PodFailed(pods(3)), + 4L -> PodDeleted(pods(4)), + 5L -> PodUnknown(pods(5)))) + } + + test("Updates add new pods for non-matching ids and edit existing pods for matching ids") { + val originalPods = Seq( + pendingExecutor(0), + runningExecutor(1)) + val originalSnapshot = ExecutorPodsSnapshot(originalPods) + val snapshotWithUpdatedPod = originalSnapshot.withUpdate(succeededExecutor(1)) + assert(snapshotWithUpdatedPod.executorPods === + Map( + 0L -> PodPending(originalPods(0)), + 1L -> PodSucceeded(succeededExecutor(1)))) + val snapshotWithNewPod = snapshotWithUpdatedPod.withUpdate(pendingExecutor(2)) + assert(snapshotWithNewPod.executorPods === + Map( + 0L -> PodPending(originalPods(0)), + 1L -> PodSucceeded(succeededExecutor(1)), + 2L -> PodPending(pendingExecutor(2)))) + } +} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsSnapshotsStoreSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsSnapshotsStoreSuite.scala new file mode 100644 index 0000000000000..cf54b3c4eb329 --- /dev/null +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsSnapshotsStoreSuite.scala @@ -0,0 +1,137 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.scheduler.cluster.k8s + +import java.util.concurrent.TimeUnit +import java.util.concurrent.atomic.AtomicReference + +import io.fabric8.kubernetes.api.model.{Pod, PodBuilder} +import org.jmock.lib.concurrent.DeterministicScheduler +import org.scalatest.BeforeAndAfter +import scala.collection.mutable + +import org.apache.spark.SparkFunSuite +import org.apache.spark.deploy.k8s.Constants._ + +class ExecutorPodsSnapshotsStoreSuite extends SparkFunSuite with BeforeAndAfter { + + private var eventBufferScheduler: DeterministicScheduler = _ + private var eventQueueUnderTest: ExecutorPodsSnapshotsStoreImpl = _ + + before { + eventBufferScheduler = new DeterministicScheduler() + eventQueueUnderTest = new ExecutorPodsSnapshotsStoreImpl(eventBufferScheduler) + } + + test("Subscribers get notified of events periodically.") { + val receivedSnapshots1 = mutable.Buffer.empty[ExecutorPodsSnapshot] + val receivedSnapshots2 = mutable.Buffer.empty[ExecutorPodsSnapshot] + eventQueueUnderTest.addSubscriber(1000) { + receivedSnapshots1 ++= _ + } + eventQueueUnderTest.addSubscriber(2000) { + receivedSnapshots2 ++= _ + } + + eventBufferScheduler.runUntilIdle() + assert(receivedSnapshots1 === Seq(ExecutorPodsSnapshot())) + assert(receivedSnapshots2 === Seq(ExecutorPodsSnapshot())) + + pushPodWithIndex(1) + // Force time to move forward so that the buffer is emitted, scheduling the + // processing task on the subscription executor... + eventBufferScheduler.tick(1000, TimeUnit.MILLISECONDS) + // ... then actually execute the subscribers. + + assert(receivedSnapshots1 === Seq( + ExecutorPodsSnapshot(), + ExecutorPodsSnapshot(Seq(podWithIndex(1))))) + assert(receivedSnapshots2 === Seq(ExecutorPodsSnapshot())) + + eventBufferScheduler.tick(1000, TimeUnit.MILLISECONDS) + + // Don't repeat snapshots + assert(receivedSnapshots1 === Seq( + ExecutorPodsSnapshot(), + ExecutorPodsSnapshot(Seq(podWithIndex(1))))) + assert(receivedSnapshots2 === Seq( + ExecutorPodsSnapshot(), + ExecutorPodsSnapshot(Seq(podWithIndex(1))))) + pushPodWithIndex(2) + pushPodWithIndex(3) + eventBufferScheduler.tick(1000, TimeUnit.MILLISECONDS) + + assert(receivedSnapshots1 === Seq( + ExecutorPodsSnapshot(), + ExecutorPodsSnapshot(Seq(podWithIndex(1))), + ExecutorPodsSnapshot(Seq(podWithIndex(1), podWithIndex(2))), + ExecutorPodsSnapshot(Seq(podWithIndex(1), podWithIndex(2), podWithIndex(3))))) + assert(receivedSnapshots2 === Seq( + ExecutorPodsSnapshot(), + ExecutorPodsSnapshot(Seq(podWithIndex(1))))) + + eventBufferScheduler.tick(1000, TimeUnit.MILLISECONDS) + assert(receivedSnapshots1 === Seq( + ExecutorPodsSnapshot(), + ExecutorPodsSnapshot(Seq(podWithIndex(1))), + ExecutorPodsSnapshot(Seq(podWithIndex(1), podWithIndex(2))), + ExecutorPodsSnapshot(Seq(podWithIndex(1), podWithIndex(2), podWithIndex(3))))) + assert(receivedSnapshots1 === receivedSnapshots2) + } + + test("Even without sending events, initially receive an empty buffer.") { + val receivedInitialSnapshot = new AtomicReference[Seq[ExecutorPodsSnapshot]](null) + eventQueueUnderTest.addSubscriber(1000) { + receivedInitialSnapshot.set + } + assert(receivedInitialSnapshot.get == null) + eventBufferScheduler.runUntilIdle() + assert(receivedInitialSnapshot.get === Seq(ExecutorPodsSnapshot())) + } + + test("Replacing the snapshot passes the new snapshot to subscribers.") { + val receivedSnapshots = mutable.Buffer.empty[ExecutorPodsSnapshot] + eventQueueUnderTest.addSubscriber(1000) { + receivedSnapshots ++= _ + } + eventQueueUnderTest.updatePod(podWithIndex(1)) + eventBufferScheduler.tick(1000, TimeUnit.MILLISECONDS) + assert(receivedSnapshots === Seq( + ExecutorPodsSnapshot(), + ExecutorPodsSnapshot(Seq(podWithIndex(1))))) + eventQueueUnderTest.replaceSnapshot(Seq(podWithIndex(2))) + eventBufferScheduler.tick(1000, TimeUnit.MILLISECONDS) + assert(receivedSnapshots === Seq( + ExecutorPodsSnapshot(), + ExecutorPodsSnapshot(Seq(podWithIndex(1))), + ExecutorPodsSnapshot(Seq(podWithIndex(2))))) + } + + private def pushPodWithIndex(index: Int): Unit = + eventQueueUnderTest.updatePod(podWithIndex(index)) + + private def podWithIndex(index: Int): Pod = + new PodBuilder() + .editOrNewMetadata() + .withName(s"pod-$index") + .addToLabels(SPARK_EXECUTOR_ID_LABEL, index.toString) + .endMetadata() + .editOrNewStatus() + .withPhase("running") + .endStatus() + .build() +} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsWatchSnapshotSourceSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsWatchSnapshotSourceSuite.scala new file mode 100644 index 0000000000000..ac1968b4ff810 --- /dev/null +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsWatchSnapshotSourceSuite.scala @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.scheduler.cluster.k8s + +import io.fabric8.kubernetes.api.model.Pod +import io.fabric8.kubernetes.client.{KubernetesClient, Watch, Watcher} +import io.fabric8.kubernetes.client.Watcher.Action +import org.mockito.{ArgumentCaptor, Mock, MockitoAnnotations} +import org.mockito.Mockito.{verify, when} +import org.scalatest.BeforeAndAfter + +import org.apache.spark.SparkFunSuite +import org.apache.spark.deploy.k8s.Constants._ +import org.apache.spark.deploy.k8s.Fabric8Aliases._ +import org.apache.spark.scheduler.cluster.k8s.ExecutorLifecycleTestUtils._ + +class ExecutorPodsWatchSnapshotSourceSuite extends SparkFunSuite with BeforeAndAfter { + + @Mock + private var eventQueue: ExecutorPodsSnapshotsStore = _ + + @Mock + private var kubernetesClient: KubernetesClient = _ + + @Mock + private var podOperations: PODS = _ + + @Mock + private var appIdLabeledPods: LABELED_PODS = _ + + @Mock + private var executorRoleLabeledPods: LABELED_PODS = _ + + @Mock + private var watchConnection: Watch = _ + + private var watch: ArgumentCaptor[Watcher[Pod]] = _ + + private var watchSourceUnderTest: ExecutorPodsWatchSnapshotSource = _ + + before { + MockitoAnnotations.initMocks(this) + watch = ArgumentCaptor.forClass(classOf[Watcher[Pod]]) + when(kubernetesClient.pods()).thenReturn(podOperations) + when(podOperations.withLabel(SPARK_APP_ID_LABEL, TEST_SPARK_APP_ID)) + .thenReturn(appIdLabeledPods) + when(appIdLabeledPods.withLabel(SPARK_ROLE_LABEL, SPARK_POD_EXECUTOR_ROLE)) + .thenReturn(executorRoleLabeledPods) + when(executorRoleLabeledPods.watch(watch.capture())).thenReturn(watchConnection) + watchSourceUnderTest = new ExecutorPodsWatchSnapshotSource( + eventQueue, kubernetesClient) + watchSourceUnderTest.start(TEST_SPARK_APP_ID) + } + + test("Watch events should be pushed to the snapshots store as snapshot updates.") { + watch.getValue.eventReceived(Action.ADDED, runningExecutor(1)) + watch.getValue.eventReceived(Action.MODIFIED, runningExecutor(2)) + verify(eventQueue).updatePod(runningExecutor(1)) + verify(eventQueue).updatePod(runningExecutor(2)) + } +} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackendSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackendSuite.scala index 96065e83f069c..52e7a12dbaf06 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackendSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackendSuite.scala @@ -16,85 +16,36 @@ */ package org.apache.spark.scheduler.cluster.k8s -import java.util.concurrent.{ExecutorService, ScheduledExecutorService, TimeUnit} - -import io.fabric8.kubernetes.api.model.{ContainerBuilder, DoneablePod, Pod, PodBuilder, PodList} -import io.fabric8.kubernetes.client.{KubernetesClient, Watch, Watcher} -import io.fabric8.kubernetes.client.Watcher.Action -import io.fabric8.kubernetes.client.dsl.{FilterWatchListDeletable, MixedOperation, NonNamespaceOperation, PodResource} -import org.hamcrest.{BaseMatcher, Description, Matcher} -import org.mockito.{AdditionalAnswers, ArgumentCaptor, Matchers, Mock, MockitoAnnotations} -import org.mockito.Matchers.{any, eq => mockitoEq} -import org.mockito.Mockito.{doNothing, never, times, verify, when} +import io.fabric8.kubernetes.client.KubernetesClient +import org.jmock.lib.concurrent.DeterministicScheduler +import org.mockito.{ArgumentCaptor, Mock, MockitoAnnotations} +import org.mockito.Matchers.{eq => mockitoEq} +import org.mockito.Mockito.{never, verify, when} import org.scalatest.BeforeAndAfter -import org.scalatest.mockito.MockitoSugar._ -import scala.collection.JavaConverters._ -import scala.concurrent.Future import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} -import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesExecutorSpecificConf, SparkPod} -import org.apache.spark.deploy.k8s.Config._ import org.apache.spark.deploy.k8s.Constants._ -import org.apache.spark.rpc._ -import org.apache.spark.scheduler.{ExecutorExited, LiveListenerBus, SlaveLost, TaskSchedulerImpl} -import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages.{RegisterExecutor, RemoveExecutor} +import org.apache.spark.deploy.k8s.Fabric8Aliases._ +import org.apache.spark.rpc.{RpcEndpoint, RpcEndpointRef, RpcEnv} +import org.apache.spark.scheduler.{ExecutorKilled, TaskSchedulerImpl} +import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages.RemoveExecutor import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend -import org.apache.spark.util.ThreadUtils +import org.apache.spark.scheduler.cluster.k8s.ExecutorLifecycleTestUtils.TEST_SPARK_APP_ID class KubernetesClusterSchedulerBackendSuite extends SparkFunSuite with BeforeAndAfter { - private val APP_ID = "test-spark-app" - private val DRIVER_POD_NAME = "spark-driver-pod" - private val NAMESPACE = "test-namespace" - private val SPARK_DRIVER_HOST = "localhost" - private val SPARK_DRIVER_PORT = 7077 - private val POD_ALLOCATION_INTERVAL = "1m" - private val FIRST_EXECUTOR_POD = new PodBuilder() - .withNewMetadata() - .withName("pod1") - .endMetadata() - .withNewSpec() - .withNodeName("node1") - .endSpec() - .withNewStatus() - .withHostIP("192.168.99.100") - .endStatus() - .build() - private val SECOND_EXECUTOR_POD = new PodBuilder() - .withNewMetadata() - .withName("pod2") - .endMetadata() - .withNewSpec() - .withNodeName("node2") - .endSpec() - .withNewStatus() - .withHostIP("192.168.99.101") - .endStatus() - .build() - - private type PODS = MixedOperation[Pod, PodList, DoneablePod, PodResource[Pod, DoneablePod]] - private type LABELED_PODS = FilterWatchListDeletable[ - Pod, PodList, java.lang.Boolean, Watch, Watcher[Pod]] - private type IN_NAMESPACE_PODS = NonNamespaceOperation[ - Pod, PodList, DoneablePod, PodResource[Pod, DoneablePod]] - - @Mock - private var sparkContext: SparkContext = _ - - @Mock - private var listenerBus: LiveListenerBus = _ - - @Mock - private var taskSchedulerImpl: TaskSchedulerImpl = _ + private val requestExecutorsService = new DeterministicScheduler() + private val sparkConf = new SparkConf(false) + .set("spark.executor.instances", "3") @Mock - private var allocatorExecutor: ScheduledExecutorService = _ + private var sc: SparkContext = _ @Mock - private var requestExecutorsService: ExecutorService = _ + private var rpcEnv: RpcEnv = _ @Mock - private var executorBuilder: KubernetesExecutorBuilder = _ + private var driverEndpointRef: RpcEndpointRef = _ @Mock private var kubernetesClient: KubernetesClient = _ @@ -103,347 +54,97 @@ class KubernetesClusterSchedulerBackendSuite extends SparkFunSuite with BeforeAn private var podOperations: PODS = _ @Mock - private var podsWithLabelOperations: LABELED_PODS = _ + private var labeledPods: LABELED_PODS = _ @Mock - private var podsInNamespace: IN_NAMESPACE_PODS = _ + private var taskScheduler: TaskSchedulerImpl = _ @Mock - private var podsWithDriverName: PodResource[Pod, DoneablePod] = _ + private var eventQueue: ExecutorPodsSnapshotsStore = _ @Mock - private var rpcEnv: RpcEnv = _ + private var podAllocator: ExecutorPodsAllocator = _ @Mock - private var driverEndpointRef: RpcEndpointRef = _ + private var lifecycleEventHandler: ExecutorPodsLifecycleManager = _ @Mock - private var executorPodsWatch: Watch = _ + private var watchEvents: ExecutorPodsWatchSnapshotSource = _ @Mock - private var successFuture: Future[Boolean] = _ + private var pollEvents: ExecutorPodsPollingSnapshotSource = _ - private var sparkConf: SparkConf = _ - private var executorPodsWatcherArgument: ArgumentCaptor[Watcher[Pod]] = _ - private var allocatorRunnable: ArgumentCaptor[Runnable] = _ - private var requestExecutorRunnable: ArgumentCaptor[Runnable] = _ private var driverEndpoint: ArgumentCaptor[RpcEndpoint] = _ - - private val driverPod = new PodBuilder() - .withNewMetadata() - .withName(DRIVER_POD_NAME) - .addToLabels(SPARK_APP_ID_LABEL, APP_ID) - .addToLabels(SPARK_ROLE_LABEL, SPARK_POD_DRIVER_ROLE) - .endMetadata() - .build() + private var schedulerBackendUnderTest: KubernetesClusterSchedulerBackend = _ before { MockitoAnnotations.initMocks(this) - sparkConf = new SparkConf() - .set(KUBERNETES_DRIVER_POD_NAME, DRIVER_POD_NAME) - .set(KUBERNETES_NAMESPACE, NAMESPACE) - .set("spark.driver.host", SPARK_DRIVER_HOST) - .set("spark.driver.port", SPARK_DRIVER_PORT.toString) - .set(KUBERNETES_ALLOCATION_BATCH_DELAY.key, POD_ALLOCATION_INTERVAL) - executorPodsWatcherArgument = ArgumentCaptor.forClass(classOf[Watcher[Pod]]) - allocatorRunnable = ArgumentCaptor.forClass(classOf[Runnable]) - requestExecutorRunnable = ArgumentCaptor.forClass(classOf[Runnable]) + when(taskScheduler.sc).thenReturn(sc) + when(sc.conf).thenReturn(sparkConf) driverEndpoint = ArgumentCaptor.forClass(classOf[RpcEndpoint]) - when(sparkContext.conf).thenReturn(sparkConf) - when(sparkContext.listenerBus).thenReturn(listenerBus) - when(taskSchedulerImpl.sc).thenReturn(sparkContext) - when(kubernetesClient.pods()).thenReturn(podOperations) - when(podOperations.withLabel(SPARK_APP_ID_LABEL, APP_ID)).thenReturn(podsWithLabelOperations) - when(podsWithLabelOperations.watch(executorPodsWatcherArgument.capture())) - .thenReturn(executorPodsWatch) - when(podOperations.inNamespace(NAMESPACE)).thenReturn(podsInNamespace) - when(podsInNamespace.withName(DRIVER_POD_NAME)).thenReturn(podsWithDriverName) - when(podsWithDriverName.get()).thenReturn(driverPod) - when(allocatorExecutor.scheduleWithFixedDelay( - allocatorRunnable.capture(), - mockitoEq(0L), - mockitoEq(TimeUnit.MINUTES.toMillis(1)), - mockitoEq(TimeUnit.MILLISECONDS))).thenReturn(null) - // Creating Futures in Scala backed by a Java executor service resolves to running - // ExecutorService#execute (as opposed to submit) - doNothing().when(requestExecutorsService).execute(requestExecutorRunnable.capture()) when(rpcEnv.setupEndpoint( mockitoEq(CoarseGrainedSchedulerBackend.ENDPOINT_NAME), driverEndpoint.capture())) .thenReturn(driverEndpointRef) - - // Used by the CoarseGrainedSchedulerBackend when making RPC calls. - when(driverEndpointRef.ask[Boolean] - (any(classOf[Any])) - (any())).thenReturn(successFuture) - when(successFuture.failed).thenReturn(Future[Throwable] { - // emulate behavior of the Future.failed method. - throw new NoSuchElementException() - }(ThreadUtils.sameThread)) - } - - test("Basic lifecycle expectations when starting and stopping the scheduler.") { - val scheduler = newSchedulerBackend() - scheduler.start() - assert(executorPodsWatcherArgument.getValue != null) - assert(allocatorRunnable.getValue != null) - scheduler.stop() - verify(executorPodsWatch).close() - } - - test("Static allocation should request executors upon first allocator run.") { - sparkConf - .set(KUBERNETES_ALLOCATION_BATCH_SIZE, 2) - .set(org.apache.spark.internal.config.EXECUTOR_INSTANCES, 2) - val scheduler = newSchedulerBackend() - scheduler.start() - requestExecutorRunnable.getValue.run() - val firstResolvedPod = expectPodCreationWithId(1, FIRST_EXECUTOR_POD) - val secondResolvedPod = expectPodCreationWithId(2, SECOND_EXECUTOR_POD) - when(podOperations.create(any(classOf[Pod]))).thenAnswer(AdditionalAnswers.returnsFirstArg()) - allocatorRunnable.getValue.run() - verify(podOperations).create(firstResolvedPod) - verify(podOperations).create(secondResolvedPod) - } - - test("Killing executors deletes the executor pods") { - sparkConf - .set(KUBERNETES_ALLOCATION_BATCH_SIZE, 2) - .set(org.apache.spark.internal.config.EXECUTOR_INSTANCES, 2) - val scheduler = newSchedulerBackend() - scheduler.start() - requestExecutorRunnable.getValue.run() - val firstResolvedPod = expectPodCreationWithId(1, FIRST_EXECUTOR_POD) - val secondResolvedPod = expectPodCreationWithId(2, SECOND_EXECUTOR_POD) - when(podOperations.create(any(classOf[Pod]))) - .thenAnswer(AdditionalAnswers.returnsFirstArg()) - allocatorRunnable.getValue.run() - scheduler.doKillExecutors(Seq("2")) - requestExecutorRunnable.getAllValues.asScala.last.run() - verify(podOperations).delete(secondResolvedPod) - verify(podOperations, never()).delete(firstResolvedPod) - } - - test("Executors should be requested in batches.") { - sparkConf - .set(KUBERNETES_ALLOCATION_BATCH_SIZE, 1) - .set(org.apache.spark.internal.config.EXECUTOR_INSTANCES, 2) - val scheduler = newSchedulerBackend() - scheduler.start() - requestExecutorRunnable.getValue.run() - when(podOperations.create(any(classOf[Pod]))) - .thenAnswer(AdditionalAnswers.returnsFirstArg()) - val firstResolvedPod = expectPodCreationWithId(1, FIRST_EXECUTOR_POD) - val secondResolvedPod = expectPodCreationWithId(2, SECOND_EXECUTOR_POD) - allocatorRunnable.getValue.run() - verify(podOperations).create(firstResolvedPod) - verify(podOperations, never()).create(secondResolvedPod) - val registerFirstExecutorMessage = RegisterExecutor( - "1", mock[RpcEndpointRef], "localhost", 1, Map.empty[String, String]) - when(taskSchedulerImpl.resourceOffers(any())).thenReturn(Seq.empty) - driverEndpoint.getValue.receiveAndReply(mock[RpcCallContext]) - .apply(registerFirstExecutorMessage) - allocatorRunnable.getValue.run() - verify(podOperations).create(secondResolvedPod) - } - - test("Scaled down executors should be cleaned up") { - sparkConf - .set(KUBERNETES_ALLOCATION_BATCH_SIZE, 1) - .set(org.apache.spark.internal.config.EXECUTOR_INSTANCES, 1) - val scheduler = newSchedulerBackend() - scheduler.start() - - // The scheduler backend spins up one executor pod. - requestExecutorRunnable.getValue.run() - when(podOperations.create(any(classOf[Pod]))) - .thenAnswer(AdditionalAnswers.returnsFirstArg()) - val resolvedPod = expectPodCreationWithId(1, FIRST_EXECUTOR_POD) - allocatorRunnable.getValue.run() - val executorEndpointRef = mock[RpcEndpointRef] - when(executorEndpointRef.address).thenReturn(RpcAddress("pod.example.com", 9000)) - val registerFirstExecutorMessage = RegisterExecutor( - "1", executorEndpointRef, "localhost:9000", 1, Map.empty[String, String]) - when(taskSchedulerImpl.resourceOffers(any())).thenReturn(Seq.empty) - driverEndpoint.getValue.receiveAndReply(mock[RpcCallContext]) - .apply(registerFirstExecutorMessage) - - // Request that there are 0 executors and trigger deletion from driver. - scheduler.doRequestTotalExecutors(0) - requestExecutorRunnable.getAllValues.asScala.last.run() - scheduler.doKillExecutors(Seq("1")) - requestExecutorRunnable.getAllValues.asScala.last.run() - verify(podOperations, times(1)).delete(resolvedPod) - driverEndpoint.getValue.onDisconnected(executorEndpointRef.address) - - val exitedPod = exitPod(resolvedPod, 0) - executorPodsWatcherArgument.getValue.eventReceived(Action.DELETED, exitedPod) - allocatorRunnable.getValue.run() - - // No more deletion attempts of the executors. - // This is graceful termination and should not be detected as a failure. - verify(podOperations, times(1)).delete(resolvedPod) - verify(driverEndpointRef, times(1)).send( - RemoveExecutor("1", ExecutorExited( - 0, - exitCausedByApp = false, - s"Container in pod ${exitedPod.getMetadata.getName} exited from" + - s" explicit termination request."))) - } - - test("Executors that fail should not be deleted.") { - sparkConf - .set(KUBERNETES_ALLOCATION_BATCH_SIZE, 1) - .set(org.apache.spark.internal.config.EXECUTOR_INSTANCES, 1) - - val scheduler = newSchedulerBackend() - scheduler.start() - val firstResolvedPod = expectPodCreationWithId(1, FIRST_EXECUTOR_POD) - when(podOperations.create(any(classOf[Pod]))).thenAnswer(AdditionalAnswers.returnsFirstArg()) - requestExecutorRunnable.getValue.run() - allocatorRunnable.getValue.run() - val executorEndpointRef = mock[RpcEndpointRef] - when(executorEndpointRef.address).thenReturn(RpcAddress("pod.example.com", 9000)) - val registerFirstExecutorMessage = RegisterExecutor( - "1", executorEndpointRef, "localhost:9000", 1, Map.empty[String, String]) - when(taskSchedulerImpl.resourceOffers(any())).thenReturn(Seq.empty) - driverEndpoint.getValue.receiveAndReply(mock[RpcCallContext]) - .apply(registerFirstExecutorMessage) - driverEndpoint.getValue.onDisconnected(executorEndpointRef.address) - executorPodsWatcherArgument.getValue.eventReceived( - Action.ERROR, exitPod(firstResolvedPod, 1)) - - // A replacement executor should be created but the error pod should persist. - val replacementPod = expectPodCreationWithId(2, SECOND_EXECUTOR_POD) - scheduler.doRequestTotalExecutors(1) - requestExecutorRunnable.getValue.run() - allocatorRunnable.getAllValues.asScala.last.run() - verify(podOperations, never()).delete(firstResolvedPod) - verify(driverEndpointRef).send( - RemoveExecutor("1", ExecutorExited( - 1, - exitCausedByApp = true, - s"Pod ${FIRST_EXECUTOR_POD.getMetadata.getName}'s executor container exited with" + - " exit status code 1."))) - } - - test("Executors disconnected due to unknown reasons are deleted and replaced.") { - sparkConf - .set(KUBERNETES_ALLOCATION_BATCH_SIZE, 1) - .set(org.apache.spark.internal.config.EXECUTOR_INSTANCES, 1) - val executorLostReasonCheckMaxAttempts = sparkConf.get( - KUBERNETES_EXECUTOR_LOST_REASON_CHECK_MAX_ATTEMPTS) - - val scheduler = newSchedulerBackend() - scheduler.start() - val firstResolvedPod = expectPodCreationWithId(1, FIRST_EXECUTOR_POD) - when(podOperations.create(any(classOf[Pod]))).thenAnswer(AdditionalAnswers.returnsFirstArg()) - requestExecutorRunnable.getValue.run() - allocatorRunnable.getValue.run() - val executorEndpointRef = mock[RpcEndpointRef] - when(executorEndpointRef.address).thenReturn(RpcAddress("pod.example.com", 9000)) - val registerFirstExecutorMessage = RegisterExecutor( - "1", executorEndpointRef, "localhost:9000", 1, Map.empty[String, String]) - when(taskSchedulerImpl.resourceOffers(any())).thenReturn(Seq.empty) - driverEndpoint.getValue.receiveAndReply(mock[RpcCallContext]) - .apply(registerFirstExecutorMessage) - - driverEndpoint.getValue.onDisconnected(executorEndpointRef.address) - 1 to executorLostReasonCheckMaxAttempts foreach { _ => - allocatorRunnable.getValue.run() - verify(podOperations, never()).delete(FIRST_EXECUTOR_POD) + when(kubernetesClient.pods()).thenReturn(podOperations) + schedulerBackendUnderTest = new KubernetesClusterSchedulerBackend( + taskScheduler, + rpcEnv, + kubernetesClient, + requestExecutorsService, + eventQueue, + podAllocator, + lifecycleEventHandler, + watchEvents, + pollEvents) { + override def applicationId(): String = TEST_SPARK_APP_ID } - - val recreatedResolvedPod = expectPodCreationWithId(2, SECOND_EXECUTOR_POD) - allocatorRunnable.getValue.run() - verify(podOperations).delete(firstResolvedPod) - verify(driverEndpointRef).send( - RemoveExecutor("1", SlaveLost("Executor lost for unknown reasons."))) } - test("Executors that fail to start on the Kubernetes API call rebuild in the next batch.") { - sparkConf - .set(KUBERNETES_ALLOCATION_BATCH_SIZE, 1) - .set(org.apache.spark.internal.config.EXECUTOR_INSTANCES, 1) - val scheduler = newSchedulerBackend() - scheduler.start() - val firstResolvedPod = expectPodCreationWithId(1, FIRST_EXECUTOR_POD) - when(podOperations.create(firstResolvedPod)) - .thenThrow(new RuntimeException("test")) - requestExecutorRunnable.getValue.run() - allocatorRunnable.getValue.run() - verify(podOperations, times(1)).create(firstResolvedPod) - val recreatedResolvedPod = expectPodCreationWithId(2, FIRST_EXECUTOR_POD) - allocatorRunnable.getValue.run() - verify(podOperations).create(recreatedResolvedPod) + test("Start all components") { + schedulerBackendUnderTest.start() + verify(podAllocator).setTotalExpectedExecutors(3) + verify(podAllocator).start(TEST_SPARK_APP_ID) + verify(lifecycleEventHandler).start(schedulerBackendUnderTest) + verify(watchEvents).start(TEST_SPARK_APP_ID) + verify(pollEvents).start(TEST_SPARK_APP_ID) } - test("Executors that are initially created but the watch notices them fail are rebuilt" + - " in the next batch.") { - sparkConf - .set(KUBERNETES_ALLOCATION_BATCH_SIZE, 1) - .set(org.apache.spark.internal.config.EXECUTOR_INSTANCES, 1) - val scheduler = newSchedulerBackend() - scheduler.start() - val firstResolvedPod = expectPodCreationWithId(1, FIRST_EXECUTOR_POD) - when(podOperations.create(FIRST_EXECUTOR_POD)).thenAnswer(AdditionalAnswers.returnsFirstArg()) - requestExecutorRunnable.getValue.run() - allocatorRunnable.getValue.run() - verify(podOperations, times(1)).create(firstResolvedPod) - executorPodsWatcherArgument.getValue.eventReceived(Action.ERROR, firstResolvedPod) - val recreatedResolvedPod = expectPodCreationWithId(2, FIRST_EXECUTOR_POD) - allocatorRunnable.getValue.run() - verify(podOperations).create(recreatedResolvedPod) + test("Stop all components") { + when(podOperations.withLabel(SPARK_APP_ID_LABEL, TEST_SPARK_APP_ID)).thenReturn(labeledPods) + when(labeledPods.withLabel(SPARK_ROLE_LABEL, SPARK_POD_EXECUTOR_ROLE)).thenReturn(labeledPods) + schedulerBackendUnderTest.stop() + verify(eventQueue).stop() + verify(watchEvents).stop() + verify(pollEvents).stop() + verify(labeledPods).delete() + verify(kubernetesClient).close() } - private def newSchedulerBackend(): KubernetesClusterSchedulerBackend = { - new KubernetesClusterSchedulerBackend( - taskSchedulerImpl, - rpcEnv, - executorBuilder, - kubernetesClient, - allocatorExecutor, - requestExecutorsService) { - - override def applicationId(): String = APP_ID - } + test("Remove executor") { + schedulerBackendUnderTest.start() + schedulerBackendUnderTest.doRemoveExecutor( + "1", ExecutorKilled) + verify(driverEndpointRef).send(RemoveExecutor("1", ExecutorKilled)) } - private def exitPod(basePod: Pod, exitCode: Int): Pod = { - new PodBuilder(basePod) - .editStatus() - .addNewContainerStatus() - .withNewState() - .withNewTerminated() - .withExitCode(exitCode) - .endTerminated() - .endState() - .endContainerStatus() - .endStatus() - .build() + test("Kill executors") { + schedulerBackendUnderTest.start() + when(podOperations.withLabel(SPARK_APP_ID_LABEL, TEST_SPARK_APP_ID)).thenReturn(labeledPods) + when(labeledPods.withLabel(SPARK_ROLE_LABEL, SPARK_POD_EXECUTOR_ROLE)).thenReturn(labeledPods) + when(labeledPods.withLabelIn(SPARK_EXECUTOR_ID_LABEL, "1", "2")).thenReturn(labeledPods) + schedulerBackendUnderTest.doKillExecutors(Seq("1", "2")) + verify(labeledPods, never()).delete() + requestExecutorsService.runNextPendingCommand() + verify(labeledPods).delete() } - private def expectPodCreationWithId(executorId: Int, expectedPod: Pod): Pod = { - val resolvedPod = new PodBuilder(expectedPod) - .editMetadata() - .addToLabels(SPARK_EXECUTOR_ID_LABEL, executorId.toString) - .endMetadata() - .build() - val resolvedContainer = new ContainerBuilder().build() - when(executorBuilder.buildFromFeatures(Matchers.argThat( - new BaseMatcher[KubernetesConf[KubernetesExecutorSpecificConf]] { - override def matches(argument: scala.Any) - : Boolean = { - argument.isInstanceOf[KubernetesConf[KubernetesExecutorSpecificConf]] && - argument.asInstanceOf[KubernetesConf[KubernetesExecutorSpecificConf]] - .roleSpecificConf.executorId == executorId.toString - } - - override def describeTo(description: Description): Unit = {} - }))).thenReturn(SparkPod(resolvedPod, resolvedContainer)) - new PodBuilder(resolvedPod) - .editSpec() - .addToContainers(resolvedContainer) - .endSpec() - .build() + test("Request total executors") { + schedulerBackendUnderTest.start() + schedulerBackendUnderTest.doRequestTotalExecutors(5) + verify(podAllocator).setTotalExpectedExecutors(3) + verify(podAllocator, never()).setTotalExpectedExecutors(5) + requestExecutorsService.runNextPendingCommand() + verify(podAllocator).setTotalExpectedExecutors(5) } + } diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilderSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilderSuite.scala index a6bc8bce32926..d0b4127065eb7 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilderSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilderSuite.scala @@ -19,14 +19,15 @@ package org.apache.spark.scheduler.cluster.k8s import io.fabric8.kubernetes.api.model.PodBuilder import org.apache.spark.{SparkConf, SparkFunSuite} -import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesExecutorSpecificConf, SparkPod} -import org.apache.spark.deploy.k8s.features.{BasicExecutorFeatureStep, EnvSecretsFeatureStep, KubernetesFeaturesTestUtils, LocalDirsFeatureStep, MountSecretsFeatureStep} +import org.apache.spark.deploy.k8s._ +import org.apache.spark.deploy.k8s.features._ class KubernetesExecutorBuilderSuite extends SparkFunSuite { private val BASIC_STEP_TYPE = "basic" private val SECRETS_STEP_TYPE = "mount-secrets" private val ENV_SECRETS_STEP_TYPE = "env-secrets" private val LOCAL_DIRS_STEP_TYPE = "local-dirs" + private val MOUNT_VOLUMES_STEP_TYPE = "mount-volumes" private val basicFeatureStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType( BASIC_STEP_TYPE, classOf[BasicExecutorFeatureStep]) @@ -36,12 +37,15 @@ class KubernetesExecutorBuilderSuite extends SparkFunSuite { ENV_SECRETS_STEP_TYPE, classOf[EnvSecretsFeatureStep]) private val localDirsStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType( LOCAL_DIRS_STEP_TYPE, classOf[LocalDirsFeatureStep]) + private val mountVolumesStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType( + MOUNT_VOLUMES_STEP_TYPE, classOf[MountVolumesFeatureStep]) private val builderUnderTest = new KubernetesExecutorBuilder( _ => basicFeatureStep, _ => mountSecretsStep, _ => envSecretsStep, - _ => localDirsStep) + _ => localDirsStep, + _ => mountVolumesStep) test("Basic steps are consistently applied.") { val conf = KubernetesConf( @@ -55,6 +59,7 @@ class KubernetesExecutorBuilderSuite extends SparkFunSuite { Map.empty, Map.empty, Map.empty, + Nil, Seq.empty[String]) validateStepTypesApplied( builderUnderTest.buildFromFeatures(conf), BASIC_STEP_TYPE, LOCAL_DIRS_STEP_TYPE) @@ -72,6 +77,7 @@ class KubernetesExecutorBuilderSuite extends SparkFunSuite { Map("secret" -> "secretMountPath"), Map("secret-name" -> "secret-key"), Map.empty, + Nil, Seq.empty[String]) validateStepTypesApplied( builderUnderTest.buildFromFeatures(conf), @@ -81,6 +87,32 @@ class KubernetesExecutorBuilderSuite extends SparkFunSuite { ENV_SECRETS_STEP_TYPE) } + test("Apply volumes step if mounts are present.") { + val volumeSpec = KubernetesVolumeSpec( + "volume", + "/tmp", + false, + KubernetesHostPathVolumeConf("/checkpoint")) + val conf = KubernetesConf( + new SparkConf(false), + KubernetesExecutorSpecificConf( + "executor-id", new PodBuilder().build()), + "prefix", + "appId", + Map.empty, + Map.empty, + Map.empty, + Map.empty, + Map.empty, + volumeSpec :: Nil, + Seq.empty[String]) + validateStepTypesApplied( + builderUnderTest.buildFromFeatures(conf), + BASIC_STEP_TYPE, + LOCAL_DIRS_STEP_TYPE, + MOUNT_VOLUMES_STEP_TYPE) + } + private def validateStepTypesApplied(resolvedPod: SparkPod, stepTypes: String*): Unit = { assert(resolvedPod.pod.getMetadata.getLabels.size === stepTypes.size) stepTypes.foreach { stepType => diff --git a/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/entrypoint.sh b/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/entrypoint.sh index acdb4b1f09e0a..8bdb0f7a10795 100755 --- a/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/entrypoint.sh +++ b/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/entrypoint.sh @@ -37,20 +37,24 @@ if [ -z "$uidentry" ] ; then fi SPARK_K8S_CMD="$1" -if [ -z "$SPARK_K8S_CMD" ]; then - echo "No command to execute has been provided." 1>&2 - exit 1 -fi -shift 1 +case "$SPARK_K8S_CMD" in + driver | driver-py | executor) + shift 1 + ;; + "") + ;; + *) + echo "Non-spark-on-k8s command provided, proceeding in pass-through mode..." + exec /sbin/tini -s -- "$@" + ;; +esac SPARK_CLASSPATH="$SPARK_CLASSPATH:${SPARK_HOME}/jars/*" env | grep SPARK_JAVA_OPT_ | sort -t_ -k4 -n | sed 's/[^=]*=\(.*\)/\1/g' > /tmp/java_opts.txt -readarray -t SPARK_JAVA_OPTS < /tmp/java_opts.txt -if [ -n "$SPARK_MOUNTED_CLASSPATH" ]; then - SPARK_CLASSPATH="$SPARK_CLASSPATH:$SPARK_MOUNTED_CLASSPATH" -fi -if [ -n "$SPARK_MOUNTED_FILES_DIR" ]; then - cp -R "$SPARK_MOUNTED_FILES_DIR/." . +readarray -t SPARK_EXECUTOR_JAVA_OPTS < /tmp/java_opts.txt + +if [ -n "$SPARK_EXTRA_CLASSPATH" ]; then + SPARK_CLASSPATH="$SPARK_CLASSPATH:$SPARK_EXTRA_CLASSPATH" fi if [ -n "$PYSPARK_FILES" ]; then @@ -92,11 +96,10 @@ case "$SPARK_K8S_CMD" in "$@" $PYSPARK_PRIMARY $PYSPARK_ARGS ) ;; - executor) CMD=( ${JAVA_HOME}/bin/java - "${SPARK_JAVA_OPTS[@]}" + "${SPARK_EXECUTOR_JAVA_OPTS[@]}" -Xms$SPARK_EXECUTOR_MEMORY -Xmx$SPARK_EXECUTOR_MEMORY -cp "$SPARK_CLASSPATH" diff --git a/resource-managers/kubernetes/integration-tests/dev/dev-run-integration-tests.sh b/resource-managers/kubernetes/integration-tests/dev/dev-run-integration-tests.sh index ea893fa39eede..3acd0f5cd3349 100755 --- a/resource-managers/kubernetes/integration-tests/dev/dev-run-integration-tests.sh +++ b/resource-managers/kubernetes/integration-tests/dev/dev-run-integration-tests.sh @@ -27,6 +27,8 @@ IMAGE_TAG="N/A" SPARK_MASTER= NAMESPACE= SERVICE_ACCOUNT= +INCLUDE_TAGS= +EXCLUDE_TAGS= # Parse arguments while (( "$#" )); do @@ -59,6 +61,14 @@ while (( "$#" )); do SERVICE_ACCOUNT="$2" shift ;; + --include-tags) + INCLUDE_TAGS="$2" + shift + ;; + --exclude-tags) + EXCLUDE_TAGS="$2" + shift + ;; *) break ;; @@ -90,4 +100,14 @@ then properties=( ${properties[@]} -Dspark.kubernetes.test.master=$SPARK_MASTER ) fi +if [ -n $EXCLUDE_TAGS ]; +then + properties=( ${properties[@]} -Dtest.exclude.tags=$EXCLUDE_TAGS ) +fi + +if [ -n $INCLUDE_TAGS ]; +then + properties=( ${properties[@]} -Dtest.include.tags=$INCLUDE_TAGS ) +fi + ../../../build/mvn integration-test ${properties[@]} diff --git a/resource-managers/kubernetes/integration-tests/pom.xml b/resource-managers/kubernetes/integration-tests/pom.xml index 520bda89e034d..29334cc6d891d 100644 --- a/resource-managers/kubernetes/integration-tests/pom.xml +++ b/resource-managers/kubernetes/integration-tests/pom.xml @@ -40,6 +40,7 @@ minikube docker.io/kubespark + jar Spark Project Kubernetes Integration Tests @@ -62,6 +63,11 @@ kubernetes-client ${kubernetes-client.version} + + org.apache.spark + spark-tags_${scala.binary.version} + test-jar + @@ -102,6 +108,15 @@ + + + org.apache.maven.plugins + maven-surefire-plugin + + true + + + @@ -126,6 +141,7 @@ ${spark.kubernetes.test.serviceAccountName} ${test.exclude.tags} + ${test.include.tags} diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala index 65c513cf241a4..774c3936b877c 100644 --- a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala +++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala @@ -21,17 +21,17 @@ import java.nio.file.{Path, Paths} import java.util.UUID import java.util.regex.Pattern -import scala.collection.JavaConverters._ - import com.google.common.io.PatternFilenameFilter -import io.fabric8.kubernetes.api.model.{Container, Pod} +import io.fabric8.kubernetes.api.model.Pod import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll} import org.scalatest.concurrent.{Eventually, PatienceConfiguration} import org.scalatest.time.{Minutes, Seconds, Span} +import scala.collection.JavaConverters._ import org.apache.spark.SparkFunSuite import org.apache.spark.deploy.k8s.integrationtest.backend.{IntegrationTestBackend, IntegrationTestBackendFactory} import org.apache.spark.deploy.k8s.integrationtest.config._ +import org.apache.spark.launcher.SparkLauncher private[spark] class KubernetesSuite extends SparkFunSuite with BeforeAndAfterAll with BeforeAndAfter { @@ -43,6 +43,7 @@ private[spark] class KubernetesSuite extends SparkFunSuite private var kubernetesTestComponents: KubernetesTestComponents = _ private var sparkAppConf: SparkAppConf = _ private var image: String = _ + private var pyImage: String = _ private var containerLocalSparkDistroExamplesJar: String = _ private var appLocator: String = _ private var driverPodName: String = _ @@ -65,6 +66,7 @@ private[spark] class KubernetesSuite extends SparkFunSuite val imageTag = getTestImageTag val imageRepo = getTestImageRepo image = s"$imageRepo/spark:$imageTag" + pyImage = s"$imageRepo/spark-py:$imageTag" val sparkDistroExamplesJarFile: File = sparkHomeDir.resolve(Paths.get("examples", "jars")) .toFile @@ -109,6 +111,12 @@ private[spark] class KubernetesSuite extends SparkFunSuite runSparkPiAndVerifyCompletion() } + test("Use SparkLauncher.NO_RESOURCE") { + sparkAppConf.setJars(Seq(containerLocalSparkDistroExamplesJar)) + runSparkPiAndVerifyCompletion( + appResource = SparkLauncher.NO_RESOURCE) + } + test("Run SparkPi with a master URL without a scheme.") { val url = kubernetesTestComponents.kubernetesClient.getMasterUrl val k8sMasterUrl = if (url.getPort < 0) { @@ -150,22 +158,77 @@ private[spark] class KubernetesSuite extends SparkFunSuite }) } - // TODO(ssuchter): Enable the below after debugging - // test("Run PageRank using remote data file") { - // sparkAppConf - // .set("spark.kubernetes.mountDependencies.filesDownloadDir", - // CONTAINER_LOCAL_FILE_DOWNLOAD_PATH) - // .set("spark.files", REMOTE_PAGE_RANK_DATA_FILE) - // runSparkPageRankAndVerifyCompletion( - // appArgs = Array(CONTAINER_LOCAL_DOWNLOADED_PAGE_RANK_DATA_FILE)) - // } + test("Run extraJVMOptions check on driver") { + sparkAppConf + .set("spark.driver.extraJavaOptions", "-Dspark.test.foo=spark.test.bar") + runSparkJVMCheckAndVerifyCompletion( + expectedJVMValue = Seq("(spark.test.foo,spark.test.bar)")) + } + + test("Run SparkRemoteFileTest using a remote data file") { + sparkAppConf + .set("spark.files", REMOTE_PAGE_RANK_DATA_FILE) + runSparkRemoteCheckAndVerifyCompletion( + appArgs = Array(REMOTE_PAGE_RANK_FILE_NAME)) + } + + test("Run PySpark on simple pi.py example") { + sparkAppConf + .set("spark.kubernetes.container.image", s"${getTestImageRepo}/spark-py:${getTestImageTag}") + runSparkApplicationAndVerifyCompletion( + appResource = PYSPARK_PI, + mainClass = "", + expectedLogOnCompletion = Seq("Pi is roughly 3"), + appArgs = Array("5"), + driverPodChecker = doBasicDriverPyPodCheck, + executorPodChecker = doBasicExecutorPyPodCheck, + appLocator = appLocator, + isJVM = false) + } + + test("Run PySpark with Python2 to test a pyfiles example") { + sparkAppConf + .set("spark.kubernetes.container.image", s"${getTestImageRepo}/spark-py:${getTestImageTag}") + .set("spark.kubernetes.pyspark.pythonversion", "2") + runSparkApplicationAndVerifyCompletion( + appResource = PYSPARK_FILES, + mainClass = "", + expectedLogOnCompletion = Seq( + "Python runtime version check is: True", + "Python environment version check is: True"), + appArgs = Array("python"), + driverPodChecker = doBasicDriverPyPodCheck, + executorPodChecker = doBasicExecutorPyPodCheck, + appLocator = appLocator, + isJVM = false, + pyFiles = Some(PYSPARK_CONTAINER_TESTS)) + } + + test("Run PySpark with Python3 to test a pyfiles example") { + sparkAppConf + .set("spark.kubernetes.container.image", s"${getTestImageRepo}/spark-py:${getTestImageTag}") + .set("spark.kubernetes.pyspark.pythonversion", "3") + runSparkApplicationAndVerifyCompletion( + appResource = PYSPARK_FILES, + mainClass = "", + expectedLogOnCompletion = Seq( + "Python runtime version check is: True", + "Python environment version check is: True"), + appArgs = Array("python3"), + driverPodChecker = doBasicDriverPyPodCheck, + executorPodChecker = doBasicExecutorPyPodCheck, + appLocator = appLocator, + isJVM = false, + pyFiles = Some(PYSPARK_CONTAINER_TESTS)) + } private def runSparkPiAndVerifyCompletion( appResource: String = containerLocalSparkDistroExamplesJar, driverPodChecker: Pod => Unit = doBasicDriverPodCheck, executorPodChecker: Pod => Unit = doBasicExecutorPodCheck, appArgs: Array[String] = Array.empty[String], - appLocator: String = appLocator): Unit = { + appLocator: String = appLocator, + isJVM: Boolean = true ): Unit = { runSparkApplicationAndVerifyCompletion( appResource, SPARK_PI_MAIN_CLASS, @@ -173,10 +236,11 @@ private[spark] class KubernetesSuite extends SparkFunSuite appArgs, driverPodChecker, executorPodChecker, - appLocator) + appLocator, + isJVM) } - private def runSparkPageRankAndVerifyCompletion( + private def runSparkRemoteCheckAndVerifyCompletion( appResource: String = containerLocalSparkDistroExamplesJar, driverPodChecker: Pod => Unit = doBasicDriverPodCheck, executorPodChecker: Pod => Unit = doBasicExecutorPodCheck, @@ -184,12 +248,50 @@ private[spark] class KubernetesSuite extends SparkFunSuite appLocator: String = appLocator): Unit = { runSparkApplicationAndVerifyCompletion( appResource, - SPARK_PAGE_RANK_MAIN_CLASS, - Seq("1 has rank", "2 has rank", "3 has rank", "4 has rank"), + SPARK_REMOTE_MAIN_CLASS, + Seq(s"Mounting of ${appArgs.head} was true"), appArgs, driverPodChecker, executorPodChecker, - appLocator) + appLocator, + true) + } + + private def runSparkJVMCheckAndVerifyCompletion( + appResource: String = containerLocalSparkDistroExamplesJar, + mainClass: String = SPARK_DRIVER_MAIN_CLASS, + driverPodChecker: Pod => Unit = doBasicDriverPodCheck, + appArgs: Array[String] = Array("5"), + expectedJVMValue: Seq[String]): Unit = { + val appArguments = SparkAppArguments( + mainAppResource = appResource, + mainClass = mainClass, + appArgs = appArgs) + SparkAppLauncher.launch( + appArguments, + sparkAppConf, + TIMEOUT.value.toSeconds.toInt, + sparkHomeDir, + true) + + val driverPod = kubernetesTestComponents.kubernetesClient + .pods() + .withLabel("spark-app-locator", appLocator) + .withLabel("spark-role", "driver") + .list() + .getItems + .get(0) + doBasicDriverPodCheck(driverPod) + + Eventually.eventually(TIMEOUT, INTERVAL) { + expectedJVMValue.foreach { e => + assert(kubernetesTestComponents.kubernetesClient + .pods() + .withName(driverPod.getMetadata.getName) + .getLog + .contains(e), "The application did not complete.") + } + } } private def runSparkApplicationAndVerifyCompletion( @@ -199,12 +301,20 @@ private[spark] class KubernetesSuite extends SparkFunSuite appArgs: Array[String], driverPodChecker: Pod => Unit, executorPodChecker: Pod => Unit, - appLocator: String): Unit = { + appLocator: String, + isJVM: Boolean, + pyFiles: Option[String] = None): Unit = { val appArguments = SparkAppArguments( mainAppResource = appResource, mainClass = mainClass, appArgs = appArgs) - SparkAppLauncher.launch(appArguments, sparkAppConf, TIMEOUT.value.toSeconds.toInt, sparkHomeDir) + SparkAppLauncher.launch( + appArguments, + sparkAppConf, + TIMEOUT.value.toSeconds.toInt, + sparkHomeDir, + isJVM, + pyFiles) val driverPod = kubernetesTestComponents.kubernetesClient .pods() @@ -242,11 +352,22 @@ private[spark] class KubernetesSuite extends SparkFunSuite assert(driverPod.getSpec.getContainers.get(0).getName === "spark-kubernetes-driver") } + private def doBasicDriverPyPodCheck(driverPod: Pod): Unit = { + assert(driverPod.getMetadata.getName === driverPodName) + assert(driverPod.getSpec.getContainers.get(0).getImage === pyImage) + assert(driverPod.getSpec.getContainers.get(0).getName === "spark-kubernetes-driver") + } + private def doBasicExecutorPodCheck(executorPod: Pod): Unit = { assert(executorPod.getSpec.getContainers.get(0).getImage === image) assert(executorPod.getSpec.getContainers.get(0).getName === "executor") } + private def doBasicExecutorPyPodCheck(executorPod: Pod): Unit = { + assert(executorPod.getSpec.getContainers.get(0).getImage === pyImage) + assert(executorPod.getSpec.getContainers.get(0).getName === "executor") + } + private def checkCustomSettings(pod: Pod): Unit = { assert(pod.getMetadata.getLabels.get("label1") === "label1-value") assert(pod.getMetadata.getLabels.get("label2") === "label2-value") @@ -281,14 +402,22 @@ private[spark] object KubernetesSuite { val TIMEOUT = PatienceConfiguration.Timeout(Span(2, Minutes)) val INTERVAL = PatienceConfiguration.Interval(Span(2, Seconds)) val SPARK_PI_MAIN_CLASS: String = "org.apache.spark.examples.SparkPi" + val SPARK_REMOTE_MAIN_CLASS: String = "org.apache.spark.examples.SparkRemoteFileTest" + val SPARK_DRIVER_MAIN_CLASS: String = "org.apache.spark.examples.DriverSubmissionTest" val SPARK_PAGE_RANK_MAIN_CLASS: String = "org.apache.spark.examples.SparkPageRank" + val CONTAINER_LOCAL_PYSPARK: String = "local:///opt/spark/examples/src/main/python/" + val PYSPARK_PI: String = CONTAINER_LOCAL_PYSPARK + "pi.py" + val PYSPARK_FILES: String = CONTAINER_LOCAL_PYSPARK + "pyfiles.py" + val PYSPARK_CONTAINER_TESTS: String = CONTAINER_LOCAL_PYSPARK + "py_container_checks.py" - // val CONTAINER_LOCAL_FILE_DOWNLOAD_PATH = "/var/spark-data/spark-files" + val TEST_SECRET_NAME_PREFIX = "test-secret-" + val TEST_SECRET_KEY = "test-key" + val TEST_SECRET_VALUE = "test-data" + val TEST_SECRET_MOUNT_PATH = "/etc/secrets" - // val REMOTE_PAGE_RANK_DATA_FILE = - // "https://storage.googleapis.com/spark-k8s-integration-tests/files/pagerank_data.txt" - // val CONTAINER_LOCAL_DOWNLOADED_PAGE_RANK_DATA_FILE = - // s"$CONTAINER_LOCAL_FILE_DOWNLOAD_PATH/pagerank_data.txt" + val REMOTE_PAGE_RANK_DATA_FILE = + "https://storage.googleapis.com/spark-k8s-integration-tests/files/pagerank_data.txt" + val REMOTE_PAGE_RANK_FILE_NAME = "pagerank_data.txt" - // case object ShuffleNotReadyException extends Exception + case object ShuffleNotReadyException extends Exception } diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesTestComponents.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesTestComponents.scala index 48727142dd052..a9b49a8e5a610 100644 --- a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesTestComponents.scala +++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesTestComponents.scala @@ -97,24 +97,33 @@ private[spark] case class SparkAppArguments( appArgs: Array[String]) private[spark] object SparkAppLauncher extends Logging { - def launch( appArguments: SparkAppArguments, appConf: SparkAppConf, timeoutSecs: Int, - sparkHomeDir: Path): Unit = { + sparkHomeDir: Path, + isJVM: Boolean, + pyFiles: Option[String] = None): Unit = { val sparkSubmitExecutable = sparkHomeDir.resolve(Paths.get("bin", "spark-submit")) logInfo(s"Launching a spark app with arguments $appArguments and conf $appConf") - val appArgsArray = - if (appArguments.appArgs.length > 0) Array(appArguments.appArgs.mkString(" ")) - else Array[String]() - val commandLine = (Array(sparkSubmitExecutable.toFile.getAbsolutePath, + val preCommandLine = if (isJVM) { + mutable.ArrayBuffer(sparkSubmitExecutable.toFile.getAbsolutePath, "--deploy-mode", "cluster", "--class", appArguments.mainClass, - "--master", appConf.get("spark.master") - ) ++ appConf.toStringArray :+ - appArguments.mainAppResource) ++ - appArgsArray - ProcessUtils.executeProcess(commandLine, timeoutSecs) + "--master", appConf.get("spark.master")) + } else { + mutable.ArrayBuffer(sparkSubmitExecutable.toFile.getAbsolutePath, + "--deploy-mode", "cluster", + "--master", appConf.get("spark.master")) + } + val commandLine = + pyFiles.map(s => preCommandLine ++ Array("--py-files", s)).getOrElse(preCommandLine) ++ + appConf.toStringArray :+ appArguments.mainAppResource + + if (appArguments.appArgs.nonEmpty) { + commandLine += appArguments.appArgs.mkString(" ") + } + logInfo(s"Launching a spark app with command line: ${commandLine.mkString(" ")}") + ProcessUtils.executeProcess(commandLine.toArray, timeoutSecs) } } diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/ui/MesosClusterUI.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/ui/MesosClusterUI.scala index 604978967d6db..15bbe60d6c8fb 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/ui/MesosClusterUI.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/ui/MesosClusterUI.scala @@ -40,7 +40,7 @@ private[spark] class MesosClusterUI( override def initialize() { attachPage(new MesosClusterPage(this)) attachPage(new DriverPage(this)) - attachHandler(createStaticHandler(MesosClusterUI.STATIC_RESOURCE_DIR, "/static")) + addStaticHandler(MesosClusterUI.STATIC_RESOURCE_DIR) } } diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala index d35bea4aca311..1ce2f816dffb2 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala @@ -634,7 +634,7 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( slave.hostname, externalShufflePort, sc.conf.getTimeAsMs("spark.storage.blockManagerSlaveTimeoutMs", - s"${sc.conf.getTimeAsSeconds("spark.network.timeout", "120s") * 1000L}ms"), + s"${sc.conf.getTimeAsSeconds("spark.network.timeout", "120s")}s"), sc.conf.getTimeAsMs("spark.executor.heartbeatInterval", "10s")) slave.shuffleRegistered = true } diff --git a/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackendSuite.scala b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackendSuite.scala index f4bd1ee9da6f7..b790c7cd27794 100644 --- a/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackendSuite.scala +++ b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackendSuite.scala @@ -789,6 +789,7 @@ class MesosCoarseGrainedSchedulerBackendSuite extends SparkFunSuite when(driver.start()).thenReturn(Protos.Status.DRIVER_RUNNING) taskScheduler = mock[TaskSchedulerImpl] + when(taskScheduler.nodeBlacklist).thenReturn(Set[String]()) when(taskScheduler.sc).thenReturn(sc) externalShuffleClient = mock[MesosExternalShuffleClient] diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala index 3d6ee50b070a3..ecc576910db9e 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -515,6 +515,10 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments) extends finish(FinalApplicationStatus.FAILED, ApplicationMaster.EXIT_MAX_EXECUTOR_FAILURES, s"Max number of executor failures ($maxNumExecutorFailures) reached") + } else if (allocator.isAllNodeBlacklisted) { + finish(FinalApplicationStatus.FAILED, + ApplicationMaster.EXIT_MAX_EXECUTOR_FAILURES, + "Due to executor failures all available nodes are blacklisted") } else { logDebug("Sending progress") allocator.allocateResources() diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index 7225ff03dc34e..793d012218490 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -899,7 +899,8 @@ private[spark] class Client( val libraryPaths = Seq(sparkConf.get(DRIVER_LIBRARY_PATH), sys.props.get("spark.driver.libraryPath")).flatten if (libraryPaths.nonEmpty) { - prefixEnv = Some(getClusterPath(sparkConf, Utils.libraryPathEnvPrefix(libraryPaths))) + prefixEnv = Some(createLibraryPathPrefix(libraryPaths.mkString(File.pathSeparator), + sparkConf)) } if (sparkConf.get(AM_JAVA_OPTIONS).isDefined) { logWarning(s"${AM_JAVA_OPTIONS.key} will not take effect in cluster mode") @@ -921,7 +922,7 @@ private[spark] class Client( .map(YarnSparkHadoopUtil.escapeForShell) } sparkConf.get(AM_LIBRARY_PATH).foreach { paths => - prefixEnv = Some(getClusterPath(sparkConf, Utils.libraryPathEnvPrefix(Seq(paths)))) + prefixEnv = Some(createLibraryPathPrefix(paths, sparkConf)) } } @@ -1485,6 +1486,23 @@ private object Client extends Logging { YarnAppReport(report.getYarnApplicationState(), report.getFinalApplicationStatus(), diagsOpt) } + /** + * Create a properly quoted and escaped library path string to be added as a prefix to the command + * executed by YARN. This is different from normal quoting / escaping due to YARN executing the + * command through "bash -c". + */ + def createLibraryPathPrefix(libpath: String, conf: SparkConf): String = { + val cmdPrefix = if (Utils.isWindows) { + Utils.libraryPathEnvPrefix(Seq(libpath)) + } else { + val envName = Utils.libraryPathEnvName + // For quotes, escape both the quote and the escape character when encoding in the command + // string. + val quoted = libpath.replace("\"", "\\\\\\\"") + envName + "=\\\"" + quoted + File.pathSeparator + "$" + envName + "\\\"" + } + getClusterPath(conf, cmdPrefix) + } } private[spark] class YarnClusterApplication extends SparkApplication { diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala index a2a18cdff65af..49a0b93aa5c40 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala @@ -131,10 +131,6 @@ private[yarn] class ExecutorRunnable( // Extra options for the JVM val javaOpts = ListBuffer[String]() - // Set the environment variable through a command prefix - // to append to the existing value of the variable - var prefixEnv: Option[String] = None - // Set the JVM memory val executorMemoryString = executorMemory + "m" javaOpts += "-Xmx" + executorMemoryString @@ -144,8 +140,11 @@ private[yarn] class ExecutorRunnable( val subsOpt = Utils.substituteAppNExecIds(opts, appId, executorId) javaOpts ++= Utils.splitCommandString(subsOpt).map(YarnSparkHadoopUtil.escapeForShell) } - sparkConf.get(EXECUTOR_LIBRARY_PATH).foreach { p => - prefixEnv = Some(Client.getClusterPath(sparkConf, Utils.libraryPathEnvPrefix(Seq(p)))) + + // Set the library path through a command prefix to append to the existing value of the + // env variable. + val prefixEnv = sparkConf.get(EXECUTOR_LIBRARY_PATH).map { libPath => + Client.createLibraryPathPrefix(libPath, sparkConf) } javaOpts += "-Djava.io.tmpdir=" + diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala index ebee3d431744d..fae054e0eea00 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala @@ -24,7 +24,7 @@ import java.util.regex.Pattern import scala.collection.JavaConverters._ import scala.collection.mutable -import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Queue} +import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet} import scala.util.control.NonFatal import org.apache.hadoop.yarn.api.records._ @@ -66,7 +66,8 @@ private[yarn] class YarnAllocator( appAttemptId: ApplicationAttemptId, securityMgr: SecurityManager, localResources: Map[String, LocalResource], - resolver: SparkRackResolver) + resolver: SparkRackResolver, + clock: Clock = new SystemClock) extends Logging { import YarnAllocator._ @@ -102,18 +103,14 @@ private[yarn] class YarnAllocator( private var executorIdCounter: Int = driverRef.askSync[Int](RetrieveLastAllocatedExecutorId) - // Queue to store the timestamp of failed executors - private val failedExecutorsTimeStamps = new Queue[Long]() + private[spark] val failureTracker = new FailureTracker(sparkConf, clock) - private var clock: Clock = new SystemClock - - private val executorFailuresValidityInterval = - sparkConf.get(EXECUTOR_ATTEMPT_FAILURE_VALIDITY_INTERVAL_MS).getOrElse(-1L) + private val allocatorBlacklistTracker = + new YarnAllocatorBlacklistTracker(sparkConf, amClient, failureTracker) @volatile private var targetNumExecutors = SchedulerBackendUtils.getInitialTargetExecutorNumber(sparkConf) - private var currentNodeBlacklist = Set.empty[String] // Executor loss reason requests that are pending - maps from executor ID for inquiry to a // list of requesters that should be responded to once we find out why the given executor @@ -149,7 +146,6 @@ private[yarn] class YarnAllocator( private val labelExpression = sparkConf.get(EXECUTOR_NODE_LABEL_EXPRESSION) - // A map to store preferred hostname and possible task numbers running on it. private var hostToLocalTaskCounts: Map[String, Int] = Map.empty @@ -160,26 +156,11 @@ private[yarn] class YarnAllocator( private[yarn] val containerPlacementStrategy = new LocalityPreferredContainerPlacementStrategy(sparkConf, conf, resource, resolver) - /** - * Use a different clock for YarnAllocator. This is mainly used for testing. - */ - def setClock(newClock: Clock): Unit = { - clock = newClock - } - def getNumExecutorsRunning: Int = runningExecutors.size() - def getNumExecutorsFailed: Int = synchronized { - val endTime = clock.getTimeMillis() + def getNumExecutorsFailed: Int = failureTracker.numFailedExecutors - while (executorFailuresValidityInterval > 0 - && failedExecutorsTimeStamps.nonEmpty - && failedExecutorsTimeStamps.head < endTime - executorFailuresValidityInterval) { - failedExecutorsTimeStamps.dequeue() - } - - failedExecutorsTimeStamps.size - } + def isAllNodeBlacklisted: Boolean = allocatorBlacklistTracker.isAllNodeBlacklisted /** * A sequence of pending container requests that have not yet been fulfilled. @@ -204,9 +185,8 @@ private[yarn] class YarnAllocator( * @param localityAwareTasks number of locality aware tasks to be used as container placement hint * @param hostToLocalTaskCount a map of preferred hostname to possible task counts to be used as * container placement hint. - * @param nodeBlacklist a set of blacklisted nodes, which is passed in to avoid allocating new - * containers on them. It will be used to update the application master's - * blacklist. + * @param nodeBlacklist blacklisted nodes, which is passed in to avoid allocating new containers + * on them. It will be used to update the application master's blacklist. * @return Whether the new requested total is different than the old value. */ def requestTotalExecutorsWithPreferredLocalities( @@ -220,19 +200,7 @@ private[yarn] class YarnAllocator( if (requestedTotal != targetNumExecutors) { logInfo(s"Driver requested a total number of $requestedTotal executor(s).") targetNumExecutors = requestedTotal - - // Update blacklist infomation to YARN ResouceManager for this application, - // in order to avoid allocating new Containers on the problematic nodes. - val blacklistAdditions = nodeBlacklist -- currentNodeBlacklist - val blacklistRemovals = currentNodeBlacklist -- nodeBlacklist - if (blacklistAdditions.nonEmpty) { - logInfo(s"adding nodes to YARN application master's blacklist: $blacklistAdditions") - } - if (blacklistRemovals.nonEmpty) { - logInfo(s"removing nodes from YARN application master's blacklist: $blacklistRemovals") - } - amClient.updateBlacklist(blacklistAdditions.toList.asJava, blacklistRemovals.toList.asJava) - currentNodeBlacklist = nodeBlacklist + allocatorBlacklistTracker.setSchedulerBlacklistedNodes(nodeBlacklist) true } else { false @@ -268,6 +236,7 @@ private[yarn] class YarnAllocator( val allocateResponse = amClient.allocate(progressIndicator) val allocatedContainers = allocateResponse.getAllocatedContainers() + allocatorBlacklistTracker.setNumClusterNodes(allocateResponse.getNumClusterNodes) if (allocatedContainers.size > 0) { logDebug(("Allocated containers: %d. Current executor count: %d. " + @@ -602,8 +571,9 @@ private[yarn] class YarnAllocator( completedContainer.getDiagnostics, PMEM_EXCEEDED_PATTERN)) case _ => - // Enqueue the timestamp of failed executor - failedExecutorsTimeStamps.enqueue(clock.getTimeMillis()) + // all the failures which not covered above, like: + // disk failure, kill by app master or resource manager, ... + allocatorBlacklistTracker.handleResourceAllocationFailure(hostOpt) (true, "Container marked as failed: " + containerId + onHostStr + ". Exit status: " + completedContainer.getExitStatus + ". Diagnostics: " + completedContainer.getDiagnostics) diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocatorBlacklistTracker.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocatorBlacklistTracker.scala new file mode 100644 index 0000000000000..1b48a0ee7ad32 --- /dev/null +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocatorBlacklistTracker.scala @@ -0,0 +1,187 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.yarn + +import scala.collection.JavaConverters._ +import scala.collection.mutable +import scala.collection.mutable.HashMap + +import org.apache.hadoop.yarn.client.api.AMRMClient +import org.apache.hadoop.yarn.client.api.AMRMClient.ContainerRequest + +import org.apache.spark.SparkConf +import org.apache.spark.deploy.yarn.config._ +import org.apache.spark.internal.Logging +import org.apache.spark.internal.config._ +import org.apache.spark.scheduler.BlacklistTracker +import org.apache.spark.util.{Clock, SystemClock, Utils} + +/** + * YarnAllocatorBlacklistTracker is responsible for tracking the blacklisted nodes + * and synchronizing the node list to YARN. + * + * Blacklisted nodes are coming from two different sources: + * + *
      + *
    • from the scheduler as task level blacklisted nodes + *
    • from this class (tracked here) as YARN resource allocation problems + *
    + * + * The reason to realize this logic here (and not in the driver) is to avoid possible delays + * between synchronizing the blacklisted nodes with YARN and resource allocations. + */ +private[spark] class YarnAllocatorBlacklistTracker( + sparkConf: SparkConf, + amClient: AMRMClient[ContainerRequest], + failureTracker: FailureTracker) + extends Logging { + + private val blacklistTimeoutMillis = BlacklistTracker.getBlacklistTimeout(sparkConf) + + private val launchBlacklistEnabled = sparkConf.get(YARN_EXECUTOR_LAUNCH_BLACKLIST_ENABLED) + + private val maxFailuresPerHost = sparkConf.get(MAX_FAILED_EXEC_PER_NODE) + + private val allocatorBlacklist = new HashMap[String, Long]() + + private var currentBlacklistedYarnNodes = Set.empty[String] + + private var schedulerBlacklist = Set.empty[String] + + private var numClusterNodes = Int.MaxValue + + def setNumClusterNodes(numClusterNodes: Int): Unit = { + this.numClusterNodes = numClusterNodes + } + + def handleResourceAllocationFailure(hostOpt: Option[String]): Unit = { + hostOpt match { + case Some(hostname) if launchBlacklistEnabled => + // failures on an already blacklisted nodes are not even tracked. + // otherwise, such failures could shutdown the application + // as resource requests are asynchronous + // and a late failure response could exceed MAX_EXECUTOR_FAILURES + if (!schedulerBlacklist.contains(hostname) && + !allocatorBlacklist.contains(hostname)) { + failureTracker.registerFailureOnHost(hostname) + updateAllocationBlacklistedNodes(hostname) + } + case _ => + failureTracker.registerExecutorFailure() + } + } + + private def updateAllocationBlacklistedNodes(hostname: String): Unit = { + val failuresOnHost = failureTracker.numFailuresOnHost(hostname) + if (failuresOnHost > maxFailuresPerHost) { + logInfo(s"blacklisting $hostname as YARN allocation failed $failuresOnHost times") + allocatorBlacklist.put( + hostname, + failureTracker.clock.getTimeMillis() + blacklistTimeoutMillis) + refreshBlacklistedNodes() + } + } + + def setSchedulerBlacklistedNodes(schedulerBlacklistedNodesWithExpiry: Set[String]): Unit = { + this.schedulerBlacklist = schedulerBlacklistedNodesWithExpiry + refreshBlacklistedNodes() + } + + def isAllNodeBlacklisted: Boolean = currentBlacklistedYarnNodes.size >= numClusterNodes + + private def refreshBlacklistedNodes(): Unit = { + removeExpiredYarnBlacklistedNodes() + val allBlacklistedNodes = schedulerBlacklist ++ allocatorBlacklist.keySet + synchronizeBlacklistedNodeWithYarn(allBlacklistedNodes) + } + + private def synchronizeBlacklistedNodeWithYarn(nodesToBlacklist: Set[String]): Unit = { + // Update blacklist information to YARN ResourceManager for this application, + // in order to avoid allocating new Containers on the problematic nodes. + val additions = (nodesToBlacklist -- currentBlacklistedYarnNodes).toList.sorted + val removals = (currentBlacklistedYarnNodes -- nodesToBlacklist).toList.sorted + if (additions.nonEmpty) { + logInfo(s"adding nodes to YARN application master's blacklist: $additions") + } + if (removals.nonEmpty) { + logInfo(s"removing nodes from YARN application master's blacklist: $removals") + } + amClient.updateBlacklist(additions.asJava, removals.asJava) + currentBlacklistedYarnNodes = nodesToBlacklist + } + + private def removeExpiredYarnBlacklistedNodes(): Unit = { + val now = failureTracker.clock.getTimeMillis() + allocatorBlacklist.retain { (_, expiryTime) => expiryTime > now } + } +} + +/** + * FailureTracker is responsible for tracking executor failures both for each host separately + * and for all hosts altogether. + */ +private[spark] class FailureTracker( + sparkConf: SparkConf, + val clock: Clock = new SystemClock) extends Logging { + + private val executorFailuresValidityInterval = + sparkConf.get(config.EXECUTOR_ATTEMPT_FAILURE_VALIDITY_INTERVAL_MS).getOrElse(-1L) + + // Queue to store the timestamp of failed executors for each host + private val failedExecutorsTimeStampsPerHost = mutable.Map[String, mutable.Queue[Long]]() + + private val failedExecutorsTimeStamps = new mutable.Queue[Long]() + + private def updateAndCountFailures(failedExecutorsWithTimeStamps: mutable.Queue[Long]): Int = { + val endTime = clock.getTimeMillis() + while (executorFailuresValidityInterval > 0 && + failedExecutorsWithTimeStamps.nonEmpty && + failedExecutorsWithTimeStamps.head < endTime - executorFailuresValidityInterval) { + failedExecutorsWithTimeStamps.dequeue() + } + failedExecutorsWithTimeStamps.size + } + + def numFailedExecutors: Int = synchronized { + updateAndCountFailures(failedExecutorsTimeStamps) + } + + def registerFailureOnHost(hostname: String): Unit = synchronized { + val timeMillis = clock.getTimeMillis() + failedExecutorsTimeStamps.enqueue(timeMillis) + val failedExecutorsOnHost = + failedExecutorsTimeStampsPerHost.getOrElse(hostname, { + val failureOnHost = mutable.Queue[Long]() + failedExecutorsTimeStampsPerHost.put(hostname, failureOnHost) + failureOnHost + }) + failedExecutorsOnHost.enqueue(timeMillis) + } + + def registerExecutorFailure(): Unit = synchronized { + val timeMillis = clock.getTimeMillis() + failedExecutorsTimeStamps.enqueue(timeMillis) + } + + def numFailuresOnHost(hostname: String): Int = { + failedExecutorsTimeStampsPerHost.get(hostname).map { failedExecutorsOnHost => + updateAndCountFailures(failedExecutorsOnHost) + }.getOrElse(0) + } + +} + diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/config.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/config.scala index 1a99b3bd57672..129084a86597a 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/config.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/config.scala @@ -328,4 +328,10 @@ package object config { CACHED_FILES_TYPES, CACHED_CONF_ARCHIVE) + /* YARN allocator-level blacklisting related config entries. */ + private[spark] val YARN_EXECUTOR_LAUNCH_BLACKLIST_ENABLED = + ConfigBuilder("spark.yarn.blacklist.executor.launch.blacklisting.enabled") + .booleanConf + .createWithDefault(false) + } diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/BaseYarnClusterSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/BaseYarnClusterSuite.scala index ac67f2196e0a0..b0abcc9149d08 100644 --- a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/BaseYarnClusterSuite.scala +++ b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/BaseYarnClusterSuite.scala @@ -36,6 +36,7 @@ import org.scalatest.concurrent.Eventually._ import org.apache.spark._ import org.apache.spark.deploy.yarn.config._ import org.apache.spark.internal.Logging +import org.apache.spark.internal.config._ import org.apache.spark.launcher._ import org.apache.spark.util.Utils @@ -216,6 +217,14 @@ abstract class BaseYarnClusterSuite props.setProperty("spark.driver.extraJavaOptions", "-Dfoo=\"one two three\"") props.setProperty("spark.executor.extraJavaOptions", "-Dfoo=\"one two three\"") + // SPARK-24446: make sure special characters in the library path do not break containers. + if (!Utils.isWindows) { + val libPath = """/tmp/does not exist:$PWD/tmp:/tmp/quote":/tmp/ampersand&""" + props.setProperty(AM_LIBRARY_PATH.key, libPath) + props.setProperty(DRIVER_LIBRARY_PATH.key, libPath) + props.setProperty(EXECUTOR_LIBRARY_PATH.key, libPath) + } + yarnCluster.getConfig().asScala.foreach { e => props.setProperty("spark.hadoop." + e.getKey(), e.getValue()) } diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/FailureTrackerSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/FailureTrackerSuite.scala new file mode 100644 index 0000000000000..4f77b9c99dd25 --- /dev/null +++ b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/FailureTrackerSuite.scala @@ -0,0 +1,100 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.yarn + +import org.scalatest.Matchers + +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.util.ManualClock + +class FailureTrackerSuite extends SparkFunSuite with Matchers { + + override def beforeAll(): Unit = { + super.beforeAll() + } + + test("failures expire if validity interval is set") { + val sparkConf = new SparkConf() + sparkConf.set(config.EXECUTOR_ATTEMPT_FAILURE_VALIDITY_INTERVAL_MS, 100L) + + val clock = new ManualClock() + val failureTracker = new FailureTracker(sparkConf, clock) + + clock.setTime(0) + failureTracker.registerFailureOnHost("host1") + failureTracker.numFailuresOnHost("host1") should be (1) + failureTracker.numFailedExecutors should be (1) + + clock.setTime(10) + failureTracker.registerFailureOnHost("host2") + failureTracker.numFailuresOnHost("host2") should be (1) + failureTracker.numFailedExecutors should be (2) + + clock.setTime(20) + failureTracker.registerFailureOnHost("host1") + failureTracker.numFailuresOnHost("host1") should be (2) + failureTracker.numFailedExecutors should be (3) + + clock.setTime(30) + failureTracker.registerFailureOnHost("host2") + failureTracker.numFailuresOnHost("host2") should be (2) + failureTracker.numFailedExecutors should be (4) + + clock.setTime(101) + failureTracker.numFailuresOnHost("host1") should be (1) + failureTracker.numFailedExecutors should be (3) + + clock.setTime(231) + failureTracker.numFailuresOnHost("host1") should be (0) + failureTracker.numFailuresOnHost("host2") should be (0) + failureTracker.numFailedExecutors should be (0) + } + + + test("failures never expire if validity interval is not set (-1)") { + val sparkConf = new SparkConf() + + val clock = new ManualClock() + val failureTracker = new FailureTracker(sparkConf, clock) + + clock.setTime(0) + failureTracker.registerFailureOnHost("host1") + failureTracker.numFailuresOnHost("host1") should be (1) + failureTracker.numFailedExecutors should be (1) + + clock.setTime(10) + failureTracker.registerFailureOnHost("host2") + failureTracker.numFailuresOnHost("host2") should be (1) + failureTracker.numFailedExecutors should be (2) + + clock.setTime(20) + failureTracker.registerFailureOnHost("host1") + failureTracker.numFailuresOnHost("host1") should be (2) + failureTracker.numFailedExecutors should be (3) + + clock.setTime(30) + failureTracker.registerFailureOnHost("host2") + failureTracker.numFailuresOnHost("host2") should be (2) + failureTracker.numFailedExecutors should be (4) + + clock.setTime(1000) + failureTracker.numFailuresOnHost("host1") should be (2) + failureTracker.numFailuresOnHost("host2") should be (2) + failureTracker.numFailedExecutors should be (4) + } + +} diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorBlacklistTrackerSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorBlacklistTrackerSuite.scala new file mode 100644 index 0000000000000..aeac68e6ed330 --- /dev/null +++ b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorBlacklistTrackerSuite.scala @@ -0,0 +1,140 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.yarn + +import java.util.Arrays +import java.util.Collections + +import org.apache.hadoop.yarn.client.api.AMRMClient +import org.apache.hadoop.yarn.client.api.AMRMClient.ContainerRequest +import org.mockito.Mockito._ +import org.scalatest.{BeforeAndAfterEach, Matchers} + +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.deploy.yarn.config.YARN_EXECUTOR_LAUNCH_BLACKLIST_ENABLED +import org.apache.spark.internal.config.{BLACKLIST_TIMEOUT_CONF, MAX_FAILED_EXEC_PER_NODE} +import org.apache.spark.util.ManualClock + +class YarnAllocatorBlacklistTrackerSuite extends SparkFunSuite with Matchers + with BeforeAndAfterEach { + + val BLACKLIST_TIMEOUT = 100L + val MAX_FAILED_EXEC_PER_NODE_VALUE = 2 + + var amClientMock: AMRMClient[ContainerRequest] = _ + var yarnBlacklistTracker: YarnAllocatorBlacklistTracker = _ + var failureTracker: FailureTracker = _ + var clock: ManualClock = _ + + override def beforeEach(): Unit = { + val sparkConf = new SparkConf() + sparkConf.set(BLACKLIST_TIMEOUT_CONF, BLACKLIST_TIMEOUT) + sparkConf.set(YARN_EXECUTOR_LAUNCH_BLACKLIST_ENABLED, true) + sparkConf.set(MAX_FAILED_EXEC_PER_NODE, MAX_FAILED_EXEC_PER_NODE_VALUE) + clock = new ManualClock() + + amClientMock = mock(classOf[AMRMClient[ContainerRequest]]) + failureTracker = new FailureTracker(sparkConf, clock) + yarnBlacklistTracker = + new YarnAllocatorBlacklistTracker(sparkConf, amClientMock, failureTracker) + yarnBlacklistTracker.setNumClusterNodes(4) + super.beforeEach() + } + + test("expiring its own blacklisted nodes") { + (1 to MAX_FAILED_EXEC_PER_NODE_VALUE).foreach { + _ => { + yarnBlacklistTracker.handleResourceAllocationFailure(Some("host")) + // host should not be blacklisted at these failures as MAX_FAILED_EXEC_PER_NODE is 2 + verify(amClientMock, never()) + .updateBlacklist(Arrays.asList("host"), Collections.emptyList()) + } + } + + yarnBlacklistTracker.handleResourceAllocationFailure(Some("host")) + // the third failure on the host triggers the blacklisting + verify(amClientMock).updateBlacklist(Arrays.asList("host"), Collections.emptyList()) + + clock.advance(BLACKLIST_TIMEOUT) + + // trigger synchronisation of blacklisted nodes with YARN + yarnBlacklistTracker.setSchedulerBlacklistedNodes(Set()) + verify(amClientMock).updateBlacklist(Collections.emptyList(), Arrays.asList("host")) + } + + test("not handling the expiry of scheduler blacklisted nodes") { + yarnBlacklistTracker.setSchedulerBlacklistedNodes(Set("host1", "host2")) + verify(amClientMock) + .updateBlacklist(Arrays.asList("host1", "host2"), Collections.emptyList()) + + // advance timer more then host1, host2 expiry time + clock.advance(200L) + + // expired blacklisted nodes (simulating a resource request) + yarnBlacklistTracker.setSchedulerBlacklistedNodes(Set("host1", "host2")) + // no change is communicated to YARN regarding the blacklisting + verify(amClientMock).updateBlacklist(Collections.emptyList(), Collections.emptyList()) + } + + test("combining scheduler and allocation blacklist") { + (1 to MAX_FAILED_EXEC_PER_NODE_VALUE).foreach { + _ => { + yarnBlacklistTracker.handleResourceAllocationFailure(Some("host1")) + // host1 should not be blacklisted at these failures as MAX_FAILED_EXEC_PER_NODE is 2 + verify(amClientMock, never()) + .updateBlacklist(Arrays.asList("host1"), Collections.emptyList()) + } + } + + // as this is the third failure on host1 the node will be blacklisted + yarnBlacklistTracker.handleResourceAllocationFailure(Some("host1")) + verify(amClientMock) + .updateBlacklist(Arrays.asList("host1"), Collections.emptyList()) + + yarnBlacklistTracker.setSchedulerBlacklistedNodes(Set("host2", "host3")) + verify(amClientMock) + .updateBlacklist(Arrays.asList("host2", "host3"), Collections.emptyList()) + + clock.advance(10L) + + yarnBlacklistTracker.setSchedulerBlacklistedNodes(Set("host3", "host4")) + verify(amClientMock) + .updateBlacklist(Arrays.asList("host4"), Arrays.asList("host2")) + } + + test("blacklist all available nodes") { + yarnBlacklistTracker.setSchedulerBlacklistedNodes(Set("host1", "host2", "host3")) + verify(amClientMock) + .updateBlacklist(Arrays.asList("host1", "host2", "host3"), Collections.emptyList()) + + clock.advance(60L) + (1 to MAX_FAILED_EXEC_PER_NODE_VALUE).foreach { + _ => { + yarnBlacklistTracker.handleResourceAllocationFailure(Some("host4")) + // host4 should not be blacklisted at these failures as MAX_FAILED_EXEC_PER_NODE is 2 + verify(amClientMock, never()) + .updateBlacklist(Arrays.asList("host4"), Collections.emptyList()) + } + } + + // the third failure on the host triggers the blacklisting + yarnBlacklistTracker.handleResourceAllocationFailure(Some("host4")) + + verify(amClientMock).updateBlacklist(Arrays.asList("host4"), Collections.emptyList()) + assert(yarnBlacklistTracker.isAllNodeBlacklisted === true) + } +} diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala index 525abb6f2b350..3f783baed110d 100644 --- a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala +++ b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala @@ -59,6 +59,8 @@ class YarnAllocatorSuite extends SparkFunSuite with Matchers with BeforeAndAfter var rmClient: AMRMClient[ContainerRequest] = _ + var clock: ManualClock = _ + var containerNum = 0 override def beforeEach() { @@ -66,6 +68,7 @@ class YarnAllocatorSuite extends SparkFunSuite with Matchers with BeforeAndAfter rmClient = AMRMClient.createAMRMClient() rmClient.init(conf) rmClient.start() + clock = new ManualClock() } override def afterEach() { @@ -101,7 +104,8 @@ class YarnAllocatorSuite extends SparkFunSuite with Matchers with BeforeAndAfter appAttemptId, new SecurityManager(sparkConf), Map(), - new MockResolver()) + new MockResolver(), + clock) } def createContainer(host: String): Container = { @@ -332,10 +336,14 @@ class YarnAllocatorSuite extends SparkFunSuite with Matchers with BeforeAndAfter handler.requestTotalExecutorsWithPreferredLocalities(1, 0, Map(), Set("hostA")) verify(mockAmClient).updateBlacklist(Seq("hostA").asJava, Seq[String]().asJava) - handler.requestTotalExecutorsWithPreferredLocalities(2, 0, Map(), Set("hostA", "hostB")) + val blacklistedNodes = Set( + "hostA", + "hostB" + ) + handler.requestTotalExecutorsWithPreferredLocalities(2, 0, Map(), blacklistedNodes) verify(mockAmClient).updateBlacklist(Seq("hostB").asJava, Seq[String]().asJava) - handler.requestTotalExecutorsWithPreferredLocalities(3, 0, Map(), Set()) + handler.requestTotalExecutorsWithPreferredLocalities(3, 0, Map(), Set.empty) verify(mockAmClient).updateBlacklist(Seq[String]().asJava, Seq("hostA", "hostB").asJava) } @@ -353,8 +361,6 @@ class YarnAllocatorSuite extends SparkFunSuite with Matchers with BeforeAndAfter test("window based failure executor counting") { sparkConf.set("spark.yarn.executor.failuresValidityInterval", "100s") val handler = createAllocator(4) - val clock = new ManualClock(0L) - handler.setClock(clock) handler.updateResourceRequests() handler.getNumExecutorsRunning should be (0) diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 index 3fe00eefde7d8..dc95751bf905c 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 @@ -539,18 +539,11 @@ expression booleanExpression : NOT booleanExpression #logicalNot | EXISTS '(' query ')' #exists - | predicated #booleanDefault + | valueExpression predicate? #predicated | left=booleanExpression operator=AND right=booleanExpression #logicalBinary | left=booleanExpression operator=OR right=booleanExpression #logicalBinary ; -// workaround for: -// https://github.com/antlr/antlr4/issues/780 -// https://github.com/antlr/antlr4/issues/781 -predicated - : valueExpression predicate? - ; - predicate : NOT? kind=BETWEEN lower=valueExpression AND upper=valueExpression | NOT? kind=IN '(' expression (',' expression)* ')' diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/MaskExpressionsUtils.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/MaskExpressionsUtils.java deleted file mode 100644 index 05879902a4ed9..0000000000000 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/MaskExpressionsUtils.java +++ /dev/null @@ -1,80 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.expressions; - -/** - * Contains all the Utils methods used in the masking expressions. - */ -public class MaskExpressionsUtils { - static final int UNMASKED_VAL = -1; - - /** - * Returns the masking character for {@param c} or {@param c} is it should not be masked. - * @param c the character to transform - * @param maskedUpperChar the character to use instead of a uppercase letter - * @param maskedLowerChar the character to use instead of a lowercase letter - * @param maskedDigitChar the character to use instead of a digit - * @param maskedOtherChar the character to use instead of a any other character - * @return masking character for {@param c} - */ - public static int transformChar( - final int c, - int maskedUpperChar, - int maskedLowerChar, - int maskedDigitChar, - int maskedOtherChar) { - switch(Character.getType(c)) { - case Character.UPPERCASE_LETTER: - if(maskedUpperChar != UNMASKED_VAL) { - return maskedUpperChar; - } - break; - - case Character.LOWERCASE_LETTER: - if(maskedLowerChar != UNMASKED_VAL) { - return maskedLowerChar; - } - break; - - case Character.DECIMAL_DIGIT_NUMBER: - if(maskedDigitChar != UNMASKED_VAL) { - return maskedDigitChar; - } - break; - - default: - if(maskedOtherChar != UNMASKED_VAL) { - return maskedOtherChar; - } - break; - } - - return c; - } - - /** - * Returns the replacement char to use according to the {@param rep} specified by the user and - * the {@param def} default. - */ - public static int getReplacementChar(String rep, int def) { - if (rep != null && rep.length() > 0) { - return rep.codePointAt(0); - } - return def; - } -} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java index d5d934bc91cab..cf2a5ed2e27f9 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java @@ -83,7 +83,7 @@ public static long calculateSizeOfUnderlyingByteArray(long numFields, int elemen private long elementOffset; private long getElementOffset(int ordinal, int elementSize) { - return elementOffset + ordinal * elementSize; + return elementOffset + ordinal * (long)elementSize; } public Object getBaseObject() { return baseObject; } @@ -414,7 +414,7 @@ public byte[] toByteArray() { public short[] toShortArray() { short[] values = new short[numElements]; Platform.copyMemory( - baseObject, elementOffset, values, Platform.SHORT_ARRAY_OFFSET, numElements * 2); + baseObject, elementOffset, values, Platform.SHORT_ARRAY_OFFSET, numElements * 2L); return values; } @@ -422,7 +422,7 @@ public short[] toShortArray() { public int[] toIntArray() { int[] values = new int[numElements]; Platform.copyMemory( - baseObject, elementOffset, values, Platform.INT_ARRAY_OFFSET, numElements * 4); + baseObject, elementOffset, values, Platform.INT_ARRAY_OFFSET, numElements * 4L); return values; } @@ -430,7 +430,7 @@ public int[] toIntArray() { public long[] toLongArray() { long[] values = new long[numElements]; Platform.copyMemory( - baseObject, elementOffset, values, Platform.LONG_ARRAY_OFFSET, numElements * 8); + baseObject, elementOffset, values, Platform.LONG_ARRAY_OFFSET, numElements * 8L); return values; } @@ -438,7 +438,7 @@ public long[] toLongArray() { public float[] toFloatArray() { float[] values = new float[numElements]; Platform.copyMemory( - baseObject, elementOffset, values, Platform.FLOAT_ARRAY_OFFSET, numElements * 4); + baseObject, elementOffset, values, Platform.FLOAT_ARRAY_OFFSET, numElements * 4L); return values; } @@ -446,14 +446,14 @@ public float[] toFloatArray() { public double[] toDoubleArray() { double[] values = new double[numElements]; Platform.copyMemory( - baseObject, elementOffset, values, Platform.DOUBLE_ARRAY_OFFSET, numElements * 8); + baseObject, elementOffset, values, Platform.DOUBLE_ARRAY_OFFSET, numElements * 8L); return values; } - private static UnsafeArrayData fromPrimitiveArray( + public static UnsafeArrayData fromPrimitiveArray( Object arr, int offset, int length, int elementSize) { final long headerInBytes = calculateHeaderPortionInBytes(length); - final long valueRegionInBytes = elementSize * length; + final long valueRegionInBytes = (long)elementSize * length; final long totalSizeInLongs = (headerInBytes + valueRegionInBytes + 7) / 8; if (totalSizeInLongs > Integer.MAX_VALUE / 8) { throw new UnsupportedOperationException("Cannot convert this array to unsafe format as " + @@ -463,14 +463,27 @@ private static UnsafeArrayData fromPrimitiveArray( final long[] data = new long[(int)totalSizeInLongs]; Platform.putLong(data, Platform.LONG_ARRAY_OFFSET, length); - Platform.copyMemory(arr, offset, data, - Platform.LONG_ARRAY_OFFSET + headerInBytes, valueRegionInBytes); + if (arr != null) { + Platform.copyMemory(arr, offset, data, + Platform.LONG_ARRAY_OFFSET + headerInBytes, valueRegionInBytes); + } UnsafeArrayData result = new UnsafeArrayData(); result.pointTo(data, Platform.LONG_ARRAY_OFFSET, (int)totalSizeInLongs * 8); return result; } + public static UnsafeArrayData forPrimitiveArray(int offset, int length, int elementSize) { + return fromPrimitiveArray(null, offset, length, elementSize); + } + + public static boolean shouldUseGenericArrayData(int elementSize, int length) { + final long headerInBytes = calculateHeaderPortionInBytes(length); + final long valueRegionInBytes = (long)elementSize * length; + final long totalSizeInLongs = (headerInBytes + valueRegionInBytes + 7) / 8; + return totalSizeInLongs > Integer.MAX_VALUE / 8; + } + public static UnsafeArrayData fromPrimitiveArray(boolean[] arr) { return fromPrimitiveArray(arr, Platform.BOOLEAN_ARRAY_OFFSET, arr.length, 1); } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/VariableLengthRowBasedKeyValueBatch.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/VariableLengthRowBasedKeyValueBatch.java index 905e6820ce6e2..c823de4810f2b 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/VariableLengthRowBasedKeyValueBatch.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/VariableLengthRowBasedKeyValueBatch.java @@ -41,7 +41,7 @@ public final class VariableLengthRowBasedKeyValueBatch extends RowBasedKeyValueB @Override public UnsafeRow appendRow(Object kbase, long koff, int klen, Object vbase, long voff, int vlen) { - final long recordLength = 8 + klen + vlen + 8; + final long recordLength = 8L + klen + vlen + 8; // if run out of max supported rows or page size, return null if (numRows >= capacity || page == null || page.size() - pageCursor < recordLength) { return null; diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/xml/UDFXPathUtil.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/xml/UDFXPathUtil.java index d224332d8a6c9..023ec139652c5 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/xml/UDFXPathUtil.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/xml/UDFXPathUtil.java @@ -21,6 +21,9 @@ import java.io.Reader; import javax.xml.namespace.QName; +import javax.xml.parsers.DocumentBuilder; +import javax.xml.parsers.DocumentBuilderFactory; +import javax.xml.parsers.ParserConfigurationException; import javax.xml.xpath.XPath; import javax.xml.xpath.XPathConstants; import javax.xml.xpath.XPathExpression; @@ -37,9 +40,15 @@ * This is based on Hive's UDFXPathUtil implementation. */ public class UDFXPathUtil { + public static final String SAX_FEATURE_PREFIX = "http://xml.org/sax/features/"; + public static final String EXTERNAL_GENERAL_ENTITIES_FEATURE = "external-general-entities"; + public static final String EXTERNAL_PARAMETER_ENTITIES_FEATURE = "external-parameter-entities"; + private DocumentBuilderFactory dbf = DocumentBuilderFactory.newInstance(); + private DocumentBuilder builder = null; private XPath xpath = XPathFactory.newInstance().newXPath(); private ReusableStringReader reader = new ReusableStringReader(); private InputSource inputSource = new InputSource(reader); + private XPathExpression expression = null; private String oldPath = null; @@ -65,14 +74,31 @@ public Object eval(String xml, String path, QName qname) throws XPathExpressionE return null; } + if (builder == null){ + try { + initializeDocumentBuilderFactory(); + builder = dbf.newDocumentBuilder(); + } catch (ParserConfigurationException e) { + throw new RuntimeException( + "Error instantiating DocumentBuilder, cannot build xml parser", e); + } + } + reader.set(xml); try { - return expression.evaluate(inputSource, qname); + return expression.evaluate(builder.parse(inputSource), qname); } catch (XPathExpressionException e) { throw new RuntimeException("Invalid XML document: " + e.getMessage() + "\n" + xml, e); + } catch (Exception e) { + throw new RuntimeException("Error loading expression '" + oldPath + "'", e); } } + private void initializeDocumentBuilderFactory() throws ParserConfigurationException { + dbf.setFeature(SAX_FEATURE_PREFIX + EXTERNAL_GENERAL_ENTITIES_FEATURE, false); + dbf.setFeature(SAX_FEATURE_PREFIX + EXTERNAL_PARAMETER_ENTITIES_FEATURE, false); + } + public Boolean evalBoolean(String xml, String path) throws XPathExpressionException { return (Boolean) eval(xml, path, XPathConstants.BOOLEAN); } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoders.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoders.scala index 0b95a8821b05a..b47ec0b72c638 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoders.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoders.scala @@ -132,7 +132,7 @@ object Encoders { * - primitive types: boolean, int, double, etc. * - boxed types: Boolean, Integer, Double, etc. * - String - * - java.math.BigDecimal + * - java.math.BigDecimal, java.math.BigInteger * - time related: java.sql.Date, java.sql.Timestamp * - collection types: only array and java.util.List currently, map support is in progress * - nested java bean. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala index 9e9105a157abe..6f5fbdd79e668 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala @@ -286,6 +286,7 @@ object CatalystTypeConverters { override def toCatalystImpl(scalaValue: Any): UTF8String = scalaValue match { case str: String => UTF8String.fromString(str) case utf8: UTF8String => utf8 + case chr: Char => UTF8String.fromString(chr.toString) case other => throw new IllegalArgumentException( s"The value (${other.toString}) of the type (${other.getClass.getCanonicalName}) " + s"cannot be converted to the string type") @@ -430,6 +431,12 @@ object CatalystTypeConverters { map, (key: Any) => convertToCatalyst(key), (value: Any) => convertToCatalyst(value)) + case (keys: Array[_], values: Array[_]) => + // case for mapdata with duplicate keys + new ArrayBasedMapData( + new GenericArrayData(keys.map(convertToCatalyst)), + new GenericArrayData(values.map(convertToCatalyst)) + ) case other => other } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index f9acc208b715e..4543bba8f6ed4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -798,7 +798,12 @@ object ScalaReflection extends ScalaReflection { * Whether the fields of the given type is defined entirely by its constructor parameters. */ def definedByConstructorParams(tpe: Type): Boolean = cleanUpReflectionObjects { - tpe.dealias <:< localTypeOf[Product] || tpe.dealias <:< localTypeOf[DefinedByConstructorParams] + tpe.dealias match { + // `Option` is a `Product`, but we don't wanna treat `Option[Int]` as a struct type. + case t if t <:< localTypeOf[Option[_]] => definedByConstructorParams(t.typeArgs.head) + case _ => tpe.dealias <:< localTypeOf[Product] || + tpe.dealias <:< localTypeOf[DefinedByConstructorParams] + } } private val javaKeywords = Set("abstract", "assert", "boolean", "break", "byte", "case", "catch", diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 6e3107f1c6f75..36f14ccdc6989 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -17,6 +17,9 @@ package org.apache.spark.sql.catalyst.analysis +import java.util.Locale + +import scala.collection.mutable import scala.collection.mutable.ArrayBuffer import scala.util.Random @@ -509,12 +512,7 @@ class Analyzer( || !p.pivotColumn.resolved => p case Pivot(groupByExprsOpt, pivotColumn, pivotValues, aggregates, child) => // Check all aggregate expressions. - aggregates.foreach { e => - if (!isAggregateExpression(e)) { - throw new AnalysisException( - s"Aggregate expression required for pivot, found '$e'") - } - } + aggregates.foreach(checkValidAggregateExpression) // Group-by expressions coming from SQL are implicit and need to be deduced. val groupByExprs = groupByExprsOpt.getOrElse( (child.outputSet -- aggregates.flatMap(_.references) -- pivotColumn.references).toSeq) @@ -586,12 +584,17 @@ class Analyzer( } } - private def isAggregateExpression(expr: Expression): Boolean = { - expr match { - case Alias(e, _) => isAggregateExpression(e) - case AggregateExpression(_, _, _, _) => true - case _ => false - } + // Support any aggregate expression that can appear in an Aggregate plan except Pandas UDF. + // TODO: Support Pandas UDF. + private def checkValidAggregateExpression(expr: Expression): Unit = expr match { + case _: AggregateExpression => // OK and leave the argument check to CheckAnalysis. + case expr: PythonUDF if PythonUDF.isGroupedAggPandasUDF(expr) => + failAnalysis("Pandas UDF aggregate expressions are currently not supported in pivot.") + case e: Attribute => + failAnalysis( + s"Aggregate expression required for pivot, but '${e.sql}' " + + s"did not appear in any aggregate function.") + case e => e.children.foreach(checkValidAggregateExpression) } } @@ -738,6 +741,10 @@ class Analyzer( if findAliases(aggregateExpressions).intersect(conflictingAttributes).nonEmpty => (oldVersion, oldVersion.copy(aggregateExpressions = newAliases(aggregateExpressions))) + case oldVersion @ FlatMapGroupsInPandas(_, _, output, _) + if oldVersion.outputSet.intersect(conflictingAttributes).nonEmpty => + (oldVersion, oldVersion.copy(output = output.map(_.newInstance()))) + case oldVersion: Generate if oldVersion.producedAttributes.intersect(conflictingAttributes).nonEmpty => val newOutput = oldVersion.generatorOutput.map(_.newInstance()) @@ -1125,7 +1132,8 @@ class Analyzer( case sa @ Sort(_, _, AnalysisBarrier(child: Aggregate)) => sa case sa @ Sort(_, _, child: Aggregate) => sa - case s @ Sort(order, _, child) if !s.resolved && child.resolved => + case s @ Sort(order, _, child) + if (!s.resolved || s.missingInput.nonEmpty) && child.resolved => val (newOrder, newChild) = resolveExprsAndAddMissingAttrs(order, child) val ordering = newOrder.map(_.asInstanceOf[SortOrder]) if (child.output == newChild.output) { @@ -1136,7 +1144,7 @@ class Analyzer( Project(child.output, newSort) } - case f @ Filter(cond, child) if !f.resolved && child.resolved => + case f @ Filter(cond, child) if (!f.resolved || f.missingInput.nonEmpty) && child.resolved => val (newCond, newChild) = resolveExprsAndAddMissingAttrs(Seq(cond), child) if (child.output == newChild.output) { f.copy(condition = newCond.head) @@ -1147,10 +1155,17 @@ class Analyzer( } } + /** + * This method tries to resolve expressions and find missing attributes recursively. Specially, + * when the expressions used in `Sort` or `Filter` contain unresolved attributes or resolved + * attributes which are missed from child output. This method tries to find the missing + * attributes out and add into the projection. + */ private def resolveExprsAndAddMissingAttrs( exprs: Seq[Expression], plan: LogicalPlan): (Seq[Expression], LogicalPlan) = { - if (exprs.forall(_.resolved)) { - // All given expressions are resolved, no need to continue anymore. + // Missing attributes can be unresolved attributes or resolved attributes which are not in + // the output attributes of the plan. + if (exprs.forall(e => e.resolved && e.references.subsetOf(plan.outputSet))) { (exprs, plan) } else { plan match { @@ -1161,15 +1176,19 @@ class Analyzer( (newExprs, AnalysisBarrier(newChild)) case p: Project => + // Resolving expressions against current plan. val maybeResolvedExprs = exprs.map(resolveExpression(_, p)) + // Recursively resolving expressions on the child of current plan. val (newExprs, newChild) = resolveExprsAndAddMissingAttrs(maybeResolvedExprs, p.child) - val missingAttrs = AttributeSet(newExprs) -- AttributeSet(maybeResolvedExprs) + // If some attributes used by expressions are resolvable only on the rewritten child + // plan, we need to add them into original projection. + val missingAttrs = (AttributeSet(newExprs) -- p.outputSet).intersect(newChild.outputSet) (newExprs, Project(p.projectList ++ missingAttrs, newChild)) case a @ Aggregate(groupExprs, aggExprs, child) => val maybeResolvedExprs = exprs.map(resolveExpression(_, a)) val (newExprs, newChild) = resolveExprsAndAddMissingAttrs(maybeResolvedExprs, child) - val missingAttrs = AttributeSet(newExprs) -- AttributeSet(maybeResolvedExprs) + val missingAttrs = (AttributeSet(newExprs) -- a.outputSet).intersect(newChild.outputSet) if (missingAttrs.forall(attr => groupExprs.exists(_.semanticEquals(attr)))) { // All the missing attributes are grouping expressions, valid case. (newExprs, a.copy(aggregateExpressions = aggExprs ++ missingAttrs, child = newChild)) @@ -1204,16 +1223,46 @@ class Analyzer( * only performs simple existence check according to the function identifier to quickly identify * undefined functions without triggering relation resolution, which may incur potentially * expensive partition/schema discovery process in some cases. - * + * In order to avoid duplicate external functions lookup, the external function identifier will + * store in the local hash set externalFunctionNameSet. * @see [[ResolveFunctions]] * @see https://issues.apache.org/jira/browse/SPARK-19737 */ object LookupFunctions extends Rule[LogicalPlan] { - override def apply(plan: LogicalPlan): LogicalPlan = plan.transformAllExpressions { - case f: UnresolvedFunction if !catalog.functionExists(f.name) => - withPosition(f) { - throw new NoSuchFunctionException(f.name.database.getOrElse("default"), f.name.funcName) - } + override def apply(plan: LogicalPlan): LogicalPlan = { + val externalFunctionNameSet = new mutable.HashSet[FunctionIdentifier]() + plan.transformAllExpressions { + case f: UnresolvedFunction + if externalFunctionNameSet.contains(normalizeFuncName(f.name)) => f + case f: UnresolvedFunction if catalog.isRegisteredFunction(f.name) => f + case f: UnresolvedFunction if catalog.isPersistentFunction(f.name) => + externalFunctionNameSet.add(normalizeFuncName(f.name)) + f + case f: UnresolvedFunction => + withPosition(f) { + throw new NoSuchFunctionException(f.name.database.getOrElse(catalog.getCurrentDatabase), + f.name.funcName) + } + } + } + + def normalizeFuncName(name: FunctionIdentifier): FunctionIdentifier = { + val funcName = if (conf.caseSensitiveAnalysis) { + name.funcName + } else { + name.funcName.toLowerCase(Locale.ROOT) + } + + val databaseName = name.database match { + case Some(a) => formatDatabaseName(a) + case None => catalog.getCurrentDatabase + } + + FunctionIdentifier(funcName, Some(databaseName)) + } + + protected def formatDatabaseName(name: String): String = { + if (conf.caseSensitiveAnalysis) name else name.toLowerCase(Locale.ROOT) } } @@ -1489,7 +1538,11 @@ class Analyzer( // Try resolving the ordering as though it is in the aggregate clause. try { - val unresolvedSortOrders = sortOrder.filter(s => !s.resolved || containsAggregate(s)) + // If a sort order is unresolved, containing references not in aggregate, or containing + // `AggregateExpression`, we need to push down it to the underlying aggregate operator. + val unresolvedSortOrders = sortOrder.filter { s => + !s.resolved || !s.references.subsetOf(aggregate.outputSet) || containsAggregate(s) + } val aliasedOrdering = unresolvedSortOrders.map(o => Alias(o.child, "aggOrder")()) val aggregatedOrdering = aggregate.copy(aggregateExpressions = aliasedOrdering) @@ -1923,6 +1976,9 @@ class Analyzer( // "Aggregate with Having clause" will be triggered. def apply(plan: LogicalPlan): LogicalPlan = plan transformDown { + case Filter(condition, _) if hasWindowFunction(condition) => + failAnalysis("It is not allowed to use window functions inside WHERE and HAVING clauses") + // Aggregate with Having clause. This rule works with an unresolved Aggregate because // a resolved Aggregate will not have Window Functions. case f @ Filter(condition, a @ Aggregate(groupingExprs, aggregateExprs, child)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 3700c63d817ea..d696ce9a766d5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -414,6 +414,7 @@ object FunctionRegistry { expression[ArrayJoin]("array_join"), expression[ArrayPosition]("array_position"), expression[ArraySort]("array_sort"), + expression[ArrayUnion]("array_union"), expression[CreateMap]("map"), expression[CreateNamedStruct]("named_struct"), expression[ElementAt]("element_at"), @@ -421,6 +422,8 @@ object FunctionRegistry { expression[MapKeys]("map_keys"), expression[MapValues]("map_values"), expression[MapEntries]("map_entries"), + expression[MapFromEntries]("map_from_entries"), + expression[MapConcat]("map_concat"), expression[Size]("size"), expression[Slice]("slice"), expression[Size]("cardinality"), @@ -431,18 +434,12 @@ object FunctionRegistry { expression[Reverse]("reverse"), expression[Concat]("concat"), expression[Flatten]("flatten"), + expression[Sequence]("sequence"), expression[ArrayRepeat]("array_repeat"), expression[ArrayRemove]("array_remove"), + expression[ArrayDistinct]("array_distinct"), CreateStruct.registryEntry, - // mask functions - expression[Mask]("mask"), - expression[MaskFirstN]("mask_first_n"), - expression[MaskLastN]("mask_last_n"), - expression[MaskShowFirstN]("mask_show_first_n"), - expression[MaskShowLastN]("mask_show_last_n"), - expression[MaskHash]("mask_hash"), - // misc functions expression[AssertTrue]("assert_true"), expression[Crc32]("crc32"), @@ -502,6 +499,7 @@ object FunctionRegistry { // json expression[StructsToJson]("to_json"), expression[JsonToStructs]("from_json"), + expression[SchemaOfJson]("schema_of_json"), // cast expression[Cast]("cast"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/NoSuchItemException.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/NoSuchItemException.scala index f5aae60431c15..68d29c2936822 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/NoSuchItemException.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/NoSuchItemException.scala @@ -40,10 +40,10 @@ class NoSuchPartitionException( class NoSuchPermanentFunctionException(db: String, func: String) extends AnalysisException(s"Function '$func' not found in database '$db'") -class NoSuchFunctionException(db: String, func: String) +class NoSuchFunctionException(db: String, func: String, rootCause: Option[String] = None) extends AnalysisException( s"Undefined function: '$func'. This function is neither a registered temporary function nor " + - s"a permanent function registered in the database '$db'.") + s"a permanent function registered in the database '$db'." + rootCause.getOrElse("")) class NoSuchPartitionsException(db: String, table: String, specs: Seq[TablePartitionSpec]) extends AnalysisException( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index b2817b0538a7f..316aebdeaffa1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -102,25 +102,7 @@ object TypeCoercion { case (_: TimestampType, _: DateType) | (_: DateType, _: TimestampType) => Some(TimestampType) - case (t1 @ StructType(fields1), t2 @ StructType(fields2)) if t1.sameType(t2) => - Some(StructType(fields1.zip(fields2).map { case (f1, f2) => - // Since `t1.sameType(t2)` is true, two StructTypes have the same DataType - // except `name` (in case of `spark.sql.caseSensitive=false`) and `nullable`. - // - Different names: use f1.name - // - Different nullabilities: `nullable` is true iff one of them is nullable. - val dataType = findTightestCommonType(f1.dataType, f2.dataType).get - StructField(f1.name, dataType, nullable = f1.nullable || f2.nullable) - })) - - case (a1 @ ArrayType(et1, hasNull1), a2 @ ArrayType(et2, hasNull2)) if a1.sameType(a2) => - findTightestCommonType(et1, et2).map(ArrayType(_, hasNull1 || hasNull2)) - - case (m1 @ MapType(kt1, vt1, hasNull1), m2 @ MapType(kt2, vt2, hasNull2)) if m1.sameType(m2) => - val keyType = findTightestCommonType(kt1, kt2) - val valueType = findTightestCommonType(vt1, vt2) - Some(MapType(keyType.get, valueType.get, hasNull1 || hasNull2)) - - case _ => None + case (t1, t2) => findTypeForComplex(t1, t2, findTightestCommonType) } /** Promotes all the way to StringType. */ @@ -166,6 +148,53 @@ object TypeCoercion { case (l, r) => None } + private def findTypeForComplex( + t1: DataType, + t2: DataType, + findTypeFunc: (DataType, DataType) => Option[DataType]): Option[DataType] = (t1, t2) match { + case (ArrayType(et1, containsNull1), ArrayType(et2, containsNull2)) => + findTypeFunc(et1, et2).map(ArrayType(_, containsNull1 || containsNull2)) + case (MapType(kt1, vt1, valueContainsNull1), MapType(kt2, vt2, valueContainsNull2)) => + findTypeFunc(kt1, kt2).flatMap { kt => + findTypeFunc(vt1, vt2).map { vt => + MapType(kt, vt, valueContainsNull1 || valueContainsNull2) + } + } + case (StructType(fields1), StructType(fields2)) if fields1.length == fields2.length => + val resolver = SQLConf.get.resolver + fields1.zip(fields2).foldLeft(Option(new StructType())) { + case (Some(struct), (field1, field2)) if resolver(field1.name, field2.name) => + findTypeFunc(field1.dataType, field2.dataType).map { + dt => struct.add(field1.name, dt, field1.nullable || field2.nullable) + } + case _ => None + } + case _ => None + } + + /** + * The method finds a common type for data types that differ only in nullable, containsNull + * and valueContainsNull flags. If the input types are too different, None is returned. + */ + def findCommonTypeDifferentOnlyInNullFlags(t1: DataType, t2: DataType): Option[DataType] = { + if (t1 == t2) { + Some(t1) + } else { + findTypeForComplex(t1, t2, findCommonTypeDifferentOnlyInNullFlags) + } + } + + def findCommonTypeDifferentOnlyInNullFlags(types: Seq[DataType]): Option[DataType] = { + if (types.isEmpty) { + None + } else { + types.tail.foldLeft[Option[DataType]](Some(types.head)) { + case (Some(t1), t2) => findCommonTypeDifferentOnlyInNullFlags(t1, t2) + case _ => None + } + } + } + /** * Case 2 type widening (see the classdoc comment above for TypeCoercion). * @@ -176,11 +205,7 @@ object TypeCoercion { findTightestCommonType(t1, t2) .orElse(findWiderTypeForDecimal(t1, t2)) .orElse(stringPromotion(t1, t2)) - .orElse((t1, t2) match { - case (ArrayType(et1, containsNull1), ArrayType(et2, containsNull2)) => - findWiderTypeForTwo(et1, et2).map(ArrayType(_, containsNull1 || containsNull2)) - case _ => None - }) + .orElse(findTypeForComplex(t1, t2, findWiderTypeForTwo)) } /** @@ -216,12 +241,7 @@ object TypeCoercion { t2: DataType): Option[DataType] = { findTightestCommonType(t1, t2) .orElse(findWiderTypeForDecimal(t1, t2)) - .orElse((t1, t2) match { - case (ArrayType(et1, containsNull1), ArrayType(et2, containsNull2)) => - findWiderTypeWithoutStringPromotionForTwo(et1, et2) - .map(ArrayType(_, containsNull1 || containsNull2)) - case _ => None - }) + .orElse(findTypeForComplex(t1, t2, findWiderTypeWithoutStringPromotionForTwo)) } def findWiderTypeWithoutStringPromotion(types: Seq[DataType]): Option[DataType] = { @@ -250,8 +270,25 @@ object TypeCoercion { } } - private def haveSameType(exprs: Seq[Expression]): Boolean = - exprs.map(_.dataType).distinct.length == 1 + /** + * Check whether the given types are equal ignoring nullable, containsNull and valueContainsNull. + */ + def haveSameType(types: Seq[DataType]): Boolean = { + if (types.size <= 1) { + true + } else { + val head = types.head + types.tail.forall(_.sameType(head)) + } + } + + private def castIfNotSameType(expr: Expression, dt: DataType): Expression = { + if (!expr.dataType.sameType(dt)) { + Cast(expr, dt) + } else { + expr + } + } /** * Widens numeric types and converts strings to numbers when appropriate. @@ -516,46 +553,63 @@ object TypeCoercion { * This ensure that the types for various functions are as expected. */ object FunctionArgumentConversion extends TypeCoercionRule { + override protected def coerceTypes( plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { // Skip nodes who's children have not been resolved yet. case e if !e.childrenResolved => e - case a @ CreateArray(children) if !haveSameType(children) => + case a @ CreateArray(children) if !haveSameType(children.map(_.dataType)) => val types = children.map(_.dataType) findWiderCommonType(types) match { - case Some(finalDataType) => CreateArray(children.map(Cast(_, finalDataType))) + case Some(finalDataType) => CreateArray(children.map(castIfNotSameType(_, finalDataType))) case None => a } case c @ Concat(children) if children.forall(c => ArrayType.acceptsType(c.dataType)) && - !haveSameType(children) => + !haveSameType(c.inputTypesForMerging) => val types = children.map(_.dataType) findWiderCommonType(types) match { - case Some(finalDataType) => Concat(children.map(Cast(_, finalDataType))) + case Some(finalDataType) => Concat(children.map(castIfNotSameType(_, finalDataType))) case None => c } + case aj @ ArrayJoin(arr, d, nr) if !ArrayType(StringType).acceptsType(arr.dataType) && + ArrayType.acceptsType(arr.dataType) => + val containsNull = arr.dataType.asInstanceOf[ArrayType].containsNull + ImplicitTypeCasts.implicitCast(arr, ArrayType(StringType, containsNull)) match { + case Some(castedArr) => ArrayJoin(castedArr, d, nr) + case None => aj + } + + case s @ Sequence(_, _, _, timeZoneId) + if !haveSameType(s.coercibleChildren.map(_.dataType)) => + val types = s.coercibleChildren.map(_.dataType) + findWiderCommonType(types) match { + case Some(widerDataType) => s.castChildrenTo(widerDataType) + case None => s + } + + case m @ MapConcat(children) if children.forall(c => MapType.acceptsType(c.dataType)) && + !haveSameType(m.inputTypesForMerging) => + val types = children.map(_.dataType) + findWiderCommonType(types) match { + case Some(finalDataType) => MapConcat(children.map(castIfNotSameType(_, finalDataType))) + case None => m + } + case m @ CreateMap(children) if m.keys.length == m.values.length && - (!haveSameType(m.keys) || !haveSameType(m.values)) => - val newKeys = if (haveSameType(m.keys)) { - m.keys - } else { - val types = m.keys.map(_.dataType) - findWiderCommonType(types) match { - case Some(finalDataType) => m.keys.map(Cast(_, finalDataType)) - case None => m.keys - } + (!haveSameType(m.keys.map(_.dataType)) || !haveSameType(m.values.map(_.dataType))) => + val keyTypes = m.keys.map(_.dataType) + val newKeys = findWiderCommonType(keyTypes) match { + case Some(finalDataType) => m.keys.map(castIfNotSameType(_, finalDataType)) + case None => m.keys } - val newValues = if (haveSameType(m.values)) { - m.values - } else { - val types = m.values.map(_.dataType) - findWiderCommonType(types) match { - case Some(finalDataType) => m.values.map(Cast(_, finalDataType)) - case None => m.values - } + val valueTypes = m.values.map(_.dataType) + val newValues = findWiderCommonType(valueTypes) match { + case Some(finalDataType) => m.values.map(castIfNotSameType(_, finalDataType)) + case None => m.values } CreateMap(newKeys.zip(newValues).flatMap { case (k, v) => Seq(k, v) }) @@ -578,27 +632,27 @@ object TypeCoercion { // Coalesce should return the first non-null value, which could be any column // from the list. So we need to make sure the return type is deterministic and // compatible with every child column. - case c @ Coalesce(es) if !haveSameType(es) => + case c @ Coalesce(es) if !haveSameType(c.inputTypesForMerging) => val types = es.map(_.dataType) findWiderCommonType(types) match { - case Some(finalDataType) => Coalesce(es.map(Cast(_, finalDataType))) + case Some(finalDataType) => Coalesce(es.map(castIfNotSameType(_, finalDataType))) case None => c } // When finding wider type for `Greatest` and `Least`, we should handle decimal types even if // we need to truncate, but we should not promote one side to string if the other side is // string.g - case g @ Greatest(children) if !haveSameType(children) => + case g @ Greatest(children) if !haveSameType(g.inputTypesForMerging) => val types = children.map(_.dataType) findWiderTypeWithoutStringPromotion(types) match { - case Some(finalDataType) => Greatest(children.map(Cast(_, finalDataType))) + case Some(finalDataType) => Greatest(children.map(castIfNotSameType(_, finalDataType))) case None => g } - case l @ Least(children) if !haveSameType(children) => + case l @ Least(children) if !haveSameType(l.inputTypesForMerging) => val types = children.map(_.dataType) findWiderTypeWithoutStringPromotion(types) match { - case Some(finalDataType) => Least(children.map(Cast(_, finalDataType))) + case Some(finalDataType) => Least(children.map(castIfNotSameType(_, finalDataType))) case None => l } @@ -640,27 +694,14 @@ object TypeCoercion { object CaseWhenCoercion extends TypeCoercionRule { override protected def coerceTypes( plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { - case c: CaseWhen if c.childrenResolved && !c.valueTypesEqual => - val maybeCommonType = findWiderCommonType(c.valueTypes) + case c: CaseWhen if c.childrenResolved && !haveSameType(c.inputTypesForMerging) => + val maybeCommonType = findWiderCommonType(c.inputTypesForMerging) maybeCommonType.map { commonType => - var changed = false val newBranches = c.branches.map { case (condition, value) => - if (value.dataType.sameType(commonType)) { - (condition, value) - } else { - changed = true - (condition, Cast(value, commonType)) - } - } - val newElseValue = c.elseValue.map { value => - if (value.dataType.sameType(commonType)) { - value - } else { - changed = true - Cast(value, commonType) - } + (condition, castIfNotSameType(value, commonType)) } - if (changed) CaseWhen(newBranches, newElseValue) else c + val newElseValue = c.elseValue.map(castIfNotSameType(_, commonType)) + CaseWhen(newBranches, newElseValue) }.getOrElse(c) } } @@ -673,10 +714,10 @@ object TypeCoercion { plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { case e if !e.childrenResolved => e // Find tightest common type for If, if the true value and false value have different types. - case i @ If(pred, left, right) if left.dataType != right.dataType => + case i @ If(pred, left, right) if !haveSameType(i.inputTypesForMerging) => findWiderTypeForTwo(left.dataType, right.dataType).map { widestType => - val newLeft = if (left.dataType == widestType) left else Cast(left, widestType) - val newRight = if (right.dataType == widestType) right else Cast(right, widestType) + val newLeft = castIfNotSameType(left, widestType) + val newRight = castIfNotSameType(right, widestType) If(pred, newLeft, newRight) }.getOrElse(i) // If there is no applicable conversion, leave expression unchanged. case If(Literal(null, NullType), left, right) => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala index 2bed41672fe33..f68df5d29b545 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala @@ -315,8 +315,10 @@ object UnsupportedOperationChecker { case GroupingSets(_, _, child, _) if child.isStreaming => throwError("GroupingSets is not supported on streaming DataFrames/Datasets") - case GlobalLimit(_, _) | LocalLimit(_, _) if subPlan.children.forall(_.isStreaming) => - throwError("Limits are not supported on streaming DataFrames/Datasets") + case GlobalLimit(_, _) | LocalLimit(_, _) + if subPlan.children.forall(_.isStreaming) && outputMode == InternalOutputModes.Update => + throwError("Limits are not supported on streaming DataFrames/Datasets in Update " + + "output mode") case Sort(_, _, _) if !containsCompleteData(subPlan) => throwError("Sorting is not supported on streaming DataFrames/Datasets, unless it is on " + @@ -349,6 +351,17 @@ object UnsupportedOperationChecker { _: DeserializeToObject | _: SerializeFromObject | _: SubqueryAlias | _: TypedFilter) => case node if node.nodeName == "StreamingRelationV2" => + case Repartition(1, false, _) => + case node: Aggregate => + val aboveSinglePartitionCoalesce = node.find { + case Repartition(1, false, _) => true + case _ => false + }.isDefined + + if (!aboveSinglePartitionCoalesce) { + throwError(s"In continuous processing mode, coalesce(1) must be called before " + + s"aggregate operation ${node.nodeName}.") + } case node => throwError(s"Continuous processing does not support ${node.nodeName} operations.") } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala index c390337c03ff5..5e4b3c5528b23 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala @@ -32,7 +32,6 @@ import org.apache.hadoop.fs.Path import org.apache.spark.internal.Logging import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst._ -import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionInfo} import org.apache.spark.sql.catalyst.parser.{CatalystSqlParser, ParserInterface} @@ -619,6 +618,7 @@ class SessionCatalog( requireTableExists(TableIdentifier(oldTableName, Some(db))) requireTableNotExists(TableIdentifier(newTableName, Some(db))) validateName(newTableName) + validateNewLocationOfRename(oldName, newName) externalCatalog.renameTable(db, oldTableName, newTableName) } else { if (newName.database.isDefined) { @@ -1192,9 +1192,26 @@ class SessionCatalog( !hiveFunctions.contains(name.funcName.toLowerCase(Locale.ROOT)) } - protected def failFunctionLookup(name: FunctionIdentifier): Nothing = { + /** + * Return whether this function has been registered in the function registry of the current + * session. If not existed, return false. + */ + def isRegisteredFunction(name: FunctionIdentifier): Boolean = { + functionRegistry.functionExists(name) + } + + /** + * Returns whether it is a persistent function. If not existed, returns false. + */ + def isPersistentFunction(name: FunctionIdentifier): Boolean = { + val db = formatDatabaseName(name.database.getOrElse(getCurrentDatabase)) + databaseExists(db) && externalCatalog.functionExists(db, name.funcName) + } + + protected def failFunctionLookup( + name: FunctionIdentifier, rootCause: Option[String] = None): Nothing = { throw new NoSuchFunctionException( - db = name.database.getOrElse(getCurrentDatabase), func = name.funcName) + db = name.database.getOrElse(getCurrentDatabase), func = name.funcName, rootCause) } /** @@ -1366,4 +1383,23 @@ class SessionCatalog( // copy over temporary views tempViews.foreach(kv => target.tempViews.put(kv._1, kv._2)) } + + /** + * Validate the new locatoin before renaming a managed table, which should be non-existent. + */ + private def validateNewLocationOfRename( + oldName: TableIdentifier, + newName: TableIdentifier): Unit = { + val oldTable = getTableMetadata(oldName) + if (oldTable.tableType == CatalogTableType.MANAGED) { + val databaseLocation = + externalCatalog.getDatabase(oldName.database.getOrElse(currentDb)).locationUri + val newTableLocation = new Path(new Path(databaseLocation), formatTableName(newName.table)) + val fs = newTableLocation.getFileSystem(hadoopConf) + if (fs.exists(newTableLocation)) { + throw new AnalysisException(s"Can not rename the managed table('$oldName')" + + s". The associated location('$newTableLocation') already exists.") + } + } + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala index f3e67dc4e975c..c6105c5526049 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala @@ -93,12 +93,16 @@ object CatalogStorageFormat { * @param spec partition spec values indexed by column name * @param storage storage format of the partition * @param parameters some parameters for the partition + * @param createTime creation time of the partition, in milliseconds + * @param lastAccessTime last access time, in milliseconds * @param stats optional statistics (number of rows, total size, etc.) */ case class CatalogTablePartition( spec: CatalogTypes.TablePartitionSpec, storage: CatalogStorageFormat, parameters: Map[String, String] = Map.empty, + createTime: Long = System.currentTimeMillis, + lastAccessTime: Long = -1, stats: Option[CatalogStatistics] = None) { def toLinkedHashMap: mutable.LinkedHashMap[String, String] = { @@ -109,6 +113,8 @@ case class CatalogTablePartition( if (parameters.nonEmpty) { map.put("Partition Parameters", s"{${parameters.map(p => p._1 + "=" + p._2).mkString(", ")}}") } + map.put("Created Time", new Date(createTime).toString) + map.put("Last Access", new Date(lastAccessTime).toString) stats.foreach(s => map.put("Partition Statistics", s.simpleString)) map } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index efb2eba655e15..89e8c998f740d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -149,6 +149,7 @@ package object dsl { } } + def rand(e: Long): Expression = Rand(e) def sum(e: Expression): Expression = Sum(e).toAggregateExpression() def sumDistinct(e: Expression): Expression = Sum(e).toAggregateExpression(isDistinct = true) def count(e: Expression): Expression = Count(e).toAggregateExpression() diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala index efc2882f0a3d3..cbea3c017a265 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala @@ -128,7 +128,7 @@ object ExpressionEncoder { case b: BoundReference if b == originalInputObject => newInputObject }) - if (enc.flat) { + val serializerExpr = if (enc.flat) { newSerializer.head } else { // For non-flat encoder, the input object is not top level anymore after being combined to @@ -146,6 +146,7 @@ object ExpressionEncoder { Invoke(Literal.fromObject(None), "equals", BooleanType, newInputObject :: Nil)) If(nullCheck, Literal.create(null, struct.dataType), struct) } + Alias(serializerExpr, s"_${index + 1}")() } val childrenDeserializers = encoders.zipWithIndex.map { case (enc, index) => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 699ea53b5df0f..7971ae602bd37 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -134,6 +134,26 @@ object Cast { toPrecedence > 0 && fromPrecedence > toPrecedence } + /** + * Returns true iff we can safely cast the `from` type to `to` type without any truncating or + * precision lose, e.g. int -> long, date -> timestamp. + */ + def canSafeCast(from: AtomicType, to: AtomicType): Boolean = (from, to) match { + case _ if from == to => true + case (from: NumericType, to: DecimalType) if to.isWiderThan(from) => true + case (from: DecimalType, to: NumericType) if from.isTighterThan(to) => true + case (from, to) if legalNumericPrecedence(from, to) => true + case (DateType, TimestampType) => true + case (_, StringType) => true + case _ => false + } + + private def legalNumericPrecedence(from: DataType, to: DataType): Boolean = { + val fromPrecedence = TypeCoercion.numericPrecedence.indexOf(from) + val toPrecedence = TypeCoercion.numericPrecedence.indexOf(to) + fromPrecedence >= 0 && fromPrecedence < toPrecedence + } + def forceNullable(from: DataType, to: DataType): Boolean = (from, to) match { case (NullType, _) => true case (_, _) if from == to => false diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 9b9fa41a47d0f..f7d1b105964d5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.expressions import java.util.Locale import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion} import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.trees.TreeNode @@ -695,6 +695,31 @@ abstract class TernaryExpression extends Expression { } } +/** + * A trait resolving nullable, containsNull, valueContainsNull flags of the output date type. + * This logic is usually utilized by expressions combining data from multiple child expressions + * of non-primitive types (e.g. [[CaseWhen]]). + */ +trait ComplexTypeMergingExpression extends Expression { + + /** + * A collection of data types used for resolution the output type of the expression. By default, + * data types of all child expressions. The collection must not be empty. + */ + @transient + lazy val inputTypesForMerging: Seq[DataType] = children.map(_.dataType) + + override def dataType: DataType = { + require( + inputTypesForMerging.nonEmpty, + "The collection of input data types must not be empty.") + require( + TypeCoercion.haveSameType(inputTypesForMerging), + "All input types must be the same except nullable, containsNull, valueContainsNull flags.") + inputTypesForMerging.reduceLeft(TypeCoercion.findCommonTypeDifferentOnlyInNullFlags(_, _).get) + } +} + /** * Common base trait for user-defined functions, including UDF/UDAF/UDTF of different languages * and Hive function wrappers. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index fe91e520169b4..55940410cc4d4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion} import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.util.TypeUtils @@ -514,7 +514,7 @@ case class Pmod(left: Expression, right: Expression) extends BinaryArithmetic { > SELECT _FUNC_(10, 9, 2, 4, 3); 2 """) -case class Least(children: Seq[Expression]) extends Expression { +case class Least(children: Seq[Expression]) extends ComplexTypeMergingExpression { override def nullable: Boolean = children.forall(_.nullable) override def foldable: Boolean = children.forall(_.foldable) @@ -525,7 +525,7 @@ case class Least(children: Seq[Expression]) extends Expression { if (children.length <= 1) { TypeCheckResult.TypeCheckFailure( s"input to function $prettyName requires at least two arguments") - } else if (children.map(_.dataType).distinct.count(_ != NullType) > 1) { + } else if (!TypeCoercion.haveSameType(inputTypesForMerging)) { TypeCheckResult.TypeCheckFailure( s"The expressions should all have the same type," + s" got LEAST(${children.map(_.dataType.simpleString).mkString(", ")}).") @@ -534,8 +534,6 @@ case class Least(children: Seq[Expression]) extends Expression { } } - override def dataType: DataType = children.head.dataType - override def eval(input: InternalRow): Any = { children.foldLeft[Any](null)((r, c) => { val evalc = c.eval(input) @@ -589,7 +587,7 @@ case class Least(children: Seq[Expression]) extends Expression { > SELECT _FUNC_(10, 9, 2, 4, 3); 10 """) -case class Greatest(children: Seq[Expression]) extends Expression { +case class Greatest(children: Seq[Expression]) extends ComplexTypeMergingExpression { override def nullable: Boolean = children.forall(_.nullable) override def foldable: Boolean = children.forall(_.foldable) @@ -600,7 +598,7 @@ case class Greatest(children: Seq[Expression]) extends Expression { if (children.length <= 1) { TypeCheckResult.TypeCheckFailure( s"input to function $prettyName requires at least two arguments") - } else if (children.map(_.dataType).distinct.count(_ != NullType) > 1) { + } else if (!TypeCoercion.haveSameType(inputTypesForMerging)) { TypeCheckResult.TypeCheckFailure( s"The expressions should all have the same type," + s" got GREATEST(${children.map(_.dataType.simpleString).mkString(", ")}).") @@ -609,8 +607,6 @@ case class Greatest(children: Seq[Expression]) extends Expression { } } - override def dataType: DataType = children.head.dataType - override def eval(input: InternalRow): Any = { children.foldLeft[Any](null)((r, c) => { val evalc = c.eval(input) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 66315e5906253..838c045d5bcce 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -819,6 +819,36 @@ class CodegenContext { } } + /** + * Generates code to do null safe execution when accessing properties of complex + * ArrayData elements. + * + * @param nullElements used to decide whether the ArrayData might contain null or not. + * @param isNull a variable indicating whether the result will be evaluated to null or not. + * @param arrayData a variable name representing the ArrayData. + * @param execute the code that should be executed only if the ArrayData doesn't contain + * any null. + */ + def nullArrayElementsSaveExec( + nullElements: Boolean, + isNull: String, + arrayData: String)( + execute: String): String = { + val i = freshName("idx") + if (nullElements) { + s""" + |for (int $i = 0; !$isNull && $i < $arrayData.numElements(); $i++) { + | $isNull |= $arrayData.isNullAt($i); + |} + |if (!$isNull) { + | $execute + |} + """.stripMargin + } else { + execute + } + } + /** * Splits the generated code of expressions into multiple functions, because function has * 64kb code size limit in JVM. If the class to which the function would be inlined would grow @@ -1385,7 +1415,7 @@ object CodeGenerator extends Logging { * weak keys/values and thus does not respond to memory pressure. */ private val cache = CacheBuilder.newBuilder() - .maximumSize(100) + .maximumSize(SQLConf.get.codegenCacheMaxEntries) .build( new CacheLoader[CodeAndComment, (GeneratedClass, Int)]() { override def load(code: CodeAndComment): (GeneratedClass, Int) = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala index 250ce48d059e0..2f8c853e836ba 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala @@ -22,6 +22,7 @@ import java.lang.{Boolean => JBool} import scala.collection.mutable.ArrayBuffer import scala.language.{existentials, implicitConversions} +import org.apache.spark.sql.catalyst.trees.TreeNode import org.apache.spark.sql.types.{BooleanType, DataType} /** @@ -118,10 +119,8 @@ object JavaCode { /** * A trait representing a block of java code. */ -trait Block extends JavaCode { - - // The expressions to be evaluated inside this block. - def exprValues: Set[ExprValue] +trait Block extends TreeNode[Block] with JavaCode { + import Block._ // Returns java code string for this code block. override def toString: String = _marginChar match { @@ -147,15 +146,48 @@ trait Block extends JavaCode { this } + /** + * Apply a map function to each java expression codes present in this java code, and return a new + * java code based on the mapped java expression codes. + */ + def transformExprValues(f: PartialFunction[ExprValue, ExprValue]): this.type = { + var changed = false + + @inline def transform(e: ExprValue): ExprValue = { + val newE = f lift e + if (!newE.isDefined || newE.get.equals(e)) { + e + } else { + changed = true + newE.get + } + } + + def doTransform(arg: Any): AnyRef = arg match { + case e: ExprValue => transform(e) + case Some(value) => Some(doTransform(value)) + case seq: Traversable[_] => seq.map(doTransform) + case other: AnyRef => other + } + + val newArgs = mapProductIterator(doTransform) + if (changed) makeCopy(newArgs).asInstanceOf[this.type] else this + } + // Concatenates this block with other block. - def + (other: Block): Block + def + (other: Block): Block = other match { + case EmptyBlock => this + case _ => code"$this\n$other" + } + + override def verboseString: String = toString } object Block { val CODE_BLOCK_BUFFER_LENGTH: Int = 512 - implicit def blocksToBlock(blocks: Seq[Block]): Block = Blocks(blocks) + implicit def blocksToBlock(blocks: Seq[Block]): Block = blocks.reduceLeft(_ + _) implicit class BlockHelper(val sc: StringContext) extends AnyVal { def code(args: Any*): Block = { @@ -190,18 +222,17 @@ object Block { while (strings.hasNext) { val input = inputs.next input match { - case _: ExprValue | _: Block => + case _: ExprValue | _: CodeBlock => codeParts += buf.toString buf.clear blockInputs += input.asInstanceOf[JavaCode] + case EmptyBlock => case _ => buf.append(input) } buf.append(strings.next) } - if (buf.nonEmpty) { - codeParts += buf.toString - } + codeParts += buf.toString (codeParts.toSeq, blockInputs.toSeq) } @@ -209,15 +240,15 @@ object Block { /** * A block of java code. Including a sequence of code parts and some inputs to this block. - * The actual java code is generated by embedding the inputs into the code parts. + * The actual java code is generated by embedding the inputs into the code parts. Here we keep + * inputs of `JavaCode` instead of simply folding them as a string of code, because we need to + * track expressions (`ExprValue`) in this code block. We need to be able to manipulate the + * expressions later without changing the behavior of this code block in some applications, e.g., + * method splitting. */ case class CodeBlock(codeParts: Seq[String], blockInputs: Seq[JavaCode]) extends Block { - override lazy val exprValues: Set[ExprValue] = { - blockInputs.flatMap { - case b: Block => b.exprValues - case e: ExprValue => Set(e) - }.toSet - } + override def children: Seq[Block] = + blockInputs.filter(_.isInstanceOf[Block]).asInstanceOf[Seq[Block]] override lazy val code: String = { val strings = codeParts.iterator @@ -230,30 +261,11 @@ case class CodeBlock(codeParts: Seq[String], blockInputs: Seq[JavaCode]) extends } buf.toString } - - override def + (other: Block): Block = other match { - case c: CodeBlock => Blocks(Seq(this, c)) - case b: Blocks => Blocks(Seq(this) ++ b.blocks) - case EmptyBlock => this - } } -case class Blocks(blocks: Seq[Block]) extends Block { - override lazy val exprValues: Set[ExprValue] = blocks.flatMap(_.exprValues).toSet - override lazy val code: String = blocks.map(_.toString).mkString("\n") - - override def + (other: Block): Block = other match { - case c: CodeBlock => Blocks(blocks :+ c) - case b: Blocks => Blocks(blocks ++ b.blocks) - case EmptyBlock => this - } -} - -object EmptyBlock extends Block with Serializable { +case object EmptyBlock extends Block with Serializable { override val code: String = "" - override val exprValues: Set[ExprValue] = Set.empty - - override def + (other: Block): Block = other + override def children: Seq[Block] = Seq.empty } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index d76f3013f0c41..972bc6e57892c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -16,21 +16,26 @@ */ package org.apache.spark.sql.catalyst.expressions -import java.util.Comparator +import java.util.{Comparator, TimeZone} import scala.collection.mutable +import scala.reflect.ClassTag import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion} import org.apache.spark.sql.catalyst.expressions.ArraySortLike.NullOrder import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ -import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, MapData, TypeUtils} +import org.apache.spark.sql.catalyst.util._ +import org.apache.spark.sql.catalyst.util.DateTimeUtils._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.unsafe.Platform import org.apache.spark.unsafe.array.ByteArrayMethods +import org.apache.spark.unsafe.array.ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH import org.apache.spark.unsafe.types.{ByteArray, UTF8String} +import org.apache.spark.unsafe.types.CalendarInterval +import org.apache.spark.util.collection.OpenHashSet /** * Base trait for [[BinaryExpression]]s with two arrays of the same element type and implicit @@ -66,37 +71,55 @@ trait BinaryArrayExpressionWithImplicitCast extends BinaryExpression /** - * Given an array or map, returns its size. Returns -1 if null. + * Given an array or map, returns total number of elements in it. */ @ExpressionDescription( - usage = "_FUNC_(expr) - Returns the size of an array or a map. Returns -1 if null.", + usage = """ + _FUNC_(expr) - Returns the size of an array or a map. + The function returns -1 if its input is null and spark.sql.legacy.sizeOfNull is set to true. + If spark.sql.legacy.sizeOfNull is set to false, the function returns null for null input. + By default, the spark.sql.legacy.sizeOfNull parameter is set to true. + """, examples = """ Examples: > SELECT _FUNC_(array('b', 'd', 'c', 'a')); 4 + > SELECT _FUNC_(map('a', 1, 'b', 2)); + 2 + > SELECT _FUNC_(NULL); + -1 """) case class Size(child: Expression) extends UnaryExpression with ExpectsInputTypes { + + val legacySizeOfNull = SQLConf.get.legacySizeOfNull + override def dataType: DataType = IntegerType override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(ArrayType, MapType)) - override def nullable: Boolean = false + override def nullable: Boolean = if (legacySizeOfNull) false else super.nullable override def eval(input: InternalRow): Any = { val value = child.eval(input) if (value == null) { - -1 + if (legacySizeOfNull) -1 else null } else child.dataType match { case _: ArrayType => value.asInstanceOf[ArrayData].numElements() case _: MapType => value.asInstanceOf[MapData].numElements() + case other => throw new UnsupportedOperationException( + s"The size function doesn't support the operand type ${other.getClass.getCanonicalName}") } } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val childGen = child.genCode(ctx) - ev.copy(code = code""" + if (legacySizeOfNull) { + val childGen = child.genCode(ctx) + ev.copy(code = code""" boolean ${ev.isNull} = false; ${childGen.code} ${CodeGenerator.javaType(dataType)} ${ev.value} = ${childGen.isNull} ? -1 : (${childGen.value}).numElements();""", isNull = FalseLiteral) + } else { + defineCodeGen(ctx, ev, c => s"($c).numElements()") + } } } @@ -199,7 +222,7 @@ case class ArraysZip(children: Seq[Expression]) extends Expression with ExpectsI """.stripMargin } - val splittedGetValuesAndCardinalities = ctx.splitExpressions( + val splittedGetValuesAndCardinalities = ctx.splitExpressionsWithCurrentInputs( expressions = getValuesAndCardinalities, funcName = "getValuesAndCardinalities", returnType = "int", @@ -209,7 +232,7 @@ case class ArraysZip(children: Seq[Expression]) extends Expression with ExpectsI |return $biggestCardinality; """.stripMargin, foldFunctions = _.map(funcCall => s"$biggestCardinality = $funcCall;").mkString("\n"), - arguments = + extraArguments = ("ArrayData[]", arrVals) :: ("int", biggestCardinality) :: Nil) @@ -474,6 +497,450 @@ case class MapEntries(child: Expression) extends UnaryExpression with ExpectsInp override def prettyName: String = "map_entries" } +/** + * Returns the union of all the given maps. + */ +@ExpressionDescription( + usage = "_FUNC_(map, ...) - Returns the union of all the given maps", + examples = """ + Examples: + > SELECT _FUNC_(map(1, 'a', 2, 'b'), map(2, 'c', 3, 'd')); + [[1 -> "a"], [2 -> "b"], [2 -> "c"], [3 -> "d"]] + """, since = "2.4.0") +case class MapConcat(children: Seq[Expression]) extends ComplexTypeMergingExpression { + + override def checkInputDataTypes(): TypeCheckResult = { + var funcName = s"function $prettyName" + if (children.exists(!_.dataType.isInstanceOf[MapType])) { + TypeCheckResult.TypeCheckFailure( + s"input to $funcName should all be of type map, but it's " + + children.map(_.dataType.simpleString).mkString("[", ", ", "]")) + } else { + TypeUtils.checkForSameTypeInputExpr(children.map(_.dataType), funcName) + } + } + + override def dataType: MapType = { + if (children.isEmpty) { + MapType(StringType, StringType) + } else { + super.dataType.asInstanceOf[MapType] + } + } + + override def nullable: Boolean = children.exists(_.nullable) + + override def eval(input: InternalRow): Any = { + val maps = children.map(_.eval(input)) + if (maps.contains(null)) { + return null + } + val keyArrayDatas = maps.map(_.asInstanceOf[MapData].keyArray()) + val valueArrayDatas = maps.map(_.asInstanceOf[MapData].valueArray()) + + val numElements = keyArrayDatas.foldLeft(0L)((sum, ad) => sum + ad.numElements()) + if (numElements > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) { + throw new RuntimeException(s"Unsuccessful attempt to concat maps with $numElements " + + s"elements due to exceeding the map size limit " + + s"${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}.") + } + val finalKeyArray = new Array[AnyRef](numElements.toInt) + val finalValueArray = new Array[AnyRef](numElements.toInt) + var position = 0 + for (i <- keyArrayDatas.indices) { + val keyArray = keyArrayDatas(i).toObjectArray(dataType.keyType) + val valueArray = valueArrayDatas(i).toObjectArray(dataType.valueType) + Array.copy(keyArray, 0, finalKeyArray, position, keyArray.length) + Array.copy(valueArray, 0, finalValueArray, position, valueArray.length) + position += keyArray.length + } + + new ArrayBasedMapData(new GenericArrayData(finalKeyArray), + new GenericArrayData(finalValueArray)) + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val mapCodes = children.map(_.genCode(ctx)) + val keyType = dataType.keyType + val valueType = dataType.valueType + val argsName = ctx.freshName("args") + val hasNullName = ctx.freshName("hasNull") + val mapDataClass = classOf[MapData].getName + val arrayBasedMapDataClass = classOf[ArrayBasedMapData].getName + val arrayDataClass = classOf[ArrayData].getName + + val init = + s""" + |$mapDataClass[] $argsName = new $mapDataClass[${mapCodes.size}]; + |boolean ${ev.isNull}, $hasNullName = false; + |$mapDataClass ${ev.value} = null; + """.stripMargin + + val assignments = mapCodes.zipWithIndex.map { case (m, i) => + s""" + |if (!$hasNullName) { + | ${m.code} + | $argsName[$i] = ${m.value}; + | if (${m.isNull}) { + | $hasNullName = true; + | } + |} + """.stripMargin + } + + val codes = ctx.splitExpressionsWithCurrentInputs( + expressions = assignments, + funcName = "getMapConcatInputs", + extraArguments = (s"$mapDataClass[]", argsName) :: ("boolean", hasNullName) :: Nil, + returnType = "boolean", + makeSplitFunction = body => + s""" + |$body + |return $hasNullName; + """.stripMargin, + foldFunctions = _.map(funcCall => s"$hasNullName = $funcCall;").mkString("\n") + ) + + val idxName = ctx.freshName("idx") + val numElementsName = ctx.freshName("numElems") + val finKeysName = ctx.freshName("finalKeys") + val finValsName = ctx.freshName("finalValues") + + val keyConcatenator = if (CodeGenerator.isPrimitiveType(keyType)) { + genCodeForPrimitiveArrays(ctx, keyType, false) + } else { + genCodeForNonPrimitiveArrays(ctx, keyType) + } + + val valueConcatenator = if (CodeGenerator.isPrimitiveType(valueType)) { + genCodeForPrimitiveArrays(ctx, valueType, dataType.valueContainsNull) + } else { + genCodeForNonPrimitiveArrays(ctx, valueType) + } + + val keyArgsName = ctx.freshName("keyArgs") + val valArgsName = ctx.freshName("valArgs") + + val mapMerge = + s""" + |${ev.isNull} = $hasNullName; + |if (!${ev.isNull}) { + | $arrayDataClass[] $keyArgsName = new $arrayDataClass[${mapCodes.size}]; + | $arrayDataClass[] $valArgsName = new $arrayDataClass[${mapCodes.size}]; + | long $numElementsName = 0; + | for (int $idxName = 0; $idxName < $argsName.length; $idxName++) { + | $keyArgsName[$idxName] = $argsName[$idxName].keyArray(); + | $valArgsName[$idxName] = $argsName[$idxName].valueArray(); + | $numElementsName += $argsName[$idxName].numElements(); + | } + | if ($numElementsName > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) { + | throw new RuntimeException("Unsuccessful attempt to concat maps with " + + | $numElementsName + " elements due to exceeding the map size limit " + + | "${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}."); + | } + | $arrayDataClass $finKeysName = $keyConcatenator.concat($keyArgsName, + | (int) $numElementsName); + | $arrayDataClass $finValsName = $valueConcatenator.concat($valArgsName, + | (int) $numElementsName); + | ${ev.value} = new $arrayBasedMapDataClass($finKeysName, $finValsName); + |} + """.stripMargin + + ev.copy( + code = code""" + |$init + |$codes + |$mapMerge + """.stripMargin) + } + + private def genCodeForPrimitiveArrays( + ctx: CodegenContext, + elementType: DataType, + checkForNull: Boolean): String = { + val counter = ctx.freshName("counter") + val arrayData = ctx.freshName("arrayData") + val argsName = ctx.freshName("args") + val numElemName = ctx.freshName("numElements") + val primitiveValueTypeName = CodeGenerator.primitiveTypeName(elementType) + + val setterCode1 = + s""" + |$arrayData.set$primitiveValueTypeName( + | $counter, + | ${CodeGenerator.getValue(s"$argsName[y]", elementType, "z")} + |);""".stripMargin + + val setterCode = if (checkForNull) { + s""" + |if ($argsName[y].isNullAt(z)) { + | $arrayData.setNullAt($counter); + |} else { + | $setterCode1 + |}""".stripMargin + } else { + setterCode1 + } + + s""" + |new Object() { + | public ArrayData concat(${classOf[ArrayData].getName}[] $argsName, int $numElemName) { + | ${ctx.createUnsafeArray(arrayData, numElemName, elementType, s" $prettyName failed.")} + | int $counter = 0; + | for (int y = 0; y < ${children.length}; y++) { + | for (int z = 0; z < $argsName[y].numElements(); z++) { + | $setterCode + | $counter++; + | } + | } + | return $arrayData; + | } + |}""".stripMargin.stripPrefix("\n") + } + + private def genCodeForNonPrimitiveArrays(ctx: CodegenContext, elementType: DataType): String = { + val genericArrayClass = classOf[GenericArrayData].getName + val arrayData = ctx.freshName("arrayObjects") + val counter = ctx.freshName("counter") + val argsName = ctx.freshName("args") + val numElemName = ctx.freshName("numElements") + + s""" + |new Object() { + | public ArrayData concat(${classOf[ArrayData].getName}[] $argsName, int $numElemName) {; + | Object[] $arrayData = new Object[$numElemName]; + | int $counter = 0; + | for (int y = 0; y < ${children.length}; y++) { + | for (int z = 0; z < $argsName[y].numElements(); z++) { + | $arrayData[$counter] = ${CodeGenerator.getValue(s"$argsName[y]", elementType, "z")}; + | $counter++; + | } + | } + | return new $genericArrayClass($arrayData); + | } + |}""".stripMargin.stripPrefix("\n") + } + + override def prettyName: String = "map_concat" +} + +/** + * Returns a map created from the given array of entries. + */ +@ExpressionDescription( + usage = "_FUNC_(arrayOfEntries) - Returns a map created from the given array of entries.", + examples = """ + Examples: + > SELECT _FUNC_(array(struct(1, 'a'), struct(2, 'b'))); + {1:"a",2:"b"} + """, + since = "2.4.0") +case class MapFromEntries(child: Expression) extends UnaryExpression { + + @transient + private lazy val dataTypeDetails: Option[(MapType, Boolean, Boolean)] = child.dataType match { + case ArrayType( + StructType(Array( + StructField(_, keyType, keyNullable, _), + StructField(_, valueType, valueNullable, _))), + containsNull) => Some((MapType(keyType, valueType, valueNullable), keyNullable, containsNull)) + case _ => None + } + + private def nullEntries: Boolean = dataTypeDetails.get._3 + + override def nullable: Boolean = child.nullable || nullEntries + + override def dataType: MapType = dataTypeDetails.get._1 + + override def checkInputDataTypes(): TypeCheckResult = dataTypeDetails match { + case Some(_) => TypeCheckResult.TypeCheckSuccess + case None => TypeCheckResult.TypeCheckFailure(s"'${child.sql}' is of " + + s"${child.dataType.simpleString} type. $prettyName accepts only arrays of pair structs.") + } + + override protected def nullSafeEval(input: Any): Any = { + val arrayData = input.asInstanceOf[ArrayData] + val numEntries = arrayData.numElements() + var i = 0 + if(nullEntries) { + while (i < numEntries) { + if (arrayData.isNullAt(i)) return null + i += 1 + } + } + val keyArray = new Array[AnyRef](numEntries) + val valueArray = new Array[AnyRef](numEntries) + i = 0 + while (i < numEntries) { + val entry = arrayData.getStruct(i, 2) + val key = entry.get(0, dataType.keyType) + if (key == null) { + throw new RuntimeException("The first field from a struct (key) can't be null.") + } + keyArray.update(i, key) + val value = entry.get(1, dataType.valueType) + valueArray.update(i, value) + i += 1 + } + ArrayBasedMapData(keyArray, valueArray) + } + + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + nullSafeCodeGen(ctx, ev, c => { + val numEntries = ctx.freshName("numEntries") + val isKeyPrimitive = CodeGenerator.isPrimitiveType(dataType.keyType) + val isValuePrimitive = CodeGenerator.isPrimitiveType(dataType.valueType) + val code = if (isKeyPrimitive && isValuePrimitive) { + genCodeForPrimitiveElements(ctx, c, ev.value, numEntries) + } else { + genCodeForAnyElements(ctx, c, ev.value, numEntries) + } + ctx.nullArrayElementsSaveExec(nullEntries, ev.isNull, c) { + s""" + |final int $numEntries = $c.numElements(); + |$code + """.stripMargin + } + }) + } + + private def genCodeForAssignmentLoop( + ctx: CodegenContext, + childVariable: String, + mapData: String, + numEntries: String, + keyAssignment: (String, String) => String, + valueAssignment: (String, String) => String): String = { + val entry = ctx.freshName("entry") + val i = ctx.freshName("idx") + + val nullKeyCheck = if (dataTypeDetails.get._2) { + s""" + |if ($entry.isNullAt(0)) { + | throw new RuntimeException("The first field from a struct (key) can't be null."); + |} + """.stripMargin + } else { + "" + } + + s""" + |for (int $i = 0; $i < $numEntries; $i++) { + | InternalRow $entry = $childVariable.getStruct($i, 2); + | $nullKeyCheck + | ${keyAssignment(CodeGenerator.getValue(entry, dataType.keyType, "0"), i)} + | ${valueAssignment(entry, i)} + |} + """.stripMargin + } + + private def genCodeForPrimitiveElements( + ctx: CodegenContext, + childVariable: String, + mapData: String, + numEntries: String): String = { + val byteArraySize = ctx.freshName("byteArraySize") + val keySectionSize = ctx.freshName("keySectionSize") + val valueSectionSize = ctx.freshName("valueSectionSize") + val data = ctx.freshName("byteArray") + val unsafeMapData = ctx.freshName("unsafeMapData") + val keyArrayData = ctx.freshName("keyArrayData") + val valueArrayData = ctx.freshName("valueArrayData") + + val baseOffset = Platform.BYTE_ARRAY_OFFSET + val keySize = dataType.keyType.defaultSize + val valueSize = dataType.valueType.defaultSize + val kByteSize = s"UnsafeArrayData.calculateSizeOfUnderlyingByteArray($numEntries, $keySize)" + val vByteSize = s"UnsafeArrayData.calculateSizeOfUnderlyingByteArray($numEntries, $valueSize)" + val keyTypeName = CodeGenerator.primitiveTypeName(dataType.keyType) + val valueTypeName = CodeGenerator.primitiveTypeName(dataType.valueType) + + val keyAssignment = (key: String, idx: String) => s"$keyArrayData.set$keyTypeName($idx, $key);" + val valueAssignment = (entry: String, idx: String) => { + val value = CodeGenerator.getValue(entry, dataType.valueType, "1") + val valueNullUnsafeAssignment = s"$valueArrayData.set$valueTypeName($idx, $value);" + if (dataType.valueContainsNull) { + s""" + |if ($entry.isNullAt(1)) { + | $valueArrayData.setNullAt($idx); + |} else { + | $valueNullUnsafeAssignment + |} + """.stripMargin + } else { + valueNullUnsafeAssignment + } + } + val assignmentLoop = genCodeForAssignmentLoop( + ctx, + childVariable, + mapData, + numEntries, + keyAssignment, + valueAssignment + ) + + s""" + |final long $keySectionSize = $kByteSize; + |final long $valueSectionSize = $vByteSize; + |final long $byteArraySize = 8 + $keySectionSize + $valueSectionSize; + |if ($byteArraySize > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) { + | ${genCodeForAnyElements(ctx, childVariable, mapData, numEntries)} + |} else { + | final byte[] $data = new byte[(int)$byteArraySize]; + | UnsafeMapData $unsafeMapData = new UnsafeMapData(); + | Platform.putLong($data, $baseOffset, $keySectionSize); + | Platform.putLong($data, ${baseOffset + 8}, $numEntries); + | Platform.putLong($data, ${baseOffset + 8} + $keySectionSize, $numEntries); + | $unsafeMapData.pointTo($data, $baseOffset, (int)$byteArraySize); + | ArrayData $keyArrayData = $unsafeMapData.keyArray(); + | ArrayData $valueArrayData = $unsafeMapData.valueArray(); + | $assignmentLoop + | $mapData = $unsafeMapData; + |} + """.stripMargin + } + + private def genCodeForAnyElements( + ctx: CodegenContext, + childVariable: String, + mapData: String, + numEntries: String): String = { + val keys = ctx.freshName("keys") + val values = ctx.freshName("values") + val mapDataClass = classOf[ArrayBasedMapData].getName() + + val isValuePrimitive = CodeGenerator.isPrimitiveType(dataType.valueType) + val valueAssignment = (entry: String, idx: String) => { + val value = CodeGenerator.getValue(entry, dataType.valueType, "1") + if (dataType.valueContainsNull && isValuePrimitive) { + s"$values[$idx] = $entry.isNullAt(1) ? null : (Object)$value;" + } else { + s"$values[$idx] = $value;" + } + } + val keyAssignment = (key: String, idx: String) => s"$keys[$idx] = $key;" + val assignmentLoop = genCodeForAssignmentLoop( + ctx, + childVariable, + mapData, + numEntries, + keyAssignment, + valueAssignment) + + s""" + |final Object[] $keys = new Object[$numEntries]; + |final Object[] $values = new Object[$numEntries]; + |$assignmentLoop + |$mapData = $mapDataClass.apply($keys, $values); + """.stripMargin + } + + override def prettyName: String = "map_from_entries" +} + + /** * Common base class for [[SortArray]] and [[ArraySort]]. */ @@ -522,7 +989,7 @@ trait ArraySortLike extends ExpectsInputTypes { } else if (o2 == null) { nullOrder } else { - -ordering.compare(o1, o2) + ordering.compare(o2, o1) } } } @@ -839,7 +1306,7 @@ case class ArrayContains(left: Expression, right: Expression) if (right.dataType == NullType) { TypeCheckResult.TypeCheckFailure("Null typed values cannot be used as arguments") } else if (!left.dataType.isInstanceOf[ArrayType] - || left.dataType.asInstanceOf[ArrayType].elementType != right.dataType) { + || !left.dataType.asInstanceOf[ArrayType].elementType.sameType(right.dataType)) { TypeCheckResult.TypeCheckFailure( "Arguments must be an array followed by a value of same type as the array members") } else { @@ -1403,6 +1870,7 @@ case class ArrayJoin( override def dataType: DataType = StringType + override def prettyName: String = "array_join" } /** @@ -1739,7 +2207,7 @@ case class ElementAt(left: Expression, right: Expression) extends GetMapValueUti > SELECT _FUNC_(array(1, 2, 3), array(4, 5), array(6)); | [1,2,3,4,5,6] """) -case class Concat(children: Seq[Expression]) extends Expression { +case class Concat(children: Seq[Expression]) extends ComplexTypeMergingExpression { private val MAX_ARRAY_LENGTH: Int = ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH @@ -1760,7 +2228,13 @@ case class Concat(children: Seq[Expression]) extends Expression { } } - override def dataType: DataType = children.map(_.dataType).headOption.getOrElse(StringType) + override def dataType: DataType = { + if (children.isEmpty) { + StringType + } else { + super.dataType + } + } lazy val javaType: String = CodeGenerator.javaType(dataType) @@ -1989,24 +2463,10 @@ case class Flatten(child: Expression) extends UnaryExpression { } else { genCodeForFlattenOfNonPrimitiveElements(ctx, c, ev.value) } - if (childDataType.containsNull) nullElementsProtection(ev, c, code) else code + ctx.nullArrayElementsSaveExec(childDataType.containsNull, ev.isNull, c)(code) }) } - private def nullElementsProtection( - ev: ExprCode, - childVariableName: String, - coreLogic: String): String = { - s""" - |for (int z = 0; !${ev.isNull} && z < $childVariableName.numElements(); z++) { - | ${ev.isNull} |= $childVariableName.isNullAt(z); - |} - |if (!${ev.isNull}) { - | $coreLogic - |} - """.stripMargin - } - private def genCodeForNumberOfElements( ctx: CodegenContext, childVariableName: String) : (String, String) = { @@ -2084,48 +2544,443 @@ case class Flatten(child: Expression) extends UnaryExpression { override def prettyName: String = "flatten" } -/** - * Returns the array containing the given input value (left) count (right) times. - */ @ExpressionDescription( - usage = "_FUNC_(element, count) - Returns the array containing element count times.", + usage = """ + _FUNC_(start, stop, step) - Generates an array of elements from start to stop (inclusive), + incrementing by step. The type of the returned elements is the same as the type of argument + expressions. + + Supported types are: byte, short, integer, long, date, timestamp. + + The start and stop expressions must resolve to the same type. + If start and stop expressions resolve to the 'date' or 'timestamp' type + then the step expression must resolve to the 'interval' type, otherwise to the same type + as the start and stop expressions. + """, + arguments = """ + Arguments: + * start - an expression. The start of the range. + * stop - an expression. The end the range (inclusive). + * step - an optional expression. The step of the range. + By default step is 1 if start is less than or equal to stop, otherwise -1. + For the temporal sequences it's 1 day and -1 day respectively. + If start is greater than stop then the step must be negative, and vice versa. + """, examples = """ Examples: - > SELECT _FUNC_('123', 2); - ['123', '123'] + > SELECT _FUNC_(1, 5); + [1, 2, 3, 4, 5] + > SELECT _FUNC_(5, 1); + [5, 4, 3, 2, 1] + > SELECT _FUNC_(to_date('2018-01-01'), to_date('2018-03-01'), interval 1 month); + [2018-01-01, 2018-02-01, 2018-03-01] """, - since = "2.4.0") -case class ArrayRepeat(left: Expression, right: Expression) - extends BinaryExpression with ExpectsInputTypes { + since = "2.4.0" +) +case class Sequence( + start: Expression, + stop: Expression, + stepOpt: Option[Expression], + timeZoneId: Option[String] = None) + extends Expression + with TimeZoneAwareExpression { - private val MAX_ARRAY_LENGTH = ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH + import Sequence._ - override def dataType: ArrayType = ArrayType(left.dataType, left.nullable) + def this(start: Expression, stop: Expression) = + this(start, stop, None, None) - override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType, IntegerType) + def this(start: Expression, stop: Expression, step: Expression) = + this(start, stop, Some(step), None) - override def nullable: Boolean = right.nullable + override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression = + copy(timeZoneId = Some(timeZoneId)) - override def eval(input: InternalRow): Any = { - val count = right.eval(input) - if (count == null) { - null + override def children: Seq[Expression] = Seq(start, stop) ++ stepOpt + + override def foldable: Boolean = children.forall(_.foldable) + + override def nullable: Boolean = children.exists(_.nullable) + + override lazy val dataType: ArrayType = ArrayType(start.dataType, containsNull = false) + + override def checkInputDataTypes(): TypeCheckResult = { + val startType = start.dataType + def stepType = stepOpt.get.dataType + val typesCorrect = + startType.sameType(stop.dataType) && + (startType match { + case TimestampType | DateType => + stepOpt.isEmpty || CalendarIntervalType.acceptsType(stepType) + case _: IntegralType => + stepOpt.isEmpty || stepType.sameType(startType) + case _ => false + }) + + if (typesCorrect) { + TypeCheckResult.TypeCheckSuccess } else { - if (count.asInstanceOf[Int] > MAX_ARRAY_LENGTH) { - throw new RuntimeException(s"Unsuccessful try to create array with $count elements " + - s"due to exceeding the array size limit $MAX_ARRAY_LENGTH."); - } - val element = left.eval(input) - new GenericArrayData(Array.fill(count.asInstanceOf[Int])(element)) + TypeCheckResult.TypeCheckFailure( + s"$prettyName only supports integral, timestamp or date types") } } - override def prettyName: String = "array_repeat" + def coercibleChildren: Seq[Expression] = children.filter(_.dataType != CalendarIntervalType) - override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val leftGen = left.genCode(ctx) - val rightGen = right.genCode(ctx) - val element = leftGen.value + def castChildrenTo(widerType: DataType): Expression = Sequence( + Cast(start, widerType), + Cast(stop, widerType), + stepOpt.map(step => if (step.dataType != CalendarIntervalType) Cast(step, widerType) else step), + timeZoneId) + + private lazy val impl: SequenceImpl = dataType.elementType match { + case iType: IntegralType => + type T = iType.InternalType + val ct = ClassTag[T](iType.tag.mirror.runtimeClass(iType.tag.tpe)) + new IntegralSequenceImpl(iType)(ct, iType.integral) + + case TimestampType => + new TemporalSequenceImpl[Long](LongType, 1, identity, timeZone) + + case DateType => + new TemporalSequenceImpl[Int](IntegerType, MICROS_PER_DAY, _.toInt, timeZone) + } + + override def eval(input: InternalRow): Any = { + val startVal = start.eval(input) + if (startVal == null) return null + val stopVal = stop.eval(input) + if (stopVal == null) return null + val stepVal = stepOpt.map(_.eval(input)).getOrElse(impl.defaultStep(startVal, stopVal)) + if (stepVal == null) return null + + ArrayData.toArrayData(impl.eval(startVal, stopVal, stepVal)) + } + + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val startGen = start.genCode(ctx) + val stopGen = stop.genCode(ctx) + val stepGen = stepOpt.map(_.genCode(ctx)).getOrElse( + impl.defaultStep.genCode(ctx, startGen, stopGen)) + + val resultType = CodeGenerator.javaType(dataType) + val resultCode = { + val arr = ctx.freshName("arr") + val arrElemType = CodeGenerator.javaType(dataType.elementType) + s""" + |final $arrElemType[] $arr = null; + |${impl.genCode(ctx, startGen.value, stopGen.value, stepGen.value, arr, arrElemType)} + |${ev.value} = UnsafeArrayData.fromPrimitiveArray($arr); + """.stripMargin + } + + if (nullable) { + val nullSafeEval = + startGen.code + ctx.nullSafeExec(start.nullable, startGen.isNull) { + stopGen.code + ctx.nullSafeExec(stop.nullable, stopGen.isNull) { + stepGen.code + ctx.nullSafeExec(stepOpt.exists(_.nullable), stepGen.isNull) { + s""" + |${ev.isNull} = false; + |$resultCode + """.stripMargin + } + } + } + ev.copy(code = + code""" + |boolean ${ev.isNull} = true; + |$resultType ${ev.value} = null; + |$nullSafeEval + """.stripMargin) + + } else { + ev.copy(code = + code""" + |${startGen.code} + |${stopGen.code} + |${stepGen.code} + |$resultType ${ev.value} = null; + |$resultCode + """.stripMargin, + isNull = FalseLiteral) + } + } +} + +object Sequence { + + private type LessThanOrEqualFn = (Any, Any) => Boolean + + private class DefaultStep(lteq: LessThanOrEqualFn, stepType: DataType, one: Any) { + private val negativeOne = UnaryMinus(Literal(one)).eval() + + def apply(start: Any, stop: Any): Any = { + if (lteq(start, stop)) one else negativeOne + } + + def genCode(ctx: CodegenContext, startGen: ExprCode, stopGen: ExprCode): ExprCode = { + val Seq(oneVal, negativeOneVal) = Seq(one, negativeOne).map(Literal(_).genCode(ctx).value) + ExprCode.forNonNullValue(JavaCode.expression( + s"${startGen.value} <= ${stopGen.value} ? $oneVal : $negativeOneVal", + stepType)) + } + } + + private trait SequenceImpl { + def eval(start: Any, stop: Any, step: Any): Any + + def genCode( + ctx: CodegenContext, + start: String, + stop: String, + step: String, + arr: String, + elemType: String): String + + val defaultStep: DefaultStep + } + + private class IntegralSequenceImpl[T: ClassTag] + (elemType: IntegralType)(implicit num: Integral[T]) extends SequenceImpl { + + override val defaultStep: DefaultStep = new DefaultStep( + (elemType.ordering.lteq _).asInstanceOf[LessThanOrEqualFn], + elemType, + num.one) + + override def eval(input1: Any, input2: Any, input3: Any): Array[T] = { + import num._ + + val start = input1.asInstanceOf[T] + val stop = input2.asInstanceOf[T] + val step = input3.asInstanceOf[T] + + var i: Int = getSequenceLength(start, stop, step) + val arr = new Array[T](i) + while (i > 0) { + i -= 1 + arr(i) = start + step * num.fromInt(i) + } + arr + } + + override def genCode( + ctx: CodegenContext, + start: String, + stop: String, + step: String, + arr: String, + elemType: String): String = { + val i = ctx.freshName("i") + s""" + |${genSequenceLengthCode(ctx, start, stop, step, i)} + |$arr = new $elemType[$i]; + |while ($i > 0) { + | $i--; + | $arr[$i] = ($elemType) ($start + $step * $i); + |} + """.stripMargin + } + } + + private class TemporalSequenceImpl[T: ClassTag] + (dt: IntegralType, scale: Long, fromLong: Long => T, timeZone: TimeZone) + (implicit num: Integral[T]) extends SequenceImpl { + + override val defaultStep: DefaultStep = new DefaultStep( + (dt.ordering.lteq _).asInstanceOf[LessThanOrEqualFn], + CalendarIntervalType, + new CalendarInterval(0, MICROS_PER_DAY)) + + private val backedSequenceImpl = new IntegralSequenceImpl[T](dt) + private val microsPerMonth = 28 * CalendarInterval.MICROS_PER_DAY + + override def eval(input1: Any, input2: Any, input3: Any): Array[T] = { + val start = input1.asInstanceOf[T] + val stop = input2.asInstanceOf[T] + val step = input3.asInstanceOf[CalendarInterval] + val stepMonths = step.months + val stepMicros = step.microseconds + + if (stepMonths == 0) { + backedSequenceImpl.eval(start, stop, fromLong(stepMicros / scale)) + + } else { + // To estimate the resulted array length we need to make assumptions + // about a month length in microseconds + val intervalStepInMicros = stepMicros + stepMonths * microsPerMonth + val startMicros: Long = num.toLong(start) * scale + val stopMicros: Long = num.toLong(stop) * scale + val maxEstimatedArrayLength = + getSequenceLength(startMicros, stopMicros, intervalStepInMicros) + + val stepSign = if (stopMicros > startMicros) +1 else -1 + val exclusiveItem = stopMicros + stepSign + val arr = new Array[T](maxEstimatedArrayLength) + var t = startMicros + var i = 0 + + while (t < exclusiveItem ^ stepSign < 0) { + arr(i) = fromLong(t / scale) + t = timestampAddInterval(t, stepMonths, stepMicros, timeZone) + i += 1 + } + + // truncate array to the correct length + if (arr.length == i) arr else arr.slice(0, i) + } + } + + override def genCode( + ctx: CodegenContext, + start: String, + stop: String, + step: String, + arr: String, + elemType: String): String = { + val stepMonths = ctx.freshName("stepMonths") + val stepMicros = ctx.freshName("stepMicros") + val stepScaled = ctx.freshName("stepScaled") + val intervalInMicros = ctx.freshName("intervalInMicros") + val startMicros = ctx.freshName("startMicros") + val stopMicros = ctx.freshName("stopMicros") + val arrLength = ctx.freshName("arrLength") + val stepSign = ctx.freshName("stepSign") + val exclusiveItem = ctx.freshName("exclusiveItem") + val t = ctx.freshName("t") + val i = ctx.freshName("i") + val genTimeZone = ctx.addReferenceObj("timeZone", timeZone, classOf[TimeZone].getName) + + val sequenceLengthCode = + s""" + |final long $intervalInMicros = $stepMicros + $stepMonths * ${microsPerMonth}L; + |${genSequenceLengthCode(ctx, startMicros, stopMicros, intervalInMicros, arrLength)} + """.stripMargin + + val timestampAddIntervalCode = + s""" + |$t = org.apache.spark.sql.catalyst.util.DateTimeUtils.timestampAddInterval( + | $t, $stepMonths, $stepMicros, $genTimeZone); + """.stripMargin + + s""" + |final int $stepMonths = $step.months; + |final long $stepMicros = $step.microseconds; + | + |if ($stepMonths == 0) { + | final $elemType $stepScaled = ($elemType) ($stepMicros / ${scale}L); + | ${backedSequenceImpl.genCode(ctx, start, stop, stepScaled, arr, elemType)}; + | + |} else { + | final long $startMicros = $start * ${scale}L; + | final long $stopMicros = $stop * ${scale}L; + | + | $sequenceLengthCode + | + | final int $stepSign = $stopMicros > $startMicros ? +1 : -1; + | final long $exclusiveItem = $stopMicros + $stepSign; + | + | $arr = new $elemType[$arrLength]; + | long $t = $startMicros; + | int $i = 0; + | + | while ($t < $exclusiveItem ^ $stepSign < 0) { + | $arr[$i] = ($elemType) ($t / ${scale}L); + | $timestampAddIntervalCode + | $i += 1; + | } + | + | if ($arr.length > $i) { + | $arr = java.util.Arrays.copyOf($arr, $i); + | } + |} + """.stripMargin + } + } + + private def getSequenceLength[U](start: U, stop: U, step: U)(implicit num: Integral[U]): Int = { + import num._ + require( + (step > num.zero && start <= stop) + || (step < num.zero && start >= stop) + || (step == num.zero && start == stop), + s"Illegal sequence boundaries: $start to $stop by $step") + + val len = if (start == stop) 1L else 1L + (stop.toLong - start.toLong) / step.toLong + + require( + len <= MAX_ROUNDED_ARRAY_LENGTH, + s"Too long sequence: $len. Should be <= $MAX_ROUNDED_ARRAY_LENGTH") + + len.toInt + } + + private def genSequenceLengthCode( + ctx: CodegenContext, + start: String, + stop: String, + step: String, + len: String): String = { + val longLen = ctx.freshName("longLen") + s""" + |if (!(($step > 0 && $start <= $stop) || + | ($step < 0 && $start >= $stop) || + | ($step == 0 && $start == $stop))) { + | throw new IllegalArgumentException( + | "Illegal sequence boundaries: " + $start + " to " + $stop + " by " + $step); + |} + |long $longLen = $stop == $start ? 1L : 1L + ((long) $stop - $start) / $step; + |if ($longLen > $MAX_ROUNDED_ARRAY_LENGTH) { + | throw new IllegalArgumentException( + | "Too long sequence: " + $longLen + ". Should be <= $MAX_ROUNDED_ARRAY_LENGTH"); + |} + |int $len = (int) $longLen; + """.stripMargin + } +} + +/** + * Returns the array containing the given input value (left) count (right) times. + */ +@ExpressionDescription( + usage = "_FUNC_(element, count) - Returns the array containing element count times.", + examples = """ + Examples: + > SELECT _FUNC_('123', 2); + ['123', '123'] + """, + since = "2.4.0") +case class ArrayRepeat(left: Expression, right: Expression) + extends BinaryExpression with ExpectsInputTypes { + + private val MAX_ARRAY_LENGTH = ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH + + override def dataType: ArrayType = ArrayType(left.dataType, left.nullable) + + override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType, IntegerType) + + override def nullable: Boolean = right.nullable + + override def eval(input: InternalRow): Any = { + val count = right.eval(input) + if (count == null) { + null + } else { + if (count.asInstanceOf[Int] > MAX_ARRAY_LENGTH) { + throw new RuntimeException(s"Unsuccessful try to create array with $count elements " + + s"due to exceeding the array size limit $MAX_ARRAY_LENGTH."); + } + val element = left.eval(input) + new GenericArrayData(Array.fill(count.asInstanceOf[Int])(element)) + } + } + + override def prettyName: String = "array_repeat" + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val leftGen = left.genCode(ctx) + val rightGen = right.genCode(ctx) + val element = leftGen.value val count = rightGen.value val et = dataType.elementType @@ -2355,3 +3210,600 @@ case class ArrayRemove(left: Expression, right: Expression) override def prettyName: String = "array_remove" } + +/** + * Removes duplicate values from the array. + */ +@ExpressionDescription( + usage = "_FUNC_(array) - Removes duplicate values from the array.", + examples = """ + Examples: + > SELECT _FUNC_(array(1, 2, 3, null, 3)); + [1,2,3,null] + """, since = "2.4.0") +case class ArrayDistinct(child: Expression) + extends UnaryExpression with ExpectsInputTypes { + + override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType) + + override def dataType: DataType = child.dataType + + @transient lazy val elementType: DataType = dataType.asInstanceOf[ArrayType].elementType + + @transient private lazy val ordering: Ordering[Any] = + TypeUtils.getInterpretedOrdering(elementType) + + override def checkInputDataTypes(): TypeCheckResult = { + super.checkInputDataTypes() match { + case f: TypeCheckResult.TypeCheckFailure => f + case TypeCheckResult.TypeCheckSuccess => + TypeUtils.checkForOrderingExpr(elementType, s"function $prettyName") + } + } + + @transient private lazy val elementTypeSupportEquals = elementType match { + case BinaryType => false + case _: AtomicType => true + case _ => false + } + + override def nullSafeEval(array: Any): Any = { + val data = array.asInstanceOf[ArrayData].toArray[AnyRef](elementType) + if (elementTypeSupportEquals) { + new GenericArrayData(data.distinct.asInstanceOf[Array[Any]]) + } else { + var foundNullElement = false + var pos = 0 + for (i <- 0 until data.length) { + if (data(i) == null) { + if (!foundNullElement) { + foundNullElement = true + pos = pos + 1 + } + } else { + var j = 0 + var done = false + while (j <= i && !done) { + if (data(j) != null && ordering.equiv(data(j), data(i))) { + done = true + } + j = j + 1 + } + if (i == j - 1) { + pos = pos + 1 + } + } + } + new GenericArrayData(data.slice(0, pos)) + } + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + nullSafeCodeGen(ctx, ev, (array) => { + val i = ctx.freshName("i") + val j = ctx.freshName("j") + val sizeOfDistinctArray = ctx.freshName("sizeOfDistinctArray") + val getValue1 = CodeGenerator.getValue(array, elementType, i) + val getValue2 = CodeGenerator.getValue(array, elementType, j) + val foundNullElement = ctx.freshName("foundNullElement") + val openHashSet = classOf[OpenHashSet[_]].getName + val hs = ctx.freshName("hs") + val classTag = s"scala.reflect.ClassTag$$.MODULE$$.Object()" + if (elementTypeSupportEquals) { + s""" + |int $sizeOfDistinctArray = 0; + |boolean $foundNullElement = false; + |$openHashSet $hs = new $openHashSet($classTag); + |for (int $i = 0; $i < $array.numElements(); $i ++) { + | if ($array.isNullAt($i)) { + | $foundNullElement = true; + | } else { + | $hs.add($getValue1); + | } + |} + |$sizeOfDistinctArray = $hs.size() + ($foundNullElement ? 1 : 0); + |${genCodeForResult(ctx, ev, array, sizeOfDistinctArray)} + """.stripMargin + } else { + s""" + |int $sizeOfDistinctArray = 0; + |boolean $foundNullElement = false; + |for (int $i = 0; $i < $array.numElements(); $i ++) { + | if ($array.isNullAt($i)) { + | if (!($foundNullElement)) { + | $sizeOfDistinctArray = $sizeOfDistinctArray + 1; + | $foundNullElement = true; + | } + | } else { + | int $j; + | for ($j = 0; $j < $i; $j ++) { + | if (!$array.isNullAt($j) && ${ctx.genEqual(elementType, getValue1, getValue2)}) { + | break; + | } + | } + | if ($i == $j) { + | $sizeOfDistinctArray = $sizeOfDistinctArray + 1; + | } + | } + |} + | + |${genCodeForResult(ctx, ev, array, sizeOfDistinctArray)} + """.stripMargin + } + }) + } + + private def setNull( + isPrimitive: Boolean, + foundNullElement: String, + distinctArray: String, + pos: String): String = { + val setNullValue = + if (!isPrimitive) { + s"$distinctArray[$pos] = null"; + } else { + s"$distinctArray.setNullAt($pos)"; + } + + s""" + |if (!($foundNullElement)) { + | $setNullValue; + | $pos = $pos + 1; + | $foundNullElement = true; + |} + """.stripMargin + } + + private def setNotNullValue(isPrimitive: Boolean, + distinctArray: String, + pos: String, + getValue1: String, + primitiveValueTypeName: String): String = { + if (!isPrimitive) { + s"$distinctArray[$pos] = $getValue1"; + } else { + s"$distinctArray.set$primitiveValueTypeName($pos, $getValue1)"; + } + } + + private def setValueForFastEval( + isPrimitive: Boolean, + hs: String, + distinctArray: String, + pos: String, + getValue1: String, + primitiveValueTypeName: String): String = { + val setValue = setNotNullValue(isPrimitive, + distinctArray, pos, getValue1, primitiveValueTypeName) + s""" + |if (!($hs.contains($getValue1))) { + | $hs.add($getValue1); + | $setValue; + | $pos = $pos + 1; + |} + """.stripMargin + } + + private def setValueForBruteForceEval( + isPrimitive: Boolean, + i: String, + j: String, + inputArray: String, + distinctArray: String, + pos: String, + getValue1: String, + isEqual: String, + primitiveValueTypeName: String): String = { + val setValue = setNotNullValue(isPrimitive, + distinctArray, pos, getValue1, primitiveValueTypeName) + s""" + |int $j; + |for ($j = 0; $j < $i; $j ++) { + | if (!$inputArray.isNullAt($j) && $isEqual) { + | break; + | } + |} + |if ($i == $j) { + | $setValue; + | $pos = $pos + 1; + |} + """.stripMargin + } + + def genCodeForResult( + ctx: CodegenContext, + ev: ExprCode, + inputArray: String, + size: String): String = { + val distinctArray = ctx.freshName("distinctArray") + val i = ctx.freshName("i") + val j = ctx.freshName("j") + val pos = ctx.freshName("pos") + val getValue1 = CodeGenerator.getValue(inputArray, elementType, i) + val getValue2 = CodeGenerator.getValue(inputArray, elementType, j) + val isEqual = ctx.genEqual(elementType, getValue1, getValue2) + val foundNullElement = ctx.freshName("foundNullElement") + val hs = ctx.freshName("hs") + val openHashSet = classOf[OpenHashSet[_]].getName + if (!CodeGenerator.isPrimitiveType(elementType)) { + val arrayClass = classOf[GenericArrayData].getName + val classTag = s"scala.reflect.ClassTag$$.MODULE$$.Object()" + val setNullForNonPrimitive = + setNull(false, foundNullElement, distinctArray, pos) + if (elementTypeSupportEquals) { + val setValueForFast = setValueForFastEval(false, hs, distinctArray, pos, getValue1, "") + s""" + |int $pos = 0; + |Object[] $distinctArray = new Object[$size]; + |boolean $foundNullElement = false; + |$openHashSet $hs = new $openHashSet($classTag); + |for (int $i = 0; $i < $inputArray.numElements(); $i ++) { + | if ($inputArray.isNullAt($i)) { + | $setNullForNonPrimitive; + | } else { + | $setValueForFast; + | } + |} + |${ev.value} = new $arrayClass($distinctArray); + """.stripMargin + } else { + val setValueForBruteForce = setValueForBruteForceEval( + false, i, j, inputArray, distinctArray, pos, getValue1, isEqual, "") + s""" + |int $pos = 0; + |Object[] $distinctArray = new Object[$size]; + |boolean $foundNullElement = false; + |for (int $i = 0; $i < $inputArray.numElements(); $i ++) { + | if ($inputArray.isNullAt($i)) { + | $setNullForNonPrimitive; + | } else { + | $setValueForBruteForce; + | } + |} + |${ev.value} = new $arrayClass($distinctArray); + """.stripMargin + } + } else { + val primitiveValueTypeName = CodeGenerator.primitiveTypeName(elementType) + val setNullForPrimitive = setNull(true, foundNullElement, distinctArray, pos) + val classTag = s"scala.reflect.ClassTag$$.MODULE$$.$primitiveValueTypeName()" + val setValueForFast = + setValueForFastEval(true, hs, distinctArray, pos, getValue1, primitiveValueTypeName) + s""" + |${ctx.createUnsafeArray(distinctArray, size, elementType, s" $prettyName failed.")} + |int $pos = 0; + |boolean $foundNullElement = false; + |$openHashSet $hs = new $openHashSet($classTag); + |for (int $i = 0; $i < $inputArray.numElements(); $i ++) { + | if ($inputArray.isNullAt($i)) { + | $setNullForPrimitive; + | } else { + | $setValueForFast; + | } + |} + |${ev.value} = $distinctArray; + """.stripMargin + } + } + + override def prettyName: String = "array_distinct" +} + +/** + * Will become common base class for [[ArrayUnion]], ArrayIntersect, and ArrayExcept. + */ +abstract class ArraySetLike extends BinaryArrayExpressionWithImplicitCast { + override def dataType: DataType = { + val dataTypes = children.map(_.dataType.asInstanceOf[ArrayType]) + ArrayType(elementType, dataTypes.exists(_.containsNull)) + } + + override def checkInputDataTypes(): TypeCheckResult = { + val typeCheckResult = super.checkInputDataTypes() + if (typeCheckResult.isSuccess) { + TypeUtils.checkForOrderingExpr(dataType.asInstanceOf[ArrayType].elementType, + s"function $prettyName") + } else { + typeCheckResult + } + } + + @transient protected lazy val ordering: Ordering[Any] = + TypeUtils.getInterpretedOrdering(elementType) + + @transient protected lazy val elementTypeSupportEquals = elementType match { + case BinaryType => false + case _: AtomicType => true + case _ => false + } +} + +object ArraySetLike { + def throwUnionLengthOverflowException(length: Int): Unit = { + throw new RuntimeException(s"Unsuccessful try to union arrays with $length " + + s"elements due to exceeding the array size limit " + + s"${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}.") + } +} + + +/** + * Returns an array of the elements in the union of x and y, without duplicates + */ +@ExpressionDescription( + usage = """ + _FUNC_(array1, array2) - Returns an array of the elements in the union of array1 and array2, + without duplicates. + """, + examples = """ + Examples: + > SELECT _FUNC_(array(1, 2, 3), array(1, 3, 5)); + array(1, 2, 3, 5) + """, + since = "2.4.0") +case class ArrayUnion(left: Expression, right: Expression) extends ArraySetLike { + var hsInt: OpenHashSet[Int] = _ + var hsLong: OpenHashSet[Long] = _ + + def assignInt(array: ArrayData, idx: Int, resultArray: ArrayData, pos: Int): Boolean = { + val elem = array.getInt(idx) + if (!hsInt.contains(elem)) { + if (resultArray != null) { + resultArray.setInt(pos, elem) + } + hsInt.add(elem) + true + } else { + false + } + } + + def assignLong(array: ArrayData, idx: Int, resultArray: ArrayData, pos: Int): Boolean = { + val elem = array.getLong(idx) + if (!hsLong.contains(elem)) { + if (resultArray != null) { + resultArray.setLong(pos, elem) + } + hsLong.add(elem) + true + } else { + false + } + } + + def evalIntLongPrimitiveType( + array1: ArrayData, + array2: ArrayData, + resultArray: ArrayData, + isLongType: Boolean): Int = { + // store elements into resultArray + var nullElementSize = 0 + var pos = 0 + Seq(array1, array2).foreach { array => + var i = 0 + while (i < array.numElements()) { + val size = if (!isLongType) hsInt.size else hsLong.size + if (size + nullElementSize > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) { + ArraySetLike.throwUnionLengthOverflowException(size) + } + if (array.isNullAt(i)) { + if (nullElementSize == 0) { + if (resultArray != null) { + resultArray.setNullAt(pos) + } + pos += 1 + nullElementSize = 1 + } + } else { + val assigned = if (!isLongType) { + assignInt(array, i, resultArray, pos) + } else { + assignLong(array, i, resultArray, pos) + } + if (assigned) { + pos += 1 + } + } + i += 1 + } + } + pos + } + + override def nullSafeEval(input1: Any, input2: Any): Any = { + val array1 = input1.asInstanceOf[ArrayData] + val array2 = input2.asInstanceOf[ArrayData] + + if (elementTypeSupportEquals) { + elementType match { + case IntegerType => + // avoid boxing of primitive int array elements + // calculate result array size + hsInt = new OpenHashSet[Int] + val elements = evalIntLongPrimitiveType(array1, array2, null, false) + hsInt = new OpenHashSet[Int] + val resultArray = if (UnsafeArrayData.shouldUseGenericArrayData( + IntegerType.defaultSize, elements)) { + new GenericArrayData(new Array[Any](elements)) + } else { + UnsafeArrayData.forPrimitiveArray( + Platform.INT_ARRAY_OFFSET, elements, IntegerType.defaultSize) + } + evalIntLongPrimitiveType(array1, array2, resultArray, false) + resultArray + case LongType => + // avoid boxing of primitive long array elements + // calculate result array size + hsLong = new OpenHashSet[Long] + val elements = evalIntLongPrimitiveType(array1, array2, null, true) + hsLong = new OpenHashSet[Long] + val resultArray = if (UnsafeArrayData.shouldUseGenericArrayData( + LongType.defaultSize, elements)) { + new GenericArrayData(new Array[Any](elements)) + } else { + UnsafeArrayData.forPrimitiveArray( + Platform.LONG_ARRAY_OFFSET, elements, LongType.defaultSize) + } + evalIntLongPrimitiveType(array1, array2, resultArray, true) + resultArray + case _ => + val arrayBuffer = new scala.collection.mutable.ArrayBuffer[Any] + val hs = new OpenHashSet[Any] + var foundNullElement = false + Seq(array1, array2).foreach { array => + var i = 0 + while (i < array.numElements()) { + if (array.isNullAt(i)) { + if (!foundNullElement) { + arrayBuffer += null + foundNullElement = true + } + } else { + val elem = array.get(i, elementType) + if (!hs.contains(elem)) { + if (arrayBuffer.size > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) { + ArraySetLike.throwUnionLengthOverflowException(arrayBuffer.size) + } + arrayBuffer += elem + hs.add(elem) + } + } + i += 1 + } + } + new GenericArrayData(arrayBuffer) + } + } else { + ArrayUnion.unionOrdering(array1, array2, elementType, ordering) + } + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val i = ctx.freshName("i") + val pos = ctx.freshName("pos") + val value = ctx.freshName("value") + val size = ctx.freshName("size") + val (postFix, openHashElementType, getter, setter, javaTypeName, castOp, arrayBuilder) = + if (elementTypeSupportEquals) { + elementType match { + case ByteType | ShortType | IntegerType | LongType => + val ptName = CodeGenerator.primitiveTypeName(elementType) + val unsafeArray = ctx.freshName("unsafeArray") + (if (elementType == LongType) s"$$mcJ$$sp" else s"$$mcI$$sp", + if (elementType == LongType) "Long" else "Int", + s"get$ptName($i)", s"set$ptName($pos, $value)", CodeGenerator.javaType(elementType), + if (elementType == LongType) "(long)" else "(int)", + s""" + |${ctx.createUnsafeArray(unsafeArray, size, elementType, s" $prettyName failed.")} + |${ev.value} = $unsafeArray; + """.stripMargin) + case _ => + val genericArrayData = classOf[GenericArrayData].getName + val et = ctx.addReferenceObj("elementType", elementType) + ("", "Object", + s"get($i, $et)", s"update($pos, $value)", "Object", "", + s"${ev.value} = new $genericArrayData(new Object[$size]);") + } + } else { + ("", "", "", "", "", "", "") + } + + nullSafeCodeGen(ctx, ev, (array1, array2) => { + if (openHashElementType != "") { + // Here, we ensure elementTypeSupportEquals is true + val foundNullElement = ctx.freshName("foundNullElement") + val openHashSet = classOf[OpenHashSet[_]].getName + val classTag = s"scala.reflect.ClassTag$$.MODULE$$.$openHashElementType()" + val hs = ctx.freshName("hs") + val arrayData = classOf[ArrayData].getName + val arrays = ctx.freshName("arrays") + val array = ctx.freshName("array") + val arrayDataIdx = ctx.freshName("arrayDataIdx") + s""" + |$openHashSet $hs = new $openHashSet$postFix($classTag); + |boolean $foundNullElement = false; + |$arrayData[] $arrays = new $arrayData[]{$array1, $array2}; + |for (int $arrayDataIdx = 0; $arrayDataIdx < 2; $arrayDataIdx++) { + | $arrayData $array = $arrays[$arrayDataIdx]; + | for (int $i = 0; $i < $array.numElements(); $i++) { + | if ($array.isNullAt($i)) { + | $foundNullElement = true; + | } else { + | $hs.add$postFix($array.$getter); + | } + | } + |} + |int $size = $hs.size() + ($foundNullElement ? 1 : 0); + |$arrayBuilder + |$hs = new $openHashSet$postFix($classTag); + |$foundNullElement = false; + |int $pos = 0; + |for (int $arrayDataIdx = 0; $arrayDataIdx < 2; $arrayDataIdx++) { + | $arrayData $array = $arrays[$arrayDataIdx]; + | for (int $i = 0; $i < $array.numElements(); $i++) { + | if ($array.isNullAt($i)) { + | if (!$foundNullElement) { + | ${ev.value}.setNullAt($pos++); + | $foundNullElement = true; + | } + | } else { + | $javaTypeName $value = $array.$getter; + | if (!$hs.contains($castOp $value)) { + | $hs.add$postFix($value); + | ${ev.value}.$setter; + | $pos++; + | } + | } + | } + |} + """.stripMargin + } else { + val arrayUnion = classOf[ArrayUnion].getName + val et = ctx.addReferenceObj("elementTypeUnion", elementType) + val order = ctx.addReferenceObj("orderingUnion", ordering) + val method = "unionOrdering" + s"${ev.value} = $arrayUnion$$.MODULE$$.$method($array1, $array2, $et, $order);" + } + }) + } + + override def prettyName: String = "array_union" +} + +object ArrayUnion { + def unionOrdering( + array1: ArrayData, + array2: ArrayData, + elementType: DataType, + ordering: Ordering[Any]): ArrayData = { + val arrayBuffer = new scala.collection.mutable.ArrayBuffer[Any] + var alreadyIncludeNull = false + Seq(array1, array2).foreach(_.foreach(elementType, (_, elem) => { + var found = false + if (elem == null) { + if (alreadyIncludeNull) { + found = true + } else { + alreadyIncludeNull = true + } + } else { + // check elem is already stored in arrayBuffer or not? + var j = 0 + while (!found && j < arrayBuffer.size) { + val va = arrayBuffer(j) + if (va != null && ordering.equiv(va, elem)) { + found = true + } + j = j + 1 + } + } + if (!found) { + if (arrayBuffer.length > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) { + ArraySetLike.throwUnionLengthOverflowException(arrayBuffer.length) + } + arrayBuffer += elem + } + })) + new GenericArrayData(arrayBuffer) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index 0a5f8a907b50a..a43de028360b1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -18,8 +18,8 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion} import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder -import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.util._ @@ -48,7 +48,8 @@ case class CreateArray(children: Seq[Expression]) extends Expression { override def dataType: ArrayType = { ArrayType( - children.headOption.map(_.dataType).getOrElse(StringType), + TypeCoercion.findCommonTypeDifferentOnlyInNullFlags(children.map(_.dataType)) + .getOrElse(StringType), containsNull = children.exists(_.nullable)) } @@ -179,11 +180,11 @@ case class CreateMap(children: Seq[Expression]) extends Expression { if (children.size % 2 != 0) { TypeCheckResult.TypeCheckFailure( s"$prettyName expects a positive even number of arguments.") - } else if (keys.map(_.dataType).distinct.length > 1) { + } else if (!TypeCoercion.haveSameType(keys.map(_.dataType))) { TypeCheckResult.TypeCheckFailure( "The given keys of function map should all be the same type, but they are " + keys.map(_.dataType.simpleString).mkString("[", ", ", "]")) - } else if (values.map(_.dataType).distinct.length > 1) { + } else if (!TypeCoercion.haveSameType(values.map(_.dataType))) { TypeCheckResult.TypeCheckFailure( "The given values of function map should all be the same type, but they are " + values.map(_.dataType.simpleString).mkString("[", ", ", "]")) @@ -194,8 +195,10 @@ case class CreateMap(children: Seq[Expression]) extends Expression { override def dataType: DataType = { MapType( - keyType = keys.headOption.map(_.dataType).getOrElse(StringType), - valueType = values.headOption.map(_.dataType).getOrElse(StringType), + keyType = TypeCoercion.findCommonTypeDifferentOnlyInNullFlags(keys.map(_.dataType)) + .getOrElse(StringType), + valueType = TypeCoercion.findCommonTypeDifferentOnlyInNullFlags(values.map(_.dataType)) + .getOrElse(StringType), valueContainsNull = values.exists(_.nullable)) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala index 77ac6c088022e..3b597e8b5263b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion} import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.types._ @@ -33,7 +33,12 @@ import org.apache.spark.sql.types._ """) // scalastyle:on line.size.limit case class If(predicate: Expression, trueValue: Expression, falseValue: Expression) - extends Expression { + extends ComplexTypeMergingExpression { + + @transient + override lazy val inputTypesForMerging: Seq[DataType] = { + Seq(trueValue.dataType, falseValue.dataType) + } override def children: Seq[Expression] = predicate :: trueValue :: falseValue :: Nil override def nullable: Boolean = trueValue.nullable || falseValue.nullable @@ -43,7 +48,7 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi TypeCheckResult.TypeCheckFailure( "type of predicate expression in If should be boolean, " + s"not ${predicate.dataType.simpleString}") - } else if (!trueValue.dataType.sameType(falseValue.dataType)) { + } else if (!TypeCoercion.haveSameType(inputTypesForMerging)) { TypeCheckResult.TypeCheckFailure(s"differing types in '$sql' " + s"(${trueValue.dataType.simpleString} and ${falseValue.dataType.simpleString}).") } else { @@ -51,8 +56,6 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi } } - override def dataType: DataType = trueValue.dataType - override def eval(input: InternalRow): Any = { if (java.lang.Boolean.TRUE.equals(predicate.eval(input))) { trueValue.eval(input) @@ -118,27 +121,24 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi case class CaseWhen( branches: Seq[(Expression, Expression)], elseValue: Option[Expression] = None) - extends Expression with Serializable { + extends ComplexTypeMergingExpression with Serializable { override def children: Seq[Expression] = branches.flatMap(b => b._1 :: b._2 :: Nil) ++ elseValue // both then and else expressions should be considered. - def valueTypes: Seq[DataType] = branches.map(_._2.dataType) ++ elseValue.map(_.dataType) - - def valueTypesEqual: Boolean = valueTypes.size <= 1 || valueTypes.sliding(2, 1).forall { - case Seq(dt1, dt2) => dt1.sameType(dt2) + @transient + override lazy val inputTypesForMerging: Seq[DataType] = { + branches.map(_._2.dataType) ++ elseValue.map(_.dataType) } - override def dataType: DataType = branches.head._2.dataType - override def nullable: Boolean = { // Result is nullable if any of the branch is nullable, or if the else value is nullable branches.exists(_._2.nullable) || elseValue.map(_.nullable).getOrElse(true) } override def checkInputDataTypes(): TypeCheckResult = { - // Make sure all branch conditions are boolean types. - if (valueTypesEqual) { + if (TypeCoercion.haveSameType(inputTypesForMerging)) { + // Make sure all branch conditions are boolean types. if (branches.forall(_._1.dataType == BooleanType)) { TypeCheckResult.TypeCheckSuccess } else { @@ -294,7 +294,7 @@ object CaseWhen { case cond :: value :: Nil => Some((cond, value)) case value :: Nil => None }.toArray.toSeq // force materialization to make the seq serializable - val elseValue = if (branches.size % 2 == 1) Some(branches.last) else None + val elseValue = if (branches.size % 2 != 0) Some(branches.last) else None CaseWhen(cases, elseValue) } } @@ -309,7 +309,7 @@ object CaseKeyWhen { case Seq(cond, value) => Some((EqualTo(key, cond), value)) case Seq(value) => None }.toArray.toSeq // force materialization to make the seq serializable - val elseValue = if (branches.size % 2 == 1) Some(branches.last) else None + val elseValue = if (branches.size % 2 != 0) Some(branches.last) else None CaseWhen(cases, elseValue) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala index 04a4eb0ffc032..63943b1a4351a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.expressions -import java.io.{ByteArrayInputStream, ByteArrayOutputStream, CharArrayWriter, InputStreamReader, StringWriter} +import java.io._ import scala.util.parsing.combinator.RegexParsers @@ -28,8 +28,8 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.catalyst.json._ -import org.apache.spark.sql.catalyst.parser.CatalystSqlParser -import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, BadRecordException, FailFastMode, GenericArrayData, MapData} +import org.apache.spark.sql.catalyst.json.JsonInferSchema.inferField +import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -514,10 +514,11 @@ case class JsonToStructs( schema: DataType, options: Map[String, String], child: Expression, - timeZoneId: Option[String], - forceNullableSchema: Boolean) + timeZoneId: Option[String] = None) extends UnaryExpression with TimeZoneAwareExpression with CodegenFallback with ExpectsInputTypes { + val forceNullableSchema = SQLConf.get.getConf(SQLConf.FROM_JSON_FORCE_NULLABLE_SCHEMA) + // The JSON input data might be missing certain fields. We force the nullability // of the user-provided schema to avoid data corruptions. In particular, the parquet-mr encoder // can generate incorrect files if values are missing in columns declared as non-nullable. @@ -526,26 +527,21 @@ case class JsonToStructs( override def nullable: Boolean = true // Used in `FunctionRegistry` - def this(child: Expression, schema: Expression) = + def this(child: Expression, schema: Expression, options: Map[String, String]) = this( - schema = JsonExprUtils.validateSchemaLiteral(schema), - options = Map.empty[String, String], + schema = JsonExprUtils.evalSchemaExpr(schema), + options = options, child = child, - timeZoneId = None, - forceNullableSchema = SQLConf.get.getConf(SQLConf.FROM_JSON_FORCE_NULLABLE_SCHEMA)) + timeZoneId = None) + + def this(child: Expression, schema: Expression) = this(child, schema, Map.empty[String, String]) def this(child: Expression, schema: Expression, options: Expression) = this( - schema = JsonExprUtils.validateSchemaLiteral(schema), + schema = JsonExprUtils.evalSchemaExpr(schema), options = JsonExprUtils.convertToMapData(options), child = child, - timeZoneId = None, - forceNullableSchema = SQLConf.get.getConf(SQLConf.FROM_JSON_FORCE_NULLABLE_SCHEMA)) - - // Used in `org.apache.spark.sql.functions` - def this(schema: DataType, options: Map[String, String], child: Expression) = - this(schema, options, child, timeZoneId = None, - forceNullableSchema = SQLConf.get.getConf(SQLConf.FROM_JSON_FORCE_NULLABLE_SCHEMA)) + timeZoneId = None) override def checkInputDataTypes(): TypeCheckResult = nullableSchema match { case _: StructType | ArrayType(_: StructType, _) | _: MapType => @@ -745,11 +741,44 @@ case class StructsToJson( override def inputTypes: Seq[AbstractDataType] = TypeCollection(ArrayType, StructType) :: Nil } +/** + * A function infers schema of JSON string. + */ +@ExpressionDescription( + usage = "_FUNC_(json[, options]) - Returns schema in the DDL format of JSON string.", + examples = """ + Examples: + > SELECT _FUNC_('[{"col":0}]'); + array> + """, + since = "2.4.0") +case class SchemaOfJson(child: Expression) + extends UnaryExpression with String2StringExpression with CodegenFallback { + + private val jsonOptions = new JSONOptions(Map.empty, "UTC") + private val jsonFactory = new JsonFactory() + jsonOptions.setJacksonOptions(jsonFactory) + + override def convert(v: UTF8String): UTF8String = { + val dt = Utils.tryWithResource(CreateJacksonParser.utf8String(jsonFactory, v)) { parser => + parser.nextToken() + inferField(parser, jsonOptions) + } + + UTF8String.fromString(dt.catalogString) + } +} + object JsonExprUtils { - def validateSchemaLiteral(exp: Expression): StructType = exp match { - case Literal(s, StringType) => CatalystSqlParser.parseTableSchema(s.toString) - case e => throw new AnalysisException(s"Expected a string literal instead of $e") + def evalSchemaExpr(exp: Expression): DataType = exp match { + case Literal(s, StringType) => DataType.fromDDL(s.toString) + case e @ SchemaOfJson(_: Literal) => + val ddlSchema = e.eval().asInstanceOf[UTF8String] + DataType.fromDDL(ddlSchema.toString) + case e => throw new AnalysisException( + "Schema should be specified in DDL format as a string literal" + + s" or output of the schema_of_json function instead of ${e.sql}") } def convertToMapData(exp: Expression): Map[String, String] = exp match { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala index 246025b82d59e..0efd1224f1bca 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala @@ -57,6 +57,7 @@ object Literal { case b: Byte => Literal(b, ByteType) case s: Short => Literal(s, ShortType) case s: String => Literal(UTF8String.fromString(s), StringType) + case c: Char => Literal(UTF8String.fromString(c.toString), StringType) case b: Boolean => Literal(b, BooleanType) case d: BigDecimal => Literal(Decimal(d), DecimalType.fromBigDecimal(d)) case d: JavaBigDecimal => @@ -185,7 +186,7 @@ object Literal { case map: MapType => create(Map(), map) case struct: StructType => create(InternalRow.fromSeq(struct.fields.map(f => default(f.dataType).value)), struct) - case udt: UserDefinedType[_] => default(udt.sqlType) + case udt: UserDefinedType[_] => Literal(default(udt.sqlType).value, udt) case other => throw new RuntimeException(s"no default for type $dataType") } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/maskExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/maskExpressions.scala deleted file mode 100644 index 276a57266a6e0..0000000000000 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/maskExpressions.scala +++ /dev/null @@ -1,569 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.expressions - -import org.apache.commons.codec.digest.DigestUtils - -import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.expressions.MaskExpressionsUtils._ -import org.apache.spark.sql.catalyst.expressions.MaskLike._ -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} -import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.types.UTF8String - - -trait MaskLike { - def upper: String - def lower: String - def digit: String - - protected lazy val upperReplacement: Int = getReplacementChar(upper, defaultMaskedUppercase) - protected lazy val lowerReplacement: Int = getReplacementChar(lower, defaultMaskedLowercase) - protected lazy val digitReplacement: Int = getReplacementChar(digit, defaultMaskedDigit) - - protected val maskUtilsClassName: String = classOf[MaskExpressionsUtils].getName - - def inputStringLengthCode(inputString: String, length: String): String = { - s"${CodeGenerator.JAVA_INT} $length = $inputString.codePointCount(0, $inputString.length());" - } - - def appendMaskedToStringBuilderCode( - ctx: CodegenContext, - sb: String, - inputString: String, - offset: String, - numChars: String): String = { - val i = ctx.freshName("i") - val codePoint = ctx.freshName("codePoint") - s""" - |for (${CodeGenerator.JAVA_INT} $i = 0; $i < $numChars; $i++) { - | ${CodeGenerator.JAVA_INT} $codePoint = $inputString.codePointAt($offset); - | $sb.appendCodePoint($maskUtilsClassName.transformChar($codePoint, - | $upperReplacement, $lowerReplacement, - | $digitReplacement, $defaultMaskedOther)); - | $offset += Character.charCount($codePoint); - |} - """.stripMargin - } - - def appendUnchangedToStringBuilderCode( - ctx: CodegenContext, - sb: String, - inputString: String, - offset: String, - numChars: String): String = { - val i = ctx.freshName("i") - val codePoint = ctx.freshName("codePoint") - s""" - |for (${CodeGenerator.JAVA_INT} $i = 0; $i < $numChars; $i++) { - | ${CodeGenerator.JAVA_INT} $codePoint = $inputString.codePointAt($offset); - | $sb.appendCodePoint($codePoint); - | $offset += Character.charCount($codePoint); - |} - """.stripMargin - } - - def appendMaskedToStringBuilder( - sb: java.lang.StringBuilder, - inputString: String, - startOffset: Int, - numChars: Int): Int = { - var offset = startOffset - (1 to numChars) foreach { _ => - val codePoint = inputString.codePointAt(offset) - sb.appendCodePoint(transformChar( - codePoint, - upperReplacement, - lowerReplacement, - digitReplacement, - defaultMaskedOther)) - offset += Character.charCount(codePoint) - } - offset - } - - def appendUnchangedToStringBuilder( - sb: java.lang.StringBuilder, - inputString: String, - startOffset: Int, - numChars: Int): Int = { - var offset = startOffset - (1 to numChars) foreach { _ => - val codePoint = inputString.codePointAt(offset) - sb.appendCodePoint(codePoint) - offset += Character.charCount(codePoint) - } - offset - } -} - -trait MaskLikeWithN extends MaskLike { - def n: Int - protected lazy val charCount: Int = if (n < 0) 0 else n -} - -/** - * Utils for mask operations. - */ -object MaskLike { - val defaultCharCount = 4 - val defaultMaskedUppercase: Int = 'X' - val defaultMaskedLowercase: Int = 'x' - val defaultMaskedDigit: Int = 'n' - val defaultMaskedOther: Int = MaskExpressionsUtils.UNMASKED_VAL - - def extractCharCount(e: Expression): Int = e match { - case Literal(i, IntegerType | NullType) => - if (i == null) defaultCharCount else i.asInstanceOf[Int] - case Literal(_, dt) => throw new AnalysisException("Expected literal expression of type " + - s"${IntegerType.simpleString}, but got literal of ${dt.simpleString}") - case other => throw new AnalysisException(s"Expected literal expression, but got ${other.sql}") - } - - def extractReplacement(e: Expression): String = e match { - case Literal(s, StringType | NullType) => if (s == null) null else s.toString - case Literal(_, dt) => throw new AnalysisException("Expected literal expression of type " + - s"${StringType.simpleString}, but got literal of ${dt.simpleString}") - case other => throw new AnalysisException(s"Expected literal expression, but got ${other.sql}") - } -} - -/** - * Masks the input string. Additional parameters can be set to change the masking chars for - * uppercase letters, lowercase letters and digits. - */ -// scalastyle:off line.size.limit -@ExpressionDescription( - usage = "_FUNC_(str[, upper[, lower[, digit]]]) - Masks str. By default, upper case letters are converted to \"X\", lower case letters are converted to \"x\" and numbers are converted to \"n\". You can override the characters used in the mask by supplying additional arguments: the second argument controls the mask character for upper case letters, the third argument for lower case letters and the fourth argument for numbers.", - examples = """ - Examples: - > SELECT _FUNC_("abcd-EFGH-8765-4321", "U", "l", "#"); - llll-UUUU-####-#### - """) -// scalastyle:on line.size.limit -case class Mask(child: Expression, upper: String, lower: String, digit: String) - extends UnaryExpression with ExpectsInputTypes with MaskLike { - - def this(child: Expression) = this(child, null.asInstanceOf[String], null, null) - - def this(child: Expression, upper: Expression) = - this(child, extractReplacement(upper), null, null) - - def this(child: Expression, upper: Expression, lower: Expression) = - this(child, extractReplacement(upper), extractReplacement(lower), null) - - def this(child: Expression, upper: Expression, lower: Expression, digit: Expression) = - this(child, extractReplacement(upper), extractReplacement(lower), extractReplacement(digit)) - - override def nullSafeEval(input: Any): Any = { - val str = input.asInstanceOf[UTF8String].toString - val length = str.codePointCount(0, str.length()) - val sb = new java.lang.StringBuilder(length) - appendMaskedToStringBuilder(sb, str, 0, length) - UTF8String.fromString(sb.toString) - } - - override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - nullSafeCodeGen(ctx, ev, (input: String) => { - val sb = ctx.freshName("sb") - val length = ctx.freshName("length") - val offset = ctx.freshName("offset") - val inputString = ctx.freshName("inputString") - s""" - |String $inputString = $input.toString(); - |${inputStringLengthCode(inputString, length)} - |StringBuilder $sb = new StringBuilder($length); - |${CodeGenerator.JAVA_INT} $offset = 0; - |${appendMaskedToStringBuilderCode(ctx, sb, inputString, offset, length)} - |${ev.value} = UTF8String.fromString($sb.toString()); - """.stripMargin - }) - } - - override def dataType: DataType = StringType - - override def inputTypes: Seq[AbstractDataType] = Seq(StringType) -} - -/** - * Masks the first N chars of the input string. N defaults to 4. Additional parameters can be set - * to change the masking chars for uppercase letters, lowercase letters and digits. - */ -// scalastyle:off line.size.limit -@ExpressionDescription( - usage = "_FUNC_(str[, n[, upper[, lower[, digit]]]]) - Masks the first n values of str. By default, n is 4, upper case letters are converted to \"X\", lower case letters are converted to \"x\" and numbers are converted to \"n\". You can override the characters used in the mask by supplying additional arguments: the second argument controls the mask character for upper case letters, the third argument for lower case letters and the fourth argument for numbers.", - examples = """ - Examples: - > SELECT _FUNC_("1234-5678-8765-4321", 4); - nnnn-5678-8765-4321 - """) -// scalastyle:on line.size.limit -case class MaskFirstN( - child: Expression, - n: Int, - upper: String, - lower: String, - digit: String) - extends UnaryExpression with ExpectsInputTypes with MaskLikeWithN { - - def this(child: Expression) = - this(child, defaultCharCount, null, null, null) - - def this(child: Expression, n: Expression) = - this(child, extractCharCount(n), null, null, null) - - def this(child: Expression, n: Expression, upper: Expression) = - this(child, extractCharCount(n), extractReplacement(upper), null, null) - - def this(child: Expression, n: Expression, upper: Expression, lower: Expression) = - this(child, extractCharCount(n), extractReplacement(upper), extractReplacement(lower), null) - - def this( - child: Expression, - n: Expression, - upper: Expression, - lower: Expression, - digit: Expression) = - this(child, - extractCharCount(n), - extractReplacement(upper), - extractReplacement(lower), - extractReplacement(digit)) - - override def nullSafeEval(input: Any): Any = { - val str = input.asInstanceOf[UTF8String].toString - val length = str.codePointCount(0, str.length()) - val endOfMask = if (charCount > length) length else charCount - val sb = new java.lang.StringBuilder(length) - val offset = appendMaskedToStringBuilder(sb, str, 0, endOfMask) - appendUnchangedToStringBuilder(sb, str, offset, length - endOfMask) - UTF8String.fromString(sb.toString) - } - - override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - nullSafeCodeGen(ctx, ev, (input: String) => { - val sb = ctx.freshName("sb") - val length = ctx.freshName("length") - val offset = ctx.freshName("offset") - val inputString = ctx.freshName("inputString") - val endOfMask = ctx.freshName("endOfMask") - s""" - |String $inputString = $input.toString(); - |${inputStringLengthCode(inputString, length)} - |${CodeGenerator.JAVA_INT} $endOfMask = $charCount > $length ? $length : $charCount; - |${CodeGenerator.JAVA_INT} $offset = 0; - |StringBuilder $sb = new StringBuilder($length); - |${appendMaskedToStringBuilderCode(ctx, sb, inputString, offset, endOfMask)} - |${appendUnchangedToStringBuilderCode( - ctx, sb, inputString, offset, s"$length - $endOfMask")} - |${ev.value} = UTF8String.fromString($sb.toString()); - |""".stripMargin - }) - } - - override def dataType: DataType = StringType - - override def inputTypes: Seq[AbstractDataType] = Seq(StringType) - - override def prettyName: String = "mask_first_n" -} - -/** - * Masks the last N chars of the input string. N defaults to 4. Additional parameters can be set - * to change the masking chars for uppercase letters, lowercase letters and digits. - */ -// scalastyle:off line.size.limit -@ExpressionDescription( - usage = "_FUNC_(str[, n[, upper[, lower[, digit]]]]) - Masks the last n values of str. By default, n is 4, upper case letters are converted to \"X\", lower case letters are converted to \"x\" and numbers are converted to \"n\". You can override the characters used in the mask by supplying additional arguments: the second argument controls the mask character for upper case letters, the third argument for lower case letters and the fourth argument for numbers.", - examples = """ - Examples: - > SELECT _FUNC_("1234-5678-8765-4321", 4); - 1234-5678-8765-nnnn - """, since = "2.4.0") -// scalastyle:on line.size.limit -case class MaskLastN( - child: Expression, - n: Int, - upper: String, - lower: String, - digit: String) - extends UnaryExpression with ExpectsInputTypes with MaskLikeWithN { - - def this(child: Expression) = - this(child, defaultCharCount, null, null, null) - - def this(child: Expression, n: Expression) = - this(child, extractCharCount(n), null, null, null) - - def this(child: Expression, n: Expression, upper: Expression) = - this(child, extractCharCount(n), extractReplacement(upper), null, null) - - def this(child: Expression, n: Expression, upper: Expression, lower: Expression) = - this(child, extractCharCount(n), extractReplacement(upper), extractReplacement(lower), null) - - def this( - child: Expression, - n: Expression, - upper: Expression, - lower: Expression, - digit: Expression) = - this(child, - extractCharCount(n), - extractReplacement(upper), - extractReplacement(lower), - extractReplacement(digit)) - - override def nullSafeEval(input: Any): Any = { - val str = input.asInstanceOf[UTF8String].toString - val length = str.codePointCount(0, str.length()) - val startOfMask = if (charCount >= length) 0 else length - charCount - val sb = new java.lang.StringBuilder(length) - val offset = appendUnchangedToStringBuilder(sb, str, 0, startOfMask) - appendMaskedToStringBuilder(sb, str, offset, length - startOfMask) - UTF8String.fromString(sb.toString) - } - - override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - nullSafeCodeGen(ctx, ev, (input: String) => { - val sb = ctx.freshName("sb") - val length = ctx.freshName("length") - val offset = ctx.freshName("offset") - val inputString = ctx.freshName("inputString") - val startOfMask = ctx.freshName("startOfMask") - s""" - |String $inputString = $input.toString(); - |${inputStringLengthCode(inputString, length)} - |${CodeGenerator.JAVA_INT} $startOfMask = $charCount >= $length ? - | 0 : $length - $charCount; - |${CodeGenerator.JAVA_INT} $offset = 0; - |StringBuilder $sb = new StringBuilder($length); - |${appendUnchangedToStringBuilderCode(ctx, sb, inputString, offset, startOfMask)} - |${appendMaskedToStringBuilderCode( - ctx, sb, inputString, offset, s"$length - $startOfMask")} - |${ev.value} = UTF8String.fromString($sb.toString()); - |""".stripMargin - }) - } - - override def dataType: DataType = StringType - - override def inputTypes: Seq[AbstractDataType] = Seq(StringType) - - override def prettyName: String = "mask_last_n" -} - -/** - * Masks all but the first N chars of the input string. N defaults to 4. Additional parameters can - * be set to change the masking chars for uppercase letters, lowercase letters and digits. - */ -// scalastyle:off line.size.limit -@ExpressionDescription( - usage = "_FUNC_(str[, n[, upper[, lower[, digit]]]]) - Masks all but the first n values of str. By default, n is 4, upper case letters are converted to \"X\", lower case letters are converted to \"x\" and numbers are converted to \"n\". You can override the characters used in the mask by supplying additional arguments: the second argument controls the mask character for upper case letters, the third argument for lower case letters and the fourth argument for numbers.", - examples = """ - Examples: - > SELECT _FUNC_("1234-5678-8765-4321", 4); - 1234-nnnn-nnnn-nnnn - """, since = "2.4.0") -// scalastyle:on line.size.limit -case class MaskShowFirstN( - child: Expression, - n: Int, - upper: String, - lower: String, - digit: String) - extends UnaryExpression with ExpectsInputTypes with MaskLikeWithN { - - def this(child: Expression) = - this(child, defaultCharCount, null, null, null) - - def this(child: Expression, n: Expression) = - this(child, extractCharCount(n), null, null, null) - - def this(child: Expression, n: Expression, upper: Expression) = - this(child, extractCharCount(n), extractReplacement(upper), null, null) - - def this(child: Expression, n: Expression, upper: Expression, lower: Expression) = - this(child, extractCharCount(n), extractReplacement(upper), extractReplacement(lower), null) - - def this( - child: Expression, - n: Expression, - upper: Expression, - lower: Expression, - digit: Expression) = - this(child, - extractCharCount(n), - extractReplacement(upper), - extractReplacement(lower), - extractReplacement(digit)) - - override def nullSafeEval(input: Any): Any = { - val str = input.asInstanceOf[UTF8String].toString - val length = str.codePointCount(0, str.length()) - val startOfMask = if (charCount > length) length else charCount - val sb = new java.lang.StringBuilder(length) - val offset = appendUnchangedToStringBuilder(sb, str, 0, startOfMask) - appendMaskedToStringBuilder(sb, str, offset, length - startOfMask) - UTF8String.fromString(sb.toString) - } - - override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - nullSafeCodeGen(ctx, ev, (input: String) => { - val sb = ctx.freshName("sb") - val length = ctx.freshName("length") - val offset = ctx.freshName("offset") - val inputString = ctx.freshName("inputString") - val startOfMask = ctx.freshName("startOfMask") - s""" - |String $inputString = $input.toString(); - |${inputStringLengthCode(inputString, length)} - |${CodeGenerator.JAVA_INT} $startOfMask = $charCount > $length ? $length : $charCount; - |${CodeGenerator.JAVA_INT} $offset = 0; - |StringBuilder $sb = new StringBuilder($length); - |${appendUnchangedToStringBuilderCode(ctx, sb, inputString, offset, startOfMask)} - |${appendMaskedToStringBuilderCode( - ctx, sb, inputString, offset, s"$length - $startOfMask")} - |${ev.value} = UTF8String.fromString($sb.toString()); - |""".stripMargin - }) - } - - override def dataType: DataType = StringType - - override def inputTypes: Seq[AbstractDataType] = Seq(StringType) - - override def prettyName: String = "mask_show_first_n" -} - -/** - * Masks all but the last N chars of the input string. N defaults to 4. Additional parameters can - * be set to change the masking chars for uppercase letters, lowercase letters and digits. - */ -// scalastyle:off line.size.limit -@ExpressionDescription( - usage = "_FUNC_(str[, n[, upper[, lower[, digit]]]]) - Masks all but the last n values of str. By default, n is 4, upper case letters are converted to \"X\", lower case letters are converted to \"x\" and numbers are converted to \"n\". You can override the characters used in the mask by supplying additional arguments: the second argument controls the mask character for upper case letters, the third argument for lower case letters and the fourth argument for numbers.", - examples = """ - Examples: - > SELECT _FUNC_("1234-5678-8765-4321", 4); - nnnn-nnnn-nnnn-4321 - """, since = "2.4.0") -// scalastyle:on line.size.limit -case class MaskShowLastN( - child: Expression, - n: Int, - upper: String, - lower: String, - digit: String) - extends UnaryExpression with ExpectsInputTypes with MaskLikeWithN { - - def this(child: Expression) = - this(child, defaultCharCount, null, null, null) - - def this(child: Expression, n: Expression) = - this(child, extractCharCount(n), null, null, null) - - def this(child: Expression, n: Expression, upper: Expression) = - this(child, extractCharCount(n), extractReplacement(upper), null, null) - - def this(child: Expression, n: Expression, upper: Expression, lower: Expression) = - this(child, extractCharCount(n), extractReplacement(upper), extractReplacement(lower), null) - - def this( - child: Expression, - n: Expression, - upper: Expression, - lower: Expression, - digit: Expression) = - this(child, - extractCharCount(n), - extractReplacement(upper), - extractReplacement(lower), - extractReplacement(digit)) - - override def nullSafeEval(input: Any): Any = { - val str = input.asInstanceOf[UTF8String].toString - val length = str.codePointCount(0, str.length()) - val endOfMask = if (charCount >= length) 0 else length - charCount - val sb = new java.lang.StringBuilder(length) - val offset = appendMaskedToStringBuilder(sb, str, 0, endOfMask) - appendUnchangedToStringBuilder(sb, str, offset, length - endOfMask) - UTF8String.fromString(sb.toString) - } - - override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - nullSafeCodeGen(ctx, ev, (input: String) => { - val sb = ctx.freshName("sb") - val length = ctx.freshName("length") - val offset = ctx.freshName("offset") - val inputString = ctx.freshName("inputString") - val endOfMask = ctx.freshName("endOfMask") - s""" - |String $inputString = $input.toString(); - |${inputStringLengthCode(inputString, length)} - |${CodeGenerator.JAVA_INT} $endOfMask = $charCount >= $length ? 0 : $length - $charCount; - |${CodeGenerator.JAVA_INT} $offset = 0; - |StringBuilder $sb = new StringBuilder($length); - |${appendMaskedToStringBuilderCode(ctx, sb, inputString, offset, endOfMask)} - |${appendUnchangedToStringBuilderCode( - ctx, sb, inputString, offset, s"$length - $endOfMask")} - |${ev.value} = UTF8String.fromString($sb.toString()); - |""".stripMargin - }) - } - - override def dataType: DataType = StringType - - override def inputTypes: Seq[AbstractDataType] = Seq(StringType) - - override def prettyName: String = "mask_show_last_n" -} - -/** - * Returns a hashed value based on str. - */ -// scalastyle:off line.size.limit -@ExpressionDescription( - usage = "_FUNC_(str) - Returns a hashed value based on str. The hash is consistent and can be used to join masked values together across tables.", - examples = """ - Examples: - > SELECT _FUNC_("abcd-EFGH-8765-4321"); - 60c713f5ec6912229d2060df1c322776 - """) -// scalastyle:on line.size.limit -case class MaskHash(child: Expression) - extends UnaryExpression with ExpectsInputTypes { - - override def nullSafeEval(input: Any): Any = { - UTF8String.fromString(DigestUtils.md5Hex(input.asInstanceOf[UTF8String].toString)) - } - - override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - nullSafeCodeGen(ctx, ev, (input: String) => { - val digestUtilsClass = classOf[DigestUtils].getName.stripSuffix("$") - s""" - |${ev.value} = UTF8String.fromString($digestUtilsClass.md5Hex($input.toString())); - |""".stripMargin - }) - } - - override def dataType: DataType = StringType - - override def inputTypes: Seq[AbstractDataType] = Seq(StringType) - - override def prettyName: String = "mask_hash" -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala index 2eeed3bbb2d91..b683d2a7e9ef3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion} import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.util.TypeUtils @@ -44,7 +44,7 @@ import org.apache.spark.sql.types._ 1 """) // scalastyle:on line.size.limit -case class Coalesce(children: Seq[Expression]) extends Expression { +case class Coalesce(children: Seq[Expression]) extends ComplexTypeMergingExpression { /** Coalesce is nullable if all of its children are nullable, or if it has no children. */ override def nullable: Boolean = children.forall(_.nullable) @@ -61,8 +61,6 @@ case class Coalesce(children: Seq[Expression]) extends Expression { } } - override def dataType: DataType = children.head.dataType - override def eval(input: InternalRow): Any = { var result: Any = null val childIterator = children.iterator diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala index 2ff12acb2946f..47eeb70e00427 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala @@ -73,6 +73,9 @@ private[sql] class JSONOptions( val columnNameOfCorruptRecord = parameters.getOrElse("columnNameOfCorruptRecord", defaultColumnNameOfCorruptRecord) + // Whether to ignore column of all null values or empty array/struct during schema inference + val dropFieldIfAllNull = parameters.get("dropFieldIfAllNull").map(_.toBoolean).getOrElse(false) + val timeZone: TimeZone = DateTimeUtils.getTimeZone( parameters.getOrElse(DateTimeUtils.TIMEZONE_OPTION, defaultTimeZoneId)) @@ -94,32 +97,16 @@ private[sql] class JSONOptions( sep } + protected def checkedEncoding(enc: String): String = enc + /** * Standard encoding (charset) name. For example UTF-8, UTF-16LE and UTF-32BE. - * If the encoding is not specified (None), it will be detected automatically - * when the multiLine option is set to `true`. + * If the encoding is not specified (None) in read, it will be detected automatically + * when the multiLine option is set to `true`. If encoding is not specified in write, + * UTF-8 is used by default. */ val encoding: Option[String] = parameters.get("encoding") - .orElse(parameters.get("charset")).map { enc => - // The following encodings are not supported in per-line mode (multiline is false) - // because they cause some problems in reading files with BOM which is supposed to - // present in the files with such encodings. After splitting input files by lines, - // only the first lines will have the BOM which leads to impossibility for reading - // the rest lines. Besides of that, the lineSep option must have the BOM in such - // encodings which can never present between lines. - val blacklist = Seq(Charset.forName("UTF-16"), Charset.forName("UTF-32")) - val isBlacklisted = blacklist.contains(Charset.forName(enc)) - require(multiLine || !isBlacklisted, - s"""The $enc encoding in the blacklist is not allowed when multiLine is disabled. - |Blacklist: ${blacklist.mkString(", ")}""".stripMargin) - - val isLineSepRequired = - multiLine || Charset.forName(enc) == StandardCharsets.UTF_8 || lineSeparator.nonEmpty - - require(isLineSepRequired, s"The lineSep option must be specified for the $enc encoding") - - enc - } + .orElse(parameters.get("charset")).map(checkedEncoding) val lineSeparatorInRead: Option[Array[Byte]] = lineSeparator.map { lineSep => lineSep.getBytes(encoding.getOrElse("UTF-8")) @@ -138,3 +125,46 @@ private[sql] class JSONOptions( factory.configure(JsonParser.Feature.ALLOW_UNQUOTED_CONTROL_CHARS, allowUnquotedControlChars) } } + +private[sql] class JSONOptionsInRead( + @transient override val parameters: CaseInsensitiveMap[String], + defaultTimeZoneId: String, + defaultColumnNameOfCorruptRecord: String) + extends JSONOptions(parameters, defaultTimeZoneId, defaultColumnNameOfCorruptRecord) { + + def this( + parameters: Map[String, String], + defaultTimeZoneId: String, + defaultColumnNameOfCorruptRecord: String = "") = { + this( + CaseInsensitiveMap(parameters), + defaultTimeZoneId, + defaultColumnNameOfCorruptRecord) + } + + protected override def checkedEncoding(enc: String): String = { + val isBlacklisted = JSONOptionsInRead.blacklist.contains(Charset.forName(enc)) + require(multiLine || !isBlacklisted, + s"""The ${enc} encoding must not be included in the blacklist when multiLine is disabled: + |Blacklist: ${JSONOptionsInRead.blacklist.mkString(", ")}""".stripMargin) + + val isLineSepRequired = + multiLine || Charset.forName(enc) == StandardCharsets.UTF_8 || lineSeparator.nonEmpty + require(isLineSepRequired, s"The lineSep option must be specified for the $enc encoding") + + enc + } +} + +private[sql] object JSONOptionsInRead { + // The following encodings are not supported in per-line mode (multiline is false) + // because they cause some problems in reading files with BOM which is supposed to + // present in the files with such encodings. After splitting input files by lines, + // only the first lines will have the BOM which leads to impossibility for reading + // the rest lines. Besides of that, the lineSep option must have the BOM in such + // encodings which can never present between lines. + val blacklist = Seq( + Charset.forName("UTF-16"), + Charset.forName("UTF-32") + ) +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JsonInferSchema.scala similarity index 92% rename from sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala rename to sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JsonInferSchema.scala index e7eed95a560a3..491ca005877f8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JsonInferSchema.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.execution.datasources.json +package org.apache.spark.sql.catalyst.json import java.util.Comparator @@ -25,7 +25,6 @@ import org.apache.spark.SparkException import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.analysis.TypeCoercion import org.apache.spark.sql.catalyst.json.JacksonUtils.nextUntil -import org.apache.spark.sql.catalyst.json.JSONOptions import org.apache.spark.sql.catalyst.util.{DropMalformedMode, FailFastMode, ParseMode, PermissiveMode} import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -75,7 +74,7 @@ private[sql] object JsonInferSchema { // active SparkSession and `SQLConf.get` may point to the wrong configs. val rootType = mergedTypesFromPartitions.toLocalIterator.fold(StructType(Nil))(typeMerger) - canonicalizeType(rootType) match { + canonicalizeType(rootType, configOptions) match { case Some(st: StructType) => st case _ => // canonicalizeType erases all empty structs, including the only one we want to keep @@ -103,7 +102,7 @@ private[sql] object JsonInferSchema { /** * Infer the type of a json document from the parser's token stream */ - private def inferField(parser: JsonParser, configOptions: JSONOptions): DataType = { + def inferField(parser: JsonParser, configOptions: JSONOptions): DataType = { import com.fasterxml.jackson.core.JsonToken._ parser.getCurrentToken match { case null | VALUE_NULL => NullType @@ -181,33 +180,33 @@ private[sql] object JsonInferSchema { } /** - * Convert NullType to StringType and remove StructTypes with no fields + * Recursively canonicalizes inferred types, e.g., removes StructTypes with no fields, + * drops NullTypes or converts them to StringType based on provided options. */ - private def canonicalizeType(tpe: DataType): Option[DataType] = tpe match { - case at @ ArrayType(elementType, _) => - for { - canonicalType <- canonicalizeType(elementType) - } yield { - at.copy(canonicalType) - } + private def canonicalizeType(tpe: DataType, options: JSONOptions): Option[DataType] = tpe match { + case at: ArrayType => + canonicalizeType(at.elementType, options) + .map(t => at.copy(elementType = t)) case StructType(fields) => - val canonicalFields: Array[StructField] = for { - field <- fields - if field.name.length > 0 - canonicalType <- canonicalizeType(field.dataType) - } yield { - field.copy(dataType = canonicalType) + val canonicalFields = fields.filter(_.name.nonEmpty).flatMap { f => + canonicalizeType(f.dataType, options) + .map(t => f.copy(dataType = t)) } - - if (canonicalFields.length > 0) { - Some(StructType(canonicalFields)) + // SPARK-8093: empty structs should be deleted + if (canonicalFields.isEmpty) { + None } else { - // per SPARK-8093: empty structs should be deleted + Some(StructType(canonicalFields)) + } + + case NullType => + if (options.dropFieldIfAllNull) { None + } else { + Some(StringType) } - case NullType => Some(StringType) case other => Some(other) } @@ -334,8 +333,8 @@ private[sql] object JsonInferSchema { ArrayType(compatibleType(elementType1, elementType2), containsNull1 || containsNull2) // The case that given `DecimalType` is capable of given `IntegralType` is handled in - // `findTightestCommonTypeOfTwo`. Both cases below will be executed only when - // the given `DecimalType` is not capable of the given `IntegralType`. + // `findTightestCommonType`. Both cases below will be executed only when the given + // `DecimalType` is not capable of the given `IntegralType`. case (t1: IntegralType, t2: DecimalType) => compatibleType(DecimalType.forType(t1), t2) case (t1: DecimalType, t2: IntegralType) => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index aa992def1ce6c..2cc27d82f7d20 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -526,9 +526,10 @@ object ColumnPruning extends Rule[LogicalPlan] { /** * The Project before Filter is not necessary but conflict with PushPredicatesThroughProject, - * so remove it. + * so remove it. Since the Projects have been added top-down, we need to remove in bottom-up + * order, otherwise lower Projects can be missed. */ - private def removeProjectBeforeFilter(plan: LogicalPlan): LogicalPlan = plan transform { + private def removeProjectBeforeFilter(plan: LogicalPlan): LogicalPlan = plan transformUp { case p1 @ Project(_, f @ Filter(_, p2 @ Project(_, child))) if p2.outputSet.subsetOf(child.outputSet) => p1.copy(child = f.copy(child = child)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 383ebde3229d6..f398b479dc273 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -1507,7 +1507,7 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging case "TIMESTAMP" => Literal(Timestamp.valueOf(value)) case "X" => - val padding = if (value.length % 2 == 1) "0" else "" + val padding = if (value.length % 2 != 0) "0" else "" Literal(DatatypeConverter.parseHexBinary(padding + value)) case other => throw new ParseException(s"Literals of type '$other' are currently not supported.", ctx) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index 64cb8c726772f..4b4722b0b2117 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -27,8 +27,6 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT /** * The active config object within the current scope. - * Note that if you want to refer config values during execution, you have to capture them - * in Driver and use the captured values in Executors. * See [[SQLConf.get]] for more information. */ def conf: SQLConf = SQLConf.get @@ -119,6 +117,7 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT case Some(value) => Some(recursiveTransform(value)) case m: Map[_, _] => m case d: DataType => d // Avoid unpacking Structs + case stream: Stream[_] => stream.map(recursiveTransform).force case seq: Traversable[_] => seq.map(recursiveTransform) case other: AnyRef => other case null => null diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index 4d9a9925fe3ff..cc1a5e835d9cd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -99,16 +99,19 @@ case class ClusteredDistribution( * This is a strictly stronger guarantee than [[ClusteredDistribution]]. Given a tuple and the * number of partitions, this distribution strictly requires which partition the tuple should be in. */ -case class HashClusteredDistribution(expressions: Seq[Expression]) extends Distribution { +case class HashClusteredDistribution( + expressions: Seq[Expression], + requiredNumPartitions: Option[Int] = None) extends Distribution { require( expressions != Nil, - "The expressions for hash of a HashPartitionedDistribution should not be Nil. " + + "The expressions for hash of a HashClusteredDistribution should not be Nil. " + "An AllTuples should be used to represent a distribution that only has " + "a single partition.") - override def requiredNumPartitions: Option[Int] = None - override def createPartitioning(numPartitions: Int): Partitioning = { + assert(requiredNumPartitions.isEmpty || requiredNumPartitions.get == numPartitions, + s"This HashClusteredDistribution requires ${requiredNumPartitions.get} partitions, but " + + s"the actual number of partitions is $numPartitions.") HashPartitioning(expressions, numPartitions) } } @@ -163,11 +166,22 @@ trait Partitioning { * i.e. the current dataset does not need to be re-partitioned for the `required` * Distribution (it is possible that tuples within a partition need to be reorganized). * + * A [[Partitioning]] can never satisfy a [[Distribution]] if its `numPartitions` does't match + * [[Distribution.requiredNumPartitions]]. + */ + final def satisfies(required: Distribution): Boolean = { + required.requiredNumPartitions.forall(_ == numPartitions) && satisfies0(required) + } + + /** + * The actual method that defines whether this [[Partitioning]] can satisfy the given + * [[Distribution]], after the `numPartitions` check. + * * By default a [[Partitioning]] can satisfy [[UnspecifiedDistribution]], and [[AllTuples]] if - * the [[Partitioning]] only have one partition. Implementations can overwrite this method with - * special logic. + * the [[Partitioning]] only have one partition. Implementations can also overwrite this method + * with special logic. */ - def satisfies(required: Distribution): Boolean = required match { + protected def satisfies0(required: Distribution): Boolean = required match { case UnspecifiedDistribution => true case AllTuples => numPartitions == 1 case _ => false @@ -186,9 +200,8 @@ case class RoundRobinPartitioning(numPartitions: Int) extends Partitioning case object SinglePartition extends Partitioning { val numPartitions = 1 - override def satisfies(required: Distribution): Boolean = required match { + override def satisfies0(required: Distribution): Boolean = required match { case _: BroadcastDistribution => false - case ClusteredDistribution(_, Some(requiredNumPartitions)) => requiredNumPartitions == 1 case _ => true } } @@ -205,16 +218,15 @@ case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int) override def nullable: Boolean = false override def dataType: DataType = IntegerType - override def satisfies(required: Distribution): Boolean = { - super.satisfies(required) || { + override def satisfies0(required: Distribution): Boolean = { + super.satisfies0(required) || { required match { case h: HashClusteredDistribution => expressions.length == h.expressions.length && expressions.zip(h.expressions).forall { case (l, r) => l.semanticEquals(r) } - case ClusteredDistribution(requiredClustering, requiredNumPartitions) => - expressions.forall(x => requiredClustering.exists(_.semanticEquals(x))) && - (requiredNumPartitions.isEmpty || requiredNumPartitions.get == numPartitions) + case ClusteredDistribution(requiredClustering, _) => + expressions.forall(x => requiredClustering.exists(_.semanticEquals(x))) case _ => false } } @@ -246,15 +258,14 @@ case class RangePartitioning(ordering: Seq[SortOrder], numPartitions: Int) override def nullable: Boolean = false override def dataType: DataType = IntegerType - override def satisfies(required: Distribution): Boolean = { - super.satisfies(required) || { + override def satisfies0(required: Distribution): Boolean = { + super.satisfies0(required) || { required match { case OrderedDistribution(requiredOrdering) => val minSize = Seq(requiredOrdering.size, ordering.size).min requiredOrdering.take(minSize) == ordering.take(minSize) - case ClusteredDistribution(requiredClustering, requiredNumPartitions) => - ordering.map(_.child).forall(x => requiredClustering.exists(_.semanticEquals(x))) && - (requiredNumPartitions.isEmpty || requiredNumPartitions.get == numPartitions) + case ClusteredDistribution(requiredClustering, _) => + ordering.map(_.child).forall(x => requiredClustering.exists(_.semanticEquals(x))) case _ => false } } @@ -295,7 +306,7 @@ case class PartitioningCollection(partitionings: Seq[Partitioning]) * Returns true if any `partitioning` of this collection satisfies the given * [[Distribution]]. */ - override def satisfies(required: Distribution): Boolean = + override def satisfies0(required: Distribution): Boolean = partitionings.exists(_.satisfies(required)) override def toString: String = { @@ -310,7 +321,7 @@ case class PartitioningCollection(partitionings: Seq[Partitioning]) case class BroadcastPartitioning(mode: BroadcastMode) extends Partitioning { override val numPartitions: Int = 1 - override def satisfies(required: Distribution): Boolean = required match { + override def satisfies0(required: Distribution): Boolean = required match { case BroadcastDistribution(m) if m == mode => true case _ => false } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala index 9c7d47f99ee10..becfa8d982213 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala @@ -199,44 +199,33 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { var changed = false val remainingNewChildren = newChildren.toBuffer val remainingOldChildren = children.toBuffer + def mapTreeNode(node: TreeNode[_]): TreeNode[_] = { + val newChild = remainingNewChildren.remove(0) + val oldChild = remainingOldChildren.remove(0) + if (newChild fastEquals oldChild) { + oldChild + } else { + changed = true + newChild + } + } + def mapChild(child: Any): Any = child match { + case arg: TreeNode[_] if containsChild(arg) => mapTreeNode(arg) + case nonChild: AnyRef => nonChild + case null => null + } val newArgs = mapProductIterator { case s: StructType => s // Don't convert struct types to some other type of Seq[StructField] // Handle Seq[TreeNode] in TreeNode parameters. - case s: Seq[_] => s.map { - case arg: TreeNode[_] if containsChild(arg) => - val newChild = remainingNewChildren.remove(0) - val oldChild = remainingOldChildren.remove(0) - if (newChild fastEquals oldChild) { - oldChild - } else { - changed = true - newChild - } - case nonChild: AnyRef => nonChild - case null => null - } - case m: Map[_, _] => m.mapValues { - case arg: TreeNode[_] if containsChild(arg) => - val newChild = remainingNewChildren.remove(0) - val oldChild = remainingOldChildren.remove(0) - if (newChild fastEquals oldChild) { - oldChild - } else { - changed = true - newChild - } - case nonChild: AnyRef => nonChild - case null => null - }.view.force // `mapValues` is lazy and we need to force it to materialize - case arg: TreeNode[_] if containsChild(arg) => - val newChild = remainingNewChildren.remove(0) - val oldChild = remainingOldChildren.remove(0) - if (newChild fastEquals oldChild) { - oldChild - } else { - changed = true - newChild - } + case s: Stream[_] => + // Stream is lazy so we need to force materialization + s.map(mapChild).force + case s: Seq[_] => + s.map(mapChild) + case m: Map[_, _] => + // `mapValues` is lazy and we need to force it to materialize + m.mapValues(mapChild).view.force + case arg: TreeNode[_] if containsChild(arg) => mapTreeNode(arg) case nonChild: AnyRef => nonChild case null => null } @@ -301,6 +290,37 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { def mapChildren(f: BaseType => BaseType): BaseType = { if (children.nonEmpty) { var changed = false + def mapChild(child: Any): Any = child match { + case arg: TreeNode[_] if containsChild(arg) => + val newChild = f(arg.asInstanceOf[BaseType]) + if (!(newChild fastEquals arg)) { + changed = true + newChild + } else { + arg + } + case tuple@(arg1: TreeNode[_], arg2: TreeNode[_]) => + val newChild1 = if (containsChild(arg1)) { + f(arg1.asInstanceOf[BaseType]) + } else { + arg1.asInstanceOf[BaseType] + } + + val newChild2 = if (containsChild(arg2)) { + f(arg2.asInstanceOf[BaseType]) + } else { + arg2.asInstanceOf[BaseType] + } + + if (!(newChild1 fastEquals arg1) || !(newChild2 fastEquals arg2)) { + changed = true + (newChild1, newChild2) + } else { + tuple + } + case other => other + } + val newArgs = mapProductIterator { case arg: TreeNode[_] if containsChild(arg) => val newChild = f(arg.asInstanceOf[BaseType]) @@ -330,36 +350,8 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { case other => other }.view.force // `mapValues` is lazy and we need to force it to materialize case d: DataType => d // Avoid unpacking Structs - case args: Traversable[_] => args.map { - case arg: TreeNode[_] if containsChild(arg) => - val newChild = f(arg.asInstanceOf[BaseType]) - if (!(newChild fastEquals arg)) { - changed = true - newChild - } else { - arg - } - case tuple@(arg1: TreeNode[_], arg2: TreeNode[_]) => - val newChild1 = if (containsChild(arg1)) { - f(arg1.asInstanceOf[BaseType]) - } else { - arg1.asInstanceOf[BaseType] - } - - val newChild2 = if (containsChild(arg2)) { - f(arg2.asInstanceOf[BaseType]) - } else { - arg2.asInstanceOf[BaseType] - } - - if (!(newChild1 fastEquals arg1) || !(newChild2 fastEquals arg2)) { - changed = true - (newChild1, newChild2) - } else { - tuple - } - case other => other - } + case args: Stream[_] => args.map(mapChild).force // Force materialization on stream + case args: Traversable[_] => args.map(mapChild) case nonChild: AnyRef => nonChild case null => null } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/GenericArrayData.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/GenericArrayData.scala index 9e39ed9c3a778..83ad08d8e1758 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/GenericArrayData.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/GenericArrayData.scala @@ -122,7 +122,7 @@ class GenericArrayData(val array: Array[Any]) extends ArrayData { if (!o2.isInstanceOf[Double] || ! java.lang.Double.isNaN(o2.asInstanceOf[Double])) { return false } - case _ => if (o1 != o2) { + case _ => if (!o1.equals(o2)) { return false } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala index 1dcda49a3af6a..b795abea95a74 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.util -import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion} import org.apache.spark.sql.catalyst.expressions.RowOrdering import org.apache.spark.sql.types._ @@ -42,18 +42,12 @@ object TypeUtils { } def checkForSameTypeInputExpr(types: Seq[DataType], caller: String): TypeCheckResult = { - if (types.size <= 1) { + if (TypeCoercion.haveSameType(types)) { TypeCheckResult.TypeCheckSuccess } else { - val firstType = types.head - types.foreach { t => - if (!t.sameType(firstType)) { - return TypeCheckResult.TypeCheckFailure( - s"input to $caller should all be the same type, but it's " + - types.map(_.simpleString).mkString("[", ", ", "]")) - } - } - TypeCheckResult.TypeCheckSuccess + return TypeCheckResult.TypeCheckFailure( + s"input to $caller should all be the same type, but it's " + + types.map(_.simpleString).mkString("[", ", ", "]")) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 8d2320d8a6ed7..41fe0c3b60d9e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -378,6 +378,43 @@ object SQLConf { .booleanConf .createWithDefault(true) + val PARQUET_FILTER_PUSHDOWN_TIMESTAMP_ENABLED = + buildConf("spark.sql.parquet.filterPushdown.timestamp") + .doc("If true, enables Parquet filter push-down optimization for Timestamp. " + + "This configuration only has an effect when 'spark.sql.parquet.filterPushdown' is " + + "enabled and Timestamp stored as TIMESTAMP_MICROS or TIMESTAMP_MILLIS type.") + .internal() + .booleanConf + .createWithDefault(true) + + val PARQUET_FILTER_PUSHDOWN_DECIMAL_ENABLED = + buildConf("spark.sql.parquet.filterPushdown.decimal") + .doc("If true, enables Parquet filter push-down optimization for Decimal. " + + "This configuration only has an effect when 'spark.sql.parquet.filterPushdown' is enabled.") + .internal() + .booleanConf + .createWithDefault(true) + + val PARQUET_FILTER_PUSHDOWN_STRING_STARTSWITH_ENABLED = + buildConf("spark.sql.parquet.filterPushdown.string.startsWith") + .doc("If true, enables Parquet filter push-down optimization for string startsWith function. " + + "This configuration only has an effect when 'spark.sql.parquet.filterPushdown' is enabled.") + .internal() + .booleanConf + .createWithDefault(true) + + val PARQUET_FILTER_PUSHDOWN_INFILTERTHRESHOLD = + buildConf("spark.sql.parquet.pushdown.inFilterThreshold") + .doc("The maximum number of values to filter push-down optimization for IN predicate. " + + "Large threshold won't necessarily provide much better performance. " + + "The experiment argued that 300 is the limit threshold. " + + "By setting this value to 0 this feature can be disabled. " + + "This configuration only has an effect when 'spark.sql.parquet.filterPushdown' is enabled.") + .internal() + .intConf + .checkValue(threshold => threshold >= 0, "The threshold must not be negative.") + .createWithDefault(10) + val PARQUET_WRITE_LEGACY_FORMAT = buildConf("spark.sql.parquet.writeLegacyFormat") .doc("Whether to be compatible with the legacy Parquet format adopted by Spark 1.4 and prior " + "versions, when converting Parquet schema to Spark SQL schema and vice versa.") @@ -867,6 +904,21 @@ object SQLConf { .stringConf .createWithDefault("org.apache.spark.sql.execution.streaming.ManifestFileCommitProtocol") + val STREAMING_MULTIPLE_WATERMARK_POLICY = + buildConf("spark.sql.streaming.multipleWatermarkPolicy") + .doc("Policy to calculate the global watermark value when there are multiple watermark " + + "operators in a streaming query. The default value is 'min' which chooses " + + "the minimum watermark reported across multiple operators. Other alternative value is" + + "'max' which chooses the maximum across multiple operators." + + "Note: This configuration cannot be changed between query restarts from the same " + + "checkpoint location.") + .stringConf + .checkValue( + str => Set("min", "max").contains(str.toLowerCase), + "Invalid value for 'spark.sql.streaming.multipleWatermarkPolicy'. " + + "Valid values are 'min' and 'max'") + .createWithDefault("min") // must be same as MultipleWatermarkPolicy.DEFAULT_POLICY_NAME + val OBJECT_AGG_SORT_BASED_FALLBACK_THRESHOLD = buildConf("spark.sql.objectHashAggregate.sortBased.fallbackThreshold") .internal() @@ -1161,6 +1213,16 @@ object SQLConf { .booleanConf .createWithDefault(true) + val PANDAS_GROUPED_MAP_ASSIGN_COLUMNS_BY_POSITION = + buildConf("spark.sql.execution.pandas.groupedMap.assignColumnsByPosition") + .internal() + .doc("When true, a grouped map Pandas UDF will assign columns from the returned " + + "Pandas DataFrame based on position, regardless of column label type. When false, " + + "columns will be looked up by name if labeled with a string and fallback to use " + + "position if not. This configuration will be deprecated in future releases.") + .booleanConf + .createWithDefault(false) + val REPLACE_EXCEPT_WITH_FILTER = buildConf("spark.sql.optimizer.replaceExceptWithFilter") .internal() .doc("When true, the apply function of the rule verifies whether the right node of the" + @@ -1314,6 +1376,35 @@ object SQLConf { "Other column values can be ignored during parsing even if they are malformed.") .booleanConf .createWithDefault(true) + + val LEGACY_SIZE_OF_NULL = buildConf("spark.sql.legacy.sizeOfNull") + .doc("If it is set to true, size of null returns -1. This behavior was inherited from Hive. " + + "The size function returns null for null input if the flag is disabled.") + .booleanConf + .createWithDefault(true) + + val REPL_EAGER_EVAL_ENABLED = buildConf("spark.sql.repl.eagerEval.enabled") + .doc("Enables eager evaluation or not. When true, the top K rows of Dataset will be " + + "displayed if and only if the REPL supports the eager evaluation. Currently, the " + + "eager evaluation is only supported in PySpark. For the notebooks like Jupyter, " + + "the HTML table (generated by _repr_html_) will be returned. For plain Python REPL, " + + "the returned outputs are formatted like dataframe.show().") + .booleanConf + .createWithDefault(false) + + val REPL_EAGER_EVAL_MAX_NUM_ROWS = buildConf("spark.sql.repl.eagerEval.maxNumRows") + .doc("The max number of rows that are returned by eager evaluation. This only takes " + + "effect when spark.sql.repl.eagerEval.enabled is set to true. The valid range of this " + + "config is from 0 to (Int.MaxValue - 1), so the invalid config like negative and " + + "greater than (Int.MaxValue - 1) will be normalized to 0 and (Int.MaxValue - 1).") + .intConf + .createWithDefault(20) + + val REPL_EAGER_EVAL_TRUNCATE = buildConf("spark.sql.repl.eagerEval.truncate") + .doc("The max number of characters for each cell that is returned by eager evaluation. " + + "This only takes effect when spark.sql.repl.eagerEval.enabled is set to true.") + .intConf + .createWithDefault(20) } /** @@ -1420,6 +1511,16 @@ class SQLConf extends Serializable with Logging { def parquetFilterPushDownDate: Boolean = getConf(PARQUET_FILTER_PUSHDOWN_DATE_ENABLED) + def parquetFilterPushDownTimestamp: Boolean = getConf(PARQUET_FILTER_PUSHDOWN_TIMESTAMP_ENABLED) + + def parquetFilterPushDownDecimal: Boolean = getConf(PARQUET_FILTER_PUSHDOWN_DECIMAL_ENABLED) + + def parquetFilterPushDownStringStartWith: Boolean = + getConf(PARQUET_FILTER_PUSHDOWN_STRING_STARTSWITH_ENABLED) + + def parquetFilterPushDownInFilterThreshold: Int = + getConf(PARQUET_FILTER_PUSHDOWN_INFILTERTHRESHOLD) + def orcFilterPushDown: Boolean = getConf(ORC_FILTER_PUSHDOWN_ENABLED) def verifyPartitionPath: Boolean = getConf(HIVE_VERIFY_PARTITION_PATH) @@ -1458,6 +1559,8 @@ class SQLConf extends Serializable with Logging { def tableRelationCacheSize: Int = getConf(StaticSQLConf.FILESOURCE_TABLE_RELATION_CACHE_SIZE) + def codegenCacheMaxEntries: Int = getConf(StaticSQLConf.CODEGEN_CACHE_MAX_ENTRIES) + def exchangeReuseEnabled: Boolean = getConf(EXCHANGE_REUSE_ENABLED) def caseSensitiveAnalysis: Boolean = getConf(SQLConf.CASE_SENSITIVE) @@ -1647,6 +1750,9 @@ class SQLConf extends Serializable with Logging { def pandasRespectSessionTimeZone: Boolean = getConf(PANDAS_RESPECT_SESSION_LOCAL_TIMEZONE) + def pandasGroupedMapAssignColumnssByPosition: Boolean = + getConf(SQLConf.PANDAS_GROUPED_MAP_ASSIGN_COLUMNS_BY_POSITION) + def replaceExceptWithFilter: Boolean = getConf(REPLACE_EXCEPT_WITH_FILTER) def decimalOperationsAllowPrecisionLoss: Boolean = getConf(DECIMAL_OPERATIONS_ALLOW_PREC_LOSS) @@ -1673,6 +1779,14 @@ class SQLConf extends Serializable with Logging { def csvColumnPruning: Boolean = getConf(SQLConf.CSV_PARSER_COLUMN_PRUNING) + def legacySizeOfNull: Boolean = getConf(SQLConf.LEGACY_SIZE_OF_NULL) + + def isReplEagerEvalEnabled: Boolean = getConf(SQLConf.REPL_EAGER_EVAL_ENABLED) + + def replEagerEvalMaxNumRows: Int = getConf(SQLConf.REPL_EAGER_EVAL_MAX_NUM_ROWS) + + def replEagerEvalTruncate: Int = getConf(SQLConf.REPL_EAGER_EVAL_TRUNCATE) + /** ********************** SQLConf functionality methods ************ */ /** Set Spark SQL configuration properties. */ @@ -1829,4 +1943,8 @@ class SQLConf extends Serializable with Logging { } cloned } + + def isModifiable(key: String): Boolean = { + sqlConfEntries.containsKey(key) && !staticConfKeys.contains(key) + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/StaticSQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/StaticSQLConf.scala index fe0ad39c29025..384b1917a1f79 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/StaticSQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/StaticSQLConf.scala @@ -66,6 +66,14 @@ object StaticSQLConf { .checkValue(cacheSize => cacheSize >= 0, "The maximum size of the cache must not be negative") .createWithDefault(1000) + val CODEGEN_CACHE_MAX_ENTRIES = buildStaticConf("spark.sql.codegen.cache.maxEntries") + .internal() + .doc("When nonzero, enable caching of generated classes for operators and expressions. " + + "All jobs share the cache that can use up to the specified number for generated classes.") + .intConf + .checkValue(maxEntries => maxEntries >= 0, "The maximum must not be negative") + .createWithDefault(100) + // When enabling the debug, Spark SQL internal table properties are not filtered out; however, // some related DDL commands (e.g., ANALYZE TABLE and CREATE TABLE LIKE) might not work properly. val DEBUG_MODE = buildStaticConf("spark.sql.debug") @@ -96,6 +104,14 @@ object StaticSQLConf { .toSequence .createOptional + val STREAMING_QUERY_LISTENERS = buildStaticConf("spark.sql.streaming.streamingQueryListeners") + .doc("List of class names implementing StreamingQueryListener that will be automatically " + + "added to newly created sessions. The classes should have either a no-arg constructor, " + + "or a constructor that expects a SparkConf argument.") + .stringConf + .toSequence + .createOptional + val UI_RETAINED_EXECUTIONS = buildStaticConf("spark.sql.ui.retainedExecutions") .doc("Number of executions to retain in the Spark UI.") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala index 0bef11659fc9e..fd40741cfb5f1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql.types import java.util.Locale +import scala.util.control.NonFatal + import org.json4s._ import org.json4s.JsonAST.JValue import org.json4s.JsonDSL._ @@ -26,6 +28,7 @@ import org.json4s.jackson.JsonMethods._ import org.apache.spark.annotation.InterfaceStability import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.parser.CatalystSqlParser import org.apache.spark.sql.internal.SQLConf import org.apache.spark.util.Utils @@ -110,6 +113,14 @@ abstract class DataType extends AbstractDataType { @InterfaceStability.Stable object DataType { + def fromDDL(ddl: String): DataType = { + try { + CatalystSqlParser.parseDataType(ddl) + } catch { + case NonFatal(_) => CatalystSqlParser.parseTableSchema(ddl) + } + } + def fromJson(json: String): DataType = parseDataType(parse(json)) private val nonDecimalNameToType = { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystTypeConvertersSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystTypeConvertersSuite.scala index f99af9b84d959..89452ee05cff3 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystTypeConvertersSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystTypeConvertersSuite.scala @@ -22,6 +22,7 @@ import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.expressions.UnsafeArrayData import org.apache.spark.sql.catalyst.util.GenericArrayData import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String class CatalystTypeConvertersSuite extends SparkFunSuite { @@ -139,4 +140,11 @@ class CatalystTypeConvertersSuite extends SparkFunSuite { assert(exception.getMessage.contains("The value (0.1) of the type " + "(java.lang.Double) cannot be converted to the string type")) } + + test("SPARK-24571: convert Char to String") { + val chr: Char = 'X' + val converter = CatalystTypeConverters.createToCatalystConverter(StringType) + val expected = UTF8String.fromString("X") + assert(converter(chr) === expected) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/DistributionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/DistributionSuite.scala index b47b8adfe5d55..39228102682b9 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/DistributionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/DistributionSuite.scala @@ -41,34 +41,127 @@ class DistributionSuite extends SparkFunSuite { } } - test("HashPartitioning (with nullSafe = true) is the output partitioning") { - // Cases which do not need an exchange between two data properties. + test("UnspecifiedDistribution and AllTuples") { + // except `BroadcastPartitioning`, all other partitioning can satisfy UnspecifiedDistribution checkSatisfied( - HashPartitioning(Seq('a, 'b, 'c), 10), + UnknownPartitioning(-1), UnspecifiedDistribution, true) checkSatisfied( - HashPartitioning(Seq('a, 'b, 'c), 10), - ClusteredDistribution(Seq('a, 'b, 'c)), + RoundRobinPartitioning(10), + UnspecifiedDistribution, true) checkSatisfied( - HashPartitioning(Seq('b, 'c), 10), - ClusteredDistribution(Seq('a, 'b, 'c)), + SinglePartition, + UnspecifiedDistribution, + true) + + checkSatisfied( + HashPartitioning(Seq('a), 10), + UnspecifiedDistribution, + true) + + checkSatisfied( + RangePartitioning(Seq('a.asc), 10), + UnspecifiedDistribution, + true) + + checkSatisfied( + BroadcastPartitioning(IdentityBroadcastMode), + UnspecifiedDistribution, + false) + + // except `BroadcastPartitioning`, all other partitioning can satisfy AllTuples if they have + // only one partition. + checkSatisfied( + UnknownPartitioning(1), + AllTuples, + true) + + checkSatisfied( + UnknownPartitioning(10), + AllTuples, + false) + + checkSatisfied( + RoundRobinPartitioning(1), + AllTuples, + true) + + checkSatisfied( + RoundRobinPartitioning(10), + AllTuples, + false) + + checkSatisfied( + SinglePartition, + AllTuples, + true) + + checkSatisfied( + HashPartitioning(Seq('a), 1), + AllTuples, true) + checkSatisfied( + HashPartitioning(Seq('a), 10), + AllTuples, + false) + + checkSatisfied( + RangePartitioning(Seq('a.asc), 1), + AllTuples, + true) + + checkSatisfied( + RangePartitioning(Seq('a.asc), 10), + AllTuples, + false) + + checkSatisfied( + BroadcastPartitioning(IdentityBroadcastMode), + AllTuples, + false) + } + + test("SinglePartition is the output partitioning") { + // SinglePartition can satisfy all the distributions except `BroadcastDistribution` checkSatisfied( SinglePartition, ClusteredDistribution(Seq('a, 'b, 'c)), true) + checkSatisfied( + SinglePartition, + HashClusteredDistribution(Seq('a, 'b, 'c)), + true) + checkSatisfied( SinglePartition, OrderedDistribution(Seq('a.asc, 'b.asc, 'c.asc)), true) - // Cases which need an exchange between two data properties. + checkSatisfied( + SinglePartition, + BroadcastDistribution(IdentityBroadcastMode), + false) + } + + test("HashPartitioning is the output partitioning") { + // HashPartitioning can satisfy ClusteredDistribution iff its hash expressions are a subset of + // the required clustering expressions. + checkSatisfied( + HashPartitioning(Seq('a, 'b, 'c), 10), + ClusteredDistribution(Seq('a, 'b, 'c)), + true) + + checkSatisfied( + HashPartitioning(Seq('b, 'c), 10), + ClusteredDistribution(Seq('a, 'b, 'c)), + true) + checkSatisfied( HashPartitioning(Seq('a, 'b, 'c), 10), ClusteredDistribution(Seq('b, 'c)), @@ -79,37 +172,43 @@ class DistributionSuite extends SparkFunSuite { ClusteredDistribution(Seq('d, 'e)), false) + // HashPartitioning can satisfy HashClusteredDistribution iff its hash expressions are exactly + // same with the required hash clustering expressions. checkSatisfied( HashPartitioning(Seq('a, 'b, 'c), 10), - AllTuples, + HashClusteredDistribution(Seq('a, 'b, 'c)), + true) + + checkSatisfied( + HashPartitioning(Seq('c, 'b, 'a), 10), + HashClusteredDistribution(Seq('a, 'b, 'c)), false) + checkSatisfied( + HashPartitioning(Seq('a, 'b), 10), + HashClusteredDistribution(Seq('a, 'b, 'c)), + false) + + // HashPartitioning cannot satisfy OrderedDistribution checkSatisfied( HashPartitioning(Seq('a, 'b, 'c), 10), OrderedDistribution(Seq('a.asc, 'b.asc, 'c.asc)), false) checkSatisfied( - HashPartitioning(Seq('b, 'c), 10), + HashPartitioning(Seq('a, 'b, 'c), 1), OrderedDistribution(Seq('a.asc, 'b.asc, 'c.asc)), - false) + false) // TODO: this can be relaxed. - // TODO: We should check functional dependencies - /* checkSatisfied( - ClusteredDistribution(Seq('b)), - ClusteredDistribution(Seq('b + 1)), - true) - */ + HashPartitioning(Seq('b, 'c), 10), + OrderedDistribution(Seq('a.asc, 'b.asc, 'c.asc)), + false) } test("RangePartitioning is the output partitioning") { - // Cases which do not need an exchange between two data properties. - checkSatisfied( - RangePartitioning(Seq('a.asc, 'b.asc, 'c.asc), 10), - UnspecifiedDistribution, - true) - + // RangePartitioning can satisfy OrderedDistribution iff its ordering is a prefix + // of the required ordering, or the required ordering is a prefix of its ordering. checkSatisfied( RangePartitioning(Seq('a.asc, 'b.asc, 'c.asc), 10), OrderedDistribution(Seq('a.asc, 'b.asc, 'c.asc)), @@ -125,6 +224,27 @@ class DistributionSuite extends SparkFunSuite { OrderedDistribution(Seq('a.asc, 'b.asc, 'c.asc, 'd.desc)), true) + // TODO: We can have an optimization to first sort the dataset + // by a.asc and then sort b, and c in a partition. This optimization + // should tradeoff the benefit of a less number of Exchange operators + // and the parallelism. + checkSatisfied( + RangePartitioning(Seq('a.asc, 'b.asc, 'c.asc), 10), + OrderedDistribution(Seq('a.asc, 'b.desc, 'c.asc)), + false) + + checkSatisfied( + RangePartitioning(Seq('a.asc, 'b.asc, 'c.asc), 10), + OrderedDistribution(Seq('b.asc, 'a.asc)), + false) + + checkSatisfied( + RangePartitioning(Seq('a.asc, 'b.asc, 'c.asc), 10), + OrderedDistribution(Seq('a.asc, 'b.asc, 'd.desc)), + false) + + // RangePartitioning can satisfy ClusteredDistribution iff its ordering expressions are a subset + // of the required clustering expressions. checkSatisfied( RangePartitioning(Seq('a.asc, 'b.asc, 'c.asc), 10), ClusteredDistribution(Seq('a, 'b, 'c)), @@ -140,34 +260,47 @@ class DistributionSuite extends SparkFunSuite { ClusteredDistribution(Seq('b, 'c, 'a, 'd)), true) - // Cases which need an exchange between two data properties. - // TODO: We can have an optimization to first sort the dataset - // by a.asc and then sort b, and c in a partition. This optimization - // should tradeoff the benefit of a less number of Exchange operators - // and the parallelism. checkSatisfied( RangePartitioning(Seq('a.asc, 'b.asc, 'c.asc), 10), - OrderedDistribution(Seq('a.asc, 'b.desc, 'c.asc)), + ClusteredDistribution(Seq('a, 'b)), false) checkSatisfied( RangePartitioning(Seq('a.asc, 'b.asc, 'c.asc), 10), - OrderedDistribution(Seq('b.asc, 'a.asc)), + ClusteredDistribution(Seq('c, 'd)), false) + // RangePartitioning cannot satisfy HashClusteredDistribution checkSatisfied( RangePartitioning(Seq('a.asc, 'b.asc, 'c.asc), 10), - ClusteredDistribution(Seq('a, 'b)), + HashClusteredDistribution(Seq('a, 'b, 'c)), false) + } + test("Partitioning.numPartitions must match Distribution.requiredNumPartitions to satisfy it") { checkSatisfied( - RangePartitioning(Seq('a.asc, 'b.asc, 'c.asc), 10), - ClusteredDistribution(Seq('c, 'd)), + SinglePartition, + ClusteredDistribution(Seq('a, 'b, 'c), Some(10)), + false) + + checkSatisfied( + SinglePartition, + HashClusteredDistribution(Seq('a, 'b, 'c), Some(10)), + false) + + checkSatisfied( + HashPartitioning(Seq('a, 'b, 'c), 10), + ClusteredDistribution(Seq('a, 'b, 'c), Some(5)), + false) + + checkSatisfied( + HashPartitioning(Seq('a, 'b, 'c), 10), + HashClusteredDistribution(Seq('a, 'b, 'c), Some(5)), false) checkSatisfied( RangePartitioning(Seq('a.asc, 'b.asc, 'c.asc), 10), - AllTuples, + ClusteredDistribution(Seq('a, 'b, 'c), Some(5)), false) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index cd8579584eada..bbcdf6c1b8481 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -21,6 +21,7 @@ import java.util.TimeZone import org.scalatest.Matchers +import org.apache.spark.api.python.PythonEvalType import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ @@ -557,4 +558,21 @@ class AnalysisSuite extends AnalysisTest with Matchers { SubqueryAlias("tbl", testRelation))) assertAnalysisError(barrier, Seq("cannot resolve '`tbl.b`'")) } + + test("SPARK-24208: analysis fails on self-join with FlatMapGroupsInPandas") { + val pythonUdf = PythonUDF("pyUDF", null, + StructType(Seq(StructField("a", LongType))), + Seq.empty, + PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF, + true) + val output = pythonUdf.dataType.asInstanceOf[StructType].toAttributes + val project = Project(Seq(UnresolvedAttribute("a")), testRelation) + val flatMapGroupsInPandas = FlatMapGroupsInPandas( + Seq(UnresolvedAttribute("a")), pythonUdf, output, project) + val left = SubqueryAlias("temp0", flatMapGroupsInPandas) + val right = SubqueryAlias("temp1", flatMapGroupsInPandas) + val join = Join(left, right, Inner, None) + assertAnalysisSuccess( + Project(Seq(UnresolvedAttribute("temp0.a"), UnresolvedAttribute("temp1.a")), join)) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/LookupFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/LookupFunctionsSuite.scala new file mode 100644 index 0000000000000..cea0f2a9cbc97 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/LookupFunctionsSuite.scala @@ -0,0 +1,104 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis + +import java.net.URI + +import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier} +import org.apache.spark.sql.catalyst.catalog.{CatalogDatabase, InMemoryCatalog, SessionCatalog} +import org.apache.spark.sql.catalyst.expressions.Alias +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.internal.SQLConf + +class LookupFunctionsSuite extends PlanTest { + + test("SPARK-23486: the functionExists for the Persistent function check") { + val externalCatalog = new CustomInMemoryCatalog + val conf = new SQLConf() + val catalog = new SessionCatalog(externalCatalog, FunctionRegistry.builtin, conf) + val analyzer = { + catalog.createDatabase( + CatalogDatabase("default", "", new URI("loc"), Map.empty), + ignoreIfExists = false) + new Analyzer(catalog, conf) + } + + def table(ref: String): LogicalPlan = UnresolvedRelation(TableIdentifier(ref)) + val unresolvedPersistentFunc = UnresolvedFunction("func", Seq.empty, false) + val unresolvedRegisteredFunc = UnresolvedFunction("max", Seq.empty, false) + val plan = Project( + Seq(Alias(unresolvedPersistentFunc, "call1")(), Alias(unresolvedPersistentFunc, "call2")(), + Alias(unresolvedPersistentFunc, "call3")(), Alias(unresolvedRegisteredFunc, "call4")(), + Alias(unresolvedRegisteredFunc, "call5")()), + table("TaBlE")) + analyzer.LookupFunctions.apply(plan) + + assert(externalCatalog.getFunctionExistsCalledTimes == 1) + assert(analyzer.LookupFunctions.normalizeFuncName + (unresolvedPersistentFunc.name).database == Some("default")) + } + + test("SPARK-23486: the functionExists for the Registered function check") { + val externalCatalog = new InMemoryCatalog + val conf = new SQLConf() + val customerFunctionReg = new CustomerFunctionRegistry + val catalog = new SessionCatalog(externalCatalog, customerFunctionReg, conf) + val analyzer = { + catalog.createDatabase( + CatalogDatabase("default", "", new URI("loc"), Map.empty), + ignoreIfExists = false) + new Analyzer(catalog, conf) + } + + def table(ref: String): LogicalPlan = UnresolvedRelation(TableIdentifier(ref)) + val unresolvedRegisteredFunc = UnresolvedFunction("max", Seq.empty, false) + val plan = Project( + Seq(Alias(unresolvedRegisteredFunc, "call1")(), Alias(unresolvedRegisteredFunc, "call2")()), + table("TaBlE")) + analyzer.LookupFunctions.apply(plan) + + assert(customerFunctionReg.getIsRegisteredFunctionCalledTimes == 2) + assert(analyzer.LookupFunctions.normalizeFuncName + (unresolvedRegisteredFunc.name).database == Some("default")) + } +} + +class CustomerFunctionRegistry extends SimpleFunctionRegistry { + + private var isRegisteredFunctionCalledTimes: Int = 0; + + override def functionExists(funcN: FunctionIdentifier): Boolean = synchronized { + isRegisteredFunctionCalledTimes = isRegisteredFunctionCalledTimes + 1 + true + } + + def getIsRegisteredFunctionCalledTimes: Int = isRegisteredFunctionCalledTimes +} + +class CustomInMemoryCatalog extends InMemoryCatalog { + + private var functionExistsCalledTimes: Int = 0 + + override def functionExists(db: String, funcName: String): Boolean = synchronized { + functionExistsCalledTimes = functionExistsCalledTimes + 1 + true + } + + def getFunctionExistsCalledTimes: Int = functionExistsCalledTimes +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala index 0acd3b490447d..4161f09c63190 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala @@ -54,8 +54,9 @@ class TypeCoercionSuite extends AnalysisTest { // | NullType | ByteType | ShortType | IntegerType | LongType | DoubleType | FloatType | Dec(10, 2) | BinaryType | BooleanType | StringType | DateType | TimestampType | ArrayType | MapType | StructType | NullType | CalendarIntervalType | DecimalType(38, 18) | DoubleType | IntegerType | // | CalendarIntervalType | X | X | X | X | X | X | X | X | X | X | X | X | X | X | X | X | CalendarIntervalType | X | X | X | // +----------------------+----------+-----------+-------------+----------+------------+-----------+------------+------------+-------------+------------+----------+---------------+------------+----------+-------------+----------+----------------------+---------------------+-------------+--------------+ - // Note: MapType*, StructType* are castable only when the internal child types also match; otherwise, not castable. + // Note: StructType* is castable when all the internal child types are castable according to the table. // Note: ArrayType* is castable when the element type is castable according to the table. + // Note: MapType* is castable when both the key type and the value type are castable according to the table. // scalastyle:on line.size.limit private def shouldCast(from: DataType, to: AbstractDataType, expected: DataType): Unit = { @@ -396,7 +397,7 @@ class TypeCoercionSuite extends AnalysisTest { widenTest( StructType(Seq(StructField("a", IntegerType, nullable = false))), StructType(Seq(StructField("a", DoubleType, nullable = false))), - None) + Some(StructType(Seq(StructField("a", DoubleType, nullable = false))))) widenTest( StructType(Seq(StructField("a", IntegerType, nullable = false))), @@ -453,15 +454,18 @@ class TypeCoercionSuite extends AnalysisTest { def widenTestWithStringPromotion( t1: DataType, t2: DataType, - expected: Option[DataType]): Unit = { - checkWidenType(TypeCoercion.findWiderTypeForTwo, t1, t2, expected) + expected: Option[DataType], + isSymmetric: Boolean = true): Unit = { + checkWidenType(TypeCoercion.findWiderTypeForTwo, t1, t2, expected, isSymmetric) } def widenTestWithoutStringPromotion( t1: DataType, t2: DataType, - expected: Option[DataType]): Unit = { - checkWidenType(TypeCoercion.findWiderTypeWithoutStringPromotionForTwo, t1, t2, expected) + expected: Option[DataType], + isSymmetric: Boolean = true): Unit = { + checkWidenType( + TypeCoercion.findWiderTypeWithoutStringPromotionForTwo, t1, t2, expected, isSymmetric) } // Decimal @@ -487,12 +491,108 @@ class TypeCoercionSuite extends AnalysisTest { ArrayType(ArrayType(IntegerType), containsNull = false), ArrayType(ArrayType(LongType), containsNull = false), Some(ArrayType(ArrayType(LongType), containsNull = false))) + widenTestWithStringPromotion( + ArrayType(MapType(IntegerType, FloatType), containsNull = false), + ArrayType(MapType(LongType, DoubleType), containsNull = false), + Some(ArrayType(MapType(LongType, DoubleType), containsNull = false))) + widenTestWithStringPromotion( + ArrayType(new StructType().add("num", ShortType), containsNull = false), + ArrayType(new StructType().add("num", LongType), containsNull = false), + Some(ArrayType(new StructType().add("num", LongType), containsNull = false))) + + // MapType + widenTestWithStringPromotion( + MapType(ShortType, TimestampType, valueContainsNull = true), + MapType(DoubleType, StringType, valueContainsNull = false), + Some(MapType(DoubleType, StringType, valueContainsNull = true))) + widenTestWithStringPromotion( + MapType(IntegerType, ArrayType(TimestampType), valueContainsNull = false), + MapType(LongType, ArrayType(StringType), valueContainsNull = true), + Some(MapType(LongType, ArrayType(StringType), valueContainsNull = true))) + widenTestWithStringPromotion( + MapType(IntegerType, MapType(ShortType, TimestampType), valueContainsNull = false), + MapType(LongType, MapType(DoubleType, StringType), valueContainsNull = false), + Some(MapType(LongType, MapType(DoubleType, StringType), valueContainsNull = false))) + widenTestWithStringPromotion( + MapType(IntegerType, new StructType().add("num", ShortType), valueContainsNull = false), + MapType(LongType, new StructType().add("num", LongType), valueContainsNull = false), + Some(MapType(LongType, new StructType().add("num", LongType), valueContainsNull = false))) + + // StructType + widenTestWithStringPromotion( + new StructType() + .add("num", ShortType, nullable = true).add("ts", StringType, nullable = false), + new StructType() + .add("num", DoubleType, nullable = false).add("ts", TimestampType, nullable = true), + Some(new StructType() + .add("num", DoubleType, nullable = true).add("ts", StringType, nullable = true))) + widenTestWithStringPromotion( + new StructType() + .add("arr", ArrayType(ShortType, containsNull = false), nullable = false), + new StructType() + .add("arr", ArrayType(DoubleType, containsNull = true), nullable = false), + Some(new StructType() + .add("arr", ArrayType(DoubleType, containsNull = true), nullable = false))) + widenTestWithStringPromotion( + new StructType() + .add("map", MapType(ShortType, TimestampType, valueContainsNull = true), nullable = false), + new StructType() + .add("map", MapType(DoubleType, StringType, valueContainsNull = false), nullable = false), + Some(new StructType() + .add("map", MapType(DoubleType, StringType, valueContainsNull = true), nullable = false))) + + widenTestWithStringPromotion( + new StructType().add("num", IntegerType), + new StructType().add("num", LongType).add("str", StringType), + None) + widenTestWithoutStringPromotion( + new StructType().add("num", IntegerType), + new StructType().add("num", LongType).add("str", StringType), + None) + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") { + widenTestWithStringPromotion( + new StructType().add("a", IntegerType), + new StructType().add("A", LongType), + None) + widenTestWithoutStringPromotion( + new StructType().add("a", IntegerType), + new StructType().add("A", LongType), + None) + } + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") { + widenTestWithStringPromotion( + new StructType().add("a", IntegerType), + new StructType().add("A", LongType), + Some(new StructType().add("a", LongType)), + isSymmetric = false) + widenTestWithoutStringPromotion( + new StructType().add("a", IntegerType), + new StructType().add("A", LongType), + Some(new StructType().add("a", LongType)), + isSymmetric = false) + } // Without string promotion widenTestWithoutStringPromotion(IntegerType, StringType, None) widenTestWithoutStringPromotion(StringType, TimestampType, None) widenTestWithoutStringPromotion(ArrayType(LongType), ArrayType(StringType), None) widenTestWithoutStringPromotion(ArrayType(StringType), ArrayType(TimestampType), None) + widenTestWithoutStringPromotion( + MapType(LongType, IntegerType), MapType(StringType, IntegerType), None) + widenTestWithoutStringPromotion( + MapType(IntegerType, LongType), MapType(IntegerType, StringType), None) + widenTestWithoutStringPromotion( + MapType(StringType, IntegerType), MapType(TimestampType, IntegerType), None) + widenTestWithoutStringPromotion( + MapType(IntegerType, StringType), MapType(IntegerType, TimestampType), None) + widenTestWithoutStringPromotion( + new StructType().add("a", IntegerType), + new StructType().add("a", StringType), + None) + widenTestWithoutStringPromotion( + new StructType().add("a", StringType), + new StructType().add("a", IntegerType), + None) // String promotion widenTestWithStringPromotion(IntegerType, StringType, Some(StringType)) @@ -501,6 +601,30 @@ class TypeCoercionSuite extends AnalysisTest { ArrayType(LongType), ArrayType(StringType), Some(ArrayType(StringType))) widenTestWithStringPromotion( ArrayType(StringType), ArrayType(TimestampType), Some(ArrayType(StringType))) + widenTestWithStringPromotion( + MapType(LongType, IntegerType), + MapType(StringType, IntegerType), + Some(MapType(StringType, IntegerType))) + widenTestWithStringPromotion( + MapType(IntegerType, LongType), + MapType(IntegerType, StringType), + Some(MapType(IntegerType, StringType))) + widenTestWithStringPromotion( + MapType(StringType, IntegerType), + MapType(TimestampType, IntegerType), + Some(MapType(StringType, IntegerType))) + widenTestWithStringPromotion( + MapType(IntegerType, StringType), + MapType(IntegerType, TimestampType), + Some(MapType(IntegerType, StringType))) + widenTestWithStringPromotion( + new StructType().add("a", IntegerType), + new StructType().add("a", StringType), + Some(new StructType().add("a", StringType))) + widenTestWithStringPromotion( + new StructType().add("a", StringType), + new StructType().add("a", IntegerType), + Some(new StructType().add("a", StringType))) } private def ruleTest(rule: Rule[LogicalPlan], initial: Expression, transformed: Expression) { @@ -563,46 +687,43 @@ class TypeCoercionSuite extends AnalysisTest { ruleTest(rule, Coalesce(Seq(doubleLit, intLit, floatLit)), - Coalesce(Seq(Cast(doubleLit, DoubleType), - Cast(intLit, DoubleType), Cast(floatLit, DoubleType)))) + Coalesce(Seq(doubleLit, Cast(intLit, DoubleType), Cast(floatLit, DoubleType)))) ruleTest(rule, Coalesce(Seq(longLit, intLit, decimalLit)), Coalesce(Seq(Cast(longLit, DecimalType(22, 0)), - Cast(intLit, DecimalType(22, 0)), Cast(decimalLit, DecimalType(22, 0))))) + Cast(intLit, DecimalType(22, 0)), decimalLit))) ruleTest(rule, Coalesce(Seq(nullLit, intLit)), - Coalesce(Seq(Cast(nullLit, IntegerType), Cast(intLit, IntegerType)))) + Coalesce(Seq(Cast(nullLit, IntegerType), intLit))) ruleTest(rule, Coalesce(Seq(timestampLit, stringLit)), - Coalesce(Seq(Cast(timestampLit, StringType), Cast(stringLit, StringType)))) + Coalesce(Seq(Cast(timestampLit, StringType), stringLit))) ruleTest(rule, Coalesce(Seq(nullLit, floatNullLit, intLit)), - Coalesce(Seq(Cast(nullLit, FloatType), Cast(floatNullLit, FloatType), - Cast(intLit, FloatType)))) + Coalesce(Seq(Cast(nullLit, FloatType), floatNullLit, Cast(intLit, FloatType)))) ruleTest(rule, Coalesce(Seq(nullLit, intLit, decimalLit, doubleLit)), Coalesce(Seq(Cast(nullLit, DoubleType), Cast(intLit, DoubleType), - Cast(decimalLit, DoubleType), Cast(doubleLit, DoubleType)))) + Cast(decimalLit, DoubleType), doubleLit))) ruleTest(rule, Coalesce(Seq(nullLit, floatNullLit, doubleLit, stringLit)), Coalesce(Seq(Cast(nullLit, StringType), Cast(floatNullLit, StringType), - Cast(doubleLit, StringType), Cast(stringLit, StringType)))) + Cast(doubleLit, StringType), stringLit))) ruleTest(rule, Coalesce(Seq(timestampLit, intLit, stringLit)), - Coalesce(Seq(Cast(timestampLit, StringType), Cast(intLit, StringType), - Cast(stringLit, StringType)))) + Coalesce(Seq(Cast(timestampLit, StringType), Cast(intLit, StringType), stringLit))) ruleTest(rule, Coalesce(Seq(tsArrayLit, intArrayLit, strArrayLit)), Coalesce(Seq(Cast(tsArrayLit, ArrayType(StringType)), - Cast(intArrayLit, ArrayType(StringType)), Cast(strArrayLit, ArrayType(StringType))))) + Cast(intArrayLit, ArrayType(StringType)), strArrayLit))) } test("CreateArray casts") { @@ -611,7 +732,7 @@ class TypeCoercionSuite extends AnalysisTest { :: Literal(1) :: Literal.create(1.0, FloatType) :: Nil), - CreateArray(Cast(Literal(1.0), DoubleType) + CreateArray(Literal(1.0) :: Cast(Literal(1), DoubleType) :: Cast(Literal.create(1.0, FloatType), DoubleType) :: Nil)) @@ -623,7 +744,7 @@ class TypeCoercionSuite extends AnalysisTest { :: Nil), CreateArray(Cast(Literal(1.0), StringType) :: Cast(Literal(1), StringType) - :: Cast(Literal("a"), StringType) + :: Literal("a") :: Nil)) ruleTest(TypeCoercion.FunctionArgumentConversion, @@ -641,7 +762,7 @@ class TypeCoercionSuite extends AnalysisTest { :: Nil), CreateArray(Literal.create(null, DecimalType(5, 3)).cast(DecimalType(38, 38)) :: Literal.create(null, DecimalType(22, 10)).cast(DecimalType(38, 38)) - :: Literal.create(null, DecimalType(38, 38)).cast(DecimalType(38, 38)) + :: Literal.create(null, DecimalType(38, 38)) :: Nil)) } @@ -655,7 +776,7 @@ class TypeCoercionSuite extends AnalysisTest { :: Nil), CreateMap(Cast(Literal(1), FloatType) :: Literal("a") - :: Cast(Literal.create(2.0, FloatType), FloatType) + :: Literal.create(2.0, FloatType) :: Literal("b") :: Nil)) ruleTest(TypeCoercion.FunctionArgumentConversion, @@ -677,7 +798,7 @@ class TypeCoercionSuite extends AnalysisTest { :: Literal(3.0) :: Nil), CreateMap(Literal(1) - :: Cast(Literal("a"), StringType) + :: Literal("a") :: Literal(2) :: Cast(Literal(3.0), StringType) :: Nil)) @@ -690,7 +811,7 @@ class TypeCoercionSuite extends AnalysisTest { CreateMap(Literal(1) :: Literal.create(null, DecimalType(38, 0)).cast(DecimalType(38, 38)) :: Literal(2) - :: Literal.create(null, DecimalType(38, 38)).cast(DecimalType(38, 38)) + :: Literal.create(null, DecimalType(38, 38)) :: Nil)) // type coercion for both map keys and values ruleTest(TypeCoercion.FunctionArgumentConversion, @@ -700,8 +821,8 @@ class TypeCoercionSuite extends AnalysisTest { :: Literal(3.0) :: Nil), CreateMap(Cast(Literal(1), DoubleType) - :: Cast(Literal("a"), StringType) - :: Cast(Literal(2.0), DoubleType) + :: Literal("a") + :: Literal(2.0) :: Cast(Literal(3.0), StringType) :: Nil)) } @@ -713,7 +834,7 @@ class TypeCoercionSuite extends AnalysisTest { :: Literal(1) :: Literal.create(1.0, FloatType) :: Nil), - operator(Cast(Literal(1.0), DoubleType) + operator(Literal(1.0) :: Cast(Literal(1), DoubleType) :: Cast(Literal.create(1.0, FloatType), DoubleType) :: Nil)) @@ -724,14 +845,14 @@ class TypeCoercionSuite extends AnalysisTest { :: Nil), operator(Cast(Literal(1L), DecimalType(22, 0)) :: Cast(Literal(1), DecimalType(22, 0)) - :: Cast(Literal(new java.math.BigDecimal("1000000000000000000000")), DecimalType(22, 0)) + :: Literal(new java.math.BigDecimal("1000000000000000000000")) :: Nil)) ruleTest(TypeCoercion.FunctionArgumentConversion, operator(Literal(1.0) :: Literal.create(null, DecimalType(10, 5)) :: Literal(1) :: Nil), - operator(Literal(1.0).cast(DoubleType) + operator(Literal(1.0) :: Literal.create(null, DecimalType(10, 5)).cast(DoubleType) :: Literal(1).cast(DoubleType) :: Nil)) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala index 6abab0073cca3..50496a0410528 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala @@ -1114,11 +1114,13 @@ abstract class SessionCatalogSuite extends AnalysisTest { // And for hive serde table, hive metastore will set some values(e.g.transient_lastDdlTime) // in table's parameters and storage's properties, here we also ignore them. val actualPartsNormalize = actualParts.map(p => - p.copy(parameters = Map.empty, storage = p.storage.copy( + p.copy(parameters = Map.empty, createTime = -1, lastAccessTime = -1, + storage = p.storage.copy( properties = Map.empty, locationUri = None, serde = None))).toSet val expectedPartsNormalize = expectedParts.map(p => - p.copy(parameters = Map.empty, storage = p.storage.copy( + p.copy(parameters = Map.empty, createTime = -1, lastAccessTime = -1, + storage = p.storage.copy( properties = Map.empty, locationUri = None, serde = None))).toSet actualPartsNormalize == expectedPartsNormalize @@ -1215,6 +1217,42 @@ abstract class SessionCatalogSuite extends AnalysisTest { } } + test("isRegisteredFunction") { + withBasicCatalog { catalog => + // Returns false when the function does not register + assert(!catalog.isRegisteredFunction(FunctionIdentifier("temp1"))) + + // Returns true when the function does register + val tempFunc1 = (e: Seq[Expression]) => e.head + catalog.registerFunction(newFunc("iff", None), overrideIfExists = false, + functionBuilder = Some(tempFunc1) ) + assert(catalog.isRegisteredFunction(FunctionIdentifier("iff"))) + + // Returns false when using the createFunction + catalog.createFunction(newFunc("sum", Some("db2")), ignoreIfExists = false) + assert(!catalog.isRegisteredFunction(FunctionIdentifier("sum"))) + assert(!catalog.isRegisteredFunction(FunctionIdentifier("sum", Some("db2")))) + } + } + + test("isPersistentFunction") { + withBasicCatalog { catalog => + // Returns false when the function does not register + assert(!catalog.isPersistentFunction(FunctionIdentifier("temp2"))) + + // Returns false when the function does register + val tempFunc2 = (e: Seq[Expression]) => e.head + catalog.registerFunction(newFunc("iff", None), overrideIfExists = false, + functionBuilder = Some(tempFunc2)) + assert(!catalog.isPersistentFunction(FunctionIdentifier("iff"))) + + // Return true when using the createFunction + catalog.createFunction(newFunc("sum", Some("db2")), ignoreIfExists = false) + assert(catalog.isPersistentFunction(FunctionIdentifier("sum", Some("db2")))) + assert(!catalog.isPersistentFunction(FunctionIdentifier("db2.sum"))) + } + } + test("drop function") { withBasicCatalog { catalog => assert(catalog.externalCatalog.listFunctions("db2", "*").toSet == Set("func1")) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala index 6edb4348f8309..021217606dc03 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala @@ -282,6 +282,12 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper DataTypeTestUtils.ordered.foreach { dt => checkConsistencyBetweenInterpretedAndCodegen(Least, dt, 2) } + + val least = Least(Seq( + Literal.create(Seq(1, 2), ArrayType(IntegerType, containsNull = false)), + Literal.create(Seq(1, 3, null), ArrayType(IntegerType, containsNull = true)))) + assert(least.dataType === ArrayType(IntegerType, containsNull = true)) + checkEvaluation(least, Seq(1, 2)) } test("function greatest") { @@ -334,6 +340,12 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper DataTypeTestUtils.ordered.foreach { dt => checkConsistencyBetweenInterpretedAndCodegen(Greatest, dt, 2) } + + val greatest = Greatest(Seq( + Literal.create(Seq(1, 2), ArrayType(IntegerType, containsNull = false)), + Literal.create(Seq(1, 3, null), ArrayType(IntegerType, containsNull = true)))) + assert(greatest.dataType === ArrayType(IntegerType, containsNull = true)) + checkEvaluation(greatest, Seq(1, 3, null)) } test("SPARK-22499: Least and greatest should not generate codes beyond 64KB") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index 85e692bdc4ef1..f1e3bd091565d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -17,14 +17,21 @@ package org.apache.spark.sql.catalyst.expressions +import java.sql.{Date, Timestamp} +import java.util.TimeZone + import org.apache.spark.SparkFunSuite import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.util.DateTimeTestUtils +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.array.ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH +import org.apache.spark.unsafe.types.CalendarInterval class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { - test("Array and Map Size") { + def testSize(sizeOfNull: Any): Unit = { val a0 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType)) val a1 = Literal.create(Seq[Integer](), ArrayType(IntegerType)) val a2 = Literal.create(Seq(1, 2), ArrayType(IntegerType)) @@ -41,8 +48,24 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(Size(m1), 0) checkEvaluation(Size(m2), 1) - checkEvaluation(Size(Literal.create(null, MapType(StringType, StringType))), -1) - checkEvaluation(Size(Literal.create(null, ArrayType(StringType))), -1) + checkEvaluation( + Size(Literal.create(null, MapType(StringType, StringType))), + expected = sizeOfNull) + checkEvaluation( + Size(Literal.create(null, ArrayType(StringType))), + expected = sizeOfNull) + } + + test("Array and Map Size - legacy") { + withSQLConf(SQLConf.LEGACY_SIZE_OF_NULL.key -> "true") { + testSize(sizeOfNull = -1) + } + } + + test("Array and Map Size") { + withSQLConf(SQLConf.LEGACY_SIZE_OF_NULL.key -> "false") { + testSize(sizeOfNull = null) + } } test("MapKeys/MapValues") { @@ -80,6 +103,204 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(MapEntries(ms2), null) } + test("Map Concat") { + val m0 = Literal.create(Map("a" -> "1", "b" -> "2"), MapType(StringType, StringType, + valueContainsNull = false)) + val m1 = Literal.create(Map("c" -> "3", "a" -> "4"), MapType(StringType, StringType, + valueContainsNull = false)) + val m2 = Literal.create(Map("d" -> "4", "e" -> "5"), MapType(StringType, StringType)) + val m3 = Literal.create(Map("a" -> "1", "b" -> "2"), MapType(StringType, StringType)) + val m4 = Literal.create(Map("a" -> null, "c" -> "3"), MapType(StringType, StringType)) + val m5 = Literal.create(Map("a" -> 1, "b" -> 2), MapType(StringType, IntegerType)) + val m6 = Literal.create(Map("a" -> null, "c" -> 3), MapType(StringType, IntegerType)) + val m7 = Literal.create(Map(List(1, 2) -> 1, List(3, 4) -> 2), + MapType(ArrayType(IntegerType), IntegerType)) + val m8 = Literal.create(Map(List(5, 6) -> 3, List(1, 2) -> 4), + MapType(ArrayType(IntegerType), IntegerType)) + val m9 = Literal.create(Map(Map(1 -> 2, 3 -> 4) -> 1, Map(5 -> 6, 7 -> 8) -> 2), + MapType(MapType(IntegerType, IntegerType), IntegerType)) + val m10 = Literal.create(Map(Map(9 -> 10, 11 -> 12) -> 3, Map(1 -> 2, 3 -> 4) -> 4), + MapType(MapType(IntegerType, IntegerType), IntegerType)) + val m11 = Literal.create(Map(1 -> "1", 2 -> "2"), MapType(IntegerType, StringType, + valueContainsNull = false)) + val m12 = Literal.create(Map(3 -> "3", 4 -> "4"), MapType(IntegerType, StringType, + valueContainsNull = false)) + val mNull = Literal.create(null, MapType(StringType, StringType)) + + // overlapping maps + checkEvaluation(MapConcat(Seq(m0, m1)), + ( + Array("a", "b", "c", "a"), // keys + Array("1", "2", "3", "4") // values + ) + ) + + // maps with no overlap + checkEvaluation(MapConcat(Seq(m0, m2)), + Map("a" -> "1", "b" -> "2", "d" -> "4", "e" -> "5")) + + // 3 maps + checkEvaluation(MapConcat(Seq(m0, m1, m2)), + ( + Array("a", "b", "c", "a", "d", "e"), // keys + Array("1", "2", "3", "4", "4", "5") // values + ) + ) + + // null reference values + checkEvaluation(MapConcat(Seq(m3, m4)), + ( + Array("a", "b", "a", "c"), // keys + Array("1", "2", null, "3") // values + ) + ) + + // null primitive values + checkEvaluation(MapConcat(Seq(m5, m6)), + ( + Array("a", "b", "a", "c"), // keys + Array(1, 2, null, 3) // values + ) + ) + + // keys that are primitive + checkEvaluation(MapConcat(Seq(m11, m12)), + ( + Array(1, 2, 3, 4), // keys + Array("1", "2", "3", "4") // values + ) + ) + + // keys that are arrays, with overlap + checkEvaluation(MapConcat(Seq(m7, m8)), + ( + Array(List(1, 2), List(3, 4), List(5, 6), List(1, 2)), // keys + Array(1, 2, 3, 4) // values + ) + ) + + // keys that are maps, with overlap + checkEvaluation(MapConcat(Seq(m9, m10)), + ( + Array(Map(1 -> 2, 3 -> 4), Map(5 -> 6, 7 -> 8), Map(9 -> 10, 11 -> 12), + Map(1 -> 2, 3 -> 4)), // keys + Array(1, 2, 3, 4) // values + ) + ) + + // null map + checkEvaluation(MapConcat(Seq(m0, mNull)), null) + checkEvaluation(MapConcat(Seq(mNull, m0)), null) + checkEvaluation(MapConcat(Seq(mNull, mNull)), null) + checkEvaluation(MapConcat(Seq(mNull)), null) + + // single map + checkEvaluation(MapConcat(Seq(m0)), Map("a" -> "1", "b" -> "2")) + + // no map + checkEvaluation(MapConcat(Seq.empty), Map.empty) + + // force split expressions for input in generated code + val expectedKeys = Array.fill(65)(Seq("a", "b")).flatten ++ Array("d", "e") + val expectedValues = Array.fill(65)(Seq("1", "2")).flatten ++ Array("4", "5") + checkEvaluation(MapConcat( + Seq( + m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, + m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, + m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m2 + )), + (expectedKeys, expectedValues)) + + // argument checking + assert(MapConcat(Seq(m0, m1)).checkInputDataTypes().isSuccess) + assert(MapConcat(Seq(m5, m6)).checkInputDataTypes().isSuccess) + assert(MapConcat(Seq(m0, m5)).checkInputDataTypes().isFailure) + assert(MapConcat(Seq(m0, Literal(12))).checkInputDataTypes().isFailure) + assert(MapConcat(Seq(m0, m1)).dataType.keyType == StringType) + assert(MapConcat(Seq(m0, m1)).dataType.valueType == StringType) + assert(!MapConcat(Seq(m0, m1)).dataType.valueContainsNull) + assert(MapConcat(Seq(m5, m6)).dataType.keyType == StringType) + assert(MapConcat(Seq(m5, m6)).dataType.valueType == IntegerType) + assert(MapConcat(Seq.empty).dataType.keyType == StringType) + assert(MapConcat(Seq.empty).dataType.valueType == StringType) + assert(MapConcat(Seq(m5, m6)).dataType.valueContainsNull) + assert(MapConcat(Seq(m6, m5)).dataType.valueContainsNull) + assert(!MapConcat(Seq(m1, m2)).nullable) + assert(MapConcat(Seq(m1, mNull)).nullable) + + val mapConcat = MapConcat(Seq( + Literal.create(Map(Seq(1, 2) -> Seq("a", "b")), + MapType( + ArrayType(IntegerType, containsNull = false), + ArrayType(StringType, containsNull = false), + valueContainsNull = false)), + Literal.create(Map(Seq(3, 4, null) -> Seq("c", "d", null), Seq(6) -> null), + MapType( + ArrayType(IntegerType, containsNull = true), + ArrayType(StringType, containsNull = true), + valueContainsNull = true)))) + assert(mapConcat.dataType === + MapType( + ArrayType(IntegerType, containsNull = true), + ArrayType(StringType, containsNull = true), + valueContainsNull = true)) + checkEvaluation(mapConcat, Map( + Seq(1, 2) -> Seq("a", "b"), + Seq(3, 4, null) -> Seq("c", "d", null), + Seq(6) -> null)) + } + + test("MapFromEntries") { + def arrayType(keyType: DataType, valueType: DataType) : DataType = { + ArrayType( + StructType(Seq( + StructField("a", keyType), + StructField("b", valueType))), + true) + } + def r(values: Any*): InternalRow = create_row(values: _*) + + // Primitive-type keys and values + val aiType = arrayType(IntegerType, IntegerType) + val ai0 = Literal.create(Seq(r(1, 10), r(2, 20), r(3, 20)), aiType) + val ai1 = Literal.create(Seq(r(1, null), r(2, 20), r(3, null)), aiType) + val ai2 = Literal.create(Seq.empty, aiType) + val ai3 = Literal.create(null, aiType) + val ai4 = Literal.create(Seq(r(1, 10), r(1, 20)), aiType) + val ai5 = Literal.create(Seq(r(1, 10), r(null, 20)), aiType) + val ai6 = Literal.create(Seq(null, r(2, 20), null), aiType) + + checkEvaluation(MapFromEntries(ai0), Map(1 -> 10, 2 -> 20, 3 -> 20)) + checkEvaluation(MapFromEntries(ai1), Map(1 -> null, 2 -> 20, 3 -> null)) + checkEvaluation(MapFromEntries(ai2), Map.empty) + checkEvaluation(MapFromEntries(ai3), null) + checkEvaluation(MapKeys(MapFromEntries(ai4)), Seq(1, 1)) + checkExceptionInExpression[RuntimeException]( + MapFromEntries(ai5), + "The first field from a struct (key) can't be null.") + checkEvaluation(MapFromEntries(ai6), null) + + // Non-primitive-type keys and values + val asType = arrayType(StringType, StringType) + val as0 = Literal.create(Seq(r("a", "aa"), r("b", "bb"), r("c", "bb")), asType) + val as1 = Literal.create(Seq(r("a", null), r("b", "bb"), r("c", null)), asType) + val as2 = Literal.create(Seq.empty, asType) + val as3 = Literal.create(null, asType) + val as4 = Literal.create(Seq(r("a", "aa"), r("a", "bb")), asType) + val as5 = Literal.create(Seq(r("a", "aa"), r(null, "bb")), asType) + val as6 = Literal.create(Seq(null, r("b", "bb"), null), asType) + + checkEvaluation(MapFromEntries(as0), Map("a" -> "aa", "b" -> "bb", "c" -> "bb")) + checkEvaluation(MapFromEntries(as1), Map("a" -> null, "b" -> "bb", "c" -> null)) + checkEvaluation(MapFromEntries(as2), Map.empty) + checkEvaluation(MapFromEntries(as3), null) + checkEvaluation(MapKeys(MapFromEntries(as4)), Seq("a", "a")) + checkExceptionInExpression[RuntimeException]( + MapFromEntries(as5), + "The first field from a struct (key) can't be null.") + checkEvaluation(MapFromEntries(as6), null) + } + test("Sort Array") { val a0 = Literal.create(Seq(2, 1, 3), ArrayType(IntegerType)) val a1 = Literal.create(Seq[Integer](), ArrayType(IntegerType)) @@ -144,6 +365,8 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper val a1 = Literal.create(Seq[String](null, ""), ArrayType(StringType)) val a2 = Literal.create(Seq(null), ArrayType(LongType)) val a3 = Literal.create(null, ArrayType(StringType)) + val a4 = Literal.create(Seq(create_row(1)), ArrayType(StructType(Seq( + StructField("a", IntegerType, true))))) checkEvaluation(ArrayContains(a0, Literal(1)), true) checkEvaluation(ArrayContains(a0, Literal(0)), false) @@ -159,6 +382,11 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(ArrayContains(a3, Literal("")), null) checkEvaluation(ArrayContains(a3, Literal.create(null, StringType)), null) + checkEvaluation(ArrayContains(a4, Literal.create(create_row(1), StructType(Seq( + StructField("a", IntegerType, false))))), true) + checkEvaluation(ArrayContains(a4, Literal.create(create_row(0), StructType(Seq( + StructField("a", IntegerType, false))))), false) + // binary val b0 = Literal.create(Seq[Array[Byte]](Array[Byte](5, 6), Array[Byte](1, 2)), ArrayType(BinaryType)) @@ -421,6 +649,292 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper ArrayMax(Literal.create(Seq(1.123, 0.1234, 1.121), ArrayType(DoubleType))), 1.123) } + test("Sequence of numbers") { + // test null handling + + checkEvaluation(new Sequence(Literal(null, LongType), Literal(1L)), null) + checkEvaluation(new Sequence(Literal(1L), Literal(null, LongType)), null) + checkEvaluation(new Sequence(Literal(null, LongType), Literal(1L), Literal(1L)), null) + checkEvaluation(new Sequence(Literal(1L), Literal(null, LongType), Literal(1L)), null) + checkEvaluation(new Sequence(Literal(1L), Literal(1L), Literal(null, LongType)), null) + + // test sequence boundaries checking + + checkExceptionInExpression[IllegalArgumentException]( + new Sequence(Literal(Int.MinValue), Literal(Int.MaxValue), Literal(1)), + EmptyRow, s"Too long sequence: 4294967296. Should be <= $MAX_ROUNDED_ARRAY_LENGTH") + + checkExceptionInExpression[IllegalArgumentException]( + new Sequence(Literal(1), Literal(2), Literal(0)), EmptyRow, "boundaries: 1 to 2 by 0") + checkExceptionInExpression[IllegalArgumentException]( + new Sequence(Literal(2), Literal(1), Literal(0)), EmptyRow, "boundaries: 2 to 1 by 0") + checkExceptionInExpression[IllegalArgumentException]( + new Sequence(Literal(2), Literal(1), Literal(1)), EmptyRow, "boundaries: 2 to 1 by 1") + checkExceptionInExpression[IllegalArgumentException]( + new Sequence(Literal(1), Literal(2), Literal(-1)), EmptyRow, "boundaries: 1 to 2 by -1") + + // test sequence with one element (zero step or equal start and stop) + + checkEvaluation(new Sequence(Literal(1), Literal(1), Literal(-1)), Seq(1)) + checkEvaluation(new Sequence(Literal(1), Literal(1), Literal(0)), Seq(1)) + checkEvaluation(new Sequence(Literal(1), Literal(1), Literal(1)), Seq(1)) + checkEvaluation(new Sequence(Literal(1), Literal(2), Literal(2)), Seq(1)) + checkEvaluation(new Sequence(Literal(1), Literal(0), Literal(-2)), Seq(1)) + + // test sequence of different integral types (ascending and descending) + + checkEvaluation(new Sequence(Literal(1L), Literal(3L), Literal(1L)), Seq(1L, 2L, 3L)) + checkEvaluation(new Sequence(Literal(-3), Literal(3), Literal(3)), Seq(-3, 0, 3)) + checkEvaluation( + new Sequence(Literal(3.toShort), Literal(-3.toShort), Literal(-3.toShort)), + Seq(3.toShort, 0.toShort, -3.toShort)) + checkEvaluation( + new Sequence(Literal(-1.toByte), Literal(-3.toByte), Literal(-1.toByte)), + Seq(-1.toByte, -2.toByte, -3.toByte)) + } + + test("Sequence of timestamps") { + checkEvaluation(new Sequence( + Literal(Timestamp.valueOf("2018-01-01 00:00:00")), + Literal(Timestamp.valueOf("2018-01-02 00:00:00")), + Literal(CalendarInterval.fromString("interval 12 hours"))), + Seq( + Timestamp.valueOf("2018-01-01 00:00:00"), + Timestamp.valueOf("2018-01-01 12:00:00"), + Timestamp.valueOf("2018-01-02 00:00:00"))) + + checkEvaluation(new Sequence( + Literal(Timestamp.valueOf("2018-01-01 00:00:00")), + Literal(Timestamp.valueOf("2018-01-02 00:00:01")), + Literal(CalendarInterval.fromString("interval 12 hours"))), + Seq( + Timestamp.valueOf("2018-01-01 00:00:00"), + Timestamp.valueOf("2018-01-01 12:00:00"), + Timestamp.valueOf("2018-01-02 00:00:00"))) + + checkEvaluation(new Sequence( + Literal(Timestamp.valueOf("2018-01-02 00:00:00")), + Literal(Timestamp.valueOf("2018-01-01 00:00:00")), + Literal(CalendarInterval.fromString("interval 12 hours").negate())), + Seq( + Timestamp.valueOf("2018-01-02 00:00:00"), + Timestamp.valueOf("2018-01-01 12:00:00"), + Timestamp.valueOf("2018-01-01 00:00:00"))) + + checkEvaluation(new Sequence( + Literal(Timestamp.valueOf("2018-01-02 00:00:00")), + Literal(Timestamp.valueOf("2017-12-31 23:59:59")), + Literal(CalendarInterval.fromString("interval 12 hours").negate())), + Seq( + Timestamp.valueOf("2018-01-02 00:00:00"), + Timestamp.valueOf("2018-01-01 12:00:00"), + Timestamp.valueOf("2018-01-01 00:00:00"))) + + checkEvaluation(new Sequence( + Literal(Timestamp.valueOf("2018-01-01 00:00:00")), + Literal(Timestamp.valueOf("2018-03-01 00:00:00")), + Literal(CalendarInterval.fromString("interval 1 month"))), + Seq( + Timestamp.valueOf("2018-01-01 00:00:00"), + Timestamp.valueOf("2018-02-01 00:00:00"), + Timestamp.valueOf("2018-03-01 00:00:00"))) + + checkEvaluation(new Sequence( + Literal(Timestamp.valueOf("2018-03-01 00:00:00")), + Literal(Timestamp.valueOf("2018-01-01 00:00:00")), + Literal(CalendarInterval.fromString("interval 1 month").negate())), + Seq( + Timestamp.valueOf("2018-03-01 00:00:00"), + Timestamp.valueOf("2018-02-01 00:00:00"), + Timestamp.valueOf("2018-01-01 00:00:00"))) + + checkEvaluation(new Sequence( + Literal(Timestamp.valueOf("2018-03-03 00:00:00")), + Literal(Timestamp.valueOf("2018-01-01 00:00:00")), + Literal(CalendarInterval.fromString("interval 1 month 1 day").negate())), + Seq( + Timestamp.valueOf("2018-03-03 00:00:00"), + Timestamp.valueOf("2018-02-02 00:00:00"), + Timestamp.valueOf("2018-01-01 00:00:00"))) + + checkEvaluation(new Sequence( + Literal(Timestamp.valueOf("2018-01-31 00:00:00")), + Literal(Timestamp.valueOf("2018-04-30 00:00:00")), + Literal(CalendarInterval.fromString("interval 1 month"))), + Seq( + Timestamp.valueOf("2018-01-31 00:00:00"), + Timestamp.valueOf("2018-02-28 00:00:00"), + Timestamp.valueOf("2018-03-31 00:00:00"), + Timestamp.valueOf("2018-04-30 00:00:00"))) + + checkEvaluation(new Sequence( + Literal(Timestamp.valueOf("2018-01-01 00:00:00")), + Literal(Timestamp.valueOf("2018-03-01 00:00:00")), + Literal(CalendarInterval.fromString("interval 1 month 1 second"))), + Seq( + Timestamp.valueOf("2018-01-01 00:00:00"), + Timestamp.valueOf("2018-02-01 00:00:01"))) + + checkEvaluation(new Sequence( + Literal(Timestamp.valueOf("2018-01-01 00:00:00")), + Literal(Timestamp.valueOf("2018-03-01 00:04:06")), + Literal(CalendarInterval.fromString("interval 1 month 2 minutes 3 seconds"))), + Seq( + Timestamp.valueOf("2018-01-01 00:00:00"), + Timestamp.valueOf("2018-02-01 00:02:03"), + Timestamp.valueOf("2018-03-01 00:04:06"))) + + checkEvaluation(new Sequence( + Literal(Timestamp.valueOf("2018-01-01 00:00:00")), + Literal(Timestamp.valueOf("2023-01-01 00:00:00")), + Literal(CalendarInterval.fromYearMonthString("1-5"))), + Seq( + Timestamp.valueOf("2018-01-01 00:00:00.000"), + Timestamp.valueOf("2019-06-01 00:00:00.000"), + Timestamp.valueOf("2020-11-01 00:00:00.000"), + Timestamp.valueOf("2022-04-01 00:00:00.000"))) + + checkEvaluation(new Sequence( + Literal(Timestamp.valueOf("2022-04-01 00:00:00")), + Literal(Timestamp.valueOf("2017-01-01 00:00:00")), + Literal(CalendarInterval.fromYearMonthString("1-5").negate())), + Seq( + Timestamp.valueOf("2022-04-01 00:00:00.000"), + Timestamp.valueOf("2020-11-01 00:00:00.000"), + Timestamp.valueOf("2019-06-01 00:00:00.000"), + Timestamp.valueOf("2018-01-01 00:00:00.000"))) + } + + test("Sequence on DST boundaries") { + val timeZone = TimeZone.getTimeZone("Europe/Prague") + val dstOffset = timeZone.getDSTSavings + + def noDST(t: Timestamp): Timestamp = new Timestamp(t.getTime - dstOffset) + + DateTimeTestUtils.withDefaultTimeZone(timeZone) { + // Spring time change + checkEvaluation(new Sequence( + Literal(Timestamp.valueOf("2018-03-25 01:30:00")), + Literal(Timestamp.valueOf("2018-03-25 03:30:00")), + Literal(CalendarInterval.fromString("interval 30 minutes"))), + Seq( + Timestamp.valueOf("2018-03-25 01:30:00"), + Timestamp.valueOf("2018-03-25 03:00:00"), + Timestamp.valueOf("2018-03-25 03:30:00"))) + + // Autumn time change + checkEvaluation(new Sequence( + Literal(Timestamp.valueOf("2018-10-28 01:30:00")), + Literal(Timestamp.valueOf("2018-10-28 03:30:00")), + Literal(CalendarInterval.fromString("interval 30 minutes"))), + Seq( + Timestamp.valueOf("2018-10-28 01:30:00"), + noDST(Timestamp.valueOf("2018-10-28 02:00:00")), + noDST(Timestamp.valueOf("2018-10-28 02:30:00")), + Timestamp.valueOf("2018-10-28 02:00:00"), + Timestamp.valueOf("2018-10-28 02:30:00"), + Timestamp.valueOf("2018-10-28 03:00:00"), + Timestamp.valueOf("2018-10-28 03:30:00"))) + } + } + + test("Sequence of dates") { + DateTimeTestUtils.withDefaultTimeZone(TimeZone.getTimeZone("UTC")) { + checkEvaluation(new Sequence( + Literal(Date.valueOf("2018-01-01")), + Literal(Date.valueOf("2018-01-05")), + Literal(CalendarInterval.fromString("interval 2 days"))), + Seq( + Date.valueOf("2018-01-01"), + Date.valueOf("2018-01-03"), + Date.valueOf("2018-01-05"))) + + checkEvaluation(new Sequence( + Literal(Date.valueOf("2018-01-01")), + Literal(Date.valueOf("2018-03-01")), + Literal(CalendarInterval.fromString("interval 1 month"))), + Seq( + Date.valueOf("2018-01-01"), + Date.valueOf("2018-02-01"), + Date.valueOf("2018-03-01"))) + + checkEvaluation(new Sequence( + Literal(Date.valueOf("2018-01-31")), + Literal(Date.valueOf("2018-04-30")), + Literal(CalendarInterval.fromString("interval 1 month"))), + Seq( + Date.valueOf("2018-01-31"), + Date.valueOf("2018-02-28"), + Date.valueOf("2018-03-31"), + Date.valueOf("2018-04-30"))) + + checkEvaluation(new Sequence( + Literal(Date.valueOf("2018-01-01")), + Literal(Date.valueOf("2023-01-01")), + Literal(CalendarInterval.fromYearMonthString("1-5"))), + Seq( + Date.valueOf("2018-01-01"), + Date.valueOf("2019-06-01"), + Date.valueOf("2020-11-01"), + Date.valueOf("2022-04-01"))) + + checkExceptionInExpression[IllegalArgumentException]( + new Sequence( + Literal(Date.valueOf("1970-01-02")), + Literal(Date.valueOf("1970-01-01")), + Literal(CalendarInterval.fromString("interval 1 day"))), + EmptyRow, "sequence boundaries: 1 to 0 by 1") + + checkExceptionInExpression[IllegalArgumentException]( + new Sequence( + Literal(Date.valueOf("1970-01-01")), + Literal(Date.valueOf("1970-02-01")), + Literal(CalendarInterval.fromString("interval 1 month").negate())), + EmptyRow, + s"sequence boundaries: 0 to 2678400000000 by -${28 * CalendarInterval.MICROS_PER_DAY}") + } + } + + test("Sequence with default step") { + // +/- 1 for integral type + checkEvaluation(new Sequence(Literal(1), Literal(3)), Seq(1, 2, 3)) + checkEvaluation(new Sequence(Literal(3), Literal(1)), Seq(3, 2, 1)) + + // +/- 1 day for timestamps + checkEvaluation(new Sequence( + Literal(Timestamp.valueOf("2018-01-01 00:00:00")), + Literal(Timestamp.valueOf("2018-01-03 00:00:00"))), + Seq( + Timestamp.valueOf("2018-01-01 00:00:00"), + Timestamp.valueOf("2018-01-02 00:00:00"), + Timestamp.valueOf("2018-01-03 00:00:00"))) + + checkEvaluation(new Sequence( + Literal(Timestamp.valueOf("2018-01-03 00:00:00")), + Literal(Timestamp.valueOf("2018-01-01 00:00:00"))), + Seq( + Timestamp.valueOf("2018-01-03 00:00:00"), + Timestamp.valueOf("2018-01-02 00:00:00"), + Timestamp.valueOf("2018-01-01 00:00:00"))) + + // +/- 1 day for dates + checkEvaluation(new Sequence( + Literal(Date.valueOf("2018-01-01")), + Literal(Date.valueOf("2018-01-03"))), + Seq( + Date.valueOf("2018-01-01"), + Date.valueOf("2018-01-02"), + Date.valueOf("2018-01-03"))) + + checkEvaluation(new Sequence( + Literal(Date.valueOf("2018-01-03")), + Literal(Date.valueOf("2018-01-01"))), + Seq( + Date.valueOf("2018-01-03"), + Date.valueOf("2018-01-02"), + Date.valueOf("2018-01-01"))) + } + test("Reverse") { // Primitive-type elements val ai0 = Literal.create(Seq(2, 1, 4, 3), ArrayType(IntegerType)) @@ -557,11 +1071,11 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper test("Concat") { // Primitive-type elements - val ai0 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType)) - val ai1 = Literal.create(Seq.empty[Integer], ArrayType(IntegerType)) - val ai2 = Literal.create(Seq(4, null, 5), ArrayType(IntegerType)) - val ai3 = Literal.create(Seq(null, null), ArrayType(IntegerType)) - val ai4 = Literal.create(null, ArrayType(IntegerType)) + val ai0 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType, containsNull = false)) + val ai1 = Literal.create(Seq.empty[Integer], ArrayType(IntegerType, containsNull = false)) + val ai2 = Literal.create(Seq(4, null, 5), ArrayType(IntegerType, containsNull = true)) + val ai3 = Literal.create(Seq(null, null), ArrayType(IntegerType, containsNull = true)) + val ai4 = Literal.create(null, ArrayType(IntegerType, containsNull = false)) checkEvaluation(Concat(Seq(ai0)), Seq(1, 2, 3)) checkEvaluation(Concat(Seq(ai0, ai1)), Seq(1, 2, 3)) @@ -574,14 +1088,18 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(Concat(Seq(ai4, ai0)), null) // Non-primitive-type elements - val as0 = Literal.create(Seq("a", "b", "c"), ArrayType(StringType)) - val as1 = Literal.create(Seq.empty[String], ArrayType(StringType)) - val as2 = Literal.create(Seq("d", null, "e"), ArrayType(StringType)) - val as3 = Literal.create(Seq(null, null), ArrayType(StringType)) - val as4 = Literal.create(null, ArrayType(StringType)) - - val aa0 = Literal.create(Seq(Seq("a", "b"), Seq("c")), ArrayType(ArrayType(StringType))) - val aa1 = Literal.create(Seq(Seq("d"), Seq("e", "f")), ArrayType(ArrayType(StringType))) + val as0 = Literal.create(Seq("a", "b", "c"), ArrayType(StringType, containsNull = false)) + val as1 = Literal.create(Seq.empty[String], ArrayType(StringType, containsNull = false)) + val as2 = Literal.create(Seq("d", null, "e"), ArrayType(StringType, containsNull = true)) + val as3 = Literal.create(Seq(null, null), ArrayType(StringType, containsNull = true)) + val as4 = Literal.create(null, ArrayType(StringType, containsNull = false)) + + val aa0 = Literal.create(Seq(Seq("a", "b"), Seq("c")), + ArrayType(ArrayType(StringType, containsNull = false), containsNull = false)) + val aa1 = Literal.create(Seq(Seq("d"), Seq("e", "f")), + ArrayType(ArrayType(StringType, containsNull = false), containsNull = false)) + val aa2 = Literal.create(Seq(Seq("g", null), null), + ArrayType(ArrayType(StringType, containsNull = true), containsNull = true)) checkEvaluation(Concat(Seq(as0)), Seq("a", "b", "c")) checkEvaluation(Concat(Seq(as0, as1)), Seq("a", "b", "c")) @@ -594,6 +1112,15 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(Concat(Seq(as4, as0)), null) checkEvaluation(Concat(Seq(aa0, aa1)), Seq(Seq("a", "b"), Seq("c"), Seq("d"), Seq("e", "f"))) + + assert(Concat(Seq(ai0, ai1)).dataType.asInstanceOf[ArrayType].containsNull === false) + assert(Concat(Seq(ai0, ai2)).dataType.asInstanceOf[ArrayType].containsNull === true) + assert(Concat(Seq(as0, as1)).dataType.asInstanceOf[ArrayType].containsNull === false) + assert(Concat(Seq(as0, as2)).dataType.asInstanceOf[ArrayType].containsNull === true) + assert(Concat(Seq(aa0, aa1)).dataType === + ArrayType(ArrayType(StringType, containsNull = false), containsNull = false)) + assert(Concat(Seq(aa0, aa2)).dataType === + ArrayType(ArrayType(StringType, containsNull = true), containsNull = true)) } test("Flatten") { @@ -766,4 +1293,130 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(ArrayRemove(c1, dataToRemove2), Seq[Seq[Int]](Seq[Int](5, 6), Seq[Int](2, 1))) checkEvaluation(ArrayRemove(c2, dataToRemove2), Seq[Seq[Int]](null, Seq[Int](2, 1))) } + + test("Array Distinct") { + val a0 = Literal.create(Seq(2, 1, 2, 3, 4, 4, 5), ArrayType(IntegerType)) + val a1 = Literal.create(Seq.empty[Integer], ArrayType(IntegerType)) + val a2 = Literal.create(Seq("b", "a", "a", "c", "b"), ArrayType(StringType)) + val a3 = Literal.create(Seq("b", null, "a", null, "a", null), ArrayType(StringType)) + val a4 = Literal.create(Seq(null, null, null), ArrayType(NullType)) + val a5 = Literal.create(Seq(true, false, false, true), ArrayType(BooleanType)) + val a6 = Literal.create(Seq(1.123, 0.1234, 1.121, 1.123, 1.1230, 1.121, 0.1234), + ArrayType(DoubleType)) + val a7 = Literal.create(Seq(1.123f, 0.1234f, 1.121f, 1.123f, 1.1230f, 1.121f, 0.1234f), + ArrayType(FloatType)) + + checkEvaluation(new ArrayDistinct(a0), Seq(2, 1, 3, 4, 5)) + checkEvaluation(new ArrayDistinct(a1), Seq.empty[Integer]) + checkEvaluation(new ArrayDistinct(a2), Seq("b", "a", "c")) + checkEvaluation(new ArrayDistinct(a3), Seq("b", null, "a")) + checkEvaluation(new ArrayDistinct(a4), Seq(null)) + checkEvaluation(new ArrayDistinct(a5), Seq(true, false)) + checkEvaluation(new ArrayDistinct(a6), Seq(1.123, 0.1234, 1.121)) + checkEvaluation(new ArrayDistinct(a7), Seq(1.123f, 0.1234f, 1.121f)) + + // complex data types + val b0 = Literal.create(Seq[Array[Byte]](Array[Byte](5, 6), Array[Byte](1, 2), + Array[Byte](1, 2), Array[Byte](5, 6)), ArrayType(BinaryType)) + val b1 = Literal.create(Seq[Array[Byte]](Array[Byte](2, 1), null), + ArrayType(BinaryType)) + val b2 = Literal.create(Seq[Array[Byte]](Array[Byte](5, 6), null, Array[Byte](1, 2), + null, Array[Byte](5, 6), null), ArrayType(BinaryType)) + + checkEvaluation(ArrayDistinct(b0), Seq[Array[Byte]](Array[Byte](5, 6), Array[Byte](1, 2))) + checkEvaluation(ArrayDistinct(b1), Seq[Array[Byte]](Array[Byte](2, 1), null)) + checkEvaluation(ArrayDistinct(b2), Seq[Array[Byte]](Array[Byte](5, 6), null, + Array[Byte](1, 2))) + + val c0 = Literal.create(Seq[Seq[Int]](Seq[Int](1, 2), Seq[Int](3, 4), Seq[Int](1, 2), + Seq[Int](3, 4), Seq[Int](1, 2)), ArrayType(ArrayType(IntegerType))) + val c1 = Literal.create(Seq[Seq[Int]](Seq[Int](5, 6), Seq[Int](2, 1)), + ArrayType(ArrayType(IntegerType))) + val c2 = Literal.create(Seq[Seq[Int]](null, Seq[Int](2, 1), null, null, Seq[Int](2, 1), null), + ArrayType(ArrayType(IntegerType))) + checkEvaluation(ArrayDistinct(c0), Seq[Seq[Int]](Seq[Int](1, 2), Seq[Int](3, 4))) + checkEvaluation(ArrayDistinct(c1), Seq[Seq[Int]](Seq[Int](5, 6), Seq[Int](2, 1))) + checkEvaluation(ArrayDistinct(c2), Seq[Seq[Int]](null, Seq[Int](2, 1))) + } + + test("Array Union") { + val a00 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType, containsNull = false)) + val a01 = Literal.create(Seq(4, 2), ArrayType(IntegerType, containsNull = false)) + val a02 = Literal.create(Seq(1, 2, null, 4, 5), ArrayType(IntegerType, containsNull = true)) + val a03 = Literal.create(Seq(-5, 4, -3, 2, 4), ArrayType(IntegerType, containsNull = false)) + val a04 = Literal.create(Seq.empty[Int], ArrayType(IntegerType, containsNull = false)) + val a05 = Literal.create(Seq[Byte](1, 2, 3), ArrayType(ByteType, containsNull = false)) + val a06 = Literal.create(Seq[Byte](4, 2), ArrayType(ByteType, containsNull = false)) + val a07 = Literal.create(Seq[Short](1, 2, 3), ArrayType(ShortType, containsNull = false)) + val a08 = Literal.create(Seq[Short](4, 2), ArrayType(ShortType, containsNull = false)) + + val a10 = Literal.create(Seq(1L, 2L, 3L), ArrayType(LongType, containsNull = false)) + val a11 = Literal.create(Seq(4L, 2L), ArrayType(LongType, containsNull = false)) + val a12 = Literal.create(Seq(1L, 2L, null, 4L, 5L), ArrayType(LongType, containsNull = true)) + val a13 = Literal.create(Seq(-5L, 4L, -3L, 2L, -1L), ArrayType(LongType, containsNull = false)) + val a14 = Literal.create(Seq.empty[Long], ArrayType(LongType, containsNull = false)) + + val a20 = Literal.create(Seq("b", "a", "c"), ArrayType(StringType, containsNull = false)) + val a21 = Literal.create(Seq("c", "d", "a", "f"), ArrayType(StringType, containsNull = false)) + val a22 = Literal.create(Seq("b", null, "a", "g"), ArrayType(StringType, containsNull = true)) + + val a30 = Literal.create(Seq(null, null), ArrayType(IntegerType)) + val a31 = Literal.create(null, ArrayType(StringType)) + + checkEvaluation(ArrayUnion(a00, a01), Seq(1, 2, 3, 4)) + checkEvaluation(ArrayUnion(a02, a03), Seq(1, 2, null, 4, 5, -5, -3)) + checkEvaluation(ArrayUnion(a03, a02), Seq(-5, 4, -3, 2, 1, null, 5)) + checkEvaluation(ArrayUnion(a02, a04), Seq(1, 2, null, 4, 5)) + checkEvaluation(ArrayUnion(a05, a06), Seq[Byte](1, 2, 3, 4)) + checkEvaluation(ArrayUnion(a07, a08), Seq[Short](1, 2, 3, 4)) + + checkEvaluation(ArrayUnion(a10, a11), Seq(1L, 2L, 3L, 4L)) + checkEvaluation(ArrayUnion(a12, a13), Seq(1L, 2L, null, 4L, 5L, -5L, -3L, -1L)) + checkEvaluation(ArrayUnion(a13, a12), Seq(-5L, 4L, -3L, 2L, -1L, 1L, null, 5L)) + checkEvaluation(ArrayUnion(a12, a14), Seq(1L, 2L, null, 4L, 5L)) + + checkEvaluation(ArrayUnion(a20, a21), Seq("b", "a", "c", "d", "f")) + checkEvaluation(ArrayUnion(a20, a22), Seq("b", "a", "c", null, "g")) + + checkEvaluation(ArrayUnion(a30, a30), Seq(null)) + checkEvaluation(ArrayUnion(a20, a31), null) + checkEvaluation(ArrayUnion(a31, a20), null) + + val b0 = Literal.create(Seq[Array[Byte]](Array[Byte](5, 6), Array[Byte](1, 2)), + ArrayType(BinaryType)) + val b1 = Literal.create(Seq[Array[Byte]](Array[Byte](2, 1), Array[Byte](4, 3)), + ArrayType(BinaryType)) + val b2 = Literal.create(Seq[Array[Byte]](Array[Byte](1, 2), Array[Byte](4, 3)), + ArrayType(BinaryType)) + val b3 = Literal.create(Seq[Array[Byte]]( + Array[Byte](1, 2), Array[Byte](4, 3), Array[Byte](1, 2)), ArrayType(BinaryType)) + val b4 = Literal.create(Seq[Array[Byte]](Array[Byte](1, 2), null), ArrayType(BinaryType)) + val b5 = Literal.create(Seq[Array[Byte]](null, Array[Byte](1, 2)), ArrayType(BinaryType)) + val b6 = Literal.create(Seq.empty, ArrayType(BinaryType)) + val arrayWithBinaryNull = Literal.create(Seq(null), ArrayType(BinaryType)) + + checkEvaluation(ArrayUnion(b0, b1), + Seq(Array[Byte](5, 6), Array[Byte](1, 2), Array[Byte](2, 1), Array[Byte](4, 3))) + checkEvaluation(ArrayUnion(b0, b2), + Seq(Array[Byte](5, 6), Array[Byte](1, 2), Array[Byte](4, 3))) + checkEvaluation(ArrayUnion(b2, b4), Seq(Array[Byte](1, 2), Array[Byte](4, 3), null)) + checkEvaluation(ArrayUnion(b3, b0), + Seq(Array[Byte](1, 2), Array[Byte](4, 3), Array[Byte](5, 6))) + checkEvaluation(ArrayUnion(b4, b0), Seq(Array[Byte](1, 2), null, Array[Byte](5, 6))) + checkEvaluation(ArrayUnion(b4, b5), Seq(Array[Byte](1, 2), null)) + checkEvaluation(ArrayUnion(b6, b4), Seq(Array[Byte](1, 2), null)) + checkEvaluation(ArrayUnion(b4, arrayWithBinaryNull), Seq(Array[Byte](1, 2), null)) + + val aa0 = Literal.create(Seq[Seq[Int]](Seq[Int](1, 2), Seq[Int](3, 4)), + ArrayType(ArrayType(IntegerType))) + val aa1 = Literal.create(Seq[Seq[Int]](Seq[Int](5, 6), Seq[Int](2, 1)), + ArrayType(ArrayType(IntegerType))) + checkEvaluation(ArrayUnion(aa0, aa1), + Seq[Seq[Int]](Seq[Int](1, 2), Seq[Int](3, 4), Seq[Int](5, 6), Seq[Int](2, 1))) + + assert(ArrayUnion(a00, a01).dataType.asInstanceOf[ArrayType].containsNull === false) + assert(ArrayUnion(a00, a02).dataType.asInstanceOf[ArrayType].containsNull === true) + assert(ArrayUnion(a20, a21).dataType.asInstanceOf[ArrayType].containsNull === false) + assert(ArrayUnion(a20, a22).dataType.asInstanceOf[ArrayType].containsNull === true) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala index 726193b411737..77aaf55480ec2 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala @@ -144,6 +144,13 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(CreateArray(byteWithNull), byteSeq :+ null, EmptyRow) checkEvaluation(CreateArray(strWithNull), strSeq :+ null, EmptyRow) checkEvaluation(CreateArray(Literal.create(null, IntegerType) :: Nil), null :: Nil) + + val array = CreateArray(Seq( + Literal.create(intSeq, ArrayType(IntegerType, containsNull = false)), + Literal.create(intSeq :+ null, ArrayType(IntegerType, containsNull = true)))) + assert(array.dataType === + ArrayType(ArrayType(IntegerType, containsNull = true), containsNull = false)) + checkEvaluation(array, Seq(intSeq, intSeq :+ null)) } test("CreateMap") { @@ -184,6 +191,18 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper { CreateMap(interlace(strWithNull, intSeq.map(Literal(_)))), null, null) } + + val map = CreateMap(Seq( + Literal.create(intSeq, ArrayType(IntegerType, containsNull = false)), + Literal.create(strSeq, ArrayType(StringType, containsNull = false)), + Literal.create(intSeq :+ null, ArrayType(IntegerType, containsNull = true)), + Literal.create(strSeq :+ null, ArrayType(StringType, containsNull = true)))) + assert(map.dataType === + MapType( + ArrayType(IntegerType, containsNull = true), + ArrayType(StringType, containsNull = true), + valueContainsNull = false)) + checkEvaluation(map, createMap(Seq(intSeq, intSeq :+ null), Seq(strSeq, strSeq :+ null))) } test("MapFromArrays") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala index a099119732e25..e068c32500cfc 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala @@ -113,6 +113,76 @@ class ConditionalExpressionSuite extends SparkFunSuite with ExpressionEvalHelper assert(CaseWhen(Seq((c2, c4_notNull), (c3, c5))).nullable === true) } + test("if/case when - null flags of non-primitive types") { + val arrayWithNulls = Literal.create(Seq("a", null, "b"), ArrayType(StringType, true)) + val arrayWithoutNulls = Literal.create(Seq("c", "d"), ArrayType(StringType, false)) + val structWithNulls = Literal.create( + create_row(null, null), + StructType(Seq(StructField("a", IntegerType, true), StructField("b", StringType, true)))) + val structWithoutNulls = Literal.create( + create_row(1, "a"), + StructType(Seq(StructField("a", IntegerType, false), StructField("b", StringType, false)))) + val mapWithNulls = Literal.create(Map(1 -> null), MapType(IntegerType, StringType, true)) + val mapWithoutNulls = Literal.create(Map(1 -> "a"), MapType(IntegerType, StringType, false)) + + val arrayIf1 = If(Literal.FalseLiteral, arrayWithNulls, arrayWithoutNulls) + val arrayIf2 = If(Literal.FalseLiteral, arrayWithoutNulls, arrayWithNulls) + val arrayIf3 = If(Literal.TrueLiteral, arrayWithNulls, arrayWithoutNulls) + val arrayIf4 = If(Literal.TrueLiteral, arrayWithoutNulls, arrayWithNulls) + val structIf1 = If(Literal.FalseLiteral, structWithNulls, structWithoutNulls) + val structIf2 = If(Literal.FalseLiteral, structWithoutNulls, structWithNulls) + val structIf3 = If(Literal.TrueLiteral, structWithNulls, structWithoutNulls) + val structIf4 = If(Literal.TrueLiteral, structWithoutNulls, structWithNulls) + val mapIf1 = If(Literal.FalseLiteral, mapWithNulls, mapWithoutNulls) + val mapIf2 = If(Literal.FalseLiteral, mapWithoutNulls, mapWithNulls) + val mapIf3 = If(Literal.TrueLiteral, mapWithNulls, mapWithoutNulls) + val mapIf4 = If(Literal.TrueLiteral, mapWithoutNulls, mapWithNulls) + + val arrayCaseWhen1 = CaseWhen(Seq((Literal.FalseLiteral, arrayWithNulls)), arrayWithoutNulls) + val arrayCaseWhen2 = CaseWhen(Seq((Literal.FalseLiteral, arrayWithoutNulls)), arrayWithNulls) + val arrayCaseWhen3 = CaseWhen(Seq((Literal.TrueLiteral, arrayWithNulls)), arrayWithoutNulls) + val arrayCaseWhen4 = CaseWhen(Seq((Literal.TrueLiteral, arrayWithoutNulls)), arrayWithNulls) + val structCaseWhen1 = CaseWhen(Seq((Literal.FalseLiteral, structWithNulls)), structWithoutNulls) + val structCaseWhen2 = CaseWhen(Seq((Literal.FalseLiteral, structWithoutNulls)), structWithNulls) + val structCaseWhen3 = CaseWhen(Seq((Literal.TrueLiteral, structWithNulls)), structWithoutNulls) + val structCaseWhen4 = CaseWhen(Seq((Literal.TrueLiteral, structWithoutNulls)), structWithNulls) + val mapCaseWhen1 = CaseWhen(Seq((Literal.FalseLiteral, mapWithNulls)), mapWithoutNulls) + val mapCaseWhen2 = CaseWhen(Seq((Literal.FalseLiteral, mapWithoutNulls)), mapWithNulls) + val mapCaseWhen3 = CaseWhen(Seq((Literal.TrueLiteral, mapWithNulls)), mapWithoutNulls) + val mapCaseWhen4 = CaseWhen(Seq((Literal.TrueLiteral, mapWithoutNulls)), mapWithNulls) + + def checkResult(expectedType: DataType, expectedValue: Any, result: Expression): Unit = { + assert(expectedType == result.dataType) + checkEvaluation(result, expectedValue) + } + + checkResult(arrayWithNulls.dataType, arrayWithoutNulls.value, arrayIf1) + checkResult(arrayWithNulls.dataType, arrayWithNulls.value, arrayIf2) + checkResult(arrayWithNulls.dataType, arrayWithNulls.value, arrayIf3) + checkResult(arrayWithNulls.dataType, arrayWithoutNulls.value, arrayIf4) + checkResult(structWithNulls.dataType, structWithoutNulls.value, structIf1) + checkResult(structWithNulls.dataType, structWithNulls.value, structIf2) + checkResult(structWithNulls.dataType, structWithNulls.value, structIf3) + checkResult(structWithNulls.dataType, structWithoutNulls.value, structIf4) + checkResult(mapWithNulls.dataType, mapWithoutNulls.value, mapIf1) + checkResult(mapWithNulls.dataType, mapWithNulls.value, mapIf2) + checkResult(mapWithNulls.dataType, mapWithNulls.value, mapIf3) + checkResult(mapWithNulls.dataType, mapWithoutNulls.value, mapIf4) + + checkResult(arrayWithNulls.dataType, arrayWithoutNulls.value, arrayCaseWhen1) + checkResult(arrayWithNulls.dataType, arrayWithNulls.value, arrayCaseWhen2) + checkResult(arrayWithNulls.dataType, arrayWithNulls.value, arrayCaseWhen3) + checkResult(arrayWithNulls.dataType, arrayWithoutNulls.value, arrayCaseWhen4) + checkResult(structWithNulls.dataType, structWithoutNulls.value, structCaseWhen1) + checkResult(structWithNulls.dataType, structWithNulls.value, structCaseWhen2) + checkResult(structWithNulls.dataType, structWithNulls.value, structCaseWhen3) + checkResult(structWithNulls.dataType, structWithoutNulls.value, structCaseWhen4) + checkResult(mapWithNulls.dataType, mapWithoutNulls.value, mapCaseWhen1) + checkResult(mapWithNulls.dataType, mapWithNulls.value, mapCaseWhen2) + checkResult(mapWithNulls.dataType, mapWithNulls.value, mapCaseWhen3) + checkResult(mapWithNulls.dataType, mapWithoutNulls.value, mapCaseWhen4) + } + test("case key when") { val row = create_row(null, 1, 2, "a", "b", "c") val c1 = 'a.int.at(0) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala index 00e97637eee7e..04f1c8ce0b83d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala @@ -392,7 +392,7 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper with val jsonData = """{"a": 1}""" val schema = StructType(StructField("a", IntegerType) :: Nil) checkEvaluation( - JsonToStructs(schema, Map.empty, Literal(jsonData), gmtId, true), + JsonToStructs(schema, Map.empty, Literal(jsonData), gmtId), InternalRow(1) ) } @@ -401,13 +401,13 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper with val jsonData = """{"a" 1}""" val schema = StructType(StructField("a", IntegerType) :: Nil) checkEvaluation( - JsonToStructs(schema, Map.empty, Literal(jsonData), gmtId, true), + JsonToStructs(schema, Map.empty, Literal(jsonData), gmtId), null ) // Other modes should still return `null`. checkEvaluation( - JsonToStructs(schema, Map("mode" -> PermissiveMode.name), Literal(jsonData), gmtId, true), + JsonToStructs(schema, Map("mode" -> PermissiveMode.name), Literal(jsonData), gmtId), null ) } @@ -416,62 +416,62 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper with val input = """[{"a": 1}, {"a": 2}]""" val schema = ArrayType(StructType(StructField("a", IntegerType) :: Nil)) val output = InternalRow(1) :: InternalRow(2) :: Nil - checkEvaluation(JsonToStructs(schema, Map.empty, Literal(input), gmtId, true), output) + checkEvaluation(JsonToStructs(schema, Map.empty, Literal(input), gmtId), output) } test("from_json - input=object, schema=array, output=array of single row") { val input = """{"a": 1}""" val schema = ArrayType(StructType(StructField("a", IntegerType) :: Nil)) val output = InternalRow(1) :: Nil - checkEvaluation(JsonToStructs(schema, Map.empty, Literal(input), gmtId, true), output) + checkEvaluation(JsonToStructs(schema, Map.empty, Literal(input), gmtId), output) } test("from_json - input=empty array, schema=array, output=empty array") { val input = "[ ]" val schema = ArrayType(StructType(StructField("a", IntegerType) :: Nil)) val output = Nil - checkEvaluation(JsonToStructs(schema, Map.empty, Literal(input), gmtId, true), output) + checkEvaluation(JsonToStructs(schema, Map.empty, Literal(input), gmtId), output) } test("from_json - input=empty object, schema=array, output=array of single row with null") { val input = "{ }" val schema = ArrayType(StructType(StructField("a", IntegerType) :: Nil)) val output = InternalRow(null) :: Nil - checkEvaluation(JsonToStructs(schema, Map.empty, Literal(input), gmtId, true), output) + checkEvaluation(JsonToStructs(schema, Map.empty, Literal(input), gmtId), output) } test("from_json - input=array of single object, schema=struct, output=single row") { val input = """[{"a": 1}]""" val schema = StructType(StructField("a", IntegerType) :: Nil) val output = InternalRow(1) - checkEvaluation(JsonToStructs(schema, Map.empty, Literal(input), gmtId, true), output) + checkEvaluation(JsonToStructs(schema, Map.empty, Literal(input), gmtId), output) } test("from_json - input=array, schema=struct, output=null") { val input = """[{"a": 1}, {"a": 2}]""" val schema = StructType(StructField("a", IntegerType) :: Nil) val output = null - checkEvaluation(JsonToStructs(schema, Map.empty, Literal(input), gmtId, true), output) + checkEvaluation(JsonToStructs(schema, Map.empty, Literal(input), gmtId), output) } test("from_json - input=empty array, schema=struct, output=null") { val input = """[]""" val schema = StructType(StructField("a", IntegerType) :: Nil) val output = null - checkEvaluation(JsonToStructs(schema, Map.empty, Literal(input), gmtId, true), output) + checkEvaluation(JsonToStructs(schema, Map.empty, Literal(input), gmtId), output) } test("from_json - input=empty object, schema=struct, output=single row with null") { val input = """{ }""" val schema = StructType(StructField("a", IntegerType) :: Nil) val output = InternalRow(null) - checkEvaluation(JsonToStructs(schema, Map.empty, Literal(input), gmtId, true), output) + checkEvaluation(JsonToStructs(schema, Map.empty, Literal(input), gmtId), output) } test("from_json null input column") { val schema = StructType(StructField("a", IntegerType) :: Nil) checkEvaluation( - JsonToStructs(schema, Map.empty, Literal.create(null, StringType), gmtId, true), + JsonToStructs(schema, Map.empty, Literal.create(null, StringType), gmtId), null ) } @@ -479,7 +479,7 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper with test("SPARK-20549: from_json bad UTF-8") { val schema = StructType(StructField("a", IntegerType) :: Nil) checkEvaluation( - JsonToStructs(schema, Map.empty, Literal(badJson), gmtId, true), + JsonToStructs(schema, Map.empty, Literal(badJson), gmtId), null) } @@ -491,14 +491,14 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper with c.set(2016, 0, 1, 0, 0, 0) c.set(Calendar.MILLISECOND, 123) checkEvaluation( - JsonToStructs(schema, Map.empty, Literal(jsonData1), gmtId, true), + JsonToStructs(schema, Map.empty, Literal(jsonData1), gmtId), InternalRow(c.getTimeInMillis * 1000L) ) // The result doesn't change because the json string includes timezone string ("Z" here), // which means the string represents the timestamp string in the timezone regardless of // the timeZoneId parameter. checkEvaluation( - JsonToStructs(schema, Map.empty, Literal(jsonData1), Option("PST"), true), + JsonToStructs(schema, Map.empty, Literal(jsonData1), Option("PST")), InternalRow(c.getTimeInMillis * 1000L) ) @@ -512,8 +512,7 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper with schema, Map("timestampFormat" -> "yyyy-MM-dd'T'HH:mm:ss"), Literal(jsonData2), - Option(tz.getID), - true), + Option(tz.getID)), InternalRow(c.getTimeInMillis * 1000L) ) checkEvaluation( @@ -522,8 +521,7 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper with Map("timestampFormat" -> "yyyy-MM-dd'T'HH:mm:ss", DateTimeUtils.TIMEZONE_OPTION -> tz.getID), Literal(jsonData2), - gmtId, - true), + gmtId), InternalRow(c.getTimeInMillis * 1000L) ) } @@ -532,7 +530,7 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper with test("SPARK-19543: from_json empty input column") { val schema = StructType(StructField("a", IntegerType) :: Nil) checkEvaluation( - JsonToStructs(schema, Map.empty, Literal.create(" ", StringType), gmtId, true), + JsonToStructs(schema, Map.empty, Literal.create(" ", StringType), gmtId), null ) } @@ -687,23 +685,31 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper with test("from_json missing fields") { for (forceJsonNullableSchema <- Seq(false, true)) { - val input = - """{ + withSQLConf(SQLConf.FROM_JSON_FORCE_NULLABLE_SCHEMA.key -> forceJsonNullableSchema.toString) { + val input = + """{ | "a": 1, | "c": "foo" |} |""".stripMargin - val jsonSchema = new StructType() - .add("a", LongType, nullable = false) - .add("b", StringType, nullable = false) - .add("c", StringType, nullable = false) - val output = InternalRow(1L, null, UTF8String.fromString("foo")) - val expr = JsonToStructs( - jsonSchema, Map.empty, Literal.create(input, StringType), gmtId, forceJsonNullableSchema) - checkEvaluation(expr, output) - val schema = expr.dataType - val schemaToCompare = if (forceJsonNullableSchema) jsonSchema.asNullable else jsonSchema - assert(schemaToCompare == schema) + val jsonSchema = new StructType() + .add("a", LongType, nullable = false) + .add("b", StringType, nullable = false) + .add("c", StringType, nullable = false) + val output = InternalRow(1L, null, UTF8String.fromString("foo")) + val expr = JsonToStructs(jsonSchema, Map.empty, Literal.create(input, StringType), gmtId) + checkEvaluation(expr, output) + val schema = expr.dataType + val schemaToCompare = if (forceJsonNullableSchema) jsonSchema.asNullable else jsonSchema + assert(schemaToCompare == schema) + } } } + + test("SPARK-24709: infer schema of json strings") { + checkEvaluation(SchemaOfJson(Literal.create("""{"col":0}""")), "struct") + checkEvaluation( + SchemaOfJson(Literal.create("""{"col0":["a"], "col1": {"col2": "b"}}""")), + "struct,col1:struct>") + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala index a9e0eb0e377a6..86f80fe66d28b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala @@ -219,4 +219,11 @@ class LiteralExpressionSuite extends SparkFunSuite with ExpressionEvalHelper { checkUnsupportedTypeInLiteral(Map("key1" -> 1, "key2" -> 2)) checkUnsupportedTypeInLiteral(("mike", 29, 1.0)) } + + test("SPARK-24571: char literals") { + checkEvaluation(Literal('X'), "X") + checkEvaluation(Literal.create('0'), "0") + checkEvaluation(Literal('\u0000'), "\u0000") + checkEvaluation(Literal.create('\n'), "\n") + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MaskExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MaskExpressionsSuite.scala deleted file mode 100644 index 4d69dc32ace82..0000000000000 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MaskExpressionsSuite.scala +++ /dev/null @@ -1,236 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.expressions - -import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.types.{IntegerType, StringType} - -class MaskExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { - - test("mask") { - checkEvaluation(Mask(Literal("abcd-EFGH-8765-4321"), "U", "l", "#"), "llll-UUUU-####-####") - checkEvaluation( - new Mask(Literal("abcd-EFGH-8765-4321"), Literal("U"), Literal("l"), Literal("#")), - "llll-UUUU-####-####") - checkEvaluation(new Mask(Literal("abcd-EFGH-8765-4321"), Literal("U"), Literal("l")), - "llll-UUUU-nnnn-nnnn") - checkEvaluation(new Mask(Literal("abcd-EFGH-8765-4321"), Literal("U")), "xxxx-UUUU-nnnn-nnnn") - checkEvaluation(new Mask(Literal("abcd-EFGH-8765-4321")), "xxxx-XXXX-nnnn-nnnn") - checkEvaluation(new Mask(Literal(null, StringType)), null) - checkEvaluation(Mask(Literal("abcd-EFGH-8765-4321"), null, "l", "#"), "llll-XXXX-####-####") - checkEvaluation(new Mask( - Literal("abcd-EFGH-8765-4321"), - Literal(null, StringType), - Literal(null, StringType), - Literal(null, StringType)), "xxxx-XXXX-nnnn-nnnn") - checkEvaluation(new Mask(Literal("abcd-EFGH-8765-4321"), Literal("Upper")), - "xxxx-UUUU-nnnn-nnnn") - checkEvaluation(new Mask(Literal("")), "") - checkEvaluation(new Mask(Literal("abcd-EFGH-8765-4321"), Literal("")), "xxxx-XXXX-nnnn-nnnn") - checkEvaluation(Mask(Literal("abcd-EFGH-8765-4321"), "", "", ""), "xxxx-XXXX-nnnn-nnnn") - // scalastyle:off nonascii - checkEvaluation(Mask(Literal("Ul9U"), "\u2200", null, null), "\u2200xn\u2200") - checkEvaluation(new Mask(Literal("Hello World, こんにちは, 𠀋"), Literal("あ"), Literal("𡈽")), - "あ𡈽𡈽𡈽𡈽 あ𡈽𡈽𡈽𡈽, こんにちは, 𠀋") - // scalastyle:on nonascii - intercept[AnalysisException] { - checkEvaluation(new Mask(Literal(""), Literal(1)), "") - } - } - - test("mask_first_n") { - checkEvaluation(MaskFirstN(Literal("aB3d-EFGH-8765"), 6, "U", "l", "#"), - "lU#l-UFGH-8765") - checkEvaluation(new MaskFirstN( - Literal("abcd-EFGH-8765-4321"), Literal(6), Literal("U"), Literal("l"), Literal("#")), - "llll-UFGH-8765-4321") - checkEvaluation( - new MaskFirstN(Literal("abcd-EFGH-8765-4321"), Literal(6), Literal("U"), Literal("l")), - "llll-UFGH-8765-4321") - checkEvaluation(new MaskFirstN(Literal("abcd-EFGH-8765-4321"), Literal(6), Literal("U")), - "xxxx-UFGH-8765-4321") - checkEvaluation(new MaskFirstN(Literal("abcd-EFGH-8765-4321"), Literal(6)), - "xxxx-XFGH-8765-4321") - intercept[AnalysisException] { - checkEvaluation(new MaskFirstN(Literal("abcd-EFGH-8765-4321"), Literal("U")), "") - } - checkEvaluation(new MaskFirstN(Literal("abcd-EFGH-8765-4321")), "xxxx-EFGH-8765-4321") - checkEvaluation(new MaskFirstN(Literal(null, StringType)), null) - checkEvaluation(MaskFirstN(Literal("abcd-EFGH-8765-4321"), 4, "U", "l", null), - "llll-EFGH-8765-4321") - checkEvaluation(new MaskFirstN( - Literal("abcd-EFGH-8765-4321"), - Literal(null, IntegerType), - Literal(null, StringType), - Literal(null, StringType), - Literal(null, StringType)), "xxxx-EFGH-8765-4321") - checkEvaluation(new MaskFirstN(Literal("abcd-EFGH-8765-4321"), Literal(6), Literal("Upper")), - "xxxx-UFGH-8765-4321") - checkEvaluation(new MaskFirstN(Literal("")), "") - checkEvaluation(new MaskFirstN(Literal("abcd-EFGH-8765-4321"), Literal(4), Literal("")), - "xxxx-EFGH-8765-4321") - checkEvaluation(MaskFirstN(Literal("abcd-EFGH-8765-4321"), 1000, "", "", ""), - "xxxx-XXXX-nnnn-nnnn") - checkEvaluation(MaskFirstN(Literal("abcd-EFGH-8765-4321"), -1, "", "", ""), - "abcd-EFGH-8765-4321") - // scalastyle:off nonascii - checkEvaluation(MaskFirstN(Literal("Ul9U"), 2, "\u2200", null, null), "\u2200x9U") - checkEvaluation(new MaskFirstN(Literal("あ, 𠀋, Hello World"), Literal(10)), - "あ, 𠀋, Xxxxo World") - // scalastyle:on nonascii - } - - test("mask_last_n") { - checkEvaluation(MaskLastN(Literal("abcd-EFGH-aB3d"), 6, "U", "l", "#"), - "abcd-EFGU-lU#l") - checkEvaluation(new MaskLastN( - Literal("abcd-EFGH-8765"), Literal(6), Literal("U"), Literal("l"), Literal("#")), - "abcd-EFGU-####") - checkEvaluation( - new MaskLastN(Literal("abcd-EFGH-8765"), Literal(6), Literal("U"), Literal("l")), - "abcd-EFGU-nnnn") - checkEvaluation( - new MaskLastN(Literal("abcd-EFGH-8765"), Literal(6), Literal("U")), - "abcd-EFGU-nnnn") - checkEvaluation( - new MaskLastN(Literal("abcd-EFGH-8765"), Literal(6)), - "abcd-EFGX-nnnn") - intercept[AnalysisException] { - checkEvaluation(new MaskLastN(Literal("abcd-EFGH-8765"), Literal("U")), "") - } - checkEvaluation(new MaskLastN(Literal("abcd-EFGH-8765-4321")), "abcd-EFGH-8765-nnnn") - checkEvaluation(new MaskLastN(Literal(null, StringType)), null) - checkEvaluation(MaskLastN(Literal("abcd-EFGH-8765-4321"), 4, "U", "l", null), - "abcd-EFGH-8765-nnnn") - checkEvaluation(new MaskLastN( - Literal("abcd-EFGH-8765-4321"), - Literal(null, IntegerType), - Literal(null, StringType), - Literal(null, StringType), - Literal(null, StringType)), "abcd-EFGH-8765-nnnn") - checkEvaluation(new MaskLastN(Literal("abcd-EFGH-8765-4321"), Literal(12), Literal("Upper")), - "abcd-EFUU-nnnn-nnnn") - checkEvaluation(new MaskLastN(Literal("")), "") - checkEvaluation(new MaskLastN(Literal("abcd-EFGH-8765-4321"), Literal(16), Literal("")), - "abcx-XXXX-nnnn-nnnn") - checkEvaluation(MaskLastN(Literal("abcd-EFGH-8765-4321"), 1000, "", "", ""), - "xxxx-XXXX-nnnn-nnnn") - checkEvaluation(MaskLastN(Literal("abcd-EFGH-8765-4321"), -1, "", "", ""), - "abcd-EFGH-8765-4321") - // scalastyle:off nonascii - checkEvaluation(MaskLastN(Literal("Ul9U"), 2, "\u2200", null, null), "Uln\u2200") - checkEvaluation(new MaskLastN(Literal("あ, 𠀋, Hello World あ 𠀋"), Literal(10)), - "あ, 𠀋, Hello Xxxxx あ 𠀋") - // scalastyle:on nonascii - } - - test("mask_show_first_n") { - checkEvaluation(MaskShowFirstN(Literal("abcd-EFGH-8765-aB3d"), 6, "U", "l", "#"), - "abcd-EUUU-####-lU#l") - checkEvaluation(new MaskShowFirstN( - Literal("abcd-EFGH-8765-4321"), Literal(6), Literal("U"), Literal("l"), Literal("#")), - "abcd-EUUU-####-####") - checkEvaluation( - new MaskShowFirstN(Literal("abcd-EFGH-8765-4321"), Literal(6), Literal("U"), Literal("l")), - "abcd-EUUU-nnnn-nnnn") - checkEvaluation(new MaskShowFirstN(Literal("abcd-EFGH-8765-4321"), Literal(6), Literal("U")), - "abcd-EUUU-nnnn-nnnn") - checkEvaluation(new MaskShowFirstN(Literal("abcd-EFGH-8765-4321"), Literal(6)), - "abcd-EXXX-nnnn-nnnn") - intercept[AnalysisException] { - checkEvaluation(new MaskShowFirstN(Literal("abcd-EFGH-8765-4321"), Literal("U")), "") - } - checkEvaluation(new MaskShowFirstN(Literal("abcd-EFGH-8765-4321")), "abcd-XXXX-nnnn-nnnn") - checkEvaluation(new MaskShowFirstN(Literal(null, StringType)), null) - checkEvaluation(MaskShowFirstN(Literal("abcd-EFGH-8765-4321"), 4, "U", "l", null), - "abcd-UUUU-nnnn-nnnn") - checkEvaluation(new MaskShowFirstN( - Literal("abcd-EFGH-8765-4321"), - Literal(null, IntegerType), - Literal(null, StringType), - Literal(null, StringType), - Literal(null, StringType)), "abcd-XXXX-nnnn-nnnn") - checkEvaluation( - new MaskShowFirstN(Literal("abcd-EFGH-8765-4321"), Literal(6), Literal("Upper")), - "abcd-EUUU-nnnn-nnnn") - checkEvaluation(new MaskShowFirstN(Literal("")), "") - checkEvaluation(new MaskShowFirstN(Literal("abcd-EFGH-8765-4321"), Literal(4), Literal("")), - "abcd-XXXX-nnnn-nnnn") - checkEvaluation(MaskShowFirstN(Literal("abcd-EFGH-8765-4321"), 1000, "", "", ""), - "abcd-EFGH-8765-4321") - checkEvaluation(MaskShowFirstN(Literal("abcd-EFGH-8765-4321"), -1, "", "", ""), - "xxxx-XXXX-nnnn-nnnn") - // scalastyle:off nonascii - checkEvaluation(MaskShowFirstN(Literal("Ul9U"), 2, "\u2200", null, null), "Uln\u2200") - checkEvaluation(new MaskShowFirstN(Literal("あ, 𠀋, Hello World"), Literal(10)), - "あ, 𠀋, Hellx Xxxxx") - // scalastyle:on nonascii - } - - test("mask_show_last_n") { - checkEvaluation(MaskShowLastN(Literal("aB3d-EFGH-8765"), 6, "U", "l", "#"), - "lU#l-UUUH-8765") - checkEvaluation(new MaskShowLastN( - Literal("abcd-EFGH-8765-4321"), Literal(6), Literal("U"), Literal("l"), Literal("#")), - "llll-UUUU-###5-4321") - checkEvaluation( - new MaskShowLastN(Literal("abcd-EFGH-8765-4321"), Literal(6), Literal("U"), Literal("l")), - "llll-UUUU-nnn5-4321") - checkEvaluation(new MaskShowLastN(Literal("abcd-EFGH-8765-4321"), Literal(6), Literal("U")), - "xxxx-UUUU-nnn5-4321") - checkEvaluation(new MaskShowLastN(Literal("abcd-EFGH-8765-4321"), Literal(6)), - "xxxx-XXXX-nnn5-4321") - intercept[AnalysisException] { - checkEvaluation(new MaskShowLastN(Literal("abcd-EFGH-8765-4321"), Literal("U")), "") - } - checkEvaluation(new MaskShowLastN(Literal("abcd-EFGH-8765-4321")), "xxxx-XXXX-nnnn-4321") - checkEvaluation(new MaskShowLastN(Literal(null, StringType)), null) - checkEvaluation(MaskShowLastN(Literal("abcd-EFGH-8765-4321"), 4, "U", "l", null), - "llll-UUUU-nnnn-4321") - checkEvaluation(new MaskShowLastN( - Literal("abcd-EFGH-8765-4321"), - Literal(null, IntegerType), - Literal(null, StringType), - Literal(null, StringType), - Literal(null, StringType)), "xxxx-XXXX-nnnn-4321") - checkEvaluation(new MaskShowLastN(Literal("abcd-EFGH-8765-4321"), Literal(6), Literal("Upper")), - "xxxx-UUUU-nnn5-4321") - checkEvaluation(new MaskShowLastN(Literal("")), "") - checkEvaluation(new MaskShowLastN(Literal("abcd-EFGH-8765-4321"), Literal(4), Literal("")), - "xxxx-XXXX-nnnn-4321") - checkEvaluation(MaskShowLastN(Literal("abcd-EFGH-8765-4321"), 1000, "", "", ""), - "abcd-EFGH-8765-4321") - checkEvaluation(MaskShowLastN(Literal("abcd-EFGH-8765-4321"), -1, "", "", ""), - "xxxx-XXXX-nnnn-nnnn") - // scalastyle:off nonascii - checkEvaluation(MaskShowLastN(Literal("Ul9U"), 2, "\u2200", null, null), "\u2200x9U") - checkEvaluation(new MaskShowLastN(Literal("あ, 𠀋, Hello World"), Literal(10)), - "あ, 𠀋, Xello World") - // scalastyle:on nonascii - } - - test("mask_hash") { - checkEvaluation(MaskHash(Literal("abcd-EFGH-8765-4321")), "60c713f5ec6912229d2060df1c322776") - checkEvaluation(MaskHash(Literal("")), "d41d8cd98f00b204e9800998ecf8427e") - checkEvaluation(MaskHash(Literal(null, StringType)), null) - // scalastyle:off nonascii - checkEvaluation(MaskHash(Literal("\u2200x9U")), "f1243ef123d516b1f32a3a75309e5711") - // scalastyle:on nonascii - } -} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullExpressionsSuite.scala index 424c3a4696077..6e07f7a59b730 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullExpressionsSuite.scala @@ -86,6 +86,13 @@ class NullExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(Coalesce(Seq(nullLit, lit, lit)), value) checkEvaluation(Coalesce(Seq(nullLit, nullLit, lit)), value) } + + val coalesce = Coalesce(Seq( + Literal.create(null, ArrayType(IntegerType, containsNull = false)), + Literal.create(Seq(1, 2, 3), ArrayType(IntegerType, containsNull = false)), + Literal.create(Seq(1, 2, 3, null), ArrayType(IntegerType, containsNull = true)))) + assert(coalesce.dataType === ArrayType(IntegerType, containsNull = true)) + checkEvaluation(coalesce, Seq(1, 2, 3)) } test("SPARK-16602 Nvl should support numeric-string cases") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeBlockSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeBlockSuite.scala index d2c6420eadb20..55569b6f2933e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeBlockSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeBlockSuite.scala @@ -65,7 +65,9 @@ class CodeBlockSuite extends SparkFunSuite { |boolean $isNull = false; |int $value = -1; """.stripMargin - val exprValues = code.exprValues + val exprValues = code.asInstanceOf[CodeBlock].blockInputs.collect { + case e: ExprValue => e + }.toSet assert(exprValues.size == 2) assert(exprValues === Set(value, isNull)) } @@ -94,7 +96,9 @@ class CodeBlockSuite extends SparkFunSuite { assert(code.toString == expected) - val exprValues = code.exprValues + val exprValues = code.children.flatMap(_.asInstanceOf[CodeBlock].blockInputs.collect { + case e: ExprValue => e + }).toSet assert(exprValues.size == 5) assert(exprValues === Set(isNull1, value1, isNull2, value2, literal)) } @@ -107,7 +111,7 @@ class CodeBlockSuite extends SparkFunSuite { assert(e.getMessage().contains(s"Can not interpolate ${obj.getClass.getName}")) } - test("replace expr values in code block") { + test("transform expr in code block") { val expr = JavaCode.expression("1 + 1", IntegerType) val isNull = JavaCode.isNullVariable("expr1_isNull") val exprInFunc = JavaCode.variable("expr1", IntegerType) @@ -120,11 +124,11 @@ class CodeBlockSuite extends SparkFunSuite { |}""".stripMargin val aliasedParam = JavaCode.variable("aliased", expr.javaType) - val aliasedInputs = code.asInstanceOf[CodeBlock].blockInputs.map { - case _: SimpleExprValue => aliasedParam - case other => other + + // We want to replace all occurrences of `expr` with the variable `aliasedParam`. + val aliasedCode = code.transformExprValues { + case SimpleExprValue("1 + 1", java.lang.Integer.TYPE) => aliasedParam } - val aliasedCode = CodeBlock(code.asInstanceOf[CodeBlock].codeParts, aliasedInputs).stripMargin val expected = code""" |callFunc(int $aliasedParam) { @@ -133,4 +137,61 @@ class CodeBlockSuite extends SparkFunSuite { |}""".stripMargin assert(aliasedCode.toString == expected.toString) } + + test ("transform expr in nested blocks") { + val expr = JavaCode.expression("1 + 1", IntegerType) + val isNull = JavaCode.isNullVariable("expr1_isNull") + val exprInFunc = JavaCode.variable("expr1", IntegerType) + + val funcs = Seq("callFunc1", "callFunc2", "callFunc3") + val subBlocks = funcs.map { funcName => + code""" + |$funcName(int $expr) { + | boolean $isNull = false; + | int $exprInFunc = $expr + 1; + |}""".stripMargin + } + + val aliasedParam = JavaCode.variable("aliased", expr.javaType) + + val block = code"${subBlocks(0)}\n${subBlocks(1)}\n${subBlocks(2)}" + val transformedBlock = block.transform { + case b: Block => b.transformExprValues { + case SimpleExprValue("1 + 1", java.lang.Integer.TYPE) => aliasedParam + } + }.asInstanceOf[CodeBlock] + + val expected1 = + code""" + |callFunc1(int aliased) { + | boolean expr1_isNull = false; + | int expr1 = aliased + 1; + |}""".stripMargin + + val expected2 = + code""" + |callFunc2(int aliased) { + | boolean expr1_isNull = false; + | int expr1 = aliased + 1; + |}""".stripMargin + + val expected3 = + code""" + |callFunc3(int aliased) { + | boolean expr1_isNull = false; + | int expr1 = aliased + 1; + |}""".stripMargin + + val exprValues = transformedBlock.children.flatMap { block => + block.asInstanceOf[CodeBlock].blockInputs.collect { + case e: ExprValue => e + } + }.toSet + + assert(transformedBlock.children(0).toString == expected1.toString) + assert(transformedBlock.children(1).toString == expected2.toString) + assert(transformedBlock.children(2).toString == expected3.toString) + assert(transformedBlock.toString == (expected1 + expected2 + expected3).toString) + assert(exprValues === Set(isNull, exprInFunc, aliasedParam)) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/xml/UDFXPathUtilSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/xml/UDFXPathUtilSuite.scala index c4cde7091154b..0fec15bc42c17 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/xml/UDFXPathUtilSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/xml/UDFXPathUtilSuite.scala @@ -77,6 +77,27 @@ class UDFXPathUtilSuite extends SparkFunSuite { assert(ret == "foo") } + test("embedFailure") { + import org.apache.commons.io.FileUtils + import java.io.File + val secretValue = String.valueOf(Math.random) + val tempFile = File.createTempFile("verifyembed", ".tmp") + tempFile.deleteOnExit() + val fname = tempFile.getAbsolutePath + + FileUtils.writeStringToFile(tempFile, secretValue) + + val xml = + s""" + | + |]> + |&embed; + """.stripMargin + val evaled = new UDFXPathUtil().evalString(xml, "/foo") + assert(evaled.isEmpty) + } + test("number eval") { var ret = util.evalNumber("truefalseb3c1-77", "a/c[2]") diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/xml/XPathExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/xml/XPathExpressionSuite.scala index bfa18a0919e45..c6f6d3abb860c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/xml/XPathExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/xml/XPathExpressionSuite.scala @@ -40,8 +40,9 @@ class XPathExpressionSuite extends SparkFunSuite with ExpressionEvalHelper { // Test error message for invalid XML document val e1 = intercept[RuntimeException] { testExpr("/a>", "a", null.asInstanceOf[T]) } - assert(e1.getCause.getMessage.contains("Invalid XML document") && - e1.getCause.getMessage.contains("/a>")) + assert(e1.getCause.getCause.getMessage.contains( + "XML document structures must start and end within the same entity.")) + assert(e1.getMessage.contains("/a>")) // Test error message for invalid xpath val e2 = intercept[RuntimeException] { testExpr("", "!#$", null.asInstanceOf[T]) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala index 3f41f4b144096..8b05ba32e6eef 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala @@ -19,7 +19,6 @@ package org.apache.spark.sql.catalyst.optimizer import scala.reflect.runtime.universe.TypeTag -import org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder @@ -370,5 +369,13 @@ class ColumnPruningSuite extends PlanTest { comparePlans(optimized2, expected2.analyze) } + test("SPARK-24696 ColumnPruning rule fails to remove extra Project") { + val input = LocalRelation('key.int, 'value.string) + val query = input.select('key).where(rand(0L) > 0.5).where('key < 10).analyze + val optimized = Optimize.execute(query) + val expected = input.where(rand(0L) > 0.5).where('key < 10).select('key).analyze + comparePlans(optimized, expected) + } + // todo: add more tests for column pruning } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/LogicalPlanSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/LogicalPlanSuite.scala index 14041747fd20e..bf569cb869428 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/LogicalPlanSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/LogicalPlanSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.plans import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} +import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeReference, Coalesce, Literal, NamedExpression} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.types.IntegerType @@ -101,4 +101,22 @@ class LogicalPlanSuite extends SparkFunSuite { assert(TestBinaryRelation(relation, incrementalRelation).isStreaming === true) assert(TestBinaryRelation(incrementalRelation, incrementalRelation).isStreaming) } + + test("transformExpressions works with a Stream") { + val id1 = NamedExpression.newExprId + val id2 = NamedExpression.newExprId + val plan = Project(Stream( + Alias(Literal(1), "a")(exprId = id1), + Alias(Literal(2), "b")(exprId = id2)), + OneRowRelation()) + val result = plan.transformExpressions { + case Literal(v: Int, IntegerType) if v != 1 => + Literal(v + 1, IntegerType) + } + val expected = Project(Stream( + Alias(Literal(1), "a")(exprId = id1), + Alias(Literal(3), "b")(exprId = id2)), + OneRowRelation()) + assert(result.sameResult(expected)) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala index 84d0ba7bef642..b7092f4c42d4c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala @@ -29,14 +29,14 @@ import org.json4s.jackson.JsonMethods._ import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.{FunctionIdentifier, InternalRow, TableIdentifier} -import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogStorageFormat, CatalogTable, CatalogTableType, FunctionResource, JarResource} +import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.dsl.expressions.DslString import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.catalyst.plans.{LeftOuter, NaturalJoin} import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, Union} import org.apache.spark.sql.catalyst.plans.physical.{IdentityBroadcastMode, RoundRobinPartitioning, SinglePartition} -import org.apache.spark.sql.types.{BooleanType, DoubleType, FloatType, IntegerType, Metadata, NullType, StringType, StructField, StructType} +import org.apache.spark.sql.types._ import org.apache.spark.storage.StorageLevel case class Dummy(optKey: Option[Expression]) extends Expression with CodegenFallback { @@ -574,4 +574,25 @@ class TreeNodeSuite extends SparkFunSuite { val right = JsonMethods.parse(rightJson) assert(left == right) } + + test("transform works on stream of children") { + val before = Coalesce(Stream(Literal(1), Literal(2))) + // Note it is a bit tricky to exhibit the broken behavior. Basically we want to create the + // situation in which the TreeNode.mapChildren function's change detection is not triggered. A + // stream's first element is typically materialized, so in order to not trip the TreeNode change + // detection logic, we should not change the first element in the sequence. + val result = before.transform { + case Literal(v: Int, IntegerType) if v != 1 => + Literal(v + 1, IntegerType) + } + val expected = Coalesce(Stream(Literal(1), Literal(3))) + assert(result === expected) + } + + test("withNewChildren on stream of children") { + val before = Coalesce(Stream(Literal(1), Literal(2))) + val result = before.withNewChildren(Stream(Literal(1), Literal(3))) + val expected = Coalesce(Stream(Literal(1), Literal(3))) + assert(result === expected) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/ComplexDataSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/ComplexDataSuite.scala index 9d285916bcf42..229e32479082c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/ComplexDataSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/ComplexDataSuite.scala @@ -104,4 +104,40 @@ class ComplexDataSuite extends SparkFunSuite { // The copied data should not be changed externally. assert(copied.getStruct(0, 1).getUTF8String(0).toString == "a") } + + test("SPARK-24659: GenericArrayData.equals should respect element type differences") { + import scala.reflect.ClassTag + + // Expected positive cases + def arraysShouldEqual[T: ClassTag](element: T*): Unit = { + val array1 = new GenericArrayData(Array[T](element: _*)) + val array2 = new GenericArrayData(Array[T](element: _*)) + assert(array1.equals(array2)) + } + arraysShouldEqual(true, false) // Boolean + arraysShouldEqual(0.toByte, 123.toByte, Byte.MinValue, Byte.MaxValue) // Byte + arraysShouldEqual(0.toShort, 123.toShort, Short.MinValue, Short.MaxValue) // Short + arraysShouldEqual(0, 123, -65536, Int.MinValue, Int.MaxValue) // Int + arraysShouldEqual(0L, 123L, -65536L, Long.MinValue, Long.MaxValue) // Long + arraysShouldEqual(0.0F, 123.0F, Float.MinValue, Float.MaxValue, Float.MinPositiveValue, + Float.PositiveInfinity, Float.NegativeInfinity, Float.NaN) // Float + arraysShouldEqual(0.0, 123.0, Double.MinValue, Double.MaxValue, Double.MinPositiveValue, + Double.PositiveInfinity, Double.NegativeInfinity, Double.NaN) // Double + arraysShouldEqual(Array[Byte](123.toByte), Array[Byte](), null) // SQL Binary + arraysShouldEqual(UTF8String.fromString("foo"), null) // SQL String + + // Expected negative cases + // Spark SQL considers cases like array vs array to be incompatible, + // so an underlying implementation of array type should return false in such cases. + def arraysShouldNotEqual[T: ClassTag, U: ClassTag](element1: T, element2: U): Unit = { + val array1 = new GenericArrayData(Array[T](element1)) + val array2 = new GenericArrayData(Array[U](element2)) + assert(!array1.equals(array2)) + } + arraysShouldNotEqual(true, 1) // Boolean <-> Int + arraysShouldNotEqual(123.toByte, 123) // Byte <-> Int + arraysShouldNotEqual(123.toByte, 123L) // Byte <-> Long + arraysShouldNotEqual(123.toShort, 123) // Short <-> Int + arraysShouldNotEqual(123, 123L) // Int <-> Long + } } diff --git a/sql/core/benchmarks/FilterPushdownBenchmark-results.txt b/sql/core/benchmarks/FilterPushdownBenchmark-results.txt new file mode 100644 index 0000000000000..2215ed91e2018 --- /dev/null +++ b/sql/core/benchmarks/FilterPushdownBenchmark-results.txt @@ -0,0 +1,704 @@ +================================================================================================ +Pushdown for many distinct value case +================================================================================================ + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 0 string row (value IS NULL): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 8970 / 9122 1.8 570.3 1.0X +Parquet Vectorized (Pushdown) 471 / 491 33.4 30.0 19.0X +Native ORC Vectorized 7661 / 7853 2.1 487.0 1.2X +Native ORC Vectorized (Pushdown) 1134 / 1161 13.9 72.1 7.9X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 0 string row ('7864320' < value < '7864320'): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 9246 / 9297 1.7 587.8 1.0X +Parquet Vectorized (Pushdown) 480 / 488 32.8 30.5 19.3X +Native ORC Vectorized 7838 / 7850 2.0 498.3 1.2X +Native ORC Vectorized (Pushdown) 1054 / 1118 14.9 67.0 8.8X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 1 string row (value = '7864320'): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 8989 / 9100 1.7 571.5 1.0X +Parquet Vectorized (Pushdown) 448 / 467 35.1 28.5 20.1X +Native ORC Vectorized 7680 / 7768 2.0 488.3 1.2X +Native ORC Vectorized (Pushdown) 1067 / 1118 14.7 67.8 8.4X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 1 string row (value <=> '7864320'): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 9115 / 9266 1.7 579.5 1.0X +Parquet Vectorized (Pushdown) 466 / 492 33.7 29.7 19.5X +Native ORC Vectorized 7800 / 7914 2.0 495.9 1.2X +Native ORC Vectorized (Pushdown) 1075 / 1102 14.6 68.4 8.5X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 1 string row ('7864320' <= value <= '7864320'): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 9099 / 9237 1.7 578.5 1.0X +Parquet Vectorized (Pushdown) 462 / 475 34.1 29.3 19.7X +Native ORC Vectorized 7847 / 7925 2.0 498.9 1.2X +Native ORC Vectorized (Pushdown) 1078 / 1114 14.6 68.5 8.4X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select all string rows (value IS NOT NULL): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 19303 / 19547 0.8 1227.3 1.0X +Parquet Vectorized (Pushdown) 19924 / 20089 0.8 1266.7 1.0X +Native ORC Vectorized 18725 / 19079 0.8 1190.5 1.0X +Native ORC Vectorized (Pushdown) 19310 / 19492 0.8 1227.7 1.0X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 0 int row (value IS NULL): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 8117 / 8323 1.9 516.1 1.0X +Parquet Vectorized (Pushdown) 484 / 494 32.5 30.8 16.8X +Native ORC Vectorized 6811 / 7036 2.3 433.0 1.2X +Native ORC Vectorized (Pushdown) 1061 / 1082 14.8 67.5 7.6X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 0 int row (7864320 < value < 7864320): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 8105 / 8140 1.9 515.3 1.0X +Parquet Vectorized (Pushdown) 478 / 505 32.9 30.4 17.0X +Native ORC Vectorized 6914 / 7211 2.3 439.6 1.2X +Native ORC Vectorized (Pushdown) 1044 / 1064 15.1 66.4 7.8X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 1 int row (value = 7864320): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 7983 / 8116 2.0 507.6 1.0X +Parquet Vectorized (Pushdown) 464 / 487 33.9 29.5 17.2X +Native ORC Vectorized 6703 / 6774 2.3 426.1 1.2X +Native ORC Vectorized (Pushdown) 1017 / 1058 15.5 64.6 7.9X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 1 int row (value <=> 7864320): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 7942 / 7983 2.0 504.9 1.0X +Parquet Vectorized (Pushdown) 468 / 479 33.6 29.7 17.0X +Native ORC Vectorized 6677 / 6779 2.4 424.5 1.2X +Native ORC Vectorized (Pushdown) 1021 / 1068 15.4 64.9 7.8X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 1 int row (7864320 <= value <= 7864320): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 7909 / 7958 2.0 502.8 1.0X +Parquet Vectorized (Pushdown) 485 / 494 32.4 30.8 16.3X +Native ORC Vectorized 6751 / 6846 2.3 429.2 1.2X +Native ORC Vectorized (Pushdown) 1043 / 1077 15.1 66.3 7.6X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 1 int row (7864319 < value < 7864321): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 8010 / 8033 2.0 509.2 1.0X +Parquet Vectorized (Pushdown) 472 / 489 33.3 30.0 17.0X +Native ORC Vectorized 6655 / 6808 2.4 423.1 1.2X +Native ORC Vectorized (Pushdown) 1015 / 1067 15.5 64.5 7.9X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 10% int rows (value < 1572864): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 8983 / 9035 1.8 571.1 1.0X +Parquet Vectorized (Pushdown) 2204 / 2231 7.1 140.1 4.1X +Native ORC Vectorized 7864 / 8011 2.0 500.0 1.1X +Native ORC Vectorized (Pushdown) 2674 / 2789 5.9 170.0 3.4X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 50% int rows (value < 7864320): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 12723 / 12903 1.2 808.9 1.0X +Parquet Vectorized (Pushdown) 9112 / 9282 1.7 579.3 1.4X +Native ORC Vectorized 12090 / 12230 1.3 768.7 1.1X +Native ORC Vectorized (Pushdown) 9242 / 9372 1.7 587.6 1.4X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 90% int rows (value < 14155776): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 16453 / 16678 1.0 1046.1 1.0X +Parquet Vectorized (Pushdown) 15997 / 16262 1.0 1017.0 1.0X +Native ORC Vectorized 16652 / 17070 0.9 1058.7 1.0X +Native ORC Vectorized (Pushdown) 15843 / 16112 1.0 1007.2 1.0X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select all int rows (value IS NOT NULL): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 17098 / 17254 0.9 1087.1 1.0X +Parquet Vectorized (Pushdown) 17302 / 17529 0.9 1100.1 1.0X +Native ORC Vectorized 16790 / 17098 0.9 1067.5 1.0X +Native ORC Vectorized (Pushdown) 17329 / 17914 0.9 1101.7 1.0X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select all int rows (value > -1): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 17088 / 17392 0.9 1086.4 1.0X +Parquet Vectorized (Pushdown) 17609 / 17863 0.9 1119.5 1.0X +Native ORC Vectorized 18334 / 69831 0.9 1165.7 0.9X +Native ORC Vectorized (Pushdown) 17465 / 17629 0.9 1110.4 1.0X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select all int rows (value != -1): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 16903 / 17233 0.9 1074.6 1.0X +Parquet Vectorized (Pushdown) 16945 / 17032 0.9 1077.3 1.0X +Native ORC Vectorized 16377 / 16762 1.0 1041.2 1.0X +Native ORC Vectorized (Pushdown) 16950 / 17212 0.9 1077.7 1.0X + + +================================================================================================ +Pushdown for few distinct value case (use dictionary encoding) +================================================================================================ + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 0 distinct string row (value IS NULL): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 7245 / 7322 2.2 460.7 1.0X +Parquet Vectorized (Pushdown) 378 / 389 41.6 24.0 19.2X +Native ORC Vectorized 6720 / 6778 2.3 427.2 1.1X +Native ORC Vectorized (Pushdown) 1009 / 1032 15.6 64.2 7.2X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 0 distinct string row ('100' < value < '100'): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 7627 / 7795 2.1 484.9 1.0X +Parquet Vectorized (Pushdown) 384 / 406 41.0 24.4 19.9X +Native ORC Vectorized 6724 / 7824 2.3 427.5 1.1X +Native ORC Vectorized (Pushdown) 968 / 986 16.3 61.5 7.9X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 1 distinct string row (value = '100'): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 7157 / 7534 2.2 455.0 1.0X +Parquet Vectorized (Pushdown) 542 / 565 29.0 34.5 13.2X +Native ORC Vectorized 6716 / 7214 2.3 427.0 1.1X +Native ORC Vectorized (Pushdown) 1212 / 1288 13.0 77.0 5.9X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 1 distinct string row (value <=> '100'): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 7368 / 7552 2.1 468.4 1.0X +Parquet Vectorized (Pushdown) 544 / 556 28.9 34.6 13.5X +Native ORC Vectorized 6740 / 6867 2.3 428.5 1.1X +Native ORC Vectorized (Pushdown) 1230 / 1426 12.8 78.2 6.0X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 1 distinct string row ('100' <= value <= '100'): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 7427 / 7734 2.1 472.2 1.0X +Parquet Vectorized (Pushdown) 556 / 568 28.3 35.4 13.3X +Native ORC Vectorized 6847 / 7059 2.3 435.3 1.1X +Native ORC Vectorized (Pushdown) 1226 / 1230 12.8 77.9 6.1X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select all distinct string rows (value IS NOT NULL): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 16998 / 17311 0.9 1080.7 1.0X +Parquet Vectorized (Pushdown) 16977 / 17250 0.9 1079.4 1.0X +Native ORC Vectorized 18447 / 19852 0.9 1172.8 0.9X +Native ORC Vectorized (Pushdown) 16614 / 17102 0.9 1056.3 1.0X + + +================================================================================================ +Pushdown benchmark for StringStartsWith +================================================================================================ + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +StringStartsWith filter: (value like '10%'): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 9705 / 10814 1.6 617.0 1.0X +Parquet Vectorized (Pushdown) 3086 / 3574 5.1 196.2 3.1X +Native ORC Vectorized 10094 / 10695 1.6 641.8 1.0X +Native ORC Vectorized (Pushdown) 9611 / 9999 1.6 611.0 1.0X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +StringStartsWith filter: (value like '1000%'): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 8016 / 8183 2.0 509.7 1.0X +Parquet Vectorized (Pushdown) 444 / 457 35.4 28.2 18.0X +Native ORC Vectorized 6970 / 7169 2.3 443.2 1.2X +Native ORC Vectorized (Pushdown) 7447 / 7503 2.1 473.5 1.1X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +StringStartsWith filter: (value like '786432%'): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 7908 / 8046 2.0 502.8 1.0X +Parquet Vectorized (Pushdown) 408 / 429 38.6 25.9 19.4X +Native ORC Vectorized 7021 / 7100 2.2 446.4 1.1X +Native ORC Vectorized (Pushdown) 7310 / 7490 2.2 464.8 1.1X + + +================================================================================================ +Pushdown benchmark for decimal +================================================================================================ + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 1 decimal(9, 2) row (value = 7864320): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 4546 / 4743 3.5 289.0 1.0X +Parquet Vectorized (Pushdown) 161 / 175 98.0 10.2 28.3X +Native ORC Vectorized 5721 / 5842 2.7 363.7 0.8X +Native ORC Vectorized (Pushdown) 1019 / 1070 15.4 64.8 4.5X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 10% decimal(9, 2) rows (value < 1572864): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 6340 / 7236 2.5 403.1 1.0X +Parquet Vectorized (Pushdown) 3052 / 3164 5.2 194.1 2.1X +Native ORC Vectorized 8370 / 9214 1.9 532.1 0.8X +Native ORC Vectorized (Pushdown) 4137 / 4242 3.8 263.0 1.5X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 50% decimal(9, 2) rows (value < 7864320): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 12976 / 13249 1.2 825.0 1.0X +Parquet Vectorized (Pushdown) 12655 / 13570 1.2 804.6 1.0X +Native ORC Vectorized 15562 / 15950 1.0 989.4 0.8X +Native ORC Vectorized (Pushdown) 15042 / 15668 1.0 956.3 0.9X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 90% decimal(9, 2) rows (value < 14155776): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 14303 / 14616 1.1 909.3 1.0X +Parquet Vectorized (Pushdown) 14380 / 14649 1.1 914.3 1.0X +Native ORC Vectorized 16964 / 17358 0.9 1078.5 0.8X +Native ORC Vectorized (Pushdown) 17255 / 17874 0.9 1097.0 0.8X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 1 decimal(18, 2) row (value = 7864320): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 4701 / 6416 3.3 298.9 1.0X +Parquet Vectorized (Pushdown) 128 / 164 122.8 8.1 36.7X +Native ORC Vectorized 5698 / 7904 2.8 362.3 0.8X +Native ORC Vectorized (Pushdown) 913 / 942 17.2 58.0 5.2X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 10% decimal(18, 2) rows (value < 1572864): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 5376 / 5461 2.9 341.8 1.0X +Parquet Vectorized (Pushdown) 1479 / 1543 10.6 94.0 3.6X +Native ORC Vectorized 6640 / 6748 2.4 422.2 0.8X +Native ORC Vectorized (Pushdown) 2438 / 2479 6.5 155.0 2.2X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 50% decimal(18, 2) rows (value < 7864320): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 9224 / 9356 1.7 586.5 1.0X +Parquet Vectorized (Pushdown) 7172 / 7415 2.2 456.0 1.3X +Native ORC Vectorized 11017 / 11408 1.4 700.4 0.8X +Native ORC Vectorized (Pushdown) 8771 / 10218 1.8 557.7 1.1X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 90% decimal(18, 2) rows (value < 14155776): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 13933 / 15990 1.1 885.8 1.0X +Parquet Vectorized (Pushdown) 12683 / 12942 1.2 806.4 1.1X +Native ORC Vectorized 16344 / 20196 1.0 1039.1 0.9X +Native ORC Vectorized (Pushdown) 15162 / 16627 1.0 964.0 0.9X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 1 decimal(38, 2) row (value = 7864320): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 7102 / 8282 2.2 451.5 1.0X +Parquet Vectorized (Pushdown) 124 / 150 126.4 7.9 57.1X +Native ORC Vectorized 5811 / 6883 2.7 369.5 1.2X +Native ORC Vectorized (Pushdown) 1121 / 1502 14.0 71.3 6.3X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 10% decimal(38, 2) rows (value < 1572864): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 6894 / 7562 2.3 438.3 1.0X +Parquet Vectorized (Pushdown) 1863 / 1980 8.4 118.4 3.7X +Native ORC Vectorized 6812 / 6848 2.3 433.1 1.0X +Native ORC Vectorized (Pushdown) 2511 / 2598 6.3 159.7 2.7X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 50% decimal(38, 2) rows (value < 7864320): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 11732 / 12183 1.3 745.9 1.0X +Parquet Vectorized (Pushdown) 8912 / 9945 1.8 566.6 1.3X +Native ORC Vectorized 11499 / 12387 1.4 731.1 1.0X +Native ORC Vectorized (Pushdown) 9328 / 9382 1.7 593.1 1.3X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 90% decimal(38, 2) rows (value < 14155776): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 16272 / 16328 1.0 1034.6 1.0X +Parquet Vectorized (Pushdown) 15714 / 18100 1.0 999.1 1.0X +Native ORC Vectorized 16539 / 18897 1.0 1051.5 1.0X +Native ORC Vectorized (Pushdown) 16328 / 17306 1.0 1038.1 1.0X + + +================================================================================================ +Pushdown benchmark for InSet -> InFilters +================================================================================================ + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +InSet -> InFilters (values count: 5, distribution: 10): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 7993 / 8104 2.0 508.2 1.0X +Parquet Vectorized (Pushdown) 507 / 532 31.0 32.2 15.8X +Native ORC Vectorized 6922 / 7163 2.3 440.1 1.2X +Native ORC Vectorized (Pushdown) 1017 / 1058 15.5 64.6 7.9X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +InSet -> InFilters (values count: 5, distribution: 50): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 7855 / 7963 2.0 499.4 1.0X +Parquet Vectorized (Pushdown) 503 / 516 31.3 32.0 15.6X +Native ORC Vectorized 6825 / 6954 2.3 433.9 1.2X +Native ORC Vectorized (Pushdown) 1019 / 1044 15.4 64.8 7.7X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +InSet -> InFilters (values count: 5, distribution: 90): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 7858 / 7928 2.0 499.6 1.0X +Parquet Vectorized (Pushdown) 490 / 519 32.1 31.1 16.0X +Native ORC Vectorized 7079 / 7966 2.2 450.1 1.1X +Native ORC Vectorized (Pushdown) 1276 / 1673 12.3 81.1 6.2X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +InSet -> InFilters (values count: 10, distribution: 10): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 8007 / 11155 2.0 509.0 1.0X +Parquet Vectorized (Pushdown) 519 / 540 30.3 33.0 15.4X +Native ORC Vectorized 6848 / 7072 2.3 435.4 1.2X +Native ORC Vectorized (Pushdown) 1026 / 1050 15.3 65.2 7.8X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +InSet -> InFilters (values count: 10, distribution: 50): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 7876 / 7956 2.0 500.7 1.0X +Parquet Vectorized (Pushdown) 521 / 535 30.2 33.1 15.1X +Native ORC Vectorized 7051 / 7368 2.2 448.3 1.1X +Native ORC Vectorized (Pushdown) 1014 / 1035 15.5 64.5 7.8X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +InSet -> InFilters (values count: 10, distribution: 90): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 7897 / 8229 2.0 502.1 1.0X +Parquet Vectorized (Pushdown) 513 / 530 30.7 32.6 15.4X +Native ORC Vectorized 6730 / 6990 2.3 427.9 1.2X +Native ORC Vectorized (Pushdown) 1003 / 1036 15.7 63.8 7.9X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +InSet -> InFilters (values count: 50, distribution: 10): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 7967 / 8175 2.0 506.5 1.0X +Parquet Vectorized (Pushdown) 8155 / 8434 1.9 518.5 1.0X +Native ORC Vectorized 7002 / 7107 2.2 445.2 1.1X +Native ORC Vectorized (Pushdown) 1092 / 1139 14.4 69.4 7.3X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +InSet -> InFilters (values count: 50, distribution: 50): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 8032 / 8122 2.0 510.7 1.0X +Parquet Vectorized (Pushdown) 8141 / 8908 1.9 517.6 1.0X +Native ORC Vectorized 7140 / 7387 2.2 454.0 1.1X +Native ORC Vectorized (Pushdown) 1156 / 1220 13.6 73.5 6.9X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +InSet -> InFilters (values count: 50, distribution: 90): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 8088 / 8350 1.9 514.2 1.0X +Parquet Vectorized (Pushdown) 8629 / 8702 1.8 548.6 0.9X +Native ORC Vectorized 7480 / 7886 2.1 475.6 1.1X +Native ORC Vectorized (Pushdown) 1106 / 1145 14.2 70.3 7.3X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +InSet -> InFilters (values count: 100, distribution: 10): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 8028 / 8165 2.0 510.4 1.0X +Parquet Vectorized (Pushdown) 8349 / 8674 1.9 530.8 1.0X +Native ORC Vectorized 7107 / 7354 2.2 451.8 1.1X +Native ORC Vectorized (Pushdown) 1175 / 1207 13.4 74.7 6.8X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +InSet -> InFilters (values count: 100, distribution: 50): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 8041 / 8195 2.0 511.2 1.0X +Parquet Vectorized (Pushdown) 8466 / 8604 1.9 538.2 0.9X +Native ORC Vectorized 7116 / 7286 2.2 452.4 1.1X +Native ORC Vectorized (Pushdown) 1197 / 1214 13.1 76.1 6.7X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +InSet -> InFilters (values count: 100, distribution: 90): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 7998 / 8311 2.0 508.5 1.0X +Parquet Vectorized (Pushdown) 9366 / 11257 1.7 595.5 0.9X +Native ORC Vectorized 7856 / 9273 2.0 499.5 1.0X +Native ORC Vectorized (Pushdown) 1350 / 1747 11.7 85.8 5.9X + + +================================================================================================ +Pushdown benchmark for tinyint +================================================================================================ + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 1 tinyint row (value = CAST(63 AS tinyint)): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 3461 / 3997 4.5 220.1 1.0X +Parquet Vectorized (Pushdown) 270 / 315 58.4 17.1 12.8X +Native ORC Vectorized 4107 / 5372 3.8 261.1 0.8X +Native ORC Vectorized (Pushdown) 778 / 1553 20.2 49.5 4.4X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 10% tinyint rows (value < CAST(12 AS tinyint)): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 4771 / 6655 3.3 303.3 1.0X +Parquet Vectorized (Pushdown) 1322 / 1606 11.9 84.0 3.6X +Native ORC Vectorized 4437 / 4572 3.5 282.1 1.1X +Native ORC Vectorized (Pushdown) 1781 / 1976 8.8 113.2 2.7X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 50% tinyint rows (value < CAST(63 AS tinyint)): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 7433 / 7752 2.1 472.6 1.0X +Parquet Vectorized (Pushdown) 5863 / 5913 2.7 372.8 1.3X +Native ORC Vectorized 7986 / 8084 2.0 507.7 0.9X +Native ORC Vectorized (Pushdown) 6522 / 6608 2.4 414.6 1.1X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 90% tinyint rows (value < CAST(114 AS tinyint)): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 11190 / 11519 1.4 711.4 1.0X +Parquet Vectorized (Pushdown) 10861 / 11206 1.4 690.5 1.0X +Native ORC Vectorized 11622 / 12196 1.4 738.9 1.0X +Native ORC Vectorized (Pushdown) 11377 / 11654 1.4 723.3 1.0X + + +================================================================================================ +Pushdown benchmark for Timestamp +================================================================================================ + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 1 timestamp stored as INT96 row (value = CAST(7864320 AS timestamp)): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 4784 / 4956 3.3 304.2 1.0X +Parquet Vectorized (Pushdown) 4838 / 4917 3.3 307.6 1.0X +Native ORC Vectorized 3923 / 4173 4.0 249.4 1.2X +Native ORC Vectorized (Pushdown) 894 / 943 17.6 56.8 5.4X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 10% timestamp stored as INT96 rows (value < CAST(1572864 AS timestamp)): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 5686 / 5901 2.8 361.5 1.0X +Parquet Vectorized (Pushdown) 5555 / 5895 2.8 353.2 1.0X +Native ORC Vectorized 4844 / 4957 3.2 308.0 1.2X +Native ORC Vectorized (Pushdown) 2141 / 2230 7.3 136.1 2.7X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 50% timestamp stored as INT96 rows (value < CAST(7864320 AS timestamp)): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 9100 / 9421 1.7 578.6 1.0X +Parquet Vectorized (Pushdown) 9122 / 9496 1.7 580.0 1.0X +Native ORC Vectorized 8365 / 8874 1.9 531.9 1.1X +Native ORC Vectorized (Pushdown) 7128 / 7376 2.2 453.2 1.3X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 90% timestamp stored as INT96 rows (value < CAST(14155776 AS timestamp)): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 12764 / 13120 1.2 811.5 1.0X +Parquet Vectorized (Pushdown) 12656 / 13003 1.2 804.7 1.0X +Native ORC Vectorized 13096 / 13233 1.2 832.6 1.0X +Native ORC Vectorized (Pushdown) 12710 / 15611 1.2 808.1 1.0X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 1 timestamp stored as TIMESTAMP_MICROS row (value = CAST(7864320 AS timestamp)): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 4381 / 4796 3.6 278.5 1.0X +Parquet Vectorized (Pushdown) 122 / 137 129.3 7.7 36.0X +Native ORC Vectorized 3913 / 3988 4.0 248.8 1.1X +Native ORC Vectorized (Pushdown) 905 / 945 17.4 57.6 4.8X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 10% timestamp stored as TIMESTAMP_MICROS rows (value < CAST(1572864 AS timestamp)): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 5145 / 5184 3.1 327.1 1.0X +Parquet Vectorized (Pushdown) 1426 / 1519 11.0 90.7 3.6X +Native ORC Vectorized 4827 / 4901 3.3 306.9 1.1X +Native ORC Vectorized (Pushdown) 2133 / 2210 7.4 135.6 2.4X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 50% timestamp stored as TIMESTAMP_MICROS rows (value < CAST(7864320 AS timestamp)): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 9234 / 9516 1.7 587.1 1.0X +Parquet Vectorized (Pushdown) 6752 / 7046 2.3 429.3 1.4X +Native ORC Vectorized 8418 / 8998 1.9 535.2 1.1X +Native ORC Vectorized (Pushdown) 7199 / 7314 2.2 457.7 1.3X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 90% timestamp stored as TIMESTAMP_MICROS rows (value < CAST(14155776 AS timestamp)): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 12414 / 12458 1.3 789.2 1.0X +Parquet Vectorized (Pushdown) 12094 / 12249 1.3 768.9 1.0X +Native ORC Vectorized 12198 / 13755 1.3 775.5 1.0X +Native ORC Vectorized (Pushdown) 12205 / 12431 1.3 776.0 1.0X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 1 timestamp stored as TIMESTAMP_MILLIS row (value = CAST(7864320 AS timestamp)): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 4369 / 4515 3.6 277.8 1.0X +Parquet Vectorized (Pushdown) 116 / 125 136.2 7.3 37.8X +Native ORC Vectorized 3965 / 4703 4.0 252.1 1.1X +Native ORC Vectorized (Pushdown) 892 / 1162 17.6 56.7 4.9X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 10% timestamp stored as TIMESTAMP_MILLIS rows (value < CAST(1572864 AS timestamp)): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 5211 / 5409 3.0 331.3 1.0X +Parquet Vectorized (Pushdown) 1427 / 1438 11.0 90.7 3.7X +Native ORC Vectorized 4719 / 4883 3.3 300.1 1.1X +Native ORC Vectorized (Pushdown) 2191 / 2228 7.2 139.3 2.4X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 50% timestamp stored as TIMESTAMP_MILLIS rows (value < CAST(7864320 AS timestamp)): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 8716 / 8953 1.8 554.2 1.0X +Parquet Vectorized (Pushdown) 6632 / 6968 2.4 421.7 1.3X +Native ORC Vectorized 8376 / 9118 1.9 532.5 1.0X +Native ORC Vectorized (Pushdown) 7218 / 7609 2.2 458.9 1.2X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 90% timestamp stored as TIMESTAMP_MILLIS rows (value < CAST(14155776 AS timestamp)): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 12264 / 12452 1.3 779.7 1.0X +Parquet Vectorized (Pushdown) 11766 / 11927 1.3 748.0 1.0X +Native ORC Vectorized 12101 / 12301 1.3 769.3 1.0X +Native ORC Vectorized (Pushdown) 11983 / 12651 1.3 761.9 1.0X + diff --git a/sql/core/pom.xml b/sql/core/pom.xml index f270c70fbfcf0..18ae314309d7b 100644 --- a/sql/core/pom.xml +++ b/sql/core/pom.xml @@ -118,7 +118,7 @@ org.apache.xbean - xbean-asm5-shaded + xbean-asm6-shaded org.scalacheck diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java index c7c4c7b3e7715..c8cf44b51df77 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java @@ -20,8 +20,8 @@ import java.io.IOException; import org.apache.spark.SparkEnv; +import org.apache.spark.TaskContext; import org.apache.spark.internal.config.package$; -import org.apache.spark.memory.TaskMemoryManager; import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.catalyst.expressions.UnsafeProjection; import org.apache.spark.sql.catalyst.expressions.UnsafeRow; @@ -82,7 +82,7 @@ public static boolean supportsAggregationBufferSchema(StructType schema) { * @param emptyAggregationBuffer the default value for new keys (a "zero" of the agg. function) * @param aggregationBufferSchema the schema of the aggregation buffer, used for row conversion. * @param groupingKeySchema the schema of the grouping key, used for row conversion. - * @param taskMemoryManager the memory manager used to allocate our Unsafe memory structures. + * @param taskContext the current task context. * @param initialCapacity the initial capacity of the map (a sizing hint to avoid re-hashing). * @param pageSizeBytes the data page size, in bytes; limits the maximum record size. */ @@ -90,19 +90,26 @@ public UnsafeFixedWidthAggregationMap( InternalRow emptyAggregationBuffer, StructType aggregationBufferSchema, StructType groupingKeySchema, - TaskMemoryManager taskMemoryManager, + TaskContext taskContext, int initialCapacity, long pageSizeBytes) { this.aggregationBufferSchema = aggregationBufferSchema; this.currentAggregationBuffer = new UnsafeRow(aggregationBufferSchema.length()); this.groupingKeyProjection = UnsafeProjection.create(groupingKeySchema); this.groupingKeySchema = groupingKeySchema; - this.map = - new BytesToBytesMap(taskMemoryManager, initialCapacity, pageSizeBytes, true); + this.map = new BytesToBytesMap( + taskContext.taskMemoryManager(), initialCapacity, pageSizeBytes, true); // Initialize the buffer for aggregation value final UnsafeProjection valueProjection = UnsafeProjection.create(aggregationBufferSchema); this.emptyAggregationBuffer = valueProjection.apply(emptyAggregationBuffer).getBytes(); + + // Register a cleanup task with TaskContext to ensure that memory is guaranteed to be freed at + // the end of the task. This is necessary to avoid memory leaks in when the downstream operator + // does not fully consume the aggregation map's output (e.g. aggregate followed by limit). + taskContext.addTaskCompletionListener(context -> { + free(); + }); } /** diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java index 4733f36174f42..6fdadde628551 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java @@ -216,12 +216,12 @@ protected UTF8String getBytesAsUTF8String(int rowId, int count) { @Override public void putShort(int rowId, short value) { - Platform.putShort(null, data + 2 * rowId, value); + Platform.putShort(null, data + 2L * rowId, value); } @Override public void putShorts(int rowId, int count, short value) { - long offset = data + 2 * rowId; + long offset = data + 2L * rowId; for (int i = 0; i < count; ++i, offset += 2) { Platform.putShort(null, offset, value); } @@ -229,20 +229,20 @@ public void putShorts(int rowId, int count, short value) { @Override public void putShorts(int rowId, int count, short[] src, int srcIndex) { - Platform.copyMemory(src, Platform.SHORT_ARRAY_OFFSET + srcIndex * 2, - null, data + 2 * rowId, count * 2); + Platform.copyMemory(src, Platform.SHORT_ARRAY_OFFSET + srcIndex * 2L, + null, data + 2L * rowId, count * 2L); } @Override public void putShorts(int rowId, int count, byte[] src, int srcIndex) { Platform.copyMemory(src, Platform.BYTE_ARRAY_OFFSET + srcIndex, - null, data + rowId * 2, count * 2); + null, data + rowId * 2L, count * 2L); } @Override public short getShort(int rowId) { if (dictionary == null) { - return Platform.getShort(null, data + 2 * rowId); + return Platform.getShort(null, data + 2L * rowId); } else { return (short) dictionary.decodeToInt(dictionaryIds.getDictId(rowId)); } @@ -252,7 +252,7 @@ public short getShort(int rowId) { public short[] getShorts(int rowId, int count) { assert(dictionary == null); short[] array = new short[count]; - Platform.copyMemory(null, data + rowId * 2, array, Platform.SHORT_ARRAY_OFFSET, count * 2); + Platform.copyMemory(null, data + rowId * 2L, array, Platform.SHORT_ARRAY_OFFSET, count * 2L); return array; } @@ -262,12 +262,12 @@ public short[] getShorts(int rowId, int count) { @Override public void putInt(int rowId, int value) { - Platform.putInt(null, data + 4 * rowId, value); + Platform.putInt(null, data + 4L * rowId, value); } @Override public void putInts(int rowId, int count, int value) { - long offset = data + 4 * rowId; + long offset = data + 4L * rowId; for (int i = 0; i < count; ++i, offset += 4) { Platform.putInt(null, offset, value); } @@ -275,24 +275,24 @@ public void putInts(int rowId, int count, int value) { @Override public void putInts(int rowId, int count, int[] src, int srcIndex) { - Platform.copyMemory(src, Platform.INT_ARRAY_OFFSET + srcIndex * 4, - null, data + 4 * rowId, count * 4); + Platform.copyMemory(src, Platform.INT_ARRAY_OFFSET + srcIndex * 4L, + null, data + 4L * rowId, count * 4L); } @Override public void putInts(int rowId, int count, byte[] src, int srcIndex) { Platform.copyMemory(src, Platform.BYTE_ARRAY_OFFSET + srcIndex, - null, data + rowId * 4, count * 4); + null, data + rowId * 4L, count * 4L); } @Override public void putIntsLittleEndian(int rowId, int count, byte[] src, int srcIndex) { if (!bigEndianPlatform) { Platform.copyMemory(src, srcIndex + Platform.BYTE_ARRAY_OFFSET, - null, data + 4 * rowId, count * 4); + null, data + 4L * rowId, count * 4L); } else { int srcOffset = srcIndex + Platform.BYTE_ARRAY_OFFSET; - long offset = data + 4 * rowId; + long offset = data + 4L * rowId; for (int i = 0; i < count; ++i, offset += 4, srcOffset += 4) { Platform.putInt(null, offset, java.lang.Integer.reverseBytes(Platform.getInt(src, srcOffset))); @@ -303,7 +303,7 @@ public void putIntsLittleEndian(int rowId, int count, byte[] src, int srcIndex) @Override public int getInt(int rowId) { if (dictionary == null) { - return Platform.getInt(null, data + 4 * rowId); + return Platform.getInt(null, data + 4L * rowId); } else { return dictionary.decodeToInt(dictionaryIds.getDictId(rowId)); } @@ -313,7 +313,7 @@ public int getInt(int rowId) { public int[] getInts(int rowId, int count) { assert(dictionary == null); int[] array = new int[count]; - Platform.copyMemory(null, data + rowId * 4, array, Platform.INT_ARRAY_OFFSET, count * 4); + Platform.copyMemory(null, data + rowId * 4L, array, Platform.INT_ARRAY_OFFSET, count * 4L); return array; } @@ -325,7 +325,7 @@ public int[] getInts(int rowId, int count) { public int getDictId(int rowId) { assert(dictionary == null) : "A ColumnVector dictionary should not have a dictionary for itself."; - return Platform.getInt(null, data + 4 * rowId); + return Platform.getInt(null, data + 4L * rowId); } // @@ -334,12 +334,12 @@ public int getDictId(int rowId) { @Override public void putLong(int rowId, long value) { - Platform.putLong(null, data + 8 * rowId, value); + Platform.putLong(null, data + 8L * rowId, value); } @Override public void putLongs(int rowId, int count, long value) { - long offset = data + 8 * rowId; + long offset = data + 8L * rowId; for (int i = 0; i < count; ++i, offset += 8) { Platform.putLong(null, offset, value); } @@ -347,24 +347,24 @@ public void putLongs(int rowId, int count, long value) { @Override public void putLongs(int rowId, int count, long[] src, int srcIndex) { - Platform.copyMemory(src, Platform.LONG_ARRAY_OFFSET + srcIndex * 8, - null, data + 8 * rowId, count * 8); + Platform.copyMemory(src, Platform.LONG_ARRAY_OFFSET + srcIndex * 8L, + null, data + 8L * rowId, count * 8L); } @Override public void putLongs(int rowId, int count, byte[] src, int srcIndex) { Platform.copyMemory(src, Platform.BYTE_ARRAY_OFFSET + srcIndex, - null, data + rowId * 8, count * 8); + null, data + rowId * 8L, count * 8L); } @Override public void putLongsLittleEndian(int rowId, int count, byte[] src, int srcIndex) { if (!bigEndianPlatform) { Platform.copyMemory(src, srcIndex + Platform.BYTE_ARRAY_OFFSET, - null, data + 8 * rowId, count * 8); + null, data + 8L * rowId, count * 8L); } else { int srcOffset = srcIndex + Platform.BYTE_ARRAY_OFFSET; - long offset = data + 8 * rowId; + long offset = data + 8L * rowId; for (int i = 0; i < count; ++i, offset += 8, srcOffset += 8) { Platform.putLong(null, offset, java.lang.Long.reverseBytes(Platform.getLong(src, srcOffset))); @@ -375,7 +375,7 @@ public void putLongsLittleEndian(int rowId, int count, byte[] src, int srcIndex) @Override public long getLong(int rowId) { if (dictionary == null) { - return Platform.getLong(null, data + 8 * rowId); + return Platform.getLong(null, data + 8L * rowId); } else { return dictionary.decodeToLong(dictionaryIds.getDictId(rowId)); } @@ -385,7 +385,7 @@ public long getLong(int rowId) { public long[] getLongs(int rowId, int count) { assert(dictionary == null); long[] array = new long[count]; - Platform.copyMemory(null, data + rowId * 8, array, Platform.LONG_ARRAY_OFFSET, count * 8); + Platform.copyMemory(null, data + rowId * 8L, array, Platform.LONG_ARRAY_OFFSET, count * 8L); return array; } @@ -395,12 +395,12 @@ public long[] getLongs(int rowId, int count) { @Override public void putFloat(int rowId, float value) { - Platform.putFloat(null, data + rowId * 4, value); + Platform.putFloat(null, data + rowId * 4L, value); } @Override public void putFloats(int rowId, int count, float value) { - long offset = data + 4 * rowId; + long offset = data + 4L * rowId; for (int i = 0; i < count; ++i, offset += 4) { Platform.putFloat(null, offset, value); } @@ -408,18 +408,18 @@ public void putFloats(int rowId, int count, float value) { @Override public void putFloats(int rowId, int count, float[] src, int srcIndex) { - Platform.copyMemory(src, Platform.FLOAT_ARRAY_OFFSET + srcIndex * 4, - null, data + 4 * rowId, count * 4); + Platform.copyMemory(src, Platform.FLOAT_ARRAY_OFFSET + srcIndex * 4L, + null, data + 4L * rowId, count * 4L); } @Override public void putFloats(int rowId, int count, byte[] src, int srcIndex) { if (!bigEndianPlatform) { Platform.copyMemory(src, Platform.BYTE_ARRAY_OFFSET + srcIndex, - null, data + rowId * 4, count * 4); + null, data + rowId * 4L, count * 4L); } else { ByteBuffer bb = ByteBuffer.wrap(src).order(ByteOrder.LITTLE_ENDIAN); - long offset = data + 4 * rowId; + long offset = data + 4L * rowId; for (int i = 0; i < count; ++i, offset += 4) { Platform.putFloat(null, offset, bb.getFloat(srcIndex + (4 * i))); } @@ -429,7 +429,7 @@ public void putFloats(int rowId, int count, byte[] src, int srcIndex) { @Override public float getFloat(int rowId) { if (dictionary == null) { - return Platform.getFloat(null, data + rowId * 4); + return Platform.getFloat(null, data + rowId * 4L); } else { return dictionary.decodeToFloat(dictionaryIds.getDictId(rowId)); } @@ -439,7 +439,7 @@ public float getFloat(int rowId) { public float[] getFloats(int rowId, int count) { assert(dictionary == null); float[] array = new float[count]; - Platform.copyMemory(null, data + rowId * 4, array, Platform.FLOAT_ARRAY_OFFSET, count * 4); + Platform.copyMemory(null, data + rowId * 4L, array, Platform.FLOAT_ARRAY_OFFSET, count * 4L); return array; } @@ -450,12 +450,12 @@ public float[] getFloats(int rowId, int count) { @Override public void putDouble(int rowId, double value) { - Platform.putDouble(null, data + rowId * 8, value); + Platform.putDouble(null, data + rowId * 8L, value); } @Override public void putDoubles(int rowId, int count, double value) { - long offset = data + 8 * rowId; + long offset = data + 8L * rowId; for (int i = 0; i < count; ++i, offset += 8) { Platform.putDouble(null, offset, value); } @@ -463,18 +463,18 @@ public void putDoubles(int rowId, int count, double value) { @Override public void putDoubles(int rowId, int count, double[] src, int srcIndex) { - Platform.copyMemory(src, Platform.DOUBLE_ARRAY_OFFSET + srcIndex * 8, - null, data + 8 * rowId, count * 8); + Platform.copyMemory(src, Platform.DOUBLE_ARRAY_OFFSET + srcIndex * 8L, + null, data + 8L * rowId, count * 8L); } @Override public void putDoubles(int rowId, int count, byte[] src, int srcIndex) { if (!bigEndianPlatform) { Platform.copyMemory(src, Platform.BYTE_ARRAY_OFFSET + srcIndex, - null, data + rowId * 8, count * 8); + null, data + rowId * 8L, count * 8L); } else { ByteBuffer bb = ByteBuffer.wrap(src).order(ByteOrder.LITTLE_ENDIAN); - long offset = data + 8 * rowId; + long offset = data + 8L * rowId; for (int i = 0; i < count; ++i, offset += 8) { Platform.putDouble(null, offset, bb.getDouble(srcIndex + (8 * i))); } @@ -484,7 +484,7 @@ public void putDoubles(int rowId, int count, byte[] src, int srcIndex) { @Override public double getDouble(int rowId) { if (dictionary == null) { - return Platform.getDouble(null, data + rowId * 8); + return Platform.getDouble(null, data + rowId * 8L); } else { return dictionary.decodeToDouble(dictionaryIds.getDictId(rowId)); } @@ -494,7 +494,7 @@ public double getDouble(int rowId) { public double[] getDoubles(int rowId, int count) { assert(dictionary == null); double[] array = new double[count]; - Platform.copyMemory(null, data + rowId * 8, array, Platform.DOUBLE_ARRAY_OFFSET, count * 8); + Platform.copyMemory(null, data + rowId * 8L, array, Platform.DOUBLE_ARRAY_OFFSET, count * 8L); return array; } @@ -504,26 +504,26 @@ public double[] getDoubles(int rowId, int count) { @Override public void putArray(int rowId, int offset, int length) { assert(offset >= 0 && offset + length <= childColumns[0].capacity); - Platform.putInt(null, lengthData + 4 * rowId, length); - Platform.putInt(null, offsetData + 4 * rowId, offset); + Platform.putInt(null, lengthData + 4L * rowId, length); + Platform.putInt(null, offsetData + 4L * rowId, offset); } @Override public int getArrayLength(int rowId) { - return Platform.getInt(null, lengthData + 4 * rowId); + return Platform.getInt(null, lengthData + 4L * rowId); } @Override public int getArrayOffset(int rowId) { - return Platform.getInt(null, offsetData + 4 * rowId); + return Platform.getInt(null, offsetData + 4L * rowId); } // APIs dealing with ByteArrays @Override public int putByteArray(int rowId, byte[] value, int offset, int length) { int result = arrayData().appendBytes(length, value, offset); - Platform.putInt(null, lengthData + 4 * rowId, length); - Platform.putInt(null, offsetData + 4 * rowId, result); + Platform.putInt(null, lengthData + 4L * rowId, length); + Platform.putInt(null, offsetData + 4L * rowId, result); return result; } @@ -533,19 +533,19 @@ protected void reserveInternal(int newCapacity) { int oldCapacity = (nulls == 0L) ? 0 : capacity; if (isArray() || type instanceof MapType) { this.lengthData = - Platform.reallocateMemory(lengthData, oldCapacity * 4, newCapacity * 4); + Platform.reallocateMemory(lengthData, oldCapacity * 4L, newCapacity * 4L); this.offsetData = - Platform.reallocateMemory(offsetData, oldCapacity * 4, newCapacity * 4); + Platform.reallocateMemory(offsetData, oldCapacity * 4L, newCapacity * 4L); } else if (type instanceof ByteType || type instanceof BooleanType) { this.data = Platform.reallocateMemory(data, oldCapacity, newCapacity); } else if (type instanceof ShortType) { - this.data = Platform.reallocateMemory(data, oldCapacity * 2, newCapacity * 2); + this.data = Platform.reallocateMemory(data, oldCapacity * 2L, newCapacity * 2L); } else if (type instanceof IntegerType || type instanceof FloatType || type instanceof DateType || DecimalType.is32BitDecimalType(type)) { - this.data = Platform.reallocateMemory(data, oldCapacity * 4, newCapacity * 4); + this.data = Platform.reallocateMemory(data, oldCapacity * 4L, newCapacity * 4L); } else if (type instanceof LongType || type instanceof DoubleType || DecimalType.is64BitDecimalType(type) || type instanceof TimestampType) { - this.data = Platform.reallocateMemory(data, oldCapacity * 8, newCapacity * 8); + this.data = Platform.reallocateMemory(data, oldCapacity * 8L, newCapacity * 8L); } else if (childColumns != null) { // Nothing to store. } else { diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java index 23dcc104e67c4..577eab6ed14c8 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java @@ -231,7 +231,7 @@ public void putShorts(int rowId, int count, short[] src, int srcIndex) { @Override public void putShorts(int rowId, int count, byte[] src, int srcIndex) { Platform.copyMemory(src, Platform.BYTE_ARRAY_OFFSET + srcIndex, shortData, - Platform.SHORT_ARRAY_OFFSET + rowId * 2, count * 2); + Platform.SHORT_ARRAY_OFFSET + rowId * 2L, count * 2L); } @Override @@ -276,7 +276,7 @@ public void putInts(int rowId, int count, int[] src, int srcIndex) { @Override public void putInts(int rowId, int count, byte[] src, int srcIndex) { Platform.copyMemory(src, Platform.BYTE_ARRAY_OFFSET + srcIndex, intData, - Platform.INT_ARRAY_OFFSET + rowId * 4, count * 4); + Platform.INT_ARRAY_OFFSET + rowId * 4L, count * 4L); } @Override @@ -342,7 +342,7 @@ public void putLongs(int rowId, int count, long[] src, int srcIndex) { @Override public void putLongs(int rowId, int count, byte[] src, int srcIndex) { Platform.copyMemory(src, Platform.BYTE_ARRAY_OFFSET + srcIndex, longData, - Platform.LONG_ARRAY_OFFSET + rowId * 8, count * 8); + Platform.LONG_ARRAY_OFFSET + rowId * 8L, count * 8L); } @Override @@ -394,7 +394,7 @@ public void putFloats(int rowId, int count, float[] src, int srcIndex) { public void putFloats(int rowId, int count, byte[] src, int srcIndex) { if (!bigEndianPlatform) { Platform.copyMemory(src, Platform.BYTE_ARRAY_OFFSET + srcIndex, floatData, - Platform.DOUBLE_ARRAY_OFFSET + rowId * 4, count * 4); + Platform.DOUBLE_ARRAY_OFFSET + rowId * 4L, count * 4L); } else { ByteBuffer bb = ByteBuffer.wrap(src).order(ByteOrder.LITTLE_ENDIAN); for (int i = 0; i < count; ++i) { @@ -443,7 +443,7 @@ public void putDoubles(int rowId, int count, double[] src, int srcIndex) { public void putDoubles(int rowId, int count, byte[] src, int srcIndex) { if (!bigEndianPlatform) { Platform.copyMemory(src, Platform.BYTE_ARRAY_OFFSET + srcIndex, doubleData, - Platform.DOUBLE_ARRAY_OFFSET + rowId * 8, count * 8); + Platform.DOUBLE_ARRAY_OFFSET + rowId * 8L, count * 8L); } else { ByteBuffer bb = ByteBuffer.wrap(src).order(ByteOrder.LITTLE_ENDIAN); for (int i = 0; i < count; ++i) { diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportStatistics.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportStatistics.java index 11bb13fd3b211..926396414816c 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportStatistics.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportStatistics.java @@ -22,6 +22,10 @@ /** * A mix in interface for {@link DataSourceReader}. Data source readers can implement this * interface to report statistics to Spark. + * + * Statistics are reported to the optimizer before any operator is pushed to the DataSourceReader. + * Implementations that return more accurate statistics based on pushed operators will not improve + * query performance until the planner can push operators before getting stats. */ @InterfaceStability.Evolving public interface SupportsReportStatistics extends DataSourceReader { diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceWriter.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceWriter.java index 0030a9f05dba7..7eedc85a5d6f3 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceWriter.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceWriter.java @@ -64,8 +64,8 @@ public interface DataSourceWriter { DataWriterFactory createWriterFactory(); /** - * Returns whether Spark should use the commit coordinator to ensure that at most one attempt for - * each task commits. + * Returns whether Spark should use the commit coordinator to ensure that at most one task for + * each partition commits. * * @return true if commit coordinator should be used, false otherwise. */ @@ -90,9 +90,9 @@ default void onDataWriterCommit(WriterCommitMessage message) {} * is undefined and @{@link #abort(WriterCommitMessage[])} may not be able to deal with it. * * Note that speculative execution may cause multiple tasks to run for a partition. By default, - * Spark uses the commit coordinator to allow at most one attempt to commit. Implementations can + * Spark uses the commit coordinator to allow at most one task to commit. Implementations can * disable this behavior by overriding {@link #useCommitCoordinator()}. If disabled, multiple - * attempts may have committed successfully and one successful commit message per task will be + * tasks may have committed successfully and one successful commit message per task will be * passed to this commit method. The remaining commit messages are ignored by Spark. */ void commit(WriterCommitMessage[] messages); diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriter.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriter.java index 39bf458298862..1626c0013e4e7 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriter.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriter.java @@ -22,7 +22,7 @@ import org.apache.spark.annotation.InterfaceStability; /** - * A data writer returned by {@link DataWriterFactory#createDataWriter(int, int, long)} and is + * A data writer returned by {@link DataWriterFactory#createDataWriter(int, long, long)} and is * responsible for writing data for an input RDD partition. * * One Spark task has one exclusive data writer, so there is no thread-safe concern. @@ -39,14 +39,14 @@ * {@link DataSourceWriter#commit(WriterCommitMessage[])} with commit messages from other data * writers. If this data writer fails(one record fails to write or {@link #commit()} fails), an * exception will be sent to the driver side, and Spark may retry this writing task a few times. - * In each retry, {@link DataWriterFactory#createDataWriter(int, int, long)} will receive a - * different `attemptNumber`. Spark will call {@link DataSourceWriter#abort(WriterCommitMessage[])} + * In each retry, {@link DataWriterFactory#createDataWriter(int, long, long)} will receive a + * different `taskId`. Spark will call {@link DataSourceWriter#abort(WriterCommitMessage[])} * when the configured number of retries is exhausted. * * Besides the retry mechanism, Spark may launch speculative tasks if the existing writing task * takes too long to finish. Different from retried tasks, which are launched one by one after the * previous one fails, speculative tasks are running simultaneously. It's possible that one input - * RDD partition has multiple data writers with different `attemptNumber` running at the same time, + * RDD partition has multiple data writers with different `taskId` running at the same time, * and data sources should guarantee that these data writers don't conflict and can work together. * Implementations can coordinate with driver during {@link #commit()} to make sure only one of * these data writers can commit successfully. Or implementations can allow all of them to commit diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriterFactory.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriterFactory.java index 7527bcc0c4027..0932ff8f8f8a7 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriterFactory.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriterFactory.java @@ -42,15 +42,12 @@ public interface DataWriterFactory extends Serializable { * Usually Spark processes many RDD partitions at the same time, * implementations should use the partition id to distinguish writers for * different partitions. - * @param attemptNumber Spark may launch multiple tasks with the same task id. For example, a task - * failed, Spark launches a new task wth the same task id but different - * attempt number. Or a task is too slow, Spark launches new tasks wth the - * same task id but different attempt number, which means there are multiple - * tasks with the same task id running at the same time. Implementations can - * use this attempt number to distinguish writers of different task attempts. + * @param taskId A unique identifier for a task that is performing the write of the partition + * data. Spark may run multiple tasks for the same partition (due to speculation + * or task failures, for example). * @param epochId A monotonically increasing id for streaming queries that are split in to * discrete periods of execution. For non-streaming queries, * this ID will always be 0. */ - DataWriter createDataWriter(int partitionId, int attemptNumber, long epochId); + DataWriter createDataWriter(int partitionId, long taskId, long epochId); } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index de6be5f76e15a..ec9352a7fa055 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -381,6 +381,8 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { * that should be used for parsing. *
  • `samplingRatio` (default is 1.0): defines fraction of input JSON objects used * for schema inferring.
  • + *
  • `dropFieldIfAllNull` (default `false`): whether to ignore column of all null values or + * empty array/struct during schema inference.
  • * * * @since 2.0.0 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index f5526104690d2..c97246f30220d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -236,12 +236,10 @@ class Dataset[T] private[sql]( * @param numRows Number of rows to return * @param truncate If set to more than 0, truncates strings to `truncate` characters and * all cells will be aligned right. - * @param vertical If set to true, the rows to return do not need truncate. */ private[sql] def getRows( numRows: Int, - truncate: Int, - vertical: Boolean): Seq[Seq[String]] = { + truncate: Int): Seq[Seq[String]] = { val newDf = toDF() val castCols = newDf.logicalPlan.output.map { col => // Since binary types in top-level schema fields have a specific format to print, @@ -289,7 +287,7 @@ class Dataset[T] private[sql]( vertical: Boolean = false): String = { val numRows = _numRows.max(0).min(Int.MaxValue - 1) // Get rows represented by Seq[Seq[String]], we may get one more line if it has more data. - val tmpRows = getRows(numRows, truncate, vertical) + val tmpRows = getRows(numRows, truncate) val hasMoreData = tmpRows.length - 1 > numRows val rows = tmpRows.take(numRows + 1) @@ -1018,6 +1016,11 @@ class Dataset[T] private[sql]( catalyst.expressions.EqualTo( withPlan(plan.left).resolve(a.name), withPlan(plan.right).resolve(b.name)) + case catalyst.expressions.EqualNullSafe(a: AttributeReference, b: AttributeReference) + if a.sameRef(b) => + catalyst.expressions.EqualNullSafe( + withPlan(plan.left).resolve(a.name), + withPlan(plan.right).resolve(b.name)) }} withPlan { @@ -2964,6 +2967,7 @@ class Dataset[T] private[sql]( /** * Mark the Dataset as non-persistent, and remove all blocks for it from memory and disk. + * This will not un-persist any cached data that is built upon this Dataset. * * @param blocking Whether to block until all blocks are deleted. * @@ -2971,12 +2975,13 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ def unpersist(blocking: Boolean): this.type = { - sparkSession.sharedState.cacheManager.uncacheQuery(this, blocking) + sparkSession.sharedState.cacheManager.uncacheQuery(this, cascade = false, blocking) this } /** * Mark the Dataset as non-persistent, and remove all blocks for it from memory and disk. + * This will not un-persist any cached data that is built upon this Dataset. * * @group basic * @since 1.6.0 @@ -3224,11 +3229,10 @@ class Dataset[T] private[sql]( private[sql] def getRowsToPython( _numRows: Int, - truncate: Int, - vertical: Boolean): Array[Any] = { + truncate: Int): Array[Any] = { EvaluatePython.registerPicklers() val numRows = _numRows.max(0).min(Int.MaxValue - 1) - val rows = getRows(numRows, truncate, vertical).map(_.toArray).toArray + val rows = getRows(numRows, truncate).map(_.toArray).toArray val toJava: (Any) => Any = EvaluatePython.toJava(_, ArrayType(ArrayType(StringType))) val iter: Iterator[Array[Byte]] = new SerDeUtil.AutoBatchedPickler( rows.iterator.map(toJava)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/ForeachWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/ForeachWriter.scala index 86e02e98c01f3..b21c50af18433 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/ForeachWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/ForeachWriter.scala @@ -20,10 +20,48 @@ package org.apache.spark.sql import org.apache.spark.annotation.InterfaceStability /** - * A class to consume data generated by a `StreamingQuery`. Typically this is used to send the - * generated data to external systems. Each partition will use a new deserialized instance, so you - * usually should do all the initialization (e.g. opening a connection or initiating a transaction) - * in the `open` method. + * The abstract class for writing custom logic to process data generated by a query. + * This is often used to write the output of a streaming query to arbitrary storage systems. + * Any implementation of this base class will be used by Spark in the following way. + * + *
      + *
    • A single instance of this class is responsible of all the data generated by a single task + * in a query. In other words, one instance is responsible for processing one partition of the + * data generated in a distributed manner. + * + *
    • Any implementation of this class must be serializable because each task will get a fresh + * serialized-deserialized copy of the provided object. Hence, it is strongly recommended that + * any initialization for writing data (e.g. opening a connection or starting a transaction) + * is done after the `open(...)` method has been called, which signifies that the task is + * ready to generate data. + * + *
    • The lifecycle of the methods are as follows. + * + *
      + *   For each partition with `partitionId`:
      + *       For each batch/epoch of streaming data (if its streaming query) with `epochId`:
      + *           Method `open(partitionId, epochId)` is called.
      + *           If `open` returns true:
      + *                For each row in the partition and batch/epoch, method `process(row)` is called.
      + *           Method `close(errorOrNull)` is called with error (if any) seen while processing rows.
      + *   
      + * + *
    + * + * Important points to note: + *
      + *
    • The `partitionId` and `epochId` can be used to deduplicate generated data when failures + * cause reprocessing of some input data. This depends on the execution mode of the query. If + * the streaming query is being executed in the micro-batch mode, then every partition + * represented by a unique tuple (partitionId, epochId) is guaranteed to have the same data. + * Hence, (partitionId, epochId) can be used to deduplicate and/or transactionally commit data + * and achieve exactly-once guarantees. However, if the streaming query is being executed in the + * continuous mode, then this guarantee does not hold and therefore should not be used for + * deduplication. + * + *
    • The `close()` method will be called if `open()` method returns successfully (irrespective + * of the return value), except if the JVM crashes in the middle. + *
    * * Scala example: * {{{ @@ -63,6 +101,7 @@ import org.apache.spark.annotation.InterfaceStability * } * }); * }}} + * * @since 2.0.0 */ @InterfaceStability.Evolving @@ -71,23 +110,18 @@ abstract class ForeachWriter[T] extends Serializable { // TODO: Move this to org.apache.spark.sql.util or consolidate this with batch API. /** - * Called when starting to process one partition of new data in the executor. The `version` is - * for data deduplication when there are failures. When recovering from a failure, some data may - * be generated multiple times but they will always have the same version. - * - * If this method finds using the `partitionId` and `version` that this partition has already been - * processed, it can return `false` to skip the further data processing. However, `close` still - * will be called for cleaning up resources. + * Called when starting to process one partition of new data in the executor. See the class + * docs for more information on how to use the `partitionId` and `epochId`. * * @param partitionId the partition id. - * @param version a unique id for data deduplication. + * @param epochId a unique id for data deduplication. * @return `true` if the corresponding partition and version id should be processed. `false` * indicates the partition should be skipped. */ - def open(partitionId: Long, version: Long): Boolean + def open(partitionId: Long, epochId: Long): Boolean /** - * Called to process the data in the executor side. This method will be called only when `open` + * Called to process the data in the executor side. This method will be called only if `open` * returns `true`. */ def process(value: T): Unit diff --git a/sql/core/src/main/scala/org/apache/spark/sql/RuntimeConfig.scala b/sql/core/src/main/scala/org/apache/spark/sql/RuntimeConfig.scala index b352e332bc7e0..3c39579149fff 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/RuntimeConfig.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/RuntimeConfig.scala @@ -132,6 +132,17 @@ class RuntimeConfig private[sql](sqlConf: SQLConf = new SQLConf) { sqlConf.unsetConf(key) } + /** + * Indicates whether the configuration property with the given key + * is modifiable in the current session. + * + * @return `true` if the configuration property is modifiable. For static SQL, Spark Core, + * invalid (not existing) and other non-modifiable configuration properties, + * the returned value is `false`. + * @since 2.4.0 + */ + def isModifiable(key: String): Boolean = sqlConf.isModifiable(key) + /** * Returns whether a particular key is set. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala index 93bf91e56f1bd..39d9a95ca4710 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala @@ -26,7 +26,7 @@ import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.spark.internal.Logging import org.apache.spark.sql.{Dataset, SparkSession} import org.apache.spark.sql.catalyst.expressions.SubqueryExpression -import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, ResolvedHint} +import org.apache.spark.sql.catalyst.plans.logical.{AnalysisBarrier, LogicalPlan, ResolvedHint} import org.apache.spark.sql.execution.columnar.InMemoryRelation import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, LogicalRelation} import org.apache.spark.storage.StorageLevel @@ -97,7 +97,7 @@ class CacheManager extends Logging { val inMemoryRelation = InMemoryRelation( sparkSession.sessionState.conf.useCompression, sparkSession.sessionState.conf.columnBatchSize, storageLevel, - sparkSession.sessionState.executePlan(planToCache).executedPlan, + sparkSession.sessionState.executePlan(AnalysisBarrier(planToCache)).executedPlan, tableName, planToCache) cachedData.add(CachedData(planToCache, inMemoryRelation)) @@ -105,24 +105,50 @@ class CacheManager extends Logging { } /** - * Un-cache all the cache entries that refer to the given plan. + * Un-cache the given plan or all the cache entries that refer to the given plan. + * @param query The [[Dataset]] to be un-cached. + * @param cascade If true, un-cache all the cache entries that refer to the given + * [[Dataset]]; otherwise un-cache the given [[Dataset]] only. + * @param blocking Whether to block until all blocks are deleted. */ - def uncacheQuery(query: Dataset[_], blocking: Boolean = true): Unit = writeLock { - uncacheQuery(query.sparkSession, query.logicalPlan, blocking) + def uncacheQuery( + query: Dataset[_], + cascade: Boolean, + blocking: Boolean = true): Unit = writeLock { + uncacheQuery(query.sparkSession, query.logicalPlan, cascade, blocking) } /** - * Un-cache all the cache entries that refer to the given plan. + * Un-cache the given plan or all the cache entries that refer to the given plan. + * @param spark The Spark session. + * @param plan The plan to be un-cached. + * @param cascade If true, un-cache all the cache entries that refer to the given + * plan; otherwise un-cache the given plan only. + * @param blocking Whether to block until all blocks are deleted. */ - def uncacheQuery(spark: SparkSession, plan: LogicalPlan, blocking: Boolean): Unit = writeLock { + def uncacheQuery( + spark: SparkSession, + plan: LogicalPlan, + cascade: Boolean, + blocking: Boolean): Unit = writeLock { + val shouldRemove: LogicalPlan => Boolean = + if (cascade) { + _.find(_.sameResult(plan)).isDefined + } else { + _.sameResult(plan) + } val it = cachedData.iterator() while (it.hasNext) { val cd = it.next() - if (cd.plan.find(_.sameResult(plan)).isDefined) { + if (shouldRemove(cd.plan)) { cd.cachedRepresentation.cacheBuilder.clearCache(blocking) it.remove() } } + // Re-compile dependent cached queries after removing the cached query. + if (!cascade) { + recacheByCondition(spark, _.find(_.sameResult(plan)).isDefined, clearCache = false) + } } /** @@ -132,20 +158,24 @@ class CacheManager extends Logging { recacheByCondition(spark, _.find(_.sameResult(plan)).isDefined) } - private def recacheByCondition(spark: SparkSession, condition: LogicalPlan => Boolean): Unit = { + private def recacheByCondition( + spark: SparkSession, + condition: LogicalPlan => Boolean, + clearCache: Boolean = true): Unit = { val it = cachedData.iterator() val needToRecache = scala.collection.mutable.ArrayBuffer.empty[CachedData] while (it.hasNext) { val cd = it.next() if (condition(cd.plan)) { - cd.cachedRepresentation.cacheBuilder.clearCache() + if (clearCache) { + cd.cachedRepresentation.cacheBuilder.clearCache() + } // Remove the cache entry before we create a new one, so that we can have a different // physical plan. it.remove() - val plan = spark.sessionState.executePlan(cd.plan).executedPlan + val plan = spark.sessionState.executePlan(AnalysisBarrier(cd.plan)).executedPlan val newCache = InMemoryRelation( - cacheBuilder = cd.cachedRepresentation - .cacheBuilder.copy(cachedPlan = plan)(_cachedColumnBuffers = null), + cacheBuilder = cd.cachedRepresentation.cacheBuilder.withCachedPlan(plan), logicalPlan = cd.plan) needToRecache += cd.copy(cachedRepresentation = newCache) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala index 1c8e4050978dc..00ff4c8ac310b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala @@ -21,7 +21,6 @@ import org.apache.spark.sql.ExperimentalMethods import org.apache.spark.sql.catalyst.catalog.SessionCatalog import org.apache.spark.sql.catalyst.optimizer.Optimizer import org.apache.spark.sql.execution.datasources.PruneFileSourcePartitions -import org.apache.spark.sql.execution.datasources.v2.PushDownOperatorsToDataSource import org.apache.spark.sql.execution.python.ExtractPythonUDFFromAggregate class SparkOptimizer( @@ -32,8 +31,7 @@ class SparkOptimizer( override def batches: Seq[Batch] = (preOptimizationBatches ++ super.batches :+ Batch("Optimize Metadata Only Query", Once, OptimizeMetadataOnlyQuery(catalog)) :+ Batch("Extract Python UDF from Aggregate", Once, ExtractPythonUDFFromAggregate) :+ - Batch("Prune File Source Table Partitions", Once, PruneFileSourcePartitions) :+ - Batch("Push down operators to data source scan", Once, PushDownOperatorsToDataSource)) ++ + Batch("Prune File Source Table Partitions", Once, PruneFileSourcePartitions)) ++ postHocOptimizationBatches :+ Batch("User Provided Optimizers", fixedPoint, experimentalMethods.extraOptimizations: _*) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index d6951ad01fb0c..02e095b42a506 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.planning._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.physical._ +import org.apache.spark.sql.catalyst.streaming.InternalOutputModes import org.apache.spark.sql.execution.columnar.{InMemoryRelation, InMemoryTableScanExec} import org.apache.spark.sql.execution.command._ import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec @@ -34,7 +35,7 @@ import org.apache.spark.sql.execution.joins.{BuildLeft, BuildRight, BuildSide} import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.sources.MemoryPlanV2 import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.streaming.StreamingQuery +import org.apache.spark.sql.streaming.{OutputMode, StreamingQuery} import org.apache.spark.sql.types.StructType /** @@ -73,12 +74,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { if limit < conf.topKSortFallbackThreshold => TakeOrderedAndProjectExec(limit, order, projectList, planLater(child)) :: Nil case Limit(IntegerLiteral(limit), child) => - // With whole stage codegen, Spark releases resources only when all the output data of the - // query plan are consumed. It's possible that `CollectLimitExec` only consumes a little - // data from child plan and finishes the query without releasing resources. Here we wrap - // the child plan with `LocalLimitExec`, to stop the processing of whole stage codegen and - // trigger the resource releasing work, after we consume `limit` rows. - CollectLimitExec(limit, LocalLimitExec(limit, planLater(child))) :: Nil + CollectLimitExec(limit, planLater(child)) :: Nil case other => planLater(other) :: Nil } case Limit(IntegerLiteral(limit), Sort(order, true, child)) @@ -354,6 +350,29 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { } } + /** + * Used to plan the streaming global limit operator for streams in append mode. + * We need to check for either a direct Limit or a Limit wrapped in a ReturnAnswer operator, + * following the example of the SpecialLimits Strategy above. + * Streams with limit in Append mode use the stateful StreamingGlobalLimitExec. + * Streams with limit in Complete mode use the stateless CollectLimitExec operator. + * Limit is unsupported for streams in Update mode. + */ + case class StreamingGlobalLimitStrategy(outputMode: OutputMode) extends Strategy { + override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { + case ReturnAnswer(rootPlan) => rootPlan match { + case Limit(IntegerLiteral(limit), child) + if plan.isStreaming && outputMode == InternalOutputModes.Append => + StreamingGlobalLimitExec(limit, LocalLimitExec(limit, planLater(child))) :: Nil + case _ => Nil + } + case Limit(IntegerLiteral(limit), child) + if plan.isStreaming && outputMode == InternalOutputModes.Append => + StreamingGlobalLimitExec(limit, LocalLimitExec(limit, planLater(child))) :: Nil + case _ => Nil + } + } + object StreamingJoinStrategy extends Strategy { override def apply(plan: LogicalPlan): Seq[SparkPlan] = { plan match { @@ -494,7 +513,6 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { } } - // Can we automate these 'pass through' operations? object BasicOperators extends Strategy { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case d: DataWritingCommand => DataWritingCommandExec(d, planLater(d.query)) :: Nil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index 8c7b2c187cccd..2cac0cfce28de 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -328,7 +328,7 @@ case class HashAggregateExec( initialBuffer, bufferSchema, groupingKeySchema, - TaskContext.get().taskMemoryManager(), + TaskContext.get(), 1024 * 16, // initial capacity TaskContext.get().taskMemoryManager().pageSizeBytes ) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala index 9dc334c1ead3c..c1911235f8df3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala @@ -166,7 +166,7 @@ class TungstenAggregationIterator( initialAggregationBuffer, StructType.fromAttributes(aggregateFunctions.flatMap(_.aggBufferAttributes)), StructType.fromAttributes(groupingExpressions.map(_.toAttribute)), - TaskContext.get().taskMemoryManager(), + TaskContext.get(), 1024 * 16, // initial capacity TaskContext.get().taskMemoryManager().pageSizeBytes ) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowUtils.scala index 6ad11bda84bf6..93c8127681b3e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowUtils.scala @@ -23,6 +23,7 @@ import org.apache.arrow.memory.RootAllocator import org.apache.arrow.vector.types.{DateUnit, FloatingPointPrecision, TimeUnit} import org.apache.arrow.vector.types.pojo.{ArrowType, Field, FieldType, Schema} +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ object ArrowUtils { @@ -120,4 +121,19 @@ object ArrowUtils { StructField(field.getName, dt, field.isNullable) }) } + + /** Return Map with conf settings to be used in ArrowPythonRunner */ + def getPythonRunnerConfMap(conf: SQLConf): Map[String, String] = { + val timeZoneConf = if (conf.pandasRespectSessionTimeZone) { + Seq(SQLConf.SESSION_LOCAL_TIMEZONE.key -> conf.sessionLocalTimeZone) + } else { + Nil + } + val pandasColsByPosition = if (conf.pandasGroupedMapAssignColumnssByPosition) { + Seq(SQLConf.PANDAS_GROUPED_MAP_ASSIGN_COLUMNS_BY_POSITION.key -> "true") + } else { + Nil + } + Map(timeZoneConf ++ pandasColsByPosition: _*) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala index da35a4734e65a..7c8faec53a828 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala @@ -74,6 +74,16 @@ case class CachedRDDBuilder( } } + def withCachedPlan(cachedPlan: SparkPlan): CachedRDDBuilder = { + new CachedRDDBuilder( + useCompression, + batchSize, + storageLevel, + cachedPlan = cachedPlan, + tableName + )(_cachedColumnBuffers) + } + private def buildBuffers(): RDD[CachedBatch] = { val output = cachedPlan.output val cached = cachedPlan.execute().mapPartitionsInternal { rowIterator => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala index 0b4dd76c7d860..997cf92449c68 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.QueryPlan -import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning} +import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.execution.{ColumnarBatchScan, LeafExecNode, SparkPlan, WholeStageCodegenExec} import org.apache.spark.sql.execution.vectorized._ import org.apache.spark.sql.types._ @@ -169,8 +169,8 @@ case class InMemoryTableScanExec( // But the cached version could alias output, so we need to replace output. override def outputPartitioning: Partitioning = { relation.cachedPlan.outputPartitioning match { - case h: HashPartitioning => updateAttribute(h).asInstanceOf[HashPartitioning] - case _ => relation.cachedPlan.outputPartitioning + case e: Expression => updateAttribute(e).asInstanceOf[Partitioning] + case other => other } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala index bf4d96fa18d0d..04bf8c6dd917f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala @@ -189,8 +189,9 @@ case class DropTableCommand( override def run(sparkSession: SparkSession): Seq[Row] = { val catalog = sparkSession.sessionState.catalog + val isTempView = catalog.isTemporaryTable(tableName) - if (!catalog.isTemporaryTable(tableName) && catalog.tableExists(tableName)) { + if (!isTempView && catalog.tableExists(tableName)) { // If the command DROP VIEW is to drop a table or DROP TABLE is to drop a view // issue an exception. catalog.getTableMetadata(tableName).tableType match { @@ -204,9 +205,10 @@ case class DropTableCommand( } } - if (catalog.isTemporaryTable(tableName) || catalog.tableExists(tableName)) { + if (isTempView || catalog.tableExists(tableName)) { try { - sparkSession.sharedState.cacheManager.uncacheQuery(sparkSession.table(tableName)) + sparkSession.sharedState.cacheManager.uncacheQuery( + sparkSession.table(tableName), cascade = !isTempView) } catch { case NonFatal(e) => log.warn(e.toString, e) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala index 44749190c79eb..ec3961f84bd8d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala @@ -493,7 +493,7 @@ case class TruncateTableCommand( spark.sessionState.refreshTable(tableName.unquotedString) // Also try to drop the contents of the table from the columnar cache try { - spark.sharedState.cacheManager.uncacheQuery(spark.table(table.identifier)) + spark.sharedState.cacheManager.uncacheQuery(spark.table(table.identifier), cascade = true) } catch { case NonFatal(e) => log.warn(s"Exception when attempting to uncache table $tableIdentWithDB", e) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala index f16d824201e77..0c3d9a4895fe2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala @@ -396,6 +396,7 @@ case class DataSource( hs.partitionSchema.map(_.name), "in the partition schema", equality) + DataSourceUtils.verifyReadSchema(hs.fileFormat, hs.dataSchema) case _ => SchemaUtils.checkColumnNameDuplication( relation.schema.map(_.name), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceUtils.scala new file mode 100644 index 0000000000000..82e99190ecf14 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceUtils.scala @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources + +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.types._ + + +object DataSourceUtils { + + /** + * Verify if the schema is supported in datasource in write path. + */ + def verifyWriteSchema(format: FileFormat, schema: StructType): Unit = { + verifySchema(format, schema, isReadPath = false) + } + + /** + * Verify if the schema is supported in datasource in read path. + */ + def verifyReadSchema(format: FileFormat, schema: StructType): Unit = { + verifySchema(format, schema, isReadPath = true) + } + + /** + * Verify if the schema is supported in datasource. This verification should be done + * in a driver side. + */ + private def verifySchema(format: FileFormat, schema: StructType, isReadPath: Boolean): Unit = { + schema.foreach { field => + if (!format.supportDataType(field.dataType, isReadPath)) { + throw new AnalysisException( + s"$format data source does not support ${field.dataType.simpleString} data type.") + } + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormat.scala index 023e127888290..2c162e23644ef 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormat.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources.Filter -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.types.{DataType, StructType} /** @@ -57,7 +57,7 @@ trait FileFormat { dataSchema: StructType): OutputWriterFactory /** - * Returns whether this format support returning columnar batch or not. + * Returns whether this format supports returning columnar batch or not. * * TODO: we should just have different traits for the different formats. */ @@ -152,6 +152,11 @@ trait FileFormat { } } + /** + * Returns whether this format supports the given [[DataType]] in read/write path. + * By default all data types are supported. + */ + def supportDataType(dataType: DataType, isReadPath: Boolean): Boolean = true } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala index 52da8356ab835..7c6ab4bc922fe 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala @@ -96,9 +96,11 @@ object FileFormatWriter extends Logging { val caseInsensitiveOptions = CaseInsensitiveMap(options) + val dataSchema = dataColumns.toStructType + DataSourceUtils.verifyWriteSchema(fileFormat, dataSchema) // Note: prepareWrite has side effect. It sets "job". val outputWriterFactory = - fileFormat.prepareWrite(sparkSession, job, caseInsensitiveOptions, dataColumns.toStructType) + fileFormat.prepareWrite(sparkSession, job, caseInsensitiveOptions, dataSchema) val description = new WriteJobDescription( uuid = UUID.randomUUID().toString, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoDataSourceCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoDataSourceCommand.scala index a813829d50cb1..80d7608a22891 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoDataSourceCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoDataSourceCommand.scala @@ -38,9 +38,8 @@ case class InsertIntoDataSourceCommand( override def run(sparkSession: SparkSession): Seq[Row] = { val relation = logicalRelation.relation.asInstanceOf[InsertableRelation] val data = Dataset.ofRows(sparkSession, query) - // Apply the schema of the existing table to the new data. - val df = sparkSession.internalCreateDataFrame(data.queryExecution.toRdd, logicalRelation.schema) - relation.insert(df, overwrite) + // Data has been casted to the target relation's schema by the PreprocessTableInsertion rule. + relation.insert(data, overwrite) // Re-cache all cached plans(including this relation itself, if it's cached) that refer to this // data source relation. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala index b90275de9f40a..aeb40e5a4131d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala @@ -66,7 +66,6 @@ class CSVFileFormat extends TextBasedFileFormat with DataSourceRegister { job: Job, options: Map[String, String], dataSchema: StructType): OutputWriterFactory = { - CSVUtils.verifySchema(dataSchema) val conf = job.getConfiguration val csvOptions = new CSVOptions( options, @@ -98,7 +97,6 @@ class CSVFileFormat extends TextBasedFileFormat with DataSourceRegister { filters: Seq[Filter], options: Map[String, String], hadoopConf: Configuration): (PartitionedFile) => Iterator[InternalRow] = { - CSVUtils.verifySchema(dataSchema) val broadcastedHadoopConf = sparkSession.sparkContext.broadcast(new SerializableConfiguration(hadoopConf)) @@ -153,6 +151,15 @@ class CSVFileFormat extends TextBasedFileFormat with DataSourceRegister { override def hashCode(): Int = getClass.hashCode() override def equals(other: Any): Boolean = other.isInstanceOf[CSVFileFormat] + + override def supportDataType(dataType: DataType, isReadPath: Boolean): Boolean = dataType match { + case _: AtomicType => true + + case udt: UserDefinedType[_] => supportDataType(udt.sqlType, isReadPath) + + case _ => false + } + } private[csv] class CsvOutputWriter( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVUtils.scala index 1012e774118e2..7ce65fa89b02d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVUtils.scala @@ -117,25 +117,6 @@ object CSVUtils { } } - /** - * Verify if the schema is supported in CSV datasource. - */ - def verifySchema(schema: StructType): Unit = { - def verifyType(dataType: DataType): Unit = dataType match { - case ByteType | ShortType | IntegerType | LongType | FloatType | - DoubleType | BooleanType | _: DecimalType | TimestampType | - DateType | StringType => - - case udt: UserDefinedType[_] => verifyType(udt.sqlType) - - case _ => - throw new UnsupportedOperationException( - s"CSV data source does not support ${dataType.simpleString} data type.") - } - - schema.foreach(field => verifyType(field.dataType)) - } - /** * Sample CSV dataset as configured by `samplingRatio`. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala index 5f7d5696b71a6..79143cce4a380 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala @@ -33,29 +33,49 @@ import org.apache.spark.sql.execution.datasources.FailureSafeParser import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String + +/** + * Constructs a parser for a given schema that translates CSV data to an [[InternalRow]]. + * + * @param dataSchema The CSV data schema that is specified by the user, or inferred from underlying + * data files. + * @param requiredSchema The schema of the data that should be output for each row. This should be a + * subset of the columns in dataSchema. + * @param options Configuration options for a CSV parser. + */ class UnivocityParser( dataSchema: StructType, requiredSchema: StructType, val options: CSVOptions) extends Logging { require(requiredSchema.toSet.subsetOf(dataSchema.toSet), - "requiredSchema should be the subset of schema.") + s"requiredSchema (${requiredSchema.catalogString}) should be the subset of " + + s"dataSchema (${dataSchema.catalogString}).") def this(schema: StructType, options: CSVOptions) = this(schema, schema, options) // A `ValueConverter` is responsible for converting the given value to a desired type. private type ValueConverter = String => Any + // This index is used to reorder parsed tokens + private val tokenIndexArr = + requiredSchema.map(f => java.lang.Integer.valueOf(dataSchema.indexOf(f))).toArray + + // When column pruning is enabled, the parser only parses the required columns based on + // their positions in the data schema. + private val parsedSchema = if (options.columnPruning) requiredSchema else dataSchema + val tokenizer = { val parserSetting = options.asParserSettings - if (options.columnPruning && requiredSchema.length < dataSchema.length) { - val tokenIndexArr = requiredSchema.map(f => java.lang.Integer.valueOf(dataSchema.indexOf(f))) + // When to-be-parsed schema is shorter than the to-be-read data schema, we let Univocity CSV + // parser select a sequence of fields for reading by their positions. + // if (options.columnPruning && requiredSchema.length < dataSchema.length) { + if (parsedSchema.length < dataSchema.length) { parserSetting.selectIndexes(tokenIndexArr: _*) } new CsvParser(parserSetting) } - private val schema = if (options.columnPruning) requiredSchema else dataSchema - private val row = new GenericInternalRow(schema.length) + private val row = new GenericInternalRow(requiredSchema.length) // Retrieve the raw record string. private def getCurrentInput: UTF8String = { @@ -82,7 +102,7 @@ class UnivocityParser( // // output row - ["A", 2] private val valueConverters: Array[ValueConverter] = { - schema.map(f => makeConverter(f.name, f.dataType, f.nullable, options)).toArray + requiredSchema.map(f => makeConverter(f.name, f.dataType, f.nullable, options)).toArray } /** @@ -183,21 +203,35 @@ class UnivocityParser( } } + private val doParse = if (requiredSchema.nonEmpty) { + (input: String) => convert(tokenizer.parseLine(input)) + } else { + // If `columnPruning` enabled and partition attributes scanned only, + // `schema` gets empty. + (_: String) => InternalRow.empty + } + /** * Parses a single CSV string and turns it into either one resulting row or no row (if the * the record is malformed). */ - def parse(input: String): InternalRow = convert(tokenizer.parseLine(input)) + def parse(input: String): InternalRow = doParse(input) + + private val getToken = if (options.columnPruning) { + (tokens: Array[String], index: Int) => tokens(index) + } else { + (tokens: Array[String], index: Int) => tokens(tokenIndexArr(index)) + } private def convert(tokens: Array[String]): InternalRow = { - if (tokens.length != schema.length) { + if (tokens.length != parsedSchema.length) { // If the number of tokens doesn't match the schema, we should treat it as a malformed record. // However, we still have chance to parse some of the tokens, by adding extra null tokens in // the tail if the number is smaller, or by dropping extra tokens if the number is larger. - val checkedTokens = if (schema.length > tokens.length) { - tokens ++ new Array[String](schema.length - tokens.length) + val checkedTokens = if (parsedSchema.length > tokens.length) { + tokens ++ new Array[String](parsedSchema.length - tokens.length) } else { - tokens.take(schema.length) + tokens.take(parsedSchema.length) } def getPartialResult(): Option[InternalRow] = { try { @@ -214,9 +248,11 @@ class UnivocityParser( new RuntimeException("Malformed CSV record")) } else { try { + // When the length of the returned tokens is identical to the length of the parsed schema, + // we just need to convert the tokens that correspond to the required columns. var i = 0 - while (i < schema.length) { - row(i) = valueConverters(i).apply(tokens(i)) + while (i < requiredSchema.length) { + row(i) = valueConverters(i).apply(getToken(tokens, i)) i += 1 } row diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala index a73a97c06fe5a..eea966d30948b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql.types.StructType * Options for the JDBC data source. */ class JDBCOptions( - @transient private val parameters: CaseInsensitiveMap[String]) + @transient val parameters: CaseInsensitiveMap[String]) extends Serializable { import JDBCOptions._ @@ -65,11 +65,31 @@ class JDBCOptions( // Required parameters // ------------------------------------------------------------ require(parameters.isDefinedAt(JDBC_URL), s"Option '$JDBC_URL' is required.") - require(parameters.isDefinedAt(JDBC_TABLE_NAME), s"Option '$JDBC_TABLE_NAME' is required.") // a JDBC URL val url = parameters(JDBC_URL) - // name of table - val table = parameters(JDBC_TABLE_NAME) + // table name or a table subquery. + val tableOrQuery = (parameters.get(JDBC_TABLE_NAME), parameters.get(JDBC_QUERY_STRING)) match { + case (Some(name), Some(subquery)) => + throw new IllegalArgumentException( + s"Both '$JDBC_TABLE_NAME' and '$JDBC_QUERY_STRING' can not be specified at the same time." + ) + case (None, None) => + throw new IllegalArgumentException( + s"Option '$JDBC_TABLE_NAME' or '$JDBC_QUERY_STRING' is required." + ) + case (Some(name), None) => + if (name.isEmpty) { + throw new IllegalArgumentException(s"Option '$JDBC_TABLE_NAME' can not be empty.") + } else { + name.trim + } + case (None, Some(subquery)) => + if (subquery.isEmpty) { + throw new IllegalArgumentException(s"Option `$JDBC_QUERY_STRING` can not be empty.") + } else { + s"(${subquery}) __SPARK_GEN_JDBC_SUBQUERY_NAME_${curId.getAndIncrement()}" + } + } // ------------------------------------------------------------ // Optional parameters @@ -109,6 +129,20 @@ class JDBCOptions( s"When reading JDBC data sources, users need to specify all or none for the following " + s"options: '$JDBC_PARTITION_COLUMN', '$JDBC_LOWER_BOUND', '$JDBC_UPPER_BOUND', " + s"and '$JDBC_NUM_PARTITIONS'") + + require(!(parameters.get(JDBC_QUERY_STRING).isDefined && partitionColumn.isDefined), + s""" + |Options '$JDBC_QUERY_STRING' and '$JDBC_PARTITION_COLUMN' can not be specified together. + |Please define the query using `$JDBC_TABLE_NAME` option instead and make sure to qualify + |the partition columns using the supplied subquery alias to resolve any ambiguity. + |Example : + |spark.read.format("jdbc") + | .option("dbtable", "(select c1, c2 from t1) as subq") + | .option("partitionColumn", "subq.c1" + | .load() + """.stripMargin + ) + val fetchSize = { val size = parameters.getOrElse(JDBC_BATCH_FETCH_SIZE, "0").toInt require(size >= 0, @@ -149,7 +183,30 @@ class JDBCOptions( val sessionInitStatement = parameters.get(JDBC_SESSION_INIT_STATEMENT) } +class JdbcOptionsInWrite( + @transient override val parameters: CaseInsensitiveMap[String]) + extends JDBCOptions(parameters) { + + import JDBCOptions._ + + def this(parameters: Map[String, String]) = this(CaseInsensitiveMap(parameters)) + + def this(url: String, table: String, parameters: Map[String, String]) = { + this(CaseInsensitiveMap(parameters ++ Map( + JDBCOptions.JDBC_URL -> url, + JDBCOptions.JDBC_TABLE_NAME -> table))) + } + + require( + parameters.get(JDBC_TABLE_NAME).isDefined, + s"Option '$JDBC_TABLE_NAME' is required. " + + s"Option '$JDBC_QUERY_STRING' is not applicable while writing.") + + val table = parameters(JDBC_TABLE_NAME) +} + object JDBCOptions { + private val curId = new java.util.concurrent.atomic.AtomicLong(0L) private val jdbcOptionNames = collection.mutable.Set[String]() private def newOption(name: String): String = { @@ -159,6 +216,7 @@ object JDBCOptions { val JDBC_URL = newOption("url") val JDBC_TABLE_NAME = newOption("dbtable") + val JDBC_QUERY_STRING = newOption("query") val JDBC_DRIVER_CLASS = newOption("driver") val JDBC_PARTITION_COLUMN = newOption("partitionColumn") val JDBC_LOWER_BOUND = newOption("lowerBound") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala index 0bab3689e5d0e..1b3b17c75e756 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala @@ -51,7 +51,7 @@ object JDBCRDD extends Logging { */ def resolveTable(options: JDBCOptions): StructType = { val url = options.url - val table = options.table + val table = options.tableOrQuery val dialect = JdbcDialects.get(url) val conn: Connection = JdbcUtils.createConnectionFactory(options)() try { @@ -296,7 +296,7 @@ private[jdbc] class JDBCRDD( val myWhereClause = getWhereClause(part) - val sqlText = s"SELECT $columnList FROM ${options.table} $myWhereClause" + val sqlText = s"SELECT $columnList FROM ${options.tableOrQuery} $myWhereClause" stmt = conn.prepareStatement(sqlText, ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY) stmt.setFetchSize(options.fetchSize) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala index b23e5a7722004..97e2d255cb7be 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala @@ -22,10 +22,12 @@ import scala.collection.mutable.ArrayBuffer import org.apache.spark.Partition import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{DataFrame, Row, SaveMode, SparkSession, SQLContext} +import org.apache.spark.sql.{AnalysisException, DataFrame, Row, SaveMode, SparkSession, SQLContext} +import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.jdbc.JdbcDialects import org.apache.spark.sql.sources._ import org.apache.spark.sql.types.StructType +import org.apache.spark.util.Utils /** * Instructions on how to partition the table among workers. @@ -48,10 +50,17 @@ private[sql] object JDBCRelation extends Logging { * Null value predicate is added to the first partition where clause to include * the rows with null value for the partitions column. * + * @param schema resolved schema of a JDBC table * @param partitioning partition information to generate the where clause for each partition + * @param resolver function used to determine if two identifiers are equal + * @param jdbcOptions JDBC options that contains url * @return an array of partitions with where clause for each partition */ - def columnPartition(partitioning: JDBCPartitioningInfo): Array[Partition] = { + def columnPartition( + schema: StructType, + partitioning: JDBCPartitioningInfo, + resolver: Resolver, + jdbcOptions: JDBCOptions): Array[Partition] = { if (partitioning == null || partitioning.numPartitions <= 1 || partitioning.lowerBound == partitioning.upperBound) { return Array[Partition](JDBCPartition(null, 0)) @@ -78,7 +87,10 @@ private[sql] object JDBCRelation extends Logging { // Overflow and silliness can happen if you subtract then divide. // Here we get a little roundoff, but that's (hopefully) OK. val stride: Long = upperBound / numPartitions - lowerBound / numPartitions - val column = partitioning.column + + val column = verifyAndGetNormalizedColumnName( + schema, partitioning.column, resolver, jdbcOptions) + var i: Int = 0 var currentValue: Long = lowerBound val ans = new ArrayBuffer[Partition]() @@ -99,10 +111,57 @@ private[sql] object JDBCRelation extends Logging { } ans.toArray } + + // Verify column name based on the JDBC resolved schema + private def verifyAndGetNormalizedColumnName( + schema: StructType, + columnName: String, + resolver: Resolver, + jdbcOptions: JDBCOptions): String = { + val dialect = JdbcDialects.get(jdbcOptions.url) + schema.map(_.name).find { fieldName => + resolver(fieldName, columnName) || + resolver(dialect.quoteIdentifier(fieldName), columnName) + }.map(dialect.quoteIdentifier).getOrElse { + throw new AnalysisException(s"User-defined partition column $columnName not " + + s"found in the JDBC relation: ${schema.simpleString(Utils.maxNumToStringFields)}") + } + } + + /** + * Takes a (schema, table) specification and returns the table's Catalyst schema. + * If `customSchema` defined in the JDBC options, replaces the schema's dataType with the + * custom schema's type. + * + * @param resolver function used to determine if two identifiers are equal + * @param jdbcOptions JDBC options that contains url, table and other information. + * @return resolved Catalyst schema of a JDBC table + */ + def getSchema(resolver: Resolver, jdbcOptions: JDBCOptions): StructType = { + val tableSchema = JDBCRDD.resolveTable(jdbcOptions) + jdbcOptions.customSchema match { + case Some(customSchema) => JdbcUtils.getCustomSchema( + tableSchema, customSchema, resolver) + case None => tableSchema + } + } + + /** + * Resolves a Catalyst schema of a JDBC table and returns [[JDBCRelation]] with the schema. + */ + def apply( + parts: Array[Partition], + jdbcOptions: JDBCOptions)( + sparkSession: SparkSession): JDBCRelation = { + val schema = JDBCRelation.getSchema(sparkSession.sessionState.conf.resolver, jdbcOptions) + JDBCRelation(schema, parts, jdbcOptions)(sparkSession) + } } private[sql] case class JDBCRelation( - parts: Array[Partition], jdbcOptions: JDBCOptions)(@transient val sparkSession: SparkSession) + override val schema: StructType, + parts: Array[Partition], + jdbcOptions: JDBCOptions)(@transient val sparkSession: SparkSession) extends BaseRelation with PrunedFilteredScan with InsertableRelation { @@ -111,15 +170,6 @@ private[sql] case class JDBCRelation( override val needConversion: Boolean = false - override val schema: StructType = { - val tableSchema = JDBCRDD.resolveTable(jdbcOptions) - jdbcOptions.customSchema match { - case Some(customSchema) => JdbcUtils.getCustomSchema( - tableSchema, customSchema, sparkSession.sessionState.conf.resolver) - case None => tableSchema - } - } - // Check if JDBCRDD.compileFilter can accept input filters override def unhandledFilters(filters: Array[Filter]): Array[Filter] = { filters.filter(JDBCRDD.compileFilter(_, JdbcDialects.get(jdbcOptions.url)).isEmpty) @@ -139,12 +189,12 @@ private[sql] case class JDBCRelation( override def insert(data: DataFrame, overwrite: Boolean): Unit = { data.write .mode(if (overwrite) SaveMode.Overwrite else SaveMode.Append) - .jdbc(jdbcOptions.url, jdbcOptions.table, jdbcOptions.asProperties) + .jdbc(jdbcOptions.url, jdbcOptions.tableOrQuery, jdbcOptions.asProperties) } override def toString: String = { val partitioningInfo = if (parts.nonEmpty) s" [numPartitions=${parts.length}]" else "" // credentials should not be included in the plan output, table information is sufficient. - s"JDBCRelation(${jdbcOptions.table})" + partitioningInfo + s"JDBCRelation(${jdbcOptions.tableOrQuery})" + partitioningInfo } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala index f8c5677ea0f2a..782d626c1573c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala @@ -48,8 +48,10 @@ class JdbcRelationProvider extends CreatableRelationProvider JDBCPartitioningInfo( partitionColumn.get, lowerBound.get, upperBound.get, numPartitions.get) } - val parts = JDBCRelation.columnPartition(partitionInfo) - JDBCRelation(parts, jdbcOptions)(sqlContext.sparkSession) + val resolver = sqlContext.conf.resolver + val schema = JDBCRelation.getSchema(resolver, jdbcOptions) + val parts = JDBCRelation.columnPartition(schema, partitionInfo, resolver, jdbcOptions) + JDBCRelation(schema, parts, jdbcOptions)(sqlContext.sparkSession) } override def createRelation( @@ -57,7 +59,7 @@ class JdbcRelationProvider extends CreatableRelationProvider mode: SaveMode, parameters: Map[String, String], df: DataFrame): BaseRelation = { - val options = new JDBCOptions(parameters) + val options = new JdbcOptionsInWrite(parameters) val isCaseSensitive = sqlContext.conf.caseSensitiveAnalysis val conn = JdbcUtils.createConnectionFactory(options)() @@ -84,7 +86,8 @@ class JdbcRelationProvider extends CreatableRelationProvider case SaveMode.ErrorIfExists => throw new AnalysisException( - s"Table or view '${options.table}' already exists. SaveMode: ErrorIfExists.") + s"Table or view '${options.table}' already exists. " + + s"SaveMode: ErrorIfExists.") case SaveMode.Ignore => // With `SaveMode.Ignore` mode, if table already exists, the save operation is expected diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala index 433443007cfd8..b81737eda475b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala @@ -67,7 +67,7 @@ object JdbcUtils extends Logging { /** * Returns true if the table already exists in the JDBC database. */ - def tableExists(conn: Connection, options: JDBCOptions): Boolean = { + def tableExists(conn: Connection, options: JdbcOptionsInWrite): Boolean = { val dialect = JdbcDialects.get(options.url) // Somewhat hacky, but there isn't a good way to identify whether a table exists for all @@ -100,7 +100,7 @@ object JdbcUtils extends Logging { /** * Truncates a table from the JDBC database without side effects. */ - def truncateTable(conn: Connection, options: JDBCOptions): Unit = { + def truncateTable(conn: Connection, options: JdbcOptionsInWrite): Unit = { val dialect = JdbcDialects.get(options.url) val statement = conn.createStatement try { @@ -255,7 +255,7 @@ object JdbcUtils extends Logging { val dialect = JdbcDialects.get(options.url) try { - val statement = conn.prepareStatement(dialect.getSchemaQuery(options.table)) + val statement = conn.prepareStatement(dialect.getSchemaQuery(options.tableOrQuery)) try { statement.setQueryTimeout(options.queryTimeout) Some(getSchema(statement.executeQuery(), dialect)) @@ -809,7 +809,7 @@ object JdbcUtils extends Logging { df: DataFrame, tableSchema: Option[StructType], isCaseSensitive: Boolean, - options: JDBCOptions): Unit = { + options: JdbcOptionsInWrite): Unit = { val url = options.url val table = options.table val dialect = JdbcDialects.get(url) @@ -838,7 +838,7 @@ object JdbcUtils extends Logging { def createTable( conn: Connection, df: DataFrame, - options: JDBCOptions): Unit = { + options: JdbcOptionsInWrite): Unit = { val strSchema = schemaString( df, options.url, options.createTableColumnTypes) val table = options.table diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala index 3b6df45e949e8..2fee2128ba1f9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala @@ -33,7 +33,7 @@ import org.apache.spark.input.{PortableDataStream, StreamInputFormat} import org.apache.spark.rdd.{BinaryFileRDD, RDD} import org.apache.spark.sql.{Dataset, Encoders, SparkSession} import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.json.{CreateJacksonParser, JacksonParser, JSONOptions} +import org.apache.spark.sql.catalyst.json.{CreateJacksonParser, JacksonParser, JsonInferSchema, JSONOptions} import org.apache.spark.sql.execution.SQLExecution import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.execution.datasources.text.TextFileFormat diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala index 3b04510d29695..a9241afba537b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala @@ -26,11 +26,11 @@ import org.apache.hadoop.mapreduce.{Job, TaskAttemptContext} import org.apache.spark.internal.Logging import org.apache.spark.sql.{AnalysisException, SparkSession} import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.json.{JacksonGenerator, JacksonParser, JSONOptions} +import org.apache.spark.sql.catalyst.json.{JacksonGenerator, JacksonParser, JSONOptions, JSONOptionsInRead} import org.apache.spark.sql.catalyst.util.CompressionCodecs import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.sources._ -import org.apache.spark.sql.types.{StringType, StructType} +import org.apache.spark.sql.types._ import org.apache.spark.util.SerializableConfiguration class JsonFileFormat extends TextBasedFileFormat with DataSourceRegister { @@ -40,7 +40,7 @@ class JsonFileFormat extends TextBasedFileFormat with DataSourceRegister { sparkSession: SparkSession, options: Map[String, String], path: Path): Boolean = { - val parsedOptions = new JSONOptions( + val parsedOptions = new JSONOptionsInRead( options, sparkSession.sessionState.conf.sessionLocalTimeZone, sparkSession.sessionState.conf.columnNameOfCorruptRecord) @@ -52,7 +52,7 @@ class JsonFileFormat extends TextBasedFileFormat with DataSourceRegister { sparkSession: SparkSession, options: Map[String, String], files: Seq[FileStatus]): Option[StructType] = { - val parsedOptions = new JSONOptions( + val parsedOptions = new JSONOptionsInRead( options, sparkSession.sessionState.conf.sessionLocalTimeZone, sparkSession.sessionState.conf.columnNameOfCorruptRecord) @@ -99,7 +99,7 @@ class JsonFileFormat extends TextBasedFileFormat with DataSourceRegister { val broadcastedHadoopConf = sparkSession.sparkContext.broadcast(new SerializableConfiguration(hadoopConf)) - val parsedOptions = new JSONOptions( + val parsedOptions = new JSONOptionsInRead( options, sparkSession.sessionState.conf.sessionLocalTimeZone, sparkSession.sessionState.conf.columnNameOfCorruptRecord) @@ -144,6 +144,23 @@ class JsonFileFormat extends TextBasedFileFormat with DataSourceRegister { override def hashCode(): Int = getClass.hashCode() override def equals(other: Any): Boolean = other.isInstanceOf[JsonFileFormat] + + override def supportDataType(dataType: DataType, isReadPath: Boolean): Boolean = dataType match { + case _: AtomicType => true + + case st: StructType => st.forall { f => supportDataType(f.dataType, isReadPath) } + + case ArrayType(elementType, _) => supportDataType(elementType, isReadPath) + + case MapType(keyType, valueType, _) => + supportDataType(keyType, isReadPath) && supportDataType(valueType, isReadPath) + + case udt: UserDefinedType[_] => supportDataType(udt.sqlType, isReadPath) + + case _: NullType => true + + case _ => false + } } private[json] class JsonOutputWriter( @@ -158,6 +175,11 @@ private[json] class JsonOutputWriter( case None => StandardCharsets.UTF_8 } + if (JSONOptionsInRead.blacklist.contains(encoding)) { + logWarning(s"The JSON file ($path) was written in the encoding ${encoding.displayName()}" + + " which can be read back by Spark only if multiLine is enabled.") + } + private val writer = CodecStreams.createOutputStreamWriter( context, new Path(path), encoding) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala index 1de2ca2914c44..3a8c0add8c2f9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala @@ -224,4 +224,21 @@ class OrcFileFormat } } } + + override def supportDataType(dataType: DataType, isReadPath: Boolean): Boolean = dataType match { + case _: AtomicType => true + + case st: StructType => st.forall { f => supportDataType(f.dataType, isReadPath) } + + case ArrayType(elementType, _) => supportDataType(elementType, isReadPath) + + case MapType(keyType, valueType, _) => + supportDataType(keyType, isReadPath) && supportDataType(valueType, isReadPath) + + case udt: UserDefinedType[_] => supportDataType(udt.sqlType, isReadPath) + + case _: NullType => isReadPath + + case _ => false + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala index 60fc9ec7e1f82..295960b1c2d30 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala @@ -78,7 +78,6 @@ class ParquetFileFormat job: Job, options: Map[String, String], dataSchema: StructType): OutputWriterFactory = { - val parquetOptions = new ParquetOptions(options, sparkSession.sessionState.conf) val conf = ContextUtil.getConfiguration(job) @@ -335,38 +334,28 @@ class ParquetFileFormat val enableVectorizedReader: Boolean = sqlConf.parquetVectorizedReaderEnabled && resultSchema.forall(_.dataType.isInstanceOf[AtomicType]) - val enableRecordFilter: Boolean = - sparkSession.sessionState.conf.parquetRecordFilterEnabled - val timestampConversion: Boolean = - sparkSession.sessionState.conf.isParquetINT96TimestampConversion + val enableRecordFilter: Boolean = sqlConf.parquetRecordFilterEnabled + val timestampConversion: Boolean = sqlConf.isParquetINT96TimestampConversion val capacity = sqlConf.parquetVectorizedReaderBatchSize - val enableParquetFilterPushDown: Boolean = - sparkSession.sessionState.conf.parquetFilterPushDown + val enableParquetFilterPushDown: Boolean = sqlConf.parquetFilterPushDown // Whole stage codegen (PhysicalRDD) is able to deal with batches directly val returningBatch = supportBatch(sparkSession, resultSchema) val pushDownDate = sqlConf.parquetFilterPushDownDate + val pushDownTimestamp = sqlConf.parquetFilterPushDownTimestamp + val pushDownDecimal = sqlConf.parquetFilterPushDownDecimal + val pushDownStringStartWith = sqlConf.parquetFilterPushDownStringStartWith + val pushDownInFilterThreshold = sqlConf.parquetFilterPushDownInFilterThreshold (file: PartitionedFile) => { assert(file.partitionValues.numFields == partitionSchema.size) - // Try to push down filters when filter push-down is enabled. - val pushed = if (enableParquetFilterPushDown) { - filters - // Collects all converted Parquet filter predicates. Notice that not all predicates can be - // converted (`ParquetFilters.createFilter` returns an `Option`). That's why a `flatMap` - // is used here. - .flatMap(new ParquetFilters(pushDownDate).createFilter(requiredSchema, _)) - .reduceOption(FilterApi.and) - } else { - None - } - val fileSplit = new FileSplit(new Path(new URI(file.filePath)), file.start, file.length, Array.empty) + val filePath = fileSplit.getPath val split = new org.apache.parquet.hadoop.ParquetInputSplit( - fileSplit.getPath, + filePath, fileSplit.getStart, fileSplit.getStart + fileSplit.getLength, fileSplit.getLength, @@ -374,12 +363,29 @@ class ParquetFileFormat null) val sharedConf = broadcastedHadoopConf.value.value + + // Try to push down filters when filter push-down is enabled. + val pushed = if (enableParquetFilterPushDown) { + val parquetSchema = ParquetFileReader.readFooter(sharedConf, filePath, SKIP_ROW_GROUPS) + .getFileMetaData.getSchema + val parquetFilters = new ParquetFilters(pushDownDate, pushDownTimestamp, pushDownDecimal, + pushDownStringStartWith, pushDownInFilterThreshold) + filters + // Collects all converted Parquet filter predicates. Notice that not all predicates can be + // converted (`ParquetFilters.createFilter` returns an `Option`). That's why a `flatMap` + // is used here. + .flatMap(parquetFilters.createFilter(parquetSchema, _)) + .reduceOption(FilterApi.and) + } else { + None + } + // PARQUET_INT96_TIMESTAMP_CONVERSION says to apply timezone conversions to int96 timestamps' // *only* if the file was created by something other than "parquet-mr", so check the actual // writer here for this file. We have to do this per-file, as each file in the table may // have different writers. def isCreatedByParquetMr(): Boolean = { - val footer = ParquetFileReader.readFooter(sharedConf, fileSplit.getPath, SKIP_ROW_GROUPS) + val footer = ParquetFileReader.readFooter(sharedConf, filePath, SKIP_ROW_GROUPS) footer.getFileMetaData().getCreatedBy().startsWith("parquet-mr") } val convertTz = @@ -445,6 +451,21 @@ class ParquetFileFormat } } } + + override def supportDataType(dataType: DataType, isReadPath: Boolean): Boolean = dataType match { + case _: AtomicType => true + + case st: StructType => st.forall { f => supportDataType(f.dataType, isReadPath) } + + case ArrayType(elementType, _) => supportDataType(elementType, isReadPath) + + case MapType(keyType, valueType, _) => + supportDataType(keyType, isReadPath) && supportDataType(valueType, isReadPath) + + case udt: UserDefinedType[_] => supportDataType(udt.sqlType, isReadPath) + + case _ => false + } } object ParquetFileFormat extends Logging { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala index 310626197a763..58b4a769fcb62 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala @@ -17,195 +17,399 @@ package org.apache.spark.sql.execution.datasources.parquet -import java.sql.Date +import java.lang.{Boolean => JBoolean, Double => JDouble, Float => JFloat, Long => JLong} +import java.math.{BigDecimal => JBigDecimal} +import java.sql.{Date, Timestamp} + +import scala.collection.JavaConverters.asScalaBufferConverter import org.apache.parquet.filter2.predicate._ import org.apache.parquet.filter2.predicate.FilterApi._ import org.apache.parquet.io.api.Binary +import org.apache.parquet.schema.{DecimalMetadata, MessageType, OriginalType, PrimitiveComparator} +import org.apache.parquet.schema.OriginalType._ +import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName +import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName._ import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.catalyst.util.DateTimeUtils.SQLDate import org.apache.spark.sql.sources -import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String /** * Some utility function to convert Spark data source filters to Parquet filters. */ -private[parquet] class ParquetFilters(pushDownDate: Boolean) { +private[parquet] class ParquetFilters( + pushDownDate: Boolean, + pushDownTimestamp: Boolean, + pushDownDecimal: Boolean, + pushDownStartWith: Boolean, + pushDownInFilterThreshold: Int) { + + private case class ParquetSchemaType( + originalType: OriginalType, + primitiveTypeName: PrimitiveTypeName, + length: Int, + decimalMetadata: DecimalMetadata) + + private val ParquetBooleanType = ParquetSchemaType(null, BOOLEAN, 0, null) + private val ParquetByteType = ParquetSchemaType(INT_8, INT32, 0, null) + private val ParquetShortType = ParquetSchemaType(INT_16, INT32, 0, null) + private val ParquetIntegerType = ParquetSchemaType(null, INT32, 0, null) + private val ParquetLongType = ParquetSchemaType(null, INT64, 0, null) + private val ParquetFloatType = ParquetSchemaType(null, FLOAT, 0, null) + private val ParquetDoubleType = ParquetSchemaType(null, DOUBLE, 0, null) + private val ParquetStringType = ParquetSchemaType(UTF8, BINARY, 0, null) + private val ParquetBinaryType = ParquetSchemaType(null, BINARY, 0, null) + private val ParquetDateType = ParquetSchemaType(DATE, INT32, 0, null) + private val ParquetTimestampMicrosType = ParquetSchemaType(TIMESTAMP_MICROS, INT64, 0, null) + private val ParquetTimestampMillisType = ParquetSchemaType(TIMESTAMP_MILLIS, INT64, 0, null) private def dateToDays(date: Date): SQLDate = { DateTimeUtils.fromJavaDate(date) } - private val makeEq: PartialFunction[DataType, (String, Any) => FilterPredicate] = { - case BooleanType => - (n: String, v: Any) => FilterApi.eq(booleanColumn(n), v.asInstanceOf[java.lang.Boolean]) - case IntegerType => - (n: String, v: Any) => FilterApi.eq(intColumn(n), v.asInstanceOf[Integer]) - case LongType => - (n: String, v: Any) => FilterApi.eq(longColumn(n), v.asInstanceOf[java.lang.Long]) - case FloatType => - (n: String, v: Any) => FilterApi.eq(floatColumn(n), v.asInstanceOf[java.lang.Float]) - case DoubleType => - (n: String, v: Any) => FilterApi.eq(doubleColumn(n), v.asInstanceOf[java.lang.Double]) + private def decimalToInt32(decimal: JBigDecimal): Integer = decimal.unscaledValue().intValue() + + private def decimalToInt64(decimal: JBigDecimal): JLong = decimal.unscaledValue().longValue() + + private def decimalToByteArray(decimal: JBigDecimal, numBytes: Int): Binary = { + val decimalBuffer = new Array[Byte](numBytes) + val bytes = decimal.unscaledValue().toByteArray + + val fixedLengthBytes = if (bytes.length == numBytes) { + bytes + } else { + val signByte = if (bytes.head < 0) -1: Byte else 0: Byte + java.util.Arrays.fill(decimalBuffer, 0, numBytes - bytes.length, signByte) + System.arraycopy(bytes, 0, decimalBuffer, numBytes - bytes.length, bytes.length) + decimalBuffer + } + Binary.fromConstantByteArray(fixedLengthBytes, 0, numBytes) + } + + private val makeEq: PartialFunction[ParquetSchemaType, (String, Any) => FilterPredicate] = { + case ParquetBooleanType => + (n: String, v: Any) => FilterApi.eq(booleanColumn(n), v.asInstanceOf[JBoolean]) + case ParquetByteType | ParquetShortType | ParquetIntegerType => + (n: String, v: Any) => FilterApi.eq( + intColumn(n), + Option(v).map(_.asInstanceOf[Number].intValue.asInstanceOf[Integer]).orNull) + case ParquetLongType => + (n: String, v: Any) => FilterApi.eq(longColumn(n), v.asInstanceOf[JLong]) + case ParquetFloatType => + (n: String, v: Any) => FilterApi.eq(floatColumn(n), v.asInstanceOf[JFloat]) + case ParquetDoubleType => + (n: String, v: Any) => FilterApi.eq(doubleColumn(n), v.asInstanceOf[JDouble]) // Binary.fromString and Binary.fromByteArray don't accept null values - case StringType => + case ParquetStringType => (n: String, v: Any) => FilterApi.eq( binaryColumn(n), Option(v).map(s => Binary.fromString(s.asInstanceOf[String])).orNull) - case BinaryType => + case ParquetBinaryType => (n: String, v: Any) => FilterApi.eq( binaryColumn(n), Option(v).map(b => Binary.fromReusedByteArray(v.asInstanceOf[Array[Byte]])).orNull) - case DateType if pushDownDate => + case ParquetDateType if pushDownDate => (n: String, v: Any) => FilterApi.eq( intColumn(n), Option(v).map(date => dateToDays(date.asInstanceOf[Date]).asInstanceOf[Integer]).orNull) + case ParquetTimestampMicrosType if pushDownTimestamp => + (n: String, v: Any) => FilterApi.eq( + longColumn(n), + Option(v).map(t => DateTimeUtils.fromJavaTimestamp(t.asInstanceOf[Timestamp]) + .asInstanceOf[JLong]).orNull) + case ParquetTimestampMillisType if pushDownTimestamp => + (n: String, v: Any) => FilterApi.eq( + longColumn(n), + Option(v).map(_.asInstanceOf[Timestamp].getTime.asInstanceOf[JLong]).orNull) + + case ParquetSchemaType(DECIMAL, INT32, _, _) if pushDownDecimal => + (n: String, v: Any) => FilterApi.eq( + intColumn(n), + Option(v).map(d => decimalToInt32(d.asInstanceOf[JBigDecimal])).orNull) + case ParquetSchemaType(DECIMAL, INT64, _, _) if pushDownDecimal => + (n: String, v: Any) => FilterApi.eq( + longColumn(n), + Option(v).map(d => decimalToInt64(d.asInstanceOf[JBigDecimal])).orNull) + case ParquetSchemaType(DECIMAL, FIXED_LEN_BYTE_ARRAY, length, _) if pushDownDecimal => + (n: String, v: Any) => FilterApi.eq( + binaryColumn(n), + Option(v).map(d => decimalToByteArray(d.asInstanceOf[JBigDecimal], length)).orNull) } - private val makeNotEq: PartialFunction[DataType, (String, Any) => FilterPredicate] = { - case BooleanType => - (n: String, v: Any) => FilterApi.notEq(booleanColumn(n), v.asInstanceOf[java.lang.Boolean]) - case IntegerType => - (n: String, v: Any) => FilterApi.notEq(intColumn(n), v.asInstanceOf[Integer]) - case LongType => - (n: String, v: Any) => FilterApi.notEq(longColumn(n), v.asInstanceOf[java.lang.Long]) - case FloatType => - (n: String, v: Any) => FilterApi.notEq(floatColumn(n), v.asInstanceOf[java.lang.Float]) - case DoubleType => - (n: String, v: Any) => FilterApi.notEq(doubleColumn(n), v.asInstanceOf[java.lang.Double]) - - case StringType => + private val makeNotEq: PartialFunction[ParquetSchemaType, (String, Any) => FilterPredicate] = { + case ParquetBooleanType => + (n: String, v: Any) => FilterApi.notEq(booleanColumn(n), v.asInstanceOf[JBoolean]) + case ParquetByteType | ParquetShortType | ParquetIntegerType => + (n: String, v: Any) => FilterApi.notEq( + intColumn(n), + Option(v).map(_.asInstanceOf[Number].intValue.asInstanceOf[Integer]).orNull) + case ParquetLongType => + (n: String, v: Any) => FilterApi.notEq(longColumn(n), v.asInstanceOf[JLong]) + case ParquetFloatType => + (n: String, v: Any) => FilterApi.notEq(floatColumn(n), v.asInstanceOf[JFloat]) + case ParquetDoubleType => + (n: String, v: Any) => FilterApi.notEq(doubleColumn(n), v.asInstanceOf[JDouble]) + + case ParquetStringType => (n: String, v: Any) => FilterApi.notEq( binaryColumn(n), Option(v).map(s => Binary.fromString(s.asInstanceOf[String])).orNull) - case BinaryType => + case ParquetBinaryType => (n: String, v: Any) => FilterApi.notEq( binaryColumn(n), Option(v).map(b => Binary.fromReusedByteArray(v.asInstanceOf[Array[Byte]])).orNull) - case DateType if pushDownDate => + case ParquetDateType if pushDownDate => (n: String, v: Any) => FilterApi.notEq( intColumn(n), Option(v).map(date => dateToDays(date.asInstanceOf[Date]).asInstanceOf[Integer]).orNull) + case ParquetTimestampMicrosType if pushDownTimestamp => + (n: String, v: Any) => FilterApi.notEq( + longColumn(n), + Option(v).map(t => DateTimeUtils.fromJavaTimestamp(t.asInstanceOf[Timestamp]) + .asInstanceOf[JLong]).orNull) + case ParquetTimestampMillisType if pushDownTimestamp => + (n: String, v: Any) => FilterApi.notEq( + longColumn(n), + Option(v).map(_.asInstanceOf[Timestamp].getTime.asInstanceOf[JLong]).orNull) + + case ParquetSchemaType(DECIMAL, INT32, _, _) if pushDownDecimal => + (n: String, v: Any) => FilterApi.notEq( + intColumn(n), + Option(v).map(d => decimalToInt32(d.asInstanceOf[JBigDecimal])).orNull) + case ParquetSchemaType(DECIMAL, INT64, _, _) if pushDownDecimal => + (n: String, v: Any) => FilterApi.notEq( + longColumn(n), + Option(v).map(d => decimalToInt64(d.asInstanceOf[JBigDecimal])).orNull) + case ParquetSchemaType(DECIMAL, FIXED_LEN_BYTE_ARRAY, length, _) if pushDownDecimal => + (n: String, v: Any) => FilterApi.notEq( + binaryColumn(n), + Option(v).map(d => decimalToByteArray(d.asInstanceOf[JBigDecimal], length)).orNull) } - private val makeLt: PartialFunction[DataType, (String, Any) => FilterPredicate] = { - case IntegerType => - (n: String, v: Any) => FilterApi.lt(intColumn(n), v.asInstanceOf[Integer]) - case LongType => - (n: String, v: Any) => FilterApi.lt(longColumn(n), v.asInstanceOf[java.lang.Long]) - case FloatType => - (n: String, v: Any) => FilterApi.lt(floatColumn(n), v.asInstanceOf[java.lang.Float]) - case DoubleType => - (n: String, v: Any) => FilterApi.lt(doubleColumn(n), v.asInstanceOf[java.lang.Double]) - - case StringType => + private val makeLt: PartialFunction[ParquetSchemaType, (String, Any) => FilterPredicate] = { + case ParquetByteType | ParquetShortType | ParquetIntegerType => (n: String, v: Any) => - FilterApi.lt(binaryColumn(n), - Binary.fromString(v.asInstanceOf[String])) - case BinaryType => + FilterApi.lt(intColumn(n), v.asInstanceOf[Number].intValue.asInstanceOf[Integer]) + case ParquetLongType => + (n: String, v: Any) => FilterApi.lt(longColumn(n), v.asInstanceOf[JLong]) + case ParquetFloatType => + (n: String, v: Any) => FilterApi.lt(floatColumn(n), v.asInstanceOf[JFloat]) + case ParquetDoubleType => + (n: String, v: Any) => FilterApi.lt(doubleColumn(n), v.asInstanceOf[JDouble]) + + case ParquetStringType => + (n: String, v: Any) => + FilterApi.lt(binaryColumn(n), Binary.fromString(v.asInstanceOf[String])) + case ParquetBinaryType => (n: String, v: Any) => FilterApi.lt(binaryColumn(n), Binary.fromReusedByteArray(v.asInstanceOf[Array[Byte]])) - case DateType if pushDownDate => + case ParquetDateType if pushDownDate => + (n: String, v: Any) => + FilterApi.lt(intColumn(n), dateToDays(v.asInstanceOf[Date]).asInstanceOf[Integer]) + case ParquetTimestampMicrosType if pushDownTimestamp => (n: String, v: Any) => FilterApi.lt( - intColumn(n), - Option(v).map(date => dateToDays(date.asInstanceOf[Date]).asInstanceOf[Integer]).orNull) + longColumn(n), + DateTimeUtils.fromJavaTimestamp(v.asInstanceOf[Timestamp]).asInstanceOf[JLong]) + case ParquetTimestampMillisType if pushDownTimestamp => + (n: String, v: Any) => FilterApi.lt( + longColumn(n), + v.asInstanceOf[Timestamp].getTime.asInstanceOf[JLong]) + + case ParquetSchemaType(DECIMAL, INT32, _, _) if pushDownDecimal => + (n: String, v: Any) => + FilterApi.lt(intColumn(n), decimalToInt32(v.asInstanceOf[JBigDecimal])) + case ParquetSchemaType(DECIMAL, INT64, _, _) if pushDownDecimal => + (n: String, v: Any) => + FilterApi.lt(longColumn(n), decimalToInt64(v.asInstanceOf[JBigDecimal])) + case ParquetSchemaType(DECIMAL, FIXED_LEN_BYTE_ARRAY, length, _) if pushDownDecimal => + (n: String, v: Any) => + FilterApi.lt(binaryColumn(n), decimalToByteArray(v.asInstanceOf[JBigDecimal], length)) } - private val makeLtEq: PartialFunction[DataType, (String, Any) => FilterPredicate] = { - case IntegerType => - (n: String, v: Any) => FilterApi.ltEq(intColumn(n), v.asInstanceOf[java.lang.Integer]) - case LongType => - (n: String, v: Any) => FilterApi.ltEq(longColumn(n), v.asInstanceOf[java.lang.Long]) - case FloatType => - (n: String, v: Any) => FilterApi.ltEq(floatColumn(n), v.asInstanceOf[java.lang.Float]) - case DoubleType => - (n: String, v: Any) => FilterApi.ltEq(doubleColumn(n), v.asInstanceOf[java.lang.Double]) - - case StringType => + private val makeLtEq: PartialFunction[ParquetSchemaType, (String, Any) => FilterPredicate] = { + case ParquetByteType | ParquetShortType | ParquetIntegerType => + (n: String, v: Any) => + FilterApi.ltEq(intColumn(n), v.asInstanceOf[Number].intValue.asInstanceOf[Integer]) + case ParquetLongType => + (n: String, v: Any) => FilterApi.ltEq(longColumn(n), v.asInstanceOf[JLong]) + case ParquetFloatType => + (n: String, v: Any) => FilterApi.ltEq(floatColumn(n), v.asInstanceOf[JFloat]) + case ParquetDoubleType => + (n: String, v: Any) => FilterApi.ltEq(doubleColumn(n), v.asInstanceOf[JDouble]) + + case ParquetStringType => (n: String, v: Any) => - FilterApi.ltEq(binaryColumn(n), - Binary.fromString(v.asInstanceOf[String])) - case BinaryType => + FilterApi.ltEq(binaryColumn(n), Binary.fromString(v.asInstanceOf[String])) + case ParquetBinaryType => (n: String, v: Any) => FilterApi.ltEq(binaryColumn(n), Binary.fromReusedByteArray(v.asInstanceOf[Array[Byte]])) - case DateType if pushDownDate => + case ParquetDateType if pushDownDate => + (n: String, v: Any) => + FilterApi.ltEq(intColumn(n), dateToDays(v.asInstanceOf[Date]).asInstanceOf[Integer]) + case ParquetTimestampMicrosType if pushDownTimestamp => (n: String, v: Any) => FilterApi.ltEq( - intColumn(n), - Option(v).map(date => dateToDays(date.asInstanceOf[Date]).asInstanceOf[Integer]).orNull) + longColumn(n), + DateTimeUtils.fromJavaTimestamp(v.asInstanceOf[Timestamp]).asInstanceOf[JLong]) + case ParquetTimestampMillisType if pushDownTimestamp => + (n: String, v: Any) => FilterApi.ltEq( + longColumn(n), + v.asInstanceOf[Timestamp].getTime.asInstanceOf[JLong]) + + case ParquetSchemaType(DECIMAL, INT32, _, _) if pushDownDecimal => + (n: String, v: Any) => + FilterApi.ltEq(intColumn(n), decimalToInt32(v.asInstanceOf[JBigDecimal])) + case ParquetSchemaType(DECIMAL, INT64, _, _) if pushDownDecimal => + (n: String, v: Any) => + FilterApi.ltEq(longColumn(n), decimalToInt64(v.asInstanceOf[JBigDecimal])) + case ParquetSchemaType(DECIMAL, FIXED_LEN_BYTE_ARRAY, length, _) if pushDownDecimal => + (n: String, v: Any) => + FilterApi.ltEq(binaryColumn(n), decimalToByteArray(v.asInstanceOf[JBigDecimal], length)) } - private val makeGt: PartialFunction[DataType, (String, Any) => FilterPredicate] = { - case IntegerType => - (n: String, v: Any) => FilterApi.gt(intColumn(n), v.asInstanceOf[java.lang.Integer]) - case LongType => - (n: String, v: Any) => FilterApi.gt(longColumn(n), v.asInstanceOf[java.lang.Long]) - case FloatType => - (n: String, v: Any) => FilterApi.gt(floatColumn(n), v.asInstanceOf[java.lang.Float]) - case DoubleType => - (n: String, v: Any) => FilterApi.gt(doubleColumn(n), v.asInstanceOf[java.lang.Double]) - - case StringType => + private val makeGt: PartialFunction[ParquetSchemaType, (String, Any) => FilterPredicate] = { + case ParquetByteType | ParquetShortType | ParquetIntegerType => + (n: String, v: Any) => + FilterApi.gt(intColumn(n), v.asInstanceOf[Number].intValue.asInstanceOf[Integer]) + case ParquetLongType => + (n: String, v: Any) => FilterApi.gt(longColumn(n), v.asInstanceOf[JLong]) + case ParquetFloatType => + (n: String, v: Any) => FilterApi.gt(floatColumn(n), v.asInstanceOf[JFloat]) + case ParquetDoubleType => + (n: String, v: Any) => FilterApi.gt(doubleColumn(n), v.asInstanceOf[JDouble]) + + case ParquetStringType => (n: String, v: Any) => - FilterApi.gt(binaryColumn(n), - Binary.fromString(v.asInstanceOf[String])) - case BinaryType => + FilterApi.gt(binaryColumn(n), Binary.fromString(v.asInstanceOf[String])) + case ParquetBinaryType => (n: String, v: Any) => FilterApi.gt(binaryColumn(n), Binary.fromReusedByteArray(v.asInstanceOf[Array[Byte]])) - case DateType if pushDownDate => + case ParquetDateType if pushDownDate => + (n: String, v: Any) => + FilterApi.gt(intColumn(n), dateToDays(v.asInstanceOf[Date]).asInstanceOf[Integer]) + case ParquetTimestampMicrosType if pushDownTimestamp => (n: String, v: Any) => FilterApi.gt( - intColumn(n), - Option(v).map(date => dateToDays(date.asInstanceOf[Date]).asInstanceOf[Integer]).orNull) + longColumn(n), + DateTimeUtils.fromJavaTimestamp(v.asInstanceOf[Timestamp]).asInstanceOf[JLong]) + case ParquetTimestampMillisType if pushDownTimestamp => + (n: String, v: Any) => FilterApi.gt( + longColumn(n), + v.asInstanceOf[Timestamp].getTime.asInstanceOf[JLong]) + + case ParquetSchemaType(DECIMAL, INT32, _, _) if pushDownDecimal => + (n: String, v: Any) => + FilterApi.gt(intColumn(n), decimalToInt32(v.asInstanceOf[JBigDecimal])) + case ParquetSchemaType(DECIMAL, INT64, _, _) if pushDownDecimal => + (n: String, v: Any) => + FilterApi.gt(longColumn(n), decimalToInt64(v.asInstanceOf[JBigDecimal])) + case ParquetSchemaType(DECIMAL, FIXED_LEN_BYTE_ARRAY, length, _) if pushDownDecimal => + (n: String, v: Any) => + FilterApi.gt(binaryColumn(n), decimalToByteArray(v.asInstanceOf[JBigDecimal], length)) } - private val makeGtEq: PartialFunction[DataType, (String, Any) => FilterPredicate] = { - case IntegerType => - (n: String, v: Any) => FilterApi.gtEq(intColumn(n), v.asInstanceOf[java.lang.Integer]) - case LongType => - (n: String, v: Any) => FilterApi.gtEq(longColumn(n), v.asInstanceOf[java.lang.Long]) - case FloatType => - (n: String, v: Any) => FilterApi.gtEq(floatColumn(n), v.asInstanceOf[java.lang.Float]) - case DoubleType => - (n: String, v: Any) => FilterApi.gtEq(doubleColumn(n), v.asInstanceOf[java.lang.Double]) - - case StringType => + private val makeGtEq: PartialFunction[ParquetSchemaType, (String, Any) => FilterPredicate] = { + case ParquetByteType | ParquetShortType | ParquetIntegerType => + (n: String, v: Any) => + FilterApi.gtEq(intColumn(n), v.asInstanceOf[Number].intValue.asInstanceOf[Integer]) + case ParquetLongType => + (n: String, v: Any) => FilterApi.gtEq(longColumn(n), v.asInstanceOf[JLong]) + case ParquetFloatType => + (n: String, v: Any) => FilterApi.gtEq(floatColumn(n), v.asInstanceOf[JFloat]) + case ParquetDoubleType => + (n: String, v: Any) => FilterApi.gtEq(doubleColumn(n), v.asInstanceOf[JDouble]) + + case ParquetStringType => (n: String, v: Any) => - FilterApi.gtEq(binaryColumn(n), - Binary.fromString(v.asInstanceOf[String])) - case BinaryType => + FilterApi.gtEq(binaryColumn(n), Binary.fromString(v.asInstanceOf[String])) + case ParquetBinaryType => (n: String, v: Any) => FilterApi.gtEq(binaryColumn(n), Binary.fromReusedByteArray(v.asInstanceOf[Array[Byte]])) - case DateType if pushDownDate => + case ParquetDateType if pushDownDate => + (n: String, v: Any) => + FilterApi.gtEq(intColumn(n), dateToDays(v.asInstanceOf[Date]).asInstanceOf[Integer]) + case ParquetTimestampMicrosType if pushDownTimestamp => (n: String, v: Any) => FilterApi.gtEq( - intColumn(n), - Option(v).map(date => dateToDays(date.asInstanceOf[Date]).asInstanceOf[Integer]).orNull) + longColumn(n), + DateTimeUtils.fromJavaTimestamp(v.asInstanceOf[Timestamp]).asInstanceOf[JLong]) + case ParquetTimestampMillisType if pushDownTimestamp => + (n: String, v: Any) => FilterApi.gtEq( + longColumn(n), + v.asInstanceOf[Timestamp].getTime.asInstanceOf[JLong]) + + case ParquetSchemaType(DECIMAL, INT32, _, _) if pushDownDecimal => + (n: String, v: Any) => + FilterApi.gtEq(intColumn(n), decimalToInt32(v.asInstanceOf[JBigDecimal])) + case ParquetSchemaType(DECIMAL, INT64, _, _) if pushDownDecimal => + (n: String, v: Any) => + FilterApi.gtEq(longColumn(n), decimalToInt64(v.asInstanceOf[JBigDecimal])) + case ParquetSchemaType(DECIMAL, FIXED_LEN_BYTE_ARRAY, length, _) if pushDownDecimal => + (n: String, v: Any) => + FilterApi.gtEq(binaryColumn(n), decimalToByteArray(v.asInstanceOf[JBigDecimal], length)) } /** * Returns a map from name of the column to the data type, if predicate push down applies. */ - private def getFieldMap(dataType: DataType): Map[String, DataType] = dataType match { - case StructType(fields) => + private def getFieldMap(dataType: MessageType): Map[String, ParquetSchemaType] = dataType match { + case m: MessageType => // Here we don't flatten the fields in the nested schema but just look up through // root fields. Currently, accessing to nested fields does not push down filters // and it does not support to create filters for them. - fields.map(f => f.name -> f.dataType).toMap - case _ => Map.empty[String, DataType] + m.getFields.asScala.filter(_.isPrimitive).map(_.asPrimitiveType()).map { f => + f.getName -> ParquetSchemaType( + f.getOriginalType, f.getPrimitiveTypeName, f.getTypeLength, f.getDecimalMetadata) + }.toMap + case _ => Map.empty[String, ParquetSchemaType] } /** * Converts data sources filters to Parquet filter predicates. */ - def createFilter(schema: StructType, predicate: sources.Filter): Option[FilterPredicate] = { + def createFilter(schema: MessageType, predicate: sources.Filter): Option[FilterPredicate] = { val nameToType = getFieldMap(schema) + // Decimal type must make sure that filter value's scale matched the file. + // If doesn't matched, which would cause data corruption. + def isDecimalMatched(value: Any, decimalMeta: DecimalMetadata): Boolean = value match { + case decimal: JBigDecimal => + decimal.scale == decimalMeta.getScale + case _ => false + } + + // Parquet's type in the given file should be matched to the value's type + // in the pushed filter in order to push down the filter to Parquet. + def valueCanMakeFilterOn(name: String, value: Any): Boolean = { + value == null || (nameToType(name) match { + case ParquetBooleanType => value.isInstanceOf[JBoolean] + case ParquetByteType | ParquetShortType | ParquetIntegerType => value.isInstanceOf[Number] + case ParquetLongType => value.isInstanceOf[JLong] + case ParquetFloatType => value.isInstanceOf[JFloat] + case ParquetDoubleType => value.isInstanceOf[JDouble] + case ParquetStringType => value.isInstanceOf[String] + case ParquetBinaryType => value.isInstanceOf[Array[Byte]] + case ParquetDateType => value.isInstanceOf[Date] + case ParquetTimestampMicrosType | ParquetTimestampMillisType => + value.isInstanceOf[Timestamp] + case ParquetSchemaType(DECIMAL, INT32, _, decimalMeta) => + isDecimalMatched(value, decimalMeta) + case ParquetSchemaType(DECIMAL, INT64, _, decimalMeta) => + isDecimalMatched(value, decimalMeta) + case ParquetSchemaType(DECIMAL, FIXED_LEN_BYTE_ARRAY, _, decimalMeta) => + isDecimalMatched(value, decimalMeta) + case _ => false + }) + } + // Parquet does not allow dots in the column name because dots are used as a column path // delimiter. Since Parquet 1.8.2 (PARQUET-389), Parquet accepts the filter predicates // with missing columns. The incorrect results could be got from Parquet when we push down // filters for the column having dots in the names. Thus, we do not push down such filters. // See SPARK-20364. - def canMakeFilterOn(name: String): Boolean = nameToType.contains(name) && !name.contains(".") + def canMakeFilterOn(name: String, value: Any): Boolean = { + nameToType.contains(name) && !name.contains(".") && valueCanMakeFilterOn(name, value) + } // NOTE: // @@ -223,29 +427,29 @@ private[parquet] class ParquetFilters(pushDownDate: Boolean) { // Probably I missed something and obviously this should be changed. predicate match { - case sources.IsNull(name) if canMakeFilterOn(name) => + case sources.IsNull(name) if canMakeFilterOn(name, null) => makeEq.lift(nameToType(name)).map(_(name, null)) - case sources.IsNotNull(name) if canMakeFilterOn(name) => + case sources.IsNotNull(name) if canMakeFilterOn(name, null) => makeNotEq.lift(nameToType(name)).map(_(name, null)) - case sources.EqualTo(name, value) if canMakeFilterOn(name) => + case sources.EqualTo(name, value) if canMakeFilterOn(name, value) => makeEq.lift(nameToType(name)).map(_(name, value)) - case sources.Not(sources.EqualTo(name, value)) if canMakeFilterOn(name) => + case sources.Not(sources.EqualTo(name, value)) if canMakeFilterOn(name, value) => makeNotEq.lift(nameToType(name)).map(_(name, value)) - case sources.EqualNullSafe(name, value) if canMakeFilterOn(name) => + case sources.EqualNullSafe(name, value) if canMakeFilterOn(name, value) => makeEq.lift(nameToType(name)).map(_(name, value)) - case sources.Not(sources.EqualNullSafe(name, value)) if canMakeFilterOn(name) => + case sources.Not(sources.EqualNullSafe(name, value)) if canMakeFilterOn(name, value) => makeNotEq.lift(nameToType(name)).map(_(name, value)) - case sources.LessThan(name, value) if canMakeFilterOn(name) => + case sources.LessThan(name, value) if canMakeFilterOn(name, value) => makeLt.lift(nameToType(name)).map(_(name, value)) - case sources.LessThanOrEqual(name, value) if canMakeFilterOn(name) => + case sources.LessThanOrEqual(name, value) if canMakeFilterOn(name, value) => makeLtEq.lift(nameToType(name)).map(_(name, value)) - case sources.GreaterThan(name, value) if canMakeFilterOn(name) => + case sources.GreaterThan(name, value) if canMakeFilterOn(name, value) => makeGt.lift(nameToType(name)).map(_(name, value)) - case sources.GreaterThanOrEqual(name, value) if canMakeFilterOn(name) => + case sources.GreaterThanOrEqual(name, value) if canMakeFilterOn(name, value) => makeGtEq.lift(nameToType(name)).map(_(name, value)) case sources.And(lhs, rhs) => @@ -270,6 +474,44 @@ private[parquet] class ParquetFilters(pushDownDate: Boolean) { case sources.Not(pred) => createFilter(schema, pred).map(FilterApi.not) + case sources.In(name, values) if canMakeFilterOn(name, values.head) + && values.distinct.length <= pushDownInFilterThreshold => + values.distinct.flatMap { v => + makeEq.lift(nameToType(name)).map(_(name, v)) + }.reduceLeftOption(FilterApi.or) + + case sources.StringStartsWith(name, prefix) + if pushDownStartWith && canMakeFilterOn(name, prefix) => + Option(prefix).map { v => + FilterApi.userDefined(binaryColumn(name), + new UserDefinedPredicate[Binary] with Serializable { + private val strToBinary = Binary.fromReusedByteArray(v.getBytes) + private val size = strToBinary.length + + override def canDrop(statistics: Statistics[Binary]): Boolean = { + val comparator = PrimitiveComparator.UNSIGNED_LEXICOGRAPHICAL_BINARY_COMPARATOR + val max = statistics.getMax + val min = statistics.getMin + comparator.compare(max.slice(0, math.min(size, max.length)), strToBinary) < 0 || + comparator.compare(min.slice(0, math.min(size, min.length)), strToBinary) > 0 + } + + override def inverseCanDrop(statistics: Statistics[Binary]): Boolean = { + val comparator = PrimitiveComparator.UNSIGNED_LEXICOGRAPHICAL_BINARY_COMPARATOR + val max = statistics.getMax + val min = statistics.getMin + comparator.compare(max.slice(0, math.min(size, max.length)), strToBinary) == 0 && + comparator.compare(min.slice(0, math.min(size, min.length)), strToBinary) == 0 + } + + override def keep(value: Binary): Boolean = { + UTF8String.fromBytes(value.getBytes).startsWith( + UTF8String.fromBytes(strToBinary.getBytes)) + } + } + ) + } + case _ => None } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala index e93908da43535..8661a5395ac44 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala @@ -30,7 +30,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter import org.apache.spark.sql.catalyst.util.CompressionCodecs import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.sources._ -import org.apache.spark.sql.types.{StringType, StructType} +import org.apache.spark.sql.types.{DataType, StringType, StructType} import org.apache.spark.util.SerializableConfiguration /** @@ -47,11 +47,6 @@ class TextFileFormat extends TextBasedFileFormat with DataSourceRegister { throw new AnalysisException( s"Text data source supports only a single column, and you have ${schema.size} columns.") } - val tpe = schema(0).dataType - if (tpe != StringType) { - throw new AnalysisException( - s"Text data source supports only a string column, but you have ${tpe.simpleString}.") - } } override def isSplitable( @@ -141,6 +136,9 @@ class TextFileFormat extends TextBasedFileFormat with DataSourceRegister { } } } + + override def supportDataType(dataType: DataType, isReadPath: Boolean): Boolean = + dataType == StringType } class TextOutputWriter( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourcePartitioning.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourcePartitioning.scala index 017a6737161a6..33079d5912506 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourcePartitioning.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourcePartitioning.scala @@ -30,8 +30,8 @@ class DataSourcePartitioning( override val numPartitions: Int = partitioning.numPartitions() - override def satisfies(required: physical.Distribution): Boolean = { - super.satisfies(required) || { + override def satisfies0(required: physical.Distribution): Boolean = { + super.satisfies0(required) || { required match { case d: physical.ClusteredDistribution if isCandidate(d.clustering) => val attrs = d.clustering.map(_.asInstanceOf[Attribute]) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala index 90fb5a14c9fc9..7613eb210c659 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala @@ -22,79 +22,36 @@ import scala.collection.JavaConverters._ import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression} -import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, Statistics} -import org.apache.spark.sql.execution.datasources.DataSourceStrategy -import org.apache.spark.sql.sources.{DataSourceRegister, Filter} +import org.apache.spark.sql.sources.DataSourceRegister import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2, ReadSupport, ReadSupportWithSchema} -import org.apache.spark.sql.sources.v2.reader.{DataSourceReader, SupportsPushDownCatalystFilters, SupportsPushDownFilters, SupportsPushDownRequiredColumns, SupportsReportStatistics} +import org.apache.spark.sql.sources.v2.reader.{DataSourceReader, SupportsReportStatistics} import org.apache.spark.sql.types.StructType +/** + * A logical plan representing a data source v2 scan. + * + * @param source An instance of a [[DataSourceV2]] implementation. + * @param options The options for this scan. Used to create fresh [[DataSourceReader]]. + * @param userSpecifiedSchema The user-specified schema for this scan. Used to create fresh + * [[DataSourceReader]]. + */ case class DataSourceV2Relation( source: DataSourceV2, + output: Seq[AttributeReference], options: Map[String, String], - projection: Seq[AttributeReference], - filters: Option[Seq[Expression]] = None, - userSpecifiedSchema: Option[StructType] = None) + userSpecifiedSchema: Option[StructType]) extends LeafNode with MultiInstanceRelation with DataSourceV2StringFormat { import DataSourceV2Relation._ - override def simpleString: String = "RelationV2 " + metadataString - - override lazy val schema: StructType = reader.readSchema() - - override lazy val output: Seq[AttributeReference] = { - // use the projection attributes to avoid assigning new ids. fields that are not projected - // will be assigned new ids, which is okay because they are not projected. - val attrMap = projection.map(a => a.name -> a).toMap - schema.map(f => attrMap.getOrElse(f.name, - AttributeReference(f.name, f.dataType, f.nullable, f.metadata)())) - } + override def pushedFilters: Seq[Expression] = Seq.empty - private lazy val v2Options: DataSourceOptions = makeV2Options(options) - - // postScanFilters: filters that need to be evaluated after the scan. - // pushedFilters: filters that will be pushed down and evaluated in the underlying data sources. - // Note: postScanFilters and pushedFilters can overlap, e.g. the parquet row group filter. - lazy val ( - reader: DataSourceReader, - postScanFilters: Seq[Expression], - pushedFilters: Seq[Expression]) = { - val newReader = userSpecifiedSchema match { - case Some(s) => - source.asReadSupportWithSchema.createReader(s, v2Options) - case _ => - source.asReadSupport.createReader(v2Options) - } - - DataSourceV2Relation.pushRequiredColumns(newReader, projection.toStructType) - - val (postScanFilters, pushedFilters) = filters match { - case Some(filterSeq) => - DataSourceV2Relation.pushFilters(newReader, filterSeq) - case _ => - (Nil, Nil) - } - logInfo(s"Post-Scan Filters: ${postScanFilters.mkString(",")}") - logInfo(s"Pushed Filters: ${pushedFilters.mkString(", ")}") - - (newReader, postScanFilters, pushedFilters) - } - - override def doCanonicalize(): LogicalPlan = { - val c = super.doCanonicalize().asInstanceOf[DataSourceV2Relation] - - // override output with canonicalized output to avoid attempting to configure a reader - val canonicalOutput: Seq[AttributeReference] = this.output - .map(a => QueryPlan.normalizeExprId(a, projection)) + override def simpleString: String = "RelationV2 " + metadataString - new DataSourceV2Relation(c.source, c.options, c.projection) { - override lazy val output: Seq[AttributeReference] = canonicalOutput - } - } + def newReader(): DataSourceReader = source.createReader(options, userSpecifiedSchema) - override def computeStats(): Statistics = reader match { + override def computeStats(): Statistics = newReader match { case r: SupportsReportStatistics => Statistics(sizeInBytes = r.getStatistics.sizeInBytes().orElse(conf.defaultSizeInBytes)) case _ => @@ -102,9 +59,7 @@ case class DataSourceV2Relation( } override def newInstance(): DataSourceV2Relation = { - // projection is used to maintain id assignment. - // if projection is not set, use output so the copy is not equal to the original - copy(projection = projection.map(_.newInstance())) + copy(output = output.map(_.newInstance())) } } @@ -184,77 +139,26 @@ object DataSourceV2Relation { source.getClass.getSimpleName } } - } - - private def makeV2Options(options: Map[String, String]): DataSourceOptions = { - new DataSourceOptions(options.asJava) - } - private def schema( - source: DataSourceV2, - v2Options: DataSourceOptions, - userSchema: Option[StructType]): StructType = { - val reader = userSchema match { - case Some(s) => - source.asReadSupportWithSchema.createReader(s, v2Options) - case _ => - source.asReadSupport.createReader(v2Options) + def createReader( + options: Map[String, String], + userSpecifiedSchema: Option[StructType]): DataSourceReader = { + val v2Options = new DataSourceOptions(options.asJava) + userSpecifiedSchema match { + case Some(s) => + asReadSupportWithSchema.createReader(s, v2Options) + case _ => + asReadSupport.createReader(v2Options) + } } - reader.readSchema() } def create( source: DataSourceV2, options: Map[String, String], - filters: Option[Seq[Expression]] = None, - userSpecifiedSchema: Option[StructType] = None): DataSourceV2Relation = { - val projection = schema(source, makeV2Options(options), userSpecifiedSchema).toAttributes - DataSourceV2Relation(source, options, projection, filters, userSpecifiedSchema) - } - - private def pushRequiredColumns(reader: DataSourceReader, struct: StructType): Unit = { - reader match { - case projectionSupport: SupportsPushDownRequiredColumns => - projectionSupport.pruneColumns(struct) - case _ => - } - } - - private def pushFilters( - reader: DataSourceReader, - filters: Seq[Expression]): (Seq[Expression], Seq[Expression]) = { - reader match { - case r: SupportsPushDownCatalystFilters => - val postScanFilters = r.pushCatalystFilters(filters.toArray) - val pushedFilters = r.pushedCatalystFilters() - (postScanFilters, pushedFilters) - - case r: SupportsPushDownFilters => - // A map from translated data source filters to original catalyst filter expressions. - val translatedFilterToExpr = scala.collection.mutable.HashMap.empty[Filter, Expression] - // Catalyst filter expression that can't be translated to data source filters. - val untranslatableExprs = scala.collection.mutable.ArrayBuffer.empty[Expression] - - for (filterExpr <- filters) { - val translated = DataSourceStrategy.translateFilter(filterExpr) - if (translated.isDefined) { - translatedFilterToExpr(translated.get) = filterExpr - } else { - untranslatableExprs += filterExpr - } - } - - // Data source filters that need to be evaluated again after scanning. which means - // the data source cannot guarantee the rows returned can pass these filters. - // As a result we must return it so Spark can plan an extra filter operator. - val postScanFilters = - r.pushFilters(translatedFilterToExpr.keys.toArray).map(translatedFilterToExpr) - // The filters which are marked as pushed to this data source - val pushedFilters = r.pushedFilters().map(translatedFilterToExpr) - - (untranslatableExprs ++ postScanFilters, pushedFilters) - - case _ => (filters, Nil) - } + userSpecifiedSchema: Option[StructType]): DataSourceV2Relation = { + val reader = source.createReader(options, userSpecifiedSchema) + DataSourceV2Relation( + source, reader.readSchema().toAttributes, options, userSpecifiedSchema) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala index 1b7c639f10f98..2a7f1de2c7c19 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala @@ -17,15 +17,121 @@ package org.apache.spark.sql.execution.datasources.v2 -import org.apache.spark.sql.Strategy -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.execution.SparkPlan -import org.apache.spark.sql.execution.streaming.continuous.{WriteToContinuousDataSource, WriteToContinuousDataSourceExec} +import scala.collection.mutable + +import org.apache.spark.sql.{sources, Strategy} +import org.apache.spark.sql.catalyst.expressions.{And, AttributeReference, AttributeSet, Expression} +import org.apache.spark.sql.catalyst.planning.PhysicalOperation +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Repartition} +import org.apache.spark.sql.execution.{FilterExec, ProjectExec, SparkPlan} +import org.apache.spark.sql.execution.datasources.DataSourceStrategy +import org.apache.spark.sql.execution.streaming.continuous.{ContinuousCoalesceExec, WriteToContinuousDataSource, WriteToContinuousDataSourceExec} +import org.apache.spark.sql.sources.v2.reader.{DataSourceReader, SupportsPushDownCatalystFilters, SupportsPushDownFilters, SupportsPushDownRequiredColumns} +import org.apache.spark.sql.sources.v2.reader.streaming.ContinuousReader object DataSourceV2Strategy extends Strategy { + + /** + * Pushes down filters to the data source reader + * + * @return pushed filter and post-scan filters. + */ + private def pushFilters( + reader: DataSourceReader, + filters: Seq[Expression]): (Seq[Expression], Seq[Expression]) = { + reader match { + case r: SupportsPushDownCatalystFilters => + val postScanFilters = r.pushCatalystFilters(filters.toArray) + val pushedFilters = r.pushedCatalystFilters() + (pushedFilters, postScanFilters) + + case r: SupportsPushDownFilters => + // A map from translated data source filters to original catalyst filter expressions. + val translatedFilterToExpr = mutable.HashMap.empty[sources.Filter, Expression] + // Catalyst filter expression that can't be translated to data source filters. + val untranslatableExprs = mutable.ArrayBuffer.empty[Expression] + + for (filterExpr <- filters) { + val translated = DataSourceStrategy.translateFilter(filterExpr) + if (translated.isDefined) { + translatedFilterToExpr(translated.get) = filterExpr + } else { + untranslatableExprs += filterExpr + } + } + + // Data source filters that need to be evaluated again after scanning. which means + // the data source cannot guarantee the rows returned can pass these filters. + // As a result we must return it so Spark can plan an extra filter operator. + val postScanFilters = r.pushFilters(translatedFilterToExpr.keys.toArray) + .map(translatedFilterToExpr) + // The filters which are marked as pushed to this data source + val pushedFilters = r.pushedFilters().map(translatedFilterToExpr) + (pushedFilters, untranslatableExprs ++ postScanFilters) + + case _ => (Nil, filters) + } + } + + /** + * Applies column pruning to the data source, w.r.t. the references of the given expressions. + * + * @return new output attributes after column pruning. + */ + // TODO: nested column pruning. + private def pruneColumns( + reader: DataSourceReader, + relation: DataSourceV2Relation, + exprs: Seq[Expression]): Seq[AttributeReference] = { + reader match { + case r: SupportsPushDownRequiredColumns => + val requiredColumns = AttributeSet(exprs.flatMap(_.references)) + val neededOutput = relation.output.filter(requiredColumns.contains) + if (neededOutput != relation.output) { + r.pruneColumns(neededOutput.toStructType) + val nameToAttr = relation.output.map(_.name).zip(relation.output).toMap + r.readSchema().toAttributes.map { + // We have to keep the attribute id during transformation. + a => a.withExprId(nameToAttr(a.name).exprId) + } + } else { + relation.output + } + + case _ => relation.output + } + } + + override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { - case r: DataSourceV2Relation => - DataSourceV2ScanExec(r.output, r.source, r.options, r.pushedFilters, r.reader) :: Nil + case PhysicalOperation(project, filters, relation: DataSourceV2Relation) => + val reader = relation.newReader() + // `pushedFilters` will be pushed down and evaluated in the underlying data sources. + // `postScanFilters` need to be evaluated after the scan. + // `postScanFilters` and `pushedFilters` can overlap, e.g. the parquet row group filter. + val (pushedFilters, postScanFilters) = pushFilters(reader, filters) + val output = pruneColumns(reader, relation, project ++ postScanFilters) + logInfo( + s""" + |Pushing operators to ${relation.source.getClass} + |Pushed Filters: ${pushedFilters.mkString(", ")} + |Post-Scan Filters: ${postScanFilters.mkString(",")} + |Output: ${output.mkString(", ")} + """.stripMargin) + + val scan = DataSourceV2ScanExec( + output, relation.source, relation.options, pushedFilters, reader) + + val filterCondition = postScanFilters.reduceLeftOption(And) + val withFilter = filterCondition.map(FilterExec(_, scan)).getOrElse(scan) + + val withProjection = if (withFilter.output != project) { + ProjectExec(project, withFilter) + } else { + withFilter + } + + withProjection :: Nil case r: StreamingDataSourceV2Relation => DataSourceV2ScanExec(r.output, r.source, r.options, r.pushedFilters, r.reader) :: Nil @@ -36,6 +142,17 @@ object DataSourceV2Strategy extends Strategy { case WriteToContinuousDataSource(writer, query) => WriteToContinuousDataSourceExec(writer, planLater(query)) :: Nil + case Repartition(1, false, child) => + val isContinuous = child.collectFirst { + case StreamingDataSourceV2Relation(_, _, _, r: ContinuousReader) => r + }.isDefined + + if (isContinuous) { + ContinuousCoalesceExec(1, planLater(child)) :: Nil + } else { + Nil + } + case _ => Nil } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownOperatorsToDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownOperatorsToDataSource.scala deleted file mode 100644 index e894f8afd6762..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownOperatorsToDataSource.scala +++ /dev/null @@ -1,66 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.datasources.v2 - -import org.apache.spark.sql.catalyst.expressions.{And, AttributeReference, AttributeSet} -import org.apache.spark.sql.catalyst.planning.PhysicalOperation -import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, Project} -import org.apache.spark.sql.catalyst.rules.Rule - -object PushDownOperatorsToDataSource extends Rule[LogicalPlan] { - override def apply(plan: LogicalPlan): LogicalPlan = plan match { - // PhysicalOperation guarantees that filters are deterministic; no need to check - case PhysicalOperation(project, filters, relation: DataSourceV2Relation) => - assert(relation.filters.isEmpty, "data source v2 should do push down only once.") - - val projectAttrs = project.map(_.toAttribute) - val projectSet = AttributeSet(project.flatMap(_.references)) - val filterSet = AttributeSet(filters.flatMap(_.references)) - - val projection = if (filterSet.subsetOf(projectSet) && - AttributeSet(projectAttrs) == projectSet) { - // When the required projection contains all of the filter columns and column pruning alone - // can produce the required projection, push the required projection. - // A final projection may still be needed if the data source produces a different column - // order or if it cannot prune all of the nested columns. - projectAttrs - } else { - // When there are filter columns not already in the required projection or when the required - // projection is more complicated than column pruning, base column pruning on the set of - // all columns needed by both. - (projectSet ++ filterSet).toSeq - } - - val newRelation = relation.copy( - projection = projection.asInstanceOf[Seq[AttributeReference]], - filters = Some(filters)) - - // Add a Filter for any filters that need to be evaluated after scan. - val postScanFilterCond = newRelation.postScanFilters.reduceLeftOption(And) - val filtered = postScanFilterCond.map(Filter(_, newRelation)).getOrElse(newRelation) - - // Add a Project to ensure the output matches the required projection - if (newRelation.output != projectAttrs) { - Project(project, filtered) - } else { - filtered - } - - case other => other.mapChildren(apply) - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala index ea4bda327f36f..b1148c0f62f7c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala @@ -29,10 +29,8 @@ import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, RowEncoder} import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.SparkPlan -import org.apache.spark.sql.execution.streaming.{MicroBatchExecution, StreamExecution} -import org.apache.spark.sql.execution.streaming.continuous.{CommitPartitionEpoch, ContinuousExecution, EpochCoordinatorRef, SetWriterPartitions} +import org.apache.spark.sql.execution.streaming.MicroBatchExecution import org.apache.spark.sql.sources.v2.writer._ -import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter import org.apache.spark.sql.types.StructType import org.apache.spark.util.Utils @@ -109,10 +107,12 @@ object DataWritingSparkTask extends Logging { iter: Iterator[InternalRow], useCommitCoordinator: Boolean): WriterCommitMessage = { val stageId = context.stageId() + val stageAttempt = context.stageAttemptNumber() val partId = context.partitionId() + val taskId = context.taskAttemptId() val attemptId = context.attemptNumber() val epochId = Option(context.getLocalProperty(MicroBatchExecution.BATCH_ID_KEY)).getOrElse("0") - val dataWriter = writeTask.createDataWriter(partId, attemptId, epochId.toLong) + val dataWriter = writeTask.createDataWriter(partId, taskId, epochId.toLong) // write the data and commit this writer. Utils.tryWithSafeFinallyAndFailureCallbacks(block = { @@ -122,12 +122,14 @@ object DataWritingSparkTask extends Logging { val msg = if (useCommitCoordinator) { val coordinator = SparkEnv.get.outputCommitCoordinator - val commitAuthorized = coordinator.canCommit(context.stageId(), partId, attemptId) + val commitAuthorized = coordinator.canCommit(stageId, stageAttempt, partId, attemptId) if (commitAuthorized) { - logInfo(s"Writer for stage $stageId, task $partId.$attemptId is authorized to commit.") + logInfo(s"Commit authorized for partition $partId (task $taskId, attempt $attemptId" + + s"stage $stageId.$stageAttempt)") dataWriter.commit() } else { - val message = s"Stage $stageId, task $partId.$attemptId: driver did not authorize commit" + val message = s"Commit denied for partition $partId (task $taskId, attempt $attemptId" + + s"stage $stageId.$stageAttempt)" logInfo(message) // throwing CommitDeniedException will trigger the catch block for abort throw new CommitDeniedException(message, stageId, partId, attemptId) @@ -138,15 +140,18 @@ object DataWritingSparkTask extends Logging { dataWriter.commit() } - logInfo(s"Writer for stage $stageId, task $partId.$attemptId committed.") + logInfo(s"Committed partition $partId (task $taskId, attempt $attemptId" + + s"stage $stageId.$stageAttempt)") msg })(catchBlock = { // If there is an error, abort this writer - logError(s"Writer for stage $stageId, task $partId.$attemptId is aborting.") + logError(s"Aborting commit for partition $partId (task $taskId, attempt $attemptId" + + s"stage $stageId.$stageAttempt)") dataWriter.abort() - logError(s"Writer for stage $stageId, task $partId.$attemptId aborted.") + logError(s"Aborted commit for partition $partId (task $taskId, attempt $attemptId" + + s"stage $stageId.$stageAttempt)") }) } } @@ -157,10 +162,10 @@ class InternalRowDataWriterFactory( override def createDataWriter( partitionId: Int, - attemptNumber: Int, + taskId: Long, epochId: Long): DataWriter[InternalRow] = { new InternalRowDataWriter( - rowWriterFactory.createDataWriter(partitionId, attemptNumber, epochId), + rowWriterFactory.createDataWriter(partitionId, taskId, epochId), RowEncoder.apply(schema).resolveAndBind()) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala index c55f9b8f1a7fc..a80673c705f1a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.execution.exchange +import java.util.concurrent.TimeoutException + import scala.concurrent.{ExecutionContext, Future} import scala.concurrent.duration._ import scala.util.control.NonFatal @@ -140,7 +142,16 @@ case class BroadcastExchangeExec( } override protected[sql] def doExecuteBroadcast[T](): broadcast.Broadcast[T] = { - ThreadUtils.awaitResult(relationFuture, timeout).asInstanceOf[broadcast.Broadcast[T]] + try { + ThreadUtils.awaitResult(relationFuture, timeout).asInstanceOf[broadcast.Broadcast[T]] + } catch { + case ex: TimeoutException => + logError(s"Could not execute broadcast in ${timeout.toSeconds} secs.", ex) + throw new SparkException(s"Could not execute broadcast in ${timeout.toSeconds} secs. " + + s"You can increase the timeout for broadcasts via ${SQLConf.BROADCAST_TIMEOUT.key} or " + + s"disable broadcast join by setting ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key} to -1", + ex) + } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala index e3d28388c5470..d96ecbaa48029 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.execution.exchange +import scala.collection.mutable import scala.collection.mutable.ArrayBuffer import org.apache.spark.sql.catalyst.expressions._ @@ -227,9 +228,16 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] { currentOrderOfKeys: Seq[Expression]): (Seq[Expression], Seq[Expression]) = { val leftKeysBuffer = ArrayBuffer[Expression]() val rightKeysBuffer = ArrayBuffer[Expression]() + val pickedIndexes = mutable.Set[Int]() + val keysAndIndexes = currentOrderOfKeys.zipWithIndex expectedOrderOfKeys.foreach(expression => { - val index = currentOrderOfKeys.indexWhere(e => e.semanticEquals(expression)) + val index = keysAndIndexes.find { case (e, idx) => + // As we may have the same key used many times, we need to filter out its occurrence we + // have already used. + e.semanticEquals(expression) && !pickedIndexes.contains(idx) + }.map(_._2).get + pickedIndexes += index leftKeysBuffer.append(leftKeys(index)) rightKeysBuffer.append(rightKeys(index)) }) @@ -270,14 +278,7 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] { * partitioning of the join nodes' children. */ private def reorderJoinPredicates(plan: SparkPlan): SparkPlan = { - plan.transformUp { - case BroadcastHashJoinExec(leftKeys, rightKeys, joinType, buildSide, condition, left, - right) => - val (reorderedLeftKeys, reorderedRightKeys) = - reorderJoinKeys(leftKeys, rightKeys, left.outputPartitioning, right.outputPartitioning) - BroadcastHashJoinExec(reorderedLeftKeys, reorderedRightKeys, joinType, buildSide, condition, - left, right) - + plan match { case ShuffledHashJoinExec(leftKeys, rightKeys, joinType, buildSide, condition, left, right) => val (reorderedLeftKeys, reorderedRightKeys) = reorderJoinKeys(leftKeys, rightKeys, left.outputPartitioning, right.outputPartitioning) @@ -288,6 +289,8 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] { val (reorderedLeftKeys, reorderedRightKeys) = reorderJoinKeys(leftKeys, rightKeys, left.outputPartitioning, right.outputPartitioning) SortMergeJoinExec(reorderedLeftKeys, reorderedRightKeys, joinType, condition, left, right) + + case other => other } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/Exchange.scala index 09f79a2de0ba0..1a5b7599bb7d9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/Exchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/Exchange.scala @@ -24,7 +24,7 @@ import org.apache.spark.broadcast import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, Expression, SortOrder} -import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning} +import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution.{LeafExecNode, SparkPlan, UnaryExecNode} import org.apache.spark.sql.internal.SQLConf @@ -70,7 +70,7 @@ case class ReusedExchangeExec(override val output: Seq[Attribute], child: Exchan } override def outputPartitioning: Partitioning = child.outputPartitioning match { - case h: HashPartitioning => h.copy(expressions = h.expressions.map(updateAttr)) + case e: Expression => updateAttr(e).asInstanceOf[Partitioning] case other => other } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala index 77b907870d678..b4f0ae1eb1a18 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.metric import java.text.NumberFormat import java.util.Locale +import java.util.concurrent.atomic.LongAdder import org.apache.spark.SparkContext import org.apache.spark.scheduler.AccumulableInfo @@ -32,40 +33,45 @@ import org.apache.spark.util.{AccumulatorContext, AccumulatorV2, Utils} * on the driver side must be explicitly posted using [[SQLMetrics.postDriverMetricUpdates()]]. */ class SQLMetric(val metricType: String, initValue: Long = 0L) extends AccumulatorV2[Long, Long] { + // This is a workaround for SPARK-11013. // We may use -1 as initial value of the accumulator, if the accumulator is valid, we will // update it at the end of task and the value will be at least 0. Then we can filter out the -1 // values before calculate max, min, etc. - private[this] var _value = initValue - private var _zeroValue = initValue + private[this] val _value = new LongAdder + private val _zeroValue = initValue + _value.add(initValue) override def copy(): SQLMetric = { - val newAcc = new SQLMetric(metricType, _value) - newAcc._zeroValue = initValue + val newAcc = new SQLMetric(metricType, initValue) + newAcc.add(_value.sum()) newAcc } - override def reset(): Unit = _value = _zeroValue + override def reset(): Unit = this.set(_zeroValue) override def merge(other: AccumulatorV2[Long, Long]): Unit = other match { - case o: SQLMetric => _value += o.value + case o: SQLMetric => _value.add(o.value) case _ => throw new UnsupportedOperationException( s"Cannot merge ${this.getClass.getName} with ${other.getClass.getName}") } - override def isZero(): Boolean = _value == _zeroValue + override def isZero(): Boolean = _value.sum() == _zeroValue - override def add(v: Long): Unit = _value += v + override def add(v: Long): Unit = _value.add(v) // We can set a double value to `SQLMetric` which stores only long value, if it is // average metrics. def set(v: Double): Unit = SQLMetrics.setDoubleForAverageMetrics(this, v) - def set(v: Long): Unit = _value = v + def set(v: Long): Unit = { + _value.reset() + _value.add(v) + } - def +=(v: Long): Unit = _value += v + def +=(v: Long): Unit = _value.add(v) - override def value: Long = _value + override def value: Long = _value.sum() // Provide special identifier as metadata so we can tell that this is a `SQLMetric` later override def toInfo(update: Option[Any], value: Option[Any]): AccumulableInfo = { @@ -153,7 +159,7 @@ object SQLMetrics { Seq.fill(3)(0L) } else { val sorted = validValues.sorted - Seq(sorted(0), sorted(validValues.length / 2), sorted(validValues.length - 1)) + Seq(sorted.head, sorted(validValues.length / 2), sorted(validValues.length - 1)) } metric.map(v => numberFormat.format(v.toDouble / baseForAvgMetric)) } @@ -173,7 +179,8 @@ object SQLMetrics { Seq.fill(4)(0L) } else { val sorted = validValues.sorted - Seq(sorted.sum, sorted(0), sorted(validValues.length / 2), sorted(validValues.length - 1)) + Seq(sorted.sum, sorted.head, sorted(validValues.length / 2), + sorted(validValues.length - 1)) } metric.map(strFormat) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala index 8e01e8e56a5bd..d00f6f042d6e0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala @@ -28,6 +28,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution, Distribution, Partitioning} import org.apache.spark.sql.execution.{GroupedIterator, SparkPlan, UnaryExecNode} +import org.apache.spark.sql.execution.arrow.ArrowUtils import org.apache.spark.sql.types.{DataType, StructField, StructType} import org.apache.spark.util.Utils @@ -81,7 +82,7 @@ case class AggregateInPandasExec( val bufferSize = inputRDD.conf.getInt("spark.buffer.size", 65536) val reuseWorker = inputRDD.conf.getBoolean("spark.python.worker.reuse", defaultValue = true) val sessionLocalTimeZone = conf.sessionLocalTimeZone - val pandasRespectSessionTimeZone = conf.pandasRespectSessionTimeZone + val pythonRunnerConf = ArrowUtils.getPythonRunnerConfMap(conf) val (pyFuncs, inputs) = udfExpressions.map(collectFunctions).unzip @@ -135,10 +136,14 @@ case class AggregateInPandasExec( } val columnarBatchIter = new ArrowPythonRunner( - pyFuncs, bufferSize, reuseWorker, - PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF, argOffsets, aggInputSchema, - sessionLocalTimeZone, pandasRespectSessionTimeZone) - .compute(projectedRowIter, context.partitionId(), context) + pyFuncs, + bufferSize, + reuseWorker, + PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF, + argOffsets, + aggInputSchema, + sessionLocalTimeZone, + pythonRunnerConf).compute(projectedRowIter, context.partitionId(), context) val joinedAttributes = groupingExpressions.map(_.toAttribute) ++ udfExpressions.map(_.resultAttribute) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala index c4de214679ae4..0bc21c0986e69 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala @@ -24,6 +24,7 @@ import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.arrow.ArrowUtils import org.apache.spark.sql.types.StructType /** @@ -63,7 +64,7 @@ case class ArrowEvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], chi private val batchSize = conf.arrowMaxRecordsPerBatch private val sessionLocalTimeZone = conf.sessionLocalTimeZone - private val pandasRespectSessionTimeZone = conf.pandasRespectSessionTimeZone + private val pythonRunnerConf = ArrowUtils.getPythonRunnerConfMap(conf) protected override def evaluate( funcs: Seq[ChainedPythonFunctions], @@ -80,10 +81,14 @@ case class ArrowEvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], chi val batchIter = if (batchSize > 0) new BatchIterator(iter, batchSize) else Iterator(iter) val columnarBatchIter = new ArrowPythonRunner( - funcs, bufferSize, reuseWorker, - PythonEvalType.SQL_SCALAR_PANDAS_UDF, argOffsets, schema, - sessionLocalTimeZone, pandasRespectSessionTimeZone) - .compute(batchIter, context.partitionId(), context) + funcs, + bufferSize, + reuseWorker, + PythonEvalType.SQL_SCALAR_PANDAS_UDF, + argOffsets, + schema, + sessionLocalTimeZone, + pythonRunnerConf).compute(batchIter, context.partitionId(), context) new Iterator[InternalRow] { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala index 01e19bddbfb66..ca665652f204d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala @@ -45,7 +45,7 @@ class ArrowPythonRunner( argOffsets: Array[Array[Int]], schema: StructType, timeZoneId: String, - respectTimeZone: Boolean) + conf: Map[String, String]) extends BasePythonRunner[Iterator[InternalRow], ColumnarBatch]( funcs, bufferSize, reuseWorker, evalType, argOffsets) { @@ -58,12 +58,15 @@ class ArrowPythonRunner( new WriterThread(env, worker, inputIterator, partitionIndex, context) { protected override def writeCommand(dataOut: DataOutputStream): Unit = { - PythonUDFRunner.writeUDFs(dataOut, funcs, argOffsets) - if (respectTimeZone) { - PythonRDD.writeUTF(timeZoneId, dataOut) - } else { - dataOut.writeInt(SpecialLengths.NULL) + + // Write config for the worker as a number of key -> value pairs of strings + dataOut.writeInt(conf.size) + for ((k, v) <- conf) { + PythonRDD.writeUTF(k, dataOut) + PythonRDD.writeUTF(v, dataOut) } + + PythonUDFRunner.writeUDFs(dataOut, funcs, argOffsets) } protected override def writeIteratorToStream(dataOut: DataOutputStream): Unit = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala index 513e174c7733e..f5a563baf52df 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala @@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution, Distribution, Partitioning} import org.apache.spark.sql.execution.{GroupedIterator, SparkPlan, UnaryExecNode} +import org.apache.spark.sql.execution.arrow.ArrowUtils import org.apache.spark.sql.types.StructType /** @@ -77,7 +78,7 @@ case class FlatMapGroupsInPandasExec( val reuseWorker = inputRDD.conf.getBoolean("spark.python.worker.reuse", defaultValue = true) val chainedFunc = Seq(ChainedPythonFunctions(Seq(pandasFunction))) val sessionLocalTimeZone = conf.sessionLocalTimeZone - val pandasRespectSessionTimeZone = conf.pandasRespectSessionTimeZone + val pythonRunnerConf = ArrowUtils.getPythonRunnerConfMap(conf) // Deduplicate the grouping attributes. // If a grouping attribute also appears in data attributes, then we don't need to send the @@ -139,10 +140,14 @@ case class FlatMapGroupsInPandasExec( val context = TaskContext.get() val columnarBatchIter = new ArrowPythonRunner( - chainedFunc, bufferSize, reuseWorker, - PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF, argOffsets, dedupSchema, - sessionLocalTimeZone, pandasRespectSessionTimeZone) - .compute(grouped, context.partitionId(), context) + chainedFunc, + bufferSize, + reuseWorker, + PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF, + argOffsets, + dedupSchema, + sessionLocalTimeZone, + pythonRunnerConf).compute(grouped, context.partitionId(), context) columnarBatchIter.flatMap(_.rowIterator.asScala).map(UnsafeProjection.create(output, output)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonForeachWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonForeachWriter.scala new file mode 100644 index 0000000000000..a58773122922f --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonForeachWriter.scala @@ -0,0 +1,161 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.python + +import java.io.File +import java.util.concurrent.TimeUnit +import java.util.concurrent.locks.ReentrantLock + +import org.apache.spark.{SparkEnv, TaskContext} +import org.apache.spark.api.python._ +import org.apache.spark.internal.Logging +import org.apache.spark.memory.TaskMemoryManager +import org.apache.spark.sql.ForeachWriter +import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.sql.types.StructType +import org.apache.spark.util.{NextIterator, Utils} + +class PythonForeachWriter(func: PythonFunction, schema: StructType) + extends ForeachWriter[UnsafeRow] { + + private lazy val context = TaskContext.get() + private lazy val buffer = new PythonForeachWriter.UnsafeRowBuffer( + context.taskMemoryManager, new File(Utils.getLocalDir(SparkEnv.get.conf)), schema.fields.length) + private lazy val inputRowIterator = buffer.iterator + + private lazy val inputByteIterator = { + EvaluatePython.registerPicklers() + val objIterator = inputRowIterator.map { row => EvaluatePython.toJava(row, schema) } + new SerDeUtil.AutoBatchedPickler(objIterator) + } + + private lazy val pythonRunner = { + val conf = SparkEnv.get.conf + val bufferSize = conf.getInt("spark.buffer.size", 65536) + val reuseWorker = conf.getBoolean("spark.python.worker.reuse", true) + PythonRunner(func, bufferSize, reuseWorker) + } + + private lazy val outputIterator = + pythonRunner.compute(inputByteIterator, context.partitionId(), context) + + override def open(partitionId: Long, version: Long): Boolean = { + outputIterator // initialize everything + TaskContext.get.addTaskCompletionListener { _ => buffer.close() } + true + } + + override def process(value: UnsafeRow): Unit = { + buffer.add(value) + } + + override def close(errorOrNull: Throwable): Unit = { + buffer.allRowsAdded() + if (outputIterator.hasNext) outputIterator.next() // to throw python exception if there was one + } +} + +object PythonForeachWriter { + + /** + * A buffer that is designed for the sole purpose of buffering UnsafeRows in PythonForeachWriter. + * It is designed to be used with only 1 writer thread (i.e. JVM task thread) and only 1 reader + * thread (i.e. PythonRunner writing thread that reads from the buffer and writes to the Python + * worker stdin). Adds to the buffer are non-blocking, and reads through the buffer's iterator + * are blocking, that is, it blocks until new data is available or all data has been added. + * + * Internally, it uses a [[HybridRowQueue]] to buffer the rows in a practically unlimited queue + * across memory and local disk. However, HybridRowQueue is designed to be used only with + * EvalPythonExec where the reader is always behind the the writer, that is, the reader does not + * try to read n+1 rows if the writer has only written n rows at any point of time. This + * assumption is not true for PythonForeachWriter where rows may be added at a different rate as + * they are consumed by the python worker. Hence, to maintain the invariant of the reader being + * behind the writer while using HybridRowQueue, the buffer does the following + * - Keeps a count of the rows in the HybridRowQueue + * - Blocks the buffer's consuming iterator when the count is 0 so that the reader does not + * try to read more rows than what has been written. + * + * The implementation of the blocking iterator (ReentrantLock, Condition, etc.) has been borrowed + * from that of ArrayBlockingQueue. + */ + class UnsafeRowBuffer(taskMemoryManager: TaskMemoryManager, tempDir: File, numFields: Int) + extends Logging { + private val queue = HybridRowQueue(taskMemoryManager, tempDir, numFields) + private val lock = new ReentrantLock() + private val unblockRemove = lock.newCondition() + + // All of these are guarded by `lock` + private var count = 0L + private var allAdded = false + private var exception: Throwable = null + + val iterator = new NextIterator[UnsafeRow] { + override protected def getNext(): UnsafeRow = { + val row = remove() + if (row == null) finished = true + row + } + override protected def close(): Unit = { } + } + + def add(row: UnsafeRow): Unit = withLock { + assert(queue.add(row), s"Failed to add row to HybridRowQueue while sending data to Python" + + s"[count = $count, allAdded = $allAdded, exception = $exception]") + count += 1 + unblockRemove.signal() + logTrace(s"Added $row, $count left") + } + + private def remove(): UnsafeRow = withLock { + while (count == 0 && !allAdded && exception == null) { + unblockRemove.await(100, TimeUnit.MILLISECONDS) + } + + // If there was any error in the adding thread, then rethrow it in the removing thread + if (exception != null) throw exception + + if (count > 0) { + val row = queue.remove() + assert(row != null, "HybridRowQueue.remove() returned null " + + s"[count = $count, allAdded = $allAdded, exception = $exception]") + count -= 1 + logTrace(s"Removed $row, $count left") + row + } else { + null + } + } + + def allRowsAdded(): Unit = withLock { + allAdded = true + unblockRemove.signal() + } + + def close(): Unit = { queue.close() } + + private def withLock[T](f: => T): T = { + lock.lockInterruptibly() + try { f } catch { + case e: Throwable => + if (exception == null) exception = e + throw e + } finally { lock.unlock() } + } + } +} + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasExec.scala index c76832a1a3829..628029b13a6c3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasExec.scala @@ -29,6 +29,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.{GroupedIterator, SparkPlan, UnaryExecNode} +import org.apache.spark.sql.execution.arrow.ArrowUtils import org.apache.spark.sql.types.{DataType, StructField, StructType} import org.apache.spark.util.Utils @@ -97,7 +98,7 @@ case class WindowInPandasExec( val bufferSize = inputRDD.conf.getInt("spark.buffer.size", 65536) val reuseWorker = inputRDD.conf.getBoolean("spark.python.worker.reuse", defaultValue = true) val sessionLocalTimeZone = conf.sessionLocalTimeZone - val pandasRespectSessionTimeZone = conf.pandasRespectSessionTimeZone + val pythonRunnerConf = ArrowUtils.getPythonRunnerConfMap(conf) // Extract window expressions and window functions val expressions = windowExpression.flatMap(_.collect { case e: WindowExpression => e }) @@ -154,11 +155,14 @@ case class WindowInPandasExec( } val windowFunctionResult = new ArrowPythonRunner( - pyFuncs, bufferSize, reuseWorker, + pyFuncs, + bufferSize, + reuseWorker, PythonEvalType.SQL_WINDOW_AGG_PANDAS_UDF, - argOffsets, windowInputSchema, - sessionLocalTimeZone, pandasRespectSessionTimeZone) - .compute(pythonInput, context.partitionId(), context) + argOffsets, + windowInputSchema, + sessionLocalTimeZone, + pythonRunnerConf).compute(pythonInput, context.partitionId(), context) val joined = new JoinedRow val resultProj = createResultProjection(expressions) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala index c480b96626f84..6ae7f2869b0f3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala @@ -59,7 +59,8 @@ class IncrementalExecution( StatefulAggregationStrategy :: FlatMapGroupsWithStateStrategy :: StreamingRelationStrategy :: - StreamingDeduplicationStrategy :: Nil + StreamingDeduplicationStrategy :: + StreamingGlobalLimitStrategy(outputMode) :: Nil } private[sql] val numStateStores = offsetSeqMetadata.conf.get(SQLConf.SHUFFLE_PARTITIONS.key) @@ -134,8 +135,12 @@ class IncrementalExecution( stateWatermarkPredicates = StreamingSymmetricHashJoinHelper.getStateWatermarkPredicates( j.left.output, j.right.output, j.leftKeys, j.rightKeys, j.condition.full, - Some(offsetSeqMetadata.batchWatermarkMs)) - ) + Some(offsetSeqMetadata.batchWatermarkMs))) + + case l: StreamingGlobalLimitExec => + l.copy( + stateInfo = Some(nextStatefulOperationStateInfo), + outputMode = Some(outputMode)) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala index 17ffa2a517312..45c43f549d24f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala @@ -61,7 +61,7 @@ class MicroBatchExecution( case _ => throw new IllegalStateException(s"Unknown type of trigger: $trigger") } - private val watermarkTracker = new WatermarkTracker() + private var watermarkTracker: WatermarkTracker = _ override lazy val logicalPlan: LogicalPlan = { assert(queryExecutionThread eq Thread.currentThread, @@ -184,6 +184,9 @@ class MicroBatchExecution( isCurrentBatchConstructed = constructNextBatch(noDataBatchesEnabled) } + // Record the trigger offset range for progress reporting *before* processing the batch + recordTriggerOffsets(from = committedOffsets, to = availableOffsets) + // Remember whether the current batch has data or not. This will be required later // for bookkeeping after running the batch, when `isNewDataAvailable` will have changed // to false as the batch would have already processed the available data. @@ -257,6 +260,7 @@ class MicroBatchExecution( OffsetSeqMetadata.setSessionConf(metadata, sparkSessionToRunBatches.conf) offsetSeqMetadata = OffsetSeqMetadata( metadata.batchWatermarkMs, metadata.batchTimestampMs, sparkSessionToRunBatches.conf) + watermarkTracker = WatermarkTracker(sparkSessionToRunBatches.conf) watermarkTracker.setWatermark(metadata.batchWatermarkMs) } @@ -295,6 +299,7 @@ class MicroBatchExecution( case None => // We are starting this stream for the first time. logInfo(s"Starting new streaming query.") currentBatchId = 0 + watermarkTracker = WatermarkTracker(sparkSessionToRunBatches.conf) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala index 787174481ff08..1ae3f36c152cf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala @@ -22,7 +22,7 @@ import org.json4s.jackson.Serialization import org.apache.spark.internal.Logging import org.apache.spark.sql.RuntimeConfig -import org.apache.spark.sql.internal.SQLConf.{SHUFFLE_PARTITIONS, STATE_STORE_PROVIDER_CLASS} +import org.apache.spark.sql.internal.SQLConf._ /** * An ordered collection of offsets, used to track the progress of processing data from one or more @@ -86,7 +86,22 @@ case class OffsetSeqMetadata( object OffsetSeqMetadata extends Logging { private implicit val format = Serialization.formats(NoTypeHints) - private val relevantSQLConfs = Seq(SHUFFLE_PARTITIONS, STATE_STORE_PROVIDER_CLASS) + private val relevantSQLConfs = Seq( + SHUFFLE_PARTITIONS, STATE_STORE_PROVIDER_CLASS, STREAMING_MULTIPLE_WATERMARK_POLICY) + + /** + * Default values of relevant configurations that are used for backward compatibility. + * As new configurations are added to the metadata, existing checkpoints may not have those + * confs. The values in this list ensures that the confs without recovered values are + * set to a default value that ensure the same behavior of the streaming query as it was before + * the restart. + * + * Note, that this is optional; set values here if you *have* to override existing session conf + * with a specific default value for ensuring same behavior of the query as before. + */ + private val relevantSQLConfDefaultValues = Map[String, String]( + STREAMING_MULTIPLE_WATERMARK_POLICY.key -> MultipleWatermarkPolicy.DEFAULT_POLICY_NAME + ) def apply(json: String): OffsetSeqMetadata = Serialization.read[OffsetSeqMetadata](json) @@ -115,8 +130,22 @@ object OffsetSeqMetadata extends Logging { case None => // For backward compatibility, if a config was not recorded in the offset log, - // then log it, and let the existing conf value in SparkSession prevail. - logWarning (s"Conf '$confKey' was not found in the offset log, using existing value") + // then either inject a default value (if specified in `relevantSQLConfDefaultValues`) or + // let the existing conf value in SparkSession prevail. + relevantSQLConfDefaultValues.get(confKey) match { + + case Some(defaultValue) => + sessionConf.set(confKey, defaultValue) + logWarning(s"Conf '$confKey' was not found in the offset log, " + + s"using default value '$defaultValue'") + + case None => + val valueStr = sessionConf.getOption(confKey).map { v => + s" Using existing session conf value '$v'." + }.getOrElse { " No value set in session conf." } + logWarning(s"Conf '$confKey' was not found in the offset log. $valueStr") + + } } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala index 16ad3ef9a3d4a..47f4b52e6e34c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala @@ -56,8 +56,6 @@ trait ProgressReporter extends Logging { protected def logicalPlan: LogicalPlan protected def lastExecution: QueryExecution protected def newData: Map[BaseStreamingSource, LogicalPlan] - protected def availableOffsets: StreamProgress - protected def committedOffsets: StreamProgress protected def sources: Seq[BaseStreamingSource] protected def sink: BaseStreamingSink protected def offsetSeqMetadata: OffsetSeqMetadata @@ -68,8 +66,11 @@ trait ProgressReporter extends Logging { // Local timestamps and counters. private var currentTriggerStartTimestamp = -1L private var currentTriggerEndTimestamp = -1L + private var currentTriggerStartOffsets: Map[BaseStreamingSource, String] = _ + private var currentTriggerEndOffsets: Map[BaseStreamingSource, String] = _ // TODO: Restore this from the checkpoint when possible. private var lastTriggerStartTimestamp = -1L + private val currentDurationsMs = new mutable.HashMap[String, Long]() /** Flag that signals whether any error with input metrics have already been logged */ @@ -114,9 +115,20 @@ trait ProgressReporter extends Logging { lastTriggerStartTimestamp = currentTriggerStartTimestamp currentTriggerStartTimestamp = triggerClock.getTimeMillis() currentStatus = currentStatus.copy(isTriggerActive = true) + currentTriggerStartOffsets = null + currentTriggerEndOffsets = null currentDurationsMs.clear() } + /** + * Record the offsets range this trigger will process. Call this before updating + * `committedOffsets` in `StreamExecution` to make sure that the correct range is recorded. + */ + protected def recordTriggerOffsets(from: StreamProgress, to: StreamProgress): Unit = { + currentTriggerStartOffsets = from.mapValues(_.json) + currentTriggerEndOffsets = to.mapValues(_.json) + } + private def updateProgress(newProgress: StreamingQueryProgress): Unit = { progressBuffer.synchronized { progressBuffer += newProgress @@ -130,6 +142,7 @@ trait ProgressReporter extends Logging { /** Finalizes the query progress and adds it to list of recent status updates. */ protected def finishTrigger(hasNewData: Boolean): Unit = { + assert(currentTriggerStartOffsets != null && currentTriggerEndOffsets != null) currentTriggerEndTimestamp = triggerClock.getTimeMillis() val executionStats = extractExecutionStats(hasNewData) @@ -147,8 +160,8 @@ trait ProgressReporter extends Logging { val numRecords = executionStats.inputRows.getOrElse(source, 0L) new SourceProgress( description = source.toString, - startOffset = committedOffsets.get(source).map(_.json).orNull, - endOffset = availableOffsets.get(source).map(_.json).orNull, + startOffset = currentTriggerStartOffsets.get(source).orNull, + endOffset = currentTriggerEndOffsets.get(source).orNull, numInputRows = numRecords, inputRowsPerSecond = numRecords / inputTimeSec, processedRowsPerSecond = numRecords / processingTimeSec diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingGlobalLimitExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingGlobalLimitExec.scala new file mode 100644 index 0000000000000..bf4af60c8cf03 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingGlobalLimitExec.scala @@ -0,0 +1,102 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.streaming + +import java.util.concurrent.TimeUnit.NANOSECONDS + +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.expressions.GenericInternalRow +import org.apache.spark.sql.catalyst.expressions.UnsafeProjection +import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, Distribution, Partitioning} +import org.apache.spark.sql.catalyst.streaming.InternalOutputModes +import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode} +import org.apache.spark.sql.execution.streaming.state.StateStoreOps +import org.apache.spark.sql.streaming.OutputMode +import org.apache.spark.sql.types.{LongType, NullType, StructField, StructType} +import org.apache.spark.util.CompletionIterator + +/** + * A physical operator for executing a streaming limit, which makes sure no more than streamLimit + * rows are returned. This operator is meant for streams in Append mode only. + */ +case class StreamingGlobalLimitExec( + streamLimit: Long, + child: SparkPlan, + stateInfo: Option[StatefulOperatorStateInfo] = None, + outputMode: Option[OutputMode] = None) + extends UnaryExecNode with StateStoreWriter { + + private val keySchema = StructType(Array(StructField("key", NullType))) + private val valueSchema = StructType(Array(StructField("value", LongType))) + + override protected def doExecute(): RDD[InternalRow] = { + metrics // force lazy init at driver + + assert(outputMode.isDefined && outputMode.get == InternalOutputModes.Append, + "StreamingGlobalLimitExec is only valid for streams in Append output mode") + + child.execute().mapPartitionsWithStateStore( + getStateInfo, + keySchema, + valueSchema, + indexOrdinal = None, + sqlContext.sessionState, + Some(sqlContext.streams.stateStoreCoordinator)) { (store, iter) => + val key = UnsafeProjection.create(keySchema)(new GenericInternalRow(Array[Any](null))) + val numOutputRows = longMetric("numOutputRows") + val numUpdatedStateRows = longMetric("numUpdatedStateRows") + val allUpdatesTimeMs = longMetric("allUpdatesTimeMs") + val commitTimeMs = longMetric("commitTimeMs") + val updatesStartTimeNs = System.nanoTime + + val preBatchRowCount: Long = Option(store.get(key)).map(_.getLong(0)).getOrElse(0L) + var cumulativeRowCount = preBatchRowCount + + val result = iter.filter { r => + val x = cumulativeRowCount < streamLimit + if (x) { + cumulativeRowCount += 1 + } + x + } + + CompletionIterator[InternalRow, Iterator[InternalRow]](result, { + if (cumulativeRowCount > preBatchRowCount) { + numUpdatedStateRows += 1 + numOutputRows += cumulativeRowCount - preBatchRowCount + store.put(key, getValueRow(cumulativeRowCount)) + } + allUpdatesTimeMs += NANOSECONDS.toMillis(System.nanoTime - updatesStartTimeNs) + commitTimeMs += timeTakenMs { store.commit() } + setStoreMetrics(store) + }) + } + } + + override def output: Seq[Attribute] = child.output + + override def outputPartitioning: Partitioning = child.outputPartitioning + + override def requiredChildDistribution: Seq[Distribution] = AllTuples :: Nil + + private def getValueRow(value: Long): UnsafeRow = { + UnsafeProjection.create(valueSchema)(new GenericInternalRow(Array[Any](value))) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala index afa664eb76525..50cf971e4ec3c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala @@ -167,8 +167,8 @@ case class StreamingSymmetricHashJoinExec( val nullRight = new GenericInternalRow(right.output.map(_.withNullability(true)).length) override def requiredChildDistribution: Seq[Distribution] = - ClusteredDistribution(leftKeys, stateInfo.map(_.numPartitions)) :: - ClusteredDistribution(rightKeys, stateInfo.map(_.numPartitions)) :: Nil + HashClusteredDistribution(leftKeys, stateInfo.map(_.numPartitions)) :: + HashClusteredDistribution(rightKeys, stateInfo.map(_.numPartitions)) :: Nil override def output: Seq[Attribute] = joinType match { case _: InnerLike => left.output ++ right.output diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/WatermarkTracker.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/WatermarkTracker.scala index 80865669558dd..7b30db44a2090 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/WatermarkTracker.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/WatermarkTracker.scala @@ -20,15 +20,68 @@ package org.apache.spark.sql.execution.streaming import scala.collection.mutable import org.apache.spark.internal.Logging +import org.apache.spark.sql.RuntimeConfig import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.internal.SQLConf -class WatermarkTracker extends Logging { +/** + * Policy to define how to choose a new global watermark value if there are + * multiple watermark operators in a streaming query. + */ +sealed trait MultipleWatermarkPolicy { + def chooseGlobalWatermark(operatorWatermarks: Seq[Long]): Long +} + +object MultipleWatermarkPolicy { + val DEFAULT_POLICY_NAME = "min" + + def apply(policyName: String): MultipleWatermarkPolicy = { + policyName.toLowerCase match { + case DEFAULT_POLICY_NAME => MinWatermark + case "max" => MaxWatermark + case _ => + throw new IllegalArgumentException(s"Could not recognize watermark policy '$policyName'") + } + } +} + +/** + * Policy to choose the *min* of the operator watermark values as the global watermark value. + * Note that this is the safe (hence default) policy as the global watermark will advance + * only if all the individual operator watermarks have advanced. In other words, in a + * streaming query with multiple input streams and watermarks defined on all of them, + * the global watermark will advance as slowly as the slowest input. So if there is watermark + * based state cleanup or late-data dropping, then this policy is the most conservative one. + */ +case object MinWatermark extends MultipleWatermarkPolicy { + def chooseGlobalWatermark(operatorWatermarks: Seq[Long]): Long = { + assert(operatorWatermarks.nonEmpty) + operatorWatermarks.min + } +} + +/** + * Policy to choose the *min* of the operator watermark values as the global watermark value. So the + * global watermark will advance if any of the individual operator watermarks has advanced. + * In other words, in a streaming query with multiple input streams and watermarks defined on all + * of them, the global watermark will advance as fast as the fastest input. So if there is watermark + * based state cleanup or late-data dropping, then this policy is the most aggressive one and + * may lead to unexpected behavior if the data of the slow stream is delayed. + */ +case object MaxWatermark extends MultipleWatermarkPolicy { + def chooseGlobalWatermark(operatorWatermarks: Seq[Long]): Long = { + assert(operatorWatermarks.nonEmpty) + operatorWatermarks.max + } +} + +/** Tracks the watermark value of a streaming query based on a given `policy` */ +case class WatermarkTracker(policy: MultipleWatermarkPolicy) extends Logging { private val operatorToWatermarkMap = mutable.HashMap[Int, Long]() - private var watermarkMs: Long = 0 - private var updated = false + private var globalWatermarkMs: Long = 0 def setWatermark(newWatermarkMs: Long): Unit = synchronized { - watermarkMs = newWatermarkMs + globalWatermarkMs = newWatermarkMs } def updateWatermark(executedPlan: SparkPlan): Unit = synchronized { @@ -37,7 +90,6 @@ class WatermarkTracker extends Logging { } if (watermarkOperators.isEmpty) return - watermarkOperators.zipWithIndex.foreach { case (e, index) if e.eventTimeStats.value.count > 0 => logDebug(s"Observed event time stats $index: ${e.eventTimeStats.value}") @@ -58,16 +110,28 @@ class WatermarkTracker extends Logging { // This is the safest option, because only the global watermark is fault-tolerant. Making // it the minimum of all individual watermarks guarantees it will never advance past where // any individual watermark operator would be if it were in a plan by itself. - val newWatermarkMs = operatorToWatermarkMap.minBy(_._2)._2 - if (newWatermarkMs > watermarkMs) { - logInfo(s"Updating eventTime watermark to: $newWatermarkMs ms") - watermarkMs = newWatermarkMs - updated = true + val chosenGlobalWatermark = policy.chooseGlobalWatermark(operatorToWatermarkMap.values.toSeq) + if (chosenGlobalWatermark > globalWatermarkMs) { + logInfo(s"Updating event-time watermark from $globalWatermarkMs to $chosenGlobalWatermark ms") + globalWatermarkMs = chosenGlobalWatermark } else { - logDebug(s"Event time didn't move: $newWatermarkMs < $watermarkMs") - updated = false + logDebug(s"Event time watermark didn't move: $chosenGlobalWatermark < $globalWatermarkMs") } } - def currentWatermark: Long = synchronized { watermarkMs } + def currentWatermark: Long = synchronized { globalWatermarkMs } +} + +object WatermarkTracker { + def apply(conf: RuntimeConfig): WatermarkTracker = { + // If the session has been explicitly configured to use non-default policy then use it, + // otherwise use the default `min` policy as thats the safe thing to do. + // When recovering from a checkpoint location, it is expected that the `conf` will already + // be configured with the value present in the checkpoint. If there is no policy explicitly + // saved in the checkpoint (e.g., old checkpoints), then the default `min` policy is enforced + // through defaults specified in OffsetSeqMetadata.setSessionConf(). + val policyName = conf.get( + SQLConf.STREAMING_MULTIPLE_WATERMARK_POLICY, MultipleWatermarkPolicy.DEFAULT_POLICY_NAME) + new WatermarkTracker(MultipleWatermarkPolicy(policyName)) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousCoalesceExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousCoalesceExec.scala new file mode 100644 index 0000000000000..5f60343bacfaa --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousCoalesceExec.scala @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.continuous + +import java.util.UUID + +import org.apache.spark.{HashPartitioner, SparkEnv} +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeRow} +import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, SinglePartition} +import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.streaming.continuous.shuffle.{ContinuousShuffleReadPartition, ContinuousShuffleReadRDD} + +/** + * Physical plan for coalescing a continuous processing plan. + * + * Currently, only coalesces to a single partition are supported. `numPartitions` must be 1. + */ +case class ContinuousCoalesceExec(numPartitions: Int, child: SparkPlan) extends SparkPlan { + override def output: Seq[Attribute] = child.output + + override def children: Seq[SparkPlan] = child :: Nil + + override def outputPartitioning: Partitioning = SinglePartition + + override def doExecute(): RDD[InternalRow] = { + assert(numPartitions == 1) + new ContinuousCoalesceRDD( + sparkContext, + numPartitions, + conf.continuousStreamingExecutorQueueSize, + sparkContext.getLocalProperty(ContinuousExecution.EPOCH_INTERVAL_KEY).toLong, + child.execute()) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousCoalesceRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousCoalesceRDD.scala new file mode 100644 index 0000000000000..ba85b355f974f --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousCoalesceRDD.scala @@ -0,0 +1,136 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.continuous + +import java.util.UUID + +import org.apache.spark._ +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.sql.execution.streaming.continuous.shuffle._ +import org.apache.spark.util.ThreadUtils + +case class ContinuousCoalesceRDDPartition( + index: Int, + endpointName: String, + queueSize: Int, + numShuffleWriters: Int, + epochIntervalMs: Long) + extends Partition { + // Initialized only on the executor, and only once even as we call compute() multiple times. + lazy val (reader: ContinuousShuffleReader, endpoint) = { + val env = SparkEnv.get.rpcEnv + val receiver = new RPCContinuousShuffleReader( + queueSize, numShuffleWriters, epochIntervalMs, env) + val endpoint = env.setupEndpoint(endpointName, receiver) + + TaskContext.get().addTaskCompletionListener { ctx => + env.stop(endpoint) + } + (receiver, endpoint) + } + // This flag will be flipped on the executors to indicate that the threads processing + // partitions of the write-side RDD have been started. These will run indefinitely + // asynchronously as epochs of the coalesce RDD complete on the read side. + private[continuous] var writersInitialized: Boolean = false +} + +/** + * RDD for continuous coalescing. Asynchronously writes all partitions of `prev` into a local + * continuous shuffle, and then reads them in the task thread using `reader`. + */ +class ContinuousCoalesceRDD( + context: SparkContext, + numPartitions: Int, + readerQueueSize: Int, + epochIntervalMs: Long, + prev: RDD[InternalRow]) + extends RDD[InternalRow](context, Nil) { + + // When we support more than 1 target partition, we'll need to figure out how to pass in the + // required partitioner. + private val outputPartitioner = new HashPartitioner(1) + + private val readerEndpointNames = (0 until numPartitions).map { i => + s"ContinuousCoalesceRDD-part$i-${UUID.randomUUID()}" + } + + override def getPartitions: Array[Partition] = { + (0 until numPartitions).map { partIndex => + ContinuousCoalesceRDDPartition( + partIndex, + readerEndpointNames(partIndex), + readerQueueSize, + prev.getNumPartitions, + epochIntervalMs) + }.toArray + } + + private lazy val threadPool = ThreadUtils.newDaemonFixedThreadPool( + prev.getNumPartitions, + this.name) + + override def compute(split: Partition, context: TaskContext): Iterator[InternalRow] = { + val part = split.asInstanceOf[ContinuousCoalesceRDDPartition] + + if (!part.writersInitialized) { + val rpcEnv = SparkEnv.get.rpcEnv + + // trigger lazy initialization + part.endpoint + val endpointRefs = readerEndpointNames.map { endpointName => + rpcEnv.setupEndpointRef(rpcEnv.address, endpointName) + } + + val runnables = prev.partitions.map { prevSplit => + new Runnable() { + override def run(): Unit = { + TaskContext.setTaskContext(context) + + val writer: ContinuousShuffleWriter = new RPCContinuousShuffleWriter( + prevSplit.index, outputPartitioner, endpointRefs.toArray) + + EpochTracker.initializeCurrentEpoch( + context.getLocalProperty(ContinuousExecution.START_EPOCH_KEY).toLong) + while (!context.isInterrupted() && !context.isCompleted()) { + writer.write(prev.compute(prevSplit, context).asInstanceOf[Iterator[UnsafeRow]]) + // Note that current epoch is a non-inheritable thread local, so each writer thread + // can properly increment its own epoch without affecting the main task thread. + EpochTracker.incrementCurrentEpoch() + } + } + } + } + + context.addTaskCompletionListener { ctx => + threadPool.shutdownNow() + } + + part.writersInitialized = true + + runnables.foreach(threadPool.execute) + } + + part.reader.read() + } + + override def clearDependencies(): Unit = { + throw new IllegalStateException("Continuous RDDs cannot be checkpointed") + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDD.scala index a7ccce10b0cee..73868d5967e90 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDD.scala @@ -51,11 +51,11 @@ class ContinuousDataSourceRDD( sc: SparkContext, dataQueueSize: Int, epochPollIntervalMs: Long, - @transient private val readerFactories: Seq[InputPartition[UnsafeRow]]) + private val readerInputPartitions: Seq[InputPartition[UnsafeRow]]) extends RDD[UnsafeRow](sc, Nil) { override protected def getPartitions: Array[Partition] = { - readerFactories.zipWithIndex.map { + readerInputPartitions.zipWithIndex.map { case (inputPartition, index) => new ContinuousDataSourceRDDPartition(index, inputPartition) }.toArray } @@ -74,8 +74,7 @@ class ContinuousDataSourceRDD( val partition = split.asInstanceOf[ContinuousDataSourceRDDPartition] if (partition.queueReader == null) { partition.queueReader = - new ContinuousQueuedDataReader( - partition.inputPartition, context, dataQueueSize, epochPollIntervalMs) + new ContinuousQueuedDataReader(partition, context, dataQueueSize, epochPollIntervalMs) } partition.queueReader diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala index e3d0cea608b2a..e991dbc81696d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala @@ -216,6 +216,9 @@ class ContinuousExecution( currentEpochCoordinatorId = epochCoordinatorId sparkSessionForQuery.sparkContext.setLocalProperty( ContinuousExecution.EPOCH_COORDINATOR_ID_KEY, epochCoordinatorId) + sparkSessionForQuery.sparkContext.setLocalProperty( + ContinuousExecution.EPOCH_INTERVAL_KEY, + trigger.asInstanceOf[ContinuousTrigger].intervalMs.toString) // Use the parent Spark session for the endpoint since it's where this query ID is registered. val epochEndpoint = @@ -306,7 +309,10 @@ class ContinuousExecution( def commit(epoch: Long): Unit = { assert(continuousSources.length == 1, "only one continuous source supported currently") assert(offsetLog.get(epoch).isDefined, s"offset for epoch $epoch not reported before commit") + synchronized { + // Record offsets before updating `committedOffsets` + recordTriggerOffsets(from = committedOffsets, to = availableOffsets) if (queryExecutionThread.isAlive) { commitLog.add(epoch) val offset = @@ -382,4 +388,5 @@ class ContinuousExecution( object ContinuousExecution { val START_EPOCH_KEY = "__continuous_start_epoch" val EPOCH_COORDINATOR_ID_KEY = "__epoch_coordinator_id" + val EPOCH_INTERVAL_KEY = "__continuous_epoch_interval" } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousQueuedDataReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousQueuedDataReader.scala index f38577b6a9f16..8c74b8244d096 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousQueuedDataReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousQueuedDataReader.scala @@ -37,11 +37,11 @@ import org.apache.spark.util.ThreadUtils * offsets across epochs. Each compute() should call the next() method here until null is returned. */ class ContinuousQueuedDataReader( - partition: InputPartition[UnsafeRow], + partition: ContinuousDataSourceRDDPartition, context: TaskContext, dataQueueSize: Int, epochPollIntervalMs: Long) extends Closeable { - private val reader = partition.createPartitionReader() + private val reader = partition.inputPartition.createPartitionReader() // Important sequencing - we must get our starting point before the provider threads start running private var currentOffset: PartitionOffset = @@ -113,7 +113,7 @@ class ContinuousQueuedDataReader( currentEntry match { case EpochMarker => epochCoordEndpoint.send(ReportPartitionOffset( - context.partitionId(), EpochTracker.getCurrentEpoch.get, currentOffset)) + partition.index, EpochTracker.getCurrentEpoch.get, currentOffset)) null case ContinuousRow(row, offset) => currentOffset = offset diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousWriteRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousWriteRDD.scala index ef5f0da1e7cc2..76f3f5baa8d56 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousWriteRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousWriteRDD.scala @@ -56,7 +56,7 @@ class ContinuousWriteRDD(var prev: RDD[InternalRow], writeTask: DataWriterFactor val dataIterator = prev.compute(split, context) dataWriter = writeTask.createDataWriter( context.partitionId(), - context.attemptNumber(), + context.taskAttemptId(), EpochTracker.getCurrentEpoch.get) while (dataIterator.hasNext) { dataWriter.write(dataIterator.next()) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala index 801b28b751bee..518223f3cd008 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala @@ -21,12 +21,14 @@ import java.util.UUID import org.apache.spark.{Partition, SparkContext, SparkEnv, TaskContext} import org.apache.spark.rdd.RDD +import org.apache.spark.rpc.RpcAddress import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.internal.SQLConf import org.apache.spark.util.NextIterator case class ContinuousShuffleReadPartition( index: Int, + endpointName: String, queueSize: Int, numShuffleWriters: Int, epochIntervalMs: Long) @@ -34,8 +36,10 @@ case class ContinuousShuffleReadPartition( // Initialized only on the executor, and only once even as we call compute() multiple times. lazy val (reader: ContinuousShuffleReader, endpoint) = { val env = SparkEnv.get.rpcEnv - val receiver = new UnsafeRowReceiver(queueSize, numShuffleWriters, epochIntervalMs, env) - val endpoint = env.setupEndpoint(s"UnsafeRowReceiver-${UUID.randomUUID()}", receiver) + val receiver = new RPCContinuousShuffleReader( + queueSize, numShuffleWriters, epochIntervalMs, env) + val endpoint = env.setupEndpoint(endpointName, receiver) + TaskContext.get().addTaskCompletionListener { ctx => env.stop(endpoint) } @@ -59,12 +63,14 @@ class ContinuousShuffleReadRDD( numPartitions: Int, queueSize: Int = 1024, numShuffleWriters: Int = 1, - epochIntervalMs: Long = 1000) + epochIntervalMs: Long = 1000, + val endpointNames: Seq[String] = Seq(s"RPCContinuousShuffleReader-${UUID.randomUUID()}")) extends RDD[UnsafeRow](sc, Nil) { override protected def getPartitions: Array[Partition] = { (0 until numPartitions).map { partIndex => - ContinuousShuffleReadPartition(partIndex, queueSize, numShuffleWriters, epochIntervalMs) + ContinuousShuffleReadPartition( + partIndex, endpointNames(partIndex), queueSize, numShuffleWriters, epochIntervalMs) }.toArray } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleWriter.scala new file mode 100644 index 0000000000000..47b1f78b24505 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleWriter.scala @@ -0,0 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.continuous.shuffle + +import org.apache.spark.sql.catalyst.expressions.UnsafeRow + +/** + * Trait for writing to a continuous processing shuffle. + */ +trait ContinuousShuffleWriter { + def write(epoch: Iterator[UnsafeRow]): Unit +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/UnsafeRowReceiver.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/RPCContinuousShuffleReader.scala similarity index 84% rename from sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/UnsafeRowReceiver.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/RPCContinuousShuffleReader.scala index d81f552d56626..502ae0d4822e8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/UnsafeRowReceiver.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/RPCContinuousShuffleReader.scala @@ -20,26 +20,24 @@ package org.apache.spark.sql.execution.streaming.continuous.shuffle import java.util.concurrent._ import java.util.concurrent.atomic.AtomicBoolean -import scala.collection.mutable - import org.apache.spark.internal.Logging import org.apache.spark.rpc.{RpcCallContext, RpcEnv, ThreadSafeRpcEndpoint} import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.util.NextIterator /** - * Messages for the UnsafeRowReceiver endpoint. Either an incoming row or an epoch marker. + * Messages for the RPCContinuousShuffleReader endpoint. Either an incoming row or an epoch marker. * * Each message comes tagged with writerId, identifying which writer the message is coming * from. The receiver will only begin the next epoch once all writers have sent an epoch * marker ending the current epoch. */ -private[shuffle] sealed trait UnsafeRowReceiverMessage extends Serializable { +private[shuffle] sealed trait RPCContinuousShuffleMessage extends Serializable { def writerId: Int } private[shuffle] case class ReceiverRow(writerId: Int, row: UnsafeRow) - extends UnsafeRowReceiverMessage -private[shuffle] case class ReceiverEpochMarker(writerId: Int) extends UnsafeRowReceiverMessage + extends RPCContinuousShuffleMessage +private[shuffle] case class ReceiverEpochMarker(writerId: Int) extends RPCContinuousShuffleMessage /** * RPC endpoint for receiving rows into a continuous processing shuffle task. Continuous shuffle @@ -48,7 +46,7 @@ private[shuffle] case class ReceiverEpochMarker(writerId: Int) extends UnsafeRow * TODO: Support multiple source tasks. We need to output a single epoch marker once all * source tasks have sent one. */ -private[shuffle] class UnsafeRowReceiver( +private[continuous] class RPCContinuousShuffleReader( queueSize: Int, numShuffleWriters: Int, epochIntervalMs: Long, @@ -57,7 +55,7 @@ private[shuffle] class UnsafeRowReceiver( // Note that this queue will be drained from the main task thread and populated in the RPC // response thread. private val queues = Array.fill(numShuffleWriters) { - new ArrayBlockingQueue[UnsafeRowReceiverMessage](queueSize) + new ArrayBlockingQueue[RPCContinuousShuffleMessage](queueSize) } // Exposed for testing to determine if the endpoint gets stopped on task end. @@ -68,7 +66,9 @@ private[shuffle] class UnsafeRowReceiver( } override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { - case r: UnsafeRowReceiverMessage => + case r: RPCContinuousShuffleMessage => + // Note that this will block a thread the shared RPC handler pool! + // The TCP based shuffle handler (SPARK-24541) will avoid this problem. queues(r.writerId).put(r) context.reply(()) } @@ -79,10 +79,10 @@ private[shuffle] class UnsafeRowReceiver( private val writerEpochMarkersReceived = Array.fill(numShuffleWriters)(false) private val executor = Executors.newFixedThreadPool(numShuffleWriters) - private val completion = new ExecutorCompletionService[UnsafeRowReceiverMessage](executor) + private val completion = new ExecutorCompletionService[RPCContinuousShuffleMessage](executor) - private def completionTask(writerId: Int) = new Callable[UnsafeRowReceiverMessage] { - override def call(): UnsafeRowReceiverMessage = queues(writerId).take() + private def completionTask(writerId: Int) = new Callable[RPCContinuousShuffleMessage] { + override def call(): RPCContinuousShuffleMessage = queues(writerId).take() } // Initialize by submitting tasks to read the first row from each writer. @@ -107,7 +107,7 @@ private[shuffle] class UnsafeRowReceiver( } logWarning( s"Completion service failed to make progress after $epochIntervalMs ms. Waiting " + - s"for writers $writerIdsUncommitted to send epoch markers.") + s"for writers ${writerIdsUncommitted.mkString(",")} to send epoch markers.") // The completion service guarantees this future will be available immediately. case future => future.get() match { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/RPCContinuousShuffleWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/RPCContinuousShuffleWriter.scala new file mode 100644 index 0000000000000..1c6f3ddb395e6 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/RPCContinuousShuffleWriter.scala @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.continuous.shuffle + +import scala.concurrent.Future +import scala.concurrent.duration.Duration + +import org.apache.spark.Partitioner +import org.apache.spark.rpc.RpcEndpointRef +import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.util.ThreadUtils + +/** + * A [[ContinuousShuffleWriter]] sending data to [[RPCContinuousShuffleReader]] instances. + * + * @param writerId The partition ID of this writer. + * @param outputPartitioner The partitioner on the reader side of the shuffle. + * @param endpoints The [[RPCContinuousShuffleReader]] endpoints to write to. Indexed by + * partition ID within outputPartitioner. + */ +class RPCContinuousShuffleWriter( + writerId: Int, + outputPartitioner: Partitioner, + endpoints: Array[RpcEndpointRef]) extends ContinuousShuffleWriter { + + if (outputPartitioner.numPartitions != 1) { + throw new IllegalArgumentException("multiple readers not yet supported") + } + + if (outputPartitioner.numPartitions != endpoints.length) { + throw new IllegalArgumentException(s"partitioner size ${outputPartitioner.numPartitions} did " + + s"not match endpoint count ${endpoints.length}") + } + + def write(epoch: Iterator[UnsafeRow]): Unit = { + while (epoch.hasNext) { + val row = epoch.next() + endpoints(outputPartitioner.getPartition(row)).askSync[Unit](ReceiverRow(writerId, row)) + } + + val futures = endpoints.map(_.ask[Unit](ReceiverEpochMarker(writerId))).toSeq + implicit val ec = ThreadUtils.sameThread + ThreadUtils.awaitResult(Future.sequence(futures), Duration.Inf) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala index d1c3498450096..0bf90b8063326 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala @@ -23,12 +23,13 @@ import java.util.concurrent.atomic.AtomicInteger import javax.annotation.concurrent.GuardedBy import scala.collection.JavaConverters._ +import scala.collection.SortedMap import scala.collection.mutable.ListBuffer import org.json4s.NoTypeHints import org.json4s.jackson.Serialization -import org.apache.spark.SparkEnv +import org.apache.spark.{SparkEnv, TaskContext} import org.apache.spark.rpc.{RpcCallContext, RpcEndpointRef, RpcEnv, ThreadSafeRpcEndpoint} import org.apache.spark.sql.{Encoder, Row, SQLContext} import org.apache.spark.sql.execution.streaming._ @@ -184,6 +185,14 @@ class ContinuousMemoryStreamInputPartitionReader( private var currentOffset = startOffset private var current: Option[Row] = None + // Defense-in-depth against failing to propagate the task context. Since it's not inheritable, + // we have to do a bit of error prone work to get it into every thread used by continuous + // processing. We hope that some unit test will end up instantiating a continuous memory stream + // in such cases. + if (TaskContext.get() == null) { + throw new IllegalStateException("Task context was not set!") + } + override def next(): Boolean = { current = getRecord while (current.isEmpty) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachBatchSink.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachBatchSink.scala new file mode 100644 index 0000000000000..03c567c58d46a --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachBatchSink.scala @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.sources + +import org.apache.spark.api.python.PythonException +import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder +import org.apache.spark.sql.execution.streaming.Sink +import org.apache.spark.sql.streaming.DataStreamWriter + +class ForeachBatchSink[T](batchWriter: (Dataset[T], Long) => Unit, encoder: ExpressionEncoder[T]) + extends Sink { + + override def addBatch(batchId: Long, data: DataFrame): Unit = { + val resolvedEncoder = encoder.resolveAndBind( + data.logicalPlan.output, + data.sparkSession.sessionState.analyzer) + val rdd = data.queryExecution.toRdd.map[T](resolvedEncoder.fromRow)(encoder.clsTag) + val ds = data.sparkSession.createDataset(rdd)(encoder) + batchWriter(ds, batchId) + } + + override def toString(): String = "ForeachBatchSink" +} + + +/** + * Interface that is meant to be extended by Python classes via Py4J. + * Py4J allows Python classes to implement Java interfaces so that the JVM can call back + * Python objects. In this case, this allows the user-defined Python `foreachBatch` function + * to be called from JVM when the query is active. + * */ +trait PythonForeachBatchFunction { + /** Call the Python implementation of this function */ + def call(batchDF: DataFrame, batchId: Long): Unit +} + +object PythonForeachBatchHelper { + def callForeachBatch(dsw: DataStreamWriter[Row], pythonFunc: PythonForeachBatchFunction): Unit = { + dsw.foreachBatch(pythonFunc.call _) + } +} + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterProvider.scala index df5d69d57e36f..bc9b6d93ce7d9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterProvider.scala @@ -20,6 +20,8 @@ package org.apache.spark.sql.execution.streaming.sources import org.apache.spark.sql.{Encoder, ForeachWriter, SparkSession} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder} +import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.sql.execution.python.PythonForeachWriter import org.apache.spark.sql.sources.v2.{DataSourceOptions, StreamWriteSupport} import org.apache.spark.sql.sources.v2.writer.{DataWriter, DataWriterFactory, SupportsWriteInternalRow, WriterCommitMessage} import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter @@ -31,9 +33,14 @@ import org.apache.spark.sql.types.StructType * [[ForeachWriter]]. * * @param writer The [[ForeachWriter]] to process all data. + * @param converter An object to convert internal rows to target type T. Either it can be + * a [[ExpressionEncoder]] or a direct converter function. * @tparam T The expected type of the sink. */ -case class ForeachWriterProvider[T: Encoder](writer: ForeachWriter[T]) extends StreamWriteSupport { +case class ForeachWriterProvider[T]( + writer: ForeachWriter[T], + converter: Either[ExpressionEncoder[T], InternalRow => T]) extends StreamWriteSupport { + override def createStreamWriter( queryId: String, schema: StructType, @@ -44,10 +51,16 @@ case class ForeachWriterProvider[T: Encoder](writer: ForeachWriter[T]) extends S override def abort(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {} override def createInternalRowWriterFactory(): DataWriterFactory[InternalRow] = { - val encoder = encoderFor[T].resolveAndBind( - schema.toAttributes, - SparkSession.getActiveSession.get.sessionState.analyzer) - ForeachWriterFactory(writer, encoder) + val rowConverter: InternalRow => T = converter match { + case Left(enc) => + val boundEnc = enc.resolveAndBind( + schema.toAttributes, + SparkSession.getActiveSession.get.sessionState.analyzer) + boundEnc.fromRow + case Right(func) => + func + } + ForeachWriterFactory(writer, rowConverter) } override def toString: String = "ForeachSink" @@ -55,29 +68,44 @@ case class ForeachWriterProvider[T: Encoder](writer: ForeachWriter[T]) extends S } } -case class ForeachWriterFactory[T: Encoder]( +object ForeachWriterProvider { + def apply[T]( + writer: ForeachWriter[T], + encoder: ExpressionEncoder[T]): ForeachWriterProvider[_] = { + writer match { + case pythonWriter: PythonForeachWriter => + new ForeachWriterProvider[UnsafeRow]( + pythonWriter, Right((x: InternalRow) => x.asInstanceOf[UnsafeRow])) + case _ => + new ForeachWriterProvider[T](writer, Left(encoder)) + } + } +} + +case class ForeachWriterFactory[T]( writer: ForeachWriter[T], - encoder: ExpressionEncoder[T]) + rowConverter: InternalRow => T) extends DataWriterFactory[InternalRow] { override def createDataWriter( partitionId: Int, - attemptNumber: Int, + taskId: Long, epochId: Long): ForeachDataWriter[T] = { - new ForeachDataWriter(writer, encoder, partitionId, epochId) + new ForeachDataWriter(writer, rowConverter, partitionId, epochId) } } /** * A [[DataWriter]] which writes data in this partition to a [[ForeachWriter]]. + * * @param writer The [[ForeachWriter]] to process all data. - * @param encoder An encoder which can convert [[InternalRow]] to the required type [[T]] + * @param rowConverter A function which can convert [[InternalRow]] to the required type [[T]] * @param partitionId * @param epochId * @tparam T The type expected by the writer. */ -class ForeachDataWriter[T : Encoder]( +class ForeachDataWriter[T]( writer: ForeachWriter[T], - encoder: ExpressionEncoder[T], + rowConverter: InternalRow => T, partitionId: Int, epochId: Long) extends DataWriter[InternalRow] { @@ -89,7 +117,7 @@ class ForeachDataWriter[T : Encoder]( if (!opened) return try { - writer.process(encoder.fromRow(record)) + writer.process(rowConverter(record)) } catch { case t: Throwable => writer.close(t) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/PackedRowWriterFactory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/PackedRowWriterFactory.scala index e07355aa37dba..b501d90c81f06 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/PackedRowWriterFactory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/PackedRowWriterFactory.scala @@ -33,7 +33,7 @@ import org.apache.spark.sql.sources.v2.writer.{DataSourceWriter, DataWriter, Dat case object PackedRowWriterFactory extends DataWriterFactory[Row] { override def createDataWriter( partitionId: Int, - attemptNumber: Int, + taskId: Long, epochId: Long): DataWriter[Row] = { new PackedRowDataWriter() } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchReader.scala index fbff8db987110..b393c48baee8d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchReader.scala @@ -202,7 +202,7 @@ class RateStreamMicroBatchInputPartitionReader( rangeEnd: Long, localStartTimeMs: Long, relativeMsPerValue: Double) extends InputPartitionReader[Row] { - private var count = 0 + private var count: Long = 0 override def next(): Boolean = { rangeStart + partitionId + numPartitions * count < rangeEnd diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala index 468313bfe8c3c..f2a35a90af24a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala @@ -154,7 +154,7 @@ class MemoryStreamWriter(val sink: MemorySinkV2, outputMode: OutputMode) case class MemoryWriterFactory(outputMode: OutputMode) extends DataWriterFactory[Row] { override def createDataWriter( partitionId: Int, - attemptNumber: Int, + taskId: Long, epochId: Long): DataWriter[Row] = { new MemoryDataWriter(partitionId, outputMode) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/AllExecutionsPage.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/AllExecutionsPage.scala index bf46bc4cf904d..a7a24ac3641b5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/AllExecutionsPage.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/AllExecutionsPage.scala @@ -214,11 +214,11 @@ private[ui] abstract class ExecutionTable( } private def jobURL(request: HttpServletRequest, jobId: Long): String = - "%s/jobs/job?id=%s".format(UIUtils.prependBaseUri(request, parent.basePath), jobId) + "%s/jobs/job/?id=%s".format(UIUtils.prependBaseUri(request, parent.basePath), jobId) private def executionURL(request: HttpServletRequest, executionID: Long): String = s"${UIUtils.prependBaseUri( - request, parent.basePath)}/${parent.prefix}/execution?id=$executionID" + request, parent.basePath)}/${parent.prefix}/execution/?id=$executionID" } private[ui] class RunningExecutionTable( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/ExecutionPage.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/ExecutionPage.scala index 282f7b4bb5a58..877176b030f8b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/ExecutionPage.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/ExecutionPage.scala @@ -122,7 +122,7 @@ class ExecutionPage(parent: SQLTab) extends WebUIPage("execution") with Logging } private def jobURL(request: HttpServletRequest, jobId: Long): String = - "%s/jobs/job?id=%s".format(UIUtils.prependBaseUri(request, parent.basePath), jobId) + "%s/jobs/job/?id=%s".format(UIUtils.prependBaseUri(request, parent.basePath), jobId) private def physicalPlanDescription(physicalPlanDescription: String): Seq[Node] = {
    diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 87bd7b3b0f9c6..de1d422856ba9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -1031,14 +1031,6 @@ object functions { // Non-aggregate functions ////////////////////////////////////////////////////////////////////////////////////////////// - /** - * Computes the absolute value. - * - * @group normal_funcs - * @since 1.3.0 - */ - def abs(e: Column): Column = withExpr { Abs(e.expr) } - /** * Creates a new array column. The input columns must all have the same data type. * @@ -1336,7 +1328,7 @@ object functions { } /** - * Computes bitwise NOT. + * Computes bitwise NOT (~) of a number. * * @group normal_funcs * @since 1.4.0 @@ -1364,6 +1356,14 @@ object functions { // Math Functions ////////////////////////////////////////////////////////////////////////////////////////////// + /** + * Computes the absolute value of a numeric value. + * + * @group math_funcs + * @since 1.3.0 + */ + def abs(e: Column): Column = withExpr { Abs(e.expr) } + /** * @return inverse cosine of `e` in radians, as if computed by `java.lang.Math.acos` * @@ -2934,6 +2934,17 @@ object functions { FromUTCTimestamp(ts.expr, Literal(tz)) } + /** + * Given a timestamp like '2017-07-14 02:40:00.0', interprets it as a time in UTC, and renders + * that time as a timestamp in the given time zone. For example, 'GMT+1' would yield + * '2017-07-14 03:40:00.0'. + * @group datetime_funcs + * @since 2.4.0 + */ + def from_utc_timestamp(ts: Column, tz: Column): Column = withExpr { + FromUTCTimestamp(ts.expr, tz.expr) + } + /** * Given a timestamp like '2017-07-14 02:40:00.0', interprets it as a time in the given time * zone, and renders that time as a timestamp in UTC. For example, 'GMT+1' would yield @@ -2945,6 +2956,17 @@ object functions { ToUTCTimestamp(ts.expr, Literal(tz)) } + /** + * Given a timestamp like '2017-07-14 02:40:00.0', interprets it as a time in the given time + * zone, and renders that time as a timestamp in UTC. For example, 'GMT+1' would yield + * '2017-07-14 01:40:00.0'. + * @group datetime_funcs + * @since 2.4.0 + */ + def to_utc_timestamp(ts: Column, tz: Column): Column = withExpr { + ToUTCTimestamp(ts.expr, tz.expr) + } + /** * Bucketize rows into one or more time windows given a timestamp specifying column. Window * starts are inclusive but the window ends are exclusive, e.g. 12:05 will be in the window @@ -3093,7 +3115,7 @@ object functions { * @since 1.5.0 */ def array_contains(column: Column, value: Any): Column = withExpr { - ArrayContains(column.expr, Literal(value)) + ArrayContains(column.expr, lit(value).expr) } /** @@ -3157,7 +3179,7 @@ object functions { * @since 2.4.0 */ def array_position(column: Column, value: Any): Column = withExpr { - ArrayPosition(column.expr, Literal(value)) + ArrayPosition(column.expr, lit(value).expr) } /** @@ -3168,7 +3190,7 @@ object functions { * @since 2.4.0 */ def element_at(column: Column, value: Any): Column = withExpr { - ElementAt(column.expr, Literal(value)) + ElementAt(column.expr, lit(value).expr) } /** @@ -3182,11 +3204,29 @@ object functions { /** * Remove all elements that equal to element from the given array. + * * @group collection_funcs * @since 2.4.0 */ def array_remove(column: Column, element: Any): Column = withExpr { - ArrayRemove(column.expr, Literal(element)) + ArrayRemove(column.expr, lit(element).expr) + } + + /** + * Removes duplicate values from the array. + * @group collection_funcs + * @since 2.4.0 + */ + def array_distinct(e: Column): Column = withExpr { ArrayDistinct(e.expr) } + + /** + * Returns an array of the elements in the union of the given two arrays, without duplicates. + * + * @group collection_funcs + * @since 2.4.0 + */ + def array_union(col1: Column, col2: Column): Column = withExpr { + ArrayUnion(col1.expr, col2.expr) } /** @@ -3275,7 +3315,7 @@ object functions { * @since 2.2.0 */ def from_json(e: Column, schema: DataType, options: Map[String, String]): Column = withExpr { - new JsonToStructs(schema, options, e.expr) + JsonToStructs(schema, options, e.expr) } /** @@ -3369,11 +3409,53 @@ object functions { val dataType = try { DataType.fromJson(schema) } catch { - case NonFatal(_) => StructType.fromDDL(schema) + case NonFatal(_) => DataType.fromDDL(schema) } from_json(e, dataType, options) } + /** + * (Scala-specific) Parses a column containing a JSON string into a `MapType` with `StringType` + * as keys type, `StructType` or `ArrayType` of `StructType`s with the specified schema. + * Returns `null`, in the case of an unparseable string. + * + * @param e a string column containing JSON data. + * @param schema the schema to use when parsing the json string + * + * @group collection_funcs + * @since 2.4.0 + */ + def from_json(e: Column, schema: Column): Column = { + from_json(e, schema, Map.empty[String, String].asJava) + } + + /** + * (Java-specific) Parses a column containing a JSON string into a `MapType` with `StringType` + * as keys type, `StructType` or `ArrayType` of `StructType`s with the specified schema. + * Returns `null`, in the case of an unparseable string. + * + * @param e a string column containing JSON data. + * @param schema the schema to use when parsing the json string + * @param options options to control how the json is parsed. accepts the same options and the + * json data source. + * + * @group collection_funcs + * @since 2.4.0 + */ + def from_json(e: Column, schema: Column, options: java.util.Map[String, String]): Column = { + withExpr(new JsonToStructs(e.expr, schema.expr, options.asScala.toMap)) + } + + /** + * Parses a column containing a JSON string and infers its schema. + * + * @param e a string column containing JSON data. + * + * @group collection_funcs + * @since 2.4.0 + */ + def schema_of_json(e: Column): Column = withExpr(new SchemaOfJson(e.expr)) + /** * (Scala-specific) Converts a column containing a `StructType`, `ArrayType` of `StructType`s, * a `MapType` or `ArrayType` of `MapType`s into a JSON string with the specified schema. @@ -3478,6 +3560,27 @@ object functions { */ def flatten(e: Column): Column = withExpr { Flatten(e.expr) } + /** + * Generate a sequence of integers from start to stop, incrementing by step. + * + * @group collection_funcs + * @since 2.4.0 + */ + def sequence(start: Column, stop: Column, step: Column): Column = withExpr { + new Sequence(start.expr, stop.expr, step.expr) + } + + /** + * Generate a sequence of integers from start to stop, + * incrementing by 1 if start is less than or equal to stop, otherwise -1. + * + * @group collection_funcs + * @since 2.4.0 + */ + def sequence(start: Column, stop: Column): Column = withExpr { + new Sequence(start.expr, stop.expr) + } + /** * Creates an array containing the left argument repeated the number of times given by the * right argument. @@ -3520,131 +3623,28 @@ object functions { def map_entries(e: Column): Column = withExpr { MapEntries(e.expr) } /** - * Returns a merged array of structs in which the N-th struct contains all N-th values of input - * arrays. + * Returns a map created from the given array of entries. * @group collection_funcs * @since 2.4.0 */ - def arrays_zip(e: Column*): Column = withExpr { ArraysZip(e.map(_.expr)) } - - ////////////////////////////////////////////////////////////////////////////////////////////// - // Mask functions - ////////////////////////////////////////////////////////////////////////////////////////////// - /** - * Returns a string which is the masked representation of the input. - * @group mask_funcs - * @since 2.4.0 - */ - def mask(e: Column): Column = withExpr { new Mask(e.expr) } - - /** - * Returns a string which is the masked representation of the input, using `upper`, `lower` and - * `digit` as replacement characters. - * @group mask_funcs - * @since 2.4.0 - */ - def mask(e: Column, upper: String, lower: String, digit: String): Column = withExpr { - Mask(e.expr, upper, lower, digit) - } - - /** - * Returns a string with the first `n` characters masked. - * @group mask_funcs - * @since 2.4.0 - */ - def mask_first_n(e: Column, n: Int): Column = withExpr { new MaskFirstN(e.expr, Literal(n)) } + def map_from_entries(e: Column): Column = withExpr { MapFromEntries(e.expr) } /** - * Returns a string with the first `n` characters masked, using `upper`, `lower` and `digit` as - * replacement characters. - * @group mask_funcs - * @since 2.4.0 - */ - def mask_first_n( - e: Column, - n: Int, - upper: String, - lower: String, - digit: String): Column = withExpr { - MaskFirstN(e.expr, n, upper, lower, digit) - } - - /** - * Returns a string with the last `n` characters masked. - * @group mask_funcs - * @since 2.4.0 - */ - def mask_last_n(e: Column, n: Int): Column = withExpr { new MaskLastN(e.expr, Literal(n)) } - - /** - * Returns a string with the last `n` characters masked, using `upper`, `lower` and `digit` as - * replacement characters. - * @group mask_funcs - * @since 2.4.0 - */ - def mask_last_n( - e: Column, - n: Int, - upper: String, - lower: String, - digit: String): Column = withExpr { - MaskLastN(e.expr, n, upper, lower, digit) - } - - /** - * Returns a string with all but the first `n` characters masked. - * @group mask_funcs - * @since 2.4.0 - */ - def mask_show_first_n(e: Column, n: Int): Column = withExpr { - new MaskShowFirstN(e.expr, Literal(n)) - } - - /** - * Returns a string with all but the first `n` characters masked, using `upper`, `lower` and - * `digit` as replacement characters. - * @group mask_funcs - * @since 2.4.0 - */ - def mask_show_first_n( - e: Column, - n: Int, - upper: String, - lower: String, - digit: String): Column = withExpr { - MaskShowFirstN(e.expr, n, upper, lower, digit) - } - - /** - * Returns a string with all but the last `n` characters masked. - * @group mask_funcs - * @since 2.4.0 - */ - def mask_show_last_n(e: Column, n: Int): Column = withExpr { - new MaskShowLastN(e.expr, Literal(n)) - } - - /** - * Returns a string with all but the last `n` characters masked, using `upper`, `lower` and - * `digit` as replacement characters. - * @group mask_funcs + * Returns a merged array of structs in which the N-th struct contains all N-th values of input + * arrays. + * @group collection_funcs * @since 2.4.0 */ - def mask_show_last_n( - e: Column, - n: Int, - upper: String, - lower: String, - digit: String): Column = withExpr { - MaskShowLastN(e.expr, n, upper, lower, digit) - } + @scala.annotation.varargs + def arrays_zip(e: Column*): Column = withExpr { ArraysZip(e.map(_.expr)) } /** - * Returns a hashed value based on the input column. - * @group mask_funcs + * Returns the union of all the given maps. + * @group collection_funcs * @since 2.4.0 */ - def mask_hash(e: Column): Column = withExpr { MaskHash(e.expr) } + @scala.annotation.varargs + def map_concat(cols: Column*): Column = withExpr { MapConcat(cols.map(_.expr)) } // scalastyle:off line.size.limit // scalastyle:off parameter.number diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala index 6ae307bce10c8..4698e8ab13ce3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala @@ -364,7 +364,8 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { */ override def dropTempView(viewName: String): Boolean = { sparkSession.sessionState.catalog.getTempView(viewName).exists { viewDef => - sparkSession.sharedState.cacheManager.uncacheQuery(sparkSession, viewDef, blocking = true) + sparkSession.sharedState.cacheManager.uncacheQuery( + sparkSession, viewDef, cascade = false, blocking = true) sessionCatalog.dropTempView(viewName) } } @@ -379,7 +380,8 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { */ override def dropGlobalTempView(viewName: String): Boolean = { sparkSession.sessionState.catalog.getGlobalTempView(viewName).exists { viewDef => - sparkSession.sharedState.cacheManager.uncacheQuery(sparkSession, viewDef, blocking = true) + sparkSession.sharedState.cacheManager.uncacheQuery( + sparkSession, viewDef, cascade = false, blocking = true) sessionCatalog.dropGlobalTempView(viewName) } } @@ -438,7 +440,9 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { * @since 2.0.0 */ override def uncacheTable(tableName: String): Unit = { - sparkSession.sharedState.cacheManager.uncacheQuery(sparkSession.table(tableName)) + val tableIdent = sparkSession.sessionState.sqlParser.parseTableIdentifier(tableName) + val cascade = !sessionCatalog.isTemporaryTable(tableIdent) + sparkSession.sharedState.cacheManager.uncacheQuery(sparkSession.table(tableName), cascade) } /** @@ -490,7 +494,7 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { // cached version and make the new version cached lazily. if (isCached(table)) { // Uncache the logicalPlan. - sparkSession.sharedState.cacheManager.uncacheQuery(table, blocking = true) + sparkSession.sharedState.cacheManager.uncacheQuery(table, cascade = true, blocking = true) // Cache it again. sparkSession.sharedState.cacheManager.cacheQuery(table, Some(tableIdent.table)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/filters.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/filters.scala index 2499e9b604f3e..bdd8c4da6bd30 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/filters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/filters.scala @@ -199,7 +199,7 @@ case class StringStartsWith(attribute: String, value: String) extends Filter { /** * A filter that evaluates to `true` iff the attribute evaluates to - * a string that starts with `value`. + * a string that ends with `value`. * * @since 1.3.1 */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala index ae93965bc50ed..ef8dc3a325a33 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala @@ -270,6 +270,8 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo * per file *
  • `lineSep` (default covers all `\r`, `\r\n` and `\n`): defines the line separator * that should be used for parsing.
  • + *
  • `dropFieldIfAllNull` (default `false`): whether to ignore column of all null values or + * empty array/struct during schema inference.
  • * * * @since 2.0.0 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala index effc1471e8e12..3b9a56ffdde4b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala @@ -21,14 +21,15 @@ import java.util.Locale import scala.collection.JavaConverters._ -import org.apache.spark.annotation.InterfaceStability -import org.apache.spark.sql.{AnalysisException, Dataset, ForeachWriter} +import org.apache.spark.annotation.{InterfaceStability, Since} +import org.apache.spark.api.java.function.VoidFunction2 +import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.streaming.InternalOutputModes import org.apache.spark.sql.execution.command.DDLUtils import org.apache.spark.sql.execution.datasources.DataSource import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.continuous.ContinuousTrigger -import org.apache.spark.sql.execution.streaming.sources.{ForeachWriterProvider, MemoryPlanV2, MemorySinkV2} +import org.apache.spark.sql.execution.streaming.sources._ import org.apache.spark.sql.sources.v2.StreamWriteSupport /** @@ -269,7 +270,22 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) { query } else if (source == "foreach") { assertNotPartitioned("foreach") - val sink = new ForeachWriterProvider[T](foreachWriter)(ds.exprEnc) + val sink = ForeachWriterProvider[T](foreachWriter, ds.exprEnc) + df.sparkSession.sessionState.streamingQueryManager.startQuery( + extraOptions.get("queryName"), + extraOptions.get("checkpointLocation"), + df, + extraOptions.toMap, + sink, + outputMode, + useTempCheckpointLocation = true, + trigger = trigger) + } else if (source == "foreachBatch") { + assertNotPartitioned("foreachBatch") + if (trigger.isInstanceOf[ContinuousTrigger]) { + throw new AnalysisException("'foreachBatch' is not supported with continuous trigger") + } + val sink = new ForeachBatchSink[T](foreachBatchWriter, ds.exprEnc) df.sparkSession.sessionState.streamingQueryManager.startQuery( extraOptions.get("queryName"), extraOptions.get("checkpointLocation"), @@ -307,49 +323,9 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) { } /** - * Starts the execution of the streaming query, which will continually send results to the given - * `ForeachWriter` as new data arrives. The `ForeachWriter` can be used to send the data - * generated by the `DataFrame`/`Dataset` to an external system. - * - * Scala example: - * {{{ - * datasetOfString.writeStream.foreach(new ForeachWriter[String] { - * - * def open(partitionId: Long, version: Long): Boolean = { - * // open connection - * } - * - * def process(record: String) = { - * // write string to connection - * } - * - * def close(errorOrNull: Throwable): Unit = { - * // close the connection - * } - * }).start() - * }}} - * - * Java example: - * {{{ - * datasetOfString.writeStream().foreach(new ForeachWriter() { - * - * @Override - * public boolean open(long partitionId, long version) { - * // open connection - * } - * - * @Override - * public void process(String value) { - * // write string to connection - * } - * - * @Override - * public void close(Throwable errorOrNull) { - * // close the connection - * } - * }).start(); - * }}} - * + * Sets the output of the streaming query to be processed using the provided writer object. + * object. See [[org.apache.spark.sql.ForeachWriter]] for more details on the lifecycle and + * semantics. * @since 2.0.0 */ def foreach(writer: ForeachWriter[T]): DataStreamWriter[T] = { @@ -362,6 +338,45 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) { this } + /** + * :: Experimental :: + * + * (Scala-specific) Sets the output of the streaming query to be processed using the provided + * function. This is supported only the in the micro-batch execution modes (that is, when the + * trigger is not continuous). In every micro-batch, the provided function will be called in + * every micro-batch with (i) the output rows as a Dataset and (ii) the batch identifier. + * The batchId can be used deduplicate and transactionally write the output + * (that is, the provided Dataset) to external systems. The output Dataset is guaranteed + * to exactly same for the same batchId (assuming all operations are deterministic in the query). + * + * @since 2.4.0 + */ + @InterfaceStability.Evolving + def foreachBatch(function: (Dataset[T], Long) => Unit): DataStreamWriter[T] = { + this.source = "foreachBatch" + if (function == null) throw new IllegalArgumentException("foreachBatch function cannot be null") + this.foreachBatchWriter = function + this + } + + /** + * :: Experimental :: + * + * (Java-specific) Sets the output of the streaming query to be processed using the provided + * function. This is supported only the in the micro-batch execution modes (that is, when the + * trigger is not continuous). In every micro-batch, the provided function will be called in + * every micro-batch with (i) the output rows as a Dataset and (ii) the batch identifier. + * The batchId can be used deduplicate and transactionally write the output + * (that is, the provided Dataset) to external systems. The output Dataset is guaranteed + * to exactly same for the same batchId (assuming all operations are deterministic in the query). + * + * @since 2.4.0 + */ + @InterfaceStability.Evolving + def foreachBatch(function: VoidFunction2[Dataset[T], Long]): DataStreamWriter[T] = { + foreachBatch((batchDs: Dataset[T], batchId: Long) => function.call(batchDs, batchId)) + } + private def normalizedParCols: Option[Seq[String]] = partitioningColumns.map { cols => cols.map(normalize(_, "Partition")) } @@ -398,5 +413,7 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) { private var foreachWriter: ForeachWriter[T] = null + private var foreachBatchWriter: (Dataset[T], Long) => Unit = null + private var partitioningColumns: Option[Seq[String]] = None } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala index 97da2b1325f58..25bb05212d66f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala @@ -24,6 +24,7 @@ import scala.collection.mutable import org.apache.hadoop.fs.Path +import org.apache.spark.SparkException import org.apache.spark.annotation.InterfaceStability import org.apache.spark.internal.Logging import org.apache.spark.sql.{AnalysisException, DataFrame, SparkSession} @@ -32,6 +33,7 @@ import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.continuous.{ContinuousExecution, ContinuousTrigger} import org.apache.spark.sql.execution.streaming.state.StateStoreCoordinatorRef import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.internal.StaticSQLConf.STREAMING_QUERY_LISTENERS import org.apache.spark.sql.sources.v2.StreamWriteSupport import org.apache.spark.util.{Clock, SystemClock, Utils} @@ -55,6 +57,19 @@ class StreamingQueryManager private[sql] (sparkSession: SparkSession) extends Lo @GuardedBy("awaitTerminationLock") private var lastTerminatedQuery: StreamingQuery = null + try { + sparkSession.sparkContext.conf.get(STREAMING_QUERY_LISTENERS).foreach { classNames => + Utils.loadExtensions(classOf[StreamingQueryListener], classNames, + sparkSession.sparkContext.conf).foreach(listener => { + addListener(listener) + logInfo(s"Registered listener ${listener.getClass.getName}") + }) + } + } catch { + case e: Exception => + throw new SparkException("Exception when registering StreamingQueryListener", e) + } + /** * Returns a list of active queries associated with this SQLContext * diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java index c132cab1b38cf..2c695fc58fd8c 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java @@ -34,6 +34,7 @@ import org.junit.*; import org.junit.rules.ExpectedException; +import org.apache.spark.api.java.JavaPairRDD; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.api.java.function.*; import org.apache.spark.sql.*; @@ -336,6 +337,23 @@ public void testTupleEncoder() { Assert.assertEquals(data5, ds5.collectAsList()); } + @Test + public void testTupleEncoderSchema() { + Encoder>> encoder = + Encoders.tuple(Encoders.STRING(), Encoders.tuple(Encoders.STRING(), Encoders.STRING())); + List>> data = Arrays.asList(tuple2("1", tuple2("a", "b")), + tuple2("2", tuple2("c", "d"))); + Dataset ds1 = spark.createDataset(data, encoder).toDF("value1", "value2"); + + JavaPairRDD> pairRDD = jsc.parallelizePairs(data); + Dataset ds2 = spark.createDataset(JavaPairRDD.toRDD(pairRDD), encoder) + .toDF("value1", "value2"); + + Assert.assertEquals(ds1.schema(), ds2.schema()); + Assert.assertEquals(ds1.select(expr("value2._1")).collectAsList(), + ds2.select(expr("value2._1")).collectAsList()); + } + @Test public void testNestedTupleEncoder() { // test ((int, string), string) diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/execution/sort/RecordBinaryComparatorSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/execution/sort/RecordBinaryComparatorSuite.java new file mode 100644 index 0000000000000..a19ddbdbadba2 --- /dev/null +++ b/sql/core/src/test/java/test/org/apache/spark/sql/execution/sort/RecordBinaryComparatorSuite.java @@ -0,0 +1,256 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package test.org.apache.spark.sql.execution.sort; + +import org.apache.spark.SparkConf; +import org.apache.spark.memory.TaskMemoryManager; +import org.apache.spark.memory.TestMemoryConsumer; +import org.apache.spark.memory.TestMemoryManager; +import org.apache.spark.sql.catalyst.expressions.UnsafeArrayData; +import org.apache.spark.sql.catalyst.expressions.UnsafeRow; +import org.apache.spark.sql.execution.RecordBinaryComparator; +import org.apache.spark.unsafe.Platform; +import org.apache.spark.unsafe.UnsafeAlignedOffset; +import org.apache.spark.unsafe.array.LongArray; +import org.apache.spark.unsafe.memory.MemoryBlock; +import org.apache.spark.unsafe.types.UTF8String; +import org.apache.spark.util.collection.unsafe.sort.*; + +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +/** + * Test the RecordBinaryComparator, which compares two UnsafeRows by their binary form. + */ +public class RecordBinaryComparatorSuite { + + private final TaskMemoryManager memoryManager = new TaskMemoryManager( + new TestMemoryManager(new SparkConf().set("spark.memory.offHeap.enabled", "false")), 0); + private final TestMemoryConsumer consumer = new TestMemoryConsumer(memoryManager); + + private final int uaoSize = UnsafeAlignedOffset.getUaoSize(); + + private MemoryBlock dataPage; + private long pageCursor; + + private LongArray array; + private int pos; + + @Before + public void beforeEach() { + // Only compare between two input rows. + array = consumer.allocateArray(2); + pos = 0; + + dataPage = memoryManager.allocatePage(4096, consumer); + pageCursor = dataPage.getBaseOffset(); + } + + @After + public void afterEach() { + consumer.freePage(dataPage); + dataPage = null; + pageCursor = 0; + + consumer.freeArray(array); + array = null; + pos = 0; + } + + private void insertRow(UnsafeRow row) { + Object recordBase = row.getBaseObject(); + long recordOffset = row.getBaseOffset(); + int recordLength = row.getSizeInBytes(); + + Object baseObject = dataPage.getBaseObject(); + assert(pageCursor + recordLength <= dataPage.getBaseOffset() + dataPage.size()); + long recordAddress = memoryManager.encodePageNumberAndOffset(dataPage, pageCursor); + UnsafeAlignedOffset.putSize(baseObject, pageCursor, recordLength); + pageCursor += uaoSize; + Platform.copyMemory(recordBase, recordOffset, baseObject, pageCursor, recordLength); + pageCursor += recordLength; + + assert(pos < 2); + array.set(pos, recordAddress); + pos++; + } + + private int compare(int index1, int index2) { + Object baseObject = dataPage.getBaseObject(); + + long recordAddress1 = array.get(index1); + long baseOffset1 = memoryManager.getOffsetInPage(recordAddress1) + uaoSize; + int recordLength1 = UnsafeAlignedOffset.getSize(baseObject, baseOffset1 - uaoSize); + + long recordAddress2 = array.get(index2); + long baseOffset2 = memoryManager.getOffsetInPage(recordAddress2) + uaoSize; + int recordLength2 = UnsafeAlignedOffset.getSize(baseObject, baseOffset2 - uaoSize); + + return binaryComparator.compare(baseObject, baseOffset1, recordLength1, baseObject, + baseOffset2, recordLength2); + } + + private final RecordComparator binaryComparator = new RecordBinaryComparator(); + + // Compute the most compact size for UnsafeRow's backing data. + private int computeSizeInBytes(int originalSize) { + // All the UnsafeRows in this suite contains less than 64 columns, so the bitSetSize shall + // always be 8. + return 8 + (originalSize + 7) / 8 * 8; + } + + // Compute the relative offset of variable-length values. + private long relativeOffset(int numFields) { + // All the UnsafeRows in this suite contains less than 64 columns, so the bitSetSize shall + // always be 8. + return 8 + numFields * 8L; + } + + @Test + public void testBinaryComparatorForSingleColumnRow() throws Exception { + int numFields = 1; + + UnsafeRow row1 = new UnsafeRow(numFields); + byte[] data1 = new byte[100]; + row1.pointTo(data1, computeSizeInBytes(numFields * 8)); + row1.setInt(0, 11); + + UnsafeRow row2 = new UnsafeRow(numFields); + byte[] data2 = new byte[100]; + row2.pointTo(data2, computeSizeInBytes(numFields * 8)); + row2.setInt(0, 42); + + insertRow(row1); + insertRow(row2); + + assert(compare(0, 0) == 0); + assert(compare(0, 1) < 0); + } + + @Test + public void testBinaryComparatorForMultipleColumnRow() throws Exception { + int numFields = 5; + + UnsafeRow row1 = new UnsafeRow(numFields); + byte[] data1 = new byte[100]; + row1.pointTo(data1, computeSizeInBytes(numFields * 8)); + for (int i = 0; i < numFields; i++) { + row1.setDouble(i, i * 3.14); + } + + UnsafeRow row2 = new UnsafeRow(numFields); + byte[] data2 = new byte[100]; + row2.pointTo(data2, computeSizeInBytes(numFields * 8)); + for (int i = 0; i < numFields; i++) { + row2.setDouble(i, 198.7 / (i + 1)); + } + + insertRow(row1); + insertRow(row2); + + assert(compare(0, 0) == 0); + assert(compare(0, 1) < 0); + } + + @Test + public void testBinaryComparatorForArrayColumn() throws Exception { + int numFields = 1; + + UnsafeRow row1 = new UnsafeRow(numFields); + byte[] data1 = new byte[100]; + UnsafeArrayData arrayData1 = UnsafeArrayData.fromPrimitiveArray(new int[]{11, 42, -1}); + row1.pointTo(data1, computeSizeInBytes(numFields * 8 + arrayData1.getSizeInBytes())); + row1.setLong(0, (relativeOffset(numFields) << 32) | (long) arrayData1.getSizeInBytes()); + Platform.copyMemory(arrayData1.getBaseObject(), arrayData1.getBaseOffset(), data1, + row1.getBaseOffset() + relativeOffset(numFields), arrayData1.getSizeInBytes()); + + UnsafeRow row2 = new UnsafeRow(numFields); + byte[] data2 = new byte[100]; + UnsafeArrayData arrayData2 = UnsafeArrayData.fromPrimitiveArray(new int[]{22}); + row2.pointTo(data2, computeSizeInBytes(numFields * 8 + arrayData2.getSizeInBytes())); + row2.setLong(0, (relativeOffset(numFields) << 32) | (long) arrayData2.getSizeInBytes()); + Platform.copyMemory(arrayData2.getBaseObject(), arrayData2.getBaseOffset(), data2, + row2.getBaseOffset() + relativeOffset(numFields), arrayData2.getSizeInBytes()); + + insertRow(row1); + insertRow(row2); + + assert(compare(0, 0) == 0); + assert(compare(0, 1) > 0); + } + + @Test + public void testBinaryComparatorForMixedColumns() throws Exception { + int numFields = 4; + + UnsafeRow row1 = new UnsafeRow(numFields); + byte[] data1 = new byte[100]; + UTF8String str1 = UTF8String.fromString("Milk tea"); + row1.pointTo(data1, computeSizeInBytes(numFields * 8 + str1.numBytes())); + row1.setInt(0, 11); + row1.setDouble(1, 3.14); + row1.setInt(2, -1); + row1.setLong(3, (relativeOffset(numFields) << 32) | (long) str1.numBytes()); + Platform.copyMemory(str1.getBaseObject(), str1.getBaseOffset(), data1, + row1.getBaseOffset() + relativeOffset(numFields), str1.numBytes()); + + UnsafeRow row2 = new UnsafeRow(numFields); + byte[] data2 = new byte[100]; + UTF8String str2 = UTF8String.fromString("Java"); + row2.pointTo(data2, computeSizeInBytes(numFields * 8 + str2.numBytes())); + row2.setInt(0, 11); + row2.setDouble(1, 3.14); + row2.setInt(2, -1); + row2.setLong(3, (relativeOffset(numFields) << 32) | (long) str2.numBytes()); + Platform.copyMemory(str2.getBaseObject(), str2.getBaseOffset(), data2, + row2.getBaseOffset() + relativeOffset(numFields), str2.numBytes()); + + insertRow(row1); + insertRow(row2); + + assert(compare(0, 0) == 0); + assert(compare(0, 1) > 0); + } + + @Test + public void testBinaryComparatorForNullColumns() throws Exception { + int numFields = 3; + + UnsafeRow row1 = new UnsafeRow(numFields); + byte[] data1 = new byte[100]; + row1.pointTo(data1, computeSizeInBytes(numFields * 8)); + for (int i = 0; i < numFields; i++) { + row1.setNullAt(i); + } + + UnsafeRow row2 = new UnsafeRow(numFields); + byte[] data2 = new byte[100]; + row2.pointTo(data2, computeSizeInBytes(numFields * 8)); + for (int i = 0; i < numFields - 1; i++) { + row2.setNullAt(i); + } + row2.setDouble(numFields - 1, 3.14); + + insertRow(row1); + insertRow(row2); + + assert(compare(0, 0) == 0); + assert(compare(0, 1) > 0); + } +} diff --git a/sql/core/src/test/resources/sql-tests/inputs/join-empty-relation.sql b/sql/core/src/test/resources/sql-tests/inputs/join-empty-relation.sql index 8afa3270f4de4..2e6a5f362a8fa 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/join-empty-relation.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/join-empty-relation.sql @@ -1,3 +1,8 @@ +-- List of configuration the test suite is run against: +--SET spark.sql.autoBroadcastJoinThreshold=10485760 +--SET spark.sql.autoBroadcastJoinThreshold=-1,spark.sql.join.preferSortMergeJoin=true +--SET spark.sql.autoBroadcastJoinThreshold=-1,spark.sql.join.preferSortMergeJoin=false + CREATE TEMPORARY VIEW t1 AS SELECT * FROM VALUES (1) AS GROUPING(a); CREATE TEMPORARY VIEW t2 AS SELECT * FROM VALUES (1) AS GROUPING(a); diff --git a/sql/core/src/test/resources/sql-tests/inputs/json-functions.sql b/sql/core/src/test/resources/sql-tests/inputs/json-functions.sql index fea069eac4d48..79fdd5895e691 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/json-functions.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/json-functions.sql @@ -31,3 +31,11 @@ CREATE TEMPORARY VIEW jsonTable(jsonField, a) AS SELECT * FROM VALUES ('{"a": 1, SELECT json_tuple(jsonField, 'b', CAST(NULL AS STRING), a) FROM jsonTable; -- Clean up DROP VIEW IF EXISTS jsonTable; + +-- from_json - complex types +select from_json('{"a":1, "b":2}', 'map'); +select from_json('{"a":1, "b":"2"}', 'struct'); + +-- infer schema of json literal +select schema_of_json('{"c1":0, "c2":[1]}'); +select from_json('{"c1":[1, 2, 3]}', schema_of_json('{"c1":[0]}')); diff --git a/sql/core/src/test/resources/sql-tests/inputs/natural-join.sql b/sql/core/src/test/resources/sql-tests/inputs/natural-join.sql index 71a50157b766c..e0abeda3eb44f 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/natural-join.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/natural-join.sql @@ -1,3 +1,8 @@ +-- List of configuration the test suite is run against: +--SET spark.sql.autoBroadcastJoinThreshold=10485760 +--SET spark.sql.autoBroadcastJoinThreshold=-1,spark.sql.join.preferSortMergeJoin=true +--SET spark.sql.autoBroadcastJoinThreshold=-1,spark.sql.join.preferSortMergeJoin=false + create temporary view nt1 as select * from values ("one", 1), ("two", 2), diff --git a/sql/core/src/test/resources/sql-tests/inputs/outer-join.sql b/sql/core/src/test/resources/sql-tests/inputs/outer-join.sql index cdc6c81e10047..ce09c21568f13 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/outer-join.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/outer-join.sql @@ -1,3 +1,8 @@ +-- List of configuration the test suite is run against: +--SET spark.sql.autoBroadcastJoinThreshold=10485760 +--SET spark.sql.autoBroadcastJoinThreshold=-1,spark.sql.join.preferSortMergeJoin=true +--SET spark.sql.autoBroadcastJoinThreshold=-1,spark.sql.join.preferSortMergeJoin=false + -- SPARK-17099: Incorrect result when HAVING clause is added to group by query CREATE OR REPLACE TEMPORARY VIEW t1 AS SELECT * FROM VALUES (-234), (145), (367), (975), (298) diff --git a/sql/core/src/test/resources/sql-tests/inputs/pivot.sql b/sql/core/src/test/resources/sql-tests/inputs/pivot.sql index 01dea6c81c11b..b3d53adfbebe7 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/pivot.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/pivot.sql @@ -111,3 +111,21 @@ PIVOT ( sum(earnings) FOR year IN (2012, 2013) ); + +-- pivot with complex aggregate expressions +SELECT * FROM ( + SELECT year, course, earnings FROM courseSales +) +PIVOT ( + ceil(sum(earnings)), avg(earnings) + 1 as a1 + FOR course IN ('dotNET', 'Java') +); + +-- pivot with invalid arguments in aggregate expressions +SELECT * FROM ( + SELECT year, course, earnings FROM courseSales +) +PIVOT ( + sum(avg(earnings)) + FOR course IN ('dotNET', 'Java') +); diff --git a/sql/core/src/test/resources/sql-tests/inputs/subquery/exists-subquery/exists-joins-and-set-ops.sql b/sql/core/src/test/resources/sql-tests/inputs/subquery/exists-subquery/exists-joins-and-set-ops.sql index cc4ed64affec7..cefc3fe6272ab 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/subquery/exists-subquery/exists-joins-and-set-ops.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/subquery/exists-subquery/exists-joins-and-set-ops.sql @@ -1,5 +1,9 @@ -- Tests EXISTS subquery support. Tests Exists subquery -- used in Joins (Both when joins occurs in outer and suquery blocks) +-- List of configuration the test suite is run against: +--SET spark.sql.autoBroadcastJoinThreshold=10485760 +--SET spark.sql.autoBroadcastJoinThreshold=-1,spark.sql.join.preferSortMergeJoin=true +--SET spark.sql.autoBroadcastJoinThreshold=-1,spark.sql.join.preferSortMergeJoin=false CREATE TEMPORARY VIEW EMP AS SELECT * FROM VALUES (100, "emp 1", date "2005-01-01", 100.00D, 10), diff --git a/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/in-joins.sql b/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/in-joins.sql index 880175fd7add0..22f3eafd6a02d 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/in-joins.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/in-joins.sql @@ -1,5 +1,9 @@ -- A test suite for IN JOINS in parent side, subquery, and both predicate subquery -- It includes correlated cases. +-- List of configuration the test suite is run against: +--SET spark.sql.autoBroadcastJoinThreshold=10485760 +--SET spark.sql.autoBroadcastJoinThreshold=-1,spark.sql.join.preferSortMergeJoin=true +--SET spark.sql.autoBroadcastJoinThreshold=-1,spark.sql.join.preferSortMergeJoin=false create temporary view t1 as select * from values ("val1a", 6S, 8, 10L, float(15.0), 20D, 20E2, timestamp '2014-04-04 01:00:00.000', date '2014-04-04'), diff --git a/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/not-in-joins.sql b/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/not-in-joins.sql index e09b91f18de0a..4f8ca8bfb27c1 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/not-in-joins.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/not-in-joins.sql @@ -1,5 +1,9 @@ -- A test suite for not-in-joins in parent side, subquery, and both predicate subquery -- It includes correlated cases. +-- List of configuration the test suite is run against: +--SET spark.sql.autoBroadcastJoinThreshold=10485760 +--SET spark.sql.autoBroadcastJoinThreshold=-1,spark.sql.join.preferSortMergeJoin=true +--SET spark.sql.autoBroadcastJoinThreshold=-1,spark.sql.join.preferSortMergeJoin=false create temporary view t1 as select * from values ("val1a", 6S, 8, 10L, float(15.0), 20D, 20E2, timestamp '2014-04-04 01:00:00.000', date '2014-04-04'), diff --git a/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/arrayJoin.sql b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/arrayJoin.sql new file mode 100644 index 0000000000000..99729c007b104 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/arrayJoin.sql @@ -0,0 +1,11 @@ +SELECT array_join(array(true, false), ', '); +SELECT array_join(array(2Y, 1Y), ', '); +SELECT array_join(array(2S, 1S), ', '); +SELECT array_join(array(2, 1), ', '); +SELECT array_join(array(2L, 1L), ', '); +SELECT array_join(array(9223372036854775809, 9223372036854775808), ', '); +SELECT array_join(array(2.0D, 1.0D), ', '); +SELECT array_join(array(float(2.0), float(1.0)), ', '); +SELECT array_join(array(date '2016-03-14', date '2016-03-13'), ', '); +SELECT array_join(array(timestamp '2016-11-15 20:54:00.000', timestamp '2016-11-12 20:54:00.000'), ', '); +SELECT array_join(array('a', 'b'), ', '); diff --git a/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/mapconcat.sql b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/mapconcat.sql new file mode 100644 index 0000000000000..fc26397b881b5 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/mapconcat.sql @@ -0,0 +1,94 @@ +CREATE TEMPORARY VIEW various_maps AS SELECT * FROM VALUES ( + map(true, false), map(false, true), + map(1Y, 2Y), map(3Y, 4Y), + map(1S, 2S), map(3S, 4S), + map(4, 6), map(7, 8), + map(6L, 7L), map(8L, 9L), + map(9223372036854775809, 9223372036854775808), map(9223372036854775808, 9223372036854775809), + map(1.0D, 2.0D), map(3.0D, 4.0D), + map(float(1.0D), float(2.0D)), map(float(3.0D), float(4.0D)), + map(date '2016-03-14', date '2016-03-13'), map(date '2016-03-12', date '2016-03-11'), + map(timestamp '2016-11-15 20:54:00.000', timestamp '2016-11-12 20:54:00.000'), + map(timestamp '2016-11-11 20:54:00.000', timestamp '2016-11-09 20:54:00.000'), + map('a', 'b'), map('c', 'd'), + map(array('a', 'b'), array('c', 'd')), map(array('e'), array('f')), + map(struct('a', 1), struct('b', 2)), map(struct('c', 3), struct('d', 4)), + map(map('a', 1), map('b', 2)), map(map('c', 3), map('d', 4)), + map('a', 1), map('c', 2), + map(1, 'a'), map(2, 'c') +) AS various_maps ( + boolean_map1, boolean_map2, + tinyint_map1, tinyint_map2, + smallint_map1, smallint_map2, + int_map1, int_map2, + bigint_map1, bigint_map2, + decimal_map1, decimal_map2, + double_map1, double_map2, + float_map1, float_map2, + date_map1, date_map2, + timestamp_map1, + timestamp_map2, + string_map1, string_map2, + array_map1, array_map2, + struct_map1, struct_map2, + map_map1, map_map2, + string_int_map1, string_int_map2, + int_string_map1, int_string_map2 +); + +-- Concatenate maps of the same type +SELECT + map_concat(boolean_map1, boolean_map2) boolean_map, + map_concat(tinyint_map1, tinyint_map2) tinyint_map, + map_concat(smallint_map1, smallint_map2) smallint_map, + map_concat(int_map1, int_map2) int_map, + map_concat(bigint_map1, bigint_map2) bigint_map, + map_concat(decimal_map1, decimal_map2) decimal_map, + map_concat(float_map1, float_map2) float_map, + map_concat(double_map1, double_map2) double_map, + map_concat(date_map1, date_map2) date_map, + map_concat(timestamp_map1, timestamp_map2) timestamp_map, + map_concat(string_map1, string_map2) string_map, + map_concat(array_map1, array_map2) array_map, + map_concat(struct_map1, struct_map2) struct_map, + map_concat(map_map1, map_map2) map_map, + map_concat(string_int_map1, string_int_map2) string_int_map, + map_concat(int_string_map1, int_string_map2) int_string_map +FROM various_maps; + +-- Concatenate maps of different types +SELECT + map_concat(tinyint_map1, smallint_map2) ts_map, + map_concat(smallint_map1, int_map2) si_map, + map_concat(int_map1, bigint_map2) ib_map, + map_concat(decimal_map1, float_map2) df_map, + map_concat(string_map1, date_map2) std_map, + map_concat(timestamp_map1, string_map2) tst_map, + map_concat(string_map1, int_map2) sti_map, + map_concat(int_string_map1, tinyint_map2) istt_map +FROM various_maps; + +-- Concatenate map of incompatible types 1 +SELECT + map_concat(tinyint_map1, map_map2) tm_map +FROM various_maps; + +-- Concatenate map of incompatible types 2 +SELECT + map_concat(boolean_map1, int_map2) bi_map +FROM various_maps; + +-- Concatenate map of incompatible types 3 +SELECT + map_concat(int_map1, struct_map2) is_map +FROM various_maps; + +-- Concatenate map of incompatible types 4 +SELECT + map_concat(map_map1, array_map2) ma_map +FROM various_maps; + +-- Concatenate map of incompatible types 5 +SELECT + map_concat(map_map1, struct_map2) ms_map +FROM various_maps; diff --git a/sql/core/src/test/resources/sql-tests/results/describe-part-after-analyze.sql.out b/sql/core/src/test/resources/sql-tests/results/describe-part-after-analyze.sql.out index 58ed201e2a60f..8ba69c698b551 100644 --- a/sql/core/src/test/resources/sql-tests/results/describe-part-after-analyze.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/describe-part-after-analyze.sql.out @@ -57,6 +57,8 @@ Database default Table t Partition Values [ds=2017-08-01, hr=10] Location [not included in comparison]sql/core/spark-warehouse/t/ds=2017-08-01/hr=10 +Created Time [not included in comparison] +Last Access [not included in comparison] # Storage Information Location [not included in comparison]sql/core/spark-warehouse/t @@ -89,6 +91,8 @@ Database default Table t Partition Values [ds=2017-08-01, hr=10] Location [not included in comparison]sql/core/spark-warehouse/t/ds=2017-08-01/hr=10 +Created Time [not included in comparison] +Last Access [not included in comparison] Partition Statistics 1121 bytes, 3 rows # Storage Information @@ -122,6 +126,8 @@ Database default Table t Partition Values [ds=2017-08-01, hr=10] Location [not included in comparison]sql/core/spark-warehouse/t/ds=2017-08-01/hr=10 +Created Time [not included in comparison] +Last Access [not included in comparison] Partition Statistics 1121 bytes, 3 rows # Storage Information @@ -147,6 +153,8 @@ Database default Table t Partition Values [ds=2017-08-01, hr=11] Location [not included in comparison]sql/core/spark-warehouse/t/ds=2017-08-01/hr=11 +Created Time [not included in comparison] +Last Access [not included in comparison] Partition Statistics 1098 bytes, 4 rows # Storage Information @@ -180,6 +188,8 @@ Database default Table t Partition Values [ds=2017-08-01, hr=10] Location [not included in comparison]sql/core/spark-warehouse/t/ds=2017-08-01/hr=10 +Created Time [not included in comparison] +Last Access [not included in comparison] Partition Statistics 1121 bytes, 3 rows # Storage Information @@ -205,6 +215,8 @@ Database default Table t Partition Values [ds=2017-08-01, hr=11] Location [not included in comparison]sql/core/spark-warehouse/t/ds=2017-08-01/hr=11 +Created Time [not included in comparison] +Last Access [not included in comparison] Partition Statistics 1098 bytes, 4 rows # Storage Information @@ -230,6 +242,8 @@ Database default Table t Partition Values [ds=2017-09-01, hr=5] Location [not included in comparison]sql/core/spark-warehouse/t/ds=2017-09-01/hr=5 +Created Time [not included in comparison] +Last Access [not included in comparison] Partition Statistics 1144 bytes, 2 rows # Storage Information diff --git a/sql/core/src/test/resources/sql-tests/results/describe.sql.out b/sql/core/src/test/resources/sql-tests/results/describe.sql.out index 8c908b7625056..79390cb424444 100644 --- a/sql/core/src/test/resources/sql-tests/results/describe.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/describe.sql.out @@ -282,6 +282,8 @@ Table t Partition Values [c=Us, d=1] Location [not included in comparison]sql/core/spark-warehouse/t/c=Us/d=1 Storage Properties [a=1, b=2] +Created Time [not included in comparison] +Last Access [not included in comparison] # Storage Information Num Buckets 2 @@ -311,6 +313,8 @@ Table t Partition Values [c=Us, d=1] Location [not included in comparison]sql/core/spark-warehouse/t/c=Us/d=1 Storage Properties [a=1, b=2] +Created Time [not included in comparison] +Last Access [not included in comparison] # Storage Information Num Buckets 2 diff --git a/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out index 14a69128ffb41..3d49323751a10 100644 --- a/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 26 +-- Number of queries: 30 -- !query 0 @@ -183,7 +183,7 @@ select from_json('{"a":1}', 1) struct<> -- !query 17 output org.apache.spark.sql.AnalysisException -Expected a string literal instead of 1;; line 1 pos 7 +Schema should be specified in DDL format as a string literal or output of the schema_of_json function instead of 1;; line 1 pos 7 -- !query 18 @@ -258,3 +258,35 @@ DROP VIEW IF EXISTS jsonTable struct<> -- !query 25 output + + +-- !query 26 +select from_json('{"a":1, "b":2}', 'map') +-- !query 26 schema +struct> +-- !query 26 output +{"a":1,"b":2} + + +-- !query 27 +select from_json('{"a":1, "b":"2"}', 'struct') +-- !query 27 schema +struct> +-- !query 27 output +{"a":1,"b":"2"} + + +-- !query 28 +select schema_of_json('{"c1":0, "c2":[1]}') +-- !query 28 schema +struct +-- !query 28 output +struct> + + +-- !query 29 +select from_json('{"c1":[1, 2, 3]}', schema_of_json('{"c1":[0]}')) +-- !query 29 schema +struct>> +-- !query 29 output +{"c1":[1,2,3]} diff --git a/sql/core/src/test/resources/sql-tests/results/pivot.sql.out b/sql/core/src/test/resources/sql-tests/results/pivot.sql.out index 85e3488990e20..922d8b9f9152c 100644 --- a/sql/core/src/test/resources/sql-tests/results/pivot.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/pivot.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 13 +-- Number of queries: 15 -- !query 0 @@ -176,7 +176,7 @@ PIVOT ( struct<> -- !query 11 output org.apache.spark.sql.AnalysisException -Aggregate expression required for pivot, found 'abs(earnings#x)'; +Aggregate expression required for pivot, but 'coursesales.`earnings`' did not appear in any aggregate function.; -- !query 12 @@ -192,3 +192,33 @@ struct<> -- !query 12 output org.apache.spark.sql.AnalysisException cannot resolve '`year`' given input columns: [__auto_generated_subquery_name.course, __auto_generated_subquery_name.earnings]; line 4 pos 0 + + +-- !query 13 +SELECT * FROM ( + SELECT year, course, earnings FROM courseSales +) +PIVOT ( + ceil(sum(earnings)), avg(earnings) + 1 as a1 + FOR course IN ('dotNET', 'Java') +) +-- !query 13 schema +struct +-- !query 13 output +2012 15000 7501.0 20000 20001.0 +2013 48000 48001.0 30000 30001.0 + + +-- !query 14 +SELECT * FROM ( + SELECT year, course, earnings FROM courseSales +) +PIVOT ( + sum(avg(earnings)) + FOR course IN ('dotNET', 'Java') +) +-- !query 14 schema +struct<> +-- !query 14 output +org.apache.spark.sql.AnalysisException +It is not allowed to use an aggregate function in the argument of another aggregate function. Please use the inner aggregate function in a sub-query.; diff --git a/sql/core/src/test/resources/sql-tests/results/show-tables.sql.out b/sql/core/src/test/resources/sql-tests/results/show-tables.sql.out index 975bb06124744..abeb7e18f031e 100644 --- a/sql/core/src/test/resources/sql-tests/results/show-tables.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/show-tables.sql.out @@ -178,6 +178,8 @@ struct -- !query 14 output showdb show_t1 false Partition Values: [c=Us, d=1] Location [not included in comparison]sql/core/spark-warehouse/showdb.db/show_t1/c=Us/d=1 +Created Time [not included in comparison] +Last Access [not included in comparison] -- !query 15 diff --git a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/arrayJoin.sql.out b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/arrayJoin.sql.out new file mode 100644 index 0000000000000..b23a62dacef7c --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/arrayJoin.sql.out @@ -0,0 +1,90 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 11 + + +-- !query 0 +SELECT array_join(array(true, false), ', ') +-- !query 0 schema +struct +-- !query 0 output +true, false + + +-- !query 1 +SELECT array_join(array(2Y, 1Y), ', ') +-- !query 1 schema +struct +-- !query 1 output +2, 1 + + +-- !query 2 +SELECT array_join(array(2S, 1S), ', ') +-- !query 2 schema +struct +-- !query 2 output +2, 1 + + +-- !query 3 +SELECT array_join(array(2, 1), ', ') +-- !query 3 schema +struct +-- !query 3 output +2, 1 + + +-- !query 4 +SELECT array_join(array(2L, 1L), ', ') +-- !query 4 schema +struct +-- !query 4 output +2, 1 + + +-- !query 5 +SELECT array_join(array(9223372036854775809, 9223372036854775808), ', ') +-- !query 5 schema +struct +-- !query 5 output +9223372036854775809, 9223372036854775808 + + +-- !query 6 +SELECT array_join(array(2.0D, 1.0D), ', ') +-- !query 6 schema +struct +-- !query 6 output +2.0, 1.0 + + +-- !query 7 +SELECT array_join(array(float(2.0), float(1.0)), ', ') +-- !query 7 schema +struct +-- !query 7 output +2.0, 1.0 + + +-- !query 8 +SELECT array_join(array(date '2016-03-14', date '2016-03-13'), ', ') +-- !query 8 schema +struct +-- !query 8 output +2016-03-14, 2016-03-13 + + +-- !query 9 +SELECT array_join(array(timestamp '2016-11-15 20:54:00.000', timestamp '2016-11-12 20:54:00.000'), ', ') +-- !query 9 schema +struct +-- !query 9 output +2016-11-15 20:54:00, 2016-11-12 20:54:00 + + +-- !query 10 +SELECT array_join(array('a', 'b'), ', ') +-- !query 10 schema +struct +-- !query 10 output +a, b diff --git a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/mapconcat.sql.out b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/mapconcat.sql.out new file mode 100644 index 0000000000000..d352b7284ae87 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/mapconcat.sql.out @@ -0,0 +1,143 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 8 + + +-- !query 0 +CREATE TEMPORARY VIEW various_maps AS SELECT * FROM VALUES ( + map(true, false), map(false, true), + map(1Y, 2Y), map(3Y, 4Y), + map(1S, 2S), map(3S, 4S), + map(4, 6), map(7, 8), + map(6L, 7L), map(8L, 9L), + map(9223372036854775809, 9223372036854775808), map(9223372036854775808, 9223372036854775809), + map(1.0D, 2.0D), map(3.0D, 4.0D), + map(float(1.0D), float(2.0D)), map(float(3.0D), float(4.0D)), + map(date '2016-03-14', date '2016-03-13'), map(date '2016-03-12', date '2016-03-11'), + map(timestamp '2016-11-15 20:54:00.000', timestamp '2016-11-12 20:54:00.000'), + map(timestamp '2016-11-11 20:54:00.000', timestamp '2016-11-09 20:54:00.000'), + map('a', 'b'), map('c', 'd'), + map(array('a', 'b'), array('c', 'd')), map(array('e'), array('f')), + map(struct('a', 1), struct('b', 2)), map(struct('c', 3), struct('d', 4)), + map(map('a', 1), map('b', 2)), map(map('c', 3), map('d', 4)), + map('a', 1), map('c', 2), + map(1, 'a'), map(2, 'c') +) AS various_maps ( + boolean_map1, boolean_map2, + tinyint_map1, tinyint_map2, + smallint_map1, smallint_map2, + int_map1, int_map2, + bigint_map1, bigint_map2, + decimal_map1, decimal_map2, + double_map1, double_map2, + float_map1, float_map2, + date_map1, date_map2, + timestamp_map1, + timestamp_map2, + string_map1, string_map2, + array_map1, array_map2, + struct_map1, struct_map2, + map_map1, map_map2, + string_int_map1, string_int_map2, + int_string_map1, int_string_map2 +) +-- !query 0 schema +struct<> +-- !query 0 output + + + +-- !query 1 +SELECT + map_concat(boolean_map1, boolean_map2) boolean_map, + map_concat(tinyint_map1, tinyint_map2) tinyint_map, + map_concat(smallint_map1, smallint_map2) smallint_map, + map_concat(int_map1, int_map2) int_map, + map_concat(bigint_map1, bigint_map2) bigint_map, + map_concat(decimal_map1, decimal_map2) decimal_map, + map_concat(float_map1, float_map2) float_map, + map_concat(double_map1, double_map2) double_map, + map_concat(date_map1, date_map2) date_map, + map_concat(timestamp_map1, timestamp_map2) timestamp_map, + map_concat(string_map1, string_map2) string_map, + map_concat(array_map1, array_map2) array_map, + map_concat(struct_map1, struct_map2) struct_map, + map_concat(map_map1, map_map2) map_map, + map_concat(string_int_map1, string_int_map2) string_int_map, + map_concat(int_string_map1, int_string_map2) int_string_map +FROM various_maps +-- !query 1 schema +struct,tinyint_map:map,smallint_map:map,int_map:map,bigint_map:map,decimal_map:map,float_map:map,double_map:map,date_map:map,timestamp_map:map,string_map:map,array_map:map,array>,struct_map:map,struct>,map_map:map,map>,string_int_map:map,int_string_map:map> +-- !query 1 output +{false:true,true:false} {1:2,3:4} {1:2,3:4} {4:6,7:8} {6:7,8:9} {9223372036854775808:9223372036854775809,9223372036854775809:9223372036854775808} {1.0:2.0,3.0:4.0} {1.0:2.0,3.0:4.0} {2016-03-12:2016-03-11,2016-03-14:2016-03-13} {2016-11-11 20:54:00.0:2016-11-09 20:54:00.0,2016-11-15 20:54:00.0:2016-11-12 20:54:00.0} {"a":"b","c":"d"} {["a","b"]:["c","d"],["e"]:["f"]} {{"col1":"a","col2":1}:{"col1":"b","col2":2},{"col1":"c","col2":3}:{"col1":"d","col2":4}} {{"a":1}:{"b":2},{"c":3}:{"d":4}} {"a":1,"c":2} {1:"a",2:"c"} + + +-- !query 2 +SELECT + map_concat(tinyint_map1, smallint_map2) ts_map, + map_concat(smallint_map1, int_map2) si_map, + map_concat(int_map1, bigint_map2) ib_map, + map_concat(decimal_map1, float_map2) df_map, + map_concat(string_map1, date_map2) std_map, + map_concat(timestamp_map1, string_map2) tst_map, + map_concat(string_map1, int_map2) sti_map, + map_concat(int_string_map1, tinyint_map2) istt_map +FROM various_maps +-- !query 2 schema +struct,si_map:map,ib_map:map,df_map:map,std_map:map,tst_map:map,sti_map:map,istt_map:map> +-- !query 2 output +{1:2,3:4} {1:2,7:8} {4:6,8:9} {3.0:4.0,9.223372036854776E18:9.223372036854776E18} {"2016-03-12":"2016-03-11","a":"b"} {"2016-11-15 20:54:00":"2016-11-12 20:54:00","c":"d"} {"7":"8","a":"b"} {1:"a",3:"4"} + + +-- !query 3 +SELECT + map_concat(tinyint_map1, map_map2) tm_map +FROM various_maps +-- !query 3 schema +struct<> +-- !query 3 output +org.apache.spark.sql.AnalysisException +cannot resolve 'map_concat(various_maps.`tinyint_map1`, various_maps.`map_map2`)' due to data type mismatch: input to function map_concat should all be the same type, but it's [map, map,map>]; line 2 pos 4 + + +-- !query 4 +SELECT + map_concat(boolean_map1, int_map2) bi_map +FROM various_maps +-- !query 4 schema +struct<> +-- !query 4 output +org.apache.spark.sql.AnalysisException +cannot resolve 'map_concat(various_maps.`boolean_map1`, various_maps.`int_map2`)' due to data type mismatch: input to function map_concat should all be the same type, but it's [map, map]; line 2 pos 4 + + +-- !query 5 +SELECT + map_concat(int_map1, struct_map2) is_map +FROM various_maps +-- !query 5 schema +struct<> +-- !query 5 output +org.apache.spark.sql.AnalysisException +cannot resolve 'map_concat(various_maps.`int_map1`, various_maps.`struct_map2`)' due to data type mismatch: input to function map_concat should all be the same type, but it's [map, map,struct>]; line 2 pos 4 + + +-- !query 6 +SELECT + map_concat(map_map1, array_map2) ma_map +FROM various_maps +-- !query 6 schema +struct<> +-- !query 6 output +org.apache.spark.sql.AnalysisException +cannot resolve 'map_concat(various_maps.`map_map1`, various_maps.`array_map2`)' due to data type mismatch: input to function map_concat should all be the same type, but it's [map,map>, map,array>]; line 2 pos 4 + + +-- !query 7 +SELECT + map_concat(map_map1, struct_map2) ms_map +FROM various_maps +-- !query 7 schema +struct<> +-- !query 7 output +org.apache.spark.sql.AnalysisException +cannot resolve 'map_concat(various_maps.`map_map1`, various_maps.`struct_map2`)' due to data type mismatch: input to function map_concat should all be the same type, but it's [map,map>, map,struct>]; line 2 pos 4 diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-for-multi-watermark-policy/commits/0 b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-for-multi-watermark-policy/commits/0 new file mode 100644 index 0000000000000..83321cd95eb0c --- /dev/null +++ b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-for-multi-watermark-policy/commits/0 @@ -0,0 +1,2 @@ +v1 +{} \ No newline at end of file diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-for-multi-watermark-policy/commits/1 b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-for-multi-watermark-policy/commits/1 new file mode 100644 index 0000000000000..83321cd95eb0c --- /dev/null +++ b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-for-multi-watermark-policy/commits/1 @@ -0,0 +1,2 @@ +v1 +{} \ No newline at end of file diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-for-multi-watermark-policy/metadata b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-for-multi-watermark-policy/metadata new file mode 100644 index 0000000000000..d6be7fbffa9b7 --- /dev/null +++ b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-for-multi-watermark-policy/metadata @@ -0,0 +1 @@ +{"id":"549eeb1a-d762-420c-bb44-3fd6d73a5268"} \ No newline at end of file diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-for-multi-watermark-policy/offsets/0 b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-for-multi-watermark-policy/offsets/0 new file mode 100644 index 0000000000000..43db49d052894 --- /dev/null +++ b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-for-multi-watermark-policy/offsets/0 @@ -0,0 +1,4 @@ +v1 +{"batchWatermarkMs":0,"batchTimestampMs":1531172902041,"conf":{"spark.sql.shuffle.partitions":"10","spark.sql.streaming.stateStore.providerClass":"org.apache.spark.sql.execution.streaming.state.HDFSBackedStateStoreProvider"}} +0 +0 \ No newline at end of file diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-for-multi-watermark-policy/offsets/1 b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-for-multi-watermark-policy/offsets/1 new file mode 100644 index 0000000000000..8cc898e81017f --- /dev/null +++ b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-for-multi-watermark-policy/offsets/1 @@ -0,0 +1,4 @@ +v1 +{"batchWatermarkMs":10000,"batchTimestampMs":1531172902217,"conf":{"spark.sql.shuffle.partitions":"10","spark.sql.streaming.stateStore.providerClass":"org.apache.spark.sql.execution.streaming.state.HDFSBackedStateStoreProvider"}} +1 +0 \ No newline at end of file diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala index 81b7e18773f81..60c73df88896b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala @@ -28,7 +28,6 @@ import org.apache.spark.sql.catalyst.expressions.SubqueryExpression import org.apache.spark.sql.execution.{RDDScanExec, SparkPlan} import org.apache.spark.sql.execution.columnar._ import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec -import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.{SharedSQLContext, SQLTestUtils} import org.apache.spark.storage.{RDDBlockId, StorageLevel} @@ -83,25 +82,6 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext }.sum } - test("withColumn doesn't invalidate cached dataframe") { - var evalCount = 0 - val myUDF = udf((x: String) => { evalCount += 1; "result" }) - val df = Seq(("test", 1)).toDF("s", "i").select(myUDF($"s")) - df.cache() - - df.collect() - assert(evalCount === 1) - - df.collect() - assert(evalCount === 1) - - val df2 = df.withColumn("newColumn", lit(1)) - df2.collect() - - // We should not reevaluate the cached dataframe - assert(evalCount === 1) - } - test("cache temp table") { withTempView("tempTable") { testData.select('key).createOrReplaceTempView("tempTable") @@ -820,4 +800,69 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext } assert(cachedData.collect === Seq(1001)) } + + test("SPARK-24596 Non-cascading Cache Invalidation - uncache temporary view") { + withTempView("t1", "t2") { + sql("CACHE TABLE t1 AS SELECT * FROM testData WHERE key > 1") + sql("CACHE TABLE t2 as SELECT * FROM t1 WHERE value > 1") + + assert(spark.catalog.isCached("t1")) + assert(spark.catalog.isCached("t2")) + sql("UNCACHE TABLE t1") + assert(!spark.catalog.isCached("t1")) + assert(spark.catalog.isCached("t2")) + } + } + + test("SPARK-24596 Non-cascading Cache Invalidation - drop temporary view") { + withTempView("t1", "t2") { + sql("CACHE TABLE t1 AS SELECT * FROM testData WHERE key > 1") + sql("CACHE TABLE t2 as SELECT * FROM t1 WHERE value > 1") + + assert(spark.catalog.isCached("t1")) + assert(spark.catalog.isCached("t2")) + sql("DROP VIEW t1") + assert(spark.catalog.isCached("t2")) + } + } + + test("SPARK-24596 Non-cascading Cache Invalidation - drop persistent view") { + withTable("t") { + spark.range(1, 10).toDF("key").withColumn("value", 'key * 2) + .write.format("json").saveAsTable("t") + withView("t1") { + withTempView("t2") { + sql("CREATE VIEW t1 AS SELECT * FROM t WHERE key > 1") + + sql("CACHE TABLE t1") + sql("CACHE TABLE t2 AS SELECT * FROM t1 WHERE value > 1") + + assert(spark.catalog.isCached("t1")) + assert(spark.catalog.isCached("t2")) + sql("DROP VIEW t1") + assert(!spark.catalog.isCached("t2")) + } + } + } + } + + test("SPARK-24596 Non-cascading Cache Invalidation - uncache table") { + withTable("t") { + spark.range(1, 10).toDF("key").withColumn("value", 'key * 2) + .write.format("json").saveAsTable("t") + withTempView("t1", "t2") { + sql("CACHE TABLE t") + sql("CACHE TABLE t1 AS SELECT * FROM t WHERE key > 1") + sql("CACHE TABLE t2 AS SELECT * FROM t1 WHERE value > 1") + + assert(spark.catalog.isCached("t")) + assert(spark.catalog.isCached("t1")) + assert(spark.catalog.isCached("t2")) + sql("UNCACHE TABLE t") + assert(!spark.catalog.isCached("t")) + assert(!spark.catalog.isCached("t1")) + assert(!spark.catalog.isCached("t2")) + } + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 4e5c1c56e2673..a39aef18d27e1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -18,12 +18,15 @@ package org.apache.spark.sql import java.nio.charset.StandardCharsets +import java.sql.{Date, Timestamp} +import java.util.TimeZone import scala.util.Random import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback +import org.apache.spark.sql.catalyst.util.DateTimeTestUtils import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext @@ -306,113 +309,6 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { ) } - test("mask functions") { - val df = Seq("TestString-123", "", null).toDF("a") - checkAnswer(df.select(mask($"a")), Seq(Row("XxxxXxxxxx-nnn"), Row(""), Row(null))) - checkAnswer(df.select(mask_first_n($"a", 4)), Seq(Row("XxxxString-123"), Row(""), Row(null))) - checkAnswer(df.select(mask_last_n($"a", 4)), Seq(Row("TestString-nnn"), Row(""), Row(null))) - checkAnswer(df.select(mask_show_first_n($"a", 4)), - Seq(Row("TestXxxxxx-nnn"), Row(""), Row(null))) - checkAnswer(df.select(mask_show_last_n($"a", 4)), - Seq(Row("XxxxXxxxxx-123"), Row(""), Row(null))) - checkAnswer(df.select(mask_hash($"a")), - Seq(Row("dd78d68ad1b23bde126812482dd70ac6"), - Row("d41d8cd98f00b204e9800998ecf8427e"), - Row(null))) - - checkAnswer(df.select(mask($"a", "U", "l", "#")), - Seq(Row("UlllUlllll-###"), Row(""), Row(null))) - checkAnswer(df.select(mask_first_n($"a", 4, "U", "l", "#")), - Seq(Row("UlllString-123"), Row(""), Row(null))) - checkAnswer(df.select(mask_last_n($"a", 4, "U", "l", "#")), - Seq(Row("TestString-###"), Row(""), Row(null))) - checkAnswer(df.select(mask_show_first_n($"a", 4, "U", "l", "#")), - Seq(Row("TestUlllll-###"), Row(""), Row(null))) - checkAnswer(df.select(mask_show_last_n($"a", 4, "U", "l", "#")), - Seq(Row("UlllUlllll-123"), Row(""), Row(null))) - - checkAnswer( - df.selectExpr("mask(a)", "mask(a, 'U')", "mask(a, 'U', 'l')", "mask(a, 'U', 'l', '#')"), - Seq(Row("XxxxXxxxxx-nnn", "UxxxUxxxxx-nnn", "UlllUlllll-nnn", "UlllUlllll-###"), - Row("", "", "", ""), - Row(null, null, null, null))) - checkAnswer(sql("select mask(null)"), Row(null)) - checkAnswer(sql("select mask('AAaa11', null, null, null)"), Row("XXxxnn")) - intercept[AnalysisException] { - checkAnswer(df.selectExpr("mask(a, a)"), Seq(Row("XxxxXxxxxx-nnn"), Row(""), Row(null))) - } - - checkAnswer( - df.selectExpr( - "mask_first_n(a)", - "mask_first_n(a, 6)", - "mask_first_n(a, 6, 'U')", - "mask_first_n(a, 6, 'U', 'l')", - "mask_first_n(a, 6, 'U', 'l', '#')"), - Seq(Row("XxxxString-123", "XxxxXxring-123", "UxxxUxring-123", "UlllUlring-123", - "UlllUlring-123"), - Row("", "", "", "", ""), - Row(null, null, null, null, null))) - checkAnswer(sql("select mask_first_n(null)"), Row(null)) - checkAnswer(sql("select mask_first_n('A1aA1a', null, null, null, null)"), Row("XnxX1a")) - intercept[AnalysisException] { - checkAnswer(spark.range(1).selectExpr("mask_first_n('A1aA1a', id)"), Row("XnxX1a")) - } - - checkAnswer( - df.selectExpr( - "mask_last_n(a)", - "mask_last_n(a, 6)", - "mask_last_n(a, 6, 'U')", - "mask_last_n(a, 6, 'U', 'l')", - "mask_last_n(a, 6, 'U', 'l', '#')"), - Seq(Row("TestString-nnn", "TestStrixx-nnn", "TestStrixx-nnn", "TestStrill-nnn", - "TestStrill-###"), - Row("", "", "", "", ""), - Row(null, null, null, null, null))) - checkAnswer(sql("select mask_last_n(null)"), Row(null)) - checkAnswer(sql("select mask_last_n('A1aA1a', null, null, null, null)"), Row("A1xXnx")) - intercept[AnalysisException] { - checkAnswer(spark.range(1).selectExpr("mask_last_n('A1aA1a', id)"), Row("A1xXnx")) - } - - checkAnswer( - df.selectExpr( - "mask_show_first_n(a)", - "mask_show_first_n(a, 6)", - "mask_show_first_n(a, 6, 'U')", - "mask_show_first_n(a, 6, 'U', 'l')", - "mask_show_first_n(a, 6, 'U', 'l', '#')"), - Seq(Row("TestXxxxxx-nnn", "TestStxxxx-nnn", "TestStxxxx-nnn", "TestStllll-nnn", - "TestStllll-###"), - Row("", "", "", "", ""), - Row(null, null, null, null, null))) - checkAnswer(sql("select mask_show_first_n(null)"), Row(null)) - checkAnswer(sql("select mask_show_first_n('A1aA1a', null, null, null, null)"), Row("A1aAnx")) - intercept[AnalysisException] { - checkAnswer(spark.range(1).selectExpr("mask_show_first_n('A1aA1a', id)"), Row("A1aAnx")) - } - - checkAnswer( - df.selectExpr( - "mask_show_last_n(a)", - "mask_show_last_n(a, 6)", - "mask_show_last_n(a, 6, 'U')", - "mask_show_last_n(a, 6, 'U', 'l')", - "mask_show_last_n(a, 6, 'U', 'l', '#')"), - Seq(Row("XxxxXxxxxx-123", "XxxxXxxxng-123", "UxxxUxxxng-123", "UlllUlllng-123", - "UlllUlllng-123"), - Row("", "", "", "", ""), - Row(null, null, null, null, null))) - checkAnswer(sql("select mask_show_last_n(null)"), Row(null)) - checkAnswer(sql("select mask_show_last_n('A1aA1a', null, null, null, null)"), Row("XnaA1a")) - intercept[AnalysisException] { - checkAnswer(spark.range(1).selectExpr("mask_show_last_n('A1aA1a', id)"), Row("XnaA1a")) - } - - checkAnswer(sql("select mask_hash(null)"), Row(null)) - } - test("sort_array/array_sort functions") { val df = Seq( (Array[Int](2, 1, 3), Array("b", "c", "a")), @@ -487,26 +383,29 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { }.getMessage().contains("only supports array input")) } - test("array size function") { + def testSizeOfArray(sizeOfNull: Any): Unit = { val df = Seq( (Seq[Int](1, 2), "x"), (Seq[Int](), "y"), (Seq[Int](1, 2, 3), "z"), (null, "empty") ).toDF("a", "b") - checkAnswer( - df.select(size($"a")), - Seq(Row(2), Row(0), Row(3), Row(-1)) - ) - checkAnswer( - df.selectExpr("size(a)"), - Seq(Row(2), Row(0), Row(3), Row(-1)) - ) - checkAnswer( - df.selectExpr("cardinality(a)"), - Seq(Row(2L), Row(0L), Row(3L), Row(-1L)) - ) + checkAnswer(df.select(size($"a")), Seq(Row(2), Row(0), Row(3), Row(sizeOfNull))) + checkAnswer(df.selectExpr("size(a)"), Seq(Row(2), Row(0), Row(3), Row(sizeOfNull))) + checkAnswer(df.selectExpr("cardinality(a)"), Seq(Row(2L), Row(0L), Row(3L), Row(sizeOfNull))) + } + + test("array size function - legacy") { + withSQLConf(SQLConf.LEGACY_SIZE_OF_NULL.key -> "true") { + testSizeOfArray(sizeOfNull = -1) + } + } + + test("array size function") { + withSQLConf(SQLConf.LEGACY_SIZE_OF_NULL.key -> "false") { + testSizeOfArray(sizeOfNull = null) + } } test("dataframe arrays_zip function") { @@ -556,21 +455,39 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { checkAnswer(df8.selectExpr("arrays_zip(v1, v2)"), expectedValue8) } - test("map size function") { + test("SPARK-24633: arrays_zip splits input processing correctly") { + Seq("true", "false").foreach { wholestageCodegenEnabled => + withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> wholestageCodegenEnabled) { + val df = spark.range(1) + val exprs = (0 to 5).map(x => array($"id" + lit(x))) + checkAnswer(df.select(arrays_zip(exprs: _*)), + Row(Seq(Row(0, 1, 2, 3, 4, 5)))) + } + } + } + + def testSizeOfMap(sizeOfNull: Any): Unit = { val df = Seq( (Map[Int, Int](1 -> 1, 2 -> 2), "x"), (Map[Int, Int](), "y"), (Map[Int, Int](1 -> 1, 2 -> 2, 3 -> 3), "z"), (null, "empty") ).toDF("a", "b") - checkAnswer( - df.select(size($"a")), - Seq(Row(2), Row(0), Row(3), Row(-1)) - ) - checkAnswer( - df.selectExpr("size(a)"), - Seq(Row(2), Row(0), Row(3), Row(-1)) - ) + + checkAnswer(df.select(size($"a")), Seq(Row(2), Row(0), Row(3), Row(sizeOfNull))) + checkAnswer(df.selectExpr("size(a)"), Seq(Row(2), Row(0), Row(3), Row(sizeOfNull))) + } + + test("map size function - legacy") { + withSQLConf(SQLConf.LEGACY_SIZE_OF_NULL.key -> "true") { + testSizeOfMap(sizeOfNull = -1: Int) + } + } + + test("map size function") { + withSQLConf(SQLConf.LEGACY_SIZE_OF_NULL.key -> "false") { + testSizeOfMap(sizeOfNull = null) + } } test("map_keys/map_values function") { @@ -633,11 +550,139 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { checkAnswer(sdf.filter(dummyFilter('m)).select(map_entries('m)), sExpected) } + test("map_concat function") { + val df1 = Seq( + (Map[Int, Int](1 -> 100, 2 -> 200), Map[Int, Int](3 -> 300, 4 -> 400)), + (Map[Int, Int](1 -> 100, 2 -> 200), Map[Int, Int](3 -> 300, 1 -> 400)), + (null, Map[Int, Int](3 -> 300, 4 -> 400)) + ).toDF("map1", "map2") + + val expected1a = Seq( + Row(Map(1 -> 100, 2 -> 200, 3 -> 300, 4 -> 400)), + Row(Map(1 -> 400, 2 -> 200, 3 -> 300)), + Row(null) + ) + + checkAnswer(df1.selectExpr("map_concat(map1, map2)"), expected1a) + checkAnswer(df1.select(map_concat('map1, 'map2)), expected1a) + + val expected1b = Seq( + Row(Map(1 -> 100, 2 -> 200)), + Row(Map(1 -> 100, 2 -> 200)), + Row(null) + ) + + checkAnswer(df1.selectExpr("map_concat(map1)"), expected1b) + checkAnswer(df1.select(map_concat('map1)), expected1b) + + val df2 = Seq( + ( + Map[Array[Int], Int](Array(1) -> 100, Array(2) -> 200), + Map[String, Int]("3" -> 300, "4" -> 400) + ) + ).toDF("map1", "map2") + + val expected2 = Seq(Row(Map())) + + checkAnswer(df2.selectExpr("map_concat()"), expected2) + checkAnswer(df2.select(map_concat()), expected2) + + val df3 = { + val schema = StructType( + StructField("map1", MapType(StringType, IntegerType, true), false) :: + StructField("map2", MapType(StringType, IntegerType, false), false) :: Nil + ) + val data = Seq( + Row(Map[String, Any]("a" -> 1, "b" -> null), Map[String, Any]("c" -> 3, "d" -> 4)), + Row(Map[String, Any]("a" -> 1, "b" -> 2), Map[String, Any]("c" -> 3, "d" -> 4)) + ) + spark.createDataFrame(spark.sparkContext.parallelize(data), schema) + } + + val expected3 = Seq( + Row(Map[String, Any]("a" -> 1, "b" -> null, "c" -> 3, "d" -> 4)), + Row(Map[String, Any]("a" -> 1, "b" -> 2, "c" -> 3, "d" -> 4)) + ) + + checkAnswer(df3.selectExpr("map_concat(map1, map2)"), expected3) + checkAnswer(df3.select(map_concat('map1, 'map2)), expected3) + + val expectedMessage1 = "input to function map_concat should all be the same type" + + assert(intercept[AnalysisException] { + df2.selectExpr("map_concat(map1, map2)").collect() + }.getMessage().contains(expectedMessage1)) + + assert(intercept[AnalysisException] { + df2.select(map_concat('map1, 'map2)).collect() + }.getMessage().contains(expectedMessage1)) + + val expectedMessage2 = "input to function map_concat should all be of type map" + + assert(intercept[AnalysisException] { + df2.selectExpr("map_concat(map1, 12)").collect() + }.getMessage().contains(expectedMessage2)) + + assert(intercept[AnalysisException] { + df2.select(map_concat('map1, lit(12))).collect() + }.getMessage().contains(expectedMessage2)) + } + + test("map_from_entries function") { + def dummyFilter(c: Column): Column = c.isNull || c.isNotNull + val oneRowDF = Seq(3215).toDF("i") + + // Test cases with primitive-type keys and values + val idf = Seq( + Seq((1, 10), (2, 20), (3, 10)), + Seq((1, 10), null, (2, 20)), + Seq.empty, + null + ).toDF("a") + val iExpected = Seq( + Row(Map(1 -> 10, 2 -> 20, 3 -> 10)), + Row(null), + Row(Map.empty), + Row(null)) + + checkAnswer(idf.select(map_from_entries('a)), iExpected) + checkAnswer(idf.selectExpr("map_from_entries(a)"), iExpected) + checkAnswer(idf.filter(dummyFilter('a)).select(map_from_entries('a)), iExpected) + checkAnswer( + oneRowDF.selectExpr("map_from_entries(array(struct(1, null), struct(2, null)))"), + Seq(Row(Map(1 -> null, 2 -> null))) + ) + checkAnswer( + oneRowDF.filter(dummyFilter('i)) + .selectExpr("map_from_entries(array(struct(1, null), struct(2, null)))"), + Seq(Row(Map(1 -> null, 2 -> null))) + ) + + // Test cases with non-primitive-type keys and values + val sdf = Seq( + Seq(("a", "aa"), ("b", "bb"), ("c", "aa")), + Seq(("a", "aa"), null, ("b", "bb")), + Seq(("a", null), ("b", null)), + Seq.empty, + null + ).toDF("a") + val sExpected = Seq( + Row(Map("a" -> "aa", "b" -> "bb", "c" -> "aa")), + Row(null), + Row(Map("a" -> null, "b" -> null)), + Row(Map.empty), + Row(null)) + + checkAnswer(sdf.select(map_from_entries('a)), sExpected) + checkAnswer(sdf.selectExpr("map_from_entries(a)"), sExpected) + checkAnswer(sdf.filter(dummyFilter('a)).select(map_from_entries('a)), sExpected) + } + test("array contains function") { val df = Seq( - (Seq[Int](1, 2), "x"), - (Seq[Int](), "x") - ).toDF("a", "b") + (Seq[Int](1, 2), "x", 1), + (Seq[Int](), "x", 1) + ).toDF("a", "b", "c") // Simple test cases checkAnswer( @@ -648,6 +693,14 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { df.selectExpr("array_contains(a, 1)"), Seq(Row(true), Row(false)) ) + checkAnswer( + df.select(array_contains(df("a"), df("c"))), + Seq(Row(true), Row(false)) + ) + checkAnswer( + df.selectExpr("array_contains(a, c)"), + Seq(Row(true), Row(false)) + ) // In hive, this errors because null has no type information intercept[AnalysisException] { @@ -736,6 +789,23 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { checkAnswer( df.selectExpr("array_join(x, delimiter, 'NULL')"), Seq(Row("a,b"), Row("a,NULL,b"), Row(""))) + + val idf = Seq(Seq(1, 2, 3)).toDF("x") + + checkAnswer( + idf.select(array_join(idf("x"), ", ")), + Seq(Row("1, 2, 3")) + ) + checkAnswer( + idf.selectExpr("array_join(x, ', ')"), + Seq(Row("1, 2, 3")) + ) + intercept[AnalysisException] { + idf.selectExpr("array_join(x, 1)") + } + intercept[AnalysisException] { + idf.selectExpr("array_join(x, ', ', 1)") + } } test("array_min function") { @@ -766,6 +836,59 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { checkAnswer(df.selectExpr("array_max(a)"), answer) } + test("sequence") { + checkAnswer(Seq((-2, 2)).toDF().select(sequence('_1, '_2)), Seq(Row(Array(-2, -1, 0, 1, 2)))) + checkAnswer(Seq((7, 2, -2)).toDF().select(sequence('_1, '_2, '_3)), Seq(Row(Array(7, 5, 3)))) + + checkAnswer( + spark.sql("select sequence(" + + " cast('2018-01-01 00:00:00' as timestamp)" + + ", cast('2018-01-02 00:00:00' as timestamp)" + + ", interval 12 hours)"), + Seq(Row(Array( + Timestamp.valueOf("2018-01-01 00:00:00"), + Timestamp.valueOf("2018-01-01 12:00:00"), + Timestamp.valueOf("2018-01-02 00:00:00"))))) + + DateTimeTestUtils.withDefaultTimeZone(TimeZone.getTimeZone("UTC")) { + checkAnswer( + spark.sql("select sequence(" + + " cast('2018-01-01' as date)" + + ", cast('2018-03-01' as date)" + + ", interval 1 month)"), + Seq(Row(Array( + Date.valueOf("2018-01-01"), + Date.valueOf("2018-02-01"), + Date.valueOf("2018-03-01"))))) + } + + // test type coercion + checkAnswer( + Seq((1.toByte, 3L, 1)).toDF().select(sequence('_1, '_2, '_3)), + Seq(Row(Array(1L, 2L, 3L)))) + + checkAnswer( + spark.sql("select sequence(" + + " cast('2018-01-01' as date)" + + ", cast('2018-01-02 00:00:00' as timestamp)" + + ", interval 12 hours)"), + Seq(Row(Array( + Timestamp.valueOf("2018-01-01 00:00:00"), + Timestamp.valueOf("2018-01-01 12:00:00"), + Timestamp.valueOf("2018-01-02 00:00:00"))))) + + // test invalid data types + intercept[AnalysisException] { + Seq((true, false)).toDF().selectExpr("sequence(_1, _2)") + } + intercept[AnalysisException] { + Seq((true, false, 42)).toDF().selectExpr("sequence(_1, _2, _3)") + } + intercept[AnalysisException] { + Seq((1, 2, 0.5)).toDF().selectExpr("sequence(_1, _2, _3)") + } + } + test("reverse function") { val dummyFilter = (c: Column) => c.isNull || c.isNotNull // switch codegen on @@ -862,9 +985,9 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { test("array position function") { val df = Seq( - (Seq[Int](1, 2), "x"), - (Seq[Int](), "x") - ).toDF("a", "b") + (Seq[Int](1, 2), "x", 1), + (Seq[Int](), "x", 1) + ).toDF("a", "b", "c") checkAnswer( df.select(array_position(df("a"), 1)), @@ -874,7 +997,14 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { df.selectExpr("array_position(a, 1)"), Seq(Row(1L), Row(0L)) ) - + checkAnswer( + df.selectExpr("array_position(a, c)"), + Seq(Row(1L), Row(0L)) + ) + checkAnswer( + df.select(array_position(df("a"), df("c"))), + Seq(Row(1L), Row(0L)) + ) checkAnswer( df.select(array_position(df("a"), null)), Seq(Row(null), Row(null)) @@ -901,10 +1031,10 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { test("element_at function") { val df = Seq( - (Seq[String]("1", "2", "3")), - (Seq[String](null, "")), - (Seq[String]()) - ).toDF("a") + (Seq[String]("1", "2", "3"), 1), + (Seq[String](null, ""), -1), + (Seq[String](), 2) + ).toDF("a", "b") intercept[Exception] { checkAnswer( @@ -922,6 +1052,14 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { df.select(element_at(df("a"), 4)), Seq(Row(null), Row(null), Row(null)) ) + checkAnswer( + df.select(element_at(df("a"), df("b"))), + Seq(Row("1"), Row(""), Row(null)) + ) + checkAnswer( + df.selectExpr("element_at(a, b)"), + Seq(Row("1"), Row(""), Row(null)) + ) checkAnswer( df.select(element_at(df("a"), 1)), @@ -953,6 +1091,58 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { "argument 1 requires (array or map) type, however, '`_1`' is of string type")) } + test("array_union functions") { + val df1 = Seq((Array(1, 2, 3), Array(4, 2))).toDF("a", "b") + val ans1 = Row(Seq(1, 2, 3, 4)) + checkAnswer(df1.select(array_union($"a", $"b")), ans1) + checkAnswer(df1.selectExpr("array_union(a, b)"), ans1) + + val df2 = Seq((Array[Integer](1, 2, null, 4, 5), Array(-5, 4, -3, 2, -1))).toDF("a", "b") + val ans2 = Row(Seq(1, 2, null, 4, 5, -5, -3, -1)) + checkAnswer(df2.select(array_union($"a", $"b")), ans2) + checkAnswer(df2.selectExpr("array_union(a, b)"), ans2) + + val df3 = Seq((Array(1L, 2L, 3L), Array(4L, 2L))).toDF("a", "b") + val ans3 = Row(Seq(1L, 2L, 3L, 4L)) + checkAnswer(df3.select(array_union($"a", $"b")), ans3) + checkAnswer(df3.selectExpr("array_union(a, b)"), ans3) + + val df4 = Seq((Array[java.lang.Long](1L, 2L, null, 4L, 5L), Array(-5L, 4L, -3L, 2L, -1L))) + .toDF("a", "b") + val ans4 = Row(Seq(1L, 2L, null, 4L, 5L, -5L, -3L, -1L)) + checkAnswer(df4.select(array_union($"a", $"b")), ans4) + checkAnswer(df4.selectExpr("array_union(a, b)"), ans4) + + val df5 = Seq((Array("b", "a", "c"), Array("b", null, "a", "g"))).toDF("a", "b") + val ans5 = Row(Seq("b", "a", "c", null, "g")) + checkAnswer(df5.select(array_union($"a", $"b")), ans5) + checkAnswer(df5.selectExpr("array_union(a, b)"), ans5) + + val df6 = Seq((null, Array("a"))).toDF("a", "b") + intercept[AnalysisException] { + df6.select(array_union($"a", $"b")) + } + intercept[AnalysisException] { + df6.selectExpr("array_union(a, b)") + } + + val df7 = Seq((null, null)).toDF("a", "b") + intercept[AnalysisException] { + df7.select(array_union($"a", $"b")) + } + intercept[AnalysisException] { + df7.selectExpr("array_union(a, b)") + } + + val df8 = Seq((Array(Array(1)), Array("a"))).toDF("a", "b") + intercept[AnalysisException] { + df8.select(array_union($"a", $"b")) + } + intercept[AnalysisException] { + df8.selectExpr("array_union(a, b)") + } + } + test("concat function - arrays") { val nseqi : Seq[Int] = null val nseqs : Seq[String] = null @@ -1189,10 +1379,10 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { test("array remove") { val df = Seq( - (Array[Int](2, 1, 2, 3), Array("a", "b", "c", "a"), Array("", "")), - (Array.empty[Int], Array.empty[String], Array.empty[String]), - (null, null, null) - ).toDF("a", "b", "c") + (Array[Int](2, 1, 2, 3), Array("a", "b", "c", "a"), Array("", ""), 2), + (Array.empty[Int], Array.empty[String], Array.empty[String], 2), + (null, null, null, 2) + ).toDF("a", "b", "c", "d") checkAnswer( df.select(array_remove($"a", 2), array_remove($"b", "a"), array_remove($"c", "")), Seq( @@ -1201,6 +1391,22 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { Row(null, null, null)) ) + checkAnswer( + df.select(array_remove($"a", $"d")), + Seq( + Row(Seq(1, 3)), + Row(Seq.empty[Int]), + Row(null)) + ) + + checkAnswer( + df.selectExpr("array_remove(a, d)"), + Seq( + Row(Seq(1, 3)), + Row(Seq.empty[Int]), + Row(null)) + ) + checkAnswer( df.selectExpr("array_remove(a, 2)", "array_remove(b, \"a\")", "array_remove(c, \"\")"), @@ -1216,6 +1422,28 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { assert(e.message.contains("argument 1 requires array type, however, '`_1`' is of string type")) } + test("array_distinct functions") { + val df = Seq( + (Array[Int](2, 1, 3, 4, 3, 5), Array("b", "c", "a", "c", "b", "", "")), + (Array.empty[Int], Array.empty[String]), + (null, null) + ).toDF("a", "b") + checkAnswer( + df.select(array_distinct($"a"), array_distinct($"b")), + Seq( + Row(Seq(2, 1, 3, 4, 5), Seq("b", "c", "a", "")), + Row(Seq.empty[Int], Seq.empty[String]), + Row(null, null)) + ) + checkAnswer( + df.selectExpr("array_distinct(a)", "array_distinct(b)"), + Seq( + Row(Seq(2, 1, 3, 4, 5), Seq("b", "c", "a", "")), + Row(Seq.empty[Int], Seq.empty[String]), + Row(null, null)) + ) + } + private def assertValuesDoNotChangeAfterCoalesceOrUnion(v: Column): Unit = { import DataFrameFunctionsSuite.CodegenFallbackExpr for ((codegenFallback, wholeStage) <- Seq((true, false), (false, false), (false, true))) { @@ -1287,6 +1515,14 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { assert(errMsg.contains(s"input to function $name requires at least two arguments")) } } + + test("SPARK-24734: Fix containsNull of Concat for array type") { + val df = Seq((Seq(1), Seq[Integer](null), Seq("a", "b"))).toDF("k1", "k2", "v") + val ex = intercept[RuntimeException] { + df.select(map_from_arrays(concat($"k1", $"k2"), $"v")).show() + } + assert(ex.getMessage.contains("Cannot use null as map key")) + } } object DataFrameFunctionsSuite { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala index 0d9eeabb397a1..10d9a11d2ee79 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala @@ -287,4 +287,12 @@ class DataFrameJoinSuite extends QueryTest with SharedSQLContext { dfOne.join(dfTwo, $"a" === $"b", "left").queryExecution.optimizedPlan } } + + test("SPARK-24385: Resolve ambiguity in self-joins with EqualNullSafe") { + withSQLConf(SQLConf.CROSS_JOINS_ENABLED.key -> "false") { + val df = spark.range(2) + // this throws an exception before the fix + df.join(df, df("id") <=> df("id")).queryExecution.optimizedPlan + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 1cc8cb3874c9b..5babdf6f33b99 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -1044,6 +1044,65 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { testData.select($"*").show(1000) } + test("getRows: truncate = [0, 20]") { + val longString = Array.fill(21)("1").mkString + val df = sparkContext.parallelize(Seq("1", longString)).toDF() + val expectedAnswerForFalse = Seq( + Seq("value"), + Seq("1"), + Seq("111111111111111111111")) + assert(df.getRows(10, 0) === expectedAnswerForFalse) + val expectedAnswerForTrue = Seq( + Seq("value"), + Seq("1"), + Seq("11111111111111111...")) + assert(df.getRows(10, 20) === expectedAnswerForTrue) + } + + test("getRows: truncate = [3, 17]") { + val longString = Array.fill(21)("1").mkString + val df = sparkContext.parallelize(Seq("1", longString)).toDF() + val expectedAnswerForFalse = Seq( + Seq("value"), + Seq("1"), + Seq("111")) + assert(df.getRows(10, 3) === expectedAnswerForFalse) + val expectedAnswerForTrue = Seq( + Seq("value"), + Seq("1"), + Seq("11111111111111...")) + assert(df.getRows(10, 17) === expectedAnswerForTrue) + } + + test("getRows: numRows = 0") { + val expectedAnswer = Seq(Seq("key", "value"), Seq("1", "1")) + assert(testData.select($"*").getRows(0, 20) === expectedAnswer) + } + + test("getRows: array") { + val df = Seq( + (Array(1, 2, 3), Array(1, 2, 3)), + (Array(2, 3, 4), Array(2, 3, 4)) + ).toDF() + val expectedAnswer = Seq( + Seq("_1", "_2"), + Seq("[1, 2, 3]", "[1, 2, 3]"), + Seq("[2, 3, 4]", "[2, 3, 4]")) + assert(df.getRows(10, 20) === expectedAnswer) + } + + test("getRows: binary") { + val df = Seq( + ("12".getBytes(StandardCharsets.UTF_8), "ABC.".getBytes(StandardCharsets.UTF_8)), + ("34".getBytes(StandardCharsets.UTF_8), "12346".getBytes(StandardCharsets.UTF_8)) + ).toDF() + val expectedAnswer = Seq( + Seq("_1", "_2"), + Seq("[31 32]", "[41 42 43 2E]"), + Seq("[33 34]", "[31 32 33 34 36]")) + assert(df.getRows(10, 20) === expectedAnswer) + } + test("showString: truncate = [0, 20]") { val longString = Array.fill(21)("1").mkString val df = sparkContext.parallelize(Seq("1", longString)).toDF() @@ -2261,6 +2320,64 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { assert(df.queryExecution.executedPlan.isInstanceOf[WholeStageCodegenExec]) } + test("SPARK-24165: CaseWhen/If - nullability of nested types") { + val rows = new java.util.ArrayList[Row]() + rows.add(Row(true, ("x", 1), Seq("x", "y"), Map(0 -> "x"))) + rows.add(Row(false, (null, 2), Seq(null, "z"), Map(0 -> null))) + val schema = StructType(Seq( + StructField("cond", BooleanType, true), + StructField("s", StructType(Seq( + StructField("val1", StringType, true), + StructField("val2", IntegerType, false) + )), false), + StructField("a", ArrayType(StringType, true)), + StructField("m", MapType(IntegerType, StringType, true)) + )) + + val sourceDF = spark.createDataFrame(rows, schema) + + val structWhenDF = sourceDF + .select(when('cond, struct(lit("a").as("val1"), lit(10).as("val2"))).otherwise('s) as "res") + .select('res.getField("val1")) + val arrayWhenDF = sourceDF + .select(when('cond, array(lit("a"), lit("b"))).otherwise('a) as "res") + .select('res.getItem(0)) + val mapWhenDF = sourceDF + .select(when('cond, map(lit(0), lit("a"))).otherwise('m) as "res") + .select('res.getItem(0)) + + val structIfDF = sourceDF + .select(expr("if(cond, struct('a' as val1, 10 as val2), s)") as "res") + .select('res.getField("val1")) + val arrayIfDF = sourceDF + .select(expr("if(cond, array('a', 'b'), a)") as "res") + .select('res.getItem(0)) + val mapIfDF = sourceDF + .select(expr("if(cond, map(0, 'a'), m)") as "res") + .select('res.getItem(0)) + + def checkResult(df: DataFrame, codegenExpected: Boolean): Unit = { + assert(df.queryExecution.executedPlan.isInstanceOf[WholeStageCodegenExec] == codegenExpected) + checkAnswer(df, Seq(Row("a"), Row(null))) + } + + // without codegen + checkResult(structWhenDF, false) + checkResult(arrayWhenDF, false) + checkResult(mapWhenDF, false) + checkResult(structIfDF, false) + checkResult(arrayIfDF, false) + checkResult(mapIfDF, false) + + // with codegen + checkResult(structWhenDF.filter('cond.isNotNull), true) + checkResult(arrayWhenDF.filter('cond.isNotNull), true) + checkResult(mapWhenDF.filter('cond.isNotNull), true) + checkResult(structIfDF.filter('cond.isNotNull), true) + checkResult(arrayIfDF.filter('cond.isNotNull), true) + checkResult(mapIfDF.filter('cond.isNotNull), true) + } + test("Uuid expressions should produce same results at retries in the same DataFrame") { val df = spark.range(1).select($"id", new Column(Uuid())) checkAnswer(df, df.collect()) @@ -2270,4 +2387,29 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { val mapWithBinaryKey = map(lit(Array[Byte](1.toByte)), lit(1)) checkAnswer(spark.range(1).select(mapWithBinaryKey.getItem(Array[Byte](1.toByte))), Row(1)) } + + test("SPARK-24781: Using a reference from Dataset in Filter/Sort") { + val df = Seq(("test1", 0), ("test2", 1)).toDF("name", "id") + val filter1 = df.select(df("name")).filter(df("id") === 0) + val filter2 = df.select(col("name")).filter(col("id") === 0) + checkAnswer(filter1, filter2.collect()) + + val sort1 = df.select(df("name")).orderBy(df("id")) + val sort2 = df.select(col("name")).orderBy(col("id")) + checkAnswer(sort1, sort2.collect()) + } + + test("SPARK-24781: Using a reference not in aggregation in Filter/Sort") { + withSQLConf(SQLConf.DATAFRAME_RETAIN_GROUP_COLUMNS.key -> "false") { + val df = Seq(("test1", 0), ("test2", 1)).toDF("name", "id") + + val aggPlusSort1 = df.groupBy(df("name")).agg(count(df("name"))).orderBy(df("name")) + val aggPlusSort2 = df.groupBy(col("name")).agg(count(col("name"))).orderBy(col("name")) + checkAnswer(aggPlusSort1, aggPlusSort2.collect()) + + val aggPlusFilter1 = df.groupBy(df("name")).agg(count(df("name"))).filter(df("name") === 0) + val aggPlusFilter2 = df.groupBy(col("name")).agg(count(col("name"))).filter(col("name") === 0) + checkAnswer(aggPlusFilter1, aggPlusFilter2.collect()) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala index 3ea398aad7375..97a843978f0bd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala @@ -17,9 +17,7 @@ package org.apache.spark.sql -import java.sql.{Date, Timestamp} - -import scala.collection.mutable +import org.scalatest.Matchers.the import org.apache.spark.TestUtils.{assertNotSpilled, assertSpilled} import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction, Window} @@ -27,7 +25,6 @@ import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.types.CalendarInterval /** * Window function testing for DataFrame API. @@ -624,4 +621,41 @@ class DataFrameWindowFunctionsSuite extends QueryTest with SharedSQLContext { } } } + + test("SPARK-24575: Window functions inside WHERE and HAVING clauses") { + def checkAnalysisError(df: => DataFrame): Unit = { + val thrownException = the [AnalysisException] thrownBy { + df.queryExecution.analyzed + } + assert(thrownException.message.contains("window functions inside WHERE and HAVING clauses")) + } + + checkAnalysisError(testData2.select('a).where(rank().over(Window.orderBy('b)) === 1)) + checkAnalysisError(testData2.where('b === 2 && rank().over(Window.orderBy('b)) === 1)) + checkAnalysisError( + testData2.groupBy('a) + .agg(avg('b).as("avgb")) + .where('a > 'avgb && rank().over(Window.orderBy('a)) === 1)) + checkAnalysisError( + testData2.groupBy('a) + .agg(max('b).as("maxb"), sum('b).as("sumb")) + .where(rank().over(Window.orderBy('a)) === 1)) + checkAnalysisError( + testData2.groupBy('a) + .agg(max('b).as("maxb"), sum('b).as("sumb")) + .where('sumb === 5 && rank().over(Window.orderBy('a)) === 1)) + + checkAnalysisError(sql("SELECT a FROM testData2 WHERE RANK() OVER(ORDER BY b) = 1")) + checkAnalysisError(sql("SELECT * FROM testData2 WHERE b = 2 AND RANK() OVER(ORDER BY b) = 1")) + checkAnalysisError( + sql("SELECT * FROM testData2 GROUP BY a HAVING a > AVG(b) AND RANK() OVER(ORDER BY a) = 1")) + checkAnalysisError( + sql("SELECT a, MAX(b), SUM(b) FROM testData2 GROUP BY a HAVING RANK() OVER(ORDER BY a) = 1")) + checkAnalysisError( + sql( + s"""SELECT a, MAX(b) + |FROM testData2 + |GROUP BY a + |HAVING SUM(b) = 5 AND RANK() OVER(ORDER BY a) = 1""".stripMargin)) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala index 0e7eaa9e88d57..538ea3c66c40e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala @@ -148,6 +148,41 @@ object VeryComplexResultAgg extends Aggregator[Row, String, ComplexAggData] { } +case class OptionBooleanData(name: String, isGood: Option[Boolean]) + +case class OptionBooleanAggregator(colName: String) + extends Aggregator[Row, Option[Boolean], Option[Boolean]] { + + override def zero: Option[Boolean] = None + + override def reduce(buffer: Option[Boolean], row: Row): Option[Boolean] = { + val index = row.fieldIndex(colName) + val value = if (row.isNullAt(index)) { + Option.empty[Boolean] + } else { + Some(row.getBoolean(index)) + } + merge(buffer, value) + } + + override def merge(b1: Option[Boolean], b2: Option[Boolean]): Option[Boolean] = { + if ((b1.isDefined && b1.get) || (b2.isDefined && b2.get)) { + Some(true) + } else if (b1.isDefined) { + b1 + } else { + b2 + } + } + + override def finish(reduction: Option[Boolean]): Option[Boolean] = reduction + + override def bufferEncoder: Encoder[Option[Boolean]] = OptionalBoolEncoder + override def outputEncoder: Encoder[Option[Boolean]] = OptionalBoolEncoder + + def OptionalBoolEncoder: Encoder[Option[Boolean]] = ExpressionEncoder() +} + class DatasetAggregatorSuite extends QueryTest with SharedSQLContext { import testImplicits._ @@ -333,4 +368,29 @@ class DatasetAggregatorSuite extends QueryTest with SharedSQLContext { df.groupBy($"i").agg(VeryComplexResultAgg.toColumn), Row(1, Row(Row(1, "a"), Row(1, "a"))) :: Row(2, Row(Row(2, "bc"), Row(2, "bc"))) :: Nil) } + + test("SPARK-24569: Aggregator with output type Option[Boolean] creates column of type Row") { + val df = Seq( + OptionBooleanData("bob", Some(true)), + OptionBooleanData("bob", Some(false)), + OptionBooleanData("bob", None)).toDF() + val group = df + .groupBy("name") + .agg(OptionBooleanAggregator("isGood").toColumn.alias("isGood")) + assert(df.schema == group.schema) + checkAnswer(group, Row("bob", true) :: Nil) + checkDataset(group.as[OptionBooleanData], OptionBooleanData("bob", Some(true))) + } + + test("SPARK-24569: groupByKey with Aggregator of output type Option[Boolean]") { + val df = Seq( + OptionBooleanData("bob", Some(true)), + OptionBooleanData("bob", Some(false)), + OptionBooleanData("bob", None)).toDF() + val grouped = df.groupByKey((r: Row) => r.getString(0)) + .agg(OptionBooleanAggregator("isGood").toColumn).toDF("name", "isGood") + + assert(grouped.schema == df.schema) + checkDataset(grouped.as[OptionBooleanData], OptionBooleanData("bob", Some(true))) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala index e0561ee2797a5..5c6a021d5b767 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala @@ -17,14 +17,28 @@ package org.apache.spark.sql +import org.scalatest.concurrent.TimeLimits +import org.scalatest.time.SpanSugar._ + +import org.apache.spark.sql.execution.columnar.{InMemoryRelation, InMemoryTableScanExec} import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.storage.StorageLevel -class DatasetCacheSuite extends QueryTest with SharedSQLContext { +class DatasetCacheSuite extends QueryTest with SharedSQLContext with TimeLimits { import testImplicits._ + /** + * Asserts that a cached [[Dataset]] will be built using the given number of other cached results. + */ + private def assertCacheDependency(df: DataFrame, numOfCachesDependedUpon: Int = 1): Unit = { + val plan = df.queryExecution.withCachedData + assert(plan.isInstanceOf[InMemoryRelation]) + val internalPlan = plan.asInstanceOf[InMemoryRelation].cacheBuilder.cachedPlan + assert(internalPlan.find(_.isInstanceOf[InMemoryTableScanExec]).size == numOfCachesDependedUpon) + } + test("get storage level") { val ds1 = Seq("1", "2").toDS().as("a") val ds2 = Seq(2, 3).toDS().as("b") @@ -96,4 +110,100 @@ class DatasetCacheSuite extends QueryTest with SharedSQLContext { agged.unpersist() assert(agged.storageLevel == StorageLevel.NONE, "The Dataset agged should not be cached.") } + + test("persist and then withColumn") { + val df = Seq(("test", 1)).toDF("s", "i") + val df2 = df.withColumn("newColumn", lit(1)) + + df.cache() + assertCached(df) + assertCached(df2) + + df.count() + assertCached(df2) + + df.unpersist() + assert(df.storageLevel == StorageLevel.NONE) + } + + test("cache UDF result correctly") { + val expensiveUDF = udf({x: Int => Thread.sleep(5000); x}) + val df = spark.range(0, 10).toDF("a").withColumn("b", expensiveUDF($"a")) + val df2 = df.agg(sum(df("b"))) + + df.cache() + df.count() + assertCached(df2) + + // udf has been evaluated during caching, and thus should not be re-evaluated here + failAfter(3 seconds) { + df2.collect() + } + + df.unpersist() + assert(df.storageLevel == StorageLevel.NONE) + } + + test("SPARK-24613 Cache with UDF could not be matched with subsequent dependent caches") { + val udf1 = udf({x: Int => x + 1}) + val df = spark.range(0, 10).toDF("a").withColumn("b", udf1($"a")) + val df2 = df.agg(sum(df("b"))) + + df.cache() + df.count() + df2.cache() + + assertCacheDependency(df2) + } + + test("SPARK-24596 Non-cascading Cache Invalidation") { + val df = Seq(("a", 1), ("b", 2)).toDF("s", "i") + val df2 = df.filter('i > 1) + val df3 = df.filter('i < 2) + + df2.cache() + df.cache() + df.count() + df3.cache() + + df.unpersist() + + // df un-cached; df2 and df3's cache plan re-compiled + assert(df.storageLevel == StorageLevel.NONE) + assertCacheDependency(df2, 0) + assertCacheDependency(df3, 0) + } + + test("SPARK-24596 Non-cascading Cache Invalidation - verify cached data reuse") { + val expensiveUDF = udf({ x: Int => Thread.sleep(5000); x }) + val df = spark.range(0, 5).toDF("a") + val df1 = df.withColumn("b", expensiveUDF($"a")) + val df2 = df1.groupBy('a).agg(sum('b)) + val df3 = df.agg(sum('a)) + + df1.cache() + df2.cache() + df2.collect() + df3.cache() + + assertCacheDependency(df2) + + df1.unpersist(blocking = true) + + // df1 un-cached; df2's cache plan re-compiled + assert(df1.storageLevel == StorageLevel.NONE) + assertCacheDependency(df1.groupBy('a).agg(sum('b)), 0) + + val df4 = df1.groupBy('a).agg(sum('b)).agg(sum("sum(b)")) + assertCached(df4) + // reuse loaded cache + failAfter(3 seconds) { + checkDataset(df4, Row(10)) + } + + val df5 = df.agg(sum('a)).filter($"sum(a)" > 1) + assertCached(df5) + // first time use, load cache + checkDataset(df5, Row(10)) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index d477d78dc14e3..ce8db99d4e2f1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -1466,6 +1466,38 @@ class DatasetSuite extends QueryTest with SharedSQLContext { val ds = Seq[(Option[Int], Option[Int])]((Some(1), None)).toDS() intercept[NullPointerException](ds.as[(Int, Int)].collect()) } + + test("SPARK-24569: Option of primitive types are mistakenly mapped to struct type") { + withSQLConf(SQLConf.CROSS_JOINS_ENABLED.key -> "true") { + val a = Seq(Some(1)).toDS + val b = Seq(Some(1.2)).toDS + val expected = Seq((Some(1), Some(1.2))).toDS + val joined = a.joinWith(b, lit(true)) + assert(joined.schema == expected.schema) + checkDataset(joined, expected.collect: _*) + } + } + + test("SPARK-24548: Dataset with tuple encoders should have correct schema") { + val encoder = Encoders.tuple(newStringEncoder, + Encoders.tuple(newStringEncoder, newStringEncoder)) + + val data = Seq(("a", ("1", "2")), ("b", ("3", "4"))) + val rdd = sparkContext.parallelize(data) + + val ds1 = spark.createDataset(rdd) + val ds2 = spark.createDataset(rdd)(encoder) + assert(ds1.schema == ds2.schema) + checkDataset(ds1.select("_2._2"), ds2.select("_2._2").collect(): _*) + } + + test("SPARK-24571: filtering of string values by char literal") { + val df = Seq("Amsterdam", "San Francisco", "X").toDF("city") + checkAnswer(df.where('city === 'X'), Seq(Row("X"))) + checkAnswer( + df.where($"city".contains(new java.lang.Character('A'))), + Seq(Row("Amsterdam"))) + } } case class TestDataUnion(x: Int, y: Int, z: Int) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala index 237412aa692e5..3af80b36ec42c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala @@ -663,7 +663,7 @@ class DateFunctionsSuite extends QueryTest with SharedSQLContext { checkAnswer(df.selectExpr("datediff(a, d)"), Seq(Row(1), Row(1))) } - test("from_utc_timestamp") { + test("from_utc_timestamp with literal zone") { val df = Seq( (Timestamp.valueOf("2015-07-24 00:00:00"), "2015-07-24 00:00:00"), (Timestamp.valueOf("2015-07-25 00:00:00"), "2015-07-25 00:00:00") @@ -680,7 +680,24 @@ class DateFunctionsSuite extends QueryTest with SharedSQLContext { Row(Timestamp.valueOf("2015-07-24 17:00:00")))) } - test("to_utc_timestamp") { + test("from_utc_timestamp with column zone") { + val df = Seq( + (Timestamp.valueOf("2015-07-24 00:00:00"), "2015-07-24 00:00:00", "CET"), + (Timestamp.valueOf("2015-07-25 00:00:00"), "2015-07-25 00:00:00", "PST") + ).toDF("a", "b", "c") + checkAnswer( + df.select(from_utc_timestamp(col("a"), col("c"))), + Seq( + Row(Timestamp.valueOf("2015-07-24 02:00:00")), + Row(Timestamp.valueOf("2015-07-24 17:00:00")))) + checkAnswer( + df.select(from_utc_timestamp(col("b"), col("c"))), + Seq( + Row(Timestamp.valueOf("2015-07-24 02:00:00")), + Row(Timestamp.valueOf("2015-07-24 17:00:00")))) + } + + test("to_utc_timestamp with literal zone") { val df = Seq( (Timestamp.valueOf("2015-07-24 00:00:00"), "2015-07-24 00:00:00"), (Timestamp.valueOf("2015-07-25 00:00:00"), "2015-07-25 00:00:00") @@ -697,6 +714,23 @@ class DateFunctionsSuite extends QueryTest with SharedSQLContext { Row(Timestamp.valueOf("2015-07-25 07:00:00")))) } + test("to_utc_timestamp with column zone") { + val df = Seq( + (Timestamp.valueOf("2015-07-24 00:00:00"), "2015-07-24 00:00:00", "PST"), + (Timestamp.valueOf("2015-07-25 00:00:00"), "2015-07-25 00:00:00", "CET") + ).toDF("a", "b", "c") + checkAnswer( + df.select(to_utc_timestamp(col("a"), col("c"))), + Seq( + Row(Timestamp.valueOf("2015-07-24 07:00:00")), + Row(Timestamp.valueOf("2015-07-24 22:00:00")))) + checkAnswer( + df.select(to_utc_timestamp(col("b"), col("c"))), + Seq( + Row(Timestamp.valueOf("2015-07-24 07:00:00")), + Row(Timestamp.valueOf("2015-07-24 22:00:00")))) + } + test("SPARK-23715: to/from_utc_timestamp can retain the previous behavior") { withSQLConf(SQLConf.REJECT_TIMEZONE_IN_STRING.key -> "false") { checkAnswer( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala index 06303099f5310..a7ce952b70ac1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala @@ -17,12 +17,14 @@ package org.apache.spark.sql -import java.io.FileNotFoundException +import java.io.{File, FileNotFoundException} +import java.util.Locale import org.apache.hadoop.fs.Path import org.scalatest.BeforeAndAfterAll import org.apache.spark.SparkException +import org.apache.spark.sql.TestingUDT.{IntervalData, IntervalUDT, NullData, NullUDT} import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext @@ -202,4 +204,258 @@ class FileBasedDataSourceSuite extends QueryTest with SharedSQLContext with Befo } } } + + // Text file format only supports string type + test("SPARK-24691 error handling for unsupported types - text") { + withTempDir { dir => + // write path + val textDir = new File(dir, "text").getCanonicalPath + var msg = intercept[AnalysisException] { + Seq(1).toDF.write.text(textDir) + }.getMessage + assert(msg.contains("Text data source does not support int data type")) + + msg = intercept[AnalysisException] { + Seq(1.2).toDF.write.text(textDir) + }.getMessage + assert(msg.contains("Text data source does not support double data type")) + + msg = intercept[AnalysisException] { + Seq(true).toDF.write.text(textDir) + }.getMessage + assert(msg.contains("Text data source does not support boolean data type")) + + msg = intercept[AnalysisException] { + Seq(1).toDF("a").selectExpr("struct(a)").write.text(textDir) + }.getMessage + assert(msg.contains("Text data source does not support struct data type")) + + msg = intercept[AnalysisException] { + Seq((Map("Tesla" -> 3))).toDF("cars").write.mode("overwrite").text(textDir) + }.getMessage + assert(msg.contains("Text data source does not support map data type")) + + msg = intercept[AnalysisException] { + Seq((Array("Tesla", "Chevy", "Ford"))).toDF("brands") + .write.mode("overwrite").text(textDir) + }.getMessage + assert(msg.contains("Text data source does not support array data type")) + + // read path + Seq("aaa").toDF.write.mode("overwrite").text(textDir) + msg = intercept[AnalysisException] { + val schema = StructType(StructField("a", IntegerType, true) :: Nil) + spark.read.schema(schema).text(textDir).collect() + }.getMessage + assert(msg.contains("Text data source does not support int data type")) + + msg = intercept[AnalysisException] { + val schema = StructType(StructField("a", DoubleType, true) :: Nil) + spark.read.schema(schema).text(textDir).collect() + }.getMessage + assert(msg.contains("Text data source does not support double data type")) + + msg = intercept[AnalysisException] { + val schema = StructType(StructField("a", BooleanType, true) :: Nil) + spark.read.schema(schema).text(textDir).collect() + }.getMessage + assert(msg.contains("Text data source does not support boolean data type")) + } + } + + // Unsupported data types of csv, json, orc, and parquet are as follows; + // csv -> R/W: Null, Array, Map, Struct + // json -> R/W: Interval + // orc -> R/W: Interval, W: Null + // parquet -> R/W: Interval, Null + test("SPARK-24204 error handling for unsupported Array/Map/Struct types - csv") { + withTempDir { dir => + val csvDir = new File(dir, "csv").getCanonicalPath + var msg = intercept[AnalysisException] { + Seq((1, "Tesla")).toDF("a", "b").selectExpr("struct(a, b)").write.csv(csvDir) + }.getMessage + assert(msg.contains("CSV data source does not support struct data type")) + + msg = intercept[AnalysisException] { + val schema = StructType.fromDDL("a struct") + spark.range(1).write.mode("overwrite").csv(csvDir) + spark.read.schema(schema).csv(csvDir).collect() + }.getMessage + assert(msg.contains("CSV data source does not support struct data type")) + + msg = intercept[AnalysisException] { + Seq((1, Map("Tesla" -> 3))).toDF("id", "cars").write.mode("overwrite").csv(csvDir) + }.getMessage + assert(msg.contains("CSV data source does not support map data type")) + + msg = intercept[AnalysisException] { + val schema = StructType.fromDDL("a map") + spark.range(1).write.mode("overwrite").csv(csvDir) + spark.read.schema(schema).csv(csvDir).collect() + }.getMessage + assert(msg.contains("CSV data source does not support map data type")) + + msg = intercept[AnalysisException] { + Seq((1, Array("Tesla", "Chevy", "Ford"))).toDF("id", "brands") + .write.mode("overwrite").csv(csvDir) + }.getMessage + assert(msg.contains("CSV data source does not support array data type")) + + msg = intercept[AnalysisException] { + val schema = StructType.fromDDL("a array") + spark.range(1).write.mode("overwrite").csv(csvDir) + spark.read.schema(schema).csv(csvDir).collect() + }.getMessage + assert(msg.contains("CSV data source does not support array data type")) + + msg = intercept[AnalysisException] { + Seq((1, new UDT.MyDenseVector(Array(0.25, 2.25, 4.25)))).toDF("id", "vectors") + .write.mode("overwrite").csv(csvDir) + }.getMessage + assert(msg.contains("CSV data source does not support mydensevector data type")) + + msg = intercept[AnalysisException] { + val schema = StructType(StructField("a", new UDT.MyDenseVectorUDT(), true) :: Nil) + spark.range(1).write.mode("overwrite").csv(csvDir) + spark.read.schema(schema).csv(csvDir).collect() + }.getMessage + assert(msg.contains("CSV data source does not support mydensevector data type.")) + } + } + + test("SPARK-24204 error handling for unsupported Interval data types - csv, json, parquet, orc") { + withTempDir { dir => + val tempDir = new File(dir, "files").getCanonicalPath + + // write path + Seq("csv", "json", "parquet", "orc").foreach { format => + var msg = intercept[AnalysisException] { + sql("select interval 1 days").write.format(format).mode("overwrite").save(tempDir) + }.getMessage + assert(msg.contains("Cannot save interval data type into external storage.")) + + msg = intercept[AnalysisException] { + spark.udf.register("testType", () => new IntervalData()) + sql("select testType()").write.format(format).mode("overwrite").save(tempDir) + }.getMessage + assert(msg.toLowerCase(Locale.ROOT) + .contains(s"$format data source does not support interval data type.")) + } + + // read path + Seq("parquet", "csv").foreach { format => + var msg = intercept[AnalysisException] { + val schema = StructType(StructField("a", CalendarIntervalType, true) :: Nil) + spark.range(1).write.format(format).mode("overwrite").save(tempDir) + spark.read.schema(schema).format(format).load(tempDir).collect() + }.getMessage + assert(msg.toLowerCase(Locale.ROOT) + .contains(s"$format data source does not support calendarinterval data type.")) + + msg = intercept[AnalysisException] { + val schema = StructType(StructField("a", new IntervalUDT(), true) :: Nil) + spark.range(1).write.format(format).mode("overwrite").save(tempDir) + spark.read.schema(schema).format(format).load(tempDir).collect() + }.getMessage + assert(msg.toLowerCase(Locale.ROOT) + .contains(s"$format data source does not support interval data type.")) + } + } + } + + test("SPARK-24204 error handling for unsupported Null data types - csv, parquet, orc") { + withTempDir { dir => + val tempDir = new File(dir, "files").getCanonicalPath + + Seq("orc").foreach { format => + // write path + var msg = intercept[AnalysisException] { + sql("select null").write.format(format).mode("overwrite").save(tempDir) + }.getMessage + assert(msg.toLowerCase(Locale.ROOT) + .contains(s"$format data source does not support null data type.")) + + msg = intercept[AnalysisException] { + spark.udf.register("testType", () => new NullData()) + sql("select testType()").write.format(format).mode("overwrite").save(tempDir) + }.getMessage + assert(msg.toLowerCase(Locale.ROOT) + .contains(s"$format data source does not support null data type.")) + + // read path + // We expect the types below should be passed for backward-compatibility + + // Null type + var schema = StructType(StructField("a", NullType, true) :: Nil) + spark.range(1).write.format(format).mode("overwrite").save(tempDir) + spark.read.schema(schema).format(format).load(tempDir).collect() + + // UDT having null data + schema = StructType(StructField("a", new NullUDT(), true) :: Nil) + spark.range(1).write.format(format).mode("overwrite").save(tempDir) + spark.read.schema(schema).format(format).load(tempDir).collect() + } + + Seq("parquet", "csv").foreach { format => + // write path + var msg = intercept[AnalysisException] { + sql("select null").write.format(format).mode("overwrite").save(tempDir) + }.getMessage + assert(msg.toLowerCase(Locale.ROOT) + .contains(s"$format data source does not support null data type.")) + + msg = intercept[AnalysisException] { + spark.udf.register("testType", () => new NullData()) + sql("select testType()").write.format(format).mode("overwrite").save(tempDir) + }.getMessage + assert(msg.toLowerCase(Locale.ROOT) + .contains(s"$format data source does not support null data type.")) + + // read path + msg = intercept[AnalysisException] { + val schema = StructType(StructField("a", NullType, true) :: Nil) + spark.range(1).write.format(format).mode("overwrite").save(tempDir) + spark.read.schema(schema).format(format).load(tempDir).collect() + }.getMessage + assert(msg.toLowerCase(Locale.ROOT) + .contains(s"$format data source does not support null data type.")) + + msg = intercept[AnalysisException] { + val schema = StructType(StructField("a", new NullUDT(), true) :: Nil) + spark.range(1).write.format(format).mode("overwrite").save(tempDir) + spark.read.schema(schema).format(format).load(tempDir).collect() + }.getMessage + assert(msg.toLowerCase(Locale.ROOT) + .contains(s"$format data source does not support null data type.")) + } + } + } +} + +object TestingUDT { + + @SQLUserDefinedType(udt = classOf[IntervalUDT]) + class IntervalData extends Serializable + + class IntervalUDT extends UserDefinedType[IntervalData] { + + override def sqlType: DataType = CalendarIntervalType + override def serialize(obj: IntervalData): Any = + throw new NotImplementedError("Not implemented") + override def deserialize(datum: Any): IntervalData = + throw new NotImplementedError("Not implemented") + override def userClass: Class[IntervalData] = classOf[IntervalData] + } + + @SQLUserDefinedType(udt = classOf[NullUDT]) + private[sql] class NullData extends Serializable + + private[sql] class NullUDT extends UserDefinedType[NullData] { + + override def sqlType: DataType = NullType + override def serialize(obj: NullData): Any = throw new NotImplementedError("Not implemented") + override def deserialize(datum: Any): NullData = + throw new NotImplementedError("Not implemented") + override def userClass: Class[NullData] = classOf[NullData] + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/FilterPushdownBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/FilterPushdownBenchmark.scala deleted file mode 100644 index c6dd7dadc9d93..0000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/FilterPushdownBenchmark.scala +++ /dev/null @@ -1,243 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql - -import java.io.File - -import scala.util.{Random, Try} - -import org.apache.spark.SparkConf -import org.apache.spark.sql.functions.monotonically_increasing_id -import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.util.{Benchmark, Utils} - - -/** - * Benchmark to measure read performance with Filter pushdown. - */ -object FilterPushdownBenchmark { - val conf = new SparkConf() - conf.set("orc.compression", "snappy") - conf.set("spark.sql.parquet.compression.codec", "snappy") - - private val spark = SparkSession.builder() - .master("local[1]") - .appName("FilterPushdownBenchmark") - .config(conf) - .getOrCreate() - - def withTempPath(f: File => Unit): Unit = { - val path = Utils.createTempDir() - path.delete() - try f(path) finally Utils.deleteRecursively(path) - } - - def withTempTable(tableNames: String*)(f: => Unit): Unit = { - try f finally tableNames.foreach(spark.catalog.dropTempView) - } - - def withSQLConf(pairs: (String, String)*)(f: => Unit): Unit = { - val (keys, values) = pairs.unzip - val currentValues = keys.map(key => Try(spark.conf.get(key)).toOption) - (keys, values).zipped.foreach(spark.conf.set) - try f finally { - keys.zip(currentValues).foreach { - case (key, Some(value)) => spark.conf.set(key, value) - case (key, None) => spark.conf.unset(key) - } - } - } - - private def prepareTable(dir: File, numRows: Int, width: Int): Unit = { - import spark.implicits._ - val selectExpr = (1 to width).map(i => s"CAST(value AS STRING) c$i") - val df = spark.range(numRows).map(_ => Random.nextLong).selectExpr(selectExpr: _*) - .withColumn("id", monotonically_increasing_id()) - - val dirORC = dir.getCanonicalPath + "/orc" - val dirParquet = dir.getCanonicalPath + "/parquet" - - df.write.mode("overwrite").orc(dirORC) - df.write.mode("overwrite").parquet(dirParquet) - - spark.read.orc(dirORC).createOrReplaceTempView("orcTable") - spark.read.parquet(dirParquet).createOrReplaceTempView("parquetTable") - } - - def filterPushDownBenchmark( - values: Int, - title: String, - whereExpr: String, - selectExpr: String = "*"): Unit = { - val benchmark = new Benchmark(title, values, minNumIters = 5) - - Seq(false, true).foreach { pushDownEnabled => - val name = s"Parquet Vectorized ${if (pushDownEnabled) s"(Pushdown)" else ""}" - benchmark.addCase(name) { _ => - withSQLConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED.key -> s"$pushDownEnabled") { - spark.sql(s"SELECT $selectExpr FROM parquetTable WHERE $whereExpr").collect() - } - } - } - - Seq(false, true).foreach { pushDownEnabled => - val name = s"Native ORC Vectorized ${if (pushDownEnabled) s"(Pushdown)" else ""}" - benchmark.addCase(name) { _ => - withSQLConf(SQLConf.ORC_FILTER_PUSHDOWN_ENABLED.key -> s"$pushDownEnabled") { - spark.sql(s"SELECT $selectExpr FROM orcTable WHERE $whereExpr").collect() - } - } - } - - /* - Java HotSpot(TM) 64-Bit Server VM 1.8.0_152-b16 on Mac OS X 10.13.2 - Intel(R) Core(TM) i7-4770HQ CPU @ 2.20GHz - - Select 0 row (id IS NULL): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ----------------------------------------------------------------------------------------------- - Parquet Vectorized 7882 / 7957 2.0 501.1 1.0X - Parquet Vectorized (Pushdown) 55 / 60 285.2 3.5 142.9X - Native ORC Vectorized 5592 / 5627 2.8 355.5 1.4X - Native ORC Vectorized (Pushdown) 66 / 70 237.2 4.2 118.9X - - Select 0 row (7864320 < id < 7864320): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ----------------------------------------------------------------------------------------------- - Parquet Vectorized 7884 / 7909 2.0 501.2 1.0X - Parquet Vectorized (Pushdown) 739 / 752 21.3 47.0 10.7X - Native ORC Vectorized 5614 / 5646 2.8 356.9 1.4X - Native ORC Vectorized (Pushdown) 81 / 83 195.2 5.1 97.8X - - Select 1 row (id = 7864320): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ----------------------------------------------------------------------------------------------- - Parquet Vectorized 7905 / 8027 2.0 502.6 1.0X - Parquet Vectorized (Pushdown) 740 / 766 21.2 47.1 10.7X - Native ORC Vectorized 5684 / 5738 2.8 361.4 1.4X - Native ORC Vectorized (Pushdown) 78 / 81 202.4 4.9 101.7X - - Select 1 row (id <=> 7864320): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ----------------------------------------------------------------------------------------------- - Parquet Vectorized 7928 / 7993 2.0 504.1 1.0X - Parquet Vectorized (Pushdown) 747 / 772 21.0 47.5 10.6X - Native ORC Vectorized 5728 / 5753 2.7 364.2 1.4X - Native ORC Vectorized (Pushdown) 76 / 78 207.9 4.8 104.8X - - Select 1 row (7864320 <= id <= 7864320):Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ----------------------------------------------------------------------------------------------- - Parquet Vectorized 7939 / 8021 2.0 504.8 1.0X - Parquet Vectorized (Pushdown) 746 / 770 21.1 47.4 10.6X - Native ORC Vectorized 5690 / 5734 2.8 361.7 1.4X - Native ORC Vectorized (Pushdown) 76 / 79 206.7 4.8 104.3X - - Select 1 row (7864319 < id < 7864321): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ----------------------------------------------------------------------------------------------- - Parquet Vectorized 7972 / 8019 2.0 506.9 1.0X - Parquet Vectorized (Pushdown) 742 / 764 21.2 47.2 10.7X - Native ORC Vectorized 5704 / 5743 2.8 362.6 1.4X - Native ORC Vectorized (Pushdown) 76 / 78 207.9 4.8 105.4X - - Select 10% rows (id < 1572864): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ----------------------------------------------------------------------------------------------- - Parquet Vectorized 8733 / 8808 1.8 555.2 1.0X - Parquet Vectorized (Pushdown) 2213 / 2267 7.1 140.7 3.9X - Native ORC Vectorized 6420 / 6463 2.4 408.2 1.4X - Native ORC Vectorized (Pushdown) 1313 / 1331 12.0 83.5 6.7X - - Select 50% rows (id < 7864320): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ----------------------------------------------------------------------------------------------- - Parquet Vectorized 11518 / 11591 1.4 732.3 1.0X - Parquet Vectorized (Pushdown) 7962 / 7991 2.0 506.2 1.4X - Native ORC Vectorized 8927 / 8985 1.8 567.6 1.3X - Native ORC Vectorized (Pushdown) 6102 / 6160 2.6 387.9 1.9X - - Select 90% rows (id < 14155776): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ----------------------------------------------------------------------------------------------- - Parquet Vectorized 14255 / 14389 1.1 906.3 1.0X - Parquet Vectorized (Pushdown) 13564 / 13594 1.2 862.4 1.1X - Native ORC Vectorized 11442 / 11608 1.4 727.5 1.2X - Native ORC Vectorized (Pushdown) 10991 / 11029 1.4 698.8 1.3X - - Select all rows (id IS NOT NULL): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ----------------------------------------------------------------------------------------------- - Parquet Vectorized 14917 / 14938 1.1 948.4 1.0X - Parquet Vectorized (Pushdown) 14910 / 14964 1.1 948.0 1.0X - Native ORC Vectorized 11986 / 12069 1.3 762.0 1.2X - Native ORC Vectorized (Pushdown) 12037 / 12123 1.3 765.3 1.2X - - Select all rows (id > -1): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ----------------------------------------------------------------------------------------------- - Parquet Vectorized 14951 / 14976 1.1 950.6 1.0X - Parquet Vectorized (Pushdown) 14934 / 15016 1.1 949.5 1.0X - Native ORC Vectorized 12000 / 12156 1.3 763.0 1.2X - Native ORC Vectorized (Pushdown) 12079 / 12113 1.3 767.9 1.2X - - Select all rows (id != -1): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ----------------------------------------------------------------------------------------------- - Parquet Vectorized 14930 / 14972 1.1 949.3 1.0X - Parquet Vectorized (Pushdown) 15015 / 15047 1.0 954.6 1.0X - Native ORC Vectorized 12090 / 12259 1.3 768.7 1.2X - Native ORC Vectorized (Pushdown) 12021 / 12096 1.3 764.2 1.2X - */ - benchmark.run() - } - - def main(args: Array[String]): Unit = { - val numRows = 1024 * 1024 * 15 - val width = 5 - val mid = numRows / 2 - - withTempPath { dir => - withTempTable("orcTable", "patquetTable") { - prepareTable(dir, numRows, width) - - Seq("id IS NULL", s"$mid < id AND id < $mid").foreach { whereExpr => - val title = s"Select 0 row ($whereExpr)".replace("id AND id", "id") - filterPushDownBenchmark(numRows, title, whereExpr) - } - - Seq( - s"id = $mid", - s"id <=> $mid", - s"$mid <= id AND id <= $mid", - s"${mid - 1} < id AND id < ${mid + 1}" - ).foreach { whereExpr => - val title = s"Select 1 row ($whereExpr)".replace("id AND id", "id") - filterPushDownBenchmark(numRows, title, whereExpr) - } - - val selectExpr = (1 to width).map(i => s"MAX(c$i)").mkString("", ",", ", MAX(id)") - - Seq(10, 50, 90).foreach { percent => - filterPushDownBenchmark( - numRows, - s"Select $percent% rows (id < ${numRows * percent / 100})", - s"id < ${numRows * percent / 100}", - selectExpr - ) - } - - Seq("id IS NOT NULL", "id > -1", "id != -1").foreach { whereExpr => - filterPushDownBenchmark( - numRows, - s"Select all rows ($whereExpr)", - whereExpr, - selectExpr) - } - } - } - } -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index 8fa747465cb1a..44767dfc92497 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -882,4 +882,15 @@ class JoinSuite extends QueryTest with SharedSQLContext { checkAnswer(df, Row(3, 8, 7, 2) :: Row(3, 8, 4, 2) :: Nil) } } + + test("SPARK-24495: Join may return wrong result when having duplicated equal-join keys") { + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1", + SQLConf.CONSTRAINT_PROPAGATION_ENABLED.key -> "false", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { + val df1 = spark.range(0, 100, 1, 2) + val df2 = spark.range(100).select($"id".as("b1"), (- $"id").as("b2")) + val res = df1.join(df2, $"id" === $"b1" && $"id" === $"b2").select($"b1", $"b2", $"id") + checkAnswer(res, Row(0, 0, 0)) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala index 055e1fc5640f3..d3b2701f2558e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql -import org.apache.spark.sql.functions.{from_json, lit, map, struct, to_json} +import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ @@ -311,7 +311,7 @@ class JsonFunctionsSuite extends QueryTest with SharedSQLContext { val errMsg1 = intercept[AnalysisException] { df3.selectExpr("from_json(value, 1)") } - assert(errMsg1.getMessage.startsWith("Expected a string literal instead of")) + assert(errMsg1.getMessage.startsWith("Schema should be specified in DDL format as a string")) val errMsg2 = intercept[AnalysisException] { df3.selectExpr("""from_json(value, 'time InvalidType')""") } @@ -354,8 +354,8 @@ class JsonFunctionsSuite extends QueryTest with SharedSQLContext { test("SPARK-24027: from_json - map>") { val in = Seq("""{"a": {"b": 1}}""").toDS() - val schema = MapType(StringType, MapType(StringType, IntegerType)) - val out = in.select(from_json($"value", schema)) + val schema = "map>" + val out = in.select(from_json($"value", schema, Map.empty[String, String])) checkAnswer(out, Row(Map("a" -> Map("b" -> 1)))) } @@ -392,4 +392,17 @@ class JsonFunctionsSuite extends QueryTest with SharedSQLContext { checkAnswer(Seq("""{"{"f": 1}": "a"}""").toDS().select(from_json($"value", schema)), Row(null)) } + + test("SPARK-24709: infers schemas of json strings and pass them to from_json") { + val in = Seq("""{"a": [1, 2, 3]}""").toDS() + val out = in.select(from_json('value, schema_of_json(lit("""{"a": [1]}"""))) as "parsed") + val expected = StructType(StructField( + "parsed", + StructType(StructField( + "a", + ArrayType(LongType, true), true) :: Nil), + true) :: Nil) + + assert(out.schema == expected) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/RuntimeConfigSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/RuntimeConfigSuite.scala index cfe2e9f2dbc44..cdcea09ad9758 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/RuntimeConfigSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/RuntimeConfigSuite.scala @@ -54,4 +54,18 @@ class RuntimeConfigSuite extends SparkFunSuite { conf.get("k1") } } + + test("SPARK-24761: is a config parameter modifiable") { + val conf = newConf() + + // SQL configs + assert(!conf.isModifiable("spark.sql.sources.schemaStringLengthThreshold")) + assert(conf.isModifiable("spark.sql.streaming.checkpointLocation")) + // Core configs + assert(!conf.isModifiable("spark.task.cpus")) + assert(!conf.isModifiable("spark.executor.cores")) + // Invalid config parameters + assert(!conf.isModifiable("")) + assert(!conf.isModifiable("invalid config parameter")) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 640affc10ee58..dfb9c137b74f0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -2792,4 +2792,25 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } } } + + test("SPARK-24696 ColumnPruning rule fails to remove extra Project") { + withTable("fact_stats", "dim_stats") { + val factData = Seq((1, 1, 99, 1), (2, 2, 99, 2), (3, 1, 99, 3), (4, 2, 99, 4)) + val storeData = Seq((1, "BW", "DE"), (2, "AZ", "US")) + spark.udf.register("filterND", udf((value: Int) => value > 2).asNondeterministic) + factData.toDF("date_id", "store_id", "product_id", "units_sold") + .write.mode("overwrite").partitionBy("store_id").format("parquet").saveAsTable("fact_stats") + storeData.toDF("store_id", "state_province", "country") + .write.mode("overwrite").format("parquet").saveAsTable("dim_stats") + val df = sql( + """ + |SELECT f.date_id, f.product_id, f.store_id FROM + |(SELECT date_id, product_id, store_id + | FROM fact_stats WHERE filterND(date_id)) AS f + |JOIN dim_stats s + |ON f.store_id = s.store_id WHERE s.country = 'DE' + """.stripMargin) + checkAnswer(df, Seq(Row(3, 99, 1))) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala index beac9699585d5..826408c7161e9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala @@ -54,6 +54,7 @@ import org.apache.spark.sql.types.StructType * The format for input files is simple: * 1. A list of SQL queries separated by semicolon. * 2. Lines starting with -- are treated as comments and ignored. + * 3. Lines starting with --SET are used to run the file with the following set of configs. * * For example: * {{{ @@ -138,18 +139,58 @@ class SQLQueryTestSuite extends QueryTest with SharedSQLContext { private def runTest(testCase: TestCase): Unit = { val input = fileToString(new File(testCase.inputFile)) + val (comments, code) = input.split("\n").partition(_.startsWith("--")) + val configSets = { + val configLines = comments.filter(_.startsWith("--SET")).map(_.substring(5)) + val configs = configLines.map(_.split(",").map { confAndValue => + val (conf, value) = confAndValue.span(_ != '=') + conf.trim -> value.substring(1).trim + }) + // When we are regenerating the golden files we don't need to run all the configs as they + // all need to return the same result + if (regenerateGoldenFiles && configs.nonEmpty) { + configs.take(1) + } else { + configs + } + } // List of SQL queries to run - val queries: Seq[String] = { - val cleaned = input.split("\n").filterNot(_.startsWith("--")).mkString("\n") - // note: this is not a robust way to split queries using semicolon, but works for now. - cleaned.split("(?<=[^\\\\]);").map(_.trim).filter(_ != "").toSeq + // note: this is not a robust way to split queries using semicolon, but works for now. + val queries = code.mkString("\n").split("(?<=[^\\\\]);").map(_.trim).filter(_ != "").toSeq + + if (configSets.isEmpty) { + runQueries(queries, testCase.resultFile, None) + } else { + configSets.foreach { configSet => + try { + runQueries(queries, testCase.resultFile, Some(configSet)) + } catch { + case e: Throwable => + val configs = configSet.map { + case (k, v) => s"$k=$v" + } + logError(s"Error using configs: ${configs.mkString(",")}") + throw e + } + } } + } + private def runQueries( + queries: Seq[String], + resultFileName: String, + configSet: Option[Seq[(String, String)]]): Unit = { // Create a local SparkSession to have stronger isolation between different test cases. // This does not isolate catalog changes. val localSparkSession = spark.newSession() loadTestData(localSparkSession) + if (configSet.isDefined) { + // Execute the list of set operation in order to add the desired configs + val setOperations = configSet.get.map { case (key, value) => s"set $key=$value" } + logInfo(s"Setting configs: ${setOperations.mkString(", ")}") + setOperations.foreach(localSparkSession.sql) + } // Run the SQL queries preparing them for comparison. val outputs: Seq[QueryOutput] = queries.map { sql => val (schema, output) = getNormalizedResult(localSparkSession, sql) @@ -167,7 +208,7 @@ class SQLQueryTestSuite extends QueryTest with SharedSQLContext { s"-- Number of queries: ${outputs.size}\n\n\n" + outputs.zipWithIndex.map{case (qr, i) => qr.toString(i)}.mkString("\n\n\n") + "\n" } - val resultFile = new File(testCase.resultFile) + val resultFile = new File(resultFileName) val parent = resultFile.getParentFile if (!parent.exists()) { assert(parent.mkdirs(), "Could not create directory: " + parent) @@ -177,7 +218,7 @@ class SQLQueryTestSuite extends QueryTest with SharedSQLContext { // Read back the golden file. val expectedOutputs: Seq[QueryOutput] = { - val goldenOutput = fileToString(new File(testCase.resultFile)) + val goldenOutput = fileToString(new File(resultFileName)) val segments = goldenOutput.split("-- !query.+\n") // each query has 3 segments, plus the header diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index 98a50fbd52b4d..d254345e8fa54 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -18,13 +18,13 @@ package org.apache.spark.sql.execution import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{execution, Row} +import org.apache.spark.sql.{execution, DataFrame, Row} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.{Cross, FullOuter, Inner, LeftOuter, RightOuter} -import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Range, Repartition, Sort} +import org.apache.spark.sql.catalyst.plans._ +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Range, Repartition, Sort, Union} import org.apache.spark.sql.catalyst.plans.physical._ -import org.apache.spark.sql.execution.columnar.InMemoryRelation +import org.apache.spark.sql.execution.columnar.{InMemoryRelation, InMemoryTableScanExec} import org.apache.spark.sql.execution.exchange.{EnsureRequirements, ReusedExchangeExec, ReuseExchange, ShuffleExchangeExec} import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, SortMergeJoinExec} import org.apache.spark.sql.functions._ @@ -679,6 +679,90 @@ class PlannerSuite extends SharedSQLContext { } assert(rangeExecInZeroPartition.head.outputPartitioning == UnknownPartitioning(0)) } + + test("SPARK-24495: EnsureRequirements can return wrong plan when reusing the same key in join") { + val plan1 = DummySparkPlan(outputOrdering = Seq(orderingA), + outputPartitioning = HashPartitioning(exprA :: exprA :: Nil, 5)) + val plan2 = DummySparkPlan(outputOrdering = Seq(orderingB), + outputPartitioning = HashPartitioning(exprB :: Nil, 5)) + val smjExec = SortMergeJoinExec( + exprA :: exprA :: Nil, exprB :: exprC :: Nil, Inner, None, plan1, plan2) + + val outputPlan = EnsureRequirements(spark.sessionState.conf).apply(smjExec) + outputPlan match { + case SortMergeJoinExec(leftKeys, rightKeys, _, _, _, _) => + assert(leftKeys == Seq(exprA, exprA)) + assert(rightKeys == Seq(exprB, exprC)) + case _ => fail() + } + } + + test("SPARK-24500: create union with stream of children") { + val df = Union(Stream( + Range(1, 1, 1, 1), + Range(1, 2, 1, 1))) + df.queryExecution.executedPlan.execute() + } + + test("SPARK-24556: always rewrite output partitioning in ReusedExchangeExec " + + "and InMemoryTableScanExec") { + def checkOutputPartitioningRewrite( + plans: Seq[SparkPlan], + expectedPartitioningClass: Class[_]): Unit = { + assert(plans.size == 1) + val plan = plans.head + val partitioning = plan.outputPartitioning + assert(partitioning.getClass == expectedPartitioningClass) + val partitionedAttrs = partitioning.asInstanceOf[Expression].references + assert(partitionedAttrs.subsetOf(plan.outputSet)) + } + + def checkReusedExchangeOutputPartitioningRewrite( + df: DataFrame, + expectedPartitioningClass: Class[_]): Unit = { + val reusedExchange = df.queryExecution.executedPlan.collect { + case r: ReusedExchangeExec => r + } + checkOutputPartitioningRewrite(reusedExchange, expectedPartitioningClass) + } + + def checkInMemoryTableScanOutputPartitioningRewrite( + df: DataFrame, + expectedPartitioningClass: Class[_]): Unit = { + val inMemoryScan = df.queryExecution.executedPlan.collect { + case m: InMemoryTableScanExec => m + } + checkOutputPartitioningRewrite(inMemoryScan, expectedPartitioningClass) + } + + // ReusedExchange is HashPartitioning + val df1 = Seq(1 -> "a").toDF("i", "j").repartition($"i") + val df2 = Seq(1 -> "a").toDF("i", "j").repartition($"i") + checkReusedExchangeOutputPartitioningRewrite(df1.union(df2), classOf[HashPartitioning]) + + // ReusedExchange is RangePartitioning + val df3 = Seq(1 -> "a").toDF("i", "j").orderBy($"i") + val df4 = Seq(1 -> "a").toDF("i", "j").orderBy($"i") + checkReusedExchangeOutputPartitioningRewrite(df3.union(df4), classOf[RangePartitioning]) + + // InMemoryTableScan is HashPartitioning + Seq(1 -> "a").toDF("i", "j").repartition($"i").persist() + checkInMemoryTableScanOutputPartitioningRewrite( + Seq(1 -> "a").toDF("i", "j").repartition($"i"), classOf[HashPartitioning]) + + // InMemoryTableScan is RangePartitioning + spark.range(1, 100, 1, 10).toDF().persist() + checkInMemoryTableScanOutputPartitioningRewrite( + spark.range(1, 100, 1, 10).toDF(), classOf[RangePartitioning]) + + // InMemoryTableScan is PartitioningCollection + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { + Seq(1 -> "a").toDF("i", "j").join(Seq(1 -> "a").toDF("m", "n"), $"i" === $"m").persist() + checkInMemoryTableScanOutputPartitioningRewrite( + Seq(1 -> "a").toDF("i", "j").join(Seq(1 -> "a").toDF("m", "n"), $"i" === $"m"), + classOf[PartitioningCollection]) + } + } } // Used for unit-testing EnsureRequirements diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala index 3e31d22e15c0e..5c15ecd42fa0c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala @@ -23,6 +23,7 @@ import scala.collection.mutable import scala.util.{Random, Try} import scala.util.control.NonFatal +import org.mockito.Mockito._ import org.scalatest.Matchers import org.apache.spark.{SparkConf, SparkFunSuite, TaskContext, TaskContextImpl} @@ -54,6 +55,8 @@ class UnsafeFixedWidthAggregationMapSuite private var memoryManager: TestMemoryManager = null private var taskMemoryManager: TaskMemoryManager = null + private var taskContext: TaskContext = null + def testWithMemoryLeakDetection(name: String)(f: => Unit) { def cleanup(): Unit = { if (taskMemoryManager != null) { @@ -67,6 +70,8 @@ class UnsafeFixedWidthAggregationMapSuite val conf = new SparkConf().set(MEMORY_OFFHEAP_ENABLED.key, "false") memoryManager = new TestMemoryManager(conf) taskMemoryManager = new TaskMemoryManager(memoryManager, 0) + taskContext = mock(classOf[TaskContext]) + when(taskContext.taskMemoryManager()).thenReturn(taskMemoryManager) TaskContext.setTaskContext(new TaskContextImpl( stageId = 0, @@ -111,7 +116,7 @@ class UnsafeFixedWidthAggregationMapSuite emptyAggregationBuffer, aggBufferSchema, groupKeySchema, - taskMemoryManager, + taskContext, 1024, // initial capacity, PAGE_SIZE_BYTES ) @@ -124,7 +129,7 @@ class UnsafeFixedWidthAggregationMapSuite emptyAggregationBuffer, aggBufferSchema, groupKeySchema, - taskMemoryManager, + taskContext, 1024, // initial capacity PAGE_SIZE_BYTES ) @@ -151,7 +156,7 @@ class UnsafeFixedWidthAggregationMapSuite emptyAggregationBuffer, aggBufferSchema, groupKeySchema, - taskMemoryManager, + taskContext, 128, // initial capacity PAGE_SIZE_BYTES ) @@ -176,7 +181,7 @@ class UnsafeFixedWidthAggregationMapSuite emptyAggregationBuffer, aggBufferSchema, groupKeySchema, - taskMemoryManager, + taskContext, 128, // initial capacity PAGE_SIZE_BYTES ) @@ -223,7 +228,7 @@ class UnsafeFixedWidthAggregationMapSuite emptyAggregationBuffer, aggBufferSchema, groupKeySchema, - taskMemoryManager, + taskContext, 128, // initial capacity PAGE_SIZE_BYTES ) @@ -263,7 +268,7 @@ class UnsafeFixedWidthAggregationMapSuite emptyAggregationBuffer, StructType(Nil), StructType(Nil), - taskMemoryManager, + taskContext, 128, // initial capacity PAGE_SIZE_BYTES ) @@ -307,7 +312,7 @@ class UnsafeFixedWidthAggregationMapSuite emptyAggregationBuffer, aggBufferSchema, groupKeySchema, - taskMemoryManager, + taskContext, 128, // initial capacity pageSize ) @@ -344,7 +349,7 @@ class UnsafeFixedWidthAggregationMapSuite emptyAggregationBuffer, aggBufferSchema, groupKeySchema, - taskMemoryManager, + taskContext, 128, // initial capacity pageSize ) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/DataSourceReadBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/DataSourceReadBenchmark.scala index fc6d8abc03c09..8711f5a8fa1ce 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/DataSourceReadBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/DataSourceReadBenchmark.scala @@ -39,9 +39,11 @@ import org.apache.spark.util.{Benchmark, Utils} object DataSourceReadBenchmark { val conf = new SparkConf() .setAppName("DataSourceReadBenchmark") - .setIfMissing("spark.master", "local[1]") + // Since `spark.master` always exists, overrides this value + .set("spark.master", "local[1]") .setIfMissing("spark.driver.memory", "3g") .setIfMissing("spark.executor.memory", "3g") + .setIfMissing("spark.ui.enabled", "false") val spark = SparkSession.builder.config(conf).getOrCreate() @@ -154,73 +156,73 @@ object DataSourceReadBenchmark { } } - /* - Intel(R) Xeon(R) CPU E5-2686 v4 @ 2.30GHz + OpenJDK 64-Bit Server VM 1.8.0_171-b10 on Linux 4.14.33-51.37.amzn1.x86_64 + Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz SQL Single TINYINT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative -------------------------------------------------------------------------------------------- - SQL CSV 15231 / 15267 1.0 968.3 1.0X - SQL Json 8476 / 8498 1.9 538.9 1.8X - SQL Parquet Vectorized 121 / 127 130.0 7.7 125.9X - SQL Parquet MR 1515 / 1543 10.4 96.3 10.1X - SQL ORC Vectorized 164 / 171 95.9 10.4 92.9X - SQL ORC Vectorized with copy 228 / 234 69.0 14.5 66.8X - SQL ORC MR 1297 / 1309 12.1 82.5 11.7X + SQL CSV 22964 / 23096 0.7 1460.0 1.0X + SQL Json 8469 / 8593 1.9 538.4 2.7X + SQL Parquet Vectorized 164 / 177 95.8 10.4 139.9X + SQL Parquet MR 1687 / 1706 9.3 107.2 13.6X + SQL ORC Vectorized 191 / 197 82.3 12.2 120.2X + SQL ORC Vectorized with copy 215 / 219 73.2 13.7 106.9X + SQL ORC MR 1392 / 1412 11.3 88.5 16.5X SQL Single SMALLINT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative -------------------------------------------------------------------------------------------- - SQL CSV 16344 / 16374 1.0 1039.1 1.0X - SQL Json 8634 / 8648 1.8 548.9 1.9X - SQL Parquet Vectorized 172 / 177 91.5 10.9 95.1X - SQL Parquet MR 1744 / 1746 9.0 110.9 9.4X - SQL ORC Vectorized 189 / 194 83.1 12.0 86.4X - SQL ORC Vectorized with copy 244 / 250 64.5 15.5 67.0X - SQL ORC MR 1341 / 1386 11.7 85.3 12.2X + SQL CSV 24090 / 24097 0.7 1531.6 1.0X + SQL Json 8791 / 8813 1.8 558.9 2.7X + SQL Parquet Vectorized 204 / 212 77.0 13.0 117.9X + SQL Parquet MR 1813 / 1850 8.7 115.3 13.3X + SQL ORC Vectorized 226 / 230 69.7 14.4 106.7X + SQL ORC Vectorized with copy 295 / 298 53.3 18.8 81.6X + SQL ORC MR 1526 / 1549 10.3 97.1 15.8X SQL Single INT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative -------------------------------------------------------------------------------------------- - SQL CSV 17874 / 17875 0.9 1136.4 1.0X - SQL Json 9190 / 9204 1.7 584.3 1.9X - SQL Parquet Vectorized 141 / 160 111.2 9.0 126.4X - SQL Parquet MR 1930 / 2049 8.2 122.7 9.3X - SQL ORC Vectorized 259 / 264 60.7 16.5 69.0X - SQL ORC Vectorized with copy 265 / 272 59.4 16.8 67.5X - SQL ORC MR 1528 / 1569 10.3 97.2 11.7X + SQL CSV 25637 / 25791 0.6 1629.9 1.0X + SQL Json 9532 / 9570 1.7 606.0 2.7X + SQL Parquet Vectorized 181 / 191 86.8 11.5 141.5X + SQL Parquet MR 2210 / 2227 7.1 140.5 11.6X + SQL ORC Vectorized 309 / 317 50.9 19.6 83.0X + SQL ORC Vectorized with copy 316 / 322 49.8 20.1 81.2X + SQL ORC MR 1650 / 1680 9.5 104.9 15.5X SQL Single BIGINT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative -------------------------------------------------------------------------------------------- - SQL CSV 22812 / 22839 0.7 1450.4 1.0X - SQL Json 12026 / 12054 1.3 764.6 1.9X - SQL Parquet Vectorized 222 / 227 70.8 14.1 102.6X - SQL Parquet MR 2199 / 2204 7.2 139.8 10.4X - SQL ORC Vectorized 331 / 335 47.6 21.0 69.0X - SQL ORC Vectorized with copy 338 / 343 46.6 21.5 67.6X - SQL ORC MR 1618 / 1622 9.7 102.9 14.1X + SQL CSV 31617 / 31764 0.5 2010.1 1.0X + SQL Json 12440 / 12451 1.3 790.9 2.5X + SQL Parquet Vectorized 284 / 315 55.4 18.0 111.4X + SQL Parquet MR 2382 / 2390 6.6 151.5 13.3X + SQL ORC Vectorized 398 / 403 39.5 25.3 79.5X + SQL ORC Vectorized with copy 410 / 413 38.3 26.1 77.1X + SQL ORC MR 1783 / 1813 8.8 113.4 17.7X SQL Single FLOAT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative -------------------------------------------------------------------------------------------- - SQL CSV 18703 / 18740 0.8 1189.1 1.0X - SQL Json 11779 / 11869 1.3 748.9 1.6X - SQL Parquet Vectorized 143 / 145 110.1 9.1 130.9X - SQL Parquet MR 1954 / 1963 8.0 124.2 9.6X - SQL ORC Vectorized 347 / 355 45.3 22.1 53.8X - SQL ORC Vectorized with copy 356 / 359 44.1 22.7 52.5X - SQL ORC MR 1570 / 1598 10.0 99.8 11.9X + SQL CSV 26679 / 26742 0.6 1696.2 1.0X + SQL Json 12490 / 12541 1.3 794.1 2.1X + SQL Parquet Vectorized 174 / 183 90.4 11.1 153.3X + SQL Parquet MR 2201 / 2223 7.1 140.0 12.1X + SQL ORC Vectorized 415 / 429 37.9 26.4 64.3X + SQL ORC Vectorized with copy 422 / 428 37.2 26.9 63.2X + SQL ORC MR 1767 / 1773 8.9 112.3 15.1X SQL Single DOUBLE Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative -------------------------------------------------------------------------------------------- - SQL CSV 23832 / 23838 0.7 1515.2 1.0X - SQL Json 16204 / 16226 1.0 1030.2 1.5X - SQL Parquet Vectorized 242 / 306 65.1 15.4 98.6X - SQL Parquet MR 2462 / 2482 6.4 156.5 9.7X - SQL ORC Vectorized 419 / 451 37.6 26.6 56.9X - SQL ORC Vectorized with copy 426 / 447 36.9 27.1 55.9X - SQL ORC MR 1885 / 1931 8.3 119.8 12.6X + SQL CSV 34223 / 34324 0.5 2175.8 1.0X + SQL Json 17784 / 17785 0.9 1130.7 1.9X + SQL Parquet Vectorized 277 / 283 56.7 17.6 123.4X + SQL Parquet MR 2356 / 2386 6.7 149.8 14.5X + SQL ORC Vectorized 533 / 536 29.5 33.9 64.2X + SQL ORC Vectorized with copy 541 / 546 29.1 34.4 63.3X + SQL ORC MR 2166 / 2177 7.3 137.7 15.8X */ sqlBenchmark.run() @@ -294,41 +296,42 @@ object DataSourceReadBenchmark { } /* - Intel(R) Xeon(R) CPU E5-2686 v4 @ 2.30GHz + OpenJDK 64-Bit Server VM 1.8.0_171-b10 on Linux 4.14.33-51.37.amzn1.x86_64 + Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz Single TINYINT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative -------------------------------------------------------------------------------------------- - ParquetReader Vectorized 187 / 201 84.2 11.9 1.0X - ParquetReader Vectorized -> Row 101 / 103 156.4 6.4 1.9X + ParquetReader Vectorized 198 / 202 79.4 12.6 1.0X + ParquetReader Vectorized -> Row 119 / 121 132.3 7.6 1.7X Single SMALLINT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative -------------------------------------------------------------------------------------------- - ParquetReader Vectorized 272 / 288 57.8 17.3 1.0X - ParquetReader Vectorized -> Row 213 / 219 73.7 13.6 1.3X + ParquetReader Vectorized 282 / 287 55.8 17.9 1.0X + ParquetReader Vectorized -> Row 246 / 247 64.0 15.6 1.1X Single INT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative -------------------------------------------------------------------------------------------- - ParquetReader Vectorized 252 / 288 62.5 16.0 1.0X - ParquetReader Vectorized -> Row 232 / 246 67.7 14.8 1.1X + ParquetReader Vectorized 258 / 262 60.9 16.4 1.0X + ParquetReader Vectorized -> Row 259 / 260 60.8 16.5 1.0X Single BIGINT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative -------------------------------------------------------------------------------------------- - ParquetReader Vectorized 415 / 454 37.9 26.4 1.0X - ParquetReader Vectorized -> Row 407 / 432 38.6 25.9 1.0X + ParquetReader Vectorized 361 / 369 43.6 23.0 1.0X + ParquetReader Vectorized -> Row 361 / 371 43.6 22.9 1.0X Single FLOAT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative -------------------------------------------------------------------------------------------- - ParquetReader Vectorized 251 / 302 62.7 16.0 1.0X - ParquetReader Vectorized -> Row 220 / 234 71.5 14.0 1.1X + ParquetReader Vectorized 253 / 261 62.2 16.1 1.0X + ParquetReader Vectorized -> Row 254 / 256 61.9 16.2 1.0X Single DOUBLE Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative -------------------------------------------------------------------------------------------- - ParquetReader Vectorized 432 / 436 36.4 27.5 1.0X - ParquetReader Vectorized -> Row 414 / 422 38.0 26.4 1.0X + ParquetReader Vectorized 357 / 364 44.0 22.7 1.0X + ParquetReader Vectorized -> Row 358 / 366 44.0 22.7 1.0X */ parquetReaderBenchmark.run() } @@ -382,16 +385,17 @@ object DataSourceReadBenchmark { } /* - Intel(R) Xeon(R) CPU E5-2686 v4 @ 2.30GHz + OpenJDK 64-Bit Server VM 1.8.0_171-b10 on Linux 4.14.33-51.37.amzn1.x86_64 + Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz Int and String Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative -------------------------------------------------------------------------------------------- - SQL CSV 19172 / 19173 0.5 1828.4 1.0X - SQL Json 12799 / 12873 0.8 1220.6 1.5X - SQL Parquet Vectorized 2558 / 2564 4.1 244.0 7.5X - SQL Parquet MR 4514 / 4583 2.3 430.4 4.2X - SQL ORC Vectorized 2561 / 2697 4.1 244.3 7.5X - SQL ORC Vectorized with copy 3076 / 3110 3.4 293.4 6.2X - SQL ORC MR 4197 / 4283 2.5 400.2 4.6X + SQL CSV 27145 / 27158 0.4 2588.7 1.0X + SQL Json 12969 / 13337 0.8 1236.8 2.1X + SQL Parquet Vectorized 2419 / 2448 4.3 230.7 11.2X + SQL Parquet MR 4631 / 4633 2.3 441.7 5.9X + SQL ORC Vectorized 2412 / 2465 4.3 230.0 11.3X + SQL ORC Vectorized with copy 2633 / 2675 4.0 251.1 10.3X + SQL ORC MR 4280 / 4350 2.4 408.2 6.3X */ benchmark.run() } @@ -445,16 +449,17 @@ object DataSourceReadBenchmark { } /* - Intel(R) Xeon(R) CPU E5-2686 v4 @ 2.30GHz + OpenJDK 64-Bit Server VM 1.8.0_171-b10 on Linux 4.14.33-51.37.amzn1.x86_64 + Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz Repeated String: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative -------------------------------------------------------------------------------------------- - SQL CSV 10889 / 10924 1.0 1038.5 1.0X - SQL Json 7903 / 7931 1.3 753.7 1.4X - SQL Parquet Vectorized 777 / 799 13.5 74.1 14.0X - SQL Parquet MR 1682 / 1708 6.2 160.4 6.5X - SQL ORC Vectorized 532 / 534 19.7 50.7 20.5X - SQL ORC Vectorized with copy 742 / 743 14.1 70.7 14.7X - SQL ORC MR 1996 / 2002 5.3 190.4 5.5X + SQL CSV 17345 / 17424 0.6 1654.1 1.0X + SQL Json 8639 / 8664 1.2 823.9 2.0X + SQL Parquet Vectorized 839 / 854 12.5 80.0 20.7X + SQL Parquet MR 1771 / 1775 5.9 168.9 9.8X + SQL ORC Vectorized 550 / 569 19.1 52.4 31.6X + SQL ORC Vectorized with copy 785 / 849 13.4 74.9 22.1X + SQL ORC MR 2168 / 2202 4.8 206.7 8.0X */ benchmark.run() } @@ -574,30 +579,31 @@ object DataSourceReadBenchmark { } /* - Intel(R) Xeon(R) CPU E5-2686 v4 @ 2.30GHz + OpenJDK 64-Bit Server VM 1.8.0_171-b10 on Linux 4.14.33-51.37.amzn1.x86_64 + Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz Partitioned Table: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative -------------------------------------------------------------------------------------------- - Data column - CSV 25428 / 25454 0.6 1616.7 1.0X - Data column - Json 12689 / 12774 1.2 806.7 2.0X - Data column - Parquet Vectorized 222 / 231 70.7 14.1 114.3X - Data column - Parquet MR 3355 / 3397 4.7 213.3 7.6X - Data column - ORC Vectorized 332 / 338 47.4 21.1 76.6X - Data column - ORC Vectorized with copy 338 / 341 46.5 21.5 75.2X - Data column - ORC MR 2329 / 2356 6.8 148.0 10.9X - Partition column - CSV 17465 / 17502 0.9 1110.4 1.5X - Partition column - Json 10865 / 10876 1.4 690.8 2.3X - Partition column - Parquet Vectorized 48 / 52 325.4 3.1 526.1X - Partition column - Parquet MR 1695 / 1696 9.3 107.8 15.0X - Partition column - ORC Vectorized 49 / 54 319.9 3.1 517.2X - Partition column - ORC Vectorized with copy 49 / 52 324.1 3.1 524.0X - Partition column - ORC MR 1548 / 1549 10.2 98.4 16.4X - Both columns - CSV 25568 / 25595 0.6 1625.6 1.0X - Both columns - Json 13658 / 13673 1.2 868.4 1.9X - Both columns - Parquet Vectorized 270 / 296 58.3 17.1 94.3X - Both columns - Parquet MR 3501 / 3521 4.5 222.6 7.3X - Both columns - ORC Vectorized 377 / 380 41.7 24.0 67.4X - Both column - ORC Vectorized with copy 447 / 448 35.2 28.4 56.9X - Both columns - ORC MR 2440 / 2446 6.4 155.2 10.4X + Data column - CSV 32613 / 32841 0.5 2073.4 1.0X + Data column - Json 13343 / 13469 1.2 848.3 2.4X + Data column - Parquet Vectorized 302 / 318 52.1 19.2 108.0X + Data column - Parquet MR 2908 / 2924 5.4 184.9 11.2X + Data column - ORC Vectorized 412 / 425 38.1 26.2 79.1X + Data column - ORC Vectorized with copy 442 / 446 35.6 28.1 73.8X + Data column - ORC MR 2390 / 2396 6.6 152.0 13.6X + Partition column - CSV 9626 / 9683 1.6 612.0 3.4X + Partition column - Json 10909 / 10923 1.4 693.6 3.0X + Partition column - Parquet Vectorized 69 / 76 228.4 4.4 473.6X + Partition column - Parquet MR 1898 / 1933 8.3 120.7 17.2X + Partition column - ORC Vectorized 67 / 74 236.0 4.2 489.4X + Partition column - ORC Vectorized with copy 65 / 72 241.9 4.1 501.6X + Partition column - ORC MR 1743 / 1749 9.0 110.8 18.7X + Both columns - CSV 35523 / 35552 0.4 2258.5 0.9X + Both columns - Json 13676 / 13681 1.2 869.5 2.4X + Both columns - Parquet Vectorized 317 / 326 49.5 20.2 102.7X + Both columns - Parquet MR 3333 / 3336 4.7 211.9 9.8X + Both columns - ORC Vectorized 441 / 446 35.6 28.1 73.9X + Both column - ORC Vectorized with copy 517 / 524 30.4 32.9 63.1X + Both columns - ORC MR 2574 / 2577 6.1 163.6 12.7X */ benchmark.run() } @@ -684,41 +690,42 @@ object DataSourceReadBenchmark { } /* - Intel(R) Xeon(R) CPU E5-2686 v4 @ 2.30GHz + OpenJDK 64-Bit Server VM 1.8.0_171-b10 on Linux 4.14.33-51.37.amzn1.x86_64 + Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz String with Nulls Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative -------------------------------------------------------------------------------------------- - SQL CSV 13518 / 13529 0.8 1289.2 1.0X - SQL Json 10895 / 10926 1.0 1039.0 1.2X - SQL Parquet Vectorized 1539 / 1581 6.8 146.8 8.8X - SQL Parquet MR 3746 / 3811 2.8 357.3 3.6X - ParquetReader Vectorized 1070 / 1112 9.8 102.0 12.6X - SQL ORC Vectorized 1389 / 1408 7.6 132.4 9.7X - SQL ORC Vectorized with copy 1736 / 1750 6.0 165.6 7.8X - SQL ORC MR 3799 / 3892 2.8 362.3 3.6X + SQL CSV 14875 / 14920 0.7 1418.6 1.0X + SQL Json 10974 / 10992 1.0 1046.5 1.4X + SQL Parquet Vectorized 1711 / 1750 6.1 163.2 8.7X + SQL Parquet MR 3838 / 3884 2.7 366.0 3.9X + ParquetReader Vectorized 1155 / 1168 9.1 110.2 12.9X + SQL ORC Vectorized 1341 / 1380 7.8 127.9 11.1X + SQL ORC Vectorized with copy 1659 / 1716 6.3 158.2 9.0X + SQL ORC MR 3594 / 3634 2.9 342.7 4.1X String with Nulls Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative -------------------------------------------------------------------------------------------- - SQL CSV 10854 / 10892 1.0 1035.2 1.0X - SQL Json 8129 / 8138 1.3 775.3 1.3X - SQL Parquet Vectorized 1053 / 1104 10.0 100.4 10.3X - SQL Parquet MR 2840 / 2854 3.7 270.8 3.8X - ParquetReader Vectorized 978 / 1008 10.7 93.2 11.1X - SQL ORC Vectorized 1312 / 1387 8.0 125.1 8.3X - SQL ORC Vectorized with copy 1764 / 1772 5.9 168.2 6.2X - SQL ORC MR 3435 / 3445 3.1 327.6 3.2X + SQL CSV 17219 / 17264 0.6 1642.1 1.0X + SQL Json 8843 / 8864 1.2 843.3 1.9X + SQL Parquet Vectorized 1169 / 1178 9.0 111.4 14.7X + SQL Parquet MR 2676 / 2697 3.9 255.2 6.4X + ParquetReader Vectorized 1068 / 1071 9.8 101.8 16.1X + SQL ORC Vectorized 1319 / 1319 7.9 125.8 13.1X + SQL ORC Vectorized with copy 1638 / 1639 6.4 156.2 10.5X + SQL ORC MR 3230 / 3257 3.2 308.1 5.3X String with Nulls Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative -------------------------------------------------------------------------------------------- - SQL CSV 8043 / 8048 1.3 767.1 1.0X - SQL Json 4911 / 4923 2.1 468.4 1.6X - SQL Parquet Vectorized 206 / 209 51.0 19.6 39.1X - SQL Parquet MR 1528 / 1537 6.9 145.8 5.3X - ParquetReader Vectorized 216 / 219 48.6 20.6 37.2X - SQL ORC Vectorized 462 / 466 22.7 44.1 17.4X - SQL ORC Vectorized with copy 568 / 572 18.5 54.2 14.2X - SQL ORC MR 1647 / 1649 6.4 157.1 4.9X + SQL CSV 13976 / 14053 0.8 1332.8 1.0X + SQL Json 5166 / 5176 2.0 492.6 2.7X + SQL Parquet Vectorized 274 / 282 38.2 26.2 50.9X + SQL Parquet MR 1553 / 1555 6.8 148.1 9.0X + ParquetReader Vectorized 241 / 246 43.5 23.0 57.9X + SQL ORC Vectorized 476 / 479 22.0 45.4 29.3X + SQL ORC Vectorized with copy 584 / 588 17.9 55.7 23.9X + SQL ORC MR 1720 / 1734 6.1 164.1 8.1X */ benchmark.run() } @@ -773,38 +780,39 @@ object DataSourceReadBenchmark { } /* - Intel(R) Xeon(R) CPU E5-2686 v4 @ 2.30GHz + OpenJDK 64-Bit Server VM 1.8.0_171-b10 on Linux 4.14.33-51.37.amzn1.x86_64 + Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz Single Column Scan from 10 columns: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative -------------------------------------------------------------------------------------------- - SQL CSV 3663 / 3665 0.3 3493.2 1.0X - SQL Json 3122 / 3160 0.3 2977.5 1.2X - SQL Parquet Vectorized 40 / 42 26.2 38.2 91.5X - SQL Parquet MR 189 / 192 5.5 180.2 19.4X - SQL ORC Vectorized 48 / 51 21.6 46.2 75.6X - SQL ORC Vectorized with copy 49 / 52 21.4 46.7 74.9X - SQL ORC MR 280 / 289 3.7 267.1 13.1X + SQL CSV 3478 / 3481 0.3 3316.4 1.0X + SQL Json 2646 / 2654 0.4 2523.6 1.3X + SQL Parquet Vectorized 67 / 72 15.8 63.5 52.2X + SQL Parquet MR 207 / 214 5.1 197.6 16.8X + SQL ORC Vectorized 69 / 76 15.2 66.0 50.3X + SQL ORC Vectorized with copy 70 / 76 15.0 66.5 49.9X + SQL ORC MR 299 / 303 3.5 285.1 11.6X Single Column Scan from 50 columns: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative -------------------------------------------------------------------------------------------- - SQL CSV 11420 / 11505 0.1 10891.1 1.0X - SQL Json 11905 / 12120 0.1 11353.6 1.0X - SQL Parquet Vectorized 50 / 54 20.9 47.8 227.7X - SQL Parquet MR 195 / 199 5.4 185.8 58.6X - SQL ORC Vectorized 61 / 65 17.3 57.8 188.3X - SQL ORC Vectorized with copy 62 / 65 17.0 58.8 185.2X - SQL ORC MR 847 / 865 1.2 807.4 13.5X + SQL CSV 9214 / 9236 0.1 8786.7 1.0X + SQL Json 9943 / 9978 0.1 9482.7 0.9X + SQL Parquet Vectorized 77 / 86 13.6 73.3 119.8X + SQL Parquet MR 229 / 235 4.6 218.6 40.2X + SQL ORC Vectorized 84 / 96 12.5 80.0 109.9X + SQL ORC Vectorized with copy 83 / 91 12.6 79.4 110.7X + SQL ORC MR 843 / 854 1.2 804.0 10.9X - Single Column Scan from 100 columns: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + Single Column Scan from 100 columns Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative -------------------------------------------------------------------------------------------- - SQL CSV 21278 / 21404 0.0 20292.4 1.0X - SQL Json 22455 / 22625 0.0 21414.7 0.9X - SQL Parquet Vectorized 73 / 75 14.4 69.3 292.8X - SQL Parquet MR 220 / 226 4.8 209.7 96.8X - SQL ORC Vectorized 82 / 86 12.8 78.2 259.4X - SQL ORC Vectorized with copy 82 / 90 12.7 78.7 258.0X - SQL ORC MR 1568 / 1582 0.7 1495.4 13.6X + SQL CSV 16503 / 16622 0.1 15738.9 1.0X + SQL Json 19109 / 19184 0.1 18224.2 0.9X + SQL Parquet Vectorized 99 / 108 10.6 94.3 166.8X + SQL Parquet MR 253 / 264 4.1 241.6 65.1X + SQL ORC Vectorized 107 / 114 9.8 101.6 154.8X + SQL ORC Vectorized with copy 107 / 118 9.8 102.1 154.1X + SQL ORC MR 1526 / 1529 0.7 1455.3 10.8X */ benchmark.run() } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/FilterPushdownBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/FilterPushdownBenchmark.scala new file mode 100644 index 0000000000000..bdb60b44750c7 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/FilterPushdownBenchmark.scala @@ -0,0 +1,412 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.benchmark + +import java.io.{File, FileOutputStream, OutputStream} + +import scala.util.{Random, Try} + +import org.scalatest.{BeforeAndAfterEachTestData, Suite, TestData} + +import org.apache.spark.SparkConf +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.{DataFrame, SparkSession} +import org.apache.spark.sql.functions.monotonically_increasing_id +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.internal.SQLConf.ParquetOutputTimestampType +import org.apache.spark.sql.types.{ByteType, Decimal, DecimalType, TimestampType} +import org.apache.spark.util.{Benchmark, Utils} + +/** + * Benchmark to measure read performance with Filter pushdown. + * To run this: + * build/sbt "sql/test-only *FilterPushdownBenchmark" + * + * Results will be written to "benchmarks/FilterPushdownBenchmark-results.txt". + */ +class FilterPushdownBenchmark extends SparkFunSuite with BenchmarkBeforeAndAfterEachTest { + private val conf = new SparkConf() + .setAppName(this.getClass.getSimpleName) + // Since `spark.master` always exists, overrides this value + .set("spark.master", "local[1]") + .setIfMissing("spark.driver.memory", "3g") + .setIfMissing("spark.executor.memory", "3g") + .setIfMissing("spark.ui.enabled", "false") + .setIfMissing("orc.compression", "snappy") + .setIfMissing("spark.sql.parquet.compression.codec", "snappy") + + private val numRows = 1024 * 1024 * 15 + private val width = 5 + private val mid = numRows / 2 + private val blockSize = 1048576 + + private val spark = SparkSession.builder().config(conf).getOrCreate() + + private var out: OutputStream = _ + + override def beforeAll() { + super.beforeAll() + out = new FileOutputStream(new File("benchmarks/FilterPushdownBenchmark-results.txt")) + } + + override def beforeEach(td: TestData) { + super.beforeEach(td) + val separator = "=" * 96 + val testHeader = (separator + '\n' + td.name + '\n' + separator + '\n' + '\n').getBytes + out.write(testHeader) + } + + override def afterEach(td: TestData) { + out.write('\n') + super.afterEach(td) + } + + override def afterAll() { + try { + out.close() + } finally { + super.afterAll() + } + } + + def withTempPath(f: File => Unit): Unit = { + val path = Utils.createTempDir() + path.delete() + try f(path) finally Utils.deleteRecursively(path) + } + + def withTempTable(tableNames: String*)(f: => Unit): Unit = { + try f finally tableNames.foreach(spark.catalog.dropTempView) + } + + def withSQLConf(pairs: (String, String)*)(f: => Unit): Unit = { + val (keys, values) = pairs.unzip + val currentValues = keys.map(key => Try(spark.conf.get(key)).toOption) + (keys, values).zipped.foreach(spark.conf.set) + try f finally { + keys.zip(currentValues).foreach { + case (key, Some(value)) => spark.conf.set(key, value) + case (key, None) => spark.conf.unset(key) + } + } + } + + private def prepareTable( + dir: File, numRows: Int, width: Int, useStringForValue: Boolean): Unit = { + import spark.implicits._ + val selectExpr = (1 to width).map(i => s"CAST(value AS STRING) c$i") + val valueCol = if (useStringForValue) { + monotonically_increasing_id().cast("string") + } else { + monotonically_increasing_id() + } + val df = spark.range(numRows).map(_ => Random.nextLong).selectExpr(selectExpr: _*) + .withColumn("value", valueCol) + .sort("value") + + saveAsTable(df, dir) + } + + private def prepareStringDictTable( + dir: File, numRows: Int, numDistinctValues: Int, width: Int): Unit = { + val selectExpr = (0 to width).map { + case 0 => s"CAST(id % $numDistinctValues AS STRING) AS value" + case i => s"CAST(rand() AS STRING) c$i" + } + val df = spark.range(numRows).selectExpr(selectExpr: _*).sort("value") + + saveAsTable(df, dir) + } + + private def saveAsTable(df: DataFrame, dir: File): Unit = { + val orcPath = dir.getCanonicalPath + "/orc" + val parquetPath = dir.getCanonicalPath + "/parquet" + + // To always turn on dictionary encoding, we set 1.0 at the threshold (the default is 0.8) + df.write.mode("overwrite") + .option("orc.dictionary.key.threshold", 1.0) + .option("orc.stripe.size", blockSize).orc(orcPath) + spark.read.orc(orcPath).createOrReplaceTempView("orcTable") + + df.write.mode("overwrite") + .option("parquet.block.size", blockSize).parquet(parquetPath) + spark.read.parquet(parquetPath).createOrReplaceTempView("parquetTable") + } + + def filterPushDownBenchmark( + values: Int, + title: String, + whereExpr: String, + selectExpr: String = "*"): Unit = { + val benchmark = new Benchmark(title, values, minNumIters = 5, output = Some(out)) + + Seq(false, true).foreach { pushDownEnabled => + val name = s"Parquet Vectorized ${if (pushDownEnabled) s"(Pushdown)" else ""}" + benchmark.addCase(name) { _ => + withSQLConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED.key -> s"$pushDownEnabled") { + spark.sql(s"SELECT $selectExpr FROM parquetTable WHERE $whereExpr").collect() + } + } + } + + Seq(false, true).foreach { pushDownEnabled => + val name = s"Native ORC Vectorized ${if (pushDownEnabled) s"(Pushdown)" else ""}" + benchmark.addCase(name) { _ => + withSQLConf(SQLConf.ORC_FILTER_PUSHDOWN_ENABLED.key -> s"$pushDownEnabled") { + spark.sql(s"SELECT $selectExpr FROM orcTable WHERE $whereExpr").collect() + } + } + } + + benchmark.run() + } + + private def runIntBenchmark(numRows: Int, width: Int, mid: Int): Unit = { + Seq("value IS NULL", s"$mid < value AND value < $mid").foreach { whereExpr => + val title = s"Select 0 int row ($whereExpr)".replace("value AND value", "value") + filterPushDownBenchmark(numRows, title, whereExpr) + } + + Seq( + s"value = $mid", + s"value <=> $mid", + s"$mid <= value AND value <= $mid", + s"${mid - 1} < value AND value < ${mid + 1}" + ).foreach { whereExpr => + val title = s"Select 1 int row ($whereExpr)".replace("value AND value", "value") + filterPushDownBenchmark(numRows, title, whereExpr) + } + + val selectExpr = (1 to width).map(i => s"MAX(c$i)").mkString("", ",", ", MAX(value)") + + Seq(10, 50, 90).foreach { percent => + filterPushDownBenchmark( + numRows, + s"Select $percent% int rows (value < ${numRows * percent / 100})", + s"value < ${numRows * percent / 100}", + selectExpr + ) + } + + Seq("value IS NOT NULL", "value > -1", "value != -1").foreach { whereExpr => + filterPushDownBenchmark( + numRows, + s"Select all int rows ($whereExpr)", + whereExpr, + selectExpr) + } + } + + private def runStringBenchmark( + numRows: Int, width: Int, searchValue: Int, colType: String): Unit = { + Seq("value IS NULL", s"'$searchValue' < value AND value < '$searchValue'") + .foreach { whereExpr => + val title = s"Select 0 $colType row ($whereExpr)".replace("value AND value", "value") + filterPushDownBenchmark(numRows, title, whereExpr) + } + + Seq( + s"value = '$searchValue'", + s"value <=> '$searchValue'", + s"'$searchValue' <= value AND value <= '$searchValue'" + ).foreach { whereExpr => + val title = s"Select 1 $colType row ($whereExpr)".replace("value AND value", "value") + filterPushDownBenchmark(numRows, title, whereExpr) + } + + val selectExpr = (1 to width).map(i => s"MAX(c$i)").mkString("", ",", ", MAX(value)") + + Seq("value IS NOT NULL").foreach { whereExpr => + filterPushDownBenchmark( + numRows, + s"Select all $colType rows ($whereExpr)", + whereExpr, + selectExpr) + } + } + + ignore("Pushdown for many distinct value case") { + withTempPath { dir => + withTempTable("orcTable", "patquetTable") { + Seq(true, false).foreach { useStringForValue => + prepareTable(dir, numRows, width, useStringForValue) + if (useStringForValue) { + runStringBenchmark(numRows, width, mid, "string") + } else { + runIntBenchmark(numRows, width, mid) + } + } + } + } + } + + ignore("Pushdown for few distinct value case (use dictionary encoding)") { + withTempPath { dir => + val numDistinctValues = 200 + + withTempTable("orcTable", "patquetTable") { + prepareStringDictTable(dir, numRows, numDistinctValues, width) + runStringBenchmark(numRows, width, numDistinctValues / 2, "distinct string") + } + } + } + + ignore("Pushdown benchmark for StringStartsWith") { + withTempPath { dir => + withTempTable("orcTable", "patquetTable") { + prepareTable(dir, numRows, width, true) + Seq( + "value like '10%'", + "value like '1000%'", + s"value like '${mid.toString.substring(0, mid.toString.length - 1)}%'" + ).foreach { whereExpr => + val title = s"StringStartsWith filter: ($whereExpr)" + filterPushDownBenchmark(numRows, title, whereExpr) + } + } + } + } + + ignore(s"Pushdown benchmark for ${DecimalType.simpleString}") { + withTempPath { dir => + Seq( + s"decimal(${Decimal.MAX_INT_DIGITS}, 2)", + s"decimal(${Decimal.MAX_LONG_DIGITS}, 2)", + s"decimal(${DecimalType.MAX_PRECISION}, 2)" + ).foreach { dt => + val columns = (1 to width).map(i => s"CAST(id AS string) c$i") + val valueCol = if (dt.equalsIgnoreCase(s"decimal(${Decimal.MAX_INT_DIGITS}, 2)")) { + monotonically_increasing_id() % 9999999 + } else { + monotonically_increasing_id() + } + val df = spark.range(numRows).selectExpr(columns: _*).withColumn("value", valueCol.cast(dt)) + withTempTable("orcTable", "patquetTable") { + saveAsTable(df, dir) + + Seq(s"value = $mid").foreach { whereExpr => + val title = s"Select 1 $dt row ($whereExpr)".replace("value AND value", "value") + filterPushDownBenchmark(numRows, title, whereExpr) + } + + val selectExpr = (1 to width).map(i => s"MAX(c$i)").mkString("", ",", ", MAX(value)") + Seq(10, 50, 90).foreach { percent => + filterPushDownBenchmark( + numRows, + s"Select $percent% $dt rows (value < ${numRows * percent / 100})", + s"value < ${numRows * percent / 100}", + selectExpr + ) + } + } + } + } + } + + ignore("Pushdown benchmark for InSet -> InFilters") { + withTempPath { dir => + withTempTable("orcTable", "patquetTable") { + prepareTable(dir, numRows, width, false) + Seq(5, 10, 50, 100).foreach { count => + Seq(10, 50, 90).foreach { distribution => + val filter = + Range(0, count).map(r => scala.util.Random.nextInt(numRows * distribution / 100)) + val whereExpr = s"value in(${filter.mkString(",")})" + val title = s"InSet -> InFilters (values count: $count, distribution: $distribution)" + filterPushDownBenchmark(numRows, title, whereExpr) + } + } + } + } + } + + ignore(s"Pushdown benchmark for ${ByteType.simpleString}") { + withTempPath { dir => + val columns = (1 to width).map(i => s"CAST(id AS string) c$i") + val df = spark.range(numRows).selectExpr(columns: _*) + .withColumn("value", (monotonically_increasing_id() % Byte.MaxValue).cast(ByteType)) + .orderBy("value") + withTempTable("orcTable", "patquetTable") { + saveAsTable(df, dir) + + Seq(s"value = CAST(${Byte.MaxValue / 2} AS ${ByteType.simpleString})") + .foreach { whereExpr => + val title = s"Select 1 ${ByteType.simpleString} row ($whereExpr)" + .replace("value AND value", "value") + filterPushDownBenchmark(numRows, title, whereExpr) + } + + val selectExpr = (1 to width).map(i => s"MAX(c$i)").mkString("", ",", ", MAX(value)") + Seq(10, 50, 90).foreach { percent => + filterPushDownBenchmark( + numRows, + s"Select $percent% ${ByteType.simpleString} rows " + + s"(value < CAST(${Byte.MaxValue * percent / 100} AS ${ByteType.simpleString}))", + s"value < CAST(${Byte.MaxValue * percent / 100} AS ${ByteType.simpleString})", + selectExpr + ) + } + } + } + } + + ignore(s"Pushdown benchmark for Timestamp") { + withTempPath { dir => + withSQLConf(SQLConf.PARQUET_FILTER_PUSHDOWN_TIMESTAMP_ENABLED.key -> true.toString) { + ParquetOutputTimestampType.values.toSeq.map(_.toString).foreach { fileType => + withSQLConf(SQLConf.PARQUET_OUTPUT_TIMESTAMP_TYPE.key -> fileType) { + val columns = (1 to width).map(i => s"CAST(id AS string) c$i") + val df = spark.range(numRows).selectExpr(columns: _*) + .withColumn("value", monotonically_increasing_id().cast(TimestampType)) + withTempTable("orcTable", "patquetTable") { + saveAsTable(df, dir) + + Seq(s"value = CAST($mid AS timestamp)").foreach { whereExpr => + val title = s"Select 1 timestamp stored as $fileType row ($whereExpr)" + .replace("value AND value", "value") + filterPushDownBenchmark(numRows, title, whereExpr) + } + + val selectExpr = (1 to width).map(i => s"MAX(c$i)").mkString("", ",", ", MAX(value)") + Seq(10, 50, 90).foreach { percent => + filterPushDownBenchmark( + numRows, + s"Select $percent% timestamp stored as $fileType rows " + + s"(value < CAST(${numRows * percent / 100} AS timestamp))", + s"value < CAST(${numRows * percent / 100} as timestamp)", + selectExpr + ) + } + } + } + } + } + } + } +} + +trait BenchmarkBeforeAndAfterEachTest extends BeforeAndAfterEachTestData { this: Suite => + + override def beforeEach(td: TestData) { + super.beforeEach(td) + } + + override def afterEach(td: TestData) { + super.afterEach(td) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala index 3998ceca38b30..ca95aad3976e6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala @@ -441,6 +441,24 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { } } + test("rename a managed table with existing empty directory") { + val tableLoc = new File(spark.sessionState.catalog.defaultTablePath(TableIdentifier("tab2"))) + try { + withTable("tab1") { + sql(s"CREATE TABLE tab1 USING $dataSource AS SELECT 1, 'a'") + tableLoc.mkdir() + val ex = intercept[AnalysisException] { + sql("ALTER TABLE tab1 RENAME TO tab2") + }.getMessage + val expectedMsg = "Can not rename the managed table('`tab1`'). The associated location" + assert(ex.contains(expectedMsg)) + } + } finally { + waitForTasksToFinish() + Utils.deleteRecursively(tableLoc) + } + } + private def checkSchemaInCreatedDataSourceTable( path: File, userSpecifiedSchema: Option[String], @@ -2495,7 +2513,7 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { test("alter datasource table add columns - text format not supported") { withTable("t1") { - sql("CREATE TABLE t1 (c1 int) USING text") + sql("CREATE TABLE t1 (c1 string) USING text") val e = intercept[AnalysisException] { sql("ALTER TABLE t1 ADD COLUMNS (c2 int)") }.getMessage diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala index 8764f0c42cf9f..bceaf1a9ec061 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala @@ -25,7 +25,7 @@ import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{BlockLocation, FileStatus, Path, RawLocalFileSystem} import org.apache.hadoop.mapreduce.Job -import org.apache.spark.{SparkConf, SparkException} +import org.apache.spark.SparkException import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.catalog.BucketSpec diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/ReadSchemaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/ReadSchemaSuite.scala new file mode 100644 index 0000000000000..23c58e175fe5e --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/ReadSchemaSuite.scala @@ -0,0 +1,181 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources + +import org.apache.spark.sql.internal.SQLConf + +/** + * Read schema suites have the following hierarchy and aims to guarantee users + * a backward-compatible read-schema change coverage on file-based data sources, and + * to prevent future regressions. + * + * ReadSchemaSuite + * -> CSVReadSchemaSuite + * -> HeaderCSVReadSchemaSuite + * + * -> JsonReadSchemaSuite + * + * -> OrcReadSchemaSuite + * -> VectorizedOrcReadSchemaSuite + * + * -> ParquetReadSchemaSuite + * -> VectorizedParquetReadSchemaSuite + * -> MergedParquetReadSchemaSuite + */ + +/** + * All file-based data sources supports column addition and removal at the end. + */ +abstract class ReadSchemaSuite + extends AddColumnTest + with HideColumnAtTheEndTest { + + var originalConf: Boolean = _ +} + +class CSVReadSchemaSuite + extends ReadSchemaSuite + with IntegralTypeTest + with ToDoubleTypeTest + with ToDecimalTypeTest + with ToStringTypeTest { + + override val format: String = "csv" +} + +class HeaderCSVReadSchemaSuite + extends ReadSchemaSuite + with IntegralTypeTest + with ToDoubleTypeTest + with ToDecimalTypeTest + with ToStringTypeTest { + + override val format: String = "csv" + + override val options = Map("header" -> "true") +} + +class JsonReadSchemaSuite + extends ReadSchemaSuite + with HideColumnInTheMiddleTest + with ChangePositionTest + with IntegralTypeTest + with ToDoubleTypeTest + with ToDecimalTypeTest + with ToStringTypeTest { + + override val format: String = "json" +} + +class OrcReadSchemaSuite + extends ReadSchemaSuite + with HideColumnInTheMiddleTest + with ChangePositionTest { + + override val format: String = "orc" + + override def beforeAll() { + super.beforeAll() + originalConf = spark.conf.get(SQLConf.ORC_VECTORIZED_READER_ENABLED) + spark.conf.set(SQLConf.ORC_VECTORIZED_READER_ENABLED.key, "false") + } + + override def afterAll() { + spark.conf.set(SQLConf.ORC_VECTORIZED_READER_ENABLED.key, originalConf) + super.afterAll() + } +} + +class VectorizedOrcReadSchemaSuite + extends ReadSchemaSuite + with HideColumnInTheMiddleTest + with ChangePositionTest + with BooleanTypeTest + with IntegralTypeTest + with ToDoubleTypeTest { + + override val format: String = "orc" + + override def beforeAll() { + super.beforeAll() + originalConf = spark.conf.get(SQLConf.ORC_VECTORIZED_READER_ENABLED) + spark.conf.set(SQLConf.ORC_VECTORIZED_READER_ENABLED.key, "true") + } + + override def afterAll() { + spark.conf.set(SQLConf.ORC_VECTORIZED_READER_ENABLED.key, originalConf) + super.afterAll() + } +} + +class ParquetReadSchemaSuite + extends ReadSchemaSuite + with HideColumnInTheMiddleTest + with ChangePositionTest { + + override val format: String = "parquet" + + override def beforeAll() { + super.beforeAll() + originalConf = spark.conf.get(SQLConf.PARQUET_VECTORIZED_READER_ENABLED) + spark.conf.set(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key, "false") + } + + override def afterAll() { + spark.conf.set(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key, originalConf) + super.afterAll() + } +} + +class VectorizedParquetReadSchemaSuite + extends ReadSchemaSuite + with HideColumnInTheMiddleTest + with ChangePositionTest { + + override val format: String = "parquet" + + override def beforeAll() { + super.beforeAll() + originalConf = spark.conf.get(SQLConf.PARQUET_VECTORIZED_READER_ENABLED) + spark.conf.set(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key, "true") + } + + override def afterAll() { + spark.conf.set(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key, originalConf) + super.afterAll() + } +} + +class MergedParquetReadSchemaSuite + extends ReadSchemaSuite + with HideColumnInTheMiddleTest + with ChangePositionTest { + + override val format: String = "parquet" + + override def beforeAll() { + super.beforeAll() + originalConf = spark.conf.get(SQLConf.PARQUET_SCHEMA_MERGING_ENABLED) + spark.conf.set(SQLConf.PARQUET_SCHEMA_MERGING_ENABLED.key, "true") + } + + override def afterAll() { + spark.conf.set(SQLConf.PARQUET_SCHEMA_MERGING_ENABLED.key, originalConf) + super.afterAll() + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/ReadSchemaTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/ReadSchemaTest.scala new file mode 100644 index 0000000000000..2a5457e00b4ef --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/ReadSchemaTest.scala @@ -0,0 +1,493 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources + +import java.io.File + +import org.apache.spark.sql.{QueryTest, Row} +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.test.{SharedSQLContext, SQLTestUtils} + +/** + * The reader schema is said to be evolved (or projected) when it changed after the data is + * written by writers. The followings are supported in file-based data sources. + * Note that partition columns are not maintained in files. Here, `column` means non-partition + * column. + * + * 1. Add a column + * 2. Hide a column + * 3. Change a column position + * 4. Change a column type (Upcast) + * + * Here, we consider safe changes without data loss. For example, data type changes should be + * from small types to larger types like `int`-to-`long`, not vice versa. + * + * So far, file-based data sources have the following coverages. + * + * | File Format | Coverage | Note | + * | ------------ | ------------ | ------------------------------------------------------ | + * | TEXT | N/A | Schema consists of a single string column. | + * | CSV | 1, 2, 4 | | + * | JSON | 1, 2, 3, 4 | | + * | ORC | 1, 2, 3, 4 | Native vectorized ORC reader has the widest coverage. | + * | PARQUET | 1, 2, 3 | | + * + * This aims to provide an explicit test coverage for reader schema change on file-based data + * sources. Since a file format has its own coverage, we need a test suite for each file-based + * data source with corresponding supported test case traits. + * + * The following is a hierarchy of test traits. + * + * ReadSchemaTest + * -> AddColumnTest + * -> HideColumnTest + * -> ChangePositionTest + * -> BooleanTypeTest + * -> IntegralTypeTest + * -> ToDoubleTypeTest + * -> ToDecimalTypeTest + */ + +trait ReadSchemaTest extends QueryTest with SQLTestUtils with SharedSQLContext { + val format: String + val options: Map[String, String] = Map.empty[String, String] +} + +/** + * Add column (Case 1). + * This test suite assumes that the missing column should be `null`. + */ +trait AddColumnTest extends ReadSchemaTest { + import testImplicits._ + + test("append column at the end") { + withTempPath { dir => + val path = dir.getCanonicalPath + + val df1 = Seq("a", "b").toDF("col1") + val df2 = df1.withColumn("col2", lit("x")) + val df3 = df2.withColumn("col3", lit("y")) + + val dir1 = s"$path${File.separator}part=one" + val dir2 = s"$path${File.separator}part=two" + val dir3 = s"$path${File.separator}part=three" + + df1.write.format(format).options(options).save(dir1) + df2.write.format(format).options(options).save(dir2) + df3.write.format(format).options(options).save(dir3) + + val df = spark.read + .schema(df3.schema) + .format(format) + .options(options) + .load(path) + + checkAnswer(df, Seq( + Row("a", null, null, "one"), + Row("b", null, null, "one"), + Row("a", "x", null, "two"), + Row("b", "x", null, "two"), + Row("a", "x", "y", "three"), + Row("b", "x", "y", "three"))) + } + } +} + +/** + * Hide column (Case 2-1). + */ +trait HideColumnAtTheEndTest extends ReadSchemaTest { + import testImplicits._ + + test("hide column at the end") { + withTempPath { dir => + val path = dir.getCanonicalPath + + val df1 = Seq(("1", "a"), ("2", "b")).toDF("col1", "col2") + val df2 = df1.withColumn("col3", lit("y")) + + val dir1 = s"$path${File.separator}part=two" + val dir2 = s"$path${File.separator}part=three" + + df1.write.format(format).options(options).save(dir1) + df2.write.format(format).options(options).save(dir2) + + val df = spark.read + .schema(df1.schema) + .format(format) + .options(options) + .load(path) + + checkAnswer(df, Seq( + Row("1", "a", "two"), + Row("2", "b", "two"), + Row("1", "a", "three"), + Row("2", "b", "three"))) + + val df3 = spark.read + .schema("col1 string") + .format(format) + .options(options) + .load(path) + + checkAnswer(df3, Seq( + Row("1", "two"), + Row("2", "two"), + Row("1", "three"), + Row("2", "three"))) + } + } +} + +/** + * Hide column in the middle (Case 2-2). + */ +trait HideColumnInTheMiddleTest extends ReadSchemaTest { + import testImplicits._ + + test("hide column in the middle") { + withTempPath { dir => + val path = dir.getCanonicalPath + + val df1 = Seq(("1", "a"), ("2", "b")).toDF("col1", "col2") + val df2 = df1.withColumn("col3", lit("y")) + + val dir1 = s"$path${File.separator}part=two" + val dir2 = s"$path${File.separator}part=three" + + df1.write.format(format).options(options).save(dir1) + df2.write.format(format).options(options).save(dir2) + + val df = spark.read + .schema("col2 string") + .format(format) + .options(options) + .load(path) + + checkAnswer(df, Seq( + Row("a", "two"), + Row("b", "two"), + Row("a", "three"), + Row("b", "three"))) + } + } +} + +/** + * Change column positions (Case 3). + * This suite assumes that all data set have the same number of columns. + */ +trait ChangePositionTest extends ReadSchemaTest { + import testImplicits._ + + test("change column position") { + withTempPath { dir => + val path = dir.getCanonicalPath + + val df1 = Seq(("1", "a"), ("2", "b"), ("3", "c")).toDF("col1", "col2") + val df2 = Seq(("d", "4"), ("e", "5"), ("f", "6")).toDF("col2", "col1") + val unionDF = df1.unionByName(df2) + + val dir1 = s"$path${File.separator}part=one" + val dir2 = s"$path${File.separator}part=two" + + df1.write.format(format).options(options).save(dir1) + df2.write.format(format).options(options).save(dir2) + + val df = spark.read + .schema(unionDF.schema) + .format(format) + .options(options) + .load(path) + .select("col1", "col2") + + checkAnswer(df, unionDF) + } + } +} + +/** + * Change a column type (Case 4). + * This suite assumes that a user gives a wider schema intentionally. + */ +trait BooleanTypeTest extends ReadSchemaTest { + import testImplicits._ + + test("change column type from boolean to byte/short/int/long") { + withTempPath { dir => + val path = dir.getCanonicalPath + + val values = (1 to 10).map(_ % 2) + val booleanDF = (1 to 10).map(_ % 2 == 1).toDF("col1") + val byteDF = values.map(_.toByte).toDF("col1") + val shortDF = values.map(_.toShort).toDF("col1") + val intDF = values.toDF("col1") + val longDF = values.map(_.toLong).toDF("col1") + + booleanDF.write.mode("overwrite").format(format).options(options).save(path) + + Seq( + ("col1 byte", byteDF), + ("col1 short", shortDF), + ("col1 int", intDF), + ("col1 long", longDF)).foreach { case (schema, answerDF) => + checkAnswer(spark.read.schema(schema).format(format).options(options).load(path), answerDF) + } + } + } +} + +/** + * Change a column type (Case 4). + * This suite assumes that a user gives a wider schema intentionally. + */ +trait ToStringTypeTest extends ReadSchemaTest { + import testImplicits._ + + test("read as string") { + withTempPath { dir => + val path = dir.getCanonicalPath + + val byteDF = (Byte.MaxValue - 2 to Byte.MaxValue).map(_.toByte).toDF("col1") + val shortDF = (Short.MaxValue - 2 to Short.MaxValue).map(_.toShort).toDF("col1") + val intDF = (Int.MaxValue - 2 to Int.MaxValue).toDF("col1") + val longDF = (Long.MaxValue - 2 to Long.MaxValue).toDF("col1") + val unionDF = byteDF.union(shortDF).union(intDF).union(longDF) + .selectExpr("cast(col1 AS STRING) col1") + + val byteDir = s"$path${File.separator}part=byte" + val shortDir = s"$path${File.separator}part=short" + val intDir = s"$path${File.separator}part=int" + val longDir = s"$path${File.separator}part=long" + + byteDF.write.format(format).options(options).save(byteDir) + shortDF.write.format(format).options(options).save(shortDir) + intDF.write.format(format).options(options).save(intDir) + longDF.write.format(format).options(options).save(longDir) + + val df = spark.read + .schema("col1 string") + .format(format) + .options(options) + .load(path) + .select("col1") + + checkAnswer(df, unionDF) + } + } +} + +/** + * Change a column type (Case 4). + * This suite assumes that a user gives a wider schema intentionally. + */ +trait IntegralTypeTest extends ReadSchemaTest { + + import testImplicits._ + + private lazy val values = 1 to 10 + private lazy val byteDF = values.map(_.toByte).toDF("col1") + private lazy val shortDF = values.map(_.toShort).toDF("col1") + private lazy val intDF = values.toDF("col1") + private lazy val longDF = values.map(_.toLong).toDF("col1") + + test("change column type from byte to short/int/long") { + withTempPath { dir => + val path = dir.getCanonicalPath + + byteDF.write.format(format).options(options).save(path) + + Seq( + ("col1 short", shortDF), + ("col1 int", intDF), + ("col1 long", longDF)).foreach { case (schema, answerDF) => + checkAnswer(spark.read.schema(schema).format(format).options(options).load(path), answerDF) + } + } + } + + test("change column type from short to int/long") { + withTempPath { dir => + val path = dir.getCanonicalPath + + shortDF.write.format(format).options(options).save(path) + + Seq(("col1 int", intDF), ("col1 long", longDF)).foreach { case (schema, answerDF) => + checkAnswer(spark.read.schema(schema).format(format).options(options).load(path), answerDF) + } + } + } + + test("change column type from int to long") { + withTempPath { dir => + val path = dir.getCanonicalPath + + intDF.write.format(format).options(options).save(path) + + Seq(("col1 long", longDF)).foreach { case (schema, answerDF) => + checkAnswer(spark.read.schema(schema).format(format).options(options).load(path), answerDF) + } + } + } + + test("read byte, int, short, long together") { + withTempPath { dir => + val path = dir.getCanonicalPath + + val byteDF = (Byte.MaxValue - 2 to Byte.MaxValue).map(_.toByte).toDF("col1") + val shortDF = (Short.MaxValue - 2 to Short.MaxValue).map(_.toShort).toDF("col1") + val intDF = (Int.MaxValue - 2 to Int.MaxValue).toDF("col1") + val longDF = (Long.MaxValue - 2 to Long.MaxValue).toDF("col1") + val unionDF = byteDF.union(shortDF).union(intDF).union(longDF) + + val byteDir = s"$path${File.separator}part=byte" + val shortDir = s"$path${File.separator}part=short" + val intDir = s"$path${File.separator}part=int" + val longDir = s"$path${File.separator}part=long" + + byteDF.write.format(format).options(options).save(byteDir) + shortDF.write.format(format).options(options).save(shortDir) + intDF.write.format(format).options(options).save(intDir) + longDF.write.format(format).options(options).save(longDir) + + val df = spark.read + .schema(unionDF.schema) + .format(format) + .options(options) + .load(path) + .select("col1") + + checkAnswer(df, unionDF) + } + } +} + +/** + * Change a column type (Case 4). + * This suite assumes that a user gives a wider schema intentionally. + */ +trait ToDoubleTypeTest extends ReadSchemaTest { + import testImplicits._ + + private lazy val values = 1 to 10 + private lazy val floatDF = values.map(_.toFloat).toDF("col1") + private lazy val doubleDF = values.map(_.toDouble).toDF("col1") + private lazy val unionDF = floatDF.union(doubleDF) + + test("change column type from float to double") { + withTempPath { dir => + val path = dir.getCanonicalPath + + floatDF.write.format(format).options(options).save(path) + + val df = spark.read.schema("col1 double").format(format).options(options).load(path) + + checkAnswer(df, doubleDF) + } + } + + test("read float and double together") { + withTempPath { dir => + val path = dir.getCanonicalPath + + val floatDir = s"$path${File.separator}part=float" + val doubleDir = s"$path${File.separator}part=double" + + floatDF.write.format(format).options(options).save(floatDir) + doubleDF.write.format(format).options(options).save(doubleDir) + + val df = spark.read + .schema(unionDF.schema) + .format(format) + .options(options) + .load(path) + .select("col1") + + checkAnswer(df, unionDF) + } + } +} + +/** + * Change a column type (Case 4). + * This suite assumes that a user gives a wider schema intentionally. + */ +trait ToDecimalTypeTest extends ReadSchemaTest { + import testImplicits._ + + private lazy val values = 1 to 10 + private lazy val floatDF = values.map(_.toFloat).toDF("col1") + private lazy val doubleDF = values.map(_.toDouble).toDF("col1") + private lazy val decimalDF = values.map(BigDecimal(_)).toDF("col1") + private lazy val unionDF = floatDF.union(doubleDF).union(decimalDF) + + test("change column type from float to decimal") { + withTempPath { dir => + val path = dir.getCanonicalPath + + floatDF.write.format(format).options(options).save(path) + + val df = spark.read + .schema("col1 decimal(38,18)") + .format(format) + .options(options) + .load(path) + + checkAnswer(df, decimalDF) + } + } + + test("change column type from double to decimal") { + withTempPath { dir => + val path = dir.getCanonicalPath + + doubleDF.write.format(format).options(options).save(path) + + val df = spark.read + .schema("col1 decimal(38,18)") + .format(format) + .options(options) + .load(path) + + checkAnswer(df, decimalDF) + } + } + + test("read float, double, decimal together") { + withTempPath { dir => + val path = dir.getCanonicalPath + + val floatDir = s"$path${File.separator}part=float" + val doubleDir = s"$path${File.separator}part=double" + val decimalDir = s"$path${File.separator}part=decimal" + + floatDF.write.format(format).options(options).save(floatDir) + doubleDF.write.format(format).options(options).save(doubleDir) + decimalDF.write.format(format).options(options).save(decimalDir) + + val df = spark.read + .schema(unionDF.schema) + .format(format) + .options(options) + .load(path) + .select("col1") + + checkAnswer(df, unionDF) + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala index d2f166c7d1877..63cc5985040c9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala @@ -60,10 +60,6 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils with Te private val unescapedQuotesFile = "test-data/unescaped-quotes.csv" private val valueMalformedFile = "test-data/value-malformed.csv" - private def testFile(fileName: String): String = { - Thread.currentThread().getContextClassLoader.getResource(fileName).toString - } - /** Verifies data and schema. */ private def verifyCars( df: DataFrame, @@ -740,39 +736,6 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils with Te assert(numbers.count() == 8) } - test("error handling for unsupported data types.") { - withTempDir { dir => - val csvDir = new File(dir, "csv").getCanonicalPath - var msg = intercept[UnsupportedOperationException] { - Seq((1, "Tesla")).toDF("a", "b").selectExpr("struct(a, b)").write.csv(csvDir) - }.getMessage - assert(msg.contains("CSV data source does not support struct data type")) - - msg = intercept[UnsupportedOperationException] { - Seq((1, Map("Tesla" -> 3))).toDF("id", "cars").write.csv(csvDir) - }.getMessage - assert(msg.contains("CSV data source does not support map data type")) - - msg = intercept[UnsupportedOperationException] { - Seq((1, Array("Tesla", "Chevy", "Ford"))).toDF("id", "brands").write.csv(csvDir) - }.getMessage - assert(msg.contains("CSV data source does not support array data type")) - - msg = intercept[UnsupportedOperationException] { - Seq((1, new UDT.MyDenseVector(Array(0.25, 2.25, 4.25)))).toDF("id", "vectors") - .write.csv(csvDir) - }.getMessage - assert(msg.contains("CSV data source does not support array data type")) - - msg = intercept[UnsupportedOperationException] { - val schema = StructType(StructField("a", new UDT.MyDenseVectorUDT(), true) :: Nil) - spark.range(1).write.csv(csvDir) - spark.read.schema(schema).csv(csvDir).collect() - }.getMessage - assert(msg.contains("CSV data source does not support array data type.")) - } - } - test("SPARK-15585 turn off quotations") { val cars = spark.read .format("csv") @@ -1602,4 +1565,43 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils with Te assert(testAppender2.events.asScala .exists(msg => msg.getRenderedMessage.contains("CSV header does not conform to the schema"))) } + + test("SPARK-24645 skip parsing when columnPruning enabled and partitions scanned only") { + withSQLConf(SQLConf.CSV_PARSER_COLUMN_PRUNING.key -> "true") { + withTempPath { path => + val dir = path.getAbsolutePath + spark.range(10).selectExpr("id % 2 AS p", "id").write.partitionBy("p").csv(dir) + checkAnswer(spark.read.csv(dir).selectExpr("sum(p)"), Row(5)) + } + } + } + + test("SPARK-24676 project required data from parsed data when columnPruning disabled") { + withSQLConf(SQLConf.CSV_PARSER_COLUMN_PRUNING.key -> "false") { + withTempPath { path => + val dir = path.getAbsolutePath + spark.range(10).selectExpr("id % 2 AS p", "id AS c0", "id AS c1").write.partitionBy("p") + .option("header", "true").csv(dir) + val df1 = spark.read.option("header", true).csv(dir).selectExpr("sum(p)", "count(c0)") + checkAnswer(df1, Row(5, 10)) + + // empty required column case + val df2 = spark.read.option("header", true).csv(dir).selectExpr("sum(p)") + checkAnswer(df2, Row(5)) + } + + // the case where tokens length != parsedSchema length + withTempPath { path => + val dir = path.getAbsolutePath + Seq("1,2").toDF().write.text(dir) + // more tokens + val df1 = spark.read.schema("c0 int").format("csv").option("mode", "permissive").load(dir) + checkAnswer(df1, Row(1)) + // less tokens + val df2 = spark.read.schema("c0 int, c1 int, c2 int").format("csv") + .option("mode", "permissive").load(dir) + checkAnswer(df2, Row(1, 2, null)) + } + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala index 4b3921c61a000..655f40ad549e6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala @@ -17,9 +17,9 @@ package org.apache.spark.sql.execution.datasources.json -import java.io.{File, FileOutputStream, StringWriter} -import java.nio.charset.{StandardCharsets, UnsupportedCharsetException} -import java.nio.file.{Files, Paths, StandardOpenOption} +import java.io._ +import java.nio.charset.{Charset, StandardCharsets, UnsupportedCharsetException} +import java.nio.file.Files import java.sql.{Date, Timestamp} import java.util.Locale @@ -31,11 +31,11 @@ import org.apache.hadoop.io.compress.GzipCodec import org.apache.spark.{SparkException, TestUtils} import org.apache.spark.rdd.RDD import org.apache.spark.sql.{functions => F, _} -import org.apache.spark.sql.catalyst.json.{CreateJacksonParser, JacksonParser, JSONOptions} +import org.apache.spark.sql.catalyst.json.{CreateJacksonParser, JacksonParser, JsonInferSchema, JSONOptions} +import org.apache.spark.sql.catalyst.json.JsonInferSchema.compatibleType import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution.ExternalRDD import org.apache.spark.sql.execution.datasources.DataSource -import org.apache.spark.sql.execution.datasources.json.JsonInferSchema.compatibleType import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ @@ -48,10 +48,6 @@ class TestFileFilter extends PathFilter { class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { import testImplicits._ - def testFile(fileName: String): String = { - Thread.currentThread().getContextClassLoader.getResource(fileName).toString - } - test("Type promotion") { def checkTypePromotion(expected: Any, actual: Any) { assert(expected.getClass == actual.getClass, @@ -2262,7 +2258,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { withTempPath { path => val df = spark.createDataset(Seq(("Dog", 42))) df.write - .options(Map("encoding" -> encoding, "lineSep" -> "\n")) + .options(Map("encoding" -> encoding)) .json(path.getCanonicalPath) checkEncoding( @@ -2286,16 +2282,22 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { test("SPARK-23723: wrong output encoding") { val encoding = "UTF-128" - val exception = intercept[UnsupportedCharsetException] { + val exception = intercept[SparkException] { withTempPath { path => val df = spark.createDataset(Seq((0))) df.write - .options(Map("encoding" -> encoding, "lineSep" -> "\n")) + .options(Map("encoding" -> encoding)) .json(path.getCanonicalPath) } } - assert(exception.getMessage == encoding) + val baos = new ByteArrayOutputStream() + val ps = new PrintStream(baos, true, "UTF-8") + exception.printStackTrace(ps) + ps.flush() + + assert(baos.toString.contains( + "java.nio.charset.UnsupportedCharsetException: UTF-128")) } test("SPARK-23723: read back json in UTF-16LE") { @@ -2316,18 +2318,17 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { test("SPARK-23723: write json in UTF-16/32 with multiline off") { Seq("UTF-16", "UTF-32").foreach { encoding => withTempPath { path => - val ds = spark.createDataset(Seq( - ("a", 1), ("b", 2), ("c", 3)) - ).repartition(2) - val e = intercept[IllegalArgumentException] { - ds.write - .option("encoding", encoding) - .option("multiline", "false") - .format("json").mode("overwrite") - .save(path.getCanonicalPath) - }.getMessage - assert(e.contains( - s"$encoding encoding in the blacklist is not allowed when multiLine is disabled")) + val ds = spark.createDataset(Seq(("a", 1))).repartition(1) + ds.write + .option("encoding", encoding) + .option("multiline", false) + .json(path.getCanonicalPath) + val jsonFiles = path.listFiles().filter(_.getName.endsWith("json")) + jsonFiles.foreach { jsonFile => + val readback = Files.readAllBytes(jsonFile.toPath) + val expected = ("""{"_1":"a","_2":1}""" + "\n").getBytes(Charset.forName(encoding)) + assert(readback === expected) + } } } } @@ -2427,4 +2428,66 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { spark.read.option("mode", "PERMISSIVE").option("encoding", "UTF-8").json(Seq(badJson).toDS()), Row(badJson)) } + + test("SPARK-23772 ignore column of all null values or empty array during schema inference") { + withTempPath { tempDir => + val path = tempDir.getAbsolutePath + + // primitive types + Seq( + """{"a":null, "b":1, "c":3.0}""", + """{"a":null, "b":null, "c":"string"}""", + """{"a":null, "b":null, "c":null}""") + .toDS().write.text(path) + var df = spark.read.format("json") + .option("dropFieldIfAllNull", true) + .load(path) + var expectedSchema = new StructType() + .add("b", LongType).add("c", StringType) + assert(df.schema === expectedSchema) + checkAnswer(df, Row(1, "3.0") :: Row(null, "string") :: Row(null, null) :: Nil) + + // arrays + Seq( + """{"a":[2, 1], "b":[null, null], "c":null, "d":[[], [null]], "e":[[], null, [[]]]}""", + """{"a":[null], "b":[null], "c":[], "d":[null, []], "e":null}""", + """{"a":null, "b":null, "c":[], "d":null, "e":[null, [], null]}""") + .toDS().write.mode("overwrite").text(path) + df = spark.read.format("json") + .option("dropFieldIfAllNull", true) + .load(path) + expectedSchema = new StructType() + .add("a", ArrayType(LongType)) + assert(df.schema === expectedSchema) + checkAnswer(df, Row(Array(2, 1)) :: Row(Array(null)) :: Row(null) :: Nil) + + // structs + Seq( + """{"a":{"a1": 1, "a2":"string"}, "b":{}}""", + """{"a":{"a1": 2, "a2":null}, "b":{"b1":[null]}}""", + """{"a":null, "b":null}""") + .toDS().write.mode("overwrite").text(path) + df = spark.read.format("json") + .option("dropFieldIfAllNull", true) + .load(path) + expectedSchema = new StructType() + .add("a", StructType(StructField("a1", LongType) :: StructField("a2", StringType) + :: Nil)) + assert(df.schema === expectedSchema) + checkAnswer(df, Row(Row(1, "string")) :: Row(Row(2, null)) :: Row(null) :: Nil) + } + } + + test("SPARK-24190: restrictions for JSONOptions in read") { + for (encoding <- Set("UTF-16", "UTF-32")) { + val exception = intercept[IllegalArgumentException] { + spark.read + .option("encoding", encoding) + .option("multiLine", false) + .json(testFile("test-data/utf16LE.json")) + .count() + } + assert(exception.getMessage.contains("encoding must not be included in the blacklist")) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala index 90da7eb8c4fb5..be4f498c921ab 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala @@ -17,10 +17,11 @@ package org.apache.spark.sql.execution.datasources.parquet +import java.math.{BigDecimal => JBigDecimal} import java.nio.charset.StandardCharsets -import java.sql.Date +import java.sql.{Date, Timestamp} -import org.apache.parquet.filter2.predicate.{FilterPredicate, Operators} +import org.apache.parquet.filter2.predicate.{FilterApi, FilterPredicate, Operators} import org.apache.parquet.filter2.predicate.FilterApi._ import org.apache.parquet.filter2.predicate.Operators.{Column => _, _} @@ -31,6 +32,7 @@ import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.execution.datasources.{DataSourceStrategy, HadoopFsRelation, LogicalRelation} import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.internal.SQLConf.ParquetOutputTimestampType import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ import org.apache.spark.util.{AccumulatorContext, AccumulatorV2} @@ -55,7 +57,10 @@ import org.apache.spark.util.{AccumulatorContext, AccumulatorV2} */ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContext { - private lazy val parquetFilters = new ParquetFilters(conf.parquetFilterPushDownDate) + private lazy val parquetFilters = + new ParquetFilters(conf.parquetFilterPushDownDate, conf.parquetFilterPushDownTimestamp, + conf.parquetFilterPushDownDecimal, conf.parquetFilterPushDownStringStartWith, + conf.parquetFilterPushDownInFilterThreshold) override def beforeEach(): Unit = { super.beforeEach() @@ -82,6 +87,9 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex withSQLConf( SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED.key -> "true", SQLConf.PARQUET_FILTER_PUSHDOWN_DATE_ENABLED.key -> "true", + SQLConf.PARQUET_FILTER_PUSHDOWN_TIMESTAMP_ENABLED.key -> "true", + SQLConf.PARQUET_FILTER_PUSHDOWN_DECIMAL_ENABLED.key -> "true", + SQLConf.PARQUET_FILTER_PUSHDOWN_STRING_STARTSWITH_ENABLED.key -> "true", SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false") { val query = df .select(output.map(e => Column(e)): _*) @@ -101,7 +109,8 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex assert(selectedFilters.nonEmpty, "No filter is pushed down") selectedFilters.foreach { pred => - val maybeFilter = parquetFilters.createFilter(df.schema, pred) + val maybeFilter = parquetFilters.createFilter( + new SparkToParquetSchemaConverter(conf).convert(df.schema), pred) assert(maybeFilter.isDefined, s"Couldn't generate filter predicate for $pred") // Doesn't bother checking type parameters here (e.g. `Eq[Integer]`) maybeFilter.exists(_.getClass === filterClass) @@ -140,6 +149,71 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex checkBinaryFilterPredicate(predicate, filterClass, Seq(Row(expected)))(df) } + private def testTimestampPushdown(data: Seq[Timestamp]): Unit = { + assert(data.size === 4) + val ts1 = data.head + val ts2 = data(1) + val ts3 = data(2) + val ts4 = data(3) + + withParquetDataFrame(data.map(i => Tuple1(i))) { implicit df => + checkFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row]) + checkFilterPredicate('_1.isNotNull, classOf[NotEq[_]], data.map(i => Row.apply(i))) + + checkFilterPredicate('_1 === ts1, classOf[Eq[_]], ts1) + checkFilterPredicate('_1 <=> ts1, classOf[Eq[_]], ts1) + checkFilterPredicate('_1 =!= ts1, classOf[NotEq[_]], + Seq(ts2, ts3, ts4).map(i => Row.apply(i))) + + checkFilterPredicate('_1 < ts2, classOf[Lt[_]], ts1) + checkFilterPredicate('_1 > ts1, classOf[Gt[_]], Seq(ts2, ts3, ts4).map(i => Row.apply(i))) + checkFilterPredicate('_1 <= ts1, classOf[LtEq[_]], ts1) + checkFilterPredicate('_1 >= ts4, classOf[GtEq[_]], ts4) + + checkFilterPredicate(Literal(ts1) === '_1, classOf[Eq[_]], ts1) + checkFilterPredicate(Literal(ts1) <=> '_1, classOf[Eq[_]], ts1) + checkFilterPredicate(Literal(ts2) > '_1, classOf[Lt[_]], ts1) + checkFilterPredicate(Literal(ts3) < '_1, classOf[Gt[_]], ts4) + checkFilterPredicate(Literal(ts1) >= '_1, classOf[LtEq[_]], ts1) + checkFilterPredicate(Literal(ts4) <= '_1, classOf[GtEq[_]], ts4) + + checkFilterPredicate(!('_1 < ts4), classOf[GtEq[_]], ts4) + checkFilterPredicate('_1 < ts2 || '_1 > ts3, classOf[Operators.Or], Seq(Row(ts1), Row(ts4))) + } + } + + private def testDecimalPushDown(data: DataFrame)(f: DataFrame => Unit): Unit = { + withTempPath { file => + data.write.parquet(file.getCanonicalPath) + readParquetFile(file.toString)(f) + } + } + + // This function tests that exactly go through the `canDrop` and `inverseCanDrop`. + private def testStringStartsWith(dataFrame: DataFrame, filter: String): Unit = { + withTempPath { dir => + val path = dir.getCanonicalPath + dataFrame.write.option("parquet.block.size", 512).parquet(path) + Seq(true, false).foreach { pushDown => + withSQLConf( + SQLConf.PARQUET_FILTER_PUSHDOWN_STRING_STARTSWITH_ENABLED.key -> pushDown.toString) { + val accu = new NumRowGroupsAcc + sparkContext.register(accu) + + val df = spark.read.parquet(path).filter(filter) + df.foreachPartition((it: Iterator[Row]) => it.foreach(v => accu.add(0))) + if (pushDown) { + assert(accu.value == 0) + } else { + assert(accu.value > 0) + } + + AccumulatorContext.remove(accu.id) + } + } + } + } + test("filter pushdown - boolean") { withParquetDataFrame((true :: false :: Nil).map(b => Tuple1.apply(Option(b)))) { implicit df => checkFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row]) @@ -151,6 +225,62 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex } } + test("filter pushdown - tinyint") { + withParquetDataFrame((1 to 4).map(i => Tuple1(Option(i.toByte)))) { implicit df => + assert(df.schema.head.dataType === ByteType) + checkFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row]) + checkFilterPredicate('_1.isNotNull, classOf[NotEq[_]], (1 to 4).map(Row.apply(_))) + + checkFilterPredicate('_1 === 1.toByte, classOf[Eq[_]], 1) + checkFilterPredicate('_1 <=> 1.toByte, classOf[Eq[_]], 1) + checkFilterPredicate('_1 =!= 1.toByte, classOf[NotEq[_]], (2 to 4).map(Row.apply(_))) + + checkFilterPredicate('_1 < 2.toByte, classOf[Lt[_]], 1) + checkFilterPredicate('_1 > 3.toByte, classOf[Gt[_]], 4) + checkFilterPredicate('_1 <= 1.toByte, classOf[LtEq[_]], 1) + checkFilterPredicate('_1 >= 4.toByte, classOf[GtEq[_]], 4) + + checkFilterPredicate(Literal(1.toByte) === '_1, classOf[Eq[_]], 1) + checkFilterPredicate(Literal(1.toByte) <=> '_1, classOf[Eq[_]], 1) + checkFilterPredicate(Literal(2.toByte) > '_1, classOf[Lt[_]], 1) + checkFilterPredicate(Literal(3.toByte) < '_1, classOf[Gt[_]], 4) + checkFilterPredicate(Literal(1.toByte) >= '_1, classOf[LtEq[_]], 1) + checkFilterPredicate(Literal(4.toByte) <= '_1, classOf[GtEq[_]], 4) + + checkFilterPredicate(!('_1 < 4.toByte), classOf[GtEq[_]], 4) + checkFilterPredicate('_1 < 2.toByte || '_1 > 3.toByte, + classOf[Operators.Or], Seq(Row(1), Row(4))) + } + } + + test("filter pushdown - smallint") { + withParquetDataFrame((1 to 4).map(i => Tuple1(Option(i.toShort)))) { implicit df => + assert(df.schema.head.dataType === ShortType) + checkFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row]) + checkFilterPredicate('_1.isNotNull, classOf[NotEq[_]], (1 to 4).map(Row.apply(_))) + + checkFilterPredicate('_1 === 1.toShort, classOf[Eq[_]], 1) + checkFilterPredicate('_1 <=> 1.toShort, classOf[Eq[_]], 1) + checkFilterPredicate('_1 =!= 1.toShort, classOf[NotEq[_]], (2 to 4).map(Row.apply(_))) + + checkFilterPredicate('_1 < 2.toShort, classOf[Lt[_]], 1) + checkFilterPredicate('_1 > 3.toShort, classOf[Gt[_]], 4) + checkFilterPredicate('_1 <= 1.toShort, classOf[LtEq[_]], 1) + checkFilterPredicate('_1 >= 4.toShort, classOf[GtEq[_]], 4) + + checkFilterPredicate(Literal(1.toShort) === '_1, classOf[Eq[_]], 1) + checkFilterPredicate(Literal(1.toShort) <=> '_1, classOf[Eq[_]], 1) + checkFilterPredicate(Literal(2.toShort) > '_1, classOf[Lt[_]], 1) + checkFilterPredicate(Literal(3.toShort) < '_1, classOf[Gt[_]], 4) + checkFilterPredicate(Literal(1.toShort) >= '_1, classOf[LtEq[_]], 1) + checkFilterPredicate(Literal(4.toShort) <= '_1, classOf[GtEq[_]], 4) + + checkFilterPredicate(!('_1 < 4.toShort), classOf[GtEq[_]], 4) + checkFilterPredicate('_1 < 2.toShort || '_1 > 3.toShort, + classOf[Operators.Or], Seq(Row(1), Row(4))) + } + } + test("filter pushdown - integer") { withParquetDataFrame((1 to 4).map(i => Tuple1(Option(i)))) { implicit df => checkFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row]) @@ -359,6 +489,117 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex } } + test("filter pushdown - timestamp") { + // spark.sql.parquet.outputTimestampType = TIMESTAMP_MILLIS + val millisData = Seq(Timestamp.valueOf("2018-06-14 08:28:53.123"), + Timestamp.valueOf("2018-06-15 08:28:53.123"), + Timestamp.valueOf("2018-06-16 08:28:53.123"), + Timestamp.valueOf("2018-06-17 08:28:53.123")) + withSQLConf(SQLConf.PARQUET_OUTPUT_TIMESTAMP_TYPE.key -> + ParquetOutputTimestampType.TIMESTAMP_MILLIS.toString) { + testTimestampPushdown(millisData) + } + + // spark.sql.parquet.outputTimestampType = TIMESTAMP_MICROS + val microsData = Seq(Timestamp.valueOf("2018-06-14 08:28:53.123456"), + Timestamp.valueOf("2018-06-15 08:28:53.123456"), + Timestamp.valueOf("2018-06-16 08:28:53.123456"), + Timestamp.valueOf("2018-06-17 08:28:53.123456")) + withSQLConf(SQLConf.PARQUET_OUTPUT_TIMESTAMP_TYPE.key -> + ParquetOutputTimestampType.TIMESTAMP_MICROS.toString) { + testTimestampPushdown(microsData) + } + + // spark.sql.parquet.outputTimestampType = INT96 doesn't support pushdown + withSQLConf(SQLConf.PARQUET_OUTPUT_TIMESTAMP_TYPE.key -> + ParquetOutputTimestampType.INT96.toString) { + withParquetDataFrame(millisData.map(i => Tuple1(i))) { implicit df => + assertResult(None) { + parquetFilters.createFilter( + new SparkToParquetSchemaConverter(conf).convert(df.schema), sources.IsNull("_1")) + } + } + } + } + + test("filter pushdown - decimal") { + Seq(true, false).foreach { legacyFormat => + withSQLConf(SQLConf.PARQUET_WRITE_LEGACY_FORMAT.key -> legacyFormat.toString) { + Seq( + s"a decimal(${Decimal.MAX_INT_DIGITS}, 2)", // 32BitDecimalType + s"a decimal(${Decimal.MAX_LONG_DIGITS}, 2)", // 64BitDecimalType + "a decimal(38, 18)" // ByteArrayDecimalType + ).foreach { schemaDDL => + val schema = StructType.fromDDL(schemaDDL) + val rdd = + spark.sparkContext.parallelize((1 to 4).map(i => Row(new java.math.BigDecimal(i)))) + val dataFrame = spark.createDataFrame(rdd, schema) + testDecimalPushDown(dataFrame) { implicit df => + assert(df.schema === schema) + checkFilterPredicate('a.isNull, classOf[Eq[_]], Seq.empty[Row]) + checkFilterPredicate('a.isNotNull, classOf[NotEq[_]], (1 to 4).map(Row.apply(_))) + + checkFilterPredicate('a === 1, classOf[Eq[_]], 1) + checkFilterPredicate('a <=> 1, classOf[Eq[_]], 1) + checkFilterPredicate('a =!= 1, classOf[NotEq[_]], (2 to 4).map(Row.apply(_))) + + checkFilterPredicate('a < 2, classOf[Lt[_]], 1) + checkFilterPredicate('a > 3, classOf[Gt[_]], 4) + checkFilterPredicate('a <= 1, classOf[LtEq[_]], 1) + checkFilterPredicate('a >= 4, classOf[GtEq[_]], 4) + + checkFilterPredicate(Literal(1) === 'a, classOf[Eq[_]], 1) + checkFilterPredicate(Literal(1) <=> 'a, classOf[Eq[_]], 1) + checkFilterPredicate(Literal(2) > 'a, classOf[Lt[_]], 1) + checkFilterPredicate(Literal(3) < 'a, classOf[Gt[_]], 4) + checkFilterPredicate(Literal(1) >= 'a, classOf[LtEq[_]], 1) + checkFilterPredicate(Literal(4) <= 'a, classOf[GtEq[_]], 4) + + checkFilterPredicate(!('a < 4), classOf[GtEq[_]], 4) + checkFilterPredicate('a < 2 || 'a > 3, classOf[Operators.Or], Seq(Row(1), Row(4))) + } + } + } + } + } + + test("Ensure that filter value matched the parquet file schema") { + val scale = 2 + val schema = StructType(Seq( + StructField("cint", IntegerType), + StructField("cdecimal1", DecimalType(Decimal.MAX_INT_DIGITS, scale)), + StructField("cdecimal2", DecimalType(Decimal.MAX_LONG_DIGITS, scale)), + StructField("cdecimal3", DecimalType(DecimalType.MAX_PRECISION, scale)) + )) + + val parquetSchema = new SparkToParquetSchemaConverter(conf).convert(schema) + + val decimal = new JBigDecimal(10).setScale(scale) + val decimal1 = new JBigDecimal(10).setScale(scale + 1) + assert(decimal.scale() === scale) + assert(decimal1.scale() === scale + 1) + + assertResult(Some(lt(intColumn("cdecimal1"), 1000: Integer))) { + parquetFilters.createFilter(parquetSchema, sources.LessThan("cdecimal1", decimal)) + } + assertResult(None) { + parquetFilters.createFilter(parquetSchema, sources.LessThan("cdecimal1", decimal1)) + } + + assertResult(Some(lt(longColumn("cdecimal2"), 1000L: java.lang.Long))) { + parquetFilters.createFilter(parquetSchema, sources.LessThan("cdecimal2", decimal)) + } + assertResult(None) { + parquetFilters.createFilter(parquetSchema, sources.LessThan("cdecimal2", decimal1)) + } + + assert(parquetFilters.createFilter( + parquetSchema, sources.LessThan("cdecimal3", decimal)).isDefined) + assertResult(None) { + parquetFilters.createFilter(parquetSchema, sources.LessThan("cdecimal3", decimal1)) + } + } + test("SPARK-6554: don't push down predicates which reference partition columns") { import testImplicits._ @@ -515,12 +756,14 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex StructField("c", DoubleType, nullable = true) )) + val parquetSchema = new SparkToParquetSchemaConverter(conf).convert(schema) + assertResult(Some(and( lt(intColumn("a"), 10: Integer), gt(doubleColumn("c"), 1.5: java.lang.Double))) ) { parquetFilters.createFilter( - schema, + parquetSchema, sources.And( sources.LessThan("a", 10), sources.GreaterThan("c", 1.5D))) @@ -528,7 +771,7 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex assertResult(None) { parquetFilters.createFilter( - schema, + parquetSchema, sources.And( sources.LessThan("a", 10), sources.StringContains("b", "prefix"))) @@ -536,7 +779,7 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex assertResult(None) { parquetFilters.createFilter( - schema, + parquetSchema, sources.Not( sources.And( sources.GreaterThan("a", 1), @@ -574,7 +817,6 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex val df = spark.read.parquet(path).filter("a < 100") df.foreachPartition((it: Iterator[Row]) => it.foreach(v => accu.add(0))) - df.collect if (enablePushDown) { assert(accu.value == 0) @@ -589,21 +831,25 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex } test("SPARK-17213: Broken Parquet filter push-down for string columns") { - withTempPath { dir => - import testImplicits._ + Seq(true, false).foreach { vectorizedEnabled => + withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> vectorizedEnabled.toString) { + withTempPath { dir => + import testImplicits._ - val path = dir.getCanonicalPath - // scalastyle:off nonascii - Seq("a", "é").toDF("name").write.parquet(path) - // scalastyle:on nonascii + val path = dir.getCanonicalPath + // scalastyle:off nonascii + Seq("a", "é").toDF("name").write.parquet(path) + // scalastyle:on nonascii - assert(spark.read.parquet(path).where("name > 'a'").count() == 1) - assert(spark.read.parquet(path).where("name >= 'a'").count() == 2) + assert(spark.read.parquet(path).where("name > 'a'").count() == 1) + assert(spark.read.parquet(path).where("name >= 'a'").count() == 2) - // scalastyle:off nonascii - assert(spark.read.parquet(path).where("name < 'é'").count() == 1) - assert(spark.read.parquet(path).where("name <= 'é'").count() == 2) - // scalastyle:on nonascii + // scalastyle:off nonascii + assert(spark.read.parquet(path).where("name < 'é'").count() == 1) + assert(spark.read.parquet(path).where("name <= 'é'").count() == 2) + // scalastyle:on nonascii + } + } } } @@ -660,6 +906,121 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex assert(df.where("col > 0").count() === 2) } } + + test("filter pushdown - StringStartsWith") { + withParquetDataFrame((1 to 4).map(i => Tuple1(i + "str" + i))) { implicit df => + checkFilterPredicate( + '_1.startsWith("").asInstanceOf[Predicate], + classOf[UserDefinedByInstance[_, _]], + Seq("1str1", "2str2", "3str3", "4str4").map(Row(_))) + + Seq("2", "2s", "2st", "2str", "2str2").foreach { prefix => + checkFilterPredicate( + '_1.startsWith(prefix).asInstanceOf[Predicate], + classOf[UserDefinedByInstance[_, _]], + "2str2") + } + + Seq("2S", "null", "2str22").foreach { prefix => + checkFilterPredicate( + '_1.startsWith(prefix).asInstanceOf[Predicate], + classOf[UserDefinedByInstance[_, _]], + Seq.empty[Row]) + } + + checkFilterPredicate( + !'_1.startsWith("").asInstanceOf[Predicate], + classOf[UserDefinedByInstance[_, _]], + Seq().map(Row(_))) + + Seq("2", "2s", "2st", "2str", "2str2").foreach { prefix => + checkFilterPredicate( + !'_1.startsWith(prefix).asInstanceOf[Predicate], + classOf[UserDefinedByInstance[_, _]], + Seq("1str1", "3str3", "4str4").map(Row(_))) + } + + Seq("2S", "null", "2str22").foreach { prefix => + checkFilterPredicate( + !'_1.startsWith(prefix).asInstanceOf[Predicate], + classOf[UserDefinedByInstance[_, _]], + Seq("1str1", "2str2", "3str3", "4str4").map(Row(_))) + } + + assertResult(None) { + parquetFilters.createFilter( + new SparkToParquetSchemaConverter(conf).convert(df.schema), + sources.StringStartsWith("_1", null)) + } + } + + import testImplicits._ + // Test canDrop() has taken effect + testStringStartsWith(spark.range(1024).map(_.toString).toDF(), "value like 'a%'") + // Test inverseCanDrop() has taken effect + testStringStartsWith(spark.range(1024).map(c => "100").toDF(), "value not like '10%'") + } + + test("SPARK-17091: Convert IN predicate to Parquet filter push-down") { + val schema = StructType(Seq( + StructField("a", IntegerType, nullable = false) + )) + + val parquetSchema = new SparkToParquetSchemaConverter(conf).convert(schema) + + assertResult(Some(FilterApi.eq(intColumn("a"), null: Integer))) { + parquetFilters.createFilter(parquetSchema, sources.In("a", Array(null))) + } + + assertResult(Some(FilterApi.eq(intColumn("a"), 10: Integer))) { + parquetFilters.createFilter(parquetSchema, sources.In("a", Array(10))) + } + + // Remove duplicates + assertResult(Some(FilterApi.eq(intColumn("a"), 10: Integer))) { + parquetFilters.createFilter(parquetSchema, sources.In("a", Array(10, 10))) + } + + assertResult(Some(or(or( + FilterApi.eq(intColumn("a"), 10: Integer), + FilterApi.eq(intColumn("a"), 20: Integer)), + FilterApi.eq(intColumn("a"), 30: Integer))) + ) { + parquetFilters.createFilter(parquetSchema, sources.In("a", Array(10, 20, 30))) + } + + assert(parquetFilters.createFilter(parquetSchema, sources.In("a", + Range(0, conf.parquetFilterPushDownInFilterThreshold).toArray)).isDefined) + assert(parquetFilters.createFilter(parquetSchema, sources.In("a", + Range(0, conf.parquetFilterPushDownInFilterThreshold + 1).toArray)).isEmpty) + + import testImplicits._ + withTempPath { path => + val data = 0 to 1024 + data.toDF("a").selectExpr("if (a = 1024, null, a) AS a") // convert 1024 to null + .coalesce(1).write.option("parquet.block.size", 512) + .parquet(path.getAbsolutePath) + val df = spark.read.parquet(path.getAbsolutePath) + Seq(true, false).foreach { pushEnabled => + withSQLConf( + SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED.key -> pushEnabled.toString) { + Seq(1, 5, 10, 11).foreach { count => + val filter = s"a in(${Range(0, count).mkString(",")})" + assert(df.where(filter).count() === count) + val actual = stripSparkFilter(df.where(filter)).collect().length + if (pushEnabled && count <= conf.parquetFilterPushDownInFilterThreshold) { + assert(actual > 1 && actual < data.length) + } else { + assert(actual === data.length) + } + } + assert(df.where("a in(null)").count() === 0) + assert(df.where("a = null").count() === 0) + assert(df.where("a is null").count() === 1) + } + } + } + } } class NumRowGroupsAcc extends AccumulatorV2[Integer, Integer] { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/WholeTextFileSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/WholeTextFileSuite.scala index fff0f82f9bc2b..a302d67b5cbf7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/WholeTextFileSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/WholeTextFileSuite.scala @@ -21,10 +21,10 @@ import java.io.File import org.apache.spark.sql.{QueryTest, Row} import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.test.{SharedSQLContext, SQLTestUtils} import org.apache.spark.sql.types.{StringType, StructType} -class WholeTextFileSuite extends QueryTest with SharedSQLContext { +class WholeTextFileSuite extends QueryTest with SharedSQLContext with SQLTestUtils { // Hadoop's FileSystem caching does not use the Configuration as part of its cache key, which // can cause Filesystem.get(Configuration) to return a cached instance created with a different @@ -35,13 +35,10 @@ class WholeTextFileSuite extends QueryTest with SharedSQLContext { protected override def sparkConf = super.sparkConf.set("spark.hadoop.fs.file.impl.disable.cache", "true") - private def testFile: String = { - Thread.currentThread().getContextClassLoader.getResource("test-data/text-suite.txt").toString - } - test("reading text file with option wholetext=true") { val df = spark.read.option("wholetext", "true") - .format("text").load(testFile) + .format("text") + .load(testFile("test-data/text-suite.txt")) // schema assert(df.schema == new StructType().add("value", StringType)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala index a3a3f3851e21c..8263c9c81c49e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala @@ -19,12 +19,12 @@ package org.apache.spark.sql.execution.metric import java.io.File +import scala.concurrent.{ExecutionContext, ExecutionContextExecutor, Future} import scala.util.Random import org.apache.spark.SparkFunSuite import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.plans.logical.LocalRelation -import org.apache.spark.sql.execution.ui.SQLAppStatusStore import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext @@ -504,4 +504,38 @@ class SQLMetricsSuite extends SparkFunSuite with SQLMetricsTestUtils with Shared test("writing data out metrics with dynamic partition: parquet") { testMetricsDynamicPartition("parquet", "parquet", "t1") } + + test("writing metrics from single thread") { + val nAdds = 10 + val acc = new SQLMetric("test", -10) + assert(acc.isZero()) + acc.set(0) + for (i <- 1 to nAdds) acc.add(1) + assert(!acc.isZero()) + assert(nAdds === acc.value) + acc.reset() + assert(acc.isZero()) + } + + test("writing metrics from multiple threads") { + implicit val ec: ExecutionContextExecutor = ExecutionContext.global + val nFutures = 1000 + val nAdds = 100 + val acc = new SQLMetric("test", -10) + assert(acc.isZero() === true) + acc.set(0) + val l = for ( i <- 1 to nFutures ) yield { + Future { + for (j <- 1 to nAdds) acc.add(1) + i + } + } + for (futures <- Future.sequence(l)) { + assert(nFutures === futures.length) + assert(!acc.isZero()) + assert(nFutures * nAdds === acc.value) + acc.reset() + assert(acc.isZero()) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonForeachWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonForeachWriterSuite.scala new file mode 100644 index 0000000000000..07e6034770127 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonForeachWriterSuite.scala @@ -0,0 +1,137 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.python + +import scala.collection.mutable.ArrayBuffer + +import org.scalatest.concurrent.Eventually +import org.scalatest.time.SpanSugar._ + +import org.apache.spark._ +import org.apache.spark.memory.{TaskMemoryManager, TestMemoryManager} +import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeProjection} +import org.apache.spark.sql.execution.python.PythonForeachWriter.UnsafeRowBuffer +import org.apache.spark.sql.types.{DataType, IntegerType} +import org.apache.spark.util.Utils + +class PythonForeachWriterSuite extends SparkFunSuite with Eventually { + + testWithBuffer("UnsafeRowBuffer: iterator blocks when no data is available") { b => + b.assertIteratorBlocked() + + b.add(Seq(1)) + b.assertOutput(Seq(1)) + b.assertIteratorBlocked() + + b.add(2 to 100) + b.assertOutput(1 to 100) + b.assertIteratorBlocked() + } + + testWithBuffer("UnsafeRowBuffer: iterator unblocks when all data added") { b => + b.assertIteratorBlocked() + b.add(Seq(1)) + b.assertIteratorBlocked() + + b.allAdded() + b.assertThreadTerminated() + b.assertOutput(Seq(1)) + } + + testWithBuffer( + "UnsafeRowBuffer: handles more data than memory", + memBytes = 5, + sleepPerRowReadMs = 1) { b => + + b.assertIteratorBlocked() + b.add(1 to 2000) + b.assertOutput(1 to 2000) + } + + def testWithBuffer( + name: String, + memBytes: Long = 4 << 10, + sleepPerRowReadMs: Int = 0 + )(f: BufferTester => Unit): Unit = { + + test(name) { + var tester: BufferTester = null + try { + tester = new BufferTester(memBytes, sleepPerRowReadMs) + f(tester) + } finally { + if (tester == null) tester.close() + } + } + } + + + class BufferTester(memBytes: Long, sleepPerRowReadMs: Int) { + private val buffer = { + val mem = new TestMemoryManager(new SparkConf()) + mem.limit(memBytes) + val taskM = new TaskMemoryManager(mem, 0) + new UnsafeRowBuffer(taskM, Utils.createTempDir(), 1) + } + private val iterator = buffer.iterator + private val outputBuffer = new ArrayBuffer[Int] + private val testTimeout = timeout(20.seconds) + private val intProj = UnsafeProjection.create(Array[DataType](IntegerType)) + private val thread = new Thread() { + override def run(): Unit = { + while (iterator.hasNext) { + outputBuffer.synchronized { + outputBuffer += iterator.next().getInt(0) + } + Thread.sleep(sleepPerRowReadMs) + } + } + } + thread.start() + + def add(ints: Seq[Int]): Unit = { + ints.foreach { i => buffer.add(intProj.apply(new GenericInternalRow(Array[Any](i)))) } + } + + def allAdded(): Unit = { buffer.allRowsAdded() } + + def assertOutput(expectedOutput: Seq[Int]): Unit = { + eventually(testTimeout) { + val output = outputBuffer.synchronized { outputBuffer.toArray }.toSeq + assert(output == expectedOutput) + } + } + + def assertIteratorBlocked(): Unit = { + import Thread.State._ + eventually(testTimeout) { + assert(thread.isAlive) + assert(thread.getState == TIMED_WAITING || thread.getState == WAITING) + } + } + + def assertThreadTerminated(): Unit = { + eventually(testTimeout) { assert(!thread.isAlive) } + } + + def close(): Unit = { + thread.interrupt() + thread.join() + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ForeachBatchSinkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ForeachBatchSinkSuite.scala new file mode 100644 index 0000000000000..a4233e15e4ffd --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ForeachBatchSinkSuite.scala @@ -0,0 +1,148 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.sources + +import scala.collection.mutable + +import org.apache.spark.sql._ +import org.apache.spark.sql.execution.streaming.MemoryStream +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.streaming._ + +case class KV(key: Int, value: Long) + +class ForeachBatchSinkSuite extends StreamTest { + import testImplicits._ + + test("foreachBatch with non-stateful query") { + val mem = MemoryStream[Int] + val ds = mem.toDS.map(_ + 1) + + val tester = new ForeachBatchTester[Int](mem) + val writer = (ds: Dataset[Int], batchId: Long) => tester.record(batchId, ds.map(_ + 1)) + + import tester._ + testWriter(ds, writer)( + check(in = 1, 2, 3)(out = 3, 4, 5), // out = in + 2 (i.e. 1 in query, 1 in writer) + check(in = 5, 6, 7)(out = 7, 8, 9)) + } + + test("foreachBatch with stateful query in update mode") { + val mem = MemoryStream[Int] + val ds = mem.toDF() + .select($"value" % 2 as "key") + .groupBy("key") + .agg(count("*") as "value") + .toDF.as[KV] + + val tester = new ForeachBatchTester[KV](mem) + val writer = (batchDS: Dataset[KV], batchId: Long) => tester.record(batchId, batchDS) + + import tester._ + testWriter(ds, writer, outputMode = OutputMode.Update)( + check(in = 0)(out = (0, 1L)), + check(in = 1)(out = (1, 1L)), + check(in = 2, 3)(out = (0, 2L), (1, 2L))) + } + + test("foreachBatch with stateful query in complete mode") { + val mem = MemoryStream[Int] + val ds = mem.toDF() + .select($"value" % 2 as "key") + .groupBy("key") + .agg(count("*") as "value") + .toDF.as[KV] + + val tester = new ForeachBatchTester[KV](mem) + val writer = (batchDS: Dataset[KV], batchId: Long) => tester.record(batchId, batchDS) + + import tester._ + testWriter(ds, writer, outputMode = OutputMode.Complete)( + check(in = 0)(out = (0, 1L)), + check(in = 1)(out = (0, 1L), (1, 1L)), + check(in = 2)(out = (0, 2L), (1, 1L))) + } + + test("foreachBatchSink does not affect metric generation") { + val mem = MemoryStream[Int] + val ds = mem.toDS.map(_ + 1) + + val tester = new ForeachBatchTester[Int](mem) + val writer = (ds: Dataset[Int], batchId: Long) => tester.record(batchId, ds.map(_ + 1)) + + import tester._ + testWriter(ds, writer)( + check(in = 1, 2, 3)(out = 3, 4, 5), + checkMetrics) + } + + test("throws errors in invalid situations") { + val ds = MemoryStream[Int].toDS + val ex1 = intercept[IllegalArgumentException] { + ds.writeStream.foreachBatch(null.asInstanceOf[(Dataset[Int], Long) => Unit]).start() + } + assert(ex1.getMessage.contains("foreachBatch function cannot be null")) + val ex2 = intercept[AnalysisException] { + ds.writeStream.foreachBatch((_, _) => {}).trigger(Trigger.Continuous("1 second")).start() + } + assert(ex2.getMessage.contains("'foreachBatch' is not supported with continuous trigger")) + val ex3 = intercept[AnalysisException] { + ds.writeStream.foreachBatch((_, _) => {}).partitionBy("value").start() + } + assert(ex3.getMessage.contains("'foreachBatch' does not support partitioning")) + } + + // ============== Helper classes and methods ================= + + private class ForeachBatchTester[T: Encoder](memoryStream: MemoryStream[Int]) { + trait Test + private case class Check(in: Seq[Int], out: Seq[T]) extends Test + private case object CheckMetrics extends Test + + private val recordedOutput = new mutable.HashMap[Long, Seq[T]] + + def testWriter( + ds: Dataset[T], + outputBatchWriter: (Dataset[T], Long) => Unit, + outputMode: OutputMode = OutputMode.Append())(tests: Test*): Unit = { + try { + var expectedBatchId = -1 + val query = ds.writeStream.outputMode(outputMode).foreachBatch(outputBatchWriter).start() + + tests.foreach { + case Check(in, out) => + expectedBatchId += 1 + memoryStream.addData(in) + query.processAllAvailable() + assert(recordedOutput.contains(expectedBatchId)) + val ds: Dataset[T] = spark.createDataset[T](recordedOutput(expectedBatchId)) + checkDataset[T](ds, out: _*) + case CheckMetrics => + assert(query.recentProgress.exists(_.numInputRows > 0)) + } + } finally { + sqlContext.streams.active.foreach(_.stop()) + } + } + + def check(in: Int*)(out: T*): Test = Check(in, out) + def checkMetrics: Test = CheckMetrics + def record(batchId: Long, ds: Dataset[T]): Unit = recordedOutput.put(batchId, ds.collect()) + implicit def conv(x: (Int, Long)): KV = KV(x._1, x._2) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/internal/ExecutorSideSQLConfSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/internal/ExecutorSideSQLConfSuite.scala index 3dd0712e02448..855fe4f4523f2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/internal/ExecutorSideSQLConfSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/internal/ExecutorSideSQLConfSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.internal import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.{AnalysisException, SparkSession} import org.apache.spark.sql.test.SQLTestUtils class ExecutorSideSQLConfSuite extends SparkFunSuite with SQLTestUtils { @@ -40,16 +40,24 @@ class ExecutorSideSQLConfSuite extends SparkFunSuite with SQLTestUtils { spark = null } + override def withSQLConf(pairs: (String, String)*)(f: => Unit): Unit = { + pairs.foreach { case (k, v) => + SQLConf.get.setConfString(k, v) + } + try f finally { + pairs.foreach { case (k, _) => + SQLConf.get.unsetConf(k) + } + } + } + test("ReadOnlySQLConf is correctly created at the executor side") { - SQLConf.get.setConfString("spark.sql.x", "a") - try { - val checks = spark.range(10).mapPartitions { it => + withSQLConf("spark.sql.x" -> "a") { + val checks = spark.range(10).mapPartitions { _ => val conf = SQLConf.get Iterator(conf.isInstanceOf[ReadOnlySQLConf] && conf.getConfString("spark.sql.x") == "a") }.collect() assert(checks.forall(_ == true)) - } finally { - SQLConf.get.unsetConf("spark.sql.x") } } @@ -63,4 +71,15 @@ class ExecutorSideSQLConfSuite extends SparkFunSuite with SQLTestUtils { } } } + + test("SPARK-24727 CODEGEN_CACHE_MAX_ENTRIES is correctly referenced at the executor side") { + withSQLConf(StaticSQLConf.CODEGEN_CACHE_MAX_ENTRIES.key -> "300") { + val checks = spark.range(10).mapPartitions { _ => + val conf = SQLConf.get + Iterator(conf.isInstanceOf[ReadOnlySQLConf] && + conf.getConfString(StaticSQLConf.CODEGEN_CACHE_MAX_ENTRIES.key) == "300") + }.collect() + assert(checks.forall(_ == true)) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index bc2aca65e803f..0389273d6cdfa 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -25,20 +25,21 @@ import org.h2.jdbc.JdbcSQLException import org.scalatest.{BeforeAndAfter, PrivateMethodTester} import org.apache.spark.{SparkException, SparkFunSuite} -import org.apache.spark.sql.{AnalysisException, DataFrame, Row} +import org.apache.spark.sql.{AnalysisException, DataFrame, QueryTest, Row} import org.apache.spark.sql.catalyst.parser.CatalystSqlParser import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap import org.apache.spark.sql.execution.DataSourceScanExec import org.apache.spark.sql.execution.command.ExplainCommand import org.apache.spark.sql.execution.datasources.LogicalRelation -import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JDBCRDD, JDBCRelation, JdbcUtils} +import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JDBCPartition, JDBCRDD, JDBCRelation, JdbcUtils} import org.apache.spark.sql.execution.metric.InputOutputMetricsHelper +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources._ import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ import org.apache.spark.util.Utils -class JDBCSuite extends SparkFunSuite +class JDBCSuite extends QueryTest with BeforeAndAfter with PrivateMethodTester with SharedSQLContext { import testImplicits._ @@ -238,6 +239,11 @@ class JDBCSuite extends SparkFunSuite |OPTIONS (url '$url', dbtable 'TEST."mixedCaseCols"', user 'testUser', password 'testPass') """.stripMargin.replaceAll("\n", " ")) + conn.prepareStatement("CREATE TABLE test.partition (THEID INTEGER, `THE ID` INTEGER) " + + "AS SELECT 1, 1") + .executeUpdate() + conn.commit() + // Untested: IDENTITY, OTHER, UUID, ARRAY, and GEOMETRY types. } @@ -1093,7 +1099,7 @@ class JDBCSuite extends SparkFunSuite test("SPARK-19318: Connection properties keys should be case-sensitive.") { def testJdbcOptions(options: JDBCOptions): Unit = { // Spark JDBC data source options are case-insensitive - assert(options.table == "t1") + assert(options.tableOrQuery == "t1") // When we convert it to properties, it should be case-sensitive. assert(options.asProperties.size == 3) assert(options.asProperties.get("customkey") == null) @@ -1206,4 +1212,135 @@ class JDBCSuite extends SparkFunSuite }.getMessage assert(errMsg.contains("Statement was canceled or the session timed out")) } + + test("SPARK-24327 verify and normalize a partition column based on a JDBC resolved schema") { + def testJdbcParitionColumn(partColName: String, expectedColumnName: String): Unit = { + val df = spark.read.format("jdbc") + .option("url", urlWithUserAndPass) + .option("dbtable", "TEST.PARTITION") + .option("partitionColumn", partColName) + .option("lowerBound", 1) + .option("upperBound", 4) + .option("numPartitions", 3) + .load() + + val quotedPrtColName = testH2Dialect.quoteIdentifier(expectedColumnName) + df.logicalPlan match { + case LogicalRelation(JDBCRelation(_, parts, _), _, _, _) => + val whereClauses = parts.map(_.asInstanceOf[JDBCPartition].whereClause).toSet + assert(whereClauses === Set( + s"$quotedPrtColName < 2 or $quotedPrtColName is null", + s"$quotedPrtColName >= 2 AND $quotedPrtColName < 3", + s"$quotedPrtColName >= 3")) + } + } + + testJdbcParitionColumn("THEID", "THEID") + testJdbcParitionColumn("\"THEID\"", "THEID") + withSQLConf("spark.sql.caseSensitive" -> "false") { + testJdbcParitionColumn("ThEiD", "THEID") + } + testJdbcParitionColumn("THE ID", "THE ID") + + def testIncorrectJdbcPartitionColumn(partColName: String): Unit = { + val errMsg = intercept[AnalysisException] { + testJdbcParitionColumn(partColName, "THEID") + }.getMessage + assert(errMsg.contains(s"User-defined partition column $partColName not found " + + "in the JDBC relation:")) + } + + testIncorrectJdbcPartitionColumn("NoExistingColumn") + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") { + testIncorrectJdbcPartitionColumn(testH2Dialect.quoteIdentifier("ThEiD")) + } + } + + test("query JDBC option - negative tests") { + val query = "SELECT * FROM test.people WHERE theid = 1" + // load path + val e1 = intercept[RuntimeException] { + val df = spark.read.format("jdbc") + .option("Url", urlWithUserAndPass) + .option("query", query) + .option("dbtable", "test.people") + .load() + }.getMessage + assert(e1.contains("Both 'dbtable' and 'query' can not be specified at the same time.")) + + // jdbc api path + val properties = new Properties() + properties.setProperty(JDBCOptions.JDBC_QUERY_STRING, query) + val e2 = intercept[RuntimeException] { + spark.read.jdbc(urlWithUserAndPass, "TEST.PEOPLE", properties).collect() + }.getMessage + assert(e2.contains("Both 'dbtable' and 'query' can not be specified at the same time.")) + + val e3 = intercept[RuntimeException] { + sql( + s""" + |CREATE OR REPLACE TEMPORARY VIEW queryOption + |USING org.apache.spark.sql.jdbc + |OPTIONS (url '$url', query '$query', dbtable 'TEST.PEOPLE', + | user 'testUser', password 'testPass') + """.stripMargin.replaceAll("\n", " ")) + }.getMessage + assert(e3.contains("Both 'dbtable' and 'query' can not be specified at the same time.")) + + val e4 = intercept[RuntimeException] { + val df = spark.read.format("jdbc") + .option("Url", urlWithUserAndPass) + .option("query", "") + .load() + }.getMessage + assert(e4.contains("Option `query` can not be empty.")) + + // Option query and partitioncolumn are not allowed together. + val expectedErrorMsg = + s""" + |Options 'query' and 'partitionColumn' can not be specified together. + |Please define the query using `dbtable` option instead and make sure to qualify + |the partition columns using the supplied subquery alias to resolve any ambiguity. + |Example : + |spark.read.format("jdbc") + | .option("dbtable", "(select c1, c2 from t1) as subq") + | .option("partitionColumn", "subq.c1" + | .load() + """.stripMargin + val e5 = intercept[RuntimeException] { + sql( + s""" + |CREATE OR REPLACE TEMPORARY VIEW queryOption + |USING org.apache.spark.sql.jdbc + |OPTIONS (url '$url', query '$query', user 'testUser', password 'testPass', + | partitionColumn 'THEID', lowerBound '1', upperBound '4', numPartitions '3') + """.stripMargin.replaceAll("\n", " ")) + }.getMessage + assert(e5.contains(expectedErrorMsg)) + } + + test("query JDBC option") { + val query = "SELECT name, theid FROM test.people WHERE theid = 1" + // query option to pass on the query string. + val df = spark.read.format("jdbc") + .option("Url", urlWithUserAndPass) + .option("query", query) + .load() + checkAnswer( + df, + Row("fred", 1) :: Nil) + + // query option in the create table path. + sql( + s""" + |CREATE OR REPLACE TEMPORARY VIEW queryOption + |USING org.apache.spark.sql.jdbc + |OPTIONS (url '$url', query '$query', user 'testUser', password 'testPass') + """.stripMargin.replaceAll("\n", " ")) + + checkAnswer( + sql("select name, theid from queryOption"), + Row("fred", 1) :: Nil) + + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala index 1c2c92d1f0737..b751ec2de4825 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala @@ -293,13 +293,23 @@ class JDBCWriteSuite extends SharedSQLContext with BeforeAndAfter { test("save errors if dbtable is not specified") { val df = spark.createDataFrame(sparkContext.parallelize(arr2x2), schema2) - val e = intercept[RuntimeException] { + val e1 = intercept[RuntimeException] { + df.write.format("jdbc") + .option("url", url1) + .options(properties.asScala) + .save() + }.getMessage + assert(e1.contains("Option 'dbtable' or 'query' is required")) + + val e2 = intercept[RuntimeException] { df.write.format("jdbc") .option("url", url1) .options(properties.asScala) + .option("query", "select * from TEST.SAVETEST") .save() }.getMessage - assert(e.contains("Option 'dbtable' is required")) + val msg = "Option 'dbtable' is required. Option 'query' is not applicable while writing." + assert(e2.contains(msg)) } test("save errors if wrong user/password combination") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala index fef01c860db6e..438d5d8176b8b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala @@ -20,12 +20,36 @@ package org.apache.spark.sql.sources import java.io.File import org.apache.spark.SparkException -import org.apache.spark.sql.{AnalysisException, Row} +import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.catalog.{CatalogStorageFormat, CatalogTable, CatalogTableType} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.SQLConf.PartitionOverwriteMode import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types._ import org.apache.spark.util.Utils +class SimpleInsertSource extends SchemaRelationProvider { + override def createRelation( + sqlContext: SQLContext, + parameters: Map[String, String], + schema: StructType): BaseRelation = { + SimpleInsert(schema)(sqlContext.sparkSession) + } +} + +case class SimpleInsert(userSpecifiedSchema: StructType)(@transient val sparkSession: SparkSession) + extends BaseRelation with InsertableRelation { + + override def sqlContext: SQLContext = sparkSession.sqlContext + + override def schema: StructType = userSpecifiedSchema + + override def insert(input: DataFrame, overwrite: Boolean): Unit = { + input.collect + } +} + class InsertSuite extends DataSourceTest with SharedSQLContext { import testImplicits._ @@ -520,4 +544,29 @@ class InsertSuite extends DataSourceTest with SharedSQLContext { } } } + + test("SPARK-24583 Wrong schema type in InsertIntoDataSourceCommand") { + withTable("test_table") { + val schema = new StructType() + .add("i", LongType, false) + .add("s", StringType, false) + val newTable = CatalogTable( + identifier = TableIdentifier("test_table", None), + tableType = CatalogTableType.EXTERNAL, + storage = CatalogStorageFormat( + locationUri = None, + inputFormat = None, + outputFormat = None, + serde = None, + compressed = false, + properties = Map.empty), + schema = schema, + provider = Some(classOf[SimpleInsertSource].getName)) + + spark.sessionState.catalog.createTable(newTable, false) + + sql("INSERT INTO TABLE test_table SELECT 1, 'a'") + sql("INSERT INTO TABLE test_table SELECT 2, null") + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala index 505a3f3465c02..e96cd4500458d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala @@ -323,21 +323,22 @@ class DataSourceV2Suite extends QueryTest with SharedSQLContext { } test("SPARK-23315: get output from canonicalized data source v2 related plans") { - def checkCanonicalizedOutput(df: DataFrame, numOutput: Int): Unit = { + def checkCanonicalizedOutput( + df: DataFrame, logicalNumOutput: Int, physicalNumOutput: Int): Unit = { val logical = df.queryExecution.optimizedPlan.collect { case d: DataSourceV2Relation => d }.head - assert(logical.canonicalized.output.length == numOutput) + assert(logical.canonicalized.output.length == logicalNumOutput) val physical = df.queryExecution.executedPlan.collect { case d: DataSourceV2ScanExec => d }.head - assert(physical.canonicalized.output.length == numOutput) + assert(physical.canonicalized.output.length == physicalNumOutput) } val df = spark.read.format(classOf[AdvancedDataSourceV2].getName).load() - checkCanonicalizedOutput(df, 2) - checkCanonicalizedOutput(df.select('i), 1) + checkCanonicalizedOutput(df, 2, 2) + checkCanonicalizedOutput(df.select('i), 2, 1) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala index 694bb3b95b0f0..1334cf71ae988 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala @@ -209,10 +209,10 @@ class SimpleCSVDataWriterFactory(path: String, jobId: String, conf: Serializable override def createDataWriter( partitionId: Int, - attemptNumber: Int, + taskId: Long, epochId: Long): DataWriter[Row] = { val jobPath = new Path(new Path(path, "_temporary"), jobId) - val filePath = new Path(jobPath, s"$jobId-$partitionId-$attemptNumber") + val filePath = new Path(jobPath, s"$jobId-$partitionId-$taskId") val fs = filePath.getFileSystem(conf.value) new SimpleCSVDataWriter(fs, filePath) } @@ -245,10 +245,10 @@ class InternalRowCSVDataWriterFactory(path: String, jobId: String, conf: Seriali override def createDataWriter( partitionId: Int, - attemptNumber: Int, + taskId: Long, epochId: Long): DataWriter[InternalRow] = { val jobPath = new Path(new Path(path, "_temporary"), jobId) - val filePath = new Path(jobPath, s"$jobId-$partitionId-$attemptNumber") + val filePath = new Path(jobPath, s"$jobId-$partitionId-$taskId") val fs = filePath.getFileSystem(conf.value) new InternalRowCSVDataWriter(fs, filePath) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala index 7e8fde1ff8e56..58ed9790ea123 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala @@ -18,18 +18,22 @@ package org.apache.spark.sql.streaming import java.{util => ju} +import java.io.File import java.text.SimpleDateFormat import java.util.{Calendar, Date} +import org.apache.commons.io.FileUtils import org.scalatest.{BeforeAndAfter, Matchers} import org.apache.spark.internal.Logging -import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.{AnalysisException, Dataset} import org.apache.spark.sql.catalyst.plans.logical.EventTimeWatermark +import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.functions.{count, window} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.OutputMode._ +import org.apache.spark.util.Utils class EventTimeWatermarkSuite extends StreamTest with BeforeAndAfter with Matchers with Logging { @@ -484,6 +488,136 @@ class EventTimeWatermarkSuite extends StreamTest with BeforeAndAfter with Matche testWithFlag(false) } + test("MultipleWatermarkPolicy: max") { + val input1 = MemoryStream[Int] + val input2 = MemoryStream[Int] + + withSQLConf(SQLConf.STREAMING_MULTIPLE_WATERMARK_POLICY.key -> "max") { + testStream(dfWithMultipleWatermarks(input1, input2))( + MultiAddData(input1, 20)(input2, 30), + CheckLastBatch(20, 30), + checkWatermark(input1, 15), // max(20 - 10, 30 - 15) = 15 + StopStream, + StartStream(), + checkWatermark(input1, 15), // watermark recovered correctly + MultiAddData(input1, 120)(input2, 130), + CheckLastBatch(120, 130), + checkWatermark(input1, 115), // max(120 - 10, 130 - 15) = 115, policy recovered correctly + AddData(input1, 150), + CheckLastBatch(150), + checkWatermark(input1, 140) // should advance even if one of the input has data + ) + } + } + + test("MultipleWatermarkPolicy: min") { + val input1 = MemoryStream[Int] + val input2 = MemoryStream[Int] + + withSQLConf(SQLConf.STREAMING_MULTIPLE_WATERMARK_POLICY.key -> "min") { + testStream(dfWithMultipleWatermarks(input1, input2))( + MultiAddData(input1, 20)(input2, 30), + CheckLastBatch(20, 30), + checkWatermark(input1, 10), // min(20 - 10, 30 - 15) = 10 + StopStream, + StartStream(), + checkWatermark(input1, 10), // watermark recovered correctly + MultiAddData(input1, 120)(input2, 130), + CheckLastBatch(120, 130), + checkWatermark(input2, 110), // min(120 - 10, 130 - 15) = 110, policy recovered correctly + AddData(input2, 150), + CheckLastBatch(150), + checkWatermark(input2, 110) // does not advance when only one of the input has data + ) + } + } + + test("MultipleWatermarkPolicy: recovery from checkpoints ignores session conf") { + val input1 = MemoryStream[Int] + val input2 = MemoryStream[Int] + + val checkpointDir = Utils.createTempDir().getCanonicalFile + withSQLConf(SQLConf.STREAMING_MULTIPLE_WATERMARK_POLICY.key -> "max") { + testStream(dfWithMultipleWatermarks(input1, input2))( + StartStream(checkpointLocation = checkpointDir.getAbsolutePath), + MultiAddData(input1, 20)(input2, 30), + CheckLastBatch(20, 30), + checkWatermark(input1, 15) // max(20 - 10, 30 - 15) = 15 + ) + } + + withSQLConf(SQLConf.STREAMING_MULTIPLE_WATERMARK_POLICY.key -> "min") { + testStream(dfWithMultipleWatermarks(input1, input2))( + StartStream(checkpointLocation = checkpointDir.getAbsolutePath), + checkWatermark(input1, 15), // watermark recovered correctly + MultiAddData(input1, 120)(input2, 130), + CheckLastBatch(120, 130), + checkWatermark(input1, 115), // max(120 - 10, 130 - 15) = 115, policy recovered correctly + AddData(input1, 150), + CheckLastBatch(150), + checkWatermark(input1, 140) // should advance even if one of the input has data + ) + } + } + + test("MultipleWatermarkPolicy: recovery from Spark ver 2.3.1 checkpoints ensures min policy") { + val input1 = MemoryStream[Int] + val input2 = MemoryStream[Int] + + val resourceUri = this.getClass.getResource( + "/structured-streaming/checkpoint-version-2.3.1-for-multi-watermark-policy/").toURI + + val checkpointDir = Utils.createTempDir().getCanonicalFile + // Copy the checkpoint to a temp dir to prevent changes to the original. + // Not doing this will lead to the test passing on the first run, but fail subsequent runs. + FileUtils.copyDirectory(new File(resourceUri), checkpointDir) + + input1.addData(20) + input2.addData(30) + input1.addData(10) + + withSQLConf(SQLConf.STREAMING_MULTIPLE_WATERMARK_POLICY.key -> "max") { + testStream(dfWithMultipleWatermarks(input1, input2))( + StartStream(checkpointLocation = checkpointDir.getAbsolutePath), + Execute { _.processAllAvailable() }, + MultiAddData(input1, 120)(input2, 130), + CheckLastBatch(120, 130), + checkWatermark(input2, 110), // should calculate 'min' even if session conf has 'max' policy + AddData(input2, 150), + CheckLastBatch(150), + checkWatermark(input2, 110) + ) + } + } + + test("MultipleWatermarkPolicy: fail on incorrect conf values") { + val invalidValues = Seq("", "random") + invalidValues.foreach { value => + val e = intercept[IllegalArgumentException] { + spark.conf.set(SQLConf.STREAMING_MULTIPLE_WATERMARK_POLICY.key, value) + } + assert(e.getMessage.toLowerCase.contains("valid values are 'min' and 'max'")) + } + } + + private def dfWithMultipleWatermarks( + input1: MemoryStream[Int], + input2: MemoryStream[Int]): Dataset[_] = { + val df1 = input1.toDF + .withColumn("eventTime", $"value".cast("timestamp")) + .withWatermark("eventTime", "10 seconds") + val df2 = input2.toDF + .withColumn("eventTime", $"value".cast("timestamp")) + .withWatermark("eventTime", "15 seconds") + df1.union(df2).select($"eventTime".cast("int")) + } + + private def checkWatermark(input: MemoryStream[Int], watermark: Long) = Execute { q => + input.addData(1) + q.processAllAvailable() + assert(q.lastProgress.eventTime.get("watermark") == formatTimestamp(watermark)) + } + private def assertNumStateRows(numTotalRows: Long): AssertOnQuery = AssertOnQuery { q => q.processAllAvailable() val progressWithData = q.recentProgress.lastOption.get diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala index c1ec1eba69fb2..ca38f04136c7d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala @@ -805,6 +805,114 @@ class StreamSuite extends StreamTest { } } + test("streaming limit without state") { + val inputData1 = MemoryStream[Int] + testStream(inputData1.toDF().limit(0))( + AddData(inputData1, 1 to 8: _*), + CheckAnswer()) + + val inputData2 = MemoryStream[Int] + testStream(inputData2.toDF().limit(4))( + AddData(inputData2, 1 to 8: _*), + CheckAnswer(1 to 4: _*)) + } + + test("streaming limit with state") { + val inputData = MemoryStream[Int] + testStream(inputData.toDF().limit(4))( + AddData(inputData, 1 to 2: _*), + CheckAnswer(1 to 2: _*), + AddData(inputData, 3 to 6: _*), + CheckAnswer(1 to 4: _*), + AddData(inputData, 7 to 9: _*), + CheckAnswer(1 to 4: _*)) + } + + test("streaming limit with other operators") { + val inputData = MemoryStream[Int] + testStream(inputData.toDF().where("value % 2 = 1").limit(4))( + AddData(inputData, 1 to 5: _*), + CheckAnswer(1, 3, 5), + AddData(inputData, 6 to 9: _*), + CheckAnswer(1, 3, 5, 7), + AddData(inputData, 10 to 12: _*), + CheckAnswer(1, 3, 5, 7)) + } + + test("streaming limit with multiple limits") { + val inputData1 = MemoryStream[Int] + testStream(inputData1.toDF().limit(4).limit(2))( + AddData(inputData1, 1), + CheckAnswer(1), + AddData(inputData1, 2 to 8: _*), + CheckAnswer(1, 2)) + + val inputData2 = MemoryStream[Int] + testStream(inputData2.toDF().limit(4).limit(100).limit(3))( + AddData(inputData2, 1, 2), + CheckAnswer(1, 2), + AddData(inputData2, 3 to 8: _*), + CheckAnswer(1 to 3: _*)) + } + + test("streaming limit in complete mode") { + val inputData = MemoryStream[Int] + val limited = inputData.toDF().limit(5).groupBy("value").count() + testStream(limited, OutputMode.Complete())( + AddData(inputData, 1 to 3: _*), + CheckAnswer(Row(1, 1), Row(2, 1), Row(3, 1)), + AddData(inputData, 1 to 9: _*), + CheckAnswer(Row(1, 2), Row(2, 2), Row(3, 2), Row(4, 1), Row(5, 1))) + } + + test("streaming limits in complete mode") { + val inputData = MemoryStream[Int] + val limited = inputData.toDF().limit(4).groupBy("value").count().orderBy("value").limit(3) + testStream(limited, OutputMode.Complete())( + AddData(inputData, 1 to 9: _*), + CheckAnswer(Row(1, 1), Row(2, 1), Row(3, 1)), + AddData(inputData, 2 to 6: _*), + CheckAnswer(Row(1, 1), Row(2, 2), Row(3, 2))) + } + + test("streaming limit in update mode") { + val inputData = MemoryStream[Int] + val e = intercept[AnalysisException] { + testStream(inputData.toDF().limit(5), OutputMode.Update())( + AddData(inputData, 1 to 3: _*) + ) + } + assert(e.getMessage.contains( + "Limits are not supported on streaming DataFrames/Datasets in Update output mode")) + } + + test("streaming limit in multiple partitions") { + val inputData = MemoryStream[Int] + testStream(inputData.toDF().repartition(2).limit(7))( + AddData(inputData, 1 to 10: _*), + CheckAnswerRowsByFunc( + rows => assert(rows.size == 7 && rows.forall(r => r.getInt(0) <= 10)), + false), + AddData(inputData, 11 to 20: _*), + CheckAnswerRowsByFunc( + rows => assert(rows.size == 7 && rows.forall(r => r.getInt(0) <= 10)), + false)) + } + + test("streaming limit in multiple partitions by column") { + val inputData = MemoryStream[(Int, Int)] + val df = inputData.toDF().repartition(2, $"_2").limit(7) + testStream(df)( + AddData(inputData, (1, 0), (2, 0), (3, 1), (4, 1)), + CheckAnswerRowsByFunc( + rows => assert(rows.size == 4 && rows.forall(r => r.getInt(0) <= 4)), + false), + AddData(inputData, (5, 0), (6, 0), (7, 1), (8, 1)), + CheckAnswerRowsByFunc( + rows => assert(rows.size == 7 && rows.forall(r => r.getInt(0) <= 8)), + false)) + } + for (e <- Seq( new InterruptedException, new InterruptedIOException, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala index 1f62357e6d09e..c5cc8df4356a8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala @@ -404,6 +404,20 @@ class StreamingInnerJoinSuite extends StreamTest with StateStoreMetricsTest with AddData(input3, 5, 10), CheckNewAnswer((5, 10, 5, 15, 5, 25))) } + + test("streaming join should require HashClusteredDistribution from children") { + val input1 = MemoryStream[Int] + val input2 = MemoryStream[Int] + + val df1 = input1.toDF.select('value as 'a, 'value * 2 as 'b) + val df2 = input2.toDF.select('value as 'a, 'value * 2 as 'b).repartition('b) + val joined = df1.join(df2, Seq("a", "b")).select('a) + + testStream(joined)( + AddData(input1, 1.to(1000): _*), + AddData(input2, 1.to(1000): _*), + CheckAnswer(1.to(1000): _*)) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenersConfSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenersConfSuite.scala new file mode 100644 index 0000000000000..1aaf8a9aa2d55 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenersConfSuite.scala @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.streaming + +import scala.language.reflectiveCalls + +import org.scalatest.BeforeAndAfter + +import org.apache.spark.SparkConf +import org.apache.spark.sql.execution.streaming._ +import org.apache.spark.sql.streaming.StreamingQueryListener._ + + +class StreamingQueryListenersConfSuite extends StreamTest with BeforeAndAfter { + + import testImplicits._ + + + override protected def sparkConf: SparkConf = + super.sparkConf.set("spark.sql.streaming.streamingQueryListeners", + "org.apache.spark.sql.streaming.TestListener") + + test("test if the configured query lister is loaded") { + testStream(MemoryStream[Int].toDS)( + StartStream(), + StopStream + ) + + assert(TestListener.queryStartedEvent != null) + assert(TestListener.queryTerminatedEvent != null) + } + +} + +object TestListener { + @volatile var queryStartedEvent: QueryStartedEvent = null + @volatile var queryTerminatedEvent: QueryTerminatedEvent = null +} + +class TestListener(sparkConf: SparkConf) extends StreamingQueryListener { + + override def onQueryStarted(event: QueryStartedEvent): Unit = { + TestListener.queryStartedEvent = event + } + + override def onQueryProgress(event: QueryProgressEvent): Unit = {} + + override def onQueryTerminated(event: QueryTerminatedEvent): Unit = { + TestListener.queryTerminatedEvent = event + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala index dcf6cb5d609ee..936a076d647b6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala @@ -335,8 +335,8 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi assert(progress.sources.length === 1) assert(progress.sources(0).description contains "MemoryStream") - assert(progress.sources(0).startOffset === "0") - assert(progress.sources(0).endOffset !== null) + assert(progress.sources(0).startOffset === null) // no prior offset + assert(progress.sources(0).endOffset === "0") assert(progress.sources(0).processedRowsPerSecond === 4.0) // 2 rows processed in 500 ms assert(progress.stateOperators.length === 1) @@ -362,6 +362,8 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi assert(query.lastProgress.batchId === 1) assert(query.lastProgress.inputRowsPerSecond === 2.0) assert(query.lastProgress.sources(0).inputRowsPerSecond === 2.0) + assert(query.lastProgress.sources(0).startOffset === "0") + assert(query.lastProgress.sources(0).endOffset === "1") true }, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousAggregationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousAggregationSuite.scala index b7ef637f5270e..0223812600961 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousAggregationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousAggregationSuite.scala @@ -31,7 +31,8 @@ class ContinuousAggregationSuite extends ContinuousSuiteBase { testStream(input.toDF().agg(max('value)), OutputMode.Complete)() } - assert(ex.getMessage.contains("Continuous processing does not support Aggregate operations")) + assert(ex.getMessage.contains( + "In continuous processing mode, coalesce(1) must be called before aggregate operation")) } test("basic") { @@ -50,6 +51,66 @@ class ContinuousAggregationSuite extends ContinuousSuiteBase { } } + test("multiple partitions with coalesce") { + val input = ContinuousMemoryStream[Int] + + val df = input.toDF().coalesce(1).agg(max('value)) + + testStream(df, OutputMode.Complete)( + AddData(input, 0, 1, 2), + CheckAnswer(2), + StopStream, + AddData(input, 3, 4, 5), + StartStream(), + CheckAnswer(5), + AddData(input, -1, -2, -3), + CheckAnswer(5)) + } + + test("multiple partitions with coalesce - multiple transformations") { + val input = ContinuousMemoryStream[Int] + + // We use a barrier to make sure predicates both before and after coalesce work + val df = input.toDF() + .select('value as 'copy, 'value) + .where('copy =!= 1) + .planWithBarrier + .coalesce(1) + .where('copy =!= 2) + .agg(max('value)) + + testStream(df, OutputMode.Complete)( + AddData(input, 0, 1, 2), + CheckAnswer(0), + StopStream, + AddData(input, 3, 4, 5), + StartStream(), + CheckAnswer(5), + AddData(input, -1, -2, -3), + CheckAnswer(5)) + } + + test("multiple partitions with multiple coalesce") { + val input = ContinuousMemoryStream[Int] + + val df = input.toDF() + .coalesce(1) + .planWithBarrier + .coalesce(1) + .select('value as 'copy, 'value) + .agg(max('value)) + + testStream(df, OutputMode.Complete)( + AddData(input, 0, 1, 2), + CheckAnswer(2), + StopStream, + AddData(input, 3, 4, 5), + StartStream(), + CheckAnswer(5), + AddData(input, -1, -2, -3), + CheckAnswer(5)) + } + test("repeated restart") { withSQLConf(("spark.sql.streaming.unsupportedOperationCheck", "false")) { val input = ContinuousMemoryStream.singlePartition[Int] diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousQueuedDataReaderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousQueuedDataReaderSuite.scala index e663fa8312da4..0e7e6febb53df 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousQueuedDataReaderSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousQueuedDataReaderSuite.scala @@ -92,7 +92,7 @@ class ContinuousQueuedDataReaderSuite extends StreamTest with MockitoSugar { } } val reader = new ContinuousQueuedDataReader( - factory, + new ContinuousDataSourceRDDPartition(0, factory), mockContext, dataQueueSize = sqlContext.conf.continuousStreamingExecutorQueueSize, epochPollIntervalMs = sqlContext.conf.continuousStreamingExecutorPollIntervalMs) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleReadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleSuite.scala similarity index 64% rename from sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleReadSuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleSuite.scala index 2e4d607a403ca..f84f3d49707bf 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleReadSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleSuite.scala @@ -17,29 +17,16 @@ package org.apache.spark.sql.execution.streaming.continuous.shuffle -import org.apache.spark.{TaskContext, TaskContextImpl} +import java.util.UUID + +import org.apache.spark.{HashPartitioner, Partition, TaskContext, TaskContextImpl} import org.apache.spark.rpc.RpcEndpointRef -import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeProjection} +import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeProjection, UnsafeRow} import org.apache.spark.sql.streaming.StreamTest import org.apache.spark.sql.types.{DataType, IntegerType, StringType} import org.apache.spark.unsafe.types.UTF8String -class ContinuousShuffleReadSuite extends StreamTest { - - private def unsafeRow(value: Int) = { - UnsafeProjection.create(Array(IntegerType : DataType))( - new GenericInternalRow(Array(value: Any))) - } - - private def unsafeRow(value: String) = { - UnsafeProjection.create(Array(StringType : DataType))( - new GenericInternalRow(Array(UTF8String.fromString(value): Any))) - } - - private def send(endpoint: RpcEndpointRef, messages: UnsafeRowReceiverMessage*) = { - messages.foreach(endpoint.askSync[Unit](_)) - } - +class ContinuousShuffleSuite extends StreamTest { // In this unit test, we emulate that we're in the task thread where // ContinuousShuffleReadRDD.compute() will be evaluated. This requires a task context // thread local to be set. @@ -58,39 +45,29 @@ class ContinuousShuffleReadSuite extends StreamTest { super.afterEach() } - test("receiver stopped with row last") { - val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1) - val endpoint = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint - send( - endpoint, - ReceiverEpochMarker(0), - ReceiverRow(0, unsafeRow(111)) - ) + private implicit def unsafeRow(value: Int) = { + UnsafeProjection.create(Array(IntegerType : DataType))( + new GenericInternalRow(Array(value: Any))) + } - ctx.markTaskCompleted(None) - val receiver = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].reader - eventually(timeout(streamingTimeout)) { - assert(receiver.asInstanceOf[UnsafeRowReceiver].stopped.get()) - } + private def unsafeRow(value: String) = { + UnsafeProjection.create(Array(StringType : DataType))( + new GenericInternalRow(Array(UTF8String.fromString(value): Any))) } - test("receiver stopped with marker last") { - val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1) - val endpoint = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint - send( - endpoint, - ReceiverRow(0, unsafeRow(111)), - ReceiverEpochMarker(0) - ) + private def send(endpoint: RpcEndpointRef, messages: RPCContinuousShuffleMessage*) = { + messages.foreach(endpoint.askSync[Unit](_)) + } - ctx.markTaskCompleted(None) - val receiver = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].reader - eventually(timeout(streamingTimeout)) { - assert(receiver.asInstanceOf[UnsafeRowReceiver].stopped.get()) - } + private def readRDDEndpoint(rdd: ContinuousShuffleReadRDD) = { + rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint } - test("one epoch") { + private def readEpoch(rdd: ContinuousShuffleReadRDD) = { + rdd.compute(rdd.partitions(0), ctx).toSeq.map(_.getInt(0)) + } + + test("reader - one epoch") { val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1) val endpoint = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint send( @@ -105,7 +82,7 @@ class ContinuousShuffleReadSuite extends StreamTest { assert(iter.toSeq.map(_.getInt(0)) == Seq(111, 222, 333)) } - test("multiple epochs") { + test("reader - multiple epochs") { val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1) val endpoint = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint send( @@ -124,7 +101,7 @@ class ContinuousShuffleReadSuite extends StreamTest { assert(secondEpoch.toSeq.map(_.getInt(0)) == Seq(222, 333)) } - test("empty epochs") { + test("reader - empty epochs") { val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1) val endpoint = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint @@ -148,8 +125,11 @@ class ContinuousShuffleReadSuite extends StreamTest { assert(rdd.compute(rdd.partitions(0), ctx).isEmpty) } - test("multiple partitions") { - val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 5) + test("reader - multiple partitions") { + val rdd = new ContinuousShuffleReadRDD( + sparkContext, + numPartitions = 5, + endpointNames = Seq.fill(5)(s"endpt-${UUID.randomUUID()}")) // Send all data before processing to ensure there's no crossover. for (p <- rdd.partitions) { val part = p.asInstanceOf[ContinuousShuffleReadPartition] @@ -169,7 +149,7 @@ class ContinuousShuffleReadSuite extends StreamTest { } } - test("blocks waiting for new rows") { + test("reader - blocks waiting for new rows") { val rdd = new ContinuousShuffleReadRDD( sparkContext, numPartitions = 1, epochIntervalMs = Long.MaxValue) val epoch = rdd.compute(rdd.partitions(0), ctx) @@ -195,7 +175,7 @@ class ContinuousShuffleReadSuite extends StreamTest { } } - test("multiple writers") { + test("reader - multiple writers") { val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1, numShuffleWriters = 3) val endpoint = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint send( @@ -213,7 +193,7 @@ class ContinuousShuffleReadSuite extends StreamTest { Set("writer0-row0", "writer1-row0", "writer2-row0")) } - test("epoch only ends when all writers send markers") { + test("reader - epoch only ends when all writers send markers") { val rdd = new ContinuousShuffleReadRDD( sparkContext, numPartitions = 1, numShuffleWriters = 3, epochIntervalMs = Long.MaxValue) val endpoint = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint @@ -233,6 +213,7 @@ class ContinuousShuffleReadSuite extends StreamTest { // After checking the right rows, block until we get an epoch marker indicating there's no next. // (Also fail the assertion if for some reason we get a row.) + val readEpochMarkerThread = new Thread { override def run(): Unit = { assert(!epoch.hasNext) @@ -251,10 +232,10 @@ class ContinuousShuffleReadSuite extends StreamTest { } // Join to pick up assertion failures. - readEpochMarkerThread.join() + readEpochMarkerThread.join(streamingTimeout.toMillis) } - test("writer epochs non aligned") { + test("reader - writer epochs non aligned") { val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1, numShuffleWriters = 3) val endpoint = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint // We send multiple epochs for 0, then multiple for 1, then multiple for 2. The receiver should @@ -288,4 +269,153 @@ class ContinuousShuffleReadSuite extends StreamTest { val thirdEpoch = rdd.compute(rdd.partitions(0), ctx).map(_.getUTF8String(0).toString).toSet assert(thirdEpoch == Set("writer1-row1", "writer2-row0")) } + + test("one epoch") { + val reader = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1) + val writer = new RPCContinuousShuffleWriter( + 0, new HashPartitioner(1), Array(readRDDEndpoint(reader))) + + writer.write(Iterator(1, 2, 3)) + + assert(readEpoch(reader) == Seq(1, 2, 3)) + } + + test("multiple epochs") { + val reader = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1) + val writer = new RPCContinuousShuffleWriter( + 0, new HashPartitioner(1), Array(readRDDEndpoint(reader))) + + writer.write(Iterator(1, 2, 3)) + writer.write(Iterator(4, 5, 6)) + + assert(readEpoch(reader) == Seq(1, 2, 3)) + assert(readEpoch(reader) == Seq(4, 5, 6)) + } + + test("empty epochs") { + val reader = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1) + val writer = new RPCContinuousShuffleWriter( + 0, new HashPartitioner(1), Array(readRDDEndpoint(reader))) + + writer.write(Iterator()) + writer.write(Iterator(1, 2)) + writer.write(Iterator()) + writer.write(Iterator()) + writer.write(Iterator(3, 4)) + writer.write(Iterator()) + + assert(readEpoch(reader) == Seq()) + assert(readEpoch(reader) == Seq(1, 2)) + assert(readEpoch(reader) == Seq()) + assert(readEpoch(reader) == Seq()) + assert(readEpoch(reader) == Seq(3, 4)) + assert(readEpoch(reader) == Seq()) + } + + test("blocks waiting for writer") { + val reader = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1) + val writer = new RPCContinuousShuffleWriter( + 0, new HashPartitioner(1), Array(readRDDEndpoint(reader))) + + val readerEpoch = reader.compute(reader.partitions(0), ctx) + + val readRowThread = new Thread { + override def run(): Unit = { + assert(readerEpoch.toSeq.map(_.getInt(0)) == Seq(1)) + } + } + readRowThread.start() + + eventually(timeout(streamingTimeout)) { + assert(readRowThread.getState == Thread.State.TIMED_WAITING) + } + + // Once we write the epoch the thread should stop waiting and succeed. + writer.write(Iterator(1)) + readRowThread.join(streamingTimeout.toMillis) + } + + test("multiple writer partitions") { + val numWriterPartitions = 3 + + val reader = new ContinuousShuffleReadRDD( + sparkContext, numPartitions = 1, numShuffleWriters = numWriterPartitions) + val writers = (0 until 3).map { idx => + new RPCContinuousShuffleWriter(idx, new HashPartitioner(1), Array(readRDDEndpoint(reader))) + } + + writers(0).write(Iterator(1, 4, 7)) + writers(1).write(Iterator(2, 5)) + writers(2).write(Iterator(3, 6)) + + writers(0).write(Iterator(4, 7, 10)) + writers(1).write(Iterator(5, 8)) + writers(2).write(Iterator(6, 9)) + + // Since there are multiple asynchronous writers, the original row sequencing is not guaranteed. + // The epochs should be deterministically preserved, however. + assert(readEpoch(reader).toSet == Seq(1, 2, 3, 4, 5, 6, 7).toSet) + assert(readEpoch(reader).toSet == Seq(4, 5, 6, 7, 8, 9, 10).toSet) + } + + test("reader epoch only ends when all writer partitions write it") { + val numWriterPartitions = 3 + + val reader = new ContinuousShuffleReadRDD( + sparkContext, numPartitions = 1, numShuffleWriters = numWriterPartitions) + val writers = (0 until 3).map { idx => + new RPCContinuousShuffleWriter(idx, new HashPartitioner(1), Array(readRDDEndpoint(reader))) + } + + writers(1).write(Iterator()) + writers(2).write(Iterator()) + + val readerEpoch = reader.compute(reader.partitions(0), ctx) + + val readEpochMarkerThread = new Thread { + override def run(): Unit = { + assert(!readerEpoch.hasNext) + } + } + + readEpochMarkerThread.start() + eventually(timeout(streamingTimeout)) { + assert(readEpochMarkerThread.getState == Thread.State.TIMED_WAITING) + } + + writers(0).write(Iterator()) + readEpochMarkerThread.join(streamingTimeout.toMillis) + } + + test("receiver stopped with row last") { + val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1) + val endpoint = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint + send( + endpoint, + ReceiverEpochMarker(0), + ReceiverRow(0, unsafeRow(111)) + ) + + ctx.markTaskCompleted(None) + val receiver = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].reader + eventually(timeout(streamingTimeout)) { + assert(receiver.asInstanceOf[RPCContinuousShuffleReader].stopped.get()) + } + } + + test("receiver stopped with marker last") { + val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1) + val endpoint = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint + send( + endpoint, + ReceiverRow(0, unsafeRow(111)), + ReceiverEpochMarker(0) + ) + + ctx.markTaskCompleted(None) + val receiver = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].reader + eventually(timeout(streamingTimeout)) { + assert(receiver.asInstanceOf[RPCContinuousShuffleReader].stopped.get()) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala index bc4a120f7042f..e562be83822e9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala @@ -391,6 +391,13 @@ private[sql] trait SQLTestUtilsBase val fs = hadoopPath.getFileSystem(spark.sessionState.newHadoopConf()) fs.makeQualified(hadoopPath).toUri } + + /** + * Returns full path to the given file in the resouce folder + */ + protected def testFile(fileName: String): String = { + Thread.currentThread().getContextClassLoader.getResource(fileName).toString + } } private[sql] object SQLTestUtils { diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerPage.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerPage.scala index 0950b30126773..771104ceb8842 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerPage.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerPage.scala @@ -76,7 +76,7 @@ private[ui] class ThriftServerPage(parent: ThriftServerTab) extends WebUIPage("" def generateDataRow(info: ExecutionInfo): Seq[Node] = { val jobLink = info.jobId.map { id: String => - [{id}] @@ -147,7 +147,7 @@ private[ui] class ThriftServerPage(parent: ThriftServerTab) extends WebUIPage("" val headerRow = Seq("User", "IP", "Session ID", "Start Time", "Finish Time", "Duration", "Total Execute") def generateDataRow(session: SessionInfo): Seq[Node] = { - val sessionLink = "%s/%s/session?id=%s".format( + val sessionLink = "%s/%s/session/?id=%s".format( UIUtils.prependBaseUri(request, parent.basePath), parent.prefix, session.sessionId)
    diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerSessionPage.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerSessionPage.scala index c884aa0ecbdf8..163eb43aabc72 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerSessionPage.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerSessionPage.scala @@ -86,7 +86,7 @@ private[ui] class ThriftServerSessionPage(parent: ThriftServerTab) def generateDataRow(info: ExecutionInfo): Seq[Node] = { val jobLink = info.jobId.map { id: String => - [{id}] diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala index 011a3ba553cb2..44480cecf0039 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala @@ -114,7 +114,7 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat * should interpret these special data source properties and restore the original table metadata * before returning it. */ - private[hive] def getRawTable(db: String, table: String): CatalogTable = withClient { + private[hive] def getRawTable(db: String, table: String): CatalogTable = { client.getTable(db, table) } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala index 94ddeae1bf547..6bbfdd84ba808 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala @@ -36,6 +36,7 @@ import org.apache.spark.sql.catalyst.parser.ParserInterface import org.apache.spark.sql.hive.HiveShim.HiveFunctionWrapper import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{DecimalType, DoubleType} +import org.apache.spark.util.Utils private[sql] class HiveSessionCatalog( @@ -141,8 +142,10 @@ private[sql] class HiveSessionCatalog( // built-in function. // Hive is case insensitive. val functionName = funcName.unquotedString.toLowerCase(Locale.ROOT) + logWarning(s"Encounter a failure during looking up function:" + + s" ${Utils.exceptionString(error)}") if (!hiveFunctions.contains(functionName)) { - failFunctionLookup(funcName) + failFunctionLookup(funcName, Some(Utils.exceptionString(error))) } // TODO: Remove this fallback path once we implement the list of fallback functions @@ -150,12 +153,12 @@ private[sql] class HiveSessionCatalog( val functionInfo = { try { Option(HiveFunctionRegistry.getFunctionInfo(functionName)).getOrElse( - failFunctionLookup(funcName)) + failFunctionLookup(funcName, Some(Utils.exceptionString(error)))) } catch { // If HiveFunctionRegistry.getFunctionInfo throws an exception, // we are failing to load a Hive builtin function, which means that // the given function is not a Hive builtin function. - case NonFatal(e) => failFunctionLookup(funcName) + case NonFatal(e) => failFunctionLookup(funcName, Some(Utils.exceptionString(e))) } } val className = functionInfo.getFunctionClass.getName @@ -175,6 +178,10 @@ private[sql] class HiveSessionCatalog( super.functionExists(name) || hiveFunctions.contains(name.funcName) } + override def isPersistentFunction(name: FunctionIdentifier): Boolean = { + super.isPersistentFunction(name) || hiveFunctions.contains(name.funcName) + } + /** List of functions we pass over to Hive. Note that over time this list should go to 0. */ // We have a list of Hive built-in functions that we do not support. So, we will check // Hive's function registry and lazily load needed functions into our own function registry. diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala index da9fe2d3088b4..db8fd5a43d842 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala @@ -353,15 +353,19 @@ private[hive] class HiveClientImpl( client.getDatabasesByPattern(pattern).asScala } + private def getRawTableOption(dbName: String, tableName: String): Option[HiveTable] = { + Option(client.getTable(dbName, tableName, false /* do not throw exception */)) + } + override def tableExists(dbName: String, tableName: String): Boolean = withHiveState { - Option(client.getTable(dbName, tableName, false /* do not throw exception */)).nonEmpty + getRawTableOption(dbName, tableName).nonEmpty } override def getTableOption( dbName: String, tableName: String): Option[CatalogTable] = withHiveState { logDebug(s"Looking up $dbName.$tableName") - Option(client.getTable(dbName, tableName, false)).map { h => + getRawTableOption(dbName, tableName).map { h => // Note: Hive separates partition columns and the schema, but for us the // partition columns are part of the schema val cols = h.getCols.asScala.map(fromHiveColumn) @@ -995,6 +999,8 @@ private[hive] object HiveClientImpl { tpart.setTableName(ht.getTableName) tpart.setValues(partValues.asJava) tpart.setSd(storageDesc) + tpart.setCreateTime((p.createTime / 1000).toInt) + tpart.setLastAccessTime((p.lastAccessTime / 1000).toInt) tpart.setParameters(mutable.Map(p.parameters.toSeq: _*).asJava) new HivePartition(ht, tpart) } @@ -1019,6 +1025,8 @@ private[hive] object HiveClientImpl { compressed = apiPartition.getSd.isCompressed, properties = Option(apiPartition.getSd.getSerdeInfo.getParameters) .map(_.asScala.toMap).orNull), + createTime = apiPartition.getCreateTime.toLong * 1000, + lastAccessTime = apiPartition.getLastAccessTime.toLong * 1000, parameters = properties, stats = readHiveStats(properties)) } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala index 130e258e78ca2..933384ed43e98 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala @@ -45,7 +45,7 @@ import org.apache.spark.sql.catalyst.analysis.NoSuchPermanentFunctionException import org.apache.spark.sql.catalyst.catalog.{CatalogFunction, CatalogTablePartition, CatalogUtils, FunctionResource, FunctionResourceType} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types.{IntegralType, StringType} +import org.apache.spark.sql.types.{AtomicType, IntegralType, StringType} import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.Utils @@ -342,7 +342,7 @@ private[client] class Shim_v0_12 extends Shim with Logging { } override def getMetastoreClientConnectRetryDelayMillis(conf: HiveConf): Long = { - conf.getIntVar(HiveConf.ConfVars.METASTORE_CLIENT_CONNECT_RETRY_DELAY) * 1000 + conf.getIntVar(HiveConf.ConfVars.METASTORE_CLIENT_CONNECT_RETRY_DELAY) * 1000L } override def loadPartition( @@ -660,7 +660,8 @@ private[client] class Shim_v0_13 extends Shim_v0_12 { def unapply(expr: Expression): Option[Attribute] = { expr match { case attr: Attribute => Some(attr) - case Cast(child, dt, _) if !Cast.mayTruncate(child.dataType, dt) => unapply(child) + case Cast(child @ AtomicType(), dt: AtomicType, _) + if Cast.canSafeCast(child.dataType.asInstanceOf[AtomicType], dt) => unapply(child) case _ => None } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala index 2f34f69b5cf48..6a90c44a2633d 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala @@ -182,6 +182,7 @@ private[hive] class IsolatedClientLoader( name.startsWith("org.slf4j") || name.startsWith("org.apache.log4j") || // log4j1.x name.startsWith("org.apache.logging.log4j") || // log4j2 + name.startsWith("org.apache.derby.") || name.startsWith("org.apache.spark.") || (sharesHadoopClasses && isHadoopClass) || name.startsWith("scala.") || diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala index 237ed9bc05988..20090696ec3fc 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala @@ -42,7 +42,7 @@ import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.execution.datasources.orc.OrcOptions import org.apache.spark.sql.hive.{HiveInspectors, HiveShim} import org.apache.spark.sql.sources.{Filter, _} -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.types._ import org.apache.spark.util.SerializableConfiguration /** @@ -72,6 +72,7 @@ class OrcFileFormat extends FileFormat with DataSourceRegister with Serializable job: Job, options: Map[String, String], dataSchema: StructType): OutputWriterFactory = { + val orcOptions = new OrcOptions(options, sparkSession.sessionState.conf) val configuration = job.getConfiguration @@ -121,6 +122,7 @@ class OrcFileFormat extends FileFormat with DataSourceRegister with Serializable filters: Seq[Filter], options: Map[String, String], hadoopConf: Configuration): (PartitionedFile) => Iterator[InternalRow] = { + if (sparkSession.sessionState.conf.orcFilterPushDown) { // Sets pushed predicates OrcFilters.createFilter(requiredSchema, filters.toArray).foreach { f => @@ -174,6 +176,23 @@ class OrcFileFormat extends FileFormat with DataSourceRegister with Serializable } } } + + override def supportDataType(dataType: DataType, isReadPath: Boolean): Boolean = dataType match { + case _: AtomicType => true + + case st: StructType => st.forall { f => supportDataType(f.dataType, isReadPath) } + + case ArrayType(elementType, _) => supportDataType(elementType, isReadPath) + + case MapType(keyType, valueType, _) => + supportDataType(keyType, isReadPath) && supportDataType(valueType, isReadPath) + + case udt: UserDefinedType[_] => supportDataType(udt.sqlType, isReadPath) + + case _: NullType => isReadPath + + case _ => false + } } private[orc] class OrcSerializer(dataSchema: StructType, conf: Configuration) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogSuite.scala index 0a522b6a11c80..1de258f060943 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogSuite.scala @@ -113,4 +113,10 @@ class HiveExternalCatalogSuite extends ExternalCatalogSuite { catalog.createDatabase(newDb("dbWithNullDesc").copy(description = null), ignoreIfExists = false) assert(catalog.getDatabase("dbWithNullDesc").description == "") } + + test("SPARK-23831: Add org.apache.derby to IsolatedClientLoader") { + val client1 = HiveUtils.newClientForMetadata(new SparkConf, new Configuration) + val client2 = HiveUtils.newClientForMetadata(new SparkConf, new Configuration) + assert(!client1.equals(client2)) + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala index 6f904c937348d..f8212684d5335 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala @@ -56,14 +56,21 @@ class HiveExternalCatalogVersionsSuite extends SparkSubmitTestUtils { } private def tryDownloadSpark(version: String, path: String): Unit = { - // Try mirrors a few times until one succeeds - for (i <- 0 until 3) { - // we don't retry on a failure to get mirror url. If we can't get a mirror url, - // the test fails (getStringFromUrl will throw an exception) - val preferredMirror = - getStringFromUrl("https://www.apache.org/dyn/closer.lua?preferred=true") + // Try a few mirrors first; fall back to Apache archive + val mirrors = + (0 until 2).flatMap { _ => + try { + Some(getStringFromUrl("https://www.apache.org/dyn/closer.lua?preferred=true")) + } catch { + // If we can't get a mirror URL, skip it. No retry. + case _: Exception => None + } + } + val sites = mirrors.distinct :+ "https://archive.apache.org/dist" + logInfo(s"Trying to download Spark $version from $sites") + for (site <- sites) { val filename = s"spark-$version-bin-hadoop2.7.tgz" - val url = s"$preferredMirror/spark/spark-$version/$filename" + val url = s"$site/spark/spark-$version/$filename" logInfo(s"Downloading Spark $version from $url") try { getFileFromUrl(url, path, filename) @@ -83,7 +90,8 @@ class HiveExternalCatalogVersionsSuite extends SparkSubmitTestUtils { Seq("rm", "-rf", targetDir).! } } catch { - case ex: Exception => logWarning(s"Failed to download Spark $version from $url", ex) + case ex: Exception => + logWarning(s"Failed to download Spark $version from $url: ${ex.getMessage}") } } fail(s"Unable to download Spark $version") @@ -195,7 +203,7 @@ class HiveExternalCatalogVersionsSuite extends SparkSubmitTestUtils { object PROCESS_TABLES extends QueryTest with SQLTestUtils { // Tests the latest version of every release line. - val testingVersions = Seq("2.0.2", "2.1.2", "2.2.1", "2.3.0") + val testingVersions = Seq("2.0.2", "2.1.2", "2.2.1", "2.3.1") protected var spark: SparkSession = _ diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientSuite.scala index 55275f6b37945..fa9f753795f65 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientSuite.scala @@ -24,7 +24,7 @@ import org.scalatest.BeforeAndAfterAll import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.types.LongType +import org.apache.spark.sql.types.{BooleanType, IntegerType, LongType} // TODO: Refactor this to `HivePartitionFilteringSuite` class HiveClientSuite(version: String) @@ -122,6 +122,22 @@ class HiveClientSuite(version: String) "aa" :: Nil) } + test("getPartitionsByFilter: cast(chunk as int)=1 (not a valid partition predicate)") { + testMetastorePartitionFiltering( + attr("chunk").cast(IntegerType) === 1, + 20170101 to 20170103, + 0 to 23, + "aa" :: "ab" :: "ba" :: "bb" :: Nil) + } + + test("getPartitionsByFilter: cast(chunk as boolean)=true (not a valid partition predicate)") { + testMetastorePartitionFiltering( + attr("chunk").cast(BooleanType) === true, + 20170101 to 20170103, + 0 to 23, + "aa" :: "ab" :: "ba" :: "bb" :: Nil) + } + test("getPartitionsByFilter: 20170101=ds") { testMetastorePartitionFiltering( Literal(20170101) === attr("ds"), @@ -138,7 +154,7 @@ class HiveClientSuite(version: String) "aa" :: "ab" :: "ba" :: "bb" :: Nil) } - test("getPartitionsByFilter: chunk in cast(ds as long)=20170101L") { + test("getPartitionsByFilter: cast(ds as long)=20170101L and h=10") { testMetastorePartitionFiltering( attr("ds").cast(LongType) === 20170101L && attr("h") === 10, 20170101 to 20170101, diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/HiveOrcSourceSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/HiveOrcSourceSuite.scala index d556a030e2186..fb4957ed943a7 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/HiveOrcSourceSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/HiveOrcSourceSuite.scala @@ -19,11 +19,13 @@ package org.apache.spark.sql.hive.orc import java.io.File -import org.apache.spark.sql.Row +import org.apache.spark.sql.{AnalysisException, Row} +import org.apache.spark.sql.TestingUDT.{IntervalData, IntervalUDT} import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.execution.datasources.orc.OrcSuite import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.internal.HiveSerDe +import org.apache.spark.sql.types._ import org.apache.spark.util.Utils class HiveOrcSourceSuite extends OrcSuite with TestHiveSingleton { @@ -133,4 +135,42 @@ class HiveOrcSourceSuite extends OrcSuite with TestHiveSingleton { Utils.deleteRecursively(location) } } + + test("SPARK-24204 error handling for unsupported data types") { + withTempDir { dir => + val orcDir = new File(dir, "orc").getCanonicalPath + + // write path + var msg = intercept[AnalysisException] { + sql("select interval 1 days").write.mode("overwrite").orc(orcDir) + }.getMessage + assert(msg.contains("Cannot save interval data type into external storage.")) + + msg = intercept[AnalysisException] { + sql("select null").write.mode("overwrite").orc(orcDir) + }.getMessage + assert(msg.contains("ORC data source does not support null data type.")) + + msg = intercept[AnalysisException] { + spark.udf.register("testType", () => new IntervalData()) + sql("select testType()").write.mode("overwrite").orc(orcDir) + }.getMessage + assert(msg.contains("ORC data source does not support interval data type.")) + + // read path + msg = intercept[AnalysisException] { + val schema = StructType(StructField("a", CalendarIntervalType, true) :: Nil) + spark.range(1).write.mode("overwrite").orc(orcDir) + spark.read.schema(schema).orc(orcDir).collect() + }.getMessage + assert(msg.contains("ORC data source does not support calendarinterval data type.")) + + msg = intercept[AnalysisException] { + val schema = StructType(StructField("a", new IntervalUDT(), true) :: Nil) + spark.range(1).write.mode("overwrite").orc(orcDir) + spark.read.schema(schema).orc(orcDir).collect() + }.getMessage + assert(msg.contains("ORC data source does not support interval data type.")) + } + } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/ui/BatchPage.scala b/streaming/src/main/scala/org/apache/spark/streaming/ui/BatchPage.scala index ca9da6139649a..884d21d0afdd3 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/ui/BatchPage.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/ui/BatchPage.scala @@ -109,7 +109,7 @@ private[ui] class BatchPage(parent: StreamingTab) extends WebUIPage("batch") { flatMap(info => info.failureReason).headOption.getOrElse("") val formattedDuration = duration.map(d => SparkUIUtils.formatDuration(d)).getOrElse("-") val detailUrl = s"${SparkUIUtils.prependBaseUri( - request, parent.basePath)}/jobs/job?id=${sparkJob.jobId}" + request, parent.basePath)}/jobs/job/?id=${sparkJob.jobId}" // In the first row, output op id and its information needs to be shown. In other rows, these // cells will be taken up due to "rowspan". diff --git a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingTab.scala b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingTab.scala index 9d1b82a6341b1..25e71258b9369 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingTab.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingTab.scala @@ -49,7 +49,7 @@ private[spark] class StreamingTab(val ssc: StreamingContext) def detach() { getSparkUI(ssc).detachTab(this) - getSparkUI(ssc).removeStaticHandler("/static/streaming") + getSparkUI(ssc).detachHandler("/static/streaming") } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala index ab7c8558321c8..2e8599026ea1d 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala @@ -222,7 +222,7 @@ private[streaming] class FileBasedWriteAheadLog( pastLogs += LogInfo(currentLogWriterStartTime, currentLogWriterStopTime, _) } currentLogWriterStartTime = currentTime - currentLogWriterStopTime = currentTime + (rollingIntervalSecs * 1000) + currentLogWriterStopTime = currentTime + (rollingIntervalSecs * 1000L) val newLogPath = new Path(logDirectory, timeToLogFile(currentLogWriterStartTime, currentLogWriterStopTime)) currentLogPath = Some(newLogPath.toString)
    keyvalue
    {session.userName}