diff --git a/.gitignore b/.gitignore index e4c44d0590d59..19db7ac277944 100644 --- a/.gitignore +++ b/.gitignore @@ -77,6 +77,7 @@ target/ unit-tests.log work/ docs/.jekyll-metadata +*.crc # For Hive TempStatsStore/ diff --git a/LICENSE b/LICENSE index 6f5d9452e800d..b771bd552b762 100644 --- a/LICENSE +++ b/LICENSE @@ -201,103 +201,61 @@ limitations under the License. -======================================================================= -Apache Spark Subcomponents: - -The Apache Spark project contains subcomponents with separate copyright -notices and license terms. Your use of the source code for the these -subcomponents is subject to the terms and conditions of the following -licenses. - - -======================================================================== -For heapq (pyspark/heapq3.py): -======================================================================== - -See license/LICENSE-heapq.txt - -======================================================================== -For SnapTree: -======================================================================== - -See license/LICENSE-SnapTree.txt - -======================================================================== -For jbcrypt: -======================================================================== - -See license/LICENSE-jbcrypt.txt - -======================================================================== -BSD-style licenses -======================================================================== - -The following components are provided under a BSD-style license. See project link for details. -The text of each license is also included at licenses/LICENSE-[project].txt. - - (BSD 3 Clause) netlib core (com.github.fommil.netlib:core:1.1.2 - https://github.com/fommil/netlib-java/core) - (BSD 3 Clause) JPMML-Model (org.jpmml:pmml-model:1.2.7 - https://github.com/jpmml/jpmml-model) - (BSD 3 Clause) jmock (org.jmock:jmock-junit4:2.8.4 - http://jmock.org/) - (BSD License) AntLR Parser Generator (antlr:antlr:2.7.7 - http://www.antlr.org/) - (BSD License) ANTLR 4.5.2-1 (org.antlr:antlr4:4.5.2-1 - http://wwww.antlr.org/) - (BSD licence) ANTLR ST4 4.0.4 (org.antlr:ST4:4.0.4 - http://www.stringtemplate.org) - (BSD licence) ANTLR StringTemplate (org.antlr:stringtemplate:3.2.1 - http://www.stringtemplate.org) - (BSD License) Javolution (javolution:javolution:5.5.1 - http://javolution.org) - (BSD) JLine (jline:jline:2.14.3 - https://github.com/jline/jline2) - (BSD) ParaNamer Core (com.thoughtworks.paranamer:paranamer:2.3 - http://paranamer.codehaus.org/paranamer) - (BSD) ParaNamer Core (com.thoughtworks.paranamer:paranamer:2.6 - http://paranamer.codehaus.org/paranamer) - (BSD 3 Clause) Scala (http://www.scala-lang.org/download/#License) - (Interpreter classes (all .scala files in repl/src/main/scala - except for Main.Scala, SparkHelper.scala and ExecutorClassLoader.scala), - and for SerializableMapWrapper in JavaUtils.scala) - (BSD-like) Scala Actors library (org.scala-lang:scala-actors:2.11.12 - http://www.scala-lang.org/) - (BSD-like) Scala Compiler (org.scala-lang:scala-compiler:2.11.12 - http://www.scala-lang.org/) - (BSD-like) Scala Compiler (org.scala-lang:scala-reflect:2.11.12 - http://www.scala-lang.org/) - (BSD-like) Scala Library (org.scala-lang:scala-library:2.11.12 - http://www.scala-lang.org/) - (BSD-like) Scalap (org.scala-lang:scalap:2.11.12 - http://www.scala-lang.org/) - (BSD-style) scalacheck (org.scalacheck:scalacheck_2.11:1.10.0 - http://www.scalacheck.org) - (BSD-style) spire (org.spire-math:spire_2.11:0.7.1 - http://spire-math.org) - (BSD-style) spire-macros (org.spire-math:spire-macros_2.11:0.7.1 - http://spire-math.org) - (New BSD License) Kryo (com.esotericsoftware:kryo:3.0.3 - https://github.com/EsotericSoftware/kryo) - (New BSD License) MinLog (com.esotericsoftware:minlog:1.3.0 - https://github.com/EsotericSoftware/minlog) - (New BSD license) Protocol Buffer Java API (com.google.protobuf:protobuf-java:2.5.0 - http://code.google.com/p/protobuf) - (New BSD license) Protocol Buffer Java API (org.spark-project.protobuf:protobuf-java:2.4.1-shaded - http://code.google.com/p/protobuf) - (The BSD License) Fortran to Java ARPACK (net.sourceforge.f2j:arpack_combined_all:0.1 - http://f2j.sourceforge.net) - (The BSD License) xmlenc Library (xmlenc:xmlenc:0.52 - http://xmlenc.sourceforge.net) - (The New BSD License) Py4J (net.sf.py4j:py4j:0.10.7 - http://py4j.sourceforge.net/) - (Two-clause BSD-style license) JUnit-Interface (com.novocode:junit-interface:0.10 - http://github.com/szeiger/junit-interface/) - (BSD licence) sbt and sbt-launch-lib.bash - (BSD 3 Clause) d3.min.js (https://github.com/mbostock/d3/blob/master/LICENSE) - (BSD 3 Clause) DPark (https://github.com/douban/dpark/blob/master/LICENSE) - (BSD 3 Clause) CloudPickle (https://github.com/cloudpipe/cloudpickle/blob/master/LICENSE) - (BSD 2 Clause) Zstd-jni (https://github.com/luben/zstd-jni/blob/master/LICENSE) - (BSD license) Zstd (https://github.com/facebook/zstd/blob/v1.3.1/LICENSE) - -======================================================================== -MIT licenses -======================================================================== - -The following components are provided under the MIT License. See project link for details. -The text of each license is also included at licenses/LICENSE-[project].txt. - - (MIT License) JCL 1.1.1 implemented over SLF4J (org.slf4j:jcl-over-slf4j:1.7.5 - http://www.slf4j.org) - (MIT License) JUL to SLF4J bridge (org.slf4j:jul-to-slf4j:1.7.5 - http://www.slf4j.org) - (MIT License) SLF4J API Module (org.slf4j:slf4j-api:1.7.5 - http://www.slf4j.org) - (MIT License) SLF4J LOG4J-12 Binding (org.slf4j:slf4j-log4j12:1.7.5 - http://www.slf4j.org) - (MIT License) pyrolite (org.spark-project:pyrolite:2.0.1 - http://pythonhosted.org/Pyro4/) - (MIT License) scopt (com.github.scopt:scopt_2.11:3.2.0 - https://github.com/scopt/scopt) - (The MIT License) Mockito (org.mockito:mockito-core:1.9.5 - http://www.mockito.org) - (MIT License) jquery (https://jquery.org/license/) - (MIT License) AnchorJS (https://github.com/bryanbraun/anchorjs) - (MIT License) graphlib-dot (https://github.com/cpettitt/graphlib-dot) - (MIT License) dagre-d3 (https://github.com/cpettitt/dagre-d3) - (MIT License) sorttable (https://github.com/stuartlangridge/sorttable) - (MIT License) boto (https://github.com/boto/boto/blob/develop/LICENSE) - (MIT License) datatables (http://datatables.net/license) - (MIT License) mustache (https://github.com/mustache/mustache/blob/master/LICENSE) - (MIT License) cookies (http://code.google.com/p/cookies/wiki/License) - (MIT License) blockUI (http://jquery.malsup.com/block/) - (MIT License) RowsGroup (http://datatables.net/license/mit) - (MIT License) jsonFormatter (http://www.jqueryscript.net/other/jQuery-Plugin-For-Pretty-JSON-Formatting-jsonFormatter.html) - (MIT License) modernizr (https://github.com/Modernizr/Modernizr/blob/master/LICENSE) - (MIT License) machinist (https://github.com/typelevel/machinist) +------------------------------------------------------------------------------------ +This product bundles various third-party components under other open source licenses. +This section summarizes those components and their licenses. See licenses/ +for text of these licenses. + + +Apache Software Foundation License 2.0 +-------------------------------------- + +common/network-common/src/main/java/org/apache/spark/network/util/LimitedInputStream.java +core/src/main/java/org/apache/spark/util/collection/TimSort.java +core/src/main/resources/org/apache/spark/ui/static/bootstrap* +core/src/main/resources/org/apache/spark/ui/static/jsonFormatter* +core/src/main/resources/org/apache/spark/ui/static/vis* +docs/js/vendor/bootstrap.js + + +Python Software Foundation License +---------------------------------- + +pyspark/heapq3.py + + +BSD 3-Clause +------------ + +python/lib/py4j-*-src.zip +python/pyspark/cloudpickle.py +python/pyspark/join.py +core/src/main/resources/org/apache/spark/ui/static/d3.min.js + +The CSS style for the navigation sidebar of the documentation was originally +submitted by Óscar Nájera for the scikit-learn project. The scikit-learn project +is distributed under the 3-Clause BSD license. + + +MIT License +----------- + +core/src/main/resources/org/apache/spark/ui/static/dagre-d3.min.js +core/src/main/resources/org/apache/spark/ui/static/*dataTables* +core/src/main/resources/org/apache/spark/ui/static/graphlib-dot.min.js +ore/src/main/resources/org/apache/spark/ui/static/jquery* +core/src/main/resources/org/apache/spark/ui/static/sorttable.js +docs/js/vendor/anchor.min.js +docs/js/vendor/jquery* +docs/js/vendor/modernizer* + + +Creative Commons CC0 1.0 Universal Public Domain Dedication +----------------------------------------------------------- +(see LICENSE-CC0.txt) + +data/mllib/images/kittens/29.5.a_b_EGDP022204.jpg +data/mllib/images/kittens/54893.jpg +data/mllib/images/kittens/DP153539.jpg +data/mllib/images/kittens/DP802813.jpg +data/mllib/images/multi-channel/chr30.4.184.jpg \ No newline at end of file diff --git a/LICENSE-binary b/LICENSE-binary new file mode 100644 index 0000000000000..b94ea90de08be --- /dev/null +++ b/LICENSE-binary @@ -0,0 +1,518 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + +------------------------------------------------------------------------------------ +This project bundles some components that are also licensed under the Apache +License Version 2.0: + +commons-beanutils:commons-beanutils +org.apache.zookeeper:zookeeper +oro:oro +commons-configuration:commons-configuration +commons-digester:commons-digester +com.chuusai:shapeless_2.11 +com.googlecode.javaewah:JavaEWAH +com.twitter:chill-java +com.twitter:chill_2.11 +com.univocity:univocity-parsers +javax.jdo:jdo-api +joda-time:joda-time +net.sf.opencsv:opencsv +org.apache.derby:derby +org.objenesis:objenesis +org.roaringbitmap:RoaringBitmap +org.scalanlp:breeze-macros_2.11 +org.scalanlp:breeze_2.11 +org.typelevel:macro-compat_2.11 +org.yaml:snakeyaml +org.apache.xbean:xbean-asm5-shaded +com.squareup.okhttp3:logging-interceptor +com.squareup.okhttp3:okhttp +com.squareup.okio:okio +org.apache.spark:spark-catalyst_2.11 +org.apache.spark:spark-kvstore_2.11 +org.apache.spark:spark-launcher_2.11 +org.apache.spark:spark-mllib-local_2.11 +org.apache.spark:spark-network-common_2.11 +org.apache.spark:spark-network-shuffle_2.11 +org.apache.spark:spark-sketch_2.11 +org.apache.spark:spark-tags_2.11 +org.apache.spark:spark-unsafe_2.11 +commons-httpclient:commons-httpclient +com.vlkan:flatbuffers +com.ning:compress-lzf +io.airlift:aircompressor +io.dropwizard.metrics:metrics-core +io.dropwizard.metrics:metrics-ganglia +io.dropwizard.metrics:metrics-graphite +io.dropwizard.metrics:metrics-json +io.dropwizard.metrics:metrics-jvm +org.iq80.snappy:snappy +com.clearspring.analytics:stream +com.jamesmurty.utils:java-xmlbuilder +commons-codec:commons-codec +commons-collections:commons-collections +io.fabric8:kubernetes-client +io.fabric8:kubernetes-model +io.netty:netty +io.netty:netty-all +net.hydromatic:eigenbase-properties +net.sf.supercsv:super-csv +org.apache.arrow:arrow-format +org.apache.arrow:arrow-memory +org.apache.arrow:arrow-vector +org.apache.calcite:calcite-avatica +org.apache.calcite:calcite-core +org.apache.calcite:calcite-linq4j +org.apache.commons:commons-crypto +org.apache.commons:commons-lang3 +org.apache.hadoop:hadoop-annotations +org.apache.hadoop:hadoop-auth +org.apache.hadoop:hadoop-client +org.apache.hadoop:hadoop-common +org.apache.hadoop:hadoop-hdfs +org.apache.hadoop:hadoop-mapreduce-client-app +org.apache.hadoop:hadoop-mapreduce-client-common +org.apache.hadoop:hadoop-mapreduce-client-core +org.apache.hadoop:hadoop-mapreduce-client-jobclient +org.apache.hadoop:hadoop-mapreduce-client-shuffle +org.apache.hadoop:hadoop-yarn-api +org.apache.hadoop:hadoop-yarn-client +org.apache.hadoop:hadoop-yarn-common +org.apache.hadoop:hadoop-yarn-server-common +org.apache.hadoop:hadoop-yarn-server-web-proxy +org.apache.httpcomponents:httpclient +org.apache.httpcomponents:httpcore +org.apache.orc:orc-core +org.apache.orc:orc-mapreduce +org.mortbay.jetty:jetty +org.mortbay.jetty:jetty-util +com.jolbox:bonecp +org.json4s:json4s-ast_2.11 +org.json4s:json4s-core_2.11 +org.json4s:json4s-jackson_2.11 +org.json4s:json4s-scalap_2.11 +com.carrotsearch:hppc +com.fasterxml.jackson.core:jackson-annotations +com.fasterxml.jackson.core:jackson-core +com.fasterxml.jackson.core:jackson-databind +com.fasterxml.jackson.dataformat:jackson-dataformat-yaml +com.fasterxml.jackson.module:jackson-module-jaxb-annotations +com.fasterxml.jackson.module:jackson-module-paranamer +com.fasterxml.jackson.module:jackson-module-scala_2.11 +com.github.mifmif:generex +com.google.code.findbugs:jsr305 +com.google.code.gson:gson +com.google.inject:guice +com.google.inject.extensions:guice-servlet +com.twitter:parquet-hadoop-bundle +commons-beanutils:commons-beanutils-core +commons-cli:commons-cli +commons-dbcp:commons-dbcp +commons-io:commons-io +commons-lang:commons-lang +commons-logging:commons-logging +commons-net:commons-net +commons-pool:commons-pool +io.fabric8:zjsonpatch +javax.inject:javax.inject +javax.validation:validation-api +log4j:apache-log4j-extras +log4j:log4j +net.sf.jpam:jpam +org.apache.avro:avro +org.apache.avro:avro-ipc +org.apache.avro:avro-mapred +org.apache.commons:commons-compress +org.apache.commons:commons-math3 +org.apache.curator:curator-client +org.apache.curator:curator-framework +org.apache.curator:curator-recipes +org.apache.directory.api:api-asn1-api +org.apache.directory.api:api-util +org.apache.directory.server:apacheds-i18n +org.apache.directory.server:apacheds-kerberos-codec +org.apache.htrace:htrace-core +org.apache.ivy:ivy +org.apache.mesos:mesos +org.apache.parquet:parquet-column +org.apache.parquet:parquet-common +org.apache.parquet:parquet-encoding +org.apache.parquet:parquet-format +org.apache.parquet:parquet-hadoop +org.apache.parquet:parquet-jackson +org.apache.thrift:libfb303 +org.apache.thrift:libthrift +org.codehaus.jackson:jackson-core-asl +org.codehaus.jackson:jackson-mapper-asl +org.datanucleus:datanucleus-api-jdo +org.datanucleus:datanucleus-core +org.datanucleus:datanucleus-rdbms +org.lz4:lz4-java +org.spark-project.hive:hive-beeline +org.spark-project.hive:hive-cli +org.spark-project.hive:hive-exec +org.spark-project.hive:hive-jdbc +org.spark-project.hive:hive-metastore +org.xerial.snappy:snappy-java +stax:stax-api +xerces:xercesImpl +org.codehaus.jackson:jackson-jaxrs +org.codehaus.jackson:jackson-xc +org.eclipse.jetty:jetty-client +org.eclipse.jetty:jetty-continuation +org.eclipse.jetty:jetty-http +org.eclipse.jetty:jetty-io +org.eclipse.jetty:jetty-jndi +org.eclipse.jetty:jetty-plus +org.eclipse.jetty:jetty-proxy +org.eclipse.jetty:jetty-security +org.eclipse.jetty:jetty-server +org.eclipse.jetty:jetty-servlet +org.eclipse.jetty:jetty-servlets +org.eclipse.jetty:jetty-util +org.eclipse.jetty:jetty-webapp +org.eclipse.jetty:jetty-xml + +core/src/main/java/org/apache/spark/util/collection/TimSort.java +core/src/main/resources/org/apache/spark/ui/static/bootstrap* +core/src/main/resources/org/apache/spark/ui/static/jsonFormatter* +core/src/main/resources/org/apache/spark/ui/static/vis* +docs/js/vendor/bootstrap.js + + +------------------------------------------------------------------------------------ +This product bundles various third-party components under other open source licenses. +This section summarizes those components and their licenses. See licenses-binary/ +for text of these licenses. + + +BSD 2-Clause +------------ + +com.github.luben:zstd-jni +javolution:javolution +com.esotericsoftware:kryo-shaded +com.esotericsoftware:minlog +com.esotericsoftware:reflectasm +com.google.protobuf:protobuf-java +org.codehaus.janino:commons-compiler +org.codehaus.janino:janino +jline:jline +org.jodd:jodd-core + + +BSD 3-Clause +------------ + +dk.brics.automaton:automaton +org.antlr:antlr-runtime +org.antlr:ST4 +org.antlr:stringtemplate +org.antlr:antlr4-runtime +antlr:antlr +com.github.fommil.netlib:core +com.thoughtworks.paranamer:paranamer +org.scala-lang:scala-compiler +org.scala-lang:scala-library +org.scala-lang:scala-reflect +org.scala-lang.modules:scala-parser-combinators_2.11 +org.scala-lang.modules:scala-xml_2.11 +org.fusesource.leveldbjni:leveldbjni-all +net.sourceforge.f2j:arpack_combined_all +xmlenc:xmlenc +net.sf.py4j:py4j +org.jpmml:pmml-model +org.jpmml:pmml-schema + +python/lib/py4j-*-src.zip +python/pyspark/cloudpickle.py +python/pyspark/join.py +core/src/main/resources/org/apache/spark/ui/static/d3.min.js + +The CSS style for the navigation sidebar of the documentation was originally +submitted by Óscar Nájera for the scikit-learn project. The scikit-learn project +is distributed under the 3-Clause BSD license. + + +MIT License +----------- + +org.spire-math:spire-macros_2.11 +org.spire-math:spire_2.11 +org.typelevel:machinist_2.11 +net.razorvine:pyrolite +org.slf4j:jcl-over-slf4j +org.slf4j:jul-to-slf4j +org.slf4j:slf4j-api +org.slf4j:slf4j-log4j12 +com.github.scopt:scopt_2.11 + +core/src/main/resources/org/apache/spark/ui/static/dagre-d3.min.js +core/src/main/resources/org/apache/spark/ui/static/*dataTables* +core/src/main/resources/org/apache/spark/ui/static/graphlib-dot.min.js +ore/src/main/resources/org/apache/spark/ui/static/jquery* +core/src/main/resources/org/apache/spark/ui/static/sorttable.js +docs/js/vendor/anchor.min.js +docs/js/vendor/jquery* +docs/js/vendor/modernizer* + + +Common Development and Distribution License (CDDL) 1.0 +------------------------------------------------------ + +javax.activation:activation http://www.oracle.com/technetwork/java/javase/tech/index-jsp-138795.html +javax.xml.stream:stax-api https://jcp.org/en/jsr/detail?id=173 + + +Common Development and Distribution License (CDDL) 1.1 +------------------------------------------------------ + +javax.annotation:javax.annotation-api https://jcp.org/en/jsr/detail?id=250 +javax.servlet:javax.servlet-api https://javaee.github.io/servlet-spec/ +javax.transaction:jta http://www.oracle.com/technetwork/java/index.html +javax.ws.rs:javax.ws.rs-api https://github.com/jax-rs +javax.xml.bind:jaxb-api https://github.com/javaee/jaxb-v2 +org.glassfish.hk2:hk2-api https://github.com/javaee/glassfish +org.glassfish.hk2:hk2-locator (same) +org.glassfish.hk2:hk2-utils +org.glassfish.hk2:osgi-resource-locator +org.glassfish.hk2.external:aopalliance-repackaged +org.glassfish.hk2.external:javax.inject +org.glassfish.jersey.bundles.repackaged:jersey-guava +org.glassfish.jersey.containers:jersey-container-servlet +org.glassfish.jersey.containers:jersey-container-servlet-core +org.glassfish.jersey.core:jersey-client +org.glassfish.jersey.core:jersey-common +org.glassfish.jersey.core:jersey-server +org.glassfish.jersey.media:jersey-media-jaxb + + +Mozilla Public License (MPL) 1.1 +-------------------------------- + +com.github.rwl:jtransforms https://sourceforge.net/projects/jtransforms/ + + +Python Software Foundation License +---------------------------------- + +pyspark/heapq3.py + + +Public Domain +------------- + +aopalliance:aopalliance +net.iharder:base64 +org.tukaani:xz + + +Creative Commons CC0 1.0 Universal Public Domain Dedication +----------------------------------------------------------- +(see LICENSE-CC0.txt) + +data/mllib/images/kittens/29.5.a_b_EGDP022204.jpg +data/mllib/images/kittens/54893.jpg +data/mllib/images/kittens/DP153539.jpg +data/mllib/images/kittens/DP802813.jpg +data/mllib/images/multi-channel/chr30.4.184.jpg diff --git a/NOTICE b/NOTICE index 6ec240efbf12e..fefe08b38afc5 100644 --- a/NOTICE +++ b/NOTICE @@ -5,663 +5,24 @@ This product includes software developed at The Apache Software Foundation (http://www.apache.org/). -======================================================================== -Common Development and Distribution License 1.0 -======================================================================== - -The following components are provided under the Common Development and Distribution License 1.0. See project link for details. - - (CDDL 1.0) Glassfish Jasper (org.mortbay.jetty:jsp-2.1:6.1.14 - http://jetty.mortbay.org/project/modules/jsp-2.1) - (CDDL 1.0) JAX-RS (https://jax-rs-spec.java.net/) - (CDDL 1.0) Servlet Specification 2.5 API (org.mortbay.jetty:servlet-api-2.5:6.1.14 - http://jetty.mortbay.org/project/modules/servlet-api-2.5) - (CDDL 1.0) (GPL2 w/ CPE) javax.annotation API (https://glassfish.java.net/nonav/public/CDDL+GPL.html) - (COMMON DEVELOPMENT AND DISTRIBUTION LICENSE (CDDL) Version 1.0) (GNU General Public Library) Streaming API for XML (javax.xml.stream:stax-api:1.0-2 - no url defined) - (Common Development and Distribution License (CDDL) v1.0) JavaBeans Activation Framework (JAF) (javax.activation:activation:1.1 - http://java.sun.com/products/javabeans/jaf/index.jsp) - -======================================================================== -Common Development and Distribution License 1.1 -======================================================================== - -The following components are provided under the Common Development and Distribution License 1.1. See project link for details. - - (CDDL 1.1) (GPL2 w/ CPE) org.glassfish.hk2 (https://hk2.java.net) - (CDDL 1.1) (GPL2 w/ CPE) JAXB API bundle for GlassFish V3 (javax.xml.bind:jaxb-api:2.2.2 - https://jaxb.dev.java.net/) - (CDDL 1.1) (GPL2 w/ CPE) JAXB RI (com.sun.xml.bind:jaxb-impl:2.2.3-1 - http://jaxb.java.net/) - (CDDL 1.1) (GPL2 w/ CPE) Jersey 2 (https://jersey.java.net) - -======================================================================== -Common Public License 1.0 -======================================================================== - -The following components are provided under the Common Public 1.0 License. See project link for details. - - (Common Public License Version 1.0) JUnit (junit:junit-dep:4.10 - http://junit.org) - (Common Public License Version 1.0) JUnit (junit:junit:3.8.1 - http://junit.org) - (Common Public License Version 1.0) JUnit (junit:junit:4.8.2 - http://junit.org) - -======================================================================== -Eclipse Public License 1.0 -======================================================================== - -The following components are provided under the Eclipse Public License 1.0. See project link for details. - - (Eclipse Public License v1.0) Eclipse JDT Core (org.eclipse.jdt:core:3.1.1 - http://www.eclipse.org/jdt/) - -======================================================================== -Mozilla Public License 1.0 -======================================================================== - -The following components are provided under the Mozilla Public License 1.0. See project link for details. - - (GPL) (LGPL) (MPL) JTransforms (com.github.rwl:jtransforms:2.4.0 - http://sourceforge.net/projects/jtransforms/) - (Mozilla Public License Version 1.1) jamon-runtime (org.jamon:jamon-runtime:2.3.1 - http://www.jamon.org/jamon-runtime/) - - - -======================================================================== -NOTICE files -======================================================================== - -The following NOTICEs are pertain to software distributed with this project. - - -// ------------------------------------------------------------------ -// NOTICE file corresponding to the section 4d of The Apache License, -// Version 2.0, in this case for -// ------------------------------------------------------------------ - -Apache Avro -Copyright 2009-2013 The Apache Software Foundation - -This product includes software developed at -The Apache Software Foundation (http://www.apache.org/). - -Apache Commons Codec -Copyright 2002-2009 The Apache Software Foundation - -This product includes software developed by -The Apache Software Foundation (http://www.apache.org/). - --------------------------------------------------------------------------------- -src/test/org/apache/commons/codec/language/DoubleMetaphoneTest.java contains -test data from http://aspell.sourceforge.net/test/batch0.tab. - -Copyright (C) 2002 Kevin Atkinson (kevina@gnu.org). Verbatim copying -and distribution of this entire article is permitted in any medium, -provided this notice is preserved. --------------------------------------------------------------------------------- - -Apache HttpComponents HttpClient -Copyright 1999-2011 The Apache Software Foundation - -This project contains annotations derived from JCIP-ANNOTATIONS -Copyright (c) 2005 Brian Goetz and Tim Peierls. See http://www.jcip.net - -Apache HttpComponents HttpCore -Copyright 2005-2011 The Apache Software Foundation - -Curator Recipes -Copyright 2011-2014 The Apache Software Foundation - -Curator Framework -Copyright 2011-2014 The Apache Software Foundation - -Curator Client -Copyright 2011-2014 The Apache Software Foundation - -Apache Geronimo -Copyright 2003-2008 The Apache Software Foundation - -Activation 1.1 -Copyright 2003-2007 The Apache Software Foundation - -Apache Commons Lang -Copyright 2001-2014 The Apache Software Foundation - -This product includes software from the Spring Framework, -under the Apache License 2.0 (see: StringUtils.containsWhitespace()) - -Apache log4j -Copyright 2007 The Apache Software Foundation - -# Compress LZF - -This library contains efficient implementation of LZF compression format, -as well as additional helper classes that build on JDK-provided gzip (deflat) -codec. - -## Licensing - -Library is licensed under Apache License 2.0, as per accompanying LICENSE file. - -## Credit - -Library has been written by Tatu Saloranta (tatu.saloranta@iki.fi). -It was started at Ning, inc., as an official Open Source process used by -platform backend, but after initial versions has been developed outside of -Ning by supporting community. - -Other contributors include: - -* Jon Hartlaub (first versions of streaming reader/writer; unit tests) -* Cedrik Lime: parallel LZF implementation - -Various community members have contributed bug reports, and suggested minor -fixes; these can be found from file "VERSION.txt" in SCM. - -Objenesis -Copyright 2006-2009 Joe Walnes, Henri Tremblay, Leonardo Mesquita - -Apache Commons Net -Copyright 2001-2010 The Apache Software Foundation - - The Netty Project - ================= - -Please visit the Netty web site for more information: - - * http://netty.io/ - -Copyright 2011 The Netty Project - -The Netty Project licenses this file to you under the Apache License, -version 2.0 (the "License"); you may not use this file except in compliance -with the License. You may obtain a copy of the License at: - -http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, WITHOUT -WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the -License for the specific language governing permissions and limitations -under the License. - -Also, please refer to each LICENSE..txt file, which is located in -the 'license' directory of the distribution file, for the license terms of the -components that this product depends on. - -------------------------------------------------------------------------------- -This product contains the extensions to Java Collections Framework which has -been derived from the works by JSR-166 EG, Doug Lea, and Jason T. Greene: - - * LICENSE: - * license/LICENSE.jsr166y.txt (Public Domain) - * HOMEPAGE: - * http://gee.cs.oswego.edu/cgi-bin/viewcvs.cgi/jsr166/ - * http://viewvc.jboss.org/cgi-bin/viewvc.cgi/jbosscache/experimental/jsr166/ - -This product contains a modified version of Robert Harder's Public Domain -Base64 Encoder and Decoder, which can be obtained at: - - * LICENSE: - * license/LICENSE.base64.txt (Public Domain) - * HOMEPAGE: - * http://iharder.sourceforge.net/current/java/base64/ - -This product contains a modified version of 'JZlib', a re-implementation of -zlib in pure Java, which can be obtained at: - - * LICENSE: - * license/LICENSE.jzlib.txt (BSD Style License) - * HOMEPAGE: - * http://www.jcraft.com/jzlib/ - -This product optionally depends on 'Protocol Buffers', Google's data -interchange format, which can be obtained at: - - * LICENSE: - * license/LICENSE.protobuf.txt (New BSD License) - * HOMEPAGE: - * http://code.google.com/p/protobuf/ - -This product optionally depends on 'SLF4J', a simple logging facade for Java, -which can be obtained at: - - * LICENSE: - * license/LICENSE.slf4j.txt (MIT License) - * HOMEPAGE: - * http://www.slf4j.org/ - -This product optionally depends on 'Apache Commons Logging', a logging -framework, which can be obtained at: - - * LICENSE: - * license/LICENSE.commons-logging.txt (Apache License 2.0) - * HOMEPAGE: - * http://commons.apache.org/logging/ - -This product optionally depends on 'Apache Log4J', a logging framework, -which can be obtained at: - - * LICENSE: - * license/LICENSE.log4j.txt (Apache License 2.0) - * HOMEPAGE: - * http://logging.apache.org/log4j/ - -This product optionally depends on 'JBoss Logging', a logging framework, -which can be obtained at: - - * LICENSE: - * license/LICENSE.jboss-logging.txt (GNU LGPL 2.1) - * HOMEPAGE: - * http://anonsvn.jboss.org/repos/common/common-logging-spi/ - -This product optionally depends on 'Apache Felix', an open source OSGi -framework implementation, which can be obtained at: - - * LICENSE: - * license/LICENSE.felix.txt (Apache License 2.0) - * HOMEPAGE: - * http://felix.apache.org/ - -This product optionally depends on 'Webbit', a Java event based -WebSocket and HTTP server: - - * LICENSE: - * license/LICENSE.webbit.txt (BSD License) - * HOMEPAGE: - * https://github.com/joewalnes/webbit - -# Jackson JSON processor - -Jackson is a high-performance, Free/Open Source JSON processing library. -It was originally written by Tatu Saloranta (tatu.saloranta@iki.fi), and has -been in development since 2007. -It is currently developed by a community of developers, as well as supported -commercially by FasterXML.com. - -Jackson core and extension components may be licensed under different licenses. -To find the details that apply to this artifact see the accompanying LICENSE file. -For more information, including possible other licensing options, contact -FasterXML.com (http://fasterxml.com). - -## Credits - -A list of contributors may be found from CREDITS file, which is included -in some artifacts (usually source distributions); but is always available -from the source code management (SCM) system project uses. - -Jackson core and extension components may licensed under different licenses. -To find the details that apply to this artifact see the accompanying LICENSE file. -For more information, including possible other licensing options, contact -FasterXML.com (http://fasterxml.com). - -mesos -Copyright 2014 The Apache Software Foundation - -Apache Thrift -Copyright 2006-2010 The Apache Software Foundation. - - Apache Ant - Copyright 1999-2013 The Apache Software Foundation - - The task is based on code Copyright (c) 2002, Landmark - Graphics Corp that has been kindly donated to the Apache Software - Foundation. - -Apache Commons IO -Copyright 2002-2012 The Apache Software Foundation - -Apache Commons Math -Copyright 2001-2013 The Apache Software Foundation - -=============================================================================== - -The inverse error function implementation in the Erf class is based on CUDA -code developed by Mike Giles, Oxford-Man Institute of Quantitative Finance, -and published in GPU Computing Gems, volume 2, 2010. -=============================================================================== - -The BracketFinder (package org.apache.commons.math3.optimization.univariate) -and PowellOptimizer (package org.apache.commons.math3.optimization.general) -classes are based on the Python code in module "optimize.py" (version 0.5) -developed by Travis E. Oliphant for the SciPy library (http://www.scipy.org/) -Copyright © 2003-2009 SciPy Developers. -=============================================================================== - -The LinearConstraint, LinearObjectiveFunction, LinearOptimizer, -RelationShip, SimplexSolver and SimplexTableau classes in package -org.apache.commons.math3.optimization.linear include software developed by -Benjamin McCann (http://www.benmccann.com) and distributed with -the following copyright: Copyright 2009 Google Inc. -=============================================================================== - -This product includes software developed by the -University of Chicago, as Operator of Argonne National -Laboratory. -The LevenbergMarquardtOptimizer class in package -org.apache.commons.math3.optimization.general includes software -translated from the lmder, lmpar and qrsolv Fortran routines -from the Minpack package -Minpack Copyright Notice (1999) University of Chicago. All rights reserved -=============================================================================== - -The GraggBulirschStoerIntegrator class in package -org.apache.commons.math3.ode.nonstiff includes software translated -from the odex Fortran routine developed by E. Hairer and G. Wanner. -Original source copyright: -Copyright (c) 2004, Ernst Hairer -=============================================================================== - -The EigenDecompositionImpl class in package -org.apache.commons.math3.linear includes software translated -from some LAPACK Fortran routines. Original source copyright: -Copyright (c) 1992-2008 The University of Tennessee. All rights reserved. -=============================================================================== - -The MersenneTwister class in package org.apache.commons.math3.random -includes software translated from the 2002-01-26 version of -the Mersenne-Twister generator written in C by Makoto Matsumoto and Takuji -Nishimura. Original source copyright: -Copyright (C) 1997 - 2002, Makoto Matsumoto and Takuji Nishimura, -All rights reserved -=============================================================================== - -The LocalizedFormatsTest class in the unit tests is an adapted version of -the OrekitMessagesTest class from the orekit library distributed under the -terms of the Apache 2 licence. Original source copyright: -Copyright 2010 CS Systèmes d'Information -=============================================================================== - -The HermiteInterpolator class and its corresponding test have been imported from -the orekit library distributed under the terms of the Apache 2 licence. Original -source copyright: -Copyright 2010-2012 CS Systèmes d'Information -=============================================================================== - -The creation of the package "o.a.c.m.analysis.integration.gauss" was inspired -by an original code donated by Sébastien Brisard. -=============================================================================== - -The complete text of licenses and disclaimers associated with the the original -sources enumerated above at the time of code translation are in the LICENSE.txt -file. - -This product currently only contains code developed by authors -of specific components, as identified by the source code files; -if such notes are missing files have been created by -Tatu Saloranta. - -For additional credits (generally to people who reported problems) -see CREDITS file. - -Apache Commons Lang -Copyright 2001-2011 The Apache Software Foundation - -Apache Commons Compress -Copyright 2002-2012 The Apache Software Foundation - -Apache Commons CLI -Copyright 2001-2009 The Apache Software Foundation - -Google Guice - Extensions - Servlet -Copyright 2006-2011 Google, Inc. - -Google Guice - Core Library -Copyright 2006-2011 Google, Inc. - -Apache Jakarta HttpClient -Copyright 1999-2007 The Apache Software Foundation - -Apache Hive -Copyright 2008-2013 The Apache Software Foundation - -This product includes software developed by The Apache Software -Foundation (http://www.apache.org/). - -This product includes software developed by The JDBM Project -(http://jdbm.sourceforge.net/). - -This product includes/uses ANTLR (http://www.antlr.org/), -Copyright (c) 2003-2011, Terrence Parr. - -This product includes/uses StringTemplate (http://www.stringtemplate.org/), -Copyright (c) 2011, Terrence Parr. - -This product includes/uses ASM (http://asm.ow2.org/), -Copyright (c) 2000-2007 INRIA, France Telecom. - -This product includes/uses JLine (http://jline.sourceforge.net/), -Copyright (c) 2002-2006, Marc Prud'hommeaux . - -This product includes/uses SQLLine (http://sqlline.sourceforge.net), -Copyright (c) 2002, 2003, 2004, 2005 Marc Prud'hommeaux . - -This product includes/uses SLF4J (http://www.slf4j.org/), -Copyright (c) 2004-2010 QOS.ch - -This product includes/uses Bootstrap (http://twitter.github.com/bootstrap/), -Copyright (c) 2012 Twitter, Inc. - -This product includes/uses Glyphicons (http://glyphicons.com/), -Copyright (c) 2010 - 2012 Jan Kovarík - -This product includes DataNucleus (http://www.datanucleus.org/) -Copyright 2008-2008 DataNucleus - -This product includes Guava (http://code.google.com/p/guava-libraries/) -Copyright (C) 2006 Google Inc. - -This product includes JavaEWAH (http://code.google.com/p/javaewah/) -Copyright (C) 2011 Google Inc. - -Apache Commons Pool -Copyright 1999-2009 The Apache Software Foundation - -This product includes/uses Kubernetes & OpenShift 3 Java Client (https://github.com/fabric8io/kubernetes-client) -Copyright (C) 2015 Red Hat, Inc. - -This product includes/uses OkHttp (https://github.com/square/okhttp) -Copyright (C) 2012 The Android Open Source Project - -========================================================================= -== NOTICE file corresponding to section 4(d) of the Apache License, == -== Version 2.0, in this case for the DataNucleus distribution. == -========================================================================= - -=================================================================== -This product includes software developed by many individuals, -including the following: -=================================================================== -Erik Bengtson -Andy Jefferson - -=================================================================== -This product has included contributions from some individuals, -including the following: -=================================================================== - -=================================================================== -This product has included contributions from some individuals, -including the following: -=================================================================== -Joerg von Frantzius -Thomas Marti -Barry Haddow -Marco Schulze -Ralph Ullrich -David Ezzio -Brendan de Beer -David Eaves -Martin Taal -Tony Lai -Roland Szabo -Marcus Mennemeier -Xuan Baldauf -Eric Sultan - -=================================================================== -This product also includes software developed by the TJDO project -(http://tjdo.sourceforge.net/). -=================================================================== - -=================================================================== -This product includes software developed by many individuals, -including the following: -=================================================================== -Andy Jefferson -Erik Bengtson -Joerg von Frantzius -Marco Schulze - -=================================================================== -This product has included contributions from some individuals, -including the following: -=================================================================== -Barry Haddow -Ralph Ullrich -David Ezzio -Brendan de Beer -David Eaves -Martin Taal -Tony Lai -Roland Szabo -Anton Troshin (Timesten) - -=================================================================== -This product also includes software developed by the Apache Commons project -(http://commons.apache.org/). -=================================================================== - -Apache Java Data Objects (JDO) -Copyright 2005-2006 The Apache Software Foundation - -========================================================================= -== NOTICE file corresponding to section 4(d) of the Apache License, == -== Version 2.0, in this case for the Apache Derby distribution. == -========================================================================= - -Apache Derby -Copyright 2004-2008 The Apache Software Foundation - -Portions of Derby were originally developed by -International Business Machines Corporation and are -licensed to the Apache Software Foundation under the -"Software Grant and Corporate Contribution License Agreement", -informally known as the "Derby CLA". -The following copyright notice(s) were affixed to portions of the code -with which this file is now or was at one time distributed -and are placed here unaltered. - -(C) Copyright 1997,2004 International Business Machines Corporation. All rights reserved. - -(C) Copyright IBM Corp. 2003. - -The portion of the functionTests under 'nist' was originally -developed by the National Institute of Standards and Technology (NIST), -an agency of the United States Department of Commerce, and adapted by -International Business Machines Corporation in accordance with the NIST -Software Acknowledgment and Redistribution document at -http://www.itl.nist.gov/div897/ctg/sql_form.htm - -Apache Commons Collections -Copyright 2001-2008 The Apache Software Foundation - -Apache Commons Configuration -Copyright 2001-2008 The Apache Software Foundation - -Apache Jakarta Commons Digester -Copyright 2001-2006 The Apache Software Foundation - -Apache Commons BeanUtils -Copyright 2000-2008 The Apache Software Foundation - -Apache Avro Mapred API -Copyright 2009-2013 The Apache Software Foundation - -Apache Avro IPC -Copyright 2009-2013 The Apache Software Foundation - - -Vis.js -Copyright 2010-2015 Almende B.V. - -Vis.js is dual licensed under both - - * The Apache 2.0 License - http://www.apache.org/licenses/LICENSE-2.0 - - and - - * The MIT License - http://opensource.org/licenses/MIT - -Vis.js may be distributed under either license. - - -Vis.js uses and redistributes the following third-party libraries: - -- component-emitter - https://github.com/component/emitter - The MIT License - -- hammer.js - http://hammerjs.github.io/ - The MIT License - -- moment.js - http://momentjs.com/ - The MIT License - -- keycharm - https://github.com/AlexDM0/keycharm - The MIT License - -=============================================================================== - -The CSS style for the navigation sidebar of the documentation was originally -submitted by Óscar Nájera for the scikit-learn project. The scikit-learn project -is distributed under the 3-Clause BSD license. -=============================================================================== - -For CSV functionality: - -/* - * Copyright 2014 Databricks - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/* - * Copyright 2015 Ayasdi Inc - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - - -=============================================================================== -For dev/sparktestsupport/toposort.py: - -Copyright 2014 True Blade Systems, Inc. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - -http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. +Export Control Notice +--------------------- + +This distribution includes cryptographic software. The country in which you currently reside may have +restrictions on the import, possession, use, and/or re-export to another country, of encryption software. +BEFORE using any encryption software, please check your country's laws, regulations and policies concerning +the import, possession, or use, and re-export of encryption software, to see if this is permitted. See + for more information. + +The U.S. Government Department of Commerce, Bureau of Industry and Security (BIS), has classified this +software as Export Commodity Control Number (ECCN) 5D002.C.1, which includes information security software +using or performing cryptographic functions with asymmetric algorithms. The form and manner of this Apache +Software Foundation distribution makes it eligible for export under the License Exception ENC Technology +Software Unrestricted (TSU) exception (see the BIS Export Administration Regulations, Section 740.13) for +both object code and source code. + +The following provides more details on the included cryptographic software: + +This software uses Apache Commons Crypto (https://commons.apache.org/proper/commons-crypto/) to +support authentication, and encryption and decryption of data sent across the network between +services. diff --git a/NOTICE-binary b/NOTICE-binary new file mode 100644 index 0000000000000..b707c436983f7 --- /dev/null +++ b/NOTICE-binary @@ -0,0 +1,1174 @@ +Apache Spark +Copyright 2014 and onwards The Apache Software Foundation. + +This product includes software developed at +The Apache Software Foundation (http://www.apache.org/). + + +Export Control Notice +--------------------- + +This distribution includes cryptographic software. The country in which you currently reside may have +restrictions on the import, possession, use, and/or re-export to another country, of encryption software. +BEFORE using any encryption software, please check your country's laws, regulations and policies concerning +the import, possession, or use, and re-export of encryption software, to see if this is permitted. See + for more information. + +The U.S. Government Department of Commerce, Bureau of Industry and Security (BIS), has classified this +software as Export Commodity Control Number (ECCN) 5D002.C.1, which includes information security software +using or performing cryptographic functions with asymmetric algorithms. The form and manner of this Apache +Software Foundation distribution makes it eligible for export under the License Exception ENC Technology +Software Unrestricted (TSU) exception (see the BIS Export Administration Regulations, Section 740.13) for +both object code and source code. + +The following provides more details on the included cryptographic software: + +This software uses Apache Commons Crypto (https://commons.apache.org/proper/commons-crypto/) to +support authentication, and encryption and decryption of data sent across the network between +services. + + +// ------------------------------------------------------------------ +// NOTICE file corresponding to the section 4d of The Apache License, +// Version 2.0, in this case for +// ------------------------------------------------------------------ + +Hive Beeline +Copyright 2016 The Apache Software Foundation + +This product includes software developed at +The Apache Software Foundation (http://www.apache.org/). + +Apache Avro +Copyright 2009-2014 The Apache Software Foundation + +This product currently only contains code developed by authors +of specific components, as identified by the source code files; +if such notes are missing files have been created by +Tatu Saloranta. + +For additional credits (generally to people who reported problems) +see CREDITS file. + +Apache Commons Compress +Copyright 2002-2012 The Apache Software Foundation + +This product includes software developed by +The Apache Software Foundation (http://www.apache.org/). + +Apache Avro Mapred API +Copyright 2009-2014 The Apache Software Foundation + +Apache Avro IPC +Copyright 2009-2014 The Apache Software Foundation + +Objenesis +Copyright 2006-2013 Joe Walnes, Henri Tremblay, Leonardo Mesquita + +Apache XBean :: ASM 5 shaded (repackaged) +Copyright 2005-2015 The Apache Software Foundation + +-------------------------------------- + +This product includes software developed at +OW2 Consortium (http://asm.ow2.org/) + +This product includes software developed by The Apache Software +Foundation (http://www.apache.org/). + +The binary distribution of this product bundles binaries of +org.iq80.leveldb:leveldb-api (https://github.com/dain/leveldb), which has the +following notices: +* Copyright 2011 Dain Sundstrom +* Copyright 2011 FuseSource Corp. http://fusesource.com + +The binary distribution of this product bundles binaries of +org.fusesource.hawtjni:hawtjni-runtime (https://github.com/fusesource/hawtjni), +which has the following notices: +* This product includes software developed by FuseSource Corp. + http://fusesource.com +* This product includes software developed at + Progress Software Corporation and/or its subsidiaries or affiliates. +* This product includes software developed by IBM Corporation and others. + +The binary distribution of this product bundles binaries of +Gson 2.2.4, +which has the following notices: + + The Netty Project + ================= + +Please visit the Netty web site for more information: + + * http://netty.io/ + +Copyright 2014 The Netty Project + +The Netty Project licenses this file to you under the Apache License, +version 2.0 (the "License"); you may not use this file except in compliance +with the License. You may obtain a copy of the License at: + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +License for the specific language governing permissions and limitations +under the License. + +Also, please refer to each LICENSE..txt file, which is located in +the 'license' directory of the distribution file, for the license terms of the +components that this product depends on. + +------------------------------------------------------------------------------- +This product contains the extensions to Java Collections Framework which has +been derived from the works by JSR-166 EG, Doug Lea, and Jason T. Greene: + + * LICENSE: + * license/LICENSE.jsr166y.txt (Public Domain) + * HOMEPAGE: + * http://gee.cs.oswego.edu/cgi-bin/viewcvs.cgi/jsr166/ + * http://viewvc.jboss.org/cgi-bin/viewvc.cgi/jbosscache/experimental/jsr166/ + +This product contains a modified version of Robert Harder's Public Domain +Base64 Encoder and Decoder, which can be obtained at: + + * LICENSE: + * license/LICENSE.base64.txt (Public Domain) + * HOMEPAGE: + * http://iharder.sourceforge.net/current/java/base64/ + +This product contains a modified portion of 'Webbit', an event based +WebSocket and HTTP server, which can be obtained at: + + * LICENSE: + * license/LICENSE.webbit.txt (BSD License) + * HOMEPAGE: + * https://github.com/joewalnes/webbit + +This product contains a modified portion of 'SLF4J', a simple logging +facade for Java, which can be obtained at: + + * LICENSE: + * license/LICENSE.slf4j.txt (MIT License) + * HOMEPAGE: + * http://www.slf4j.org/ + +This product contains a modified portion of 'ArrayDeque', written by Josh +Bloch of Google, Inc: + + * LICENSE: + * license/LICENSE.deque.txt (Public Domain) + +This product contains a modified portion of 'Apache Harmony', an open source +Java SE, which can be obtained at: + + * LICENSE: + * license/LICENSE.harmony.txt (Apache License 2.0) + * HOMEPAGE: + * http://archive.apache.org/dist/harmony/ + +This product contains a modified version of Roland Kuhn's ASL2 +AbstractNodeQueue, which is based on Dmitriy Vyukov's non-intrusive MPSC queue. +It can be obtained at: + + * LICENSE: + * license/LICENSE.abstractnodequeue.txt (Public Domain) + * HOMEPAGE: + * https://github.com/akka/akka/blob/wip-2.2.3-for-scala-2.11/akka-actor/src/main/java/akka/dispatch/AbstractNodeQueue.java + +This product contains a modified portion of 'jbzip2', a Java bzip2 compression +and decompression library written by Matthew J. Francis. It can be obtained at: + + * LICENSE: + * license/LICENSE.jbzip2.txt (MIT License) + * HOMEPAGE: + * https://code.google.com/p/jbzip2/ + +This product contains a modified portion of 'libdivsufsort', a C API library to construct +the suffix array and the Burrows-Wheeler transformed string for any input string of +a constant-size alphabet written by Yuta Mori. It can be obtained at: + + * LICENSE: + * license/LICENSE.libdivsufsort.txt (MIT License) + * HOMEPAGE: + * https://code.google.com/p/libdivsufsort/ + +This product contains a modified portion of Nitsan Wakart's 'JCTools', Java Concurrency Tools for the JVM, + which can be obtained at: + + * LICENSE: + * license/LICENSE.jctools.txt (ASL2 License) + * HOMEPAGE: + * https://github.com/JCTools/JCTools + +This product optionally depends on 'JZlib', a re-implementation of zlib in +pure Java, which can be obtained at: + + * LICENSE: + * license/LICENSE.jzlib.txt (BSD style License) + * HOMEPAGE: + * http://www.jcraft.com/jzlib/ + +This product optionally depends on 'Compress-LZF', a Java library for encoding and +decoding data in LZF format, written by Tatu Saloranta. It can be obtained at: + + * LICENSE: + * license/LICENSE.compress-lzf.txt (Apache License 2.0) + * HOMEPAGE: + * https://github.com/ning/compress + +This product optionally depends on 'lz4', a LZ4 Java compression +and decompression library written by Adrien Grand. It can be obtained at: + + * LICENSE: + * license/LICENSE.lz4.txt (Apache License 2.0) + * HOMEPAGE: + * https://github.com/jpountz/lz4-java + +This product optionally depends on 'lzma-java', a LZMA Java compression +and decompression library, which can be obtained at: + + * LICENSE: + * license/LICENSE.lzma-java.txt (Apache License 2.0) + * HOMEPAGE: + * https://github.com/jponge/lzma-java + +This product contains a modified portion of 'jfastlz', a Java port of FastLZ compression +and decompression library written by William Kinney. It can be obtained at: + + * LICENSE: + * license/LICENSE.jfastlz.txt (MIT License) + * HOMEPAGE: + * https://code.google.com/p/jfastlz/ + +This product contains a modified portion of and optionally depends on 'Protocol Buffers', Google's data +interchange format, which can be obtained at: + + * LICENSE: + * license/LICENSE.protobuf.txt (New BSD License) + * HOMEPAGE: + * http://code.google.com/p/protobuf/ + +This product optionally depends on 'Bouncy Castle Crypto APIs' to generate +a temporary self-signed X.509 certificate when the JVM does not provide the +equivalent functionality. It can be obtained at: + + * LICENSE: + * license/LICENSE.bouncycastle.txt (MIT License) + * HOMEPAGE: + * http://www.bouncycastle.org/ + +This product optionally depends on 'Snappy', a compression library produced +by Google Inc, which can be obtained at: + + * LICENSE: + * license/LICENSE.snappy.txt (New BSD License) + * HOMEPAGE: + * http://code.google.com/p/snappy/ + +This product optionally depends on 'JBoss Marshalling', an alternative Java +serialization API, which can be obtained at: + + * LICENSE: + * license/LICENSE.jboss-marshalling.txt (GNU LGPL 2.1) + * HOMEPAGE: + * http://www.jboss.org/jbossmarshalling + +This product optionally depends on 'Caliper', Google's micro- +benchmarking framework, which can be obtained at: + + * LICENSE: + * license/LICENSE.caliper.txt (Apache License 2.0) + * HOMEPAGE: + * http://code.google.com/p/caliper/ + +This product optionally depends on 'Apache Commons Logging', a logging +framework, which can be obtained at: + + * LICENSE: + * license/LICENSE.commons-logging.txt (Apache License 2.0) + * HOMEPAGE: + * http://commons.apache.org/logging/ + +This product optionally depends on 'Apache Log4J', a logging framework, which +can be obtained at: + + * LICENSE: + * license/LICENSE.log4j.txt (Apache License 2.0) + * HOMEPAGE: + * http://logging.apache.org/log4j/ + +This product optionally depends on 'Aalto XML', an ultra-high performance +non-blocking XML processor, which can be obtained at: + + * LICENSE: + * license/LICENSE.aalto-xml.txt (Apache License 2.0) + * HOMEPAGE: + * http://wiki.fasterxml.com/AaltoHome + +This product contains a modified version of 'HPACK', a Java implementation of +the HTTP/2 HPACK algorithm written by Twitter. It can be obtained at: + + * LICENSE: + * license/LICENSE.hpack.txt (Apache License 2.0) + * HOMEPAGE: + * https://github.com/twitter/hpack + +This product contains a modified portion of 'Apache Commons Lang', a Java library +provides utilities for the java.lang API, which can be obtained at: + + * LICENSE: + * license/LICENSE.commons-lang.txt (Apache License 2.0) + * HOMEPAGE: + * https://commons.apache.org/proper/commons-lang/ + +The binary distribution of this product bundles binaries of +Commons Codec 1.4, +which has the following notices: + * src/test/org/apache/commons/codec/language/DoubleMetaphoneTest.javacontains test data from http://aspell.net/test/orig/batch0.tab.Copyright (C) 2002 Kevin Atkinson (kevina@gnu.org) + =============================================================================== + The content of package org.apache.commons.codec.language.bm has been translated + from the original php source code available at http://stevemorse.org/phoneticinfo.htm + with permission from the original authors. + Original source copyright:Copyright (c) 2008 Alexander Beider & Stephen P. Morse. + +The binary distribution of this product bundles binaries of +Commons Lang 2.6, +which has the following notices: + * This product includes software from the Spring Framework,under the Apache License 2.0 (see: StringUtils.containsWhitespace()) + +The binary distribution of this product bundles binaries of +Apache Log4j 1.2.17, +which has the following notices: + * ResolverUtil.java + Copyright 2005-2006 Tim Fennell + Dumbster SMTP test server + Copyright 2004 Jason Paul Kitchen + TypeUtil.java + Copyright 2002-2012 Ramnivas Laddad, Juergen Hoeller, Chris Beams + +The binary distribution of this product bundles binaries of +Jetty 6.1.26, +which has the following notices: + * ============================================================== + Jetty Web Container + Copyright 1995-2016 Mort Bay Consulting Pty Ltd. + ============================================================== + + The Jetty Web Container is Copyright Mort Bay Consulting Pty Ltd + unless otherwise noted. + + Jetty is dual licensed under both + + * The Apache 2.0 License + http://www.apache.org/licenses/LICENSE-2.0.html + + and + + * The Eclipse Public 1.0 License + http://www.eclipse.org/legal/epl-v10.html + + Jetty may be distributed under either license. + + ------ + Eclipse + + The following artifacts are EPL. + * org.eclipse.jetty.orbit:org.eclipse.jdt.core + + The following artifacts are EPL and ASL2. + * org.eclipse.jetty.orbit:javax.security.auth.message + + The following artifacts are EPL and CDDL 1.0. + * org.eclipse.jetty.orbit:javax.mail.glassfish + + ------ + Oracle + + The following artifacts are CDDL + GPLv2 with classpath exception. + https://glassfish.dev.java.net/nonav/public/CDDL+GPL.html + + * javax.servlet:javax.servlet-api + * javax.annotation:javax.annotation-api + * javax.transaction:javax.transaction-api + * javax.websocket:javax.websocket-api + + ------ + Oracle OpenJDK + + If ALPN is used to negotiate HTTP/2 connections, then the following + artifacts may be included in the distribution or downloaded when ALPN + module is selected. + + * java.sun.security.ssl + + These artifacts replace/modify OpenJDK classes. The modififications + are hosted at github and both modified and original are under GPL v2 with + classpath exceptions. + http://openjdk.java.net/legal/gplv2+ce.html + + ------ + OW2 + + The following artifacts are licensed by the OW2 Foundation according to the + terms of http://asm.ow2.org/license.html + + org.ow2.asm:asm-commons + org.ow2.asm:asm + + ------ + Apache + + The following artifacts are ASL2 licensed. + + org.apache.taglibs:taglibs-standard-spec + org.apache.taglibs:taglibs-standard-impl + + ------ + MortBay + + The following artifacts are ASL2 licensed. Based on selected classes from + following Apache Tomcat jars, all ASL2 licensed. + + org.mortbay.jasper:apache-jsp + org.apache.tomcat:tomcat-jasper + org.apache.tomcat:tomcat-juli + org.apache.tomcat:tomcat-jsp-api + org.apache.tomcat:tomcat-el-api + org.apache.tomcat:tomcat-jasper-el + org.apache.tomcat:tomcat-api + org.apache.tomcat:tomcat-util-scan + org.apache.tomcat:tomcat-util + + org.mortbay.jasper:apache-el + org.apache.tomcat:tomcat-jasper-el + org.apache.tomcat:tomcat-el-api + + ------ + Mortbay + + The following artifacts are CDDL + GPLv2 with classpath exception. + + https://glassfish.dev.java.net/nonav/public/CDDL+GPL.html + + org.eclipse.jetty.toolchain:jetty-schemas + + ------ + Assorted + + The UnixCrypt.java code implements the one way cryptography used by + Unix systems for simple password protection. Copyright 1996 Aki Yoshida, + modified April 2001 by Iris Van den Broeke, Daniel Deville. + Permission to use, copy, modify and distribute UnixCrypt + for non-commercial or commercial purposes and without fee is + granted provided that the copyright notice appears in all copies./ + +The binary distribution of this product bundles binaries of +Snappy for Java 1.0.4.1, +which has the following notices: + * This product includes software developed by Google + Snappy: http://code.google.com/p/snappy/ (New BSD License) + + This product includes software developed by Apache + PureJavaCrc32C from apache-hadoop-common http://hadoop.apache.org/ + (Apache 2.0 license) + + This library contains statically linked libstdc++. This inclusion is allowed by + "GCC RUntime Library Exception" + http://gcc.gnu.org/onlinedocs/libstdc++/manual/license.html + + == Contributors == + * Tatu Saloranta + * Providing benchmark suite + * Alec Wysoker + * Performance and memory usage improvement + +The binary distribution of this product bundles binaries of +Xerces2 Java Parser 2.9.1, +which has the following notices: + * ========================================================================= + == NOTICE file corresponding to section 4(d) of the Apache License, == + == Version 2.0, in this case for the Apache Xerces Java distribution. == + ========================================================================= + + Apache Xerces Java + Copyright 1999-2007 The Apache Software Foundation + + This product includes software developed at + The Apache Software Foundation (http://www.apache.org/). + + Portions of this software were originally based on the following: + - software copyright (c) 1999, IBM Corporation., http://www.ibm.com. + - software copyright (c) 1999, Sun Microsystems., http://www.sun.com. + - voluntary contributions made by Paul Eng on behalf of the + Apache Software Foundation that were originally developed at iClick, Inc., + software copyright (c) 1999. + +Apache Commons Collections +Copyright 2001-2015 The Apache Software Foundation + +Apache Commons Configuration +Copyright 2001-2008 The Apache Software Foundation + +Apache Jakarta Commons Digester +Copyright 2001-2006 The Apache Software Foundation + +Apache Commons BeanUtils +Copyright 2000-2008 The Apache Software Foundation + +ApacheDS Protocol Kerberos Codec +Copyright 2003-2013 The Apache Software Foundation + +ApacheDS I18n +Copyright 2003-2013 The Apache Software Foundation + +Apache Directory API ASN.1 API +Copyright 2003-2013 The Apache Software Foundation + +Apache Directory LDAP API Utilities +Copyright 2003-2013 The Apache Software Foundation + +Curator Client +Copyright 2011-2015 The Apache Software Foundation + +htrace-core +Copyright 2015 The Apache Software Foundation + + ========================================================================= + == NOTICE file corresponding to section 4(d) of the Apache License, == + == Version 2.0, in this case for the Apache Xerces Java distribution. == + ========================================================================= + + Portions of this software were originally based on the following: + - software copyright (c) 1999, IBM Corporation., http://www.ibm.com. + - software copyright (c) 1999, Sun Microsystems., http://www.sun.com. + - voluntary contributions made by Paul Eng on behalf of the + Apache Software Foundation that were originally developed at iClick, Inc., + software copyright (c) 1999. + +# Jackson JSON processor + +Jackson is a high-performance, Free/Open Source JSON processing library. +It was originally written by Tatu Saloranta (tatu.saloranta@iki.fi), and has +been in development since 2007. +It is currently developed by a community of developers, as well as supported +commercially by FasterXML.com. + +## Licensing + +Jackson core and extension components may licensed under different licenses. +To find the details that apply to this artifact see the accompanying LICENSE file. +For more information, including possible other licensing options, contact +FasterXML.com (http://fasterxml.com). + +## Credits + +A list of contributors may be found from CREDITS file, which is included +in some artifacts (usually source distributions); but is always available +from the source code management (SCM) system project uses. + +Apache HttpCore +Copyright 2005-2017 The Apache Software Foundation + +Curator Recipes +Copyright 2011-2015 The Apache Software Foundation + +Curator Framework +Copyright 2011-2015 The Apache Software Foundation + +Apache Commons Lang +Copyright 2001-2016 The Apache Software Foundation + +This product includes software from the Spring Framework, +under the Apache License 2.0 (see: StringUtils.containsWhitespace()) + +Apache Commons Math +Copyright 2001-2015 The Apache Software Foundation + +This product includes software developed for Orekit by +CS Systèmes d'Information (http://www.c-s.fr/) +Copyright 2010-2012 CS Systèmes d'Information + +Apache log4j +Copyright 2007 The Apache Software Foundation + +# Compress LZF + +This library contains efficient implementation of LZF compression format, +as well as additional helper classes that build on JDK-provided gzip (deflat) +codec. + +Library is licensed under Apache License 2.0, as per accompanying LICENSE file. + +## Credit + +Library has been written by Tatu Saloranta (tatu.saloranta@iki.fi). +It was started at Ning, inc., as an official Open Source process used by +platform backend, but after initial versions has been developed outside of +Ning by supporting community. + +Other contributors include: + +* Jon Hartlaub (first versions of streaming reader/writer; unit tests) +* Cedrik Lime: parallel LZF implementation + +Various community members have contributed bug reports, and suggested minor +fixes; these can be found from file "VERSION.txt" in SCM. + +Apache Commons Net +Copyright 2001-2012 The Apache Software Foundation + +Copyright 2011 The Netty Project + +http://www.apache.org/licenses/LICENSE-2.0 + +This product contains a modified version of 'JZlib', a re-implementation of +zlib in pure Java, which can be obtained at: + + * LICENSE: + * license/LICENSE.jzlib.txt (BSD Style License) + * HOMEPAGE: + * http://www.jcraft.com/jzlib/ + +This product contains a modified version of 'Webbit', a Java event based +WebSocket and HTTP server: + +This product optionally depends on 'Protocol Buffers', Google's data +interchange format, which can be obtained at: + +This product optionally depends on 'SLF4J', a simple logging facade for Java, +which can be obtained at: + +This product optionally depends on 'Apache Log4J', a logging framework, +which can be obtained at: + +This product optionally depends on 'JBoss Logging', a logging framework, +which can be obtained at: + + * LICENSE: + * license/LICENSE.jboss-logging.txt (GNU LGPL 2.1) + * HOMEPAGE: + * http://anonsvn.jboss.org/repos/common/common-logging-spi/ + +This product optionally depends on 'Apache Felix', an open source OSGi +framework implementation, which can be obtained at: + + * LICENSE: + * license/LICENSE.felix.txt (Apache License 2.0) + * HOMEPAGE: + * http://felix.apache.org/ + +Jackson core and extension components may be licensed under different licenses. +To find the details that apply to this artifact see the accompanying LICENSE file. +For more information, including possible other licensing options, contact +FasterXML.com (http://fasterxml.com). + +Apache Ivy (TM) +Copyright 2007-2014 The Apache Software Foundation + +Portions of Ivy were originally developed at +Jayasoft SARL (http://www.jayasoft.fr/) +and are licensed to the Apache Software Foundation under the +"Software Grant License Agreement" + +SSH and SFTP support is provided by the JCraft JSch package, +which is open source software, available under +the terms of a BSD style license. +The original software and related information is available +at http://www.jcraft.com/jsch/. + + +ORC Core +Copyright 2013-2018 The Apache Software Foundation + +Apache Commons Lang +Copyright 2001-2011 The Apache Software Foundation + +ORC MapReduce +Copyright 2013-2018 The Apache Software Foundation + +Apache Parquet Format +Copyright 2017 The Apache Software Foundation + +Arrow Vectors +Copyright 2017 The Apache Software Foundation + +Arrow Format +Copyright 2017 The Apache Software Foundation + +Arrow Memory +Copyright 2017 The Apache Software Foundation + +Apache Commons CLI +Copyright 2001-2009 The Apache Software Foundation + +Google Guice - Extensions - Servlet +Copyright 2006-2011 Google, Inc. + +Apache Commons IO +Copyright 2002-2012 The Apache Software Foundation + +Google Guice - Core Library +Copyright 2006-2011 Google, Inc. + +mesos +Copyright 2017 The Apache Software Foundation + +Apache Parquet Hadoop Bundle (Incubating) +Copyright 2015 The Apache Software Foundation + +Hive Query Language +Copyright 2016 The Apache Software Foundation + +Apache Extras Companion for log4j 1.2. +Copyright 2007 The Apache Software Foundation + +Hive Metastore +Copyright 2016 The Apache Software Foundation + +Apache Commons Logging +Copyright 2003-2013 The Apache Software Foundation + +========================================================================= +== NOTICE file corresponding to section 4(d) of the Apache License, == +== Version 2.0, in this case for the DataNucleus distribution. == +========================================================================= + +=================================================================== +This product includes software developed by many individuals, +including the following: +=================================================================== +Erik Bengtson +Andy Jefferson + +=================================================================== +This product has included contributions from some individuals, +including the following: +=================================================================== + +=================================================================== +This product includes software developed by many individuals, +including the following: +=================================================================== +Andy Jefferson +Erik Bengtson +Joerg von Frantzius +Marco Schulze + +=================================================================== +This product has included contributions from some individuals, +including the following: +=================================================================== +Barry Haddow +Ralph Ullrich +David Ezzio +Brendan de Beer +David Eaves +Martin Taal +Tony Lai +Roland Szabo +Anton Troshin (Timesten) + +=================================================================== +This product also includes software developed by the TJDO project +(http://tjdo.sourceforge.net/). +=================================================================== + +=================================================================== +This product also includes software developed by the Apache Commons project +(http://commons.apache.org/). +=================================================================== + +Apache Commons Pool +Copyright 1999-2009 The Apache Software Foundation + +Apache Commons DBCP +Copyright 2001-2010 The Apache Software Foundation + +Apache Java Data Objects (JDO) +Copyright 2005-2006 The Apache Software Foundation + +Apache Jakarta HttpClient +Copyright 1999-2007 The Apache Software Foundation + +Calcite Avatica +Copyright 2012-2015 The Apache Software Foundation + +Calcite Core +Copyright 2012-2015 The Apache Software Foundation + +Calcite Linq4j +Copyright 2012-2015 The Apache Software Foundation + +Apache HttpClient +Copyright 1999-2017 The Apache Software Foundation + +Apache Commons Codec +Copyright 2002-2014 The Apache Software Foundation + +src/test/org/apache/commons/codec/language/DoubleMetaphoneTest.java +contains test data from http://aspell.net/test/orig/batch0.tab. +Copyright (C) 2002 Kevin Atkinson (kevina@gnu.org) + +=============================================================================== + +The content of package org.apache.commons.codec.language.bm has been translated +from the original php source code available at http://stevemorse.org/phoneticinfo.htm +with permission from the original authors. +Original source copyright: +Copyright (c) 2008 Alexander Beider & Stephen P. Morse. + +============================================================================= += NOTICE file corresponding to section 4d of the Apache License Version 2.0 = +============================================================================= +This product includes software developed by +Joda.org (http://www.joda.org/). + +=================================================================== +This product has included contributions from some individuals, +including the following: +=================================================================== +Joerg von Frantzius +Thomas Marti +Barry Haddow +Marco Schulze +Ralph Ullrich +David Ezzio +Brendan de Beer +David Eaves +Martin Taal +Tony Lai +Roland Szabo +Marcus Mennemeier +Xuan Baldauf +Eric Sultan + +Apache Thrift +Copyright 2006-2010 The Apache Software Foundation. + +========================================================================= +== NOTICE file corresponding to section 4(d) of the Apache License, +== Version 2.0, in this case for the Apache Derby distribution. +== +== DO NOT EDIT THIS FILE DIRECTLY. IT IS GENERATED +== BY THE buildnotice TARGET IN THE TOP LEVEL build.xml FILE. +== +========================================================================= + +Apache Derby +Copyright 2004-2015 The Apache Software Foundation + +========================================================================= + +Portions of Derby were originally developed by +International Business Machines Corporation and are +licensed to the Apache Software Foundation under the +"Software Grant and Corporate Contribution License Agreement", +informally known as the "Derby CLA". +The following copyright notice(s) were affixed to portions of the code +with which this file is now or was at one time distributed +and are placed here unaltered. + +(C) Copyright 1997,2004 International Business Machines Corporation. All rights reserved. + +(C) Copyright IBM Corp. 2003. + +The portion of the functionTests under 'nist' was originally +developed by the National Institute of Standards and Technology (NIST), +an agency of the United States Department of Commerce, and adapted by +International Business Machines Corporation in accordance with the NIST +Software Acknowledgment and Redistribution document at +http://www.itl.nist.gov/div897/ctg/sql_form.htm + +The JDBC apis for small devices and JDBC3 (under java/stubs/jsr169 and +java/stubs/jdbc3) were produced by trimming sources supplied by the +Apache Harmony project. In addition, the Harmony SerialBlob and +SerialClob implementations are used. The following notice covers the Harmony sources: + +Portions of Harmony were originally developed by +Intel Corporation and are licensed to the Apache Software +Foundation under the "Software Grant and Corporate Contribution +License Agreement", informally known as the "Intel Harmony CLA". + +The Derby build relies on source files supplied by the Apache Felix +project. The following notice covers the Felix files: + + Apache Felix Main + Copyright 2008 The Apache Software Foundation + + I. Included Software + + This product includes software developed at + The Apache Software Foundation (http://www.apache.org/). + Licensed under the Apache License 2.0. + + This product includes software developed at + The OSGi Alliance (http://www.osgi.org/). + Copyright (c) OSGi Alliance (2000, 2007). + Licensed under the Apache License 2.0. + + This product includes software from http://kxml.sourceforge.net. + Copyright (c) 2002,2003, Stefan Haustein, Oberhausen, Rhld., Germany. + Licensed under BSD License. + + II. Used Software + + This product uses software developed at + The OSGi Alliance (http://www.osgi.org/). + Copyright (c) OSGi Alliance (2000, 2007). + Licensed under the Apache License 2.0. + + III. License Summary + - Apache License 2.0 + - BSD License + +The Derby build relies on jar files supplied by the Apache Lucene +project. The following notice covers the Lucene files: + +Apache Lucene +Copyright 2013 The Apache Software Foundation + +Includes software from other Apache Software Foundation projects, +including, but not limited to: + - Apache Ant + - Apache Jakarta Regexp + - Apache Commons + - Apache Xerces + +ICU4J, (under analysis/icu) is licensed under an MIT styles license +and Copyright (c) 1995-2008 International Business Machines Corporation and others + +Some data files (under analysis/icu/src/data) are derived from Unicode data such +as the Unicode Character Database. See http://unicode.org/copyright.html for more +details. + +Brics Automaton (under core/src/java/org/apache/lucene/util/automaton) is +BSD-licensed, created by Anders Møller. See http://www.brics.dk/automaton/ + +The levenshtein automata tables (under core/src/java/org/apache/lucene/util/automaton) were +automatically generated with the moman/finenight FSA library, created by +Jean-Philippe Barrette-LaPierre. This library is available under an MIT license, +see http://sites.google.com/site/rrettesite/moman and +http://bitbucket.org/jpbarrette/moman/overview/ + +The class org.apache.lucene.util.WeakIdentityMap was derived from +the Apache CXF project and is Apache License 2.0. + +The Google Code Prettify is Apache License 2.0. +See http://code.google.com/p/google-code-prettify/ + +JUnit (junit-4.10) is licensed under the Common Public License v. 1.0 +See http://junit.sourceforge.net/cpl-v10.html + +This product includes code (JaspellTernarySearchTrie) from Java Spelling Checkin +g Package (jaspell): http://jaspell.sourceforge.net/ +License: The BSD License (http://www.opensource.org/licenses/bsd-license.php) + +The snowball stemmers in + analysis/common/src/java/net/sf/snowball +were developed by Martin Porter and Richard Boulton. +The snowball stopword lists in + analysis/common/src/resources/org/apache/lucene/analysis/snowball +were developed by Martin Porter and Richard Boulton. +The full snowball package is available from + http://snowball.tartarus.org/ + +The KStem stemmer in + analysis/common/src/org/apache/lucene/analysis/en +was developed by Bob Krovetz and Sergio Guzman-Lara (CIIR-UMass Amherst) +under the BSD-license. + +The Arabic,Persian,Romanian,Bulgarian, and Hindi analyzers (common) come with a default +stopword list that is BSD-licensed created by Jacques Savoy. These files reside in: +analysis/common/src/resources/org/apache/lucene/analysis/ar/stopwords.txt, +analysis/common/src/resources/org/apache/lucene/analysis/fa/stopwords.txt, +analysis/common/src/resources/org/apache/lucene/analysis/ro/stopwords.txt, +analysis/common/src/resources/org/apache/lucene/analysis/bg/stopwords.txt, +analysis/common/src/resources/org/apache/lucene/analysis/hi/stopwords.txt +See http://members.unine.ch/jacques.savoy/clef/index.html. + +The German,Spanish,Finnish,French,Hungarian,Italian,Portuguese,Russian and Swedish light stemmers +(common) are based on BSD-licensed reference implementations created by Jacques Savoy and +Ljiljana Dolamic. These files reside in: +analysis/common/src/java/org/apache/lucene/analysis/de/GermanLightStemmer.java +analysis/common/src/java/org/apache/lucene/analysis/de/GermanMinimalStemmer.java +analysis/common/src/java/org/apache/lucene/analysis/es/SpanishLightStemmer.java +analysis/common/src/java/org/apache/lucene/analysis/fi/FinnishLightStemmer.java +analysis/common/src/java/org/apache/lucene/analysis/fr/FrenchLightStemmer.java +analysis/common/src/java/org/apache/lucene/analysis/fr/FrenchMinimalStemmer.java +analysis/common/src/java/org/apache/lucene/analysis/hu/HungarianLightStemmer.java +analysis/common/src/java/org/apache/lucene/analysis/it/ItalianLightStemmer.java +analysis/common/src/java/org/apache/lucene/analysis/pt/PortugueseLightStemmer.java +analysis/common/src/java/org/apache/lucene/analysis/ru/RussianLightStemmer.java +analysis/common/src/java/org/apache/lucene/analysis/sv/SwedishLightStemmer.java + +The Stempel analyzer (stempel) includes BSD-licensed software developed +by the Egothor project http://egothor.sf.net/, created by Leo Galambos, Martin Kvapil, +and Edmond Nolan. + +The Polish analyzer (stempel) comes with a default +stopword list that is BSD-licensed created by the Carrot2 project. The file resides +in stempel/src/resources/org/apache/lucene/analysis/pl/stopwords.txt. +See http://project.carrot2.org/license.html. + +The SmartChineseAnalyzer source code (smartcn) was +provided by Xiaoping Gao and copyright 2009 by www.imdict.net. + +WordBreakTestUnicode_*.java (under modules/analysis/common/src/test/) +is derived from Unicode data such as the Unicode Character Database. +See http://unicode.org/copyright.html for more details. + +The Morfologik analyzer (morfologik) includes BSD-licensed software +developed by Dawid Weiss and Marcin Miłkowski (http://morfologik.blogspot.com/). + +Morfologik uses data from Polish ispell/myspell dictionary +(http://www.sjp.pl/slownik/en/) licenced on the terms of (inter alia) +LGPL and Creative Commons ShareAlike. + +Morfologic includes data from BSD-licensed dictionary of Polish (SGJP) +(http://sgjp.pl/morfeusz/) + +Servlet-api.jar and javax.servlet-*.jar are under the CDDL license, the original +source code for this can be found at http://www.eclipse.org/jetty/downloads.php + +=========================================================================== +Kuromoji Japanese Morphological Analyzer - Apache Lucene Integration +=========================================================================== + +This software includes a binary and/or source version of data from + + mecab-ipadic-2.7.0-20070801 + +which can be obtained from + + http://atilika.com/releases/mecab-ipadic/mecab-ipadic-2.7.0-20070801.tar.gz + +or + + http://jaist.dl.sourceforge.net/project/mecab/mecab-ipadic/2.7.0-20070801/mecab-ipadic-2.7.0-20070801.tar.gz + +=========================================================================== +mecab-ipadic-2.7.0-20070801 Notice +=========================================================================== + +Nara Institute of Science and Technology (NAIST), +the copyright holders, disclaims all warranties with regard to this +software, including all implied warranties of merchantability and +fitness, in no event shall NAIST be liable for +any special, indirect or consequential damages or any damages +whatsoever resulting from loss of use, data or profits, whether in an +action of contract, negligence or other tortuous action, arising out +of or in connection with the use or performance of this software. + +A large portion of the dictionary entries +originate from ICOT Free Software. The following conditions for ICOT +Free Software applies to the current dictionary as well. + +Each User may also freely distribute the Program, whether in its +original form or modified, to any third party or parties, PROVIDED +that the provisions of Section 3 ("NO WARRANTY") will ALWAYS appear +on, or be attached to, the Program, which is distributed substantially +in the same form as set out herein and that such intended +distribution, if actually made, will neither violate or otherwise +contravene any of the laws and regulations of the countries having +jurisdiction over the User or the intended distribution itself. + +NO WARRANTY + +The program was produced on an experimental basis in the course of the +research and development conducted during the project and is provided +to users as so produced on an experimental basis. Accordingly, the +program is provided without any warranty whatsoever, whether express, +implied, statutory or otherwise. The term "warranty" used herein +includes, but is not limited to, any warranty of the quality, +performance, merchantability and fitness for a particular purpose of +the program and the nonexistence of any infringement or violation of +any right of any third party. + +Each user of the program will agree and understand, and be deemed to +have agreed and understood, that there is no warranty whatsoever for +the program and, accordingly, the entire risk arising from or +otherwise connected with the program is assumed by the user. + +Therefore, neither ICOT, the copyright holder, or any other +organization that participated in or was otherwise related to the +development of the program and their respective officials, directors, +officers and other employees shall be held liable for any and all +damages, including, without limitation, general, special, incidental +and consequential damages, arising out of or otherwise in connection +with the use or inability to use the program or any product, material +or result produced or otherwise obtained by using the program, +regardless of whether they have been advised of, or otherwise had +knowledge of, the possibility of such damages at any time during the +project or thereafter. Each user will be deemed to have agreed to the +foregoing by his or her commencement of use of the program. The term +"use" as used herein includes, but is not limited to, the use, +modification, copying and distribution of the program and the +production of secondary products from the program. + +In the case where the program, whether in its original form or +modified, was distributed or delivered to or received by a user from +any person, organization or entity other than ICOT, unless it makes or +grants independently of ICOT any specific warranty to the user in +writing, such person, organization or entity, will also be exempted +from and not be held liable to the user for any such damages as noted +above as far as the program is concerned. + +The Derby build relies on a jar file supplied by the JSON Simple +project, hosted at https://code.google.com/p/json-simple/. +The JSON simple jar file is licensed under the Apache 2.0 License. + +Hive CLI +Copyright 2016 The Apache Software Foundation + +Hive JDBC +Copyright 2016 The Apache Software Foundation + + +Chill is a set of Scala extensions for Kryo. +Copyright 2012 Twitter, Inc. + +Third Party Dependencies: + +Kryo 2.17 +BSD 3-Clause License +http://code.google.com/p/kryo + +Commons-Codec 1.7 +Apache Public License 2.0 +http://hadoop.apache.org + + + +Breeze is distributed under an Apache License V2.0 (See LICENSE) + +=============================================================================== + +Proximal algorithms outlined in Proximal.scala (package breeze.optimize.proximal) +are based on https://github.com/cvxgrp/proximal (see LICENSE for details) and distributed with +Copyright (c) 2014 by Debasish Das (Verizon), all rights reserved. + +=============================================================================== + +QuadraticMinimizer class in package breeze.optimize.proximal is distributed with Copyright (c) +2014, Debasish Das (Verizon), all rights reserved. + +=============================================================================== + +NonlinearMinimizer class in package breeze.optimize.proximal is distributed with Copyright (c) +2015, Debasish Das (Verizon), all rights reserved. + + +stream-lib +Copyright 2016 AddThis + +This product includes software developed by AddThis. + +This product also includes code adapted from: + +Apache Solr (http://lucene.apache.org/solr/) +Copyright 2014 The Apache Software Foundation + +Apache Mahout (http://mahout.apache.org/) +Copyright 2014 The Apache Software Foundation diff --git a/R/README.md b/R/README.md index 1152b1e8e5f9f..d77a1ecffc99c 100644 --- a/R/README.md +++ b/R/README.md @@ -17,7 +17,7 @@ export R_HOME=/home/username/R #### Build Spark -Build Spark with [Maven](http://spark.apache.org/docs/latest/building-spark.html#building-with-buildmvn) and include the `-Psparkr` profile to build the R package. For example to use the default Hadoop versions you can run +Build Spark with [Maven](http://spark.apache.org/docs/latest/building-spark.html#buildmvn) and include the `-Psparkr` profile to build the R package. For example to use the default Hadoop versions you can run ```bash build/mvn -DskipTests -Psparkr package diff --git a/R/WINDOWS.md b/R/WINDOWS.md index 124bc631be9cd..da668a69b8679 100644 --- a/R/WINDOWS.md +++ b/R/WINDOWS.md @@ -14,7 +14,7 @@ directory in Maven in `PATH`. 4. Set `MAVEN_OPTS` as described in [Building Spark](http://spark.apache.org/docs/latest/building-spark.html). -5. Open a command shell (`cmd`) in the Spark directory and build Spark with [Maven](http://spark.apache.org/docs/latest/building-spark.html#building-with-buildmvn) and include the `-Psparkr` profile to build the R package. For example to use the default Hadoop versions you can run +5. Open a command shell (`cmd`) in the Spark directory and build Spark with [Maven](http://spark.apache.org/docs/latest/building-spark.html#buildmvn) and include the `-Psparkr` profile to build the R package. For example to use the default Hadoop versions you can run ```bash mvn.cmd -DskipTests -Psparkr package diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index 9696f6987ad78..96ff389faf4a0 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -117,6 +117,7 @@ exportMethods("arrange", "dropna", "dtypes", "except", + "exceptAll", "explain", "fillna", "filter", @@ -131,6 +132,7 @@ exportMethods("arrange", "hint", "insertInto", "intersect", + "intersectAll", "isLocal", "isStreaming", "join", @@ -201,13 +203,19 @@ exportMethods("%<=>%", "approxCountDistinct", "approxQuantile", "array_contains", + "array_distinct", + "array_except", + "array_intersect", "array_join", "array_max", "array_min", "array_position", + "array_remove", "array_repeat", "array_sort", "arrays_overlap", + "array_union", + "arrays_zip", "asc", "ascii", "asin", @@ -306,6 +314,7 @@ exportMethods("%<=>%", "lpad", "ltrim", "map_entries", + "map_from_arrays", "map_keys", "map_values", "max", @@ -349,6 +358,7 @@ exportMethods("%<=>%", "shiftLeft", "shiftRight", "shiftRightUnsigned", + "shuffle", "sd", "sign", "signum", diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index 70eb7a874b75c..4f2d4c7c002d4 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -588,7 +588,7 @@ setMethod("cache", #' \url{http://spark.apache.org/docs/latest/rdd-programming-guide.html#rdd-persistence}. #' #' @param x the SparkDataFrame to persist. -#' @param newLevel storage level chosen for the persistance. See available options in +#' @param newLevel storage level chosen for the persistence. See available options in #' the description. #' #' @family SparkDataFrame functions @@ -2848,6 +2848,35 @@ setMethod("intersect", dataFrame(intersected) }) +#' intersectAll +#' +#' Return a new SparkDataFrame containing rows in both this SparkDataFrame +#' and another SparkDataFrame while preserving the duplicates. +#' This is equivalent to \code{INTERSECT ALL} in SQL. Also as standard in +#' SQL, this function resolves columns by position (not by name). +#' +#' @param x a SparkDataFrame. +#' @param y a SparkDataFrame. +#' @return A SparkDataFrame containing the result of the intersect all operation. +#' @family SparkDataFrame functions +#' @aliases intersectAll,SparkDataFrame,SparkDataFrame-method +#' @rdname intersectAll +#' @name intersectAll +#' @examples +#'\dontrun{ +#' sparkR.session() +#' df1 <- read.json(path) +#' df2 <- read.json(path2) +#' intersectAllDF <- intersectAll(df1, df2) +#' } +#' @note intersectAll since 2.4.0 +setMethod("intersectAll", + signature(x = "SparkDataFrame", y = "SparkDataFrame"), + function(x, y) { + intersected <- callJMethod(x@sdf, "intersectAll", y@sdf) + dataFrame(intersected) + }) + #' except #' #' Return a new SparkDataFrame containing rows in this SparkDataFrame @@ -2867,7 +2896,6 @@ setMethod("intersect", #' df2 <- read.json(path2) #' exceptDF <- except(df, df2) #' } -#' @rdname except #' @note except since 1.4.0 setMethod("except", signature(x = "SparkDataFrame", y = "SparkDataFrame"), @@ -2876,6 +2904,35 @@ setMethod("except", dataFrame(excepted) }) +#' exceptAll +#' +#' Return a new SparkDataFrame containing rows in this SparkDataFrame +#' but not in another SparkDataFrame while preserving the duplicates. +#' This is equivalent to \code{EXCEPT ALL} in SQL. Also as standard in +#' SQL, this function resolves columns by position (not by name). +#' +#' @param x a SparkDataFrame. +#' @param y a SparkDataFrame. +#' @return A SparkDataFrame containing the result of the except all operation. +#' @family SparkDataFrame functions +#' @aliases exceptAll,SparkDataFrame,SparkDataFrame-method +#' @rdname exceptAll +#' @name exceptAll +#' @examples +#'\dontrun{ +#' sparkR.session() +#' df1 <- read.json(path) +#' df2 <- read.json(path2) +#' exceptAllDF <- exceptAll(df1, df2) +#' } +#' @note exceptAll since 2.4.0 +setMethod("exceptAll", + signature(x = "SparkDataFrame", y = "SparkDataFrame"), + function(x, y) { + excepted <- callJMethod(x@sdf, "exceptAll", y@sdf) + dataFrame(excepted) + }) + #' Save the contents of SparkDataFrame to a data source. #' #' The data source is specified by the \code{source} and a set of options (...). diff --git a/R/pkg/R/SQLContext.R b/R/pkg/R/SQLContext.R index 429dd5d565492..c819a7d14ae98 100644 --- a/R/pkg/R/SQLContext.R +++ b/R/pkg/R/SQLContext.R @@ -351,7 +351,7 @@ setMethod("toDF", signature(x = "RDD"), read.json.default <- function(path, ...) { sparkSession <- getSparkSession() options <- varargsToStrEnv(...) - # Allow the user to have a more flexible definiton of the text file path + # Allow the user to have a more flexible definition of the text file path paths <- as.list(suppressWarnings(normalizePath(path))) read <- callJMethod(sparkSession, "read") read <- callJMethod(read, "options", options) @@ -421,7 +421,7 @@ jsonRDD <- function(sqlContext, rdd, schema = NULL, samplingRatio = 1.0) { read.orc <- function(path, ...) { sparkSession <- getSparkSession() options <- varargsToStrEnv(...) - # Allow the user to have a more flexible definiton of the ORC file path + # Allow the user to have a more flexible definition of the ORC file path path <- suppressWarnings(normalizePath(path)) read <- callJMethod(sparkSession, "read") read <- callJMethod(read, "options", options) @@ -442,7 +442,7 @@ read.orc <- function(path, ...) { read.parquet.default <- function(path, ...) { sparkSession <- getSparkSession() options <- varargsToStrEnv(...) - # Allow the user to have a more flexible definiton of the Parquet file path + # Allow the user to have a more flexible definition of the Parquet file path paths <- as.list(suppressWarnings(normalizePath(path))) read <- callJMethod(sparkSession, "read") read <- callJMethod(read, "options", options) @@ -492,7 +492,7 @@ parquetFile <- function(x, ...) { read.text.default <- function(path, ...) { sparkSession <- getSparkSession() options <- varargsToStrEnv(...) - # Allow the user to have a more flexible definiton of the text file path + # Allow the user to have a more flexible definition of the text file path paths <- as.list(suppressWarnings(normalizePath(path))) read <- callJMethod(sparkSession, "read") read <- callJMethod(read, "options", options) diff --git a/R/pkg/R/client.R b/R/pkg/R/client.R index 4c87f64e7f0e1..660f0864403e0 100644 --- a/R/pkg/R/client.R +++ b/R/pkg/R/client.R @@ -71,15 +71,20 @@ checkJavaVersion <- function() { # If java is missing from PATH, we get an error in Unix and a warning in Windows javaVersionOut <- tryCatch( - launchScript(javaBin, "-version", wait = TRUE, stdout = TRUE, stderr = TRUE), - error = function(e) { - stop("Java version check failed. Please make sure Java is installed", - " and set JAVA_HOME to point to the installation directory.", e) - }, - warning = function(w) { - stop("Java version check failed. Please make sure Java is installed", - " and set JAVA_HOME to point to the installation directory.", w) - }) + if (is_windows()) { + # See SPARK-24535 + system2(javaBin, "-version", wait = TRUE, stdout = TRUE, stderr = TRUE) + } else { + launchScript(javaBin, "-version", wait = TRUE, stdout = TRUE, stderr = TRUE) + }, + error = function(e) { + stop("Java version check failed. Please make sure Java is installed", + " and set JAVA_HOME to point to the installation directory.", e) + }, + warning = function(w) { + stop("Java version check failed. Please make sure Java is installed", + " and set JAVA_HOME to point to the installation directory.", w) + }) javaVersionFilter <- Filter( function(x) { grepl(" version", x) @@ -93,6 +98,7 @@ checkJavaVersion <- function() { stop(paste("Java version", sparkJavaVersion, "is required for this package; found version:", javaVersionStr)) } + return(javaVersionNum) } launchBackend <- function(args, sparkHome, jars, sparkSubmitOpts, packages) { diff --git a/R/pkg/R/context.R b/R/pkg/R/context.R index 8ec727dd042bc..f168ca76b6007 100644 --- a/R/pkg/R/context.R +++ b/R/pkg/R/context.R @@ -43,7 +43,7 @@ getMinPartitions <- function(sc, minPartitions) { #' lines <- textFile(sc, "myfile.txt") #'} textFile <- function(sc, path, minPartitions = NULL) { - # Allow the user to have a more flexible definiton of the text file path + # Allow the user to have a more flexible definition of the text file path path <- suppressWarnings(normalizePath(path)) # Convert a string vector of paths to a string containing comma separated paths path <- paste(path, collapse = ",") @@ -71,7 +71,7 @@ textFile <- function(sc, path, minPartitions = NULL) { #' rdd <- objectFile(sc, "myfile") #'} objectFile <- function(sc, path, minPartitions = NULL) { - # Allow the user to have a more flexible definiton of the text file path + # Allow the user to have a more flexible definition of the text file path path <- suppressWarnings(normalizePath(path)) # Convert a string vector of paths to a string containing comma separated paths path <- paste(path, collapse = ",") @@ -138,11 +138,10 @@ parallelize <- function(sc, coll, numSlices = 1) { sizeLimit <- getMaxAllocationLimit(sc) objectSize <- object.size(coll) + len <- length(coll) # For large objects we make sure the size of each slice is also smaller than sizeLimit - numSerializedSlices <- max(numSlices, ceiling(objectSize / sizeLimit)) - if (numSerializedSlices > length(coll)) - numSerializedSlices <- length(coll) + numSerializedSlices <- min(len, max(numSlices, ceiling(objectSize / sizeLimit))) # Generate the slice ids to put each row # For instance, for numSerializedSlices of 22, length of 50 @@ -153,8 +152,8 @@ parallelize <- function(sc, coll, numSlices = 1) { splits <- if (numSerializedSlices > 0) { unlist(lapply(0: (numSerializedSlices - 1), function(x) { # nolint start - start <- trunc((x * length(coll)) / numSerializedSlices) - end <- trunc(((x + 1) * length(coll)) / numSerializedSlices) + start <- trunc((as.numeric(x) * len) / numSerializedSlices) + end <- trunc(((as.numeric(x) + 1) * len) / numSerializedSlices) # nolint end rep(start, end - start) })) @@ -305,6 +304,8 @@ setCheckpointDirSC <- function(sc, dirName) { #' Currently directories are only supported for Hadoop-supported filesystems. #' Refer Hadoop-supported filesystems at \url{https://wiki.apache.org/hadoop/HCFS}. #' +#' Note: A path can be added only once. Subsequent additions of the same path are ignored. +#' #' @rdname spark.addFile #' @param path The path of the file to be added #' @param recursive Whether to add files recursively from the path. Default is FALSE. diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index 3bff633fbc1ff..572dee50127b8 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -194,10 +194,12 @@ NULL #' \itemize{ #' \item \code{array_contains}: a value to be checked if contained in the column. #' \item \code{array_position}: a value to locate in the given array. +#' \item \code{array_remove}: a value to remove in the given array. #' } #' @param ... additional argument(s). In \code{to_json} and \code{from_json}, this contains #' additional named properties to control how it is converted, accepts the same -#' options as the JSON data source. +#' options as the JSON data source. In \code{arrays_zip}, this contains additional +#' Columns of arrays to be merged. #' @name column_collection_functions #' @rdname column_collection_functions #' @family collection functions @@ -206,10 +208,10 @@ NULL #' # Dataframe used throughout this doc #' df <- createDataFrame(cbind(model = rownames(mtcars), mtcars)) #' tmp <- mutate(df, v1 = create_array(df$mpg, df$cyl, df$hp)) -#' head(select(tmp, array_contains(tmp$v1, 21), size(tmp$v1))) -#' head(select(tmp, array_max(tmp$v1), array_min(tmp$v1))) +#' head(select(tmp, array_contains(tmp$v1, 21), size(tmp$v1), shuffle(tmp$v1))) +#' head(select(tmp, array_max(tmp$v1), array_min(tmp$v1), array_distinct(tmp$v1))) #' head(select(tmp, array_position(tmp$v1, 21), array_repeat(df$mpg, 3), array_sort(tmp$v1))) -#' head(select(tmp, flatten(tmp$v1), reverse(tmp$v1))) +#' head(select(tmp, flatten(tmp$v1), reverse(tmp$v1), array_remove(tmp$v1, 21))) #' tmp2 <- mutate(tmp, v2 = explode(tmp$v1)) #' head(tmp2) #' head(select(tmp, posexplode(tmp$v1))) @@ -221,6 +223,9 @@ NULL #' head(select(tmp3, element_at(tmp3$v3, "Valiant"))) #' tmp4 <- mutate(df, v4 = create_array(df$mpg, df$cyl), v5 = create_array(df$cyl, df$hp)) #' head(select(tmp4, concat(tmp4$v4, tmp4$v5), arrays_overlap(tmp4$v4, tmp4$v5))) +#' head(select(tmp4, array_except(tmp4$v4, tmp4$v5), array_intersect(tmp4$v4, tmp4$v5))) +#' head(select(tmp4, array_union(tmp4$v4, tmp4$v5))) +#' head(select(tmp4, arrays_zip(tmp4$v4, tmp4$v5), map_from_arrays(tmp4$v4, tmp4$v5))) #' head(select(tmp, concat(df$mpg, df$cyl, df$hp))) #' tmp5 <- mutate(df, v6 = create_array(df$model, df$model)) #' head(select(tmp5, array_join(tmp5$v6, "#"), array_join(tmp5$v6, "#", "NULL")))} @@ -1694,8 +1699,8 @@ setMethod("to_date", }) #' @details -#' \code{to_json}: Converts a column containing a \code{structType}, array of \code{structType}, -#' a \code{mapType} or array of \code{mapType} into a Column of JSON string. +#' \code{to_json}: Converts a column containing a \code{structType}, a \code{mapType} +#' or an \code{arrayType} into a Column of JSON string. #' Resolving the Column can fail if an unsupported type is encountered. #' #' @rdname column_collection_functions @@ -1978,7 +1983,7 @@ setMethod("levenshtein", signature(y = "Column"), }) #' @details -#' \code{months_between}: Returns number of months between dates \code{y} and \code{x}. +#' \code{months_between}: Returns number of months between dates \code{y} and \code{x}. #' If \code{y} is later than \code{x}, then the result is positive. If \code{y} and \code{x} #' are on the same day of month, or both are the last day of month, time of day will be ignored. #' Otherwise, the difference is calculated based on 31 days per month, and rounded to 8 digits. @@ -3008,6 +3013,47 @@ setMethod("array_contains", column(jc) }) +#' @details +#' \code{array_distinct}: Removes duplicate values from the array. +#' +#' @rdname column_collection_functions +#' @aliases array_distinct array_distinct,Column-method +#' @note array_distinct since 2.4.0 +setMethod("array_distinct", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "array_distinct", x@jc) + column(jc) + }) + +#' @details +#' \code{array_except}: Returns an array of the elements in the first array but not in the second +#' array, without duplicates. The order of elements in the result is not determined. +#' +#' @rdname column_collection_functions +#' @aliases array_except array_except,Column-method +#' @note array_except since 2.4.0 +setMethod("array_except", + signature(x = "Column", y = "Column"), + function(x, y) { + jc <- callJStatic("org.apache.spark.sql.functions", "array_except", x@jc, y@jc) + column(jc) + }) + +#' @details +#' \code{array_intersect}: Returns an array of the elements in the intersection of the given two +#' arrays, without duplicates. +#' +#' @rdname column_collection_functions +#' @aliases array_intersect array_intersect,Column-method +#' @note array_intersect since 2.4.0 +setMethod("array_intersect", + signature(x = "Column", y = "Column"), + function(x, y) { + jc <- callJStatic("org.apache.spark.sql.functions", "array_intersect", x@jc, y@jc) + column(jc) + }) + #' @details #' \code{array_join}: Concatenates the elements of column using the delimiter. #' Null values are replaced with nullReplacement if set, otherwise they are ignored. @@ -3071,6 +3117,19 @@ setMethod("array_position", column(jc) }) +#' @details +#' \code{array_remove}: Removes all elements that equal to element from the given array. +#' +#' @rdname column_collection_functions +#' @aliases array_remove array_remove,Column-method +#' @note array_remove since 2.4.0 +setMethod("array_remove", + signature(x = "Column", value = "ANY"), + function(x, value) { + jc <- callJStatic("org.apache.spark.sql.functions", "array_remove", x@jc, value) + column(jc) + }) + #' @details #' \code{array_repeat}: Creates an array containing \code{x} repeated the number of times #' given by \code{count}. @@ -3120,6 +3179,51 @@ setMethod("arrays_overlap", column(jc) }) +#' @details +#' \code{array_union}: Returns an array of the elements in the union of the given two arrays, +#' without duplicates. +#' +#' @rdname column_collection_functions +#' @aliases array_union array_union,Column-method +#' @note array_union since 2.4.0 +setMethod("array_union", + signature(x = "Column", y = "Column"), + function(x, y) { + jc <- callJStatic("org.apache.spark.sql.functions", "array_union", x@jc, y@jc) + column(jc) + }) + +#' @details +#' \code{arrays_zip}: Returns a merged array of structs in which the N-th struct contains all N-th +#' values of input arrays. +#' +#' @rdname column_collection_functions +#' @aliases arrays_zip arrays_zip,Column-method +#' @note arrays_zip since 2.4.0 +setMethod("arrays_zip", + signature(x = "Column"), + function(x, ...) { + jcols <- lapply(list(x, ...), function(arg) { + stopifnot(class(arg) == "Column") + arg@jc + }) + jc <- callJStatic("org.apache.spark.sql.functions", "arrays_zip", jcols) + column(jc) + }) + +#' @details +#' \code{shuffle}: Returns a random permutation of the given array. +#' +#' @rdname column_collection_functions +#' @aliases shuffle shuffle,Column-method +#' @note shuffle since 2.4.0 +setMethod("shuffle", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "shuffle", x@jc) + column(jc) + }) + #' @details #' \code{flatten}: Creates a single array from an array of arrays. #' If a structure of nested arrays is deeper than two levels, only one level of nesting is removed. @@ -3147,6 +3251,21 @@ setMethod("map_entries", column(jc) }) +#' @details +#' \code{map_from_arrays}: Creates a new map column. The array in the first column is used for +#' keys. The array in the second column is used for values. All elements in the array for key +#' should not be null. +#' +#' @rdname column_collection_functions +#' @aliases map_from_arrays map_from_arrays,Column-method +#' @note map_from_arrays since 2.4.0 +setMethod("map_from_arrays", + signature(x = "Column", y = "Column"), + function(x, y) { + jc <- callJStatic("org.apache.spark.sql.functions", "map_from_arrays", x@jc, y@jc) + column(jc) + }) + #' @details #' \code{map_keys}: Returns an unordered array containing the keys of the map. #' diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 9321bbaf96ff8..27c1b312d645c 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -471,6 +471,9 @@ setGeneric("explain", function(x, ...) { standardGeneric("explain") }) #' @rdname except setGeneric("except", function(x, y) { standardGeneric("except") }) +#' @rdname exceptAll +setGeneric("exceptAll", function(x, y) { standardGeneric("exceptAll") }) + #' @rdname nafunctions setGeneric("fillna", function(x, value, cols = NULL) { standardGeneric("fillna") }) @@ -495,6 +498,9 @@ setGeneric("insertInto", function(x, tableName, ...) { standardGeneric("insertIn #' @rdname intersect setGeneric("intersect", function(x, y) { standardGeneric("intersect") }) +#' @rdname intersectAll +setGeneric("intersectAll", function(x, y) { standardGeneric("intersectAll") }) + #' @rdname isLocal setGeneric("isLocal", function(x) { standardGeneric("isLocal") }) @@ -757,6 +763,18 @@ setGeneric("approxCountDistinct", function(x, ...) { standardGeneric("approxCoun #' @name NULL setGeneric("array_contains", function(x, value) { standardGeneric("array_contains") }) +#' @rdname column_collection_functions +#' @name NULL +setGeneric("array_distinct", function(x) { standardGeneric("array_distinct") }) + +#' @rdname column_collection_functions +#' @name NULL +setGeneric("array_except", function(x, y) { standardGeneric("array_except") }) + +#' @rdname column_collection_functions +#' @name NULL +setGeneric("array_intersect", function(x, y) { standardGeneric("array_intersect") }) + #' @rdname column_collection_functions #' @name NULL setGeneric("array_join", function(x, delimiter, ...) { standardGeneric("array_join") }) @@ -773,6 +791,10 @@ setGeneric("array_min", function(x) { standardGeneric("array_min") }) #' @name NULL setGeneric("array_position", function(x, value) { standardGeneric("array_position") }) +#' @rdname column_collection_functions +#' @name NULL +setGeneric("array_remove", function(x, value) { standardGeneric("array_remove") }) + #' @rdname column_collection_functions #' @name NULL setGeneric("array_repeat", function(x, count) { standardGeneric("array_repeat") }) @@ -785,6 +807,14 @@ setGeneric("array_sort", function(x) { standardGeneric("array_sort") }) #' @name NULL setGeneric("arrays_overlap", function(x, y) { standardGeneric("arrays_overlap") }) +#' @rdname column_collection_functions +#' @name NULL +setGeneric("array_union", function(x, y) { standardGeneric("array_union") }) + +#' @rdname column_collection_functions +#' @name NULL +setGeneric("arrays_zip", function(x, ...) { standardGeneric("arrays_zip") }) + #' @rdname column_string_functions #' @name NULL setGeneric("ascii", function(x) { standardGeneric("ascii") }) @@ -1050,6 +1080,10 @@ setGeneric("ltrim", function(x, trimString) { standardGeneric("ltrim") }) #' @name NULL setGeneric("map_entries", function(x) { standardGeneric("map_entries") }) +#' @rdname column_collection_functions +#' @name NULL +setGeneric("map_from_arrays", function(x, y) { standardGeneric("map_from_arrays") }) + #' @rdname column_collection_functions #' @name NULL setGeneric("map_keys", function(x) { standardGeneric("map_keys") }) @@ -1198,6 +1232,10 @@ setGeneric("shiftRight", function(y, x) { standardGeneric("shiftRight") }) #' @name NULL setGeneric("shiftRightUnsigned", function(y, x) { standardGeneric("shiftRightUnsigned") }) +#' @rdname column_collection_functions +#' @name NULL +setGeneric("shuffle", function(x) { standardGeneric("shuffle") }) + #' @rdname column_math_functions #' @name NULL setGeneric("signum", function(x) { standardGeneric("signum") }) diff --git a/R/pkg/R/mllib_fpm.R b/R/pkg/R/mllib_fpm.R index e2394906d8012..4ad34fe82328f 100644 --- a/R/pkg/R/mllib_fpm.R +++ b/R/pkg/R/mllib_fpm.R @@ -116,10 +116,11 @@ setMethod("spark.freqItemsets", signature(object = "FPGrowthModel"), # Get association rules. #' @return A \code{SparkDataFrame} with association rules. -#' The \code{SparkDataFrame} contains three columns: +#' The \code{SparkDataFrame} contains four columns: #' \code{antecedent} (an array of the same type as the input column), #' \code{consequent} (an array of the same type as the input column), -#' and \code{condfidence} (confidence). +#' \code{condfidence} (confidence for the rule) +#' and \code{lift} (lift for the rule) #' @rdname spark.fpGrowth #' @aliases associationRules,FPGrowthModel-method #' @note spark.associationRules(FPGrowthModel) since 2.2.0 diff --git a/R/pkg/R/mllib_tree.R b/R/pkg/R/mllib_tree.R index 6769be038efa9..0e60842dd44c8 100644 --- a/R/pkg/R/mllib_tree.R +++ b/R/pkg/R/mllib_tree.R @@ -362,7 +362,18 @@ setMethod("write.ml", signature(object = "GBTClassificationModel", path = "chara #' For regression, must be "variance". For classification, must be one of #' "entropy" and "gini", default is "gini". #' @param featureSubsetStrategy The number of features to consider for splits at each tree node. -#' Supported options: "auto", "all", "onethird", "sqrt", "log2", (0.0-1.0], [1-n]. +#' Supported options: "auto" (choose automatically for task: If +#' numTrees == 1, set to "all." If numTrees > 1 +#' (forest), set to "sqrt" for classification and +#' to "onethird" for regression), +#' "all" (use all features), +#' "onethird" (use 1/3 of the features), +#' "sqrt" (use sqrt(number of features)), +#' "log2" (use log2(number of features)), +#' "n": (when n is in the range (0, 1.0], use +#' n * number of features. When n is in the range +#' (1, number of features), use n features). +#' Default is "auto". #' @param seed integer seed for random number generation. #' @param subsamplingRate Fraction of the training data used for learning each decision tree, in #' range (0, 1]. diff --git a/R/pkg/R/sparkR.R b/R/pkg/R/sparkR.R index f7c1663d32c96..d3a9cbae7d808 100644 --- a/R/pkg/R/sparkR.R +++ b/R/pkg/R/sparkR.R @@ -167,7 +167,7 @@ sparkR.sparkContext <- function( submitOps <- getClientModeSparkSubmitOpts( Sys.getenv("SPARKR_SUBMIT_ARGS", "sparkr-shell"), sparkEnvirMap) - checkJavaVersion() + invisible(checkJavaVersion()) launchBackend( args = path, sparkHome = sparkHome, diff --git a/R/pkg/R/streaming.R b/R/pkg/R/streaming.R index fc83463f72cd4..5eccbdc9d3818 100644 --- a/R/pkg/R/streaming.R +++ b/R/pkg/R/streaming.R @@ -163,7 +163,7 @@ setMethod("isActive", #' #' @param x a StreamingQuery. #' @param timeout time to wait in milliseconds, if omitted, wait indefinitely until \code{stopQuery} -#' is called or an error has occured. +#' is called or an error has occurred. #' @return TRUE if query has terminated within the timeout period; nothing if timeout is not #' specified. #' @rdname awaitTermination diff --git a/R/pkg/inst/tests/testthat/test_basic.R b/R/pkg/inst/tests/testthat/test_basic.R index 823d26f12feee..80df3d8ce6e59 100644 --- a/R/pkg/inst/tests/testthat/test_basic.R +++ b/R/pkg/inst/tests/testthat/test_basic.R @@ -18,6 +18,10 @@ context("basic tests for CRAN") test_that("create DataFrame from list or data.frame", { + tryCatch(checkJavaVersion(), + error = function(e) { skip("error on Java check") }, + warning = function(e) { skip("warning on Java check") }) + sparkR.session(master = sparkRTestMaster, enableHiveSupport = FALSE, sparkConfig = sparkRTestConfig) @@ -50,6 +54,10 @@ test_that("create DataFrame from list or data.frame", { }) test_that("spark.glm and predict", { + tryCatch(checkJavaVersion(), + error = function(e) { skip("error on Java check") }, + warning = function(e) { skip("warning on Java check") }) + sparkR.session(master = sparkRTestMaster, enableHiveSupport = FALSE, sparkConfig = sparkRTestConfig) diff --git a/R/pkg/inst/worker/worker.R b/R/pkg/inst/worker/worker.R index ba458d2b9ddfb..c2adf613acb02 100644 --- a/R/pkg/inst/worker/worker.R +++ b/R/pkg/inst/worker/worker.R @@ -62,7 +62,7 @@ compute <- function(mode, partition, serializer, deserializer, key, # Transform the result data.frame back to a list of rows output <- split(output, seq(nrow(output))) } else { - # Serialize the ouput to a byte array + # Serialize the output to a byte array stopifnot(serializer == "byte") } } else { diff --git a/R/pkg/tests/fulltests/test_context.R b/R/pkg/tests/fulltests/test_context.R index f0d0a5114f89f..288a2714a554e 100644 --- a/R/pkg/tests/fulltests/test_context.R +++ b/R/pkg/tests/fulltests/test_context.R @@ -240,3 +240,10 @@ test_that("add and get file to be downloaded with Spark job on every node", { unlink(path, recursive = TRUE) sparkR.session.stop() }) + +test_that("SPARK-25234: parallelize should not have integer overflow", { + sc <- sparkR.sparkContext(master = sparkRTestMaster) + # 47000 * 47000 exceeds integer range + parallelize(sc, 1:47000, 47000) + sparkR.session.stop() +}) diff --git a/R/pkg/tests/fulltests/test_mllib_classification.R b/R/pkg/tests/fulltests/test_mllib_classification.R index a46c47dccd02e..023686e75d50a 100644 --- a/R/pkg/tests/fulltests/test_mllib_classification.R +++ b/R/pkg/tests/fulltests/test_mllib_classification.R @@ -382,10 +382,10 @@ test_that("spark.mlp", { trainidxs <- base::sample(nrow(data), nrow(data) * 0.7) traindf <- as.DataFrame(data[trainidxs, ]) testdf <- as.DataFrame(rbind(data[-trainidxs, ], c(0, "the other"))) - model <- spark.mlp(traindf, clicked ~ ., layers = c(1, 3)) + model <- spark.mlp(traindf, clicked ~ ., layers = c(1, 2)) predictions <- predict(model, testdf) expect_error(collect(predictions)) - model <- spark.mlp(traindf, clicked ~ ., layers = c(1, 3), handleInvalid = "skip") + model <- spark.mlp(traindf, clicked ~ ., layers = c(1, 2), handleInvalid = "skip") predictions <- predict(model, testdf) expect_equal(class(collect(predictions)$clicked[1]), "list") diff --git a/R/pkg/tests/fulltests/test_mllib_fpm.R b/R/pkg/tests/fulltests/test_mllib_fpm.R index 69dda52f0c279..d80f66a25de1c 100644 --- a/R/pkg/tests/fulltests/test_mllib_fpm.R +++ b/R/pkg/tests/fulltests/test_mllib_fpm.R @@ -44,7 +44,8 @@ test_that("spark.fpGrowth", { expected_association_rules <- data.frame( antecedent = I(list(list("2"), list("3"))), consequent = I(list(list("1"), list("1"))), - confidence = c(1, 1) + confidence = c(1, 1), + lift = c(1, 1) ) expect_equivalent(expected_association_rules, collect(spark.associationRules(model))) diff --git a/R/pkg/tests/fulltests/test_sparkSQL.R b/R/pkg/tests/fulltests/test_sparkSQL.R index 36e0f78bb0599..0c4bdb31b027b 100644 --- a/R/pkg/tests/fulltests/test_sparkSQL.R +++ b/R/pkg/tests/fulltests/test_sparkSQL.R @@ -734,8 +734,8 @@ test_that("test cache, uncache and clearCache", { clearCache() expect_true(dropTempView("table1")) - expect_error(uncacheTable("foo"), - "Error in uncacheTable : analysis error - Table or view not found: foo") + expect_error(uncacheTable("zxwtyswklpf"), + "Error in uncacheTable : analysis error - Table or view not found: zxwtyswklpf") }) test_that("insertInto() on a registered table", { @@ -1503,6 +1503,27 @@ test_that("column functions", { result <- collect(select(df2, reverse(df2[[1]])))[[1]] expect_equal(result, "cba") + # Test array_distinct() and array_remove() + df <- createDataFrame(list(list(list(1L, 2L, 3L, 1L, 2L)), list(list(6L, 5L, 5L, 4L, 6L)))) + result <- collect(select(df, array_distinct(df[[1]])))[[1]] + expect_equal(result, list(list(1L, 2L, 3L), list(6L, 5L, 4L))) + + result <- collect(select(df, array_remove(df[[1]], 2L)))[[1]] + expect_equal(result, list(list(1L, 3L, 1L), list(6L, 5L, 5L, 4L, 6L))) + + # Test arrays_zip() + df <- createDataFrame(list(list(list(1L, 2L), list(3L, 4L))), schema = c("c1", "c2")) + result <- collect(select(df, arrays_zip(df[[1]], df[[2]])))[[1]] + expected_entries <- list(listToStruct(list(c1 = 1L, c2 = 3L)), + listToStruct(list(c1 = 2L, c2 = 4L))) + expect_equal(result, list(expected_entries)) + + # Test map_from_arrays() + df <- createDataFrame(list(list(list("x", "y"), list(1, 2))), schema = c("k", "v")) + result <- collect(select(df, map_from_arrays(df$k, df$v)))[[1]] + expected_entries <- list(as.environment(list(x = 1, y = 2))) + expect_equal(result, expected_entries) + # Test array_repeat() df <- createDataFrame(list(list("a", 3L), list("b", 2L))) result <- collect(select(df, array_repeat(df[[1]], df[[2]])))[[1]] @@ -1577,6 +1598,25 @@ test_that("column functions", { result <- collect(select(df, element_at(df$map, "y")))[[1]] expect_equal(result, 2) + # Test array_except(), array_intersect() and array_union() + df <- createDataFrame(list(list(list(1L, 2L, 3L), list(3L, 1L)), + list(list(1L, 2L), list(3L, 4L)), + list(list(1L, 2L, 3L), list(3L, 4L)))) + result1 <- collect(select(df, array_except(df[[1]], df[[2]])))[[1]] + expect_equal(result1, list(list(2L), list(1L, 2L), list(1L, 2L))) + + result2 <- collect(select(df, array_intersect(df[[1]], df[[2]])))[[1]] + expect_equal(result2, list(list(1L, 3L), list(), list(3L))) + + result3 <- collect(select(df, array_union(df[[1]], df[[2]])))[[1]] + expect_equal(result3, list(list(1L, 2L, 3L), list(1L, 2L, 3L, 4L), list(1L, 2L, 3L, 4L))) + + # Test shuffle() + df <- createDataFrame(list(list(list(1L, 20L, 3L, 5L)), list(list(4L, 5L, 6L, 7L)))) + result <- collect(select(df, shuffle(df[[1]])))[[1]] + expect_true(setequal(result[[1]], c(1L, 20L, 3L, 5L))) + expect_true(setequal(result[[2]], c(4L, 5L, 6L, 7L))) + # Test that stats::lag is working expect_equal(length(lag(ldeaths, 12)), 72) @@ -1646,6 +1686,15 @@ test_that("column functions", { expect_true(any(apply(s, 1, function(x) { x[[1]]$age == 16 }))) } + # Test to_json() supports arrays of primitive types and arrays + df <- sql("SELECT array(19, 42, 70) as age") + j <- collect(select(df, alias(to_json(df$age), "json"))) + expect_equal(j[order(j$json), ][1], "[19,42,70]") + + df <- sql("SELECT array(array(1, 2), array(3, 4)) as matrix") + j <- collect(select(df, alias(to_json(df$matrix), "json"))) + expect_equal(j[order(j$json), ][1], "[[1,2],[3,4]]") + # passing option df <- as.DataFrame(list(list("col" = "{\"date\":\"21/10/2014\"}"))) schema2 <- structType(structField("date", "date")) @@ -1830,9 +1879,9 @@ test_that("date functions on a DataFrame", { expect_equal(collect(select(df2, minute(df2$b)))[, 1], c(34, 24)) expect_equal(collect(select(df2, second(df2$b)))[, 1], c(0, 34)) expect_equal(collect(select(df2, from_utc_timestamp(df2$b, "JST")))[, 1], - c(as.POSIXlt("2012-12-13 21:34:00 UTC"), as.POSIXlt("2014-12-15 10:24:34 UTC"))) + c(as.POSIXct("2012-12-13 21:34:00 UTC"), as.POSIXct("2014-12-15 10:24:34 UTC"))) expect_equal(collect(select(df2, to_utc_timestamp(df2$b, "JST")))[, 1], - c(as.POSIXlt("2012-12-13 03:34:00 UTC"), as.POSIXlt("2014-12-14 16:24:34 UTC"))) + c(as.POSIXct("2012-12-13 03:34:00 UTC"), as.POSIXct("2014-12-14 16:24:34 UTC"))) expect_gt(collect(select(df2, unix_timestamp()))[1, 1], 0) expect_gt(collect(select(df2, unix_timestamp(df2$b)))[1, 1], 0) expect_gt(collect(select(df2, unix_timestamp(lit("2015-01-01"), "yyyy-MM-dd")))[1, 1], 0) @@ -2461,6 +2510,25 @@ test_that("union(), unionByName(), rbind(), except(), and intersect() on a DataF unlink(jsonPath2) }) +test_that("intersectAll() and exceptAll()", { + df1 <- createDataFrame(list(list("a", 1), list("a", 1), list("a", 1), + list("a", 1), list("b", 3), list("c", 4)), + schema = c("a", "b")) + df2 <- createDataFrame(list(list("a", 1), list("a", 1), list("b", 3)), schema = c("a", "b")) + intersectAllExpected <- data.frame("a" = c("a", "a", "b"), "b" = c(1, 1, 3), + stringsAsFactors = FALSE) + exceptAllExpected <- data.frame("a" = c("a", "a", "c"), "b" = c(1, 1, 4), + stringsAsFactors = FALSE) + intersectAllDf <- arrange(intersectAll(df1, df2), df1$a) + expect_is(intersectAllDf, "SparkDataFrame") + exceptAllDf <- arrange(exceptAll(df1, df2), df1$a) + expect_is(exceptAllDf, "SparkDataFrame") + intersectAllActual <- collect(intersectAllDf) + expect_identical(intersectAllActual, intersectAllExpected) + exceptAllActual <- collect(exceptAllDf) + expect_identical(exceptAllActual, exceptAllExpected) +}) + test_that("withColumn() and withColumnRenamed()", { df <- read.json(jsonPath) newDF <- withColumn(df, "newAge", df$age + 2) @@ -3592,11 +3660,12 @@ test_that("Collect on DataFrame when NAs exists at the top of a timestamp column test_that("catalog APIs, currentDatabase, setCurrentDatabase, listDatabases", { expect_equal(currentDatabase(), "default") expect_error(setCurrentDatabase("default"), NA) - expect_error(setCurrentDatabase("foo"), - "Error in setCurrentDatabase : analysis error - Database 'foo' does not exist") + expect_error(setCurrentDatabase("zxwtyswklpf"), + paste0("Error in setCurrentDatabase : analysis error - Database ", + "'zxwtyswklpf' does not exist")) dbs <- collect(listDatabases()) expect_equal(names(dbs), c("name", "description", "locationUri")) - expect_equal(dbs[[1]], "default") + expect_equal(which(dbs[, 1] == "default"), 1) }) test_that("catalog APIs, listTables, listColumns, listFunctions", { @@ -3619,8 +3688,9 @@ test_that("catalog APIs, listTables, listColumns, listFunctions", { expect_equal(colnames(c), c("name", "description", "dataType", "nullable", "isPartition", "isBucket")) expect_equal(collect(c)[[1]][[1]], "speed") - expect_error(listColumns("foo", "default"), - "Error in listColumns : analysis error - Table 'foo' does not exist in database 'default'") + expect_error(listColumns("zxwtyswklpf", "default"), + paste("Error in listColumns : analysis error - Table", + "'zxwtyswklpf' does not exist in database 'default'")) f <- listFunctions() expect_true(nrow(f) >= 200) # 250 @@ -3628,8 +3698,9 @@ test_that("catalog APIs, listTables, listColumns, listFunctions", { c("name", "database", "description", "className", "isTemporary")) expect_equal(take(orderBy(f, "className"), 1)$className, "org.apache.spark.sql.catalyst.expressions.Abs") - expect_error(listFunctions("foo_db"), - "Error in listFunctions : analysis error - Database 'foo_db' does not exist") + expect_error(listFunctions("zxwtyswklpf_db"), + paste("Error in listFunctions : analysis error - Database", + "'zxwtyswklpf_db' does not exist")) # recoverPartitions does not work with tempory view expect_error(recoverPartitions("cars"), diff --git a/R/pkg/tests/fulltests/test_utils.R b/R/pkg/tests/fulltests/test_utils.R index f0292ab335592..b2b6f34aaa085 100644 --- a/R/pkg/tests/fulltests/test_utils.R +++ b/R/pkg/tests/fulltests/test_utils.R @@ -103,7 +103,7 @@ test_that("cleanClosure on R functions", { expect_true("l" %in% ls(env)) expect_true("f" %in% ls(env)) expect_equal(get("l", envir = env, inherits = FALSE), l) - # "y" should be in the environemnt of g. + # "y" should be in the environment of g. newG <- get("g", envir = env, inherits = FALSE) env <- environment(newG) expect_equal(length(ls(env)), 1) diff --git a/R/pkg/vignettes/sparkr-vignettes.Rmd b/R/pkg/vignettes/sparkr-vignettes.Rmd index d4713de7806a1..090363c5f8a3e 100644 --- a/R/pkg/vignettes/sparkr-vignettes.Rmd +++ b/R/pkg/vignettes/sparkr-vignettes.Rmd @@ -590,6 +590,7 @@ summary(model) Predict values on training data ```{r} prediction <- predict(model, training) +head(select(prediction, "Class", "Sex", "Age", "Freq", "Survived", "prediction")) ``` #### Logistic Regression @@ -613,6 +614,7 @@ summary(model) Predict values on training data ```{r} fitted <- predict(model, training) +head(select(fitted, "Class", "Sex", "Age", "Freq", "Survived", "prediction")) ``` Multinomial logistic regression against three classes @@ -652,7 +654,7 @@ We use Titanic data set to show how to use `spark.mlp` in classification. t <- as.data.frame(Titanic) training <- createDataFrame(t) # fit a Multilayer Perceptron Classification Model -model <- spark.mlp(training, Survived ~ Age + Sex, blockSize = 128, layers = c(2, 3), solver = "l-bfgs", maxIter = 100, tol = 0.5, stepSize = 1, seed = 1, initialWeights = c( 0, 0, 0, 5, 5, 5, 9, 9, 9)) +model <- spark.mlp(training, Survived ~ Age + Sex, blockSize = 128, layers = c(2, 2), solver = "l-bfgs", maxIter = 100, tol = 0.5, stepSize = 1, seed = 1, initialWeights = c( 0, 0, 5, 5, 9, 9)) ``` To avoid lengthy display, we only present partial results of the model summary. You can check the full result from your sparkR shell. @@ -807,6 +809,7 @@ df <- createDataFrame(t) dtModel <- spark.decisionTree(df, Survived ~ ., type = "classification", maxDepth = 2) summary(dtModel) predictions <- predict(dtModel, df) +head(select(predictions, "Class", "Sex", "Age", "Freq", "Survived", "prediction")) ``` #### Gradient-Boosted Trees @@ -822,6 +825,7 @@ df <- createDataFrame(t) gbtModel <- spark.gbt(df, Survived ~ ., type = "classification", maxDepth = 2, maxIter = 2) summary(gbtModel) predictions <- predict(gbtModel, df) +head(select(predictions, "Class", "Sex", "Age", "Freq", "Survived", "prediction")) ``` #### Random Forest @@ -837,6 +841,7 @@ df <- createDataFrame(t) rfModel <- spark.randomForest(df, Survived ~ ., type = "classification", maxDepth = 2, numTrees = 2) summary(rfModel) predictions <- predict(rfModel, df) +head(select(predictions, "Class", "Sex", "Age", "Freq", "Survived", "prediction")) ``` #### Bisecting k-Means diff --git a/README.md b/README.md index 531d330234062..fd8c7f656968e 100644 --- a/README.md +++ b/README.md @@ -90,7 +90,7 @@ storage systems. Because the protocols have changed in different versions of Hadoop, you must build Spark against the same version that your cluster runs. Please refer to the build documentation at -["Specifying the Hadoop Version"](http://spark.apache.org/docs/latest/building-spark.html#specifying-the-hadoop-version) +["Specifying the Hadoop Version and Enabling YARN"](http://spark.apache.org/docs/latest/building-spark.html#specifying-the-hadoop-version-and-enabling-yarn) for detailed guidance on building for a particular distribution of Hadoop, including building for particular Hive and Hive Thriftserver distributions. diff --git a/appveyor.yml b/appveyor.yml index aee94c59612d2..7fb45745a036f 100644 --- a/appveyor.yml +++ b/appveyor.yml @@ -48,7 +48,7 @@ install: - cmd: R -e "packageVersion('knitr'); packageVersion('rmarkdown'); packageVersion('testthat'); packageVersion('e1071'); packageVersion('survival')" build_script: - - cmd: mvn -DskipTests -Psparkr -Phive -Phive-thriftserver package + - cmd: mvn -DskipTests -Psparkr -Phive package environment: NOT_CRAN: true diff --git a/bin/docker-image-tool.sh b/bin/docker-image-tool.sh index a3f1bcffaea57..d6371051ef7fb 100755 --- a/bin/docker-image-tool.sh +++ b/bin/docker-image-tool.sh @@ -49,6 +49,7 @@ function build { # Set image build arguments accordingly if this is a source repo and not a distribution archive. IMG_PATH=resource-managers/kubernetes/docker/src/main/dockerfiles BUILD_ARGS=( + ${BUILD_PARAMS} --build-arg img_path=$IMG_PATH --build-arg @@ -57,18 +58,20 @@ function build { else # Not passed as an argument to docker, but used to validate the Spark directory. IMG_PATH="kubernetes/dockerfiles" - BUILD_ARGS=() + BUILD_ARGS=(${BUILD_PARAMS}) fi if [ ! -d "$IMG_PATH" ]; then error "Cannot find docker image. This script must be run from a runnable distribution of Apache Spark." fi local BINDING_BUILD_ARGS=( + ${BUILD_PARAMS} --build-arg base_img=$(image_ref spark) ) local BASEDOCKERFILE=${BASEDOCKERFILE:-"$IMG_PATH/spark/Dockerfile"} local PYDOCKERFILE=${PYDOCKERFILE:-"$IMG_PATH/spark/bindings/python/Dockerfile"} + local RDOCKERFILE=${RDOCKERFILE:-"$IMG_PATH/spark/bindings/R/Dockerfile"} docker build $NOCACHEARG "${BUILD_ARGS[@]}" \ -t $(image_ref spark) \ @@ -77,11 +80,16 @@ function build { docker build $NOCACHEARG "${BINDING_BUILD_ARGS[@]}" \ -t $(image_ref spark-py) \ -f "$PYDOCKERFILE" . + + docker build $NOCACHEARG "${BINDING_BUILD_ARGS[@]}" \ + -t $(image_ref spark-r) \ + -f "$RDOCKERFILE" . } function push { docker push "$(image_ref spark)" docker push "$(image_ref spark-py)" + docker push "$(image_ref spark-r)" } function usage { @@ -95,12 +103,15 @@ Commands: push Push a pre-built image to a registry. Requires a repository address to be provided. Options: - -f file Dockerfile to build for JVM based Jobs. By default builds the Dockerfile shipped with Spark. - -p file Dockerfile with Python baked in. By default builds the Dockerfile shipped with Spark. - -r repo Repository address. - -t tag Tag to apply to the built image, or to identify the image to be pushed. - -m Use minikube's Docker daemon. - -n Build docker image with --no-cache + -f file Dockerfile to build for JVM based Jobs. By default builds the Dockerfile shipped with Spark. + -p file Dockerfile to build for PySpark Jobs. Builds Python dependencies and ships with Spark. + -R file Dockerfile to build for SparkR Jobs. Builds R dependencies and ships with Spark. + -r repo Repository address. + -t tag Tag to apply to the built image, or to identify the image to be pushed. + -m Use minikube's Docker daemon. + -n Build docker image with --no-cache + -b arg Build arg to build or push the image. For multiple build args, this option needs to + be used separately for each build arg. Using minikube when building images will do so directly into minikube's Docker daemon. There is no need to push the images into minikube in that case, they'll be automatically @@ -129,16 +140,20 @@ REPO= TAG= BASEDOCKERFILE= PYDOCKERFILE= +RDOCKERFILE= NOCACHEARG= -while getopts f:mr:t:n option +BUILD_PARAMS= +while getopts f:p:R:mr:t:n:b: option do case "${option}" in f) BASEDOCKERFILE=${OPTARG};; p) PYDOCKERFILE=${OPTARG};; + R) RDOCKERFILE=${OPTARG};; r) REPO=${OPTARG};; t) TAG=${OPTARG};; n) NOCACHEARG="--no-cache";; + b) BUILD_PARAMS=${BUILD_PARAMS}" --build-arg "${OPTARG};; m) if ! which minikube 1>/dev/null; then error "Cannot find minikube." diff --git a/build/mvn b/build/mvn index ae4276dbc7e32..2487b81abb4ea 100755 --- a/build/mvn +++ b/build/mvn @@ -67,6 +67,9 @@ install_app() { fi } +# See simple version normalization: http://stackoverflow.com/questions/16989598/bash-comparing-version-numbers +function version { echo "$@" | awk -F. '{ printf("%03d%03d%03d\n", $1,$2,$3); }'; } + # Determine the Maven version from the root pom.xml file and # install maven under the build/ folder if needed. install_mvn() { @@ -75,8 +78,6 @@ install_mvn() { if [ "$MVN_BIN" ]; then local MVN_DETECTED_VERSION="$(mvn --version | head -n1 | awk '{print $3}')" fi - # See simple version normalization: http://stackoverflow.com/questions/16989598/bash-comparing-version-numbers - function version { echo "$@" | awk -F. '{ printf("%03d%03d%03d\n", $1,$2,$3); }'; } if [ $(version $MVN_DETECTED_VERSION) -lt $(version $MVN_VERSION) ]; then local APACHE_MIRROR=${APACHE_MIRROR:-'https://www.apache.org/dyn/closer.lua?action=download&filename='} @@ -91,15 +92,23 @@ install_mvn() { # Install zinc under the build/ folder install_zinc() { - local zinc_path="zinc-0.3.15/bin/zinc" - [ ! -f "${_DIR}/${zinc_path}" ] && ZINC_INSTALL_FLAG=1 - local TYPESAFE_MIRROR=${TYPESAFE_MIRROR:-https://downloads.lightbend.com} + local ZINC_VERSION=0.3.15 + ZINC_BIN="$(command -v zinc)" + if [ "$ZINC_BIN" ]; then + local ZINC_DETECTED_VERSION="$(zinc -version | head -n1 | awk '{print $5}')" + fi - install_app \ - "${TYPESAFE_MIRROR}/zinc/0.3.15" \ - "zinc-0.3.15.tgz" \ - "${zinc_path}" - ZINC_BIN="${_DIR}/${zinc_path}" + if [ $(version $ZINC_DETECTED_VERSION) -lt $(version $ZINC_VERSION) ]; then + local zinc_path="zinc-${ZINC_VERSION}/bin/zinc" + [ ! -f "${_DIR}/${zinc_path}" ] && ZINC_INSTALL_FLAG=1 + local TYPESAFE_MIRROR=${TYPESAFE_MIRROR:-https://downloads.lightbend.com} + + install_app \ + "${TYPESAFE_MIRROR}/zinc/${ZINC_VERSION}" \ + "zinc-${ZINC_VERSION}.tgz" \ + "${zinc_path}" + ZINC_BIN="${_DIR}/${zinc_path}" + fi } # Determine the Scala version from the root pom.xml file, set the Scala URL, diff --git a/common/kvstore/src/main/java/org/apache/spark/util/kvstore/LevelDB.java b/common/kvstore/src/main/java/org/apache/spark/util/kvstore/LevelDB.java index 0e491efac9181..58e2a8f25f34f 100644 --- a/common/kvstore/src/main/java/org/apache/spark/util/kvstore/LevelDB.java +++ b/common/kvstore/src/main/java/org/apache/spark/util/kvstore/LevelDB.java @@ -234,7 +234,7 @@ public void close() throws IOException { * Closes the given iterator if the DB is still open. Trying to close a JNI LevelDB handle * with a closed DB can cause JVM crashes, so this ensures that situation does not happen. */ - void closeIterator(LevelDBIterator it) throws IOException { + void closeIterator(LevelDBIterator it) throws IOException { synchronized (this._db) { DB _db = this._db.get(); if (_db != null) { diff --git a/common/kvstore/src/test/java/org/apache/spark/util/kvstore/InMemoryStoreSuite.java b/common/kvstore/src/test/java/org/apache/spark/util/kvstore/InMemoryStoreSuite.java index 510b3058a4e3c..9abf26f02f7a7 100644 --- a/common/kvstore/src/test/java/org/apache/spark/util/kvstore/InMemoryStoreSuite.java +++ b/common/kvstore/src/test/java/org/apache/spark/util/kvstore/InMemoryStoreSuite.java @@ -35,7 +35,7 @@ public void testObjectWriteReadDelete() throws Exception { try { store.read(CustomType1.class, t.key); - fail("Expected exception for non-existant object."); + fail("Expected exception for non-existent object."); } catch (NoSuchElementException nsee) { // Expected. } diff --git a/common/kvstore/src/test/java/org/apache/spark/util/kvstore/LevelDBSuite.java b/common/kvstore/src/test/java/org/apache/spark/util/kvstore/LevelDBSuite.java index b8123ac81d29a..205f7df87c5bc 100644 --- a/common/kvstore/src/test/java/org/apache/spark/util/kvstore/LevelDBSuite.java +++ b/common/kvstore/src/test/java/org/apache/spark/util/kvstore/LevelDBSuite.java @@ -80,7 +80,7 @@ public void testObjectWriteReadDelete() throws Exception { try { db.read(CustomType1.class, t.key); - fail("Expected exception for non-existant object."); + fail("Expected exception for non-existent object."); } catch (NoSuchElementException nsee) { // Expected. } diff --git a/common/network-common/src/main/java/org/apache/spark/network/buffer/FileSegmentManagedBuffer.java b/common/network-common/src/main/java/org/apache/spark/network/buffer/FileSegmentManagedBuffer.java index 8b8f9892847c3..45fee541a4f5d 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/buffer/FileSegmentManagedBuffer.java +++ b/common/network-common/src/main/java/org/apache/spark/network/buffer/FileSegmentManagedBuffer.java @@ -77,16 +77,16 @@ public ByteBuffer nioByteBuffer() throws IOException { return channel.map(FileChannel.MapMode.READ_ONLY, offset, length); } } catch (IOException e) { + String errorMessage = "Error in reading " + this; try { if (channel != null) { long size = channel.size(); - throw new IOException("Error in reading " + this + " (actual file length " + size + ")", - e); + errorMessage = "Error in reading " + this + " (actual file length " + size + ")"; } } catch (IOException ignored) { // ignore } - throw new IOException("Error in opening " + this, e); + throw new IOException(errorMessage, e); } finally { JavaUtils.closeQuietly(channel); } @@ -95,26 +95,24 @@ public ByteBuffer nioByteBuffer() throws IOException { @Override public InputStream createInputStream() throws IOException { FileInputStream is = null; + boolean shouldClose = true; try { is = new FileInputStream(file); ByteStreams.skipFully(is, offset); - return new LimitedInputStream(is, length); + InputStream r = new LimitedInputStream(is, length); + shouldClose = false; + return r; } catch (IOException e) { - try { - if (is != null) { - long size = file.length(); - throw new IOException("Error in reading " + this + " (actual file length " + size + ")", - e); - } - } catch (IOException ignored) { - // ignore - } finally { + String errorMessage = "Error in reading " + this; + if (is != null) { + long size = file.length(); + errorMessage = "Error in reading " + this + " (actual file length " + size + ")"; + } + throw new IOException(errorMessage, e); + } finally { + if (shouldClose) { JavaUtils.closeQuietly(is); } - throw new IOException("Error in opening " + this, e); - } catch (RuntimeException e) { - JavaUtils.closeQuietly(is); - throw e; } } diff --git a/common/network-common/src/main/java/org/apache/spark/network/client/TransportClient.java b/common/network-common/src/main/java/org/apache/spark/network/client/TransportClient.java index 325225dc0ea2c..20d840baeaf6c 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/client/TransportClient.java +++ b/common/network-common/src/main/java/org/apache/spark/network/client/TransportClient.java @@ -318,7 +318,7 @@ private class StdChannelListener } @Override - public void operationComplete(Future future) throws Exception { + public void operationComplete(Future future) throws Exception { if (future.isSuccess()) { if (logger.isTraceEnabled()) { long timeTaken = System.currentTimeMillis() - startTime; diff --git a/common/network-common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java b/common/network-common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java index 7a3d96ceaef0c..596b0ea5dba9b 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java +++ b/common/network-common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java @@ -212,8 +212,8 @@ public void handle(ResponseMessage message) throws Exception { if (entry != null) { StreamCallback callback = entry.getValue(); if (resp.byteCount > 0) { - StreamInterceptor interceptor = new StreamInterceptor(this, resp.streamId, resp.byteCount, - callback); + StreamInterceptor interceptor = new StreamInterceptor<>( + this, resp.streamId, resp.byteCount, callback); try { TransportFrameDecoder frameDecoder = (TransportFrameDecoder) channel.pipeline().get(TransportFrameDecoder.HANDLER_NAME); diff --git a/common/network-common/src/main/java/org/apache/spark/network/crypto/TransportCipher.java b/common/network-common/src/main/java/org/apache/spark/network/crypto/TransportCipher.java index e04524dde0a75..b64e4b7a970b5 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/crypto/TransportCipher.java +++ b/common/network-common/src/main/java/org/apache/spark/network/crypto/TransportCipher.java @@ -240,7 +240,7 @@ public boolean release(int decrement) { @Override public long transferTo(WritableByteChannel target, long position) throws IOException { - Preconditions.checkArgument(position == transfered(), "Invalid position."); + Preconditions.checkArgument(position == transferred(), "Invalid position."); do { if (currentEncrypted == null) { @@ -267,7 +267,7 @@ private void encryptMore() throws IOException { int copied = byteRawChannel.write(buf.nioBuffer()); buf.skipBytes(copied); } else { - region.transferTo(byteRawChannel, region.transfered()); + region.transferTo(byteRawChannel, region.transferred()); } cos.write(byteRawChannel.getData(), 0, byteRawChannel.length()); cos.flush(); diff --git a/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageWithHeader.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageWithHeader.java index e7b66a6f33a82..b81c25afc737f 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageWithHeader.java +++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageWithHeader.java @@ -140,8 +140,24 @@ private int copyByteBuf(ByteBuf buf, WritableByteChannel target) throws IOExcept // SPARK-24578: cap the sub-region's size of returned nio buffer to improve the performance // for the case that the passed-in buffer has too many components. int length = Math.min(buf.readableBytes(), NIO_BUFFER_LIMIT); - ByteBuffer buffer = buf.nioBuffer(buf.readerIndex(), length); - int written = target.write(buffer); + // If the ByteBuf holds more then one ByteBuffer we should better call nioBuffers(...) + // to eliminate extra memory copies. + int written = 0; + if (buf.nioBufferCount() == 1) { + ByteBuffer buffer = buf.nioBuffer(buf.readerIndex(), length); + written = target.write(buffer); + } else { + ByteBuffer[] buffers = buf.nioBuffers(buf.readerIndex(), length); + for (ByteBuffer buffer: buffers) { + int remaining = buffer.remaining(); + int w = target.write(buffer); + written += w; + if (w < remaining) { + // Could not write all, we need to break now. + break; + } + } + } buf.skipBytes(written); return written; } diff --git a/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslEncryption.java b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslEncryption.java index 3ac9081d78a75..e1275689ae6a0 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslEncryption.java +++ b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslEncryption.java @@ -135,13 +135,14 @@ static class EncryptedMessage extends AbstractFileRegion { private final boolean isByteBuf; private final ByteBuf buf; private final FileRegion region; + private final int maxOutboundBlockSize; /** * A channel used to buffer input data for encryption. The channel has an upper size bound * so that if the input is larger than the allowed buffer, it will be broken into multiple - * chunks. + * chunks. Made non-final to enable lazy initialization, which saves memory. */ - private final ByteArrayWritableChannel byteChannel; + private ByteArrayWritableChannel byteChannel; private ByteBuf currentHeader; private ByteBuffer currentChunk; @@ -157,7 +158,7 @@ static class EncryptedMessage extends AbstractFileRegion { this.isByteBuf = msg instanceof ByteBuf; this.buf = isByteBuf ? (ByteBuf) msg : null; this.region = isByteBuf ? null : (FileRegion) msg; - this.byteChannel = new ByteArrayWritableChannel(maxOutboundBlockSize); + this.maxOutboundBlockSize = maxOutboundBlockSize; } /** @@ -230,17 +231,17 @@ public boolean release(int decrement) { * data into memory at once, and can avoid ballooning memory usage when transferring large * messages such as shuffle blocks. * - * The {@link #transfered()} counter also behaves a little funny, in that it won't go forward + * The {@link #transferred()} counter also behaves a little funny, in that it won't go forward * until a whole chunk has been written. This is done because the code can't use the actual * number of bytes written to the channel as the transferred count (see {@link #count()}). * Instead, once an encrypted chunk is written to the output (including its header), the - * size of the original block will be added to the {@link #transfered()} amount. + * size of the original block will be added to the {@link #transferred()} amount. */ @Override public long transferTo(final WritableByteChannel target, final long position) throws IOException { - Preconditions.checkArgument(position == transfered(), "Invalid position."); + Preconditions.checkArgument(position == transferred(), "Invalid position."); long reportedWritten = 0L; long actuallyWritten = 0L; @@ -272,7 +273,7 @@ public long transferTo(final WritableByteChannel target, final long position) currentChunkSize = 0; currentReportedBytes = 0; } - } while (currentChunk == null && transfered() + reportedWritten < count()); + } while (currentChunk == null && transferred() + reportedWritten < count()); // Returning 0 triggers a backoff mechanism in netty which may harm performance. Instead, // we return 1 until we can (i.e. until the reported count would actually match the size @@ -292,12 +293,15 @@ public long transferTo(final WritableByteChannel target, final long position) } private void nextChunk() throws IOException { + if (byteChannel == null) { + byteChannel = new ByteArrayWritableChannel(maxOutboundBlockSize); + } byteChannel.reset(); if (isByteBuf) { int copied = byteChannel.write(buf.nioBuffer()); buf.skipBytes(copied); } else { - region.transferTo(byteChannel, region.transfered()); + region.transferTo(byteChannel, region.transferred()); } byte[] encrypted = backend.wrap(byteChannel.getData(), 0, byteChannel.length()); diff --git a/common/network-common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java b/common/network-common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java index e1d7b2dbff60f..9fac96dbe450d 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java +++ b/common/network-common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java @@ -234,7 +234,7 @@ public void onComplete(String streamId) throws IOException { callback.onSuccess(ByteBuffer.allocate(0)); } catch (Exception ex) { IOException ioExc = new IOException("Failure post-processing complete stream;" + - " failing this rpc and leaving channel active"); + " failing this rpc and leaving channel active", ex); callback.onFailure(ioExc); streamHandler.onFailure(streamId, ioExc); } @@ -252,8 +252,8 @@ public String getID() { } }; if (req.bodyByteCount > 0) { - StreamInterceptor interceptor = new StreamInterceptor(this, wrappedCallback.getID(), - req.bodyByteCount, wrappedCallback); + StreamInterceptor interceptor = new StreamInterceptor<>( + this, wrappedCallback.getID(), req.bodyByteCount, wrappedCallback); frameDecoder.setInterceptor(interceptor); } else { wrappedCallback.onComplete(wrappedCallback.getID()); diff --git a/common/network-common/src/main/java/org/apache/spark/network/server/TransportServer.java b/common/network-common/src/main/java/org/apache/spark/network/server/TransportServer.java index 60f51125c07fd..9c85ab2f5f06f 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/server/TransportServer.java +++ b/common/network-common/src/main/java/org/apache/spark/network/server/TransportServer.java @@ -70,11 +70,14 @@ public TransportServer( this.appRpcHandler = appRpcHandler; this.bootstraps = Lists.newArrayList(Preconditions.checkNotNull(bootstraps)); + boolean shouldClose = true; try { init(hostToBind, portToBind); - } catch (RuntimeException e) { - JavaUtils.closeQuietly(this); - throw e; + shouldClose = false; + } finally { + if (shouldClose) { + JavaUtils.closeQuietly(this); + } } } @@ -148,11 +151,11 @@ public void close() { channelFuture.channel().close().awaitUninterruptibly(10, TimeUnit.SECONDS); channelFuture = null; } - if (bootstrap != null && bootstrap.group() != null) { - bootstrap.group().shutdownGracefully(); + if (bootstrap != null && bootstrap.config().group() != null) { + bootstrap.config().group().shutdownGracefully(); } - if (bootstrap != null && bootstrap.childGroup() != null) { - bootstrap.childGroup().shutdownGracefully(); + if (bootstrap != null && bootstrap.config().childGroup() != null) { + bootstrap.config().childGroup().shutdownGracefully(); } bootstrap = null; } diff --git a/common/network-common/src/main/java/org/apache/spark/network/util/NettyUtils.java b/common/network-common/src/main/java/org/apache/spark/network/util/NettyUtils.java index 5e85180bd6f9f..33d6eb4a83a0c 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/util/NettyUtils.java +++ b/common/network-common/src/main/java/org/apache/spark/network/util/NettyUtils.java @@ -17,7 +17,6 @@ package org.apache.spark.network.util; -import java.lang.reflect.Field; import java.util.concurrent.ThreadFactory; import io.netty.buffer.PooledByteBufAllocator; @@ -111,24 +110,14 @@ public static PooledByteBufAllocator createPooledByteBufAllocator( } return new PooledByteBufAllocator( allowDirectBufs && PlatformDependent.directBufferPreferred(), - Math.min(getPrivateStaticField("DEFAULT_NUM_HEAP_ARENA"), numCores), - Math.min(getPrivateStaticField("DEFAULT_NUM_DIRECT_ARENA"), allowDirectBufs ? numCores : 0), - getPrivateStaticField("DEFAULT_PAGE_SIZE"), - getPrivateStaticField("DEFAULT_MAX_ORDER"), - allowCache ? getPrivateStaticField("DEFAULT_TINY_CACHE_SIZE") : 0, - allowCache ? getPrivateStaticField("DEFAULT_SMALL_CACHE_SIZE") : 0, - allowCache ? getPrivateStaticField("DEFAULT_NORMAL_CACHE_SIZE") : 0 + Math.min(PooledByteBufAllocator.defaultNumHeapArena(), numCores), + Math.min(PooledByteBufAllocator.defaultNumDirectArena(), allowDirectBufs ? numCores : 0), + PooledByteBufAllocator.defaultPageSize(), + PooledByteBufAllocator.defaultMaxOrder(), + allowCache ? PooledByteBufAllocator.defaultTinyCacheSize() : 0, + allowCache ? PooledByteBufAllocator.defaultSmallCacheSize() : 0, + allowCache ? PooledByteBufAllocator.defaultNormalCacheSize() : 0, + allowCache ? PooledByteBufAllocator.defaultUseCacheForAllThreads() : false ); } - - /** Used to get defaults from Netty's private static fields. */ - private static int getPrivateStaticField(String name) { - try { - Field f = PooledByteBufAllocator.DEFAULT.getClass().getDeclaredField(name); - f.setAccessible(true); - return f.getInt(null); - } catch (Exception e) { - throw new RuntimeException(e); - } - } } diff --git a/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java b/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java index 91497b9492219..34e4bb5912dcb 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java +++ b/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java @@ -209,7 +209,7 @@ public String keyFactoryAlgorithm() { * (128 bits by default), which is not generally the case with user passwords. */ public int keyFactoryIterations() { - return conf.getInt("spark.networy.crypto.keyFactoryIterations", 1024); + return conf.getInt("spark.network.crypto.keyFactoryIterations", 1024); } /** diff --git a/common/network-common/src/test/java/org/apache/spark/network/ProtocolSuite.java b/common/network-common/src/test/java/org/apache/spark/network/ProtocolSuite.java index bc94f7ca63a96..6fb44fea8c5a4 100644 --- a/common/network-common/src/test/java/org/apache/spark/network/ProtocolSuite.java +++ b/common/network-common/src/test/java/org/apache/spark/network/ProtocolSuite.java @@ -116,8 +116,8 @@ public void encode(ChannelHandlerContext ctx, FileRegion in, List out) throws Exception { ByteArrayWritableChannel channel = new ByteArrayWritableChannel(Ints.checkedCast(in.count())); - while (in.transfered() < in.count()) { - in.transferTo(channel, in.transfered()); + while (in.transferred() < in.count()) { + in.transferTo(channel, in.transferred()); } out.add(Unpooled.wrappedBuffer(channel.getData())); } diff --git a/common/network-common/src/test/java/org/apache/spark/network/protocol/MessageWithHeaderSuite.java b/common/network-common/src/test/java/org/apache/spark/network/protocol/MessageWithHeaderSuite.java index ecb66fcf2ff76..3bff34e210e3c 100644 --- a/common/network-common/src/test/java/org/apache/spark/network/protocol/MessageWithHeaderSuite.java +++ b/common/network-common/src/test/java/org/apache/spark/network/protocol/MessageWithHeaderSuite.java @@ -22,6 +22,7 @@ import java.nio.channels.WritableByteChannel; import io.netty.buffer.ByteBuf; +import io.netty.buffer.CompositeByteBuf; import io.netty.buffer.Unpooled; import org.apache.spark.network.util.AbstractFileRegion; import org.junit.Test; @@ -48,7 +49,36 @@ public void testShortWrite() throws Exception { @Test public void testByteBufBody() throws Exception { + testByteBufBody(Unpooled.copyLong(42)); + } + + @Test + public void testCompositeByteBufBodySingleBuffer() throws Exception { + ByteBuf header = Unpooled.copyLong(42); + CompositeByteBuf compositeByteBuf = Unpooled.compositeBuffer(); + compositeByteBuf.addComponent(true, header); + assertEquals(1, compositeByteBuf.nioBufferCount()); + testByteBufBody(compositeByteBuf); + } + + @Test + public void testCompositeByteBufBodyMultipleBuffers() throws Exception { ByteBuf header = Unpooled.copyLong(42); + CompositeByteBuf compositeByteBuf = Unpooled.compositeBuffer(); + compositeByteBuf.addComponent(true, header.retainedSlice(0, 4)); + compositeByteBuf.addComponent(true, header.slice(4, 4)); + assertEquals(2, compositeByteBuf.nioBufferCount()); + testByteBufBody(compositeByteBuf); + } + + /** + * Test writing a {@link MessageWithHeader} using the given {@link ByteBuf} as header. + * + * @param header the header to use. + * @throws Exception thrown on error. + */ + private void testByteBufBody(ByteBuf header) throws Exception { + long expectedHeaderValue = header.getLong(header.readerIndex()); ByteBuf bodyPassedToNettyManagedBuffer = Unpooled.copyLong(84); assertEquals(1, header.refCnt()); assertEquals(1, bodyPassedToNettyManagedBuffer.refCnt()); @@ -61,7 +91,7 @@ public void testByteBufBody() throws Exception { MessageWithHeader msg = new MessageWithHeader(managedBuf, header, body, managedBuf.size()); ByteBuf result = doWrite(msg, 1); assertEquals(msg.count(), result.readableBytes()); - assertEquals(42, result.readLong()); + assertEquals(expectedHeaderValue, result.readLong()); assertEquals(84, result.readLong()); assertTrue(msg.release()); diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java index 9af6759f5d5f3..a68a297519b66 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java @@ -42,7 +42,7 @@ public abstract class BlockTransferMessage implements Encodable { /** Preceding every serialized message is its type, which allows us to deserialize it. */ public enum Type { OPEN_BLOCKS(0), UPLOAD_BLOCK(1), REGISTER_EXECUTOR(2), STREAM_HANDLE(3), REGISTER_DRIVER(4), - HEARTBEAT(5); + HEARTBEAT(5), UPLOAD_BLOCK_STREAM(6); private final byte id; @@ -67,6 +67,7 @@ public static BlockTransferMessage fromByteBuffer(ByteBuffer msg) { case 3: return StreamHandle.decode(buf); case 4: return RegisterDriver.decode(buf); case 5: return ShuffleServiceHeartbeat.decode(buf); + case 6: return UploadBlockStream.decode(buf); default: throw new IllegalArgumentException("Unknown message type: " + type); } } diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/UploadBlockStream.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/UploadBlockStream.java new file mode 100644 index 0000000000000..9df30967d5bb2 --- /dev/null +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/UploadBlockStream.java @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.shuffle.protocol; + +import java.util.Arrays; + +import com.google.common.base.Objects; +import io.netty.buffer.ByteBuf; + +import org.apache.spark.network.protocol.Encoders; + +// Needed by ScalaDoc. See SPARK-7726 +import static org.apache.spark.network.shuffle.protocol.BlockTransferMessage.Type; + +/** + * A request to Upload a block, which the destination should receive as a stream. + * + * The actual block data is not contained here. It will be passed to the StreamCallbackWithID + * that is returned from RpcHandler.receiveStream() + */ +public class UploadBlockStream extends BlockTransferMessage { + public final String blockId; + public final byte[] metadata; + + public UploadBlockStream(String blockId, byte[] metadata) { + this.blockId = blockId; + this.metadata = metadata; + } + + @Override + protected Type type() { return Type.UPLOAD_BLOCK_STREAM; } + + @Override + public int hashCode() { + int objectsHashCode = Objects.hashCode(blockId); + return objectsHashCode * 41 + Arrays.hashCode(metadata); + } + + @Override + public String toString() { + return Objects.toStringHelper(this) + .add("blockId", blockId) + .add("metadata size", metadata.length) + .toString(); + } + + @Override + public boolean equals(Object other) { + if (other != null && other instanceof UploadBlockStream) { + UploadBlockStream o = (UploadBlockStream) other; + return Objects.equal(blockId, o.blockId) + && Arrays.equals(metadata, o.metadata); + } + return false; + } + + @Override + public int encodedLength() { + return Encoders.Strings.encodedLength(blockId) + + Encoders.ByteArrays.encodedLength(metadata); + } + + @Override + public void encode(ByteBuf buf) { + Encoders.Strings.encode(buf, blockId); + Encoders.ByteArrays.encode(buf, metadata); + } + + public static UploadBlockStream decode(ByteBuf buf) { + String blockId = Encoders.Strings.decode(buf); + byte[] metadata = Encoders.ByteArrays.decode(buf); + return new UploadBlockStream(blockId, metadata); + } +} diff --git a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/expressions/HiveHasher.java b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/expressions/HiveHasher.java index 62b75ae8aa01d..73577437ac506 100644 --- a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/expressions/HiveHasher.java +++ b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/expressions/HiveHasher.java @@ -17,8 +17,7 @@ package org.apache.spark.sql.catalyst.expressions; -import org.apache.spark.unsafe.memory.MemoryBlock; -import org.apache.spark.unsafe.types.UTF8String; +import org.apache.spark.unsafe.Platform; /** * Simulates Hive's hashing function from Hive v1.2.1 @@ -39,21 +38,12 @@ public static int hashLong(long input) { return (int) ((input >>> 32) ^ input); } - public static int hashUnsafeBytesBlock(MemoryBlock mb) { - long lengthInBytes = mb.size(); + public static int hashUnsafeBytes(Object base, long offset, int lengthInBytes) { assert (lengthInBytes >= 0): "lengthInBytes cannot be negative"; int result = 0; - for (long i = 0; i < lengthInBytes; i++) { - result = (result * 31) + (int) mb.getByte(i); + for (int i = 0; i < lengthInBytes; i++) { + result = (result * 31) + (int) Platform.getByte(base, offset + i); } return result; } - - public static int hashUnsafeBytes(Object base, long offset, int lengthInBytes) { - return hashUnsafeBytesBlock(MemoryBlock.allocateFromObject(base, offset, lengthInBytes)); - } - - public static int hashUTF8String(UTF8String str) { - return hashUnsafeBytesBlock(str.getMemoryBlock()); - } } diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java index 54dcadf3a7754..aca6fca00c48b 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java @@ -187,7 +187,7 @@ public static void setMemory(long address, byte value, long size) { } public static void copyMemory( - Object src, long srcOffset, Object dst, long dstOffset, long length) { + Object src, long srcOffset, Object dst, long dstOffset, long length) { // Check if dstOffset is before or after srcOffset to determine if we should copy // forward or backwards. This is necessary in case src and dst overlap. if (dstOffset < srcOffset) { diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/array/ByteArrayMethods.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/array/ByteArrayMethods.java index ef0f78d95d1ee..cec8c30887e2f 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/array/ByteArrayMethods.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/array/ByteArrayMethods.java @@ -18,7 +18,6 @@ package org.apache.spark.unsafe.array; import org.apache.spark.unsafe.Platform; -import org.apache.spark.unsafe.memory.MemoryBlock; public class ByteArrayMethods { @@ -53,25 +52,15 @@ public static long roundNumberOfBytesToNearestWord(long numBytes) { public static int MAX_ROUNDED_ARRAY_LENGTH = Integer.MAX_VALUE - 15; private static final boolean unaligned = Platform.unaligned(); - /** - * MemoryBlock equality check for MemoryBlocks. - * @return true if the arrays are equal, false otherwise - */ - public static boolean arrayEqualsBlock( - MemoryBlock leftBase, long leftOffset, MemoryBlock rightBase, long rightOffset, long length) { - return arrayEquals(leftBase.getBaseObject(), leftBase.getBaseOffset() + leftOffset, - rightBase.getBaseObject(), rightBase.getBaseOffset() + rightOffset, length); - } - /** * Optimized byte array equality check for byte arrays. * @return true if the arrays are equal, false otherwise */ public static boolean arrayEquals( - Object leftBase, long leftOffset, Object rightBase, long rightOffset, long length) { + Object leftBase, long leftOffset, Object rightBase, long rightOffset, final long length) { int i = 0; - // check if starts align and we can get both offsets to be aligned + // check if stars align and we can get both offsets to be aligned if ((leftOffset % 8) == (rightOffset % 8)) { while ((leftOffset + i) % 8 != 0 && i < length) { if (Platform.getByte(leftBase, leftOffset + i) != diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/array/LongArray.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/array/LongArray.java index b74d2de0691d5..2cd39bd60c2ac 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/array/LongArray.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/array/LongArray.java @@ -17,6 +17,7 @@ package org.apache.spark.unsafe.array; +import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.memory.MemoryBlock; /** @@ -32,12 +33,16 @@ public final class LongArray { private static final long WIDTH = 8; private final MemoryBlock memory; + private final Object baseObj; + private final long baseOffset; private final long length; public LongArray(MemoryBlock memory) { assert memory.size() < (long) Integer.MAX_VALUE * 8: "Array size >= Integer.MAX_VALUE elements"; this.memory = memory; + this.baseObj = memory.getBaseObject(); + this.baseOffset = memory.getBaseOffset(); this.length = memory.size() / WIDTH; } @@ -46,11 +51,11 @@ public MemoryBlock memoryBlock() { } public Object getBaseObject() { - return memory.getBaseObject(); + return baseObj; } public long getBaseOffset() { - return memory.getBaseOffset(); + return baseOffset; } /** @@ -64,8 +69,8 @@ public long size() { * Fill this all with 0L. */ public void zeroOut() { - for (long off = 0; off < length * WIDTH; off += WIDTH) { - memory.putLong(off, 0); + for (long off = baseOffset; off < baseOffset + length * WIDTH; off += WIDTH) { + Platform.putLong(baseObj, off, 0); } } @@ -75,7 +80,7 @@ public void zeroOut() { public void set(int index, long value) { assert index >= 0 : "index (" + index + ") should >= 0"; assert index < length : "index (" + index + ") should < length (" + length + ")"; - memory.putLong(index * WIDTH, value); + Platform.putLong(baseObj, baseOffset + index * WIDTH, value); } /** @@ -84,6 +89,6 @@ public void set(int index, long value) { public long get(int index) { assert index >= 0 : "index (" + index + ") should >= 0"; assert index < length : "index (" + index + ") should < length (" + length + ")"; - return memory.getLong(index * WIDTH); + return Platform.getLong(baseObj, baseOffset + index * WIDTH); } } diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/hash/Murmur3_x86_32.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/hash/Murmur3_x86_32.java index aff6e93d647fe..d239de6083ad0 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/hash/Murmur3_x86_32.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/hash/Murmur3_x86_32.java @@ -17,10 +17,7 @@ package org.apache.spark.unsafe.hash; -import com.google.common.primitives.Ints; - -import org.apache.spark.unsafe.memory.MemoryBlock; -import org.apache.spark.unsafe.types.UTF8String; +import org.apache.spark.unsafe.Platform; /** * 32-bit Murmur3 hasher. This is based on Guava's Murmur3_32HashFunction. @@ -52,70 +49,49 @@ public static int hashInt(int input, int seed) { } public int hashUnsafeWords(Object base, long offset, int lengthInBytes) { - return hashUnsafeWordsBlock(MemoryBlock.allocateFromObject(base, offset, lengthInBytes), seed); + return hashUnsafeWords(base, offset, lengthInBytes, seed); } - public static int hashUnsafeWordsBlock(MemoryBlock base, int seed) { + public static int hashUnsafeWords(Object base, long offset, int lengthInBytes, int seed) { // This is based on Guava's `Murmur32_Hasher.processRemaining(ByteBuffer)` method. - int lengthInBytes = Ints.checkedCast(base.size()); assert (lengthInBytes % 8 == 0): "lengthInBytes must be a multiple of 8 (word-aligned)"; - int h1 = hashBytesByIntBlock(base, seed); + int h1 = hashBytesByInt(base, offset, lengthInBytes, seed); return fmix(h1, lengthInBytes); } - public static int hashUnsafeWords(Object base, long offset, int lengthInBytes, int seed) { - // This is based on Guava's `Murmur32_Hasher.processRemaining(ByteBuffer)` method. - return hashUnsafeWordsBlock(MemoryBlock.allocateFromObject(base, offset, lengthInBytes), seed); - } - - public static int hashUnsafeBytesBlock(MemoryBlock base, int seed) { + public static int hashUnsafeBytes(Object base, long offset, int lengthInBytes, int seed) { // This is not compatible with original and another implementations. // But remain it for backward compatibility for the components existing before 2.3. - int lengthInBytes = Ints.checkedCast(base.size()); assert (lengthInBytes >= 0): "lengthInBytes cannot be negative"; int lengthAligned = lengthInBytes - lengthInBytes % 4; - int h1 = hashBytesByIntBlock(base.subBlock(0, lengthAligned), seed); + int h1 = hashBytesByInt(base, offset, lengthAligned, seed); for (int i = lengthAligned; i < lengthInBytes; i++) { - int halfWord = base.getByte(i); + int halfWord = Platform.getByte(base, offset + i); int k1 = mixK1(halfWord); h1 = mixH1(h1, k1); } return fmix(h1, lengthInBytes); } - public static int hashUTF8String(UTF8String str, int seed) { - return hashUnsafeBytesBlock(str.getMemoryBlock(), seed); - } - - public static int hashUnsafeBytes(Object base, long offset, int lengthInBytes, int seed) { - return hashUnsafeBytesBlock(MemoryBlock.allocateFromObject(base, offset, lengthInBytes), seed); - } - public static int hashUnsafeBytes2(Object base, long offset, int lengthInBytes, int seed) { - return hashUnsafeBytes2Block(MemoryBlock.allocateFromObject(base, offset, lengthInBytes), seed); - } - - public static int hashUnsafeBytes2Block(MemoryBlock base, int seed) { - // This is compatible with original and other implementations. + // This is compatible with original and another implementations. // Use this method for new components after Spark 2.3. - int lengthInBytes = Ints.checkedCast(base.size()); - assert (lengthInBytes >= 0) : "lengthInBytes cannot be negative"; + assert (lengthInBytes >= 0): "lengthInBytes cannot be negative"; int lengthAligned = lengthInBytes - lengthInBytes % 4; - int h1 = hashBytesByIntBlock(base.subBlock(0, lengthAligned), seed); + int h1 = hashBytesByInt(base, offset, lengthAligned, seed); int k1 = 0; for (int i = lengthAligned, shift = 0; i < lengthInBytes; i++, shift += 8) { - k1 ^= (base.getByte(i) & 0xFF) << shift; + k1 ^= (Platform.getByte(base, offset + i) & 0xFF) << shift; } h1 ^= mixK1(k1); return fmix(h1, lengthInBytes); } - private static int hashBytesByIntBlock(MemoryBlock base, int seed) { - long lengthInBytes = base.size(); + private static int hashBytesByInt(Object base, long offset, int lengthInBytes, int seed) { assert (lengthInBytes % 4 == 0); int h1 = seed; - for (long i = 0; i < lengthInBytes; i += 4) { - int halfWord = base.getInt(i); + for (int i = 0; i < lengthInBytes; i += 4) { + int halfWord = Platform.getInt(base, offset + i); int k1 = mixK1(halfWord); h1 = mixH1(h1, k1); } diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/ByteArrayMemoryBlock.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/ByteArrayMemoryBlock.java deleted file mode 100644 index 9f238632bc87a..0000000000000 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/ByteArrayMemoryBlock.java +++ /dev/null @@ -1,128 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.unsafe.memory; - -import com.google.common.primitives.Ints; - -import org.apache.spark.unsafe.Platform; - -/** - * A consecutive block of memory with a byte array on Java heap. - */ -public final class ByteArrayMemoryBlock extends MemoryBlock { - - private final byte[] array; - - public ByteArrayMemoryBlock(byte[] obj, long offset, long size) { - super(obj, offset, size); - this.array = obj; - assert(offset + size <= Platform.BYTE_ARRAY_OFFSET + obj.length) : - "The sum of size " + size + " and offset " + offset + " should not be larger than " + - "the size of the given memory space " + (obj.length + Platform.BYTE_ARRAY_OFFSET); - } - - public ByteArrayMemoryBlock(long length) { - this(new byte[Ints.checkedCast(length)], Platform.BYTE_ARRAY_OFFSET, length); - } - - @Override - public MemoryBlock subBlock(long offset, long size) { - checkSubBlockRange(offset, size); - if (offset == 0 && size == this.size()) return this; - return new ByteArrayMemoryBlock(array, this.offset + offset, size); - } - - public byte[] getByteArray() { return array; } - - /** - * Creates a memory block pointing to the memory used by the byte array. - */ - public static ByteArrayMemoryBlock fromArray(final byte[] array) { - return new ByteArrayMemoryBlock(array, Platform.BYTE_ARRAY_OFFSET, array.length); - } - - @Override - public int getInt(long offset) { - return Platform.getInt(array, this.offset + offset); - } - - @Override - public void putInt(long offset, int value) { - Platform.putInt(array, this.offset + offset, value); - } - - @Override - public boolean getBoolean(long offset) { - return Platform.getBoolean(array, this.offset + offset); - } - - @Override - public void putBoolean(long offset, boolean value) { - Platform.putBoolean(array, this.offset + offset, value); - } - - @Override - public byte getByte(long offset) { - return array[(int)(this.offset + offset - Platform.BYTE_ARRAY_OFFSET)]; - } - - @Override - public void putByte(long offset, byte value) { - array[(int)(this.offset + offset - Platform.BYTE_ARRAY_OFFSET)] = value; - } - - @Override - public short getShort(long offset) { - return Platform.getShort(array, this.offset + offset); - } - - @Override - public void putShort(long offset, short value) { - Platform.putShort(array, this.offset + offset, value); - } - - @Override - public long getLong(long offset) { - return Platform.getLong(array, this.offset + offset); - } - - @Override - public void putLong(long offset, long value) { - Platform.putLong(array, this.offset + offset, value); - } - - @Override - public float getFloat(long offset) { - return Platform.getFloat(array, this.offset + offset); - } - - @Override - public void putFloat(long offset, float value) { - Platform.putFloat(array, this.offset + offset, value); - } - - @Override - public double getDouble(long offset) { - return Platform.getDouble(array, this.offset + offset); - } - - @Override - public void putDouble(long offset, double value) { - Platform.putDouble(array, this.offset + offset, value); - } -} diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java index 36caf80888cda..2733760dd19ef 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java @@ -23,6 +23,8 @@ import java.util.LinkedList; import java.util.Map; +import org.apache.spark.unsafe.Platform; + /** * A simple {@link MemoryAllocator} that can allocate up to 16GB using a JVM long primitive array. */ @@ -56,7 +58,7 @@ public MemoryBlock allocate(long size) throws OutOfMemoryError { final long[] array = arrayReference.get(); if (array != null) { assert (array.length * 8L >= size); - MemoryBlock memory = OnHeapMemoryBlock.fromArray(array, size); + MemoryBlock memory = new MemoryBlock(array, Platform.LONG_ARRAY_OFFSET, size); if (MemoryAllocator.MEMORY_DEBUG_FILL_ENABLED) { memory.fill(MemoryAllocator.MEMORY_DEBUG_FILL_CLEAN_VALUE); } @@ -68,7 +70,7 @@ public MemoryBlock allocate(long size) throws OutOfMemoryError { } } long[] array = new long[numWords]; - MemoryBlock memory = OnHeapMemoryBlock.fromArray(array, size); + MemoryBlock memory = new MemoryBlock(array, Platform.LONG_ARRAY_OFFSET, size); if (MemoryAllocator.MEMORY_DEBUG_FILL_ENABLED) { memory.fill(MemoryAllocator.MEMORY_DEBUG_FILL_CLEAN_VALUE); } @@ -77,13 +79,12 @@ public MemoryBlock allocate(long size) throws OutOfMemoryError { @Override public void free(MemoryBlock memory) { - assert(memory instanceof OnHeapMemoryBlock); - assert (memory.getBaseObject() != null) : + assert (memory.obj != null) : "baseObject was null; are you trying to use the on-heap allocator to free off-heap memory?"; - assert (memory.getPageNumber() != MemoryBlock.FREED_IN_ALLOCATOR_PAGE_NUMBER) : + assert (memory.pageNumber != MemoryBlock.FREED_IN_ALLOCATOR_PAGE_NUMBER) : "page has already been freed"; - assert ((memory.getPageNumber() == MemoryBlock.NO_PAGE_NUMBER) - || (memory.getPageNumber() == MemoryBlock.FREED_IN_TMM_PAGE_NUMBER)) : + assert ((memory.pageNumber == MemoryBlock.NO_PAGE_NUMBER) + || (memory.pageNumber == MemoryBlock.FREED_IN_TMM_PAGE_NUMBER)) : "TMM-allocated pages must first be freed via TMM.freePage(), not directly in allocator " + "free()"; @@ -93,12 +94,12 @@ public void free(MemoryBlock memory) { } // Mark the page as freed (so we can detect double-frees). - memory.setPageNumber(MemoryBlock.FREED_IN_ALLOCATOR_PAGE_NUMBER); + memory.pageNumber = MemoryBlock.FREED_IN_ALLOCATOR_PAGE_NUMBER; // As an additional layer of defense against use-after-free bugs, we mutate the // MemoryBlock to null out its reference to the long[] array. - long[] array = ((OnHeapMemoryBlock)memory).getLongArray(); - memory.resetObjAndOffset(); + long[] array = (long[]) memory.obj; + memory.setObjAndOffset(null, 0); long alignedSize = ((size + 7) / 8) * 8; if (shouldPool(alignedSize)) { diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryAllocator.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryAllocator.java index 38315fb97b46a..7b588681d9790 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryAllocator.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryAllocator.java @@ -38,7 +38,7 @@ public interface MemoryAllocator { void free(MemoryBlock memory); - UnsafeMemoryAllocator UNSAFE = new UnsafeMemoryAllocator(); + MemoryAllocator UNSAFE = new UnsafeMemoryAllocator(); - HeapMemoryAllocator HEAP = new HeapMemoryAllocator(); + MemoryAllocator HEAP = new HeapMemoryAllocator(); } diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryBlock.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryBlock.java index ca7213bbf92da..c333857358d30 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryBlock.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryBlock.java @@ -22,10 +22,10 @@ import org.apache.spark.unsafe.Platform; /** - * A representation of a consecutive memory block in Spark. It defines the common interfaces - * for memory accessing and mutating. + * A consecutive block of memory, starting at a {@link MemoryLocation} with a fixed size. */ -public abstract class MemoryBlock { +public class MemoryBlock extends MemoryLocation { + /** Special `pageNumber` value for pages which were not allocated by TaskMemoryManagers */ public static final int NO_PAGE_NUMBER = -1; @@ -45,163 +45,38 @@ public abstract class MemoryBlock { */ public static final int FREED_IN_ALLOCATOR_PAGE_NUMBER = -3; - @Nullable - protected Object obj; - - protected long offset; - - protected long length; + private final long length; /** * Optional page number; used when this MemoryBlock represents a page allocated by a - * TaskMemoryManager. This field can be updated using setPageNumber method so that - * this can be modified by the TaskMemoryManager, which lives in a different package. + * TaskMemoryManager. This field is public so that it can be modified by the TaskMemoryManager, + * which lives in a different package. */ - private int pageNumber = NO_PAGE_NUMBER; + public int pageNumber = NO_PAGE_NUMBER; - protected MemoryBlock(@Nullable Object obj, long offset, long length) { - if (offset < 0 || length < 0) { - throw new IllegalArgumentException( - "Length " + length + " and offset " + offset + "must be non-negative"); - } - this.obj = obj; - this.offset = offset; + public MemoryBlock(@Nullable Object obj, long offset, long length) { + super(obj, offset); this.length = length; } - protected MemoryBlock() { - this(null, 0, 0); - } - - public final Object getBaseObject() { - return obj; - } - - public final long getBaseOffset() { - return offset; - } - - public void resetObjAndOffset() { - this.obj = null; - this.offset = 0; - } - /** * Returns the size of the memory block. */ - public final long size() { + public long size() { return length; } - public final void setPageNumber(int pageNum) { - pageNumber = pageNum; - } - - public final int getPageNumber() { - return pageNumber; - } - - /** - * Fills the memory block with the specified byte value. - */ - public final void fill(byte value) { - Platform.setMemory(obj, offset, length, value); - } - - /** - * Instantiate MemoryBlock for given object type with new offset - */ - public static final MemoryBlock allocateFromObject(Object obj, long offset, long length) { - MemoryBlock mb = null; - if (obj instanceof byte[]) { - byte[] array = (byte[])obj; - mb = new ByteArrayMemoryBlock(array, offset, length); - } else if (obj instanceof long[]) { - long[] array = (long[])obj; - mb = new OnHeapMemoryBlock(array, offset, length); - } else if (obj == null) { - // we assume that to pass null pointer means off-heap - mb = new OffHeapMemoryBlock(offset, length); - } else { - throw new UnsupportedOperationException( - "Instantiate MemoryBlock for type " + obj.getClass() + " is not supported now"); - } - return mb; - } - /** - * Just instantiate the sub-block with the same type of MemoryBlock with the new size and relative - * offset from the original offset. The data is not copied. - * If parameters are invalid, an exception is thrown. + * Creates a memory block pointing to the memory used by the long array. */ - public abstract MemoryBlock subBlock(long offset, long size); - - protected void checkSubBlockRange(long offset, long size) { - if (offset < 0 || size < 0) { - throw new ArrayIndexOutOfBoundsException( - "Size " + size + " and offset " + offset + " must be non-negative"); - } - if (offset + size > length) { - throw new ArrayIndexOutOfBoundsException("The sum of size " + size + " and offset " + - offset + " should not be larger than the length " + length + " in the MemoryBlock"); - } + public static MemoryBlock fromLongArray(final long[] array) { + return new MemoryBlock(array, Platform.LONG_ARRAY_OFFSET, array.length * 8L); } /** - * getXXX/putXXX does not ensure guarantee behavior if the offset is invalid. e.g cause illegal - * memory access, throw an exception, or etc. - * getXXX/putXXX uses an index based on this.offset that includes the size of metadata such as - * JVM object header. The offset is 0-based and is expected as an logical offset in the memory - * block. + * Fills the memory block with the specified byte value. */ - public abstract int getInt(long offset); - - public abstract void putInt(long offset, int value); - - public abstract boolean getBoolean(long offset); - - public abstract void putBoolean(long offset, boolean value); - - public abstract byte getByte(long offset); - - public abstract void putByte(long offset, byte value); - - public abstract short getShort(long offset); - - public abstract void putShort(long offset, short value); - - public abstract long getLong(long offset); - - public abstract void putLong(long offset, long value); - - public abstract float getFloat(long offset); - - public abstract void putFloat(long offset, float value); - - public abstract double getDouble(long offset); - - public abstract void putDouble(long offset, double value); - - public static final void copyMemory( - MemoryBlock src, long srcOffset, MemoryBlock dst, long dstOffset, long length) { - assert(srcOffset + length <= src.length && dstOffset + length <= dst.length); - Platform.copyMemory(src.getBaseObject(), src.getBaseOffset() + srcOffset, - dst.getBaseObject(), dst.getBaseOffset() + dstOffset, length); - } - - public static final void copyMemory(MemoryBlock src, MemoryBlock dst, long length) { - assert(length <= src.length && length <= dst.length); - Platform.copyMemory(src.getBaseObject(), src.getBaseOffset(), - dst.getBaseObject(), dst.getBaseOffset(), length); - } - - public final void copyFrom(Object src, long srcOffset, long dstOffset, long length) { - assert(length <= this.length - srcOffset); - Platform.copyMemory(src, srcOffset, obj, offset + dstOffset, length); - } - - public final void writeTo(long srcOffset, Object dst, long dstOffset, long length) { - assert(length <= this.length - srcOffset); - Platform.copyMemory(obj, offset + srcOffset, dst, dstOffset, length); + public void fill(byte value) { + Platform.setMemory(obj, offset, length, value); } } diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryLocation.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryLocation.java new file mode 100644 index 0000000000000..74ebc87dc978c --- /dev/null +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryLocation.java @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.unsafe.memory; + +import javax.annotation.Nullable; + +/** + * A memory location. Tracked either by a memory address (with off-heap allocation), + * or by an offset from a JVM object (in-heap allocation). + */ +public class MemoryLocation { + + @Nullable + Object obj; + + long offset; + + public MemoryLocation(@Nullable Object obj, long offset) { + this.obj = obj; + this.offset = offset; + } + + public MemoryLocation() { + this(null, 0); + } + + public void setObjAndOffset(Object newObj, long newOffset) { + this.obj = newObj; + this.offset = newOffset; + } + + public final Object getBaseObject() { + return obj; + } + + public final long getBaseOffset() { + return offset; + } +} diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/OffHeapMemoryBlock.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/OffHeapMemoryBlock.java deleted file mode 100644 index 3431b08980eb8..0000000000000 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/OffHeapMemoryBlock.java +++ /dev/null @@ -1,105 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.unsafe.memory; - -import org.apache.spark.unsafe.Platform; - -public class OffHeapMemoryBlock extends MemoryBlock { - public static final OffHeapMemoryBlock NULL = new OffHeapMemoryBlock(0, 0); - - public OffHeapMemoryBlock(long address, long size) { - super(null, address, size); - } - - @Override - public MemoryBlock subBlock(long offset, long size) { - checkSubBlockRange(offset, size); - if (offset == 0 && size == this.size()) return this; - return new OffHeapMemoryBlock(this.offset + offset, size); - } - - @Override - public final int getInt(long offset) { - return Platform.getInt(null, this.offset + offset); - } - - @Override - public final void putInt(long offset, int value) { - Platform.putInt(null, this.offset + offset, value); - } - - @Override - public final boolean getBoolean(long offset) { - return Platform.getBoolean(null, this.offset + offset); - } - - @Override - public final void putBoolean(long offset, boolean value) { - Platform.putBoolean(null, this.offset + offset, value); - } - - @Override - public final byte getByte(long offset) { - return Platform.getByte(null, this.offset + offset); - } - - @Override - public final void putByte(long offset, byte value) { - Platform.putByte(null, this.offset + offset, value); - } - - @Override - public final short getShort(long offset) { - return Platform.getShort(null, this.offset + offset); - } - - @Override - public final void putShort(long offset, short value) { - Platform.putShort(null, this.offset + offset, value); - } - - @Override - public final long getLong(long offset) { - return Platform.getLong(null, this.offset + offset); - } - - @Override - public final void putLong(long offset, long value) { - Platform.putLong(null, this.offset + offset, value); - } - - @Override - public final float getFloat(long offset) { - return Platform.getFloat(null, this.offset + offset); - } - - @Override - public final void putFloat(long offset, float value) { - Platform.putFloat(null, this.offset + offset, value); - } - - @Override - public final double getDouble(long offset) { - return Platform.getDouble(null, this.offset + offset); - } - - @Override - public final void putDouble(long offset, double value) { - Platform.putDouble(null, this.offset + offset, value); - } -} diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/OnHeapMemoryBlock.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/OnHeapMemoryBlock.java deleted file mode 100644 index ee42bc27c9c5f..0000000000000 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/OnHeapMemoryBlock.java +++ /dev/null @@ -1,132 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.unsafe.memory; - -import com.google.common.primitives.Ints; - -import org.apache.spark.unsafe.Platform; - -/** - * A consecutive block of memory with a long array on Java heap. - */ -public final class OnHeapMemoryBlock extends MemoryBlock { - - private final long[] array; - - public OnHeapMemoryBlock(long[] obj, long offset, long size) { - super(obj, offset, size); - this.array = obj; - assert(offset + size <= obj.length * 8L + Platform.LONG_ARRAY_OFFSET) : - "The sum of size " + size + " and offset " + offset + " should not be larger than " + - "the size of the given memory space " + (obj.length * 8L + Platform.LONG_ARRAY_OFFSET); - } - - public OnHeapMemoryBlock(long size) { - this(new long[Ints.checkedCast((size + 7) / 8)], Platform.LONG_ARRAY_OFFSET, size); - } - - @Override - public MemoryBlock subBlock(long offset, long size) { - checkSubBlockRange(offset, size); - if (offset == 0 && size == this.size()) return this; - return new OnHeapMemoryBlock(array, this.offset + offset, size); - } - - public long[] getLongArray() { return array; } - - /** - * Creates a memory block pointing to the memory used by the long array. - */ - public static OnHeapMemoryBlock fromArray(final long[] array) { - return new OnHeapMemoryBlock(array, Platform.LONG_ARRAY_OFFSET, array.length * 8L); - } - - public static OnHeapMemoryBlock fromArray(final long[] array, long size) { - return new OnHeapMemoryBlock(array, Platform.LONG_ARRAY_OFFSET, size); - } - - @Override - public int getInt(long offset) { - return Platform.getInt(array, this.offset + offset); - } - - @Override - public void putInt(long offset, int value) { - Platform.putInt(array, this.offset + offset, value); - } - - @Override - public boolean getBoolean(long offset) { - return Platform.getBoolean(array, this.offset + offset); - } - - @Override - public void putBoolean(long offset, boolean value) { - Platform.putBoolean(array, this.offset + offset, value); - } - - @Override - public byte getByte(long offset) { - return Platform.getByte(array, this.offset + offset); - } - - @Override - public void putByte(long offset, byte value) { - Platform.putByte(array, this.offset + offset, value); - } - - @Override - public short getShort(long offset) { - return Platform.getShort(array, this.offset + offset); - } - - @Override - public void putShort(long offset, short value) { - Platform.putShort(array, this.offset + offset, value); - } - - @Override - public long getLong(long offset) { - return Platform.getLong(array, this.offset + offset); - } - - @Override - public void putLong(long offset, long value) { - Platform.putLong(array, this.offset + offset, value); - } - - @Override - public float getFloat(long offset) { - return Platform.getFloat(array, this.offset + offset); - } - - @Override - public void putFloat(long offset, float value) { - Platform.putFloat(array, this.offset + offset, value); - } - - @Override - public double getDouble(long offset) { - return Platform.getDouble(array, this.offset + offset); - } - - @Override - public void putDouble(long offset, double value) { - Platform.putDouble(array, this.offset + offset, value); - } -} diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/UnsafeMemoryAllocator.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/UnsafeMemoryAllocator.java index 5310bdf2779a9..4368fb615ba1e 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/UnsafeMemoryAllocator.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/UnsafeMemoryAllocator.java @@ -25,9 +25,9 @@ public class UnsafeMemoryAllocator implements MemoryAllocator { @Override - public OffHeapMemoryBlock allocate(long size) throws OutOfMemoryError { + public MemoryBlock allocate(long size) throws OutOfMemoryError { long address = Platform.allocateMemory(size); - OffHeapMemoryBlock memory = new OffHeapMemoryBlock(address, size); + MemoryBlock memory = new MemoryBlock(null, address, size); if (MemoryAllocator.MEMORY_DEBUG_FILL_ENABLED) { memory.fill(MemoryAllocator.MEMORY_DEBUG_FILL_CLEAN_VALUE); } @@ -36,25 +36,22 @@ public OffHeapMemoryBlock allocate(long size) throws OutOfMemoryError { @Override public void free(MemoryBlock memory) { - assert(memory instanceof OffHeapMemoryBlock) : - "UnsafeMemoryAllocator can only free OffHeapMemoryBlock."; - if (memory == OffHeapMemoryBlock.NULL) return; - assert (memory.getPageNumber() != MemoryBlock.FREED_IN_ALLOCATOR_PAGE_NUMBER) : + assert (memory.obj == null) : + "baseObject not null; are you trying to use the off-heap allocator to free on-heap memory?"; + assert (memory.pageNumber != MemoryBlock.FREED_IN_ALLOCATOR_PAGE_NUMBER) : "page has already been freed"; - assert ((memory.getPageNumber() == MemoryBlock.NO_PAGE_NUMBER) - || (memory.getPageNumber() == MemoryBlock.FREED_IN_TMM_PAGE_NUMBER)) : + assert ((memory.pageNumber == MemoryBlock.NO_PAGE_NUMBER) + || (memory.pageNumber == MemoryBlock.FREED_IN_TMM_PAGE_NUMBER)) : "TMM-allocated pages must be freed via TMM.freePage(), not directly in allocator free()"; if (MemoryAllocator.MEMORY_DEBUG_FILL_ENABLED) { memory.fill(MemoryAllocator.MEMORY_DEBUG_FILL_FREED_VALUE); } - Platform.freeMemory(memory.offset); - // As an additional layer of defense against use-after-free bugs, we mutate the // MemoryBlock to reset its pointer. - memory.resetObjAndOffset(); + memory.offset = 0; // Mark the page as freed (so we can detect double-frees). - memory.setPageNumber(MemoryBlock.FREED_IN_ALLOCATOR_PAGE_NUMBER); + memory.pageNumber = MemoryBlock.FREED_IN_ALLOCATOR_PAGE_NUMBER; } } diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java index e91fc4391425c..dff4a73f3e9da 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java @@ -34,8 +34,6 @@ import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.array.ByteArrayMethods; import org.apache.spark.unsafe.hash.Murmur3_x86_32; -import org.apache.spark.unsafe.memory.ByteArrayMemoryBlock; -import org.apache.spark.unsafe.memory.MemoryBlock; import static org.apache.spark.unsafe.Platform.*; @@ -53,13 +51,12 @@ public final class UTF8String implements Comparable, Externalizable, // These are only updated by readExternal() or read() @Nonnull - private MemoryBlock base; - // While numBytes has the same value as base.size(), to keep as int avoids cast from long to int + private Object base; + private long offset; private int numBytes; - public MemoryBlock getMemoryBlock() { return base; } - public Object getBaseObject() { return base.getBaseObject(); } - public long getBaseOffset() { return base.getBaseOffset(); } + public Object getBaseObject() { return base; } + public long getBaseOffset() { return offset; } /** * A char in UTF-8 encoding can take 1-4 bytes depending on the first byte which @@ -112,8 +109,7 @@ public final class UTF8String implements Comparable, Externalizable, */ public static UTF8String fromBytes(byte[] bytes) { if (bytes != null) { - return new UTF8String( - new ByteArrayMemoryBlock(bytes, BYTE_ARRAY_OFFSET, bytes.length)); + return new UTF8String(bytes, BYTE_ARRAY_OFFSET, bytes.length); } else { return null; } @@ -126,13 +122,19 @@ public static UTF8String fromBytes(byte[] bytes) { */ public static UTF8String fromBytes(byte[] bytes, int offset, int numBytes) { if (bytes != null) { - return new UTF8String( - new ByteArrayMemoryBlock(bytes, BYTE_ARRAY_OFFSET + offset, numBytes)); + return new UTF8String(bytes, BYTE_ARRAY_OFFSET + offset, numBytes); } else { return null; } } + /** + * Creates an UTF8String from given address (base and offset) and length. + */ + public static UTF8String fromAddress(Object base, long offset, int numBytes) { + return new UTF8String(base, offset, numBytes); + } + /** * Creates an UTF8String from String. */ @@ -149,13 +151,16 @@ public static UTF8String blankString(int length) { return fromBytes(spaces); } - public UTF8String(MemoryBlock base) { + protected UTF8String(Object base, long offset, int numBytes) { this.base = base; - this.numBytes = Ints.checkedCast(base.size()); + this.offset = offset; + this.numBytes = numBytes; } // for serialization - public UTF8String() {} + public UTF8String() { + this(null, 0, 0); + } /** * Writes the content of this string into a memory address, identified by an object and an offset. @@ -163,7 +168,7 @@ public UTF8String() {} * bytes in this string. */ public void writeToMemory(Object target, long targetOffset) { - base.writeTo(0, target, targetOffset, numBytes); + Platform.copyMemory(base, offset, target, targetOffset, numBytes); } public void writeTo(ByteBuffer buffer) { @@ -183,9 +188,8 @@ public void writeTo(ByteBuffer buffer) { */ @Nonnull public ByteBuffer getByteBuffer() { - long offset = base.getBaseOffset(); - if (base instanceof ByteArrayMemoryBlock && offset >= BYTE_ARRAY_OFFSET) { - final byte[] bytes = ((ByteArrayMemoryBlock) base).getByteArray(); + if (base instanceof byte[] && offset >= BYTE_ARRAY_OFFSET) { + final byte[] bytes = (byte[]) base; // the offset includes an object header... this is only needed for unsafe copies final long arrayOffset = offset - BYTE_ARRAY_OFFSET; @@ -252,12 +256,12 @@ public long getPrefix() { long mask = 0; if (IS_LITTLE_ENDIAN) { if (numBytes >= 8) { - p = base.getLong(0); + p = Platform.getLong(base, offset); } else if (numBytes > 4) { - p = base.getLong(0); + p = Platform.getLong(base, offset); mask = (1L << (8 - numBytes) * 8) - 1; } else if (numBytes > 0) { - p = (long) base.getInt(0); + p = (long) Platform.getInt(base, offset); mask = (1L << (8 - numBytes) * 8) - 1; } else { p = 0; @@ -266,12 +270,12 @@ public long getPrefix() { } else { // byteOrder == ByteOrder.BIG_ENDIAN if (numBytes >= 8) { - p = base.getLong(0); + p = Platform.getLong(base, offset); } else if (numBytes > 4) { - p = base.getLong(0); + p = Platform.getLong(base, offset); mask = (1L << (8 - numBytes) * 8) - 1; } else if (numBytes > 0) { - p = ((long) base.getInt(0)) << 32; + p = ((long) Platform.getInt(base, offset)) << 32; mask = (1L << (8 - numBytes) * 8) - 1; } else { p = 0; @@ -286,13 +290,12 @@ public long getPrefix() { */ public byte[] getBytes() { // avoid copy if `base` is `byte[]` - long offset = base.getBaseOffset(); - if (offset == BYTE_ARRAY_OFFSET && base instanceof ByteArrayMemoryBlock - && (((ByteArrayMemoryBlock) base).getByteArray()).length == numBytes) { - return ((ByteArrayMemoryBlock) base).getByteArray(); + if (offset == BYTE_ARRAY_OFFSET && base instanceof byte[] + && ((byte[]) base).length == numBytes) { + return (byte[]) base; } else { byte[] bytes = new byte[numBytes]; - base.writeTo(0, bytes, BYTE_ARRAY_OFFSET, numBytes); + copyMemory(base, offset, bytes, BYTE_ARRAY_OFFSET, numBytes); return bytes; } } @@ -322,7 +325,7 @@ public UTF8String substring(final int start, final int until) { if (i > j) { byte[] bytes = new byte[i - j]; - base.writeTo(j, bytes, BYTE_ARRAY_OFFSET, i - j); + copyMemory(base, offset + j, bytes, BYTE_ARRAY_OFFSET, i - j); return fromBytes(bytes); } else { return EMPTY_UTF8; @@ -363,14 +366,14 @@ public boolean contains(final UTF8String substring) { * Returns the byte at position `i`. */ private byte getByte(int i) { - return base.getByte(i); + return Platform.getByte(base, offset + i); } private boolean matchAt(final UTF8String s, int pos) { if (s.numBytes + pos > numBytes || pos < 0) { return false; } - return ByteArrayMethods.arrayEqualsBlock(base, pos, s.base, 0, s.numBytes); + return ByteArrayMethods.arrayEquals(base, offset + pos, s.base, s.offset, s.numBytes); } public boolean startsWith(final UTF8String prefix) { @@ -497,7 +500,8 @@ public int findInSet(UTF8String match) { for (int i = 0; i < numBytes; i++) { if (getByte(i) == (byte) ',') { if (i - (lastComma + 1) == match.numBytes && - ByteArrayMethods.arrayEqualsBlock(base, lastComma + 1, match.base, 0, match.numBytes)) { + ByteArrayMethods.arrayEquals(base, offset + (lastComma + 1), match.base, match.offset, + match.numBytes)) { return n; } lastComma = i; @@ -505,7 +509,8 @@ public int findInSet(UTF8String match) { } } if (numBytes - (lastComma + 1) == match.numBytes && - ByteArrayMethods.arrayEqualsBlock(base, lastComma + 1, match.base, 0, match.numBytes)) { + ByteArrayMethods.arrayEquals(base, offset + (lastComma + 1), match.base, match.offset, + match.numBytes)) { return n; } return 0; @@ -520,7 +525,7 @@ public int findInSet(UTF8String match) { private UTF8String copyUTF8String(int start, int end) { int len = end - start + 1; byte[] newBytes = new byte[len]; - base.writeTo(start, newBytes, BYTE_ARRAY_OFFSET, len); + copyMemory(base, offset + start, newBytes, BYTE_ARRAY_OFFSET, len); return UTF8String.fromBytes(newBytes); } @@ -667,7 +672,8 @@ public UTF8String reverse() { int i = 0; // position in byte while (i < numBytes) { int len = numBytesForFirstByte(getByte(i)); - base.writeTo(i, result, BYTE_ARRAY_OFFSET + result.length - i - len, len); + copyMemory(this.base, this.offset + i, result, + BYTE_ARRAY_OFFSET + result.length - i - len, len); i += len; } @@ -681,7 +687,7 @@ public UTF8String repeat(int times) { } byte[] newBytes = new byte[numBytes * times]; - base.writeTo(0, newBytes, BYTE_ARRAY_OFFSET, numBytes); + copyMemory(this.base, this.offset, newBytes, BYTE_ARRAY_OFFSET, numBytes); int copied = 1; while (copied < times) { @@ -718,7 +724,7 @@ public int indexOf(UTF8String v, int start) { if (i + v.numBytes > numBytes) { return -1; } - if (ByteArrayMethods.arrayEqualsBlock(base, i, v.base, 0, v.numBytes)) { + if (ByteArrayMethods.arrayEquals(base, offset + i, v.base, v.offset, v.numBytes)) { return c; } i += numBytesForFirstByte(getByte(i)); @@ -734,7 +740,7 @@ public int indexOf(UTF8String v, int start) { private int find(UTF8String str, int start) { assert (str.numBytes > 0); while (start <= numBytes - str.numBytes) { - if (ByteArrayMethods.arrayEqualsBlock(base, start, str.base, 0, str.numBytes)) { + if (ByteArrayMethods.arrayEquals(base, offset + start, str.base, str.offset, str.numBytes)) { return start; } start += 1; @@ -748,7 +754,7 @@ private int find(UTF8String str, int start) { private int rfind(UTF8String str, int start) { assert (str.numBytes > 0); while (start >= 0) { - if (ByteArrayMethods.arrayEqualsBlock(base, start, str.base, 0, str.numBytes)) { + if (ByteArrayMethods.arrayEquals(base, offset + start, str.base, str.offset, str.numBytes)) { return start; } start -= 1; @@ -781,7 +787,7 @@ public UTF8String subStringIndex(UTF8String delim, int count) { return EMPTY_UTF8; } byte[] bytes = new byte[idx]; - base.writeTo(0, bytes, BYTE_ARRAY_OFFSET, idx); + copyMemory(base, offset, bytes, BYTE_ARRAY_OFFSET, idx); return fromBytes(bytes); } else { @@ -801,7 +807,7 @@ public UTF8String subStringIndex(UTF8String delim, int count) { } int size = numBytes - delim.numBytes - idx; byte[] bytes = new byte[size]; - base.writeTo(idx + delim.numBytes, bytes, BYTE_ARRAY_OFFSET, size); + copyMemory(base, offset + idx + delim.numBytes, bytes, BYTE_ARRAY_OFFSET, size); return fromBytes(bytes); } } @@ -824,15 +830,15 @@ public UTF8String rpad(int len, UTF8String pad) { UTF8String remain = pad.substring(0, spaces - padChars * count); byte[] data = new byte[this.numBytes + pad.numBytes * count + remain.numBytes]; - base.writeTo(0, data, BYTE_ARRAY_OFFSET, this.numBytes); + copyMemory(this.base, this.offset, data, BYTE_ARRAY_OFFSET, this.numBytes); int offset = this.numBytes; int idx = 0; while (idx < count) { - pad.base.writeTo(0, data, BYTE_ARRAY_OFFSET + offset, pad.numBytes); + copyMemory(pad.base, pad.offset, data, BYTE_ARRAY_OFFSET + offset, pad.numBytes); ++ idx; offset += pad.numBytes; } - remain.base.writeTo(0, data, BYTE_ARRAY_OFFSET + offset, remain.numBytes); + copyMemory(remain.base, remain.offset, data, BYTE_ARRAY_OFFSET + offset, remain.numBytes); return UTF8String.fromBytes(data); } @@ -860,13 +866,13 @@ public UTF8String lpad(int len, UTF8String pad) { int offset = 0; int idx = 0; while (idx < count) { - pad.base.writeTo(0, data, BYTE_ARRAY_OFFSET + offset, pad.numBytes); + copyMemory(pad.base, pad.offset, data, BYTE_ARRAY_OFFSET + offset, pad.numBytes); ++ idx; offset += pad.numBytes; } - remain.base.writeTo(0, data, BYTE_ARRAY_OFFSET + offset, remain.numBytes); + copyMemory(remain.base, remain.offset, data, BYTE_ARRAY_OFFSET + offset, remain.numBytes); offset += remain.numBytes; - base.writeTo(0, data, BYTE_ARRAY_OFFSET + offset, numBytes()); + copyMemory(this.base, this.offset, data, BYTE_ARRAY_OFFSET + offset, numBytes()); return UTF8String.fromBytes(data); } @@ -891,8 +897,8 @@ public static UTF8String concat(UTF8String... inputs) { int offset = 0; for (int i = 0; i < inputs.length; i++) { int len = inputs[i].numBytes; - inputs[i].base.writeTo( - 0, + copyMemory( + inputs[i].base, inputs[i].offset, result, BYTE_ARRAY_OFFSET + offset, len); offset += len; @@ -931,8 +937,8 @@ public static UTF8String concatWs(UTF8String separator, UTF8String... inputs) { for (int i = 0, j = 0; i < inputs.length; i++) { if (inputs[i] != null) { int len = inputs[i].numBytes; - inputs[i].base.writeTo( - 0, + copyMemory( + inputs[i].base, inputs[i].offset, result, BYTE_ARRAY_OFFSET + offset, len); offset += len; @@ -940,8 +946,8 @@ public static UTF8String concatWs(UTF8String separator, UTF8String... inputs) { j++; // Add separator if this is not the last input. if (j < numInputs) { - separator.base.writeTo( - 0, + copyMemory( + separator.base, separator.offset, result, BYTE_ARRAY_OFFSET + offset, separator.numBytes); offset += separator.numBytes; @@ -1215,7 +1221,7 @@ public UTF8String clone() { public UTF8String copy() { byte[] bytes = new byte[numBytes]; - base.writeTo(0, bytes, BYTE_ARRAY_OFFSET, numBytes); + copyMemory(base, offset, bytes, BYTE_ARRAY_OFFSET, numBytes); return fromBytes(bytes); } @@ -1223,10 +1229,11 @@ public UTF8String copy() { public int compareTo(@Nonnull final UTF8String other) { int len = Math.min(numBytes, other.numBytes); int wordMax = (len / 8) * 8; - MemoryBlock rbase = other.base; + long roffset = other.offset; + Object rbase = other.base; for (int i = 0; i < wordMax; i += 8) { - long left = base.getLong(i); - long right = rbase.getLong(i); + long left = getLong(base, offset + i); + long right = getLong(rbase, roffset + i); if (left != right) { if (IS_LITTLE_ENDIAN) { return Long.compareUnsigned(Long.reverseBytes(left), Long.reverseBytes(right)); @@ -1237,7 +1244,7 @@ public int compareTo(@Nonnull final UTF8String other) { } for (int i = wordMax; i < len; i++) { // In UTF-8, the byte should be unsigned, so we should compare them as unsigned int. - int res = (getByte(i) & 0xFF) - (rbase.getByte(i) & 0xFF); + int res = (getByte(i) & 0xFF) - (Platform.getByte(rbase, roffset + i) & 0xFF); if (res != 0) { return res; } @@ -1256,7 +1263,7 @@ public boolean equals(final Object other) { if (numBytes != o.numBytes) { return false; } - return ByteArrayMethods.arrayEqualsBlock(base, 0, o.base, 0, numBytes); + return ByteArrayMethods.arrayEquals(base, offset, o.base, o.offset, numBytes); } else { return false; } @@ -1312,8 +1319,8 @@ public int levenshteinDistance(UTF8String other) { num_bytes_j != numBytesForFirstByte(s.getByte(i_bytes))) { cost = 1; } else { - cost = (ByteArrayMethods.arrayEqualsBlock(t.base, j_bytes, s.base, - i_bytes, num_bytes_j)) ? 0 : 1; + cost = (ByteArrayMethods.arrayEquals(t.base, t.offset + j_bytes, s.base, + s.offset + i_bytes, num_bytes_j)) ? 0 : 1; } d[i + 1] = Math.min(Math.min(d[i] + 1, p[i + 1] + 1), p[i] + cost); } @@ -1328,7 +1335,7 @@ public int levenshteinDistance(UTF8String other) { @Override public int hashCode() { - return Murmur3_x86_32.hashUnsafeBytesBlock(base,42); + return Murmur3_x86_32.hashUnsafeBytes(base, offset, numBytes, 42); } /** @@ -1391,10 +1398,10 @@ public void writeExternal(ObjectOutput out) throws IOException { } public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException { + offset = BYTE_ARRAY_OFFSET; numBytes = in.readInt(); - byte[] bytes = new byte[numBytes]; - in.readFully(bytes); - base = ByteArrayMemoryBlock.fromArray(bytes); + base = new byte[numBytes]; + in.readFully((byte[]) base); } @Override @@ -1406,10 +1413,10 @@ public void write(Kryo kryo, Output out) { @Override public void read(Kryo kryo, Input in) { - numBytes = in.readInt(); - byte[] bytes = new byte[numBytes]; - in.read(bytes); - base = ByteArrayMemoryBlock.fromArray(bytes); + this.offset = BYTE_ARRAY_OFFSET; + this.numBytes = in.readInt(); + this.base = new byte[numBytes]; + in.read((byte[]) base); } } diff --git a/common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java index 583a148b3845d..3ad9ac7b4de9c 100644 --- a/common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java +++ b/common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java @@ -81,7 +81,7 @@ public void freeingOnHeapMemoryBlockResetsBaseObjectAndOffset() { MemoryAllocator.HEAP.free(block); Assert.assertNull(block.getBaseObject()); Assert.assertEquals(0, block.getBaseOffset()); - Assert.assertEquals(MemoryBlock.FREED_IN_ALLOCATOR_PAGE_NUMBER, block.getPageNumber()); + Assert.assertEquals(MemoryBlock.FREED_IN_ALLOCATOR_PAGE_NUMBER, block.pageNumber); } @Test @@ -92,7 +92,7 @@ public void freeingOffHeapMemoryBlockResetsOffset() { MemoryAllocator.UNSAFE.free(block); Assert.assertNull(block.getBaseObject()); Assert.assertEquals(0, block.getBaseOffset()); - Assert.assertEquals(MemoryBlock.FREED_IN_ALLOCATOR_PAGE_NUMBER, block.getPageNumber()); + Assert.assertEquals(MemoryBlock.FREED_IN_ALLOCATOR_PAGE_NUMBER, block.pageNumber); } @Test(expected = AssertionError.class) diff --git a/common/unsafe/src/test/java/org/apache/spark/unsafe/array/LongArraySuite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/array/LongArraySuite.java index 8c2e98c2bfc54..fb8e53b3348f3 100644 --- a/common/unsafe/src/test/java/org/apache/spark/unsafe/array/LongArraySuite.java +++ b/common/unsafe/src/test/java/org/apache/spark/unsafe/array/LongArraySuite.java @@ -20,13 +20,14 @@ import org.junit.Assert; import org.junit.Test; -import org.apache.spark.unsafe.memory.OnHeapMemoryBlock; +import org.apache.spark.unsafe.memory.MemoryBlock; public class LongArraySuite { @Test public void basicTest() { - LongArray arr = new LongArray(new OnHeapMemoryBlock(16)); + long[] bytes = new long[2]; + LongArray arr = new LongArray(MemoryBlock.fromLongArray(bytes)); arr.set(0, 1L); arr.set(1, 2L); arr.set(1, 3L); diff --git a/common/unsafe/src/test/java/org/apache/spark/unsafe/hash/Murmur3_x86_32Suite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/hash/Murmur3_x86_32Suite.java index d7ed005db1891..6348a73bf3895 100644 --- a/common/unsafe/src/test/java/org/apache/spark/unsafe/hash/Murmur3_x86_32Suite.java +++ b/common/unsafe/src/test/java/org/apache/spark/unsafe/hash/Murmur3_x86_32Suite.java @@ -70,24 +70,6 @@ public void testKnownBytesInputs() { Murmur3_x86_32.hashUnsafeBytes2(tes, Platform.BYTE_ARRAY_OFFSET, tes.length, 0)); } - @Test - public void testKnownWordsInputs() { - byte[] bytes = new byte[16]; - long offset = Platform.BYTE_ARRAY_OFFSET; - for (int i = 0; i < 16; i++) { - bytes[i] = 0; - } - Assert.assertEquals(-300363099, hasher.hashUnsafeWords(bytes, offset, 16, 42)); - for (int i = 0; i < 16; i++) { - bytes[i] = -1; - } - Assert.assertEquals(-1210324667, hasher.hashUnsafeWords(bytes, offset, 16, 42)); - for (int i = 0; i < 16; i++) { - bytes[i] = (byte)i; - } - Assert.assertEquals(-634919701, hasher.hashUnsafeWords(bytes, offset, 16, 42)); - } - @Test public void randomizedStressTest() { int size = 65536; diff --git a/common/unsafe/src/test/java/org/apache/spark/unsafe/memory/MemoryBlockSuite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/memory/MemoryBlockSuite.java deleted file mode 100644 index ef5ff8ee70ec0..0000000000000 --- a/common/unsafe/src/test/java/org/apache/spark/unsafe/memory/MemoryBlockSuite.java +++ /dev/null @@ -1,179 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.unsafe.memory; - -import org.apache.spark.unsafe.Platform; -import org.junit.Assert; -import org.junit.Test; - -import java.nio.ByteOrder; - -import static org.hamcrest.core.StringContains.containsString; - -public class MemoryBlockSuite { - private static final boolean bigEndianPlatform = - ByteOrder.nativeOrder().equals(ByteOrder.BIG_ENDIAN); - - private void check(MemoryBlock memory, Object obj, long offset, int length) { - memory.setPageNumber(1); - memory.fill((byte)-1); - memory.putBoolean(0, true); - memory.putByte(1, (byte)127); - memory.putShort(2, (short)257); - memory.putInt(4, 0x20000002); - memory.putLong(8, 0x1234567089ABCDEFL); - memory.putFloat(16, 1.0F); - memory.putLong(20, 0x1234567089ABCDEFL); - memory.putDouble(28, 2.0); - MemoryBlock.copyMemory(memory, 0L, memory, 36, 4); - int[] a = new int[2]; - a[0] = 0x12345678; - a[1] = 0x13579BDF; - memory.copyFrom(a, Platform.INT_ARRAY_OFFSET, 40, 8); - byte[] b = new byte[8]; - memory.writeTo(40, b, Platform.BYTE_ARRAY_OFFSET, 8); - - Assert.assertEquals(obj, memory.getBaseObject()); - Assert.assertEquals(offset, memory.getBaseOffset()); - Assert.assertEquals(length, memory.size()); - Assert.assertEquals(1, memory.getPageNumber()); - Assert.assertEquals(true, memory.getBoolean(0)); - Assert.assertEquals((byte)127, memory.getByte(1 )); - Assert.assertEquals((short)257, memory.getShort(2)); - Assert.assertEquals(0x20000002, memory.getInt(4)); - Assert.assertEquals(0x1234567089ABCDEFL, memory.getLong(8)); - Assert.assertEquals(1.0F, memory.getFloat(16), 0); - Assert.assertEquals(0x1234567089ABCDEFL, memory.getLong(20)); - Assert.assertEquals(2.0, memory.getDouble(28), 0); - Assert.assertEquals(true, memory.getBoolean(36)); - Assert.assertEquals((byte)127, memory.getByte(37 )); - Assert.assertEquals((short)257, memory.getShort(38)); - Assert.assertEquals(a[0], memory.getInt(40)); - Assert.assertEquals(a[1], memory.getInt(44)); - if (bigEndianPlatform) { - Assert.assertEquals(a[0], - ((int)b[0] & 0xff) << 24 | ((int)b[1] & 0xff) << 16 | - ((int)b[2] & 0xff) << 8 | ((int)b[3] & 0xff)); - Assert.assertEquals(a[1], - ((int)b[4] & 0xff) << 24 | ((int)b[5] & 0xff) << 16 | - ((int)b[6] & 0xff) << 8 | ((int)b[7] & 0xff)); - } else { - Assert.assertEquals(a[0], - ((int)b[3] & 0xff) << 24 | ((int)b[2] & 0xff) << 16 | - ((int)b[1] & 0xff) << 8 | ((int)b[0] & 0xff)); - Assert.assertEquals(a[1], - ((int)b[7] & 0xff) << 24 | ((int)b[6] & 0xff) << 16 | - ((int)b[5] & 0xff) << 8 | ((int)b[4] & 0xff)); - } - for (int i = 48; i < memory.size(); i++) { - Assert.assertEquals((byte) -1, memory.getByte(i)); - } - - assert(memory.subBlock(0, memory.size()) == memory); - - try { - memory.subBlock(-8, 8); - Assert.fail(); - } catch (Exception expected) { - Assert.assertThat(expected.getMessage(), containsString("non-negative")); - } - - try { - memory.subBlock(0, -8); - Assert.fail(); - } catch (Exception expected) { - Assert.assertThat(expected.getMessage(), containsString("non-negative")); - } - - try { - memory.subBlock(0, length + 8); - Assert.fail(); - } catch (Exception expected) { - Assert.assertThat(expected.getMessage(), containsString("should not be larger than")); - } - - try { - memory.subBlock(8, length - 4); - Assert.fail(); - } catch (Exception expected) { - Assert.assertThat(expected.getMessage(), containsString("should not be larger than")); - } - - try { - memory.subBlock(length + 8, 4); - Assert.fail(); - } catch (Exception expected) { - Assert.assertThat(expected.getMessage(), containsString("should not be larger than")); - } - - memory.setPageNumber(MemoryBlock.NO_PAGE_NUMBER); - } - - @Test - public void testByteArrayMemoryBlock() { - byte[] obj = new byte[56]; - long offset = Platform.BYTE_ARRAY_OFFSET; - int length = obj.length; - - MemoryBlock memory = new ByteArrayMemoryBlock(obj, offset, length); - check(memory, obj, offset, length); - - memory = ByteArrayMemoryBlock.fromArray(obj); - check(memory, obj, offset, length); - - obj = new byte[112]; - memory = new ByteArrayMemoryBlock(obj, offset, length); - check(memory, obj, offset, length); - } - - @Test - public void testOnHeapMemoryBlock() { - long[] obj = new long[7]; - long offset = Platform.LONG_ARRAY_OFFSET; - int length = obj.length * 8; - - MemoryBlock memory = new OnHeapMemoryBlock(obj, offset, length); - check(memory, obj, offset, length); - - memory = OnHeapMemoryBlock.fromArray(obj); - check(memory, obj, offset, length); - - obj = new long[14]; - memory = new OnHeapMemoryBlock(obj, offset, length); - check(memory, obj, offset, length); - } - - @Test - public void testOffHeapArrayMemoryBlock() { - MemoryAllocator memoryAllocator = new UnsafeMemoryAllocator(); - MemoryBlock memory = memoryAllocator.allocate(56); - Object obj = memory.getBaseObject(); - long offset = memory.getBaseOffset(); - int length = 56; - - check(memory, obj, offset, length); - memoryAllocator.free(memory); - - long address = Platform.allocateMemory(112); - memory = new OffHeapMemoryBlock(address, length); - obj = memory.getBaseObject(); - offset = memory.getBaseOffset(); - check(memory, obj, offset, length); - Platform.freeMemory(address); - } -} diff --git a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java index 42dda30480702..dae13f03b02ff 100644 --- a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java +++ b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java @@ -25,8 +25,7 @@ import java.util.*; import com.google.common.collect.ImmutableMap; -import org.apache.spark.unsafe.memory.ByteArrayMemoryBlock; -import org.apache.spark.unsafe.memory.OnHeapMemoryBlock; +import org.apache.spark.unsafe.Platform; import org.junit.Test; import static org.junit.Assert.*; @@ -513,6 +512,21 @@ public void soundex() { assertEquals(fromString("世界千世").soundex(), fromString("世界千世")); } + @Test + public void writeToOutputStreamUnderflow() throws IOException { + // offset underflow is apparently supported? + final ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); + final byte[] test = "01234567".getBytes(StandardCharsets.UTF_8); + + for (int i = 1; i <= Platform.BYTE_ARRAY_OFFSET; ++i) { + UTF8String.fromAddress(test, Platform.BYTE_ARRAY_OFFSET - i, test.length + i) + .writeTo(outputStream); + final ByteBuffer buffer = ByteBuffer.wrap(outputStream.toByteArray(), i, test.length); + assertEquals("01234567", StandardCharsets.UTF_8.decode(buffer).toString()); + outputStream.reset(); + } + } + @Test public void writeToOutputStreamSlice() throws IOException { final ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); @@ -520,7 +534,7 @@ public void writeToOutputStreamSlice() throws IOException { for (int i = 0; i < test.length; ++i) { for (int j = 0; j < test.length - i; ++j) { - new UTF8String(ByteArrayMemoryBlock.fromArray(test).subBlock(i, j)) + UTF8String.fromAddress(test, Platform.BYTE_ARRAY_OFFSET + i, j) .writeTo(outputStream); assertArrayEquals(Arrays.copyOfRange(test, i, i + j), outputStream.toByteArray()); @@ -551,7 +565,7 @@ public void writeToOutputStreamOverflow() throws IOException { for (final long offset : offsets) { try { - new UTF8String(ByteArrayMemoryBlock.fromArray(test).subBlock(offset, test.length)) + fromAddress(test, BYTE_ARRAY_OFFSET + offset, test.length) .writeTo(outputStream); throw new IllegalStateException(Long.toString(offset)); @@ -578,25 +592,26 @@ public void writeToOutputStream() throws IOException { } @Test - public void writeToOutputStreamLongArray() throws IOException { + public void writeToOutputStreamIntArray() throws IOException { // verify that writes work on objects that are not byte arrays - final ByteBuffer buffer = StandardCharsets.UTF_8.encode("3千大千世界"); + final ByteBuffer buffer = StandardCharsets.UTF_8.encode("大千世界"); buffer.position(0); buffer.order(ByteOrder.nativeOrder()); final int length = buffer.limit(); - assertEquals(16, length); + assertEquals(12, length); - final int longs = length / 8; - final long[] array = new long[longs]; + final int ints = length / 4; + final int[] array = new int[ints]; - for (int i = 0; i < longs; ++i) { - array[i] = buffer.getLong(); + for (int i = 0; i < ints; ++i) { + array[i] = buffer.getInt(); } final ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); - new UTF8String(OnHeapMemoryBlock.fromArray(array)).writeTo(outputStream); - assertEquals("3千大千世界", outputStream.toString("UTF-8")); + fromAddress(array, Platform.INT_ARRAY_OFFSET, length) + .writeTo(outputStream); + assertEquals("大千世界", outputStream.toString("UTF-8")); } @Test diff --git a/common/unsafe/src/test/scala/org/apache/spark/unsafe/types/UTF8StringPropertyCheckSuite.scala b/common/unsafe/src/test/scala/org/apache/spark/unsafe/types/UTF8StringPropertyCheckSuite.scala index 48004e812a8bf..7d3331f44f015 100644 --- a/common/unsafe/src/test/scala/org/apache/spark/unsafe/types/UTF8StringPropertyCheckSuite.scala +++ b/common/unsafe/src/test/scala/org/apache/spark/unsafe/types/UTF8StringPropertyCheckSuite.scala @@ -192,8 +192,8 @@ class UTF8StringPropertyCheckSuite extends FunSuite with GeneratorDrivenProperty val nullalbeSeq = Gen.listOf(Gen.oneOf[String](null: String, randomString)) test("concat") { - def concat(orgin: Seq[String]): String = - if (orgin.contains(null)) null else orgin.mkString + def concat(origin: Seq[String]): String = + if (origin.contains(null)) null else origin.mkString forAll { (inputs: Seq[String]) => assert(UTF8String.concat(inputs.map(toUTF8): _*) === toUTF8(inputs.mkString)) diff --git a/core/pom.xml b/core/pom.xml index 220522d3a8296..5fa3a86de6b01 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -56,7 +56,7 @@ org.apache.xbean - xbean-asm5-shaded + xbean-asm6-shaded org.apache.hadoop @@ -88,8 +88,8 @@ ${project.version} - net.java.dev.jets3t - jets3t + javax.activation + activation org.apache.curator diff --git a/core/src/main/java/org/apache/spark/SparkFirehoseListener.java b/core/src/main/java/org/apache/spark/SparkFirehoseListener.java index 94c5c11b61a50..731f6fc767dfd 100644 --- a/core/src/main/java/org/apache/spark/SparkFirehoseListener.java +++ b/core/src/main/java/org/apache/spark/SparkFirehoseListener.java @@ -103,6 +103,12 @@ public final void onExecutorMetricsUpdate( onEvent(executorMetricsUpdate); } + @Override + public final void onStageExecutorMetrics( + SparkListenerStageExecutorMetrics executorMetrics) { + onEvent(executorMetrics); + } + @Override public final void onExecutorAdded(SparkListenerExecutorAdded executorAdded) { onEvent(executorAdded); diff --git a/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java b/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java index 8651a639c07f7..d07faf1da1248 100644 --- a/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java +++ b/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java @@ -311,7 +311,7 @@ public MemoryBlock allocatePage(long size, MemoryConsumer consumer) { // this could trigger spilling to free some pages. return allocatePage(size, consumer); } - page.setPageNumber(pageNumber); + page.pageNumber = pageNumber; pageTable[pageNumber] = page; if (logger.isTraceEnabled()) { logger.trace("Allocate page number {} ({} bytes)", pageNumber, acquired); @@ -323,25 +323,25 @@ public MemoryBlock allocatePage(long size, MemoryConsumer consumer) { * Free a block of memory allocated via {@link TaskMemoryManager#allocatePage}. */ public void freePage(MemoryBlock page, MemoryConsumer consumer) { - assert (page.getPageNumber() != MemoryBlock.NO_PAGE_NUMBER) : + assert (page.pageNumber != MemoryBlock.NO_PAGE_NUMBER) : "Called freePage() on memory that wasn't allocated with allocatePage()"; - assert (page.getPageNumber() != MemoryBlock.FREED_IN_ALLOCATOR_PAGE_NUMBER) : + assert (page.pageNumber != MemoryBlock.FREED_IN_ALLOCATOR_PAGE_NUMBER) : "Called freePage() on a memory block that has already been freed"; - assert (page.getPageNumber() != MemoryBlock.FREED_IN_TMM_PAGE_NUMBER) : + assert (page.pageNumber != MemoryBlock.FREED_IN_TMM_PAGE_NUMBER) : "Called freePage() on a memory block that has already been freed"; - assert(allocatedPages.get(page.getPageNumber())); - pageTable[page.getPageNumber()] = null; + assert(allocatedPages.get(page.pageNumber)); + pageTable[page.pageNumber] = null; synchronized (this) { - allocatedPages.clear(page.getPageNumber()); + allocatedPages.clear(page.pageNumber); } if (logger.isTraceEnabled()) { - logger.trace("Freed page number {} ({} bytes)", page.getPageNumber(), page.size()); + logger.trace("Freed page number {} ({} bytes)", page.pageNumber, page.size()); } long pageSize = page.size(); // Clear the page number before passing the block to the MemoryAllocator's free(). // Doing this allows the MemoryAllocator to detect when a TaskMemoryManager-managed // page has been inappropriately directly freed without calling TMM.freePage(). - page.setPageNumber(MemoryBlock.FREED_IN_TMM_PAGE_NUMBER); + page.pageNumber = MemoryBlock.FREED_IN_TMM_PAGE_NUMBER; memoryManager.tungstenMemoryAllocator().free(page); releaseExecutionMemory(pageSize, consumer); } @@ -363,7 +363,7 @@ public long encodePageNumberAndOffset(MemoryBlock page, long offsetInPage) { // relative to the page's base offset; this relative offset will fit in 51 bits. offsetInPage -= page.getBaseOffset(); } - return encodePageNumberAndOffset(page.getPageNumber(), offsetInPage); + return encodePageNumberAndOffset(page.pageNumber, offsetInPage); } @VisibleForTesting @@ -434,7 +434,7 @@ public long cleanUpAllAllocatedMemory() { for (MemoryBlock page : pageTable) { if (page != null) { logger.debug("unreleased page: " + page + " in task " + taskAttemptId); - page.setPageNumber(MemoryBlock.FREED_IN_TMM_PAGE_NUMBER); + page.pageNumber = MemoryBlock.FREED_IN_TMM_PAGE_NUMBER; memoryManager.tungstenMemoryAllocator().free(page); } } diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java index 323a5d3c52831..e3bd5496cf5ba 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java @@ -125,7 +125,7 @@ public void write(Iterator> records) throws IOException { if (!records.hasNext()) { partitionLengths = new long[numPartitions]; shuffleBlockResolver.writeIndexFileAndCommit(shuffleId, mapId, partitionLengths, null); - mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths); + mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths, 0); return; } final SerializerInstance serInstance = serializer.newInstance(); @@ -167,7 +167,8 @@ public void write(Iterator> records) throws IOException { logger.error("Error while deleting temp file {}", tmp.getAbsolutePath()); } } - mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths); + mapStatus = MapStatus$.MODULE$.apply( + blockManager.shuffleServerId(), partitionLengths, writeMetrics.recordsWritten()); } @VisibleForTesting diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java index c3a07b2abf896..c7d2db4217d96 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java @@ -43,6 +43,7 @@ import org.apache.spark.storage.FileSegment; import org.apache.spark.storage.TempShuffleBlockId; import org.apache.spark.unsafe.Platform; +import org.apache.spark.unsafe.UnsafeAlignedOffset; import org.apache.spark.unsafe.array.LongArray; import org.apache.spark.unsafe.memory.MemoryBlock; import org.apache.spark.util.Utils; @@ -184,6 +185,7 @@ private void writeSortedFile(boolean isLastFile) { blockManager.getDiskWriter(blockId, file, ser, fileBufferSizeBytes, writeMetricsToUse); int currentPartition = -1; + final int uaoSize = UnsafeAlignedOffset.getUaoSize(); while (sortedRecords.hasNext()) { sortedRecords.loadNext(); final int partition = sortedRecords.packedRecordPointer.getPartitionId(); @@ -200,8 +202,8 @@ private void writeSortedFile(boolean isLastFile) { final long recordPointer = sortedRecords.packedRecordPointer.getRecordPointer(); final Object recordPage = taskMemoryManager.getPage(recordPointer); final long recordOffsetInPage = taskMemoryManager.getOffsetInPage(recordPointer); - int dataRemaining = Platform.getInt(recordPage, recordOffsetInPage); - long recordReadPosition = recordOffsetInPage + 4; // skip over record length + int dataRemaining = UnsafeAlignedOffset.getSize(recordPage, recordOffsetInPage); + long recordReadPosition = recordOffsetInPage + uaoSize; // skip over record length while (dataRemaining > 0) { final int toTransfer = Math.min(diskWriteBufferSize, dataRemaining); Platform.copyMemory( @@ -389,15 +391,16 @@ public void insertRecord(Object recordBase, long recordOffset, int length, int p } growPointerArrayIfNecessary(); - // Need 4 bytes to store the record length. - final int required = length + 4; + final int uaoSize = UnsafeAlignedOffset.getUaoSize(); + // Need 4 or 8 bytes to store the record length. + final int required = length + uaoSize; acquireNewPageIfNecessary(required); assert(currentPage != null); final Object base = currentPage.getBaseObject(); final long recordAddress = taskMemoryManager.encodePageNumberAndOffset(currentPage, pageCursor); - Platform.putInt(base, pageCursor, length); - pageCursor += 4; + UnsafeAlignedOffset.putSize(base, pageCursor, length); + pageCursor += uaoSize; Platform.copyMemory(recordBase, recordOffset, base, pageCursor, length); pageCursor += length; inMemSorter.insertRecord(recordAddress, partitionId); diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorter.java b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorter.java index 8f49859746b89..0d069125dc60e 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorter.java @@ -20,6 +20,7 @@ import java.util.Comparator; import org.apache.spark.memory.MemoryConsumer; +import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.array.LongArray; import org.apache.spark.unsafe.memory.MemoryBlock; import org.apache.spark.util.collection.Sorter; @@ -65,7 +66,7 @@ public int compare(PackedRecordPointer left, PackedRecordPointer right) { */ private int usableCapacity = 0; - private int initialSize; + private final int initialSize; ShuffleInMemorySorter(MemoryConsumer consumer, int initialSize, boolean useRadixSort) { this.consumer = consumer; @@ -94,17 +95,31 @@ public int numRecords() { } public void reset() { + // Reset `pos` here so that `spill` triggered by the below `allocateArray` will be no-op. + pos = 0; if (consumer != null) { consumer.freeArray(array); + // As `array` has been released, we should set it to `null` to avoid accessing it before + // `allocateArray` returns. `usableCapacity` is also set to `0` to avoid any codes writing + // data to `ShuffleInMemorySorter` when `array` is `null` (e.g., in + // ShuffleExternalSorter.growPointerArrayIfNecessary, we may try to access + // `ShuffleInMemorySorter` when `allocateArray` throws SparkOutOfMemoryError). + array = null; + usableCapacity = 0; array = consumer.allocateArray(initialSize); usableCapacity = getUsableCapacity(); } - pos = 0; } public void expandPointerArray(LongArray newArray) { assert(newArray.size() > array.size()); - MemoryBlock.copyMemory(array.memoryBlock(), newArray.memoryBlock(), pos * 8L); + Platform.copyMemory( + array.getBaseObject(), + array.getBaseOffset(), + newArray.getBaseObject(), + newArray.getBaseOffset(), + pos * 8L + ); consumer.freeArray(array); array = newArray; usableCapacity = getUsableCapacity(); @@ -173,7 +188,10 @@ public ShuffleSorterIterator getSortedIterator() { PackedRecordPointer.PARTITION_ID_START_BYTE_INDEX, PackedRecordPointer.PARTITION_ID_END_BYTE_INDEX, false, false); } else { - MemoryBlock unused = array.memoryBlock().subBlock(pos * 8L, (array.size() - pos) * 8L); + MemoryBlock unused = new MemoryBlock( + array.getBaseObject(), + array.getBaseOffset() + pos * 8L, + (array.size() - pos) * 8L); LongArray buffer = new LongArray(unused); Sorter sorter = new Sorter<>(new ShuffleSortDataFormat(buffer)); diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleSortDataFormat.java b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleSortDataFormat.java index 254449e95443e..717bdd79d47ef 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleSortDataFormat.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleSortDataFormat.java @@ -17,8 +17,8 @@ package org.apache.spark.shuffle.sort; +import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.array.LongArray; -import org.apache.spark.unsafe.memory.MemoryBlock; import org.apache.spark.util.collection.SortDataFormat; final class ShuffleSortDataFormat extends SortDataFormat { @@ -60,8 +60,13 @@ public void copyElement(LongArray src, int srcPos, LongArray dst, int dstPos) { @Override public void copyRange(LongArray src, int srcPos, LongArray dst, int dstPos, int length) { - MemoryBlock.copyMemory(src.memoryBlock(), srcPos * 8L, - dst.memoryBlock(),dstPos * 8L,length * 8L); + Platform.copyMemory( + src.getBaseObject(), + src.getBaseOffset() + srcPos * 8L, + dst.getBaseObject(), + dst.getBaseOffset() + dstPos * 8L, + length * 8L + ); } @Override diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java index 4839d04522f10..069e6d5f224d7 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java @@ -248,7 +248,8 @@ void closeAndWriteOutput() throws IOException { logger.error("Error while deleting temp file {}", tmp.getAbsolutePath()); } } - mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths); + mapStatus = MapStatus$.MODULE$.apply( + blockManager.shuffleServerId(), partitionLengths, writeMetrics.recordsWritten()); } @VisibleForTesting diff --git a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java index 9a767dd739b91..9b6cbab38cbcc 100644 --- a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java +++ b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java @@ -662,7 +662,7 @@ public int getValueLength() { * It is only valid to call this method immediately after calling `lookup()` using the same key. *

*

- * The key and value must be word-aligned (that is, their sizes must multiples of 8). + * The key and value must be word-aligned (that is, their sizes must be a multiple of 8). *

*

* After calling this method, calls to `get[Key|Value]Address()` and `get[Key|Value]Length` diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java index 4fc19b1721518..5056652a2420b 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java @@ -402,7 +402,7 @@ public void insertRecord( growPointerArrayIfNecessary(); int uaoSize = UnsafeAlignedOffset.getUaoSize(); - // Need 4 bytes to store the record length. + // Need 4 or 8 bytes to store the record length. final int required = length + uaoSize; acquireNewPageIfNecessary(required); @@ -544,7 +544,7 @@ public long spill() throws IOException { // is accessing the current record. We free this page in that caller's next loadNext() // call. for (MemoryBlock page : allocatedPages) { - if (!loaded || page.getPageNumber() != + if (!loaded || page.pageNumber != ((UnsafeInMemorySorter.SortedIterator)upstream).getCurrentPageNumber()) { released += page.size(); freePage(page); diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java index 717823ebbd320..75690ae264838 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java @@ -26,6 +26,7 @@ import org.apache.spark.memory.MemoryConsumer; import org.apache.spark.memory.SparkOutOfMemoryError; import org.apache.spark.memory.TaskMemoryManager; +import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.UnsafeAlignedOffset; import org.apache.spark.unsafe.array.LongArray; import org.apache.spark.unsafe.memory.MemoryBlock; @@ -215,7 +216,12 @@ public void expandPointerArray(LongArray newArray) { if (newArray.size() < array.size()) { throw new SparkOutOfMemoryError("Not enough memory to grow pointer array"); } - MemoryBlock.copyMemory(array.memoryBlock(), newArray.memoryBlock(), pos * 8L); + Platform.copyMemory( + array.getBaseObject(), + array.getBaseOffset(), + newArray.getBaseObject(), + newArray.getBaseOffset(), + pos * 8L); consumer.freeArray(array); array = newArray; usableCapacity = getUsableCapacity(); @@ -342,7 +348,10 @@ public UnsafeSorterIterator getSortedIterator() { array, nullBoundaryPos, (pos - nullBoundaryPos) / 2L, 0, 7, radixSortSupport.sortDescending(), radixSortSupport.sortSigned()); } else { - MemoryBlock unused = array.memoryBlock().subBlock(pos * 8L, (array.size() - pos) * 8L); + MemoryBlock unused = new MemoryBlock( + array.getBaseObject(), + array.getBaseOffset() + pos * 8L, + (array.size() - pos) * 8L); LongArray buffer = new LongArray(unused); Sorter sorter = new Sorter<>(new UnsafeSortDataFormat(buffer)); diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillMerger.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillMerger.java index ff0dcc259a4ad..ab800288dcb43 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillMerger.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillMerger.java @@ -51,7 +51,7 @@ public void addSpillIfNotEmpty(UnsafeSorterIterator spillReader) throws IOExcept if (spillReader.hasNext()) { // We only add the spillReader to the priorityQueue if it is not empty. We do this to // make sure the hasNext method of UnsafeSorterIterator returned by getSortedIterator - // does not return wrong result because hasNext will returns true + // does not return wrong result because hasNext will return true // at least priorityQueue.size() times. If we allow n spillReaders in the // priorityQueue, we will have n extra empty records in the result of UnsafeSorterIterator. spillReader.loadNext(); diff --git a/core/src/main/scala/org/apache/spark/BarrierCoordinator.scala b/core/src/main/scala/org/apache/spark/BarrierCoordinator.scala new file mode 100644 index 0000000000000..6439ca5db06e9 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/BarrierCoordinator.scala @@ -0,0 +1,236 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +import java.util.{Timer, TimerTask} +import java.util.concurrent.ConcurrentHashMap +import java.util.function.{Consumer, Function} + +import scala.collection.mutable.ArrayBuffer + +import org.apache.spark.internal.Logging +import org.apache.spark.rpc.{RpcCallContext, RpcEnv, ThreadSafeRpcEndpoint} +import org.apache.spark.scheduler.{LiveListenerBus, SparkListener, SparkListenerStageCompleted} + +/** + * For each barrier stage attempt, only at most one barrier() call can be active at any time, thus + * we can use (stageId, stageAttemptId) to identify the stage attempt where the barrier() call is + * from. + */ +private case class ContextBarrierId(stageId: Int, stageAttemptId: Int) { + override def toString: String = s"Stage $stageId (Attempt $stageAttemptId)" +} + +/** + * A coordinator that handles all global sync requests from BarrierTaskContext. Each global sync + * request is generated by `BarrierTaskContext.barrier()`, and identified by + * stageId + stageAttemptId + barrierEpoch. Reply all the blocking global sync requests upon + * all the requests for a group of `barrier()` calls are received. If the coordinator is unable to + * collect enough global sync requests within a configured time, fail all the requests and return + * an Exception with timeout message. + */ +private[spark] class BarrierCoordinator( + timeoutInSecs: Long, + listenerBus: LiveListenerBus, + override val rpcEnv: RpcEnv) extends ThreadSafeRpcEndpoint with Logging { + + // TODO SPARK-25030 Create a Timer() in the mainClass submitted to SparkSubmit makes it unable to + // fetch result, we shall fix the issue. + private lazy val timer = new Timer("BarrierCoordinator barrier epoch increment timer") + + // Listen to StageCompleted event, clear corresponding ContextBarrierState. + private val listener = new SparkListener { + override def onStageCompleted(stageCompleted: SparkListenerStageCompleted): Unit = { + val stageInfo = stageCompleted.stageInfo + val barrierId = ContextBarrierId(stageInfo.stageId, stageInfo.attemptNumber) + // Clear ContextBarrierState from a finished stage attempt. + cleanupBarrierStage(barrierId) + } + } + + // Record all active stage attempts that make barrier() call(s), and the corresponding internal + // state. + private val states = new ConcurrentHashMap[ContextBarrierId, ContextBarrierState] + + override def onStart(): Unit = { + super.onStart() + listenerBus.addToStatusQueue(listener) + } + + override def onStop(): Unit = { + try { + states.forEachValue(1, clearStateConsumer) + states.clear() + listenerBus.removeListener(listener) + } finally { + super.onStop() + } + } + + /** + * Provide the current state of a barrier() call. A state is created when a new stage attempt + * sends out a barrier() call, and recycled on stage completed. + * + * @param barrierId Identifier of the barrier stage that make a barrier() call. + * @param numTasks Number of tasks of the barrier stage, all barrier() calls from the stage shall + * collect `numTasks` requests to succeed. + */ + private class ContextBarrierState( + val barrierId: ContextBarrierId, + val numTasks: Int) { + + // There may be multiple barrier() calls from a barrier stage attempt, `barrierEpoch` is used + // to identify each barrier() call. It shall get increased when a barrier() call succeeds, or + // reset when a barrier() call fails due to timeout. + private var barrierEpoch: Int = 0 + + // An array of RPCCallContexts for barrier tasks that are waiting for reply of a barrier() + // call. + private val requesters: ArrayBuffer[RpcCallContext] = new ArrayBuffer[RpcCallContext](numTasks) + + // A timer task that ensures we may timeout for a barrier() call. + private var timerTask: TimerTask = null + + // Init a TimerTask for a barrier() call. + private def initTimerTask(): Unit = { + timerTask = new TimerTask { + override def run(): Unit = synchronized { + // Timeout current barrier() call, fail all the sync requests. + requesters.foreach(_.sendFailure(new SparkException("The coordinator didn't get all " + + s"barrier sync requests for barrier epoch $barrierEpoch from $barrierId within " + + s"$timeoutInSecs second(s)."))) + cleanupBarrierStage(barrierId) + } + } + } + + // Cancel the current active TimerTask and release resources. + private def cancelTimerTask(): Unit = { + if (timerTask != null) { + timerTask.cancel() + timer.purge() + timerTask = null + } + } + + // Process the global sync request. The barrier() call succeed if collected enough requests + // within a configured time, otherwise fail all the pending requests. + def handleRequest(requester: RpcCallContext, request: RequestToSync): Unit = synchronized { + val taskId = request.taskAttemptId + val epoch = request.barrierEpoch + + // Require the number of tasks is correctly set from the BarrierTaskContext. + require(request.numTasks == numTasks, s"Number of tasks of $barrierId is " + + s"${request.numTasks} from Task $taskId, previously it was $numTasks.") + + // Check whether the epoch from the barrier tasks matches current barrierEpoch. + logInfo(s"Current barrier epoch for $barrierId is $barrierEpoch.") + if (epoch != barrierEpoch) { + requester.sendFailure(new SparkException(s"The request to sync of $barrierId with " + + s"barrier epoch $barrierEpoch has already finished. Maybe task $taskId is not " + + "properly killed.")) + } else { + // If this is the first sync message received for a barrier() call, start timer to ensure + // we may timeout for the sync. + if (requesters.isEmpty) { + initTimerTask() + timer.schedule(timerTask, timeoutInSecs * 1000) + } + // Add the requester to array of RPCCallContexts pending for reply. + requesters += requester + logInfo(s"Barrier sync epoch $barrierEpoch from $barrierId received update from Task " + + s"$taskId, current progress: ${requesters.size}/$numTasks.") + if (maybeFinishAllRequesters(requesters, numTasks)) { + // Finished current barrier() call successfully, clean up ContextBarrierState and + // increase the barrier epoch. + logInfo(s"Barrier sync epoch $barrierEpoch from $barrierId received all updates from " + + s"tasks, finished successfully.") + barrierEpoch += 1 + requesters.clear() + cancelTimerTask() + } + } + } + + // Finish all the blocking barrier sync requests from a stage attempt successfully if we + // have received all the sync requests. + private def maybeFinishAllRequesters( + requesters: ArrayBuffer[RpcCallContext], + numTasks: Int): Boolean = { + if (requesters.size == numTasks) { + requesters.foreach(_.reply(())) + true + } else { + false + } + } + + // Cleanup the internal state of a barrier stage attempt. + def clear(): Unit = synchronized { + // The global sync fails so the stage is expected to retry another attempt, all sync + // messages come from current stage attempt shall fail. + barrierEpoch = -1 + requesters.clear() + cancelTimerTask() + } + } + + // Clean up the [[ContextBarrierState]] that correspond to a specific stage attempt. + private def cleanupBarrierStage(barrierId: ContextBarrierId): Unit = { + val barrierState = states.remove(barrierId) + if (barrierState != null) { + barrierState.clear() + } + } + + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { + case request @ RequestToSync(numTasks, stageId, stageAttemptId, _, _) => + // Get or init the ContextBarrierState correspond to the stage attempt. + val barrierId = ContextBarrierId(stageId, stageAttemptId) + states.computeIfAbsent(barrierId, new Function[ContextBarrierId, ContextBarrierState] { + override def apply(key: ContextBarrierId): ContextBarrierState = + new ContextBarrierState(key, numTasks) + }) + val barrierState = states.get(barrierId) + + barrierState.handleRequest(context, request) + } + + private val clearStateConsumer = new Consumer[ContextBarrierState] { + override def accept(state: ContextBarrierState) = state.clear() + } +} + +private[spark] sealed trait BarrierCoordinatorMessage extends Serializable + +/** + * A global sync request message from BarrierTaskContext, by `barrier()` call. Each request is + * identified by stageId + stageAttemptId + barrierEpoch. + * + * @param numTasks The number of global sync requests the BarrierCoordinator shall receive + * @param stageId ID of current stage + * @param stageAttemptId ID of current stage attempt + * @param taskAttemptId Unique ID of current task + * @param barrierEpoch ID of the `barrier()` call, a task may consist multiple `barrier()` calls. + */ +private[spark] case class RequestToSync( + numTasks: Int, + stageId: Int, + stageAttemptId: Int, + taskAttemptId: Long, + barrierEpoch: Int) extends BarrierCoordinatorMessage diff --git a/core/src/main/scala/org/apache/spark/BarrierTaskContext.scala b/core/src/main/scala/org/apache/spark/BarrierTaskContext.scala new file mode 100644 index 0000000000000..90a5c4130f799 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/BarrierTaskContext.scala @@ -0,0 +1,237 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +import java.util.{Properties, Timer, TimerTask} + +import scala.concurrent.duration._ +import scala.language.postfixOps + +import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.executor.TaskMetrics +import org.apache.spark.internal.Logging +import org.apache.spark.memory.TaskMemoryManager +import org.apache.spark.metrics.source.Source +import org.apache.spark.rpc.{RpcEndpointRef, RpcTimeout} +import org.apache.spark.shuffle.FetchFailedException +import org.apache.spark.util._ + +/** + * :: Experimental :: + * A [[TaskContext]] with extra contextual info and tooling for tasks in a barrier stage. + * Use [[BarrierTaskContext#get]] to obtain the barrier context for a running barrier task. + */ +@Experimental +@Since("2.4.0") +class BarrierTaskContext private[spark] ( + taskContext: TaskContext) extends TaskContext with Logging { + + // Find the driver side RPCEndpointRef of the coordinator that handles all the barrier() calls. + private val barrierCoordinator: RpcEndpointRef = { + val env = SparkEnv.get + RpcUtils.makeDriverRef("barrierSync", env.conf, env.rpcEnv) + } + + private val timer = new Timer("Barrier task timer for barrier() calls.") + + // Local barrierEpoch that identify a barrier() call from current task, it shall be identical + // with the driver side epoch. + private var barrierEpoch = 0 + + // Number of tasks of the current barrier stage, a barrier() call must collect enough requests + // from different tasks within the same barrier stage attempt to succeed. + private lazy val numTasks = getTaskInfos().size + + /** + * :: Experimental :: + * Sets a global barrier and waits until all tasks in this stage hit this barrier. Similar to + * MPI_Barrier function in MPI, the barrier() function call blocks until all tasks in the same + * stage have reached this routine. + * + * CAUTION! In a barrier stage, each task must have the same number of barrier() calls, in all + * possible code branches. Otherwise, you may get the job hanging or a SparkException after + * timeout. Some examples of '''misuses''' are listed below: + * 1. Only call barrier() function on a subset of all the tasks in the same barrier stage, it + * shall lead to timeout of the function call. + * {{{ + * rdd.barrier().mapPartitions { iter => + * val context = BarrierTaskContext.get() + * if (context.partitionId() == 0) { + * // Do nothing. + * } else { + * context.barrier() + * } + * iter + * } + * }}} + * + * 2. Include barrier() function in a try-catch code block, this may lead to timeout of the + * second function call. + * {{{ + * rdd.barrier().mapPartitions { iter => + * val context = BarrierTaskContext.get() + * try { + * // Do something that might throw an Exception. + * doSomething() + * context.barrier() + * } catch { + * case e: Exception => logWarning("...", e) + * } + * context.barrier() + * iter + * } + * }}} + */ + @Experimental + @Since("2.4.0") + def barrier(): Unit = { + logInfo(s"Task $taskAttemptId from Stage $stageId(Attempt $stageAttemptNumber) has entered " + + s"the global sync, current barrier epoch is $barrierEpoch.") + logTrace("Current callSite: " + Utils.getCallSite()) + + val startTime = System.currentTimeMillis() + val timerTask = new TimerTask { + override def run(): Unit = { + logInfo(s"Task $taskAttemptId from Stage $stageId(Attempt $stageAttemptNumber) waiting " + + s"under the global sync since $startTime, has been waiting for " + + s"${(System.currentTimeMillis() - startTime) / 1000} seconds, current barrier epoch " + + s"is $barrierEpoch.") + } + } + // Log the update of global sync every 60 seconds. + timer.schedule(timerTask, 60000, 60000) + + try { + barrierCoordinator.askSync[Unit]( + message = RequestToSync(numTasks, stageId, stageAttemptNumber, taskAttemptId, + barrierEpoch), + // Set a fixed timeout for RPC here, so users shall get a SparkException thrown by + // BarrierCoordinator on timeout, instead of RPCTimeoutException from the RPC framework. + timeout = new RpcTimeout(31536000 /* = 3600 * 24 * 365 */ seconds, "barrierTimeout")) + barrierEpoch += 1 + logInfo(s"Task $taskAttemptId from Stage $stageId(Attempt $stageAttemptNumber) finished " + + "global sync successfully, waited for " + + s"${(System.currentTimeMillis() - startTime) / 1000} seconds, current barrier epoch is " + + s"$barrierEpoch.") + } catch { + case e: SparkException => + logInfo(s"Task $taskAttemptId from Stage $stageId(Attempt $stageAttemptNumber) failed " + + "to perform global sync, waited for " + + s"${(System.currentTimeMillis() - startTime) / 1000} seconds, current barrier epoch " + + s"is $barrierEpoch.") + throw e + } finally { + timerTask.cancel() + timer.purge() + } + } + + /** + * :: Experimental :: + * Returns [[BarrierTaskInfo]] for all tasks in this barrier stage, ordered by partition ID. + */ + @Experimental + @Since("2.4.0") + def getTaskInfos(): Array[BarrierTaskInfo] = { + val addressesStr = Option(taskContext.getLocalProperty("addresses")).getOrElse("") + addressesStr.split(",").map(_.trim()).map(new BarrierTaskInfo(_)) + } + + // delegate methods + + override def isCompleted(): Boolean = taskContext.isCompleted() + + override def isInterrupted(): Boolean = taskContext.isInterrupted() + + override def isRunningLocally(): Boolean = taskContext.isRunningLocally() + + override def addTaskCompletionListener(listener: TaskCompletionListener): this.type = { + taskContext.addTaskCompletionListener(listener) + this + } + + override def addTaskFailureListener(listener: TaskFailureListener): this.type = { + taskContext.addTaskFailureListener(listener) + this + } + + override def stageId(): Int = taskContext.stageId() + + override def stageAttemptNumber(): Int = taskContext.stageAttemptNumber() + + override def partitionId(): Int = taskContext.partitionId() + + override def attemptNumber(): Int = taskContext.attemptNumber() + + override def taskAttemptId(): Long = taskContext.taskAttemptId() + + override def getLocalProperty(key: String): String = taskContext.getLocalProperty(key) + + override def taskMetrics(): TaskMetrics = taskContext.taskMetrics() + + override def getMetricsSources(sourceName: String): Seq[Source] = { + taskContext.getMetricsSources(sourceName) + } + + override private[spark] def killTaskIfInterrupted(): Unit = taskContext.killTaskIfInterrupted() + + override private[spark] def getKillReason(): Option[String] = taskContext.getKillReason() + + override private[spark] def taskMemoryManager(): TaskMemoryManager = { + taskContext.taskMemoryManager() + } + + override private[spark] def registerAccumulator(a: AccumulatorV2[_, _]): Unit = { + taskContext.registerAccumulator(a) + } + + override private[spark] def setFetchFailed(fetchFailed: FetchFailedException): Unit = { + taskContext.setFetchFailed(fetchFailed) + } + + override private[spark] def markInterrupted(reason: String): Unit = { + taskContext.markInterrupted(reason) + } + + override private[spark] def markTaskFailed(error: Throwable): Unit = { + taskContext.markTaskFailed(error) + } + + override private[spark] def markTaskCompleted(error: Option[Throwable]): Unit = { + taskContext.markTaskCompleted(error) + } + + override private[spark] def fetchFailed: Option[FetchFailedException] = { + taskContext.fetchFailed + } + + override private[spark] def getLocalProperties: Properties = taskContext.getLocalProperties +} + +@Experimental +@Since("2.4.0") +object BarrierTaskContext { + /** + * :: Experimental :: + * Returns the currently active BarrierTaskContext. This can be called inside of user functions to + * access contextual information about running barrier tasks. + */ + @Experimental + @Since("2.4.0") + def get(): BarrierTaskContext = TaskContext.get().asInstanceOf[BarrierTaskContext] +} diff --git a/core/src/main/scala/org/apache/spark/BarrierTaskInfo.scala b/core/src/main/scala/org/apache/spark/BarrierTaskInfo.scala new file mode 100644 index 0000000000000..347239b1d7db4 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/BarrierTaskInfo.scala @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +import org.apache.spark.annotation.{Experimental, Since} + + +/** + * :: Experimental :: + * Carries all task infos of a barrier task. + * + * @param address the IPv4 address(host:port) of the executor that a barrier task is running on + */ +@Experimental +@Since("2.4.0") +class BarrierTaskInfo private[spark] (val address: String) diff --git a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala index aa363eeffffb8..c3e5b96a55884 100644 --- a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala +++ b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala @@ -25,7 +25,7 @@ import scala.util.control.{ControlThrowable, NonFatal} import com.codahale.metrics.{Gauge, MetricRegistry} -import org.apache.spark.internal.Logging +import org.apache.spark.internal.{config, Logging} import org.apache.spark.internal.config._ import org.apache.spark.metrics.source.Source import org.apache.spark.scheduler._ @@ -212,7 +212,7 @@ private[spark] class ExecutorAllocationManager( } // Require external shuffle service for dynamic allocation // Otherwise, we may lose shuffle files when killing executors - if (!conf.getBoolean("spark.shuffle.service.enabled", false) && !testing) { + if (!conf.get(config.SHUFFLE_SERVICE_ENABLED) && !testing) { throw new SparkException("Dynamic allocation of executors requires the external " + "shuffle service. You may enable this through spark.shuffle.service.enabled.") } @@ -488,9 +488,15 @@ private[spark] class ExecutorAllocationManager( newExecutorTotal = numExistingExecutors if (testing || executorsRemoved.nonEmpty) { executorsRemoved.foreach { removedExecutorId => + // If it is a cached block, it uses cachedExecutorIdleTimeoutS for timeout + val idleTimeout = if (blockManagerMaster.hasCachedBlocks(removedExecutorId)) { + cachedExecutorIdleTimeoutS + } else { + executorIdleTimeoutS + } newExecutorTotal -= 1 logInfo(s"Removing executor $removedExecutorId because it has been idle for " + - s"$executorIdleTimeoutS seconds (new desired total will be $newExecutorTotal)") + s"$idleTimeout seconds (new desired total will be $newExecutorTotal)") executorsPendingToRemove.add(removedExecutorId) } executorsRemoved diff --git a/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala b/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala index bcbc8df0d5865..ab0ae55ed357d 100644 --- a/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala +++ b/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala @@ -22,6 +22,7 @@ import java.util.concurrent.{ScheduledFuture, TimeUnit} import scala.collection.mutable import scala.concurrent.Future +import org.apache.spark.executor.ExecutorMetrics import org.apache.spark.internal.Logging import org.apache.spark.rpc.{RpcCallContext, RpcEnv, ThreadSafeRpcEndpoint} import org.apache.spark.scheduler._ @@ -37,7 +38,8 @@ import org.apache.spark.util._ private[spark] case class Heartbeat( executorId: String, accumUpdates: Array[(Long, Seq[AccumulatorV2[_, _]])], // taskId -> accumulator updates - blockManagerId: BlockManagerId) + blockManagerId: BlockManagerId, + executorUpdates: ExecutorMetrics) // executor level updates /** * An event that SparkContext uses to notify HeartbeatReceiver that SparkContext.taskScheduler is @@ -119,14 +121,14 @@ private[spark] class HeartbeatReceiver(sc: SparkContext, clock: Clock) context.reply(true) // Messages received from executors - case heartbeat @ Heartbeat(executorId, accumUpdates, blockManagerId) => + case heartbeat @ Heartbeat(executorId, accumUpdates, blockManagerId, executorMetrics) => if (scheduler != null) { if (executorLastSeen.contains(executorId)) { executorLastSeen(executorId) = clock.getTimeMillis() eventLoopThread.submit(new Runnable { override def run(): Unit = Utils.tryLogNonFatalError { val unknownExecutor = !scheduler.executorHeartbeatReceived( - executorId, accumUpdates, blockManagerId) + executorId, accumUpdates, blockManagerId, executorMetrics) val response = HeartbeatResponse(reregisterBlockManager = unknownExecutor) context.reply(response) } diff --git a/core/src/main/scala/org/apache/spark/Heartbeater.scala b/core/src/main/scala/org/apache/spark/Heartbeater.scala new file mode 100644 index 0000000000000..5ba1b9b2d828e --- /dev/null +++ b/core/src/main/scala/org/apache/spark/Heartbeater.scala @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +import java.util.concurrent.TimeUnit + +import org.apache.spark.executor.ExecutorMetrics +import org.apache.spark.internal.Logging +import org.apache.spark.memory.MemoryManager +import org.apache.spark.metrics.ExecutorMetricType +import org.apache.spark.util.{ThreadUtils, Utils} + +/** + * Creates a heartbeat thread which will call the specified reportHeartbeat function at + * intervals of intervalMs. + * + * @param memoryManager the memory manager for execution and storage memory. + * @param reportHeartbeat the heartbeat reporting function to call. + * @param name the thread name for the heartbeater. + * @param intervalMs the interval between heartbeats. + */ +private[spark] class Heartbeater( + memoryManager: MemoryManager, + reportHeartbeat: () => Unit, + name: String, + intervalMs: Long) extends Logging { + // Executor for the heartbeat task + private val heartbeater = ThreadUtils.newDaemonSingleThreadScheduledExecutor(name) + + /** Schedules a task to report a heartbeat. */ + def start(): Unit = { + // Wait a random interval so the heartbeats don't end up in sync + val initialDelay = intervalMs + (math.random * intervalMs).asInstanceOf[Int] + + val heartbeatTask = new Runnable() { + override def run(): Unit = Utils.logUncaughtExceptions(reportHeartbeat()) + } + heartbeater.scheduleAtFixedRate(heartbeatTask, initialDelay, intervalMs, TimeUnit.MILLISECONDS) + } + + /** Stops the heartbeat thread. */ + def stop(): Unit = { + heartbeater.shutdown() + heartbeater.awaitTermination(10, TimeUnit.SECONDS) + } + + /** + * Get the current executor level metrics. These are returned as an array, with the index + * determined by MetricGetter.values + */ + def getCurrentMetrics(): ExecutorMetrics = { + val metrics = ExecutorMetricType.values.map(_.getMetricValue(memoryManager)).toArray + new ExecutorMetrics(metrics) + } +} + diff --git a/core/src/main/scala/org/apache/spark/MapOutputStatistics.scala b/core/src/main/scala/org/apache/spark/MapOutputStatistics.scala index f8a6f1d0d8cbb..ff85e11409e35 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputStatistics.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputStatistics.scala @@ -23,5 +23,9 @@ package org.apache.spark * @param shuffleId ID of the shuffle * @param bytesByPartitionId approximate number of output bytes for each map output partition * (may be inexact due to use of compressed map statuses) + * @param recordsByPartitionId number of output records for each map output partition */ -private[spark] class MapOutputStatistics(val shuffleId: Int, val bytesByPartitionId: Array[Long]) +private[spark] class MapOutputStatistics( + val shuffleId: Int, + val bytesByPartitionId: Array[Long], + val recordsByPartitionId: Array[Long]) diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala index 73646051f264c..41575ce4e6e3d 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala @@ -434,6 +434,18 @@ private[spark] class MapOutputTrackerMaster( } } + /** Unregister all map output information of the given shuffle. */ + def unregisterAllMapOutput(shuffleId: Int) { + shuffleStatuses.get(shuffleId) match { + case Some(shuffleStatus) => + shuffleStatus.removeOutputsByFilter(x => true) + incrementEpoch() + case None => + throw new SparkException( + s"unregisterAllMapOutput called for nonexistent shuffle ID $shuffleId.") + } + } + /** Unregister shuffle data */ def unregisterShuffle(shuffleId: Int) { shuffleStatuses.remove(shuffleId).foreach { shuffleStatus => @@ -510,16 +522,19 @@ private[spark] class MapOutputTrackerMaster( def getStatistics(dep: ShuffleDependency[_, _, _]): MapOutputStatistics = { shuffleStatuses(dep.shuffleId).withMapStatuses { statuses => val totalSizes = new Array[Long](dep.partitioner.numPartitions) + val recordsByMapTask = new Array[Long](statuses.length) + val parallelAggThreshold = conf.get( SHUFFLE_MAP_OUTPUT_PARALLEL_AGGREGATION_THRESHOLD) val parallelism = math.min( Runtime.getRuntime.availableProcessors(), statuses.length.toLong * totalSizes.length / parallelAggThreshold + 1).toInt if (parallelism <= 1) { - for (s <- statuses) { + statuses.zipWithIndex.foreach { case (s, index) => for (i <- 0 until totalSizes.length) { totalSizes(i) += s.getSizeForBlock(i) } + recordsByMapTask(index) = s.numberOfOutput } } else { val threadPool = ThreadUtils.newDaemonFixedThreadPool(parallelism, "map-output-aggregate") @@ -536,8 +551,11 @@ private[spark] class MapOutputTrackerMaster( } finally { threadPool.shutdown() } + statuses.zipWithIndex.foreach { case (s, index) => + recordsByMapTask(index) = s.numberOfOutput + } } - new MapOutputStatistics(dep.shuffleId, totalSizes) + new MapOutputStatistics(dep.shuffleId, totalSizes, recordsByMapTask) } } diff --git a/core/src/main/scala/org/apache/spark/Partitioner.scala b/core/src/main/scala/org/apache/spark/Partitioner.scala index c940cb25d478b..515237558fd87 100644 --- a/core/src/main/scala/org/apache/spark/Partitioner.scala +++ b/core/src/main/scala/org/apache/spark/Partitioner.scala @@ -33,6 +33,9 @@ import org.apache.spark.util.random.SamplingUtils /** * An object that defines how the elements in a key-value pair RDD are partitioned by key. * Maps each key to a partition ID, from 0 to `numPartitions - 1`. + * + * Note that, partitioner must be deterministic, i.e. it must return the same partition id given + * the same partition key. */ abstract class Partitioner extends Serializable { def numPartitions: Int diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 74bfb5d6d2ea3..d943087ab6b80 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -213,6 +213,7 @@ class SparkContext(config: SparkConf) extends Logging { private var _files: Seq[String] = _ private var _shutdownHookRef: AnyRef = _ private var _statusStore: AppStatusStore = _ + private var _heartbeater: Heartbeater = _ /* ------------------------------------------------------------------------------------- * | Accessors and public fields. These provide access to the internal state of the | @@ -254,7 +255,7 @@ class SparkContext(config: SparkConf) extends Logging { conf: SparkConf, isLocal: Boolean, listenerBus: LiveListenerBus): SparkEnv = { - SparkEnv.createDriverEnv(conf, isLocal, listenerBus, SparkContext.numDriverCores(master)) + SparkEnv.createDriverEnv(conf, isLocal, listenerBus, SparkContext.numDriverCores(master, conf)) } private[spark] def env: SparkEnv = _env @@ -496,6 +497,11 @@ class SparkContext(config: SparkConf) extends Logging { _dagScheduler = new DAGScheduler(this) _heartbeatReceiver.ask[Boolean](TaskSchedulerIsSet) + // create and start the heartbeater for collecting memory metrics + _heartbeater = new Heartbeater(env.memoryManager, reportHeartBeat, "driver-heartbeater", + conf.getTimeAsMs("spark.executor.heartbeatInterval", "10s")) + _heartbeater.start() + // start TaskScheduler after taskScheduler sets DAGScheduler reference in DAGScheduler's // constructor _taskScheduler.start() @@ -571,7 +577,12 @@ class SparkContext(config: SparkConf) extends Logging { _shutdownHookRef = ShutdownHookManager.addShutdownHook( ShutdownHookManager.SPARK_CONTEXT_SHUTDOWN_PRIORITY) { () => logInfo("Invoking stop() from shutdown hook") - stop() + try { + stop() + } catch { + case e: Throwable => + logWarning("Ignoring Exception while stopping SparkContext from shutdown hook", e) + } } } catch { case NonFatal(e) => @@ -1496,6 +1507,8 @@ class SparkContext(config: SparkConf) extends Logging { * @param path can be either a local file, a file in HDFS (or other Hadoop-supported * filesystems), or an HTTP, HTTPS or FTP URI. To access the file in Spark jobs, * use `SparkFiles.get(fileName)` to find its download location. + * + * @note A path can be added only once. Subsequent additions of the same path are ignored. */ def addFile(path: String): Unit = { addFile(path, false) @@ -1516,11 +1529,17 @@ class SparkContext(config: SparkConf) extends Logging { * use `SparkFiles.get(fileName)` to find its download location. * @param recursive if true, a directory can be given in `path`. Currently directories are * only supported for Hadoop-supported filesystems. + * + * @note A path can be added only once. Subsequent additions of the same path are ignored. */ def addFile(path: String, recursive: Boolean): Unit = { val uri = new Path(path).toUri val schemeCorrectedPath = uri.getScheme match { - case null | "local" => new File(path).getCanonicalFile.toURI.toString + case null => new File(path).getCanonicalFile.toURI.toString + case "local" => + logWarning("File with 'local' scheme is not supported to add to file server, since " + + "it is already available on every node.") + return case _ => path } @@ -1555,6 +1574,9 @@ class SparkContext(config: SparkConf) extends Logging { Utils.fetchFile(uri.toString, new File(SparkFiles.getRootDirectory()), conf, env.securityManager, hadoopConfiguration, timestamp, useCache = false) postEnvironmentUpdate() + } else { + logWarning(s"The path $path has been added already. Overwriting of added paths " + + "is not supported in the current version.") } } @@ -1586,6 +1608,15 @@ class SparkContext(config: SparkConf) extends Logging { } } + /** + * Get the max number of tasks that can be concurrent launched currently. + * Note that please don't cache the value returned by this method, because the number can change + * due to add/remove executors. + * + * @return The max number of tasks that can be concurrent launched currently. + */ + private[spark] def maxNumConcurrentTasks(): Int = schedulerBackend.maxNumConcurrentTasks() + /** * Update the cluster manager on our scheduling needs. Three bits of information are included * to help it make decisions. @@ -1803,6 +1834,8 @@ class SparkContext(config: SparkConf) extends Logging { * * @param path can be either a local file, a file in HDFS (or other Hadoop-supported filesystems), * an HTTP, HTTPS or FTP URI, or local:/path for a file on every worker node. + * + * @note A path can be added only once. Subsequent additions of the same path are ignored. */ def addJar(path: String) { def addJarFile(file: File): String = { @@ -1849,6 +1882,9 @@ class SparkContext(config: SparkConf) extends Logging { if (addedJars.putIfAbsent(key, timestamp).isEmpty) { logInfo(s"Added JAR $path at $key with timestamp $timestamp") postEnvironmentUpdate() + } else { + logWarning(s"The jar $path has been added already. Overwriting of added jars " + + "is not supported in the current version.") } } } @@ -1914,6 +1950,12 @@ class SparkContext(config: SparkConf) extends Logging { Utils.tryLogNonFatalError { _executorAllocationManager.foreach(_.stop()) } + if (_dagScheduler != null) { + Utils.tryLogNonFatalError { + _dagScheduler.stop() + } + _dagScheduler = null + } if (_listenerBusStarted) { Utils.tryLogNonFatalError { listenerBus.stop() @@ -1923,11 +1965,11 @@ class SparkContext(config: SparkConf) extends Logging { Utils.tryLogNonFatalError { _eventLogger.foreach(_.stop()) } - if (_dagScheduler != null) { + if (_heartbeater != null) { Utils.tryLogNonFatalError { - _dagScheduler.stop() + _heartbeater.stop() } - _dagScheduler = null + _heartbeater = null } if (env != null && _heartbeatReceiver != null) { Utils.tryLogNonFatalError { @@ -2399,6 +2441,14 @@ class SparkContext(config: SparkConf) extends Logging { } } + /** Reports heartbeat metrics for the driver. */ + private def reportHeartBeat(): Unit = { + val driverUpdates = _heartbeater.getCurrentMetrics() + val accumUpdates = new Array[(Long, Int, Int, Seq[AccumulableInfo])](0) + listenerBus.post(SparkListenerExecutorMetricsUpdate("driver", accumUpdates, + Some(driverUpdates))) + } + // In order to prevent multiple SparkContexts from being active at the same time, mark this // context as having finished construction. // NOTE: this must be placed at the end of the SparkContext constructor. @@ -2652,9 +2702,16 @@ object SparkContext extends Logging { } /** - * The number of driver cores to use for execution in local mode, 0 otherwise. + * The number of cores available to the driver to use for tasks such as I/O with Netty */ private[spark] def numDriverCores(master: String): Int = { + numDriverCores(master, null) + } + + /** + * The number of cores available to the driver to use for tasks such as I/O with Netty + */ + private[spark] def numDriverCores(master: String, conf: SparkConf): Int = { def convertToInt(threads: String): Int = { if (threads == "*") Runtime.getRuntime.availableProcessors() else threads.toInt } @@ -2662,7 +2719,13 @@ object SparkContext extends Logging { case "local" => 1 case SparkMasterRegex.LOCAL_N_REGEX(threads) => convertToInt(threads) case SparkMasterRegex.LOCAL_N_FAILURES_REGEX(threads, _) => convertToInt(threads) - case _ => 0 // driver is not used for execution + case "yarn" => + if (conf != null && conf.getOption("spark.submit.deployMode").contains("cluster")) { + conf.getInt("spark.driver.cores", 0) + } else { + 0 + } + case _ => 0 // Either driver is not being used, or its core count will be interpolated later } } diff --git a/core/src/main/scala/org/apache/spark/TaskContext.scala b/core/src/main/scala/org/apache/spark/TaskContext.scala index 69739745aa6cf..2b939dabb1105 100644 --- a/core/src/main/scala/org/apache/spark/TaskContext.scala +++ b/core/src/main/scala/org/apache/spark/TaskContext.scala @@ -123,7 +123,10 @@ abstract class TaskContext extends Serializable { * * Exceptions thrown by the listener will result in failure of the task. */ - def addTaskCompletionListener(f: (TaskContext) => Unit): TaskContext = { + def addTaskCompletionListener[U](f: (TaskContext) => U): TaskContext = { + // Note that due to this scala bug: https://github.com/scala/bug/issues/11016, we need to make + // this function polymorphic for every scala version >= 2.12, otherwise an overloaded method + // resolution error occurs at compile time. addTaskCompletionListener(new TaskCompletionListener { override def onTaskCompletion(context: TaskContext): Unit = f(context) }) @@ -218,4 +221,18 @@ abstract class TaskContext extends Serializable { */ private[spark] def setFetchFailed(fetchFailed: FetchFailedException): Unit + /** Marks the task for interruption, i.e. cancellation. */ + private[spark] def markInterrupted(reason: String): Unit + + /** Marks the task as failed and triggers the failure listeners. */ + private[spark] def markTaskFailed(error: Throwable): Unit + + /** Marks the task as completed and triggers the completion listeners. */ + private[spark] def markTaskCompleted(error: Option[Throwable]): Unit + + /** Optionally returns the stored fetch failure in the task. */ + private[spark] def fetchFailed: Option[FetchFailedException] + + /** Gets local properties set upstream in the driver. */ + private[spark] def getLocalProperties: Properties } diff --git a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala index 0791fe856ef15..89730424e5acf 100644 --- a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala +++ b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala @@ -30,6 +30,7 @@ import org.apache.spark.metrics.source.Source import org.apache.spark.shuffle.FetchFailedException import org.apache.spark.util._ + /** * A [[TaskContext]] implementation. * @@ -98,9 +99,8 @@ private[spark] class TaskContextImpl( this } - /** Marks the task as failed and triggers the failure listeners. */ @GuardedBy("this") - private[spark] def markTaskFailed(error: Throwable): Unit = synchronized { + private[spark] override def markTaskFailed(error: Throwable): Unit = synchronized { if (failed) return failed = true failure = error @@ -109,9 +109,8 @@ private[spark] class TaskContextImpl( } } - /** Marks the task as completed and triggers the completion listeners. */ @GuardedBy("this") - private[spark] def markTaskCompleted(error: Option[Throwable]): Unit = synchronized { + private[spark] override def markTaskCompleted(error: Option[Throwable]): Unit = synchronized { if (completed) return completed = true invokeListeners(onCompleteCallbacks, "TaskCompletionListener", error) { @@ -140,8 +139,7 @@ private[spark] class TaskContextImpl( } } - /** Marks the task for interruption, i.e. cancellation. */ - private[spark] def markInterrupted(reason: String): Unit = { + private[spark] override def markInterrupted(reason: String): Unit = { reasonIfKilled = Some(reason) } @@ -176,8 +174,7 @@ private[spark] class TaskContextImpl( this._fetchFailedException = Option(fetchFailed) } - private[spark] def fetchFailed: Option[FetchFailedException] = _fetchFailedException + private[spark] override def fetchFailed: Option[FetchFailedException] = _fetchFailedException - // TODO: shall we publish it and define it in `TaskContext`? - private[spark] def getLocalProperties(): Properties = localProperties + private[spark] override def getLocalProperties(): Properties = localProperties } diff --git a/core/src/main/scala/org/apache/spark/TestUtils.scala b/core/src/main/scala/org/apache/spark/TestUtils.scala index b5c4c705dcbc7..c2ebd388a2365 100644 --- a/core/src/main/scala/org/apache/spark/TestUtils.scala +++ b/core/src/main/scala/org/apache/spark/TestUtils.scala @@ -23,7 +23,7 @@ import java.nio.charset.StandardCharsets import java.security.SecureRandom import java.security.cert.X509Certificate import java.util.{Arrays, Properties} -import java.util.concurrent.{CountDownLatch, TimeoutException, TimeUnit} +import java.util.concurrent.{TimeoutException, TimeUnit} import java.util.jar.{JarEntry, JarOutputStream} import javax.net.ssl._ import javax.tools.{JavaFileObject, SimpleJavaFileObject, ToolProvider} @@ -172,22 +172,24 @@ private[spark] object TestUtils { /** * Run some code involving jobs submitted to the given context and assert that the jobs spilled. */ - def assertSpilled[T](sc: SparkContext, identifier: String)(body: => T): Unit = { - val spillListener = new SpillListener - sc.addSparkListener(spillListener) - body - assert(spillListener.numSpilledStages > 0, s"expected $identifier to spill, but did not") + def assertSpilled(sc: SparkContext, identifier: String)(body: => Unit): Unit = { + val listener = new SpillListener + withListener(sc, listener) { _ => + body + } + assert(listener.numSpilledStages > 0, s"expected $identifier to spill, but did not") } /** * Run some code involving jobs submitted to the given context and assert that the jobs * did not spill. */ - def assertNotSpilled[T](sc: SparkContext, identifier: String)(body: => T): Unit = { - val spillListener = new SpillListener - sc.addSparkListener(spillListener) - body - assert(spillListener.numSpilledStages == 0, s"expected $identifier to not spill, but did") + def assertNotSpilled(sc: SparkContext, identifier: String)(body: => Unit): Unit = { + val listener = new SpillListener + withListener(sc, listener) { _ => + body + } + assert(listener.numSpilledStages == 0, s"expected $identifier to not spill, but did") } /** @@ -233,6 +235,21 @@ private[spark] object TestUtils { } } + /** + * Runs some code with the given listener installed in the SparkContext. After the code runs, + * this method will wait until all events posted to the listener bus are processed, and then + * remove the listener from the bus. + */ + def withListener[L <: SparkListener](sc: SparkContext, listener: L) (body: L => Unit): Unit = { + sc.addSparkListener(listener) + try { + body(listener) + } finally { + sc.listenerBus.waitUntilEmpty(TimeUnit.SECONDS.toMillis(10)) + sc.listenerBus.removeListener(listener) + } + } + /** * Wait until at least `numExecutors` executors are up, or throw `TimeoutException` if the waiting * time elapsed before `numExecutors` executors up. Exposed for testing. @@ -289,21 +306,17 @@ private[spark] object TestUtils { private class SpillListener extends SparkListener { private val stageIdToTaskMetrics = new mutable.HashMap[Int, ArrayBuffer[TaskMetrics]] private val spilledStageIds = new mutable.HashSet[Int] - private val stagesDone = new CountDownLatch(1) - def numSpilledStages: Int = { - // Long timeout, just in case somehow the job end isn't notified. - // Fails if a timeout occurs - assert(stagesDone.await(10, TimeUnit.SECONDS)) + def numSpilledStages: Int = synchronized { spilledStageIds.size } - override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = { + override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = synchronized { stageIdToTaskMetrics.getOrElseUpdate( taskEnd.stageId, new ArrayBuffer[TaskMetrics]) += taskEnd.taskMetrics } - override def onStageCompleted(stageComplete: SparkListenerStageCompleted): Unit = { + override def onStageCompleted(stageComplete: SparkListenerStageCompleted): Unit = synchronized { val stageId = stageComplete.stageInfo.stageId val metrics = stageIdToTaskMetrics.remove(stageId).toSeq.flatten val spilled = metrics.map(_.memoryBytesSpilled).sum > 0 @@ -311,8 +324,4 @@ private class SpillListener extends SparkListener { spilledStageIds += stageId } } - - override def onJobEnd(jobEnd: SparkListenerJobEnd): Unit = { - stagesDone.countDown() - } } diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala index f1936bf587282..09c83849e26b2 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala @@ -668,6 +668,8 @@ class JavaSparkContext(val sc: SparkContext) * The `path` passed can be either a local file, a file in HDFS (or other Hadoop-supported * filesystems), or an HTTP, HTTPS or FTP URI. To access the file in Spark jobs, * use `SparkFiles.get(fileName)` to find its download location. + * + * @note A path can be added only once. Subsequent additions of the same path are ignored. */ def addFile(path: String) { sc.addFile(path) @@ -681,6 +683,8 @@ class JavaSparkContext(val sc: SparkContext) * * A directory can be given if the recursive option is set to true. Currently directories are only * supported for Hadoop-supported filesystems. + * + * @note A path can be added only once. Subsequent additions of the same path are ignored. */ def addFile(path: String, recursive: Boolean): Unit = { sc.addFile(path, recursive) @@ -690,6 +694,8 @@ class JavaSparkContext(val sc: SparkContext) * Adds a JAR dependency for all tasks to be executed on this SparkContext in the future. * The `path` passed can be either a local file, a file in HDFS (or other Hadoop-supported * filesystems), or an HTTP, HTTPS or FTP URI. + * + * @note A path can be added only once. Subsequent additions of the same path are ignored. */ def addJar(path: String) { sc.addJar(path) diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index a1ee2f7d1b119..e639a842754bd 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -45,12 +45,10 @@ import org.apache.spark.util._ private[spark] class PythonRDD( parent: RDD[_], func: PythonFunction, - preservePartitoning: Boolean) + preservePartitoning: Boolean, + isFromBarrier: Boolean = false) extends RDD[Array[Byte]](parent) { - val bufferSize = conf.getInt("spark.buffer.size", 65536) - val reuseWorker = conf.getBoolean("spark.python.worker.reuse", true) - override def getPartitions: Array[Partition] = firstParent.partitions override val partitioner: Option[Partitioner] = { @@ -60,9 +58,12 @@ private[spark] class PythonRDD( val asJavaRDD: JavaRDD[Array[Byte]] = JavaRDD.fromRDD(this) override def compute(split: Partition, context: TaskContext): Iterator[Array[Byte]] = { - val runner = PythonRunner(func, bufferSize, reuseWorker) + val runner = PythonRunner(func) runner.compute(firstParent.iterator(split, context), split.index, context) } + + @transient protected lazy override val isBarrier_ : Boolean = + isFromBarrier || dependencies.exists(_.rdd.isBarrier()) } /** @@ -398,6 +399,26 @@ private[spark] object PythonRDD extends Logging { * data collected from this job, and the secret for authentication. */ def serveIterator(items: Iterator[_], threadName: String): Array[Any] = { + serveToStream(threadName) { out => + writeIteratorToStream(items, new DataOutputStream(out)) + } + } + + /** + * Create a socket server and background thread to execute the writeFunc + * with the given OutputStream. + * + * The socket server can only accept one connection, or close if no connection + * in 15 seconds. + * + * Once a connection comes in, it will execute the block of code and pass in + * the socket output stream. + * + * The thread will terminate after the block of code is executed or any + * exceptions happen. + */ + private[spark] def serveToStream( + threadName: String)(writeFunc: OutputStream => Unit): Array[Any] = { val serverSocket = new ServerSocket(0, 1, InetAddress.getByName("localhost")) // Close the socket if no connection in 15 seconds serverSocket.setSoTimeout(15000) @@ -409,9 +430,9 @@ private[spark] object PythonRDD extends Logging { val sock = serverSocket.accept() authHelper.authClient(sock) - val out = new DataOutputStream(new BufferedOutputStream(sock.getOutputStream)) + val out = new BufferedOutputStream(sock.getOutputStream) Utils.tryWithSafeFinally { - writeIteratorToStream(items, out) + writeFunc(out) } { out.close() sock.close() @@ -586,8 +607,9 @@ class BytesToString extends org.apache.spark.api.java.function.Function[Array[By */ private[spark] class PythonAccumulatorV2( @transient private val serverHost: String, - private val serverPort: Int) - extends CollectionAccumulator[Array[Byte]] { + private val serverPort: Int, + private val secretToken: String) + extends CollectionAccumulator[Array[Byte]] with Logging{ Utils.checkHost(serverHost) @@ -602,17 +624,22 @@ private[spark] class PythonAccumulatorV2( private def openSocket(): Socket = synchronized { if (socket == null || socket.isClosed) { socket = new Socket(serverHost, serverPort) + logInfo(s"Connected to AccumulatorServer at host: $serverHost port: $serverPort") + // send the secret just for the initial authentication when opening a new connection + socket.getOutputStream.write(secretToken.getBytes(StandardCharsets.UTF_8)) } socket } // Need to override so the types match with PythonFunction - override def copyAndReset(): PythonAccumulatorV2 = new PythonAccumulatorV2(serverHost, serverPort) + override def copyAndReset(): PythonAccumulatorV2 = { + new PythonAccumulatorV2(serverHost, serverPort, secretToken) + } override def merge(other: AccumulatorV2[Array[Byte], JList[Array[Byte]]]): Unit = synchronized { val otherPythonAccumulator = other.asInstanceOf[PythonAccumulatorV2] // This conditional isn't strictly speaking needed - merging only currently happens on the - // driver program - but that isn't gauranteed so incase this changes. + // driver program - but that isn't guaranteed so incase this changes. if (serverHost == null) { // We are on the worker super.merge(otherPythonAccumulator) diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala index ebabedf950e39..4c53bc269a104 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala @@ -20,12 +20,15 @@ package org.apache.spark.api.python import java.io._ import java.net._ import java.nio.charset.StandardCharsets +import java.nio.charset.StandardCharsets.UTF_8 import java.util.concurrent.atomic.AtomicBoolean import scala.collection.JavaConverters._ import org.apache.spark._ import org.apache.spark.internal.Logging +import org.apache.spark.internal.config.PYSPARK_EXECUTOR_MEMORY +import org.apache.spark.security.SocketAuthHelper import org.apache.spark.util._ @@ -60,14 +63,20 @@ private[spark] object PythonEvalType { */ private[spark] abstract class BasePythonRunner[IN, OUT]( funcs: Seq[ChainedPythonFunctions], - bufferSize: Int, - reuseWorker: Boolean, evalType: Int, argOffsets: Array[Array[Int]]) extends Logging { require(funcs.length == argOffsets.length, "argOffsets should have the same length as funcs") + private val conf = SparkEnv.get.conf + private val bufferSize = conf.getInt("spark.buffer.size", 65536) + private val reuseWorker = conf.getBoolean("spark.python.worker.reuse", true) + // each python worker gets an equal part of the allocation. the worker pool will grow to the + // number of concurrent tasks, which is determined by the number of cores in this executor. + private val memoryMb = conf.get(PYSPARK_EXECUTOR_MEMORY) + .map(_ / conf.getInt("spark.executor.cores", 1)) + // All the Python functions should have the same exec, version and envvars. protected val envVars = funcs.head.funcs.head.envVars protected val pythonExec = funcs.head.funcs.head.pythonExec @@ -76,6 +85,12 @@ private[spark] abstract class BasePythonRunner[IN, OUT]( // TODO: support accumulator in multiple UDF protected val accumulator = funcs.head.funcs.head.accumulator + // Expose a ServerSocket to support method calls via socket from Python side. + private[spark] var serverSocket: Option[ServerSocket] = None + + // Authentication helper used when serving method calls via socket from Python side. + private lazy val authHelper = new SocketAuthHelper(conf) + def compute( inputIterator: Iterator[IN], partitionIndex: Int, @@ -87,6 +102,9 @@ private[spark] abstract class BasePythonRunner[IN, OUT]( if (reuseWorker) { envVars.put("SPARK_REUSE_WORKER", "1") } + if (memoryMb.isDefined) { + envVars.put("PYSPARK_EXECUTOR_MEMORY_MB", memoryMb.get.toString) + } val worker: Socket = env.createPythonWorker(pythonExec, envVars.asScala.toMap) // Whether is the worker released into idle pool val released = new AtomicBoolean(false) @@ -94,7 +112,7 @@ private[spark] abstract class BasePythonRunner[IN, OUT]( // Start a thread to feed the process input from our parent's iterator val writerThread = newWriterThread(env, worker, inputIterator, partitionIndex, context) - context.addTaskCompletionListener { _ => + context.addTaskCompletionListener[Unit] { _ => writerThread.shutdownOnTaskCompletion() if (!reuseWorker || !released.get) { try { @@ -180,12 +198,79 @@ private[spark] abstract class BasePythonRunner[IN, OUT]( dataOut.writeInt(partitionIndex) // Python version of driver PythonRDD.writeUTF(pythonVer, dataOut) + // Init a ServerSocket to accept method calls from Python side. + val isBarrier = context.isInstanceOf[BarrierTaskContext] + if (isBarrier) { + serverSocket = Some(new ServerSocket(/* port */ 0, + /* backlog */ 1, + InetAddress.getByName("localhost"))) + // A call to accept() for ServerSocket shall block infinitely. + serverSocket.map(_.setSoTimeout(0)) + new Thread("accept-connections") { + setDaemon(true) + + override def run(): Unit = { + while (!serverSocket.get.isClosed()) { + var sock: Socket = null + try { + sock = serverSocket.get.accept() + // Wait for function call from python side. + sock.setSoTimeout(10000) + authHelper.authClient(sock) + val input = new DataInputStream(sock.getInputStream()) + input.readInt() match { + case BarrierTaskContextMessageProtocol.BARRIER_FUNCTION => + // The barrier() function may wait infinitely, socket shall not timeout + // before the function finishes. + sock.setSoTimeout(0) + barrierAndServe(sock) + + case _ => + val out = new DataOutputStream(new BufferedOutputStream( + sock.getOutputStream)) + writeUTF(BarrierTaskContextMessageProtocol.ERROR_UNRECOGNIZED_FUNCTION, out) + } + } catch { + case e: SocketException if e.getMessage.contains("Socket closed") => + // It is possible that the ServerSocket is not closed, but the native socket + // has already been closed, we shall catch and silently ignore this case. + } finally { + if (sock != null) { + sock.close() + } + } + } + } + }.start() + } + val secret = if (isBarrier) { + authHelper.secret + } else { + "" + } + // Close ServerSocket on task completion. + serverSocket.foreach { server => + context.addTaskCompletionListener[Unit](_ => server.close()) + } + val boundPort: Int = serverSocket.map(_.getLocalPort).getOrElse(0) + if (boundPort == -1) { + val message = "ServerSocket failed to bind to Java side." + logError(message) + throw new SparkException(message) + } else if (isBarrier) { + logDebug(s"Started ServerSocket on port $boundPort.") + } // Write out the TaskContextInfo + dataOut.writeBoolean(isBarrier) + dataOut.writeInt(boundPort) + val secretBytes = secret.getBytes(UTF_8) + dataOut.writeInt(secretBytes.length) + dataOut.write(secretBytes, 0, secretBytes.length) dataOut.writeInt(context.stageId()) dataOut.writeInt(context.partitionId()) dataOut.writeInt(context.attemptNumber()) dataOut.writeLong(context.taskAttemptId()) - val localProps = context.asInstanceOf[TaskContextImpl].getLocalProperties.asScala + val localProps = context.getLocalProperties.asScala dataOut.writeInt(localProps.size) localProps.foreach { case (k, v) => PythonRDD.writeUTF(k, dataOut) @@ -243,6 +328,30 @@ private[spark] abstract class BasePythonRunner[IN, OUT]( } } } + + /** + * Gateway to call BarrierTaskContext.barrier(). + */ + def barrierAndServe(sock: Socket): Unit = { + require(serverSocket.isDefined, "No available ServerSocket to redirect the barrier() call.") + + val out = new DataOutputStream(new BufferedOutputStream(sock.getOutputStream)) + try { + context.asInstanceOf[BarrierTaskContext].barrier() + writeUTF(BarrierTaskContextMessageProtocol.BARRIER_RESULT_SUCCESS, out) + } catch { + case e: SparkException => + writeUTF(e.getMessage, out) + } finally { + out.close() + } + } + + def writeUTF(str: String, dataOut: DataOutputStream) { + val bytes = str.getBytes(UTF_8) + dataOut.writeInt(bytes.length) + dataOut.write(bytes) + } } abstract class ReaderIterator( @@ -385,20 +494,17 @@ private[spark] abstract class BasePythonRunner[IN, OUT]( private[spark] object PythonRunner { - def apply(func: PythonFunction, bufferSize: Int, reuseWorker: Boolean): PythonRunner = { - new PythonRunner(Seq(ChainedPythonFunctions(Seq(func))), bufferSize, reuseWorker) + def apply(func: PythonFunction): PythonRunner = { + new PythonRunner(Seq(ChainedPythonFunctions(Seq(func)))) } } /** * A helper class to run Python mapPartition in Spark. */ -private[spark] class PythonRunner( - funcs: Seq[ChainedPythonFunctions], - bufferSize: Int, - reuseWorker: Boolean) +private[spark] class PythonRunner(funcs: Seq[ChainedPythonFunctions]) extends BasePythonRunner[Array[Byte], Array[Byte]]( - funcs, bufferSize, reuseWorker, PythonEvalType.NON_UDF, Array(Array(0))) { + funcs, PythonEvalType.NON_UDF, Array(Array(0))) { protected override def newWriterThread( env: SparkEnv, @@ -465,3 +571,9 @@ private[spark] object SpecialLengths { val NULL = -5 val START_ARROW_STREAM = -6 } + +private[spark] object BarrierTaskContextMessageProtocol { + val BARRIER_FUNCTION = 1 + val BARRIER_RESULT_SUCCESS = "success" + val ERROR_UNRECOGNIZED_FUNCTION = "Not recognized function call from python side." +} diff --git a/core/src/main/scala/org/apache/spark/api/r/RBackend.scala b/core/src/main/scala/org/apache/spark/api/r/RBackend.scala index 3b2e809408e0f..7ce2581555014 100644 --- a/core/src/main/scala/org/apache/spark/api/r/RBackend.scala +++ b/core/src/main/scala/org/apache/spark/api/r/RBackend.scala @@ -96,11 +96,11 @@ private[spark] class RBackend { channelFuture.channel().close().awaitUninterruptibly(10, TimeUnit.SECONDS) channelFuture = null } - if (bootstrap != null && bootstrap.group() != null) { - bootstrap.group().shutdownGracefully() + if (bootstrap != null && bootstrap.config().group() != null) { + bootstrap.config().group().shutdownGracefully() } if (bootstrap != null && bootstrap.childGroup() != null) { - bootstrap.childGroup().shutdownGracefully() + bootstrap.config().childGroup().shutdownGracefully() } bootstrap = null jvmObjectTracker.clear() diff --git a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala index e125095cf4777..cbd49e070f2eb 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala @@ -262,7 +262,7 @@ private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long) val blockManager = SparkEnv.get.blockManager Option(TaskContext.get()) match { case Some(taskContext) => - taskContext.addTaskCompletionListener(_ => blockManager.releaseLock(blockId)) + taskContext.addTaskCompletionListener[Unit](_ => blockManager.releaseLock(blockId)) case None => // This should only happen on the driver, where broadcast variables may be accessed // outside of running tasks (e.g. when computing rdd.partitions()). In order to allow diff --git a/core/src/main/scala/org/apache/spark/deploy/ExternalShuffleService.scala b/core/src/main/scala/org/apache/spark/deploy/ExternalShuffleService.scala index b59a4fe66587c..f6b3c37f0fe72 100644 --- a/core/src/main/scala/org/apache/spark/deploy/ExternalShuffleService.scala +++ b/core/src/main/scala/org/apache/spark/deploy/ExternalShuffleService.scala @@ -22,7 +22,7 @@ import java.util.concurrent.CountDownLatch import scala.collection.JavaConverters._ import org.apache.spark.{SecurityManager, SparkConf} -import org.apache.spark.internal.Logging +import org.apache.spark.internal.{config, Logging} import org.apache.spark.metrics.MetricsSystem import org.apache.spark.network.TransportContext import org.apache.spark.network.crypto.AuthServerBootstrap @@ -45,8 +45,8 @@ class ExternalShuffleService(sparkConf: SparkConf, securityManager: SecurityMana protected val masterMetricsSystem = MetricsSystem.createMetricsSystem("shuffleService", sparkConf, securityManager) - private val enabled = sparkConf.getBoolean("spark.shuffle.service.enabled", false) - private val port = sparkConf.getInt("spark.shuffle.service.port", 7337) + private val enabled = sparkConf.get(config.SHUFFLE_SERVICE_ENABLED) + private val port = sparkConf.get(config.SHUFFLE_SERVICE_PORT) private val transportConf = SparkTransportConf.fromSparkConf(sparkConf, "shuffle", numUsableCores = 0) @@ -131,7 +131,7 @@ object ExternalShuffleService extends Logging { // we override this value since this service is started from the command line // and we assume the user really wants it to be running - sparkConf.set("spark.shuffle.service.enabled", "true") + sparkConf.set(config.SHUFFLE_SERVICE_ENABLED.key, "true") server = newShuffleService(sparkConf, securityManager) server.start() diff --git a/core/src/main/scala/org/apache/spark/deploy/LocalSparkCluster.scala b/core/src/main/scala/org/apache/spark/deploy/LocalSparkCluster.scala index 84aa8944fc1c7..be293f88a9d4a 100644 --- a/core/src/main/scala/org/apache/spark/deploy/LocalSparkCluster.scala +++ b/core/src/main/scala/org/apache/spark/deploy/LocalSparkCluster.scala @@ -22,7 +22,7 @@ import scala.collection.mutable.ArrayBuffer import org.apache.spark.SparkConf import org.apache.spark.deploy.master.Master import org.apache.spark.deploy.worker.Worker -import org.apache.spark.internal.Logging +import org.apache.spark.internal.{config, Logging} import org.apache.spark.rpc.RpcEnv import org.apache.spark.util.Utils @@ -52,7 +52,7 @@ class LocalSparkCluster( // Disable REST server on Master in this mode unless otherwise specified val _conf = conf.clone() .setIfMissing("spark.master.rest.enabled", "false") - .set("spark.shuffle.service.enabled", "false") + .set(config.SHUFFLE_SERVICE_ENABLED.key, "false") /* Start the Master */ val (rpcEnv, webUiPort, _) = Master.startRpcEnvAndEndpoint(localHostname, 0, 0, _conf) diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala index 8353e64a619cf..4cc0063d010ef 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala @@ -31,7 +31,6 @@ import scala.util.control.NonFatal import com.google.common.primitives.Longs import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileStatus, FileSystem, Path, PathFilter} -import org.apache.hadoop.fs.permission.FsAction import org.apache.hadoop.mapred.JobConf import org.apache.hadoop.security.{Credentials, UserGroupInformation} import org.apache.hadoop.security.token.{Token, TokenIdentifier} @@ -108,7 +107,7 @@ class SparkHadoopUtil extends Logging { } /** - * Return an appropriate (subclass) of Configuration. Creating config can initializes some Hadoop + * Return an appropriate (subclass) of Configuration. Creating config can initialize some Hadoop * subsystems. */ def newConfiguration(conf: SparkConf): Configuration = { @@ -367,28 +366,6 @@ class SparkHadoopUtil extends Logging { buffer.toString } - private[spark] def checkAccessPermission(status: FileStatus, mode: FsAction): Boolean = { - val perm = status.getPermission - val ugi = UserGroupInformation.getCurrentUser - - if (ugi.getShortUserName == status.getOwner) { - if (perm.getUserAction.implies(mode)) { - return true - } - } else if (ugi.getGroupNames.contains(status.getGroup)) { - if (perm.getGroupAction.implies(mode)) { - return true - } - } else if (perm.getOtherAction.implies(mode)) { - return true - } - - logDebug(s"Permission denied: user=${ugi.getShortUserName}, " + - s"path=${status.getPath}:${status.getOwner}:${status.getGroup}" + - s"${if (status.isDirectory) "d" else "-"}$perm") - false - } - def serialize(creds: Credentials): Array[Byte] = { val byteStream = new ByteArrayOutputStream val dataStream = new DataOutputStream(byteStream) diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index e83d82f847c61..cf902db8709e7 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -181,6 +181,7 @@ private[spark] class SparkSubmit extends Logging { if (args.isStandaloneCluster && args.useRest) { try { logInfo("Running Spark using the REST application submission protocol.") + doRunMain() } catch { // Fail over to use the legacy submission gateway case e: SubmitRestConnectionException => @@ -285,8 +286,6 @@ private[spark] class SparkSubmit extends Logging { case (STANDALONE, CLUSTER) if args.isR => error("Cluster deploy mode is currently not supported for R " + "applications on standalone clusters.") - case (KUBERNETES, _) if args.isR => - error("R applications are currently not supported for Kubernetes.") case (LOCAL, CLUSTER) => error("Cluster deploy mode is not compatible with master \"local\"") case (_, CLUSTER) if isShell(args.primaryResource) => @@ -385,7 +384,7 @@ private[spark] class SparkSubmit extends Logging { val forceDownloadSchemes = sparkConf.get(FORCE_DOWNLOAD_SCHEMES) def shouldDownload(scheme: String): Boolean = { - forceDownloadSchemes.contains(scheme) || + forceDownloadSchemes.contains("*") || forceDownloadSchemes.contains(scheme) || Try { FileSystem.getFileSystemClass(scheme, hadoopConf) }.isFailure } @@ -578,7 +577,8 @@ private[spark] class SparkSubmit extends Logging { } // Add the main application jar and any added jars to classpath in case YARN client // requires these jars. - // This assumes both primaryResource and user jars are local jars, otherwise it will not be + // This assumes both primaryResource and user jars are local jars, or already downloaded + // to local by configuring "spark.yarn.dist.forceDownloadSchemes", otherwise it will not be // added to the classpath of YARN client. if (isYarnCluster) { if (isUserJar(args.primaryResource)) { @@ -698,10 +698,16 @@ private[spark] class SparkSubmit extends Logging { if (args.pyFiles != null) { childArgs ++= Array("--other-py-files", args.pyFiles) } - } else { + } else if (args.isR) { + childArgs ++= Array("--primary-r-file", args.primaryResource) + childArgs ++= Array("--main-class", "org.apache.spark.deploy.RRunner") + } + else { childArgs ++= Array("--primary-java-resource", args.primaryResource) childArgs ++= Array("--main-class", args.mainClass) } + } else { + childArgs ++= Array("--main-class", args.mainClass) } if (args.childArgs != null) { args.childArgs.foreach { arg => diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala index fb232101114b9..0998757715457 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala @@ -82,7 +82,7 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S var driverCores: String = null var submissionToKill: String = null var submissionToRequestStatusFor: String = null - var useRest: Boolean = true // used internally + var useRest: Boolean = false // used internally /** Default properties present in the currently defined defaults file. */ lazy val defaultSparkProperties: HashMap[String, String] = { @@ -115,6 +115,8 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S // Use `sparkProperties` map along with env vars to fill in any missing parameters loadEnvironmentArguments() + useRest = sparkProperties.getOrElse("spark.master.rest.enabled", "false").toBoolean + validateArguments() /** diff --git a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala index bf1eeb0c1bf59..44d23908146c7 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala @@ -21,11 +21,12 @@ import java.io.{File, FileNotFoundException, IOException} import java.nio.file.Files import java.nio.file.attribute.PosixFilePermissions import java.util.{Date, ServiceLoader} -import java.util.concurrent.{ExecutorService, TimeUnit} +import java.util.concurrent.{ConcurrentHashMap, ExecutorService, Future, TimeUnit} import java.util.zip.{ZipEntry, ZipOutputStream} import scala.collection.JavaConverters._ import scala.collection.mutable +import scala.concurrent.ExecutionException import scala.io.Source import scala.util.Try import scala.xml.Node @@ -33,8 +34,7 @@ import scala.xml.Node import com.fasterxml.jackson.annotation.JsonIgnore import com.google.common.io.ByteStreams import com.google.common.util.concurrent.MoreExecutors -import org.apache.hadoop.fs.{FileStatus, Path} -import org.apache.hadoop.fs.permission.FsAction +import org.apache.hadoop.fs.{FileStatus, FileSystem, Path} import org.apache.hadoop.hdfs.DistributedFileSystem import org.apache.hadoop.hdfs.protocol.HdfsConstants import org.apache.hadoop.security.AccessControlException @@ -114,7 +114,8 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) "; groups with admin permissions" + HISTORY_UI_ADMIN_ACLS_GROUPS.toString) private val hadoopConf = SparkHadoopUtil.get.newConfiguration(conf) - private val fs = new Path(logDir).getFileSystem(hadoopConf) + // Visible for testing + private[history] val fs: FileSystem = new Path(logDir).getFileSystem(hadoopConf) // Used by check event thread and clean log thread. // Scheduled thread pool size must be one, otherwise it will have concurrent issues about fs @@ -161,6 +162,25 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) new HistoryServerDiskManager(conf, path, listing, clock) } + private val blacklist = new ConcurrentHashMap[String, Long] + + // Visible for testing + private[history] def isBlacklisted(path: Path): Boolean = { + blacklist.containsKey(path.getName) + } + + private def blacklist(path: Path): Unit = { + blacklist.put(path.getName, clock.getTimeMillis()) + } + + /** + * Removes expired entries in the blacklist, according to the provided `expireTimeInSeconds`. + */ + private def clearBlacklist(expireTimeInSeconds: Long): Unit = { + val expiredThreshold = clock.getTimeMillis() - expireTimeInSeconds * 1000 + blacklist.asScala.retain((_, creationTime) => creationTime >= expiredThreshold) + } + private val activeUIs = new mutable.HashMap[(String, Option[String]), LoadedAppUI]() /** @@ -418,7 +438,7 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) // reading a garbage file is safe, but we would log an error which can be scary to // the end-user. !entry.getPath().getName().startsWith(".") && - SparkHadoopUtil.get.checkAccessPermission(entry, FsAction.READ) + !isBlacklisted(entry.getPath) } .filter { entry => try { @@ -461,32 +481,37 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) logDebug(s"New/updated attempts found: ${updated.size} ${updated.map(_.getPath)}") } - val tasks = updated.map { entry => + val tasks = updated.flatMap { entry => try { - replayExecutor.submit(new Runnable { + val task: Future[Unit] = replayExecutor.submit(new Runnable { override def run(): Unit = mergeApplicationListing(entry, newLastScanTime, true) - }) + }, Unit) + Some(task -> entry.getPath) } catch { // let the iteration over the updated entries break, since an exception on // replayExecutor.submit (..) indicates the ExecutorService is unable // to take any more submissions at this time case e: Exception => logError(s"Exception while submitting event log for replay", e) - null + None } - }.filter(_ != null) + } pendingReplayTasksCount.addAndGet(tasks.size) // Wait for all tasks to finish. This makes sure that checkForLogs // is not scheduled again while some tasks are already running in // the replayExecutor. - tasks.foreach { task => + tasks.foreach { case (task, path) => try { task.get() } catch { case e: InterruptedException => throw e + case e: ExecutionException if e.getCause.isInstanceOf[AccessControlException] => + // We don't have read permissions on the log file + logWarning(s"Unable to read log $path", e.getCause) + blacklist(path) case e: Exception => logError("Exception while merging application listings", e) } finally { @@ -779,6 +804,8 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) listing.delete(classOf[LogInfo], log.logPath) } } + // Clean the blacklist from the expired entries. + clearBlacklist(CLEAN_INTERVAL_S) } /** @@ -938,13 +965,17 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) } private def deleteLog(log: Path): Unit = { - try { - fs.delete(log, true) - } catch { - case _: AccessControlException => - logInfo(s"No permission to delete $log, ignoring.") - case ioe: IOException => - logError(s"IOException in cleaning $log", ioe) + if (isBlacklisted(log)) { + logDebug(s"Skipping deleting $log as we don't have permissions on it.") + } else { + try { + fs.delete(log, true) + } catch { + case _: AccessControlException => + logInfo(s"No permission to delete $log, ignoring.") + case ioe: IOException => + logError(s"IOException in cleaning $log", ioe) + } } } diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala index 2c78c15773af2..e1184248af460 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala @@ -121,10 +121,18 @@ private[deploy] class Master( } // Alternative application submission gateway that is stable across Spark versions - private val restServerEnabled = conf.getBoolean("spark.master.rest.enabled", true) + private val restServerEnabled = conf.getBoolean("spark.master.rest.enabled", false) private var restServer: Option[StandaloneRestServer] = None private var restServerBoundPort: Option[Int] = None + { + val authKey = SecurityManager.SPARK_AUTH_SECRET_CONF + require(conf.getOption(authKey).isEmpty || !restServerEnabled, + s"The RestSubmissionServer does not support authentication via ${authKey}. Either turn " + + "off the RestSubmissionServer with spark.master.rest.enabled=false, or do not use " + + "authentication.") + } + override def onStart(): Unit = { logInfo("Starting Spark master at " + masterUrl) logInfo(s"Running Spark version ${org.apache.spark.SPARK_VERSION}") diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionClient.scala b/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionClient.scala index 742a95841a138..31a8e3e60c067 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionClient.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionClient.scala @@ -233,30 +233,44 @@ private[spark] class RestSubmissionClient(master: String) extends Logging { private[rest] def readResponse(connection: HttpURLConnection): SubmitRestProtocolResponse = { import scala.concurrent.ExecutionContext.Implicits.global val responseFuture = Future { - val dataStream = - if (connection.getResponseCode == HttpServletResponse.SC_OK) { - connection.getInputStream - } else { - connection.getErrorStream + val responseCode = connection.getResponseCode + + if (responseCode != HttpServletResponse.SC_OK) { + val errString = Some(Source.fromInputStream(connection.getErrorStream()) + .getLines().mkString("\n")) + if (responseCode == HttpServletResponse.SC_INTERNAL_SERVER_ERROR && + !connection.getContentType().contains("application/json")) { + throw new SubmitRestProtocolException(s"Server responded with exception:\n${errString}") + } + logError(s"Server responded with error:\n${errString}") + val error = new ErrorResponse + if (responseCode == RestSubmissionServer.SC_UNKNOWN_PROTOCOL_VERSION) { + error.highestProtocolVersion = RestSubmissionServer.PROTOCOL_VERSION + } + error.message = errString.get + error + } else { + val dataStream = connection.getInputStream + + // If the server threw an exception while writing a response, it will not have a body + if (dataStream == null) { + throw new SubmitRestProtocolException("Server returned empty body") + } + val responseJson = Source.fromInputStream(dataStream).mkString + logDebug(s"Response from the server:\n$responseJson") + val response = SubmitRestProtocolMessage.fromJson(responseJson) + response.validate() + response match { + // If the response is an error, log the message + case error: ErrorResponse => + logError(s"Server responded with error:\n${error.message}") + error + // Otherwise, simply return the response + case response: SubmitRestProtocolResponse => response + case unexpected => + throw new SubmitRestProtocolException( + s"Message received from server was not a response:\n${unexpected.toJson}") } - // If the server threw an exception while writing a response, it will not have a body - if (dataStream == null) { - throw new SubmitRestProtocolException("Server returned empty body") - } - val responseJson = Source.fromInputStream(dataStream).mkString - logDebug(s"Response from the server:\n$responseJson") - val response = SubmitRestProtocolMessage.fromJson(responseJson) - response.validate() - response match { - // If the response is an error, log the message - case error: ErrorResponse => - logError(s"Server responded with error:\n${error.message}") - error - // Otherwise, simply return the response - case response: SubmitRestProtocolResponse => response - case unexpected => - throw new SubmitRestProtocolException( - s"Message received from server was not a response:\n${unexpected.toJson}") } } diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionServer.scala b/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionServer.scala index 3d99d085408c6..e59bf3f0eaf44 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionServer.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionServer.scala @@ -51,6 +51,7 @@ private[spark] abstract class RestSubmissionServer( val host: String, val requestedPort: Int, val masterConf: SparkConf) extends Logging { + protected val submitRequestServlet: SubmitRequestServlet protected val killRequestServlet: KillRequestServlet protected val statusRequestServlet: StatusRequestServlet diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala index ee1ca0bba5749..d5ea2523c628b 100755 --- a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala @@ -36,7 +36,7 @@ import org.apache.spark.deploy.DeployMessages._ import org.apache.spark.deploy.ExternalShuffleService import org.apache.spark.deploy.master.{DriverState, Master} import org.apache.spark.deploy.worker.ui.WorkerWebUI -import org.apache.spark.internal.Logging +import org.apache.spark.internal.{config, Logging} import org.apache.spark.metrics.MetricsSystem import org.apache.spark.rpc._ import org.apache.spark.util.{SparkUncaughtExceptionHandler, ThreadUtils, Utils} @@ -758,6 +758,7 @@ private[deploy] class Worker( private[deploy] object Worker extends Logging { val SYSTEM_NAME = "sparkWorker" val ENDPOINT_NAME = "Worker" + private val SSL_NODE_LOCAL_CONFIG_PATTERN = """\-Dspark\.ssl\.useNodeLocalConf\=(.+)""".r def main(argStrings: Array[String]) { Thread.setDefaultUncaughtExceptionHandler(new SparkUncaughtExceptionHandler( @@ -772,7 +773,7 @@ private[deploy] object Worker extends Logging { // bound, we may launch no more than one external shuffle service on each host. // When this happens, we should give explicit reason of failure instead of fail silently. For // more detail see SPARK-20989. - val externalShuffleServiceEnabled = conf.getBoolean("spark.shuffle.service.enabled", false) + val externalShuffleServiceEnabled = conf.get(config.SHUFFLE_SERVICE_ENABLED) val sparkWorkerInstances = scala.sys.env.getOrElse("SPARK_WORKER_INSTANCES", "1").toInt require(externalShuffleServiceEnabled == false || sparkWorkerInstances <= 1, "Starting multiple workers on one host is failed because we may launch no more than one " + @@ -803,9 +804,8 @@ private[deploy] object Worker extends Logging { } def isUseLocalNodeSSLConfig(cmd: Command): Boolean = { - val pattern = """\-Dspark\.ssl\.useNodeLocalConf\=(.+)""".r val result = cmd.javaOpts.collectFirst { - case pattern(_result) => _result.toBoolean + case SSL_NODE_LOCAL_CONFIG_PATTERN(_result) => _result.toBoolean } result.getOrElse(false) } diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index b1856ff0f3247..072277cb78dc1 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -38,7 +38,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.internal.config._ import org.apache.spark.memory.{SparkOutOfMemoryError, TaskMemoryManager} import org.apache.spark.rpc.RpcTimeout -import org.apache.spark.scheduler.{DirectTaskResult, IndirectTaskResult, Task, TaskDescription} +import org.apache.spark.scheduler._ import org.apache.spark.shuffle.FetchFailedException import org.apache.spark.storage.{StorageLevel, TaskResultBlockId} import org.apache.spark.util._ @@ -148,7 +148,8 @@ private[spark] class Executor( private val runningTasks = new ConcurrentHashMap[Long, TaskRunner] // Executor for the heartbeat task. - private val heartbeater = ThreadUtils.newDaemonSingleThreadScheduledExecutor("driver-heartbeater") + private val heartbeater = new Heartbeater(env.memoryManager, reportHeartBeat, + "executor-heartbeater", conf.getTimeAsMs("spark.executor.heartbeatInterval", "10s")) // must be initialized before running startDriverHeartbeat() private val heartbeatReceiverRef = @@ -167,7 +168,7 @@ private[spark] class Executor( */ private var heartbeatFailures = 0 - startDriverHeartbeater() + heartbeater.start() private[executor] def numRunningTasks: Int = runningTasks.size() @@ -216,8 +217,12 @@ private[spark] class Executor( def stop(): Unit = { env.metricsSystem.report() - heartbeater.shutdown() - heartbeater.awaitTermination(10, TimeUnit.SECONDS) + try { + heartbeater.stop() + } catch { + case NonFatal(e) => + logWarning("Unable to stop heartbeater", e) + } threadPool.shutdown() if (!isLocal) { env.stop() @@ -363,14 +368,14 @@ private[spark] class Executor( threadMXBean.getCurrentThreadCpuTime } else 0L var threwException = true - val value = try { + val value = Utils.tryWithSafeFinally { val res = task.run( taskAttemptId = taskId, attemptNumber = taskDescription.attemptNumber, metricsSystem = env.metricsSystem) threwException = false res - } finally { + } { val releasedLocks = env.blockManager.releaseAllLocksForTask(taskId) val freedMemory = taskMemoryManager.cleanUpAllAllocatedMemory() @@ -787,6 +792,9 @@ private[spark] class Executor( val accumUpdates = new ArrayBuffer[(Long, Seq[AccumulatorV2[_, _]])]() val curGCTime = computeTotalGcTime() + // get executor level memory metrics + val executorUpdates = heartbeater.getCurrentMetrics() + for (taskRunner <- runningTasks.values().asScala) { if (taskRunner.task != null) { taskRunner.task.metrics.mergeShuffleReadMetrics() @@ -795,7 +803,8 @@ private[spark] class Executor( } } - val message = Heartbeat(executorId, accumUpdates.toArray, env.blockManager.blockManagerId) + val message = Heartbeat(executorId, accumUpdates.toArray, env.blockManager.blockManagerId, + executorUpdates) try { val response = heartbeatReceiverRef.askSync[HeartbeatResponse]( message, RpcTimeout(conf, "spark.executor.heartbeatInterval", "10s")) @@ -815,21 +824,6 @@ private[spark] class Executor( } } } - - /** - * Schedules a task to report heartbeat and partial metrics for active tasks to driver. - */ - private def startDriverHeartbeater(): Unit = { - val intervalMs = conf.getTimeAsMs("spark.executor.heartbeatInterval", "10s") - - // Wait a random interval so the heartbeats don't end up in sync - val initialDelay = intervalMs + (math.random * intervalMs).asInstanceOf[Int] - - val heartbeatTask = new Runnable() { - override def run(): Unit = Utils.logUncaughtExceptions(reportHeartBeat()) - } - heartbeater.scheduleAtFixedRate(heartbeatTask, initialDelay, intervalMs, TimeUnit.MILLISECONDS) - } } private[spark] object Executor { diff --git a/core/src/main/scala/org/apache/spark/executor/ExecutorMetrics.scala b/core/src/main/scala/org/apache/spark/executor/ExecutorMetrics.scala new file mode 100644 index 0000000000000..2933f3ba6d3b5 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/executor/ExecutorMetrics.scala @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.executor + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.metrics.ExecutorMetricType + +/** + * :: DeveloperApi :: + * Metrics tracked for executors and the driver. + * + * Executor-level metrics are sent from each executor to the driver as part of the Heartbeat. + */ +@DeveloperApi +class ExecutorMetrics private[spark] extends Serializable { + + // Metrics are indexed by MetricGetter.values + private val metrics = new Array[Long](ExecutorMetricType.values.length) + + // the first element is initialized to -1, indicating that the values for the array + // haven't been set yet. + metrics(0) = -1 + + /** Returns the value for the specified metricType. */ + def getMetricValue(metricType: ExecutorMetricType): Long = { + metrics(ExecutorMetricType.metricIdxMap(metricType)) + } + + /** Returns true if the values for the metrics have been set, false otherwise. */ + def isSet(): Boolean = metrics(0) > -1 + + private[spark] def this(metrics: Array[Long]) { + this() + Array.copy(metrics, 0, this.metrics, 0, Math.min(metrics.size, this.metrics.size)) + } + + /** + * Constructor: create the ExecutorMetrics with the values specified. + * + * @param executorMetrics map of executor metric name to value + */ + private[spark] def this(executorMetrics: Map[String, Long]) { + this() + (0 until ExecutorMetricType.values.length).foreach { idx => + metrics(idx) = executorMetrics.getOrElse(ExecutorMetricType.values(idx).name, 0L) + } + } + + /** + * Compare the specified executor metrics values with the current executor metric values, + * and update the value for any metrics where the new value for the metric is larger. + * + * @param executorMetrics the executor metrics to compare + * @return if there is a new peak value for any metric + */ + private[spark] def compareAndUpdatePeakValues(executorMetrics: ExecutorMetrics): Boolean = { + var updated = false + + (0 until ExecutorMetricType.values.length).foreach { idx => + if (executorMetrics.metrics(idx) > metrics(idx)) { + updated = true + metrics(idx) = executorMetrics.metrics(idx) + } + } + updated + } +} diff --git a/core/src/main/scala/org/apache/spark/executor/ExecutorSource.scala b/core/src/main/scala/org/apache/spark/executor/ExecutorSource.scala index 669ce63325d0e..a8264022a0aff 100644 --- a/core/src/main/scala/org/apache/spark/executor/ExecutorSource.scala +++ b/core/src/main/scala/org/apache/spark/executor/ExecutorSource.scala @@ -17,9 +17,12 @@ package org.apache.spark.executor +import java.lang.management.ManagementFactory import java.util.concurrent.ThreadPoolExecutor +import javax.management.{MBeanServer, ObjectName} import scala.collection.JavaConverters._ +import scala.util.control.NonFatal import com.codahale.metrics.{Gauge, MetricRegistry} import org.apache.hadoop.fs.FileSystem @@ -73,6 +76,24 @@ class ExecutorSource(threadPool: ThreadPoolExecutor, executorId: String) extends registerFileSystemStat(scheme, "write_ops", _.getWriteOps(), 0) } + // Dropwizard metrics gauge measuring the executor's process CPU time. + // This Gauge will try to get and return the JVM Process CPU time or return -1 otherwise. + // The CPU time value is returned in nanoseconds. + // It will use proprietary extensions such as com.sun.management.OperatingSystemMXBean or + // com.ibm.lang.management.OperatingSystemMXBean, if available. + metricRegistry.register(MetricRegistry.name("jvmCpuTime"), new Gauge[Long] { + val mBean: MBeanServer = ManagementFactory.getPlatformMBeanServer + val name = new ObjectName("java.lang", "type", "OperatingSystem") + override def getValue: Long = { + try { + // return JVM process CPU time if the ProcessCpuTime method is available + mBean.getAttribute(name, "ProcessCpuTime").asInstanceOf[Long] + } catch { + case NonFatal(_) => -1L + } + } + }) + // Expose executor task metrics using the Dropwizard metrics system. // The list is taken from TaskMetrics.scala val METRIC_CPU_TIME = metricRegistry.counter(MetricRegistry.name("cpuTime")) diff --git a/core/src/main/scala/org/apache/spark/input/PortableDataStream.scala b/core/src/main/scala/org/apache/spark/input/PortableDataStream.scala index 17cdba4f1305b..ab020aaf6fa4f 100644 --- a/core/src/main/scala/org/apache/spark/input/PortableDataStream.scala +++ b/core/src/main/scala/org/apache/spark/input/PortableDataStream.scala @@ -47,7 +47,7 @@ private[spark] abstract class StreamFileInputFormat[T] def setMinPartitions(sc: SparkContext, context: JobContext, minPartitions: Int) { val defaultMaxSplitBytes = sc.getConf.get(config.FILES_MAX_PARTITION_BYTES) val openCostInBytes = sc.getConf.get(config.FILES_OPEN_COST_IN_BYTES) - val defaultParallelism = sc.defaultParallelism + val defaultParallelism = Math.max(sc.defaultParallelism, minPartitions) val files = listStatus(context).asScala val totalBytes = files.filterNot(_.isDirectory).map(_.getLen + openCostInBytes).sum val bytesPerCore = totalBytes / defaultParallelism diff --git a/core/src/main/scala/org/apache/spark/input/WholeTextFileInputFormat.scala b/core/src/main/scala/org/apache/spark/input/WholeTextFileInputFormat.scala index f47cd38d712c3..04c5c4b90e8a1 100644 --- a/core/src/main/scala/org/apache/spark/input/WholeTextFileInputFormat.scala +++ b/core/src/main/scala/org/apache/spark/input/WholeTextFileInputFormat.scala @@ -53,6 +53,19 @@ private[spark] class WholeTextFileInputFormat val totalLen = files.map(file => if (file.isDirectory) 0L else file.getLen).sum val maxSplitSize = Math.ceil(totalLen * 1.0 / (if (minPartitions == 0) 1 else minPartitions)).toLong + + // For small files we need to ensure the min split size per node & rack <= maxSplitSize + val config = context.getConfiguration + val minSplitSizePerNode = config.getLong(CombineFileInputFormat.SPLIT_MINSIZE_PERNODE, 0L) + val minSplitSizePerRack = config.getLong(CombineFileInputFormat.SPLIT_MINSIZE_PERRACK, 0L) + + if (maxSplitSize < minSplitSizePerNode) { + super.setMinSplitSizeNode(maxSplitSize) + } + + if (maxSplitSize < minSplitSizePerRack) { + super.setMinSplitSizeRack(maxSplitSize) + } super.setMaxSplitSize(maxSplitSize) } } diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index 38a043c85ae33..8d827189ebb57 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -69,9 +69,17 @@ package object config { .bytesConf(ByteUnit.KiB) .createWithDefaultString("100k") + private[spark] val EVENT_LOG_STAGE_EXECUTOR_METRICS = + ConfigBuilder("spark.eventLog.logStageExecutorMetrics.enabled") + .booleanConf + .createWithDefault(false) + private[spark] val EVENT_LOG_OVERWRITE = ConfigBuilder("spark.eventLog.overwrite").booleanConf.createWithDefault(false) + private[spark] val EVENT_LOG_CALLSITE_LONG_FORM = + ConfigBuilder("spark.eventLog.longForm.enabled").booleanConf.createWithDefault(false) + private[spark] val EXECUTOR_CLASS_PATH = ConfigBuilder(SparkLauncher.EXECUTOR_EXTRA_CLASSPATH).stringConf.createOptional @@ -111,6 +119,10 @@ package object config { .checkValue(_ >= 0, "The off-heap memory size must not be negative") .createWithDefault(0) + private[spark] val PYSPARK_EXECUTOR_MEMORY = ConfigBuilder("spark.executor.pyspark.memory") + .bytesConf(ByteUnit.MiB) + .createOptional + private[spark] val IS_PYTHON_APP = ConfigBuilder("spark.yarn.isPython").internal() .booleanConf.createWithDefault(false) @@ -137,6 +149,9 @@ package object config { private[spark] val SHUFFLE_SERVICE_ENABLED = ConfigBuilder("spark.shuffle.service.enabled").booleanConf.createWithDefault(false) + private[spark] val SHUFFLE_SERVICE_PORT = + ConfigBuilder("spark.shuffle.service.port").intConf.createWithDefault(7337) + private[spark] val KEYTAB = ConfigBuilder("spark.yarn.keytab") .doc("Location of user's keytab.") .stringConf.createOptional @@ -429,7 +444,11 @@ package object config { "external shuffle service, this feature can only be worked when external shuffle" + "service is newer than Spark 2.2.") .bytesConf(ByteUnit.BYTE) - .createWithDefault(Long.MaxValue) + // fetch-to-mem is guaranteed to fail if the message is bigger than 2 GB, so we might + // as well use fetch-to-disk in that case. The message includes some metadata in addition + // to the block data itself (in particular UploadBlock has a lot of metadata), so we leave + // extra room. + .createWithDefault(Int.MaxValue - 512) private[spark] val TASK_METRICS_TRACK_UPDATED_BLOCK_STATUSES = ConfigBuilder("spark.taskMetrics.trackUpdatedBlockStatuses") @@ -483,10 +502,11 @@ package object config { private[spark] val FORCE_DOWNLOAD_SCHEMES = ConfigBuilder("spark.yarn.dist.forceDownloadSchemes") - .doc("Comma-separated list of schemes for which files will be downloaded to the " + + .doc("Comma-separated list of schemes for which resources will be downloaded to the " + "local disk prior to being added to YARN's distributed cache. For use in cases " + "where the YARN service does not support schemes that are supported by Spark, like http, " + - "https and ftp.") + "https and ftp, or jars required to be in the local YARN client's classpath. Wildcard " + + "'*' is denoted to download resources for all the schemes.") .stringConf .toSequence .createWithDefault(Nil) @@ -559,4 +579,48 @@ package object config { .intConf .checkValue(v => v > 0, "The value should be a positive integer.") .createWithDefault(2000) + + private[spark] val MEMORY_MAP_LIMIT_FOR_TESTS = + ConfigBuilder("spark.storage.memoryMapLimitForTests") + .internal() + .doc("For testing only, controls the size of chunks when memory mapping a file") + .bytesConf(ByteUnit.BYTE) + .createWithDefault(Int.MaxValue) + + private[spark] val BARRIER_SYNC_TIMEOUT = + ConfigBuilder("spark.barrier.sync.timeout") + .doc("The timeout in seconds for each barrier() call from a barrier task. If the " + + "coordinator didn't receive all the sync messages from barrier tasks within the " + + "configed time, throw a SparkException to fail all the tasks. The default value is set " + + "to 31536000(3600 * 24 * 365) so the barrier() call shall wait for one year.") + .timeConf(TimeUnit.SECONDS) + .checkValue(v => v > 0, "The value should be a positive time value.") + .createWithDefaultString("365d") + + private[spark] val BARRIER_MAX_CONCURRENT_TASKS_CHECK_INTERVAL = + ConfigBuilder("spark.scheduler.barrier.maxConcurrentTasksCheck.interval") + .doc("Time in seconds to wait between a max concurrent tasks check failure and the next " + + "check. A max concurrent tasks check ensures the cluster can launch more concurrent " + + "tasks than required by a barrier stage on job submitted. The check can fail in case " + + "a cluster has just started and not enough executors have registered, so we wait for a " + + "little while and try to perform the check again. If the check fails more than a " + + "configured max failure times for a job then fail current job submission. Note this " + + "config only applies to jobs that contain one or more barrier stages, we won't perform " + + "the check on non-barrier jobs.") + .timeConf(TimeUnit.SECONDS) + .createWithDefaultString("15s") + + private[spark] val BARRIER_MAX_CONCURRENT_TASKS_CHECK_MAX_FAILURES = + ConfigBuilder("spark.scheduler.barrier.maxConcurrentTasksCheck.maxFailures") + .doc("Number of max concurrent tasks check failures allowed before fail a job submission. " + + "A max concurrent tasks check ensures the cluster can launch more concurrent tasks than " + + "required by a barrier stage on job submitted. The check can fail in case a cluster " + + "has just started and not enough executors have registered, so we wait for a little " + + "while and try to perform the check again. If the check fails more than a configured " + + "max failure times for a job then fail current job submission. Note this config only " + + "applies to jobs that contain one or more barrier stages, we won't perform the check on " + + "non-barrier jobs.") + .intConf + .checkValue(v => v > 0, "The max failures should be a positive value.") + .createWithDefault(40) } diff --git a/core/src/main/scala/org/apache/spark/memory/MemoryManager.scala b/core/src/main/scala/org/apache/spark/memory/MemoryManager.scala index 0641adc2ab699..4fde2d0beaa71 100644 --- a/core/src/main/scala/org/apache/spark/memory/MemoryManager.scala +++ b/core/src/main/scala/org/apache/spark/memory/MemoryManager.scala @@ -180,6 +180,34 @@ private[spark] abstract class MemoryManager( onHeapStorageMemoryPool.memoryUsed + offHeapStorageMemoryPool.memoryUsed } + /** + * On heap execution memory currently in use, in bytes. + */ + final def onHeapExecutionMemoryUsed: Long = synchronized { + onHeapExecutionMemoryPool.memoryUsed + } + + /** + * Off heap execution memory currently in use, in bytes. + */ + final def offHeapExecutionMemoryUsed: Long = synchronized { + offHeapExecutionMemoryPool.memoryUsed + } + + /** + * On heap storage memory currently in use, in bytes. + */ + final def onHeapStorageMemoryUsed: Long = synchronized { + onHeapStorageMemoryPool.memoryUsed + } + + /** + * Off heap storage memory currently in use, in bytes. + */ + final def offHeapStorageMemoryUsed: Long = synchronized { + offHeapStorageMemoryPool.memoryUsed + } + /** * Returns the execution memory consumption, in bytes, for the given task. */ diff --git a/core/src/main/scala/org/apache/spark/metrics/ExecutorMetricType.scala b/core/src/main/scala/org/apache/spark/metrics/ExecutorMetricType.scala new file mode 100644 index 0000000000000..cd10dad25e87b --- /dev/null +++ b/core/src/main/scala/org/apache/spark/metrics/ExecutorMetricType.scala @@ -0,0 +1,104 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.metrics + +import java.lang.management.{BufferPoolMXBean, ManagementFactory} +import javax.management.ObjectName + +import org.apache.spark.memory.MemoryManager + +/** + * Executor metric types for executor-level metrics stored in ExecutorMetrics. + */ +sealed trait ExecutorMetricType { + private[spark] def getMetricValue(memoryManager: MemoryManager): Long + private[spark] val name = getClass().getName().stripSuffix("$").split("""\.""").last +} + +private[spark] abstract class MemoryManagerExecutorMetricType( + f: MemoryManager => Long) extends ExecutorMetricType { + override private[spark] def getMetricValue(memoryManager: MemoryManager): Long = { + f(memoryManager) + } +} + +private[spark] abstract class MBeanExecutorMetricType(mBeanName: String) + extends ExecutorMetricType { + private val bean = ManagementFactory.newPlatformMXBeanProxy( + ManagementFactory.getPlatformMBeanServer, + new ObjectName(mBeanName).toString, classOf[BufferPoolMXBean]) + + override private[spark] def getMetricValue(memoryManager: MemoryManager): Long = { + bean.getMemoryUsed + } +} + +case object JVMHeapMemory extends ExecutorMetricType { + override private[spark] def getMetricValue(memoryManager: MemoryManager): Long = { + ManagementFactory.getMemoryMXBean.getHeapMemoryUsage().getUsed() + } +} + +case object JVMOffHeapMemory extends ExecutorMetricType { + override private[spark] def getMetricValue(memoryManager: MemoryManager): Long = { + ManagementFactory.getMemoryMXBean.getNonHeapMemoryUsage().getUsed() + } +} + +case object OnHeapExecutionMemory extends MemoryManagerExecutorMetricType( + _.onHeapExecutionMemoryUsed) + +case object OffHeapExecutionMemory extends MemoryManagerExecutorMetricType( + _.offHeapExecutionMemoryUsed) + +case object OnHeapStorageMemory extends MemoryManagerExecutorMetricType( + _.onHeapStorageMemoryUsed) + +case object OffHeapStorageMemory extends MemoryManagerExecutorMetricType( + _.offHeapStorageMemoryUsed) + +case object OnHeapUnifiedMemory extends MemoryManagerExecutorMetricType( + (m => m.onHeapExecutionMemoryUsed + m.onHeapStorageMemoryUsed)) + +case object OffHeapUnifiedMemory extends MemoryManagerExecutorMetricType( + (m => m.offHeapExecutionMemoryUsed + m.offHeapStorageMemoryUsed)) + +case object DirectPoolMemory extends MBeanExecutorMetricType( + "java.nio:type=BufferPool,name=direct") + +case object MappedPoolMemory extends MBeanExecutorMetricType( + "java.nio:type=BufferPool,name=mapped") + +private[spark] object ExecutorMetricType { + // List of all executor metric types + val values = IndexedSeq( + JVMHeapMemory, + JVMOffHeapMemory, + OnHeapExecutionMemory, + OffHeapExecutionMemory, + OnHeapStorageMemory, + OffHeapStorageMemory, + OnHeapUnifiedMemory, + OffHeapUnifiedMemory, + DirectPoolMemory, + MappedPoolMemory + ) + + // Map of executor metric type to its index in values. + val metricIdxMap = + Map[ExecutorMetricType, Int](ExecutorMetricType.values.zipWithIndex: _*) +} diff --git a/core/src/main/scala/org/apache/spark/network/BlockDataManager.scala b/core/src/main/scala/org/apache/spark/network/BlockDataManager.scala index b3f8bfe8b1d48..e94a01244474c 100644 --- a/core/src/main/scala/org/apache/spark/network/BlockDataManager.scala +++ b/core/src/main/scala/org/apache/spark/network/BlockDataManager.scala @@ -20,6 +20,7 @@ package org.apache.spark.network import scala.reflect.ClassTag import org.apache.spark.network.buffer.ManagedBuffer +import org.apache.spark.network.client.StreamCallbackWithID import org.apache.spark.storage.{BlockId, StorageLevel} private[spark] @@ -43,6 +44,17 @@ trait BlockDataManager { level: StorageLevel, classTag: ClassTag[_]): Boolean + /** + * Put the given block that will be received as a stream. + * + * When this method is called, the block data itself is not available -- it will be passed to the + * returned StreamCallbackWithID. + */ + def putBlockDataAsStream( + blockId: BlockId, + level: StorageLevel, + classTag: ClassTag[_]): StreamCallbackWithID + /** * Release locks acquired by [[putBlockData()]] and [[getBlockData()]]. */ diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala index eb4cf94164fd4..7076701421e2e 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala @@ -26,9 +26,9 @@ import scala.reflect.ClassTag import org.apache.spark.internal.Logging import org.apache.spark.network.BlockDataManager import org.apache.spark.network.buffer.NioManagedBuffer -import org.apache.spark.network.client.{RpcResponseCallback, TransportClient} +import org.apache.spark.network.client.{RpcResponseCallback, StreamCallbackWithID, TransportClient} import org.apache.spark.network.server.{OneForOneStreamManager, RpcHandler, StreamManager} -import org.apache.spark.network.shuffle.protocol.{BlockTransferMessage, OpenBlocks, StreamHandle, UploadBlock} +import org.apache.spark.network.shuffle.protocol._ import org.apache.spark.serializer.Serializer import org.apache.spark.storage.{BlockId, StorageLevel} @@ -73,10 +73,32 @@ class NettyBlockRpcServer( } val data = new NioManagedBuffer(ByteBuffer.wrap(uploadBlock.blockData)) val blockId = BlockId(uploadBlock.blockId) + logDebug(s"Receiving replicated block $blockId with level ${level} " + + s"from ${client.getSocketAddress}") blockManager.putBlockData(blockId, data, level, classTag) responseContext.onSuccess(ByteBuffer.allocate(0)) } } + override def receiveStream( + client: TransportClient, + messageHeader: ByteBuffer, + responseContext: RpcResponseCallback): StreamCallbackWithID = { + val message = + BlockTransferMessage.Decoder.fromByteBuffer(messageHeader).asInstanceOf[UploadBlockStream] + val (level: StorageLevel, classTag: ClassTag[_]) = { + serializer + .newInstance() + .deserialize(ByteBuffer.wrap(message.metadata)) + .asInstanceOf[(StorageLevel, ClassTag[_])] + } + val blockId = BlockId(message.blockId) + logDebug(s"Receiving replicated block $blockId with level ${level} as stream " + + s"from ${client.getSocketAddress}") + // This will return immediately, but will setup a callback on streamData which will still + // do all the processing in the netty thread. + blockManager.putBlockDataAsStream(blockId, level, classTag) + } + override def getStreamManager(): StreamManager = streamManager } diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala index b7d8c35032763..1905632a936d3 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala @@ -27,13 +27,14 @@ import scala.reflect.ClassTag import com.codahale.metrics.{Metric, MetricSet} import org.apache.spark.{SecurityManager, SparkConf} +import org.apache.spark.internal.config import org.apache.spark.network._ -import org.apache.spark.network.buffer.ManagedBuffer +import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer} import org.apache.spark.network.client.{RpcResponseCallback, TransportClientBootstrap, TransportClientFactory} import org.apache.spark.network.crypto.{AuthClientBootstrap, AuthServerBootstrap} import org.apache.spark.network.server._ import org.apache.spark.network.shuffle.{BlockFetchingListener, OneForOneBlockFetcher, RetryingBlockFetcher, TempFileManager} -import org.apache.spark.network.shuffle.protocol.UploadBlock +import org.apache.spark.network.shuffle.protocol.{UploadBlock, UploadBlockStream} import org.apache.spark.network.util.JavaUtils import org.apache.spark.serializer.JavaSerializer import org.apache.spark.storage.{BlockId, StorageLevel} @@ -148,20 +149,28 @@ private[spark] class NettyBlockTransferService( // Everything else is encoded using our binary protocol. val metadata = JavaUtils.bufferToArray(serializer.newInstance().serialize((level, classTag))) - // Convert or copy nio buffer into array in order to serialize it. - val array = JavaUtils.bufferToArray(blockData.nioByteBuffer()) + val asStream = blockData.size() > conf.get(config.MAX_REMOTE_BLOCK_SIZE_FETCH_TO_MEM) + val callback = new RpcResponseCallback { + override def onSuccess(response: ByteBuffer): Unit = { + logTrace(s"Successfully uploaded block $blockId${if (asStream) " as stream" else ""}") + result.success((): Unit) + } - client.sendRpc(new UploadBlock(appId, execId, blockId.name, metadata, array).toByteBuffer, - new RpcResponseCallback { - override def onSuccess(response: ByteBuffer): Unit = { - logTrace(s"Successfully uploaded block $blockId") - result.success((): Unit) - } - override def onFailure(e: Throwable): Unit = { - logError(s"Error while uploading block $blockId", e) - result.failure(e) - } - }) + override def onFailure(e: Throwable): Unit = { + logError(s"Error while uploading $blockId${if (asStream) " as stream" else ""}", e) + result.failure(e) + } + } + if (asStream) { + val streamHeader = new UploadBlockStream(blockId.name, metadata).toByteBuffer + client.uploadStream(new NioManagedBuffer(streamHeader), blockData, callback) + } else { + // Convert or copy nio buffer into array in order to serialize it. + val array = JavaUtils.bufferToArray(blockData.nioByteBuffer()) + + client.sendRpc(new UploadBlock(appId, execId, blockId.name, metadata, array).toByteBuffer, + callback) + } result.future } diff --git a/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala b/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala index 4e036c2ed49b5..23cf19d55b4ae 100644 --- a/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala @@ -30,7 +30,7 @@ private[spark] class BlockRDD[T: ClassTag](sc: SparkContext, @transient val blockIds: Array[BlockId]) extends RDD[T](sc, Nil) { - @transient lazy val _locations = BlockManager.blockIdsToHosts(blockIds, SparkEnv.get) + @transient lazy val _locations = BlockManager.blockIdsToLocations(blockIds, SparkEnv.get) @volatile private var _isValid = true override def getPartitions: Array[Partition] = { diff --git a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala index 44895abc7bd4d..3974580cfaa11 100644 --- a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala @@ -278,7 +278,7 @@ class HadoopRDD[K, V]( null } // Register an on-task-completion callback to close the input stream. - context.addTaskCompletionListener { context => + context.addTaskCompletionListener[Unit] { context => // Update the bytes read before closing is to make sure lingering bytesRead statistics in // this thread get correctly added. updateBytesRead() diff --git a/core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala b/core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala index aab46b8954bf7..56ef3e107a980 100644 --- a/core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala @@ -77,7 +77,7 @@ class JdbcRDD[T: ClassTag]( override def compute(thePart: Partition, context: TaskContext): Iterator[T] = new NextIterator[T] { - context.addTaskCompletionListener{ context => closeIfNeeded() } + context.addTaskCompletionListener[Unit]{ context => closeIfNeeded() } val part = thePart.asInstanceOf[JdbcPartition] val conn = getConnection() val stmt = conn.prepareStatement(sql, ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY) diff --git a/core/src/main/scala/org/apache/spark/rdd/MapPartitionsRDD.scala b/core/src/main/scala/org/apache/spark/rdd/MapPartitionsRDD.scala index e4587c96eae1c..aa61997122cf4 100644 --- a/core/src/main/scala/org/apache/spark/rdd/MapPartitionsRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/MapPartitionsRDD.scala @@ -23,11 +23,25 @@ import org.apache.spark.{Partition, TaskContext} /** * An RDD that applies the provided function to every partition of the parent RDD. + * + * @param prev the parent RDD. + * @param f The function used to map a tuple of (TaskContext, partition index, input iterator) to + * an output iterator. + * @param preservesPartitioning Whether the input function preserves the partitioner, which should + * be `false` unless `prev` is a pair RDD and the input function + * doesn't modify the keys. + * @param isFromBarrier Indicates whether this RDD is transformed from an RDDBarrier, a stage + * containing at least one RDDBarrier shall be turned into a barrier stage. + * @param isOrderSensitive whether or not the function is order-sensitive. If it's order + * sensitive, it may return totally different result when the input order + * is changed. Mostly stateful functions are order-sensitive. */ private[spark] class MapPartitionsRDD[U: ClassTag, T: ClassTag]( var prev: RDD[T], f: (TaskContext, Int, Iterator[T]) => Iterator[U], // (TaskContext, partition index, iterator) - preservesPartitioning: Boolean = false) + preservesPartitioning: Boolean = false, + isFromBarrier: Boolean = false, + isOrderSensitive: Boolean = false) extends RDD[U](prev) { override val partitioner = if (preservesPartitioning) firstParent[T].partitioner else None @@ -41,4 +55,15 @@ private[spark] class MapPartitionsRDD[U: ClassTag, T: ClassTag]( super.clearDependencies() prev = null } + + @transient protected lazy override val isBarrier_ : Boolean = + isFromBarrier || dependencies.exists(_.rdd.isBarrier()) + + override protected def getOutputDeterministicLevel = { + if (isOrderSensitive && prev.outputDeterministicLevel == DeterministicLevel.UNORDERED) { + DeterministicLevel.INDETERMINATE + } else { + super.getOutputDeterministicLevel + } + } } diff --git a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala index ff66a04859d10..2d66d25ba39fa 100644 --- a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala @@ -214,7 +214,7 @@ class NewHadoopRDD[K, V]( } // Register an on-task-completion callback to close the input stream. - context.addTaskCompletionListener { context => + context.addTaskCompletionListener[Unit] { context => // Update the bytesRead before closing is to make sure lingering bytesRead statistics in // this thread get correctly added. updateBytesRead() diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index 0574abdca32ac..61ad6dfdb2215 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -33,7 +33,7 @@ import org.apache.hadoop.mapred.TextOutputFormat import org.apache.spark._ import org.apache.spark.Partitioner._ -import org.apache.spark.annotation.{DeveloperApi, Since} +import org.apache.spark.annotation.{DeveloperApi, Experimental, Since} import org.apache.spark.api.java.JavaRDD import org.apache.spark.internal.Logging import org.apache.spark.partial.BoundedDouble @@ -462,8 +462,9 @@ abstract class RDD[T: ClassTag]( // include a shuffle step so that our upstream tasks are still distributed new CoalescedRDD( - new ShuffledRDD[Int, T, T](mapPartitionsWithIndex(distributePartition), - new HashPartitioner(numPartitions)), + new ShuffledRDD[Int, T, T]( + mapPartitionsWithIndexInternal(distributePartition, isOrderSensitive = true), + new HashPartitioner(numPartitions)), numPartitions, partitionCoalescer).values } else { @@ -807,16 +808,21 @@ abstract class RDD[T: ClassTag]( * serializable and don't require closure cleaning. * * @param preservesPartitioning indicates whether the input function preserves the partitioner, - * which should be `false` unless this is a pair RDD and the input function doesn't modify - * the keys. + * which should be `false` unless this is a pair RDD and the input + * function doesn't modify the keys. + * @param isOrderSensitive whether or not the function is order-sensitive. If it's order + * sensitive, it may return totally different result when the input order + * is changed. Mostly stateful functions are order-sensitive. */ private[spark] def mapPartitionsWithIndexInternal[U: ClassTag]( f: (Int, Iterator[T]) => Iterator[U], - preservesPartitioning: Boolean = false): RDD[U] = withScope { + preservesPartitioning: Boolean = false, + isOrderSensitive: Boolean = false): RDD[U] = withScope { new MapPartitionsRDD( this, (context: TaskContext, index: Int, iter: Iterator[T]) => f(index, iter), - preservesPartitioning) + preservesPartitioning = preservesPartitioning, + isOrderSensitive = isOrderSensitive) } /** @@ -1636,6 +1642,16 @@ abstract class RDD[T: ClassTag]( } } + /** + * Return whether this RDD is reliably checkpointed and materialized. + */ + private[rdd] def isReliablyCheckpointed: Boolean = { + checkpointData match { + case Some(reliable: ReliableRDDCheckpointData[_]) if reliable.isCheckpointed => true + case _ => false + } + } + /** * Gets the name of the directory to which this RDD was checkpointed. * This is not defined if the RDD is checkpointed locally. @@ -1647,6 +1663,22 @@ abstract class RDD[T: ClassTag]( } } + /** + * :: Experimental :: + * Marks the current stage as a barrier stage, where Spark must launch all tasks together. + * In case of a task failure, instead of only restarting the failed task, Spark will abort the + * entire stage and re-launch all tasks for this stage. + * The barrier execution mode feature is experimental and it only handles limited scenarios. + * Please read the linked SPIP and design docs to understand the limitations and future plans. + * @return an [[RDDBarrier]] instance that provides actions within a barrier stage + * @see [[org.apache.spark.BarrierTaskContext]] + * @see SPIP: Barrier Execution Mode + * @see Design Doc + */ + @Experimental + @Since("2.4.0") + def barrier(): RDDBarrier[T] = withScope(new RDDBarrier[T](this)) + // ======================================================================= // Other internal methods and fields // ======================================================================= @@ -1839,6 +1871,81 @@ abstract class RDD[T: ClassTag]( def toJavaRDD() : JavaRDD[T] = { new JavaRDD(this)(elementClassTag) } + + /** + * Whether the RDD is in a barrier stage. Spark must launch all the tasks at the same time for a + * barrier stage. + * + * An RDD is in a barrier stage, if at least one of its parent RDD(s), or itself, are mapped from + * an [[RDDBarrier]]. This function always returns false for a [[ShuffledRDD]], since a + * [[ShuffledRDD]] indicates start of a new stage. + * + * A [[MapPartitionsRDD]] can be transformed from an [[RDDBarrier]], under that case the + * [[MapPartitionsRDD]] shall be marked as barrier. + */ + private[spark] def isBarrier(): Boolean = isBarrier_ + + // From performance concern, cache the value to avoid repeatedly compute `isBarrier()` on a long + // RDD chain. + @transient protected lazy val isBarrier_ : Boolean = + dependencies.filter(!_.isInstanceOf[ShuffleDependency[_, _, _]]).exists(_.rdd.isBarrier()) + + /** + * Returns the deterministic level of this RDD's output. Please refer to [[DeterministicLevel]] + * for the definition. + * + * By default, an reliably checkpointed RDD, or RDD without parents(root RDD) is DETERMINATE. For + * RDDs with parents, we will generate a deterministic level candidate per parent according to + * the dependency. The deterministic level of the current RDD is the deterministic level + * candidate that is deterministic least. Please override [[getOutputDeterministicLevel]] to + * provide custom logic of calculating output deterministic level. + */ + // TODO: make it public so users can set deterministic level to their custom RDDs. + // TODO: this can be per-partition. e.g. UnionRDD can have different deterministic level for + // different partitions. + private[spark] final lazy val outputDeterministicLevel: DeterministicLevel.Value = { + if (isReliablyCheckpointed) { + DeterministicLevel.DETERMINATE + } else { + getOutputDeterministicLevel + } + } + + @DeveloperApi + protected def getOutputDeterministicLevel: DeterministicLevel.Value = { + val deterministicLevelCandidates = dependencies.map { + // The shuffle is not really happening, treat it like narrow dependency and assume the output + // deterministic level of current RDD is same as parent. + case dep: ShuffleDependency[_, _, _] if dep.rdd.partitioner.exists(_ == dep.partitioner) => + dep.rdd.outputDeterministicLevel + + case dep: ShuffleDependency[_, _, _] => + if (dep.rdd.outputDeterministicLevel == DeterministicLevel.INDETERMINATE) { + // If map output was indeterminate, shuffle output will be indeterminate as well + DeterministicLevel.INDETERMINATE + } else if (dep.keyOrdering.isDefined && dep.aggregator.isDefined) { + // if aggregator specified (and so unique keys) and key ordering specified - then + // consistent ordering. + DeterministicLevel.DETERMINATE + } else { + // In Spark, the reducer fetches multiple remote shuffle blocks at the same time, and + // the arrival order of these shuffle blocks are totally random. Even if the parent map + // RDD is DETERMINATE, the reduce RDD is always UNORDERED. + DeterministicLevel.UNORDERED + } + + // For narrow dependency, assume the output deterministic level of current RDD is same as + // parent. + case dep => dep.rdd.outputDeterministicLevel + } + + if (deterministicLevelCandidates.isEmpty) { + // By default we assume the root RDD is determinate. + DeterministicLevel.DETERMINATE + } else { + deterministicLevelCandidates.maxBy(_.id) + } + } } @@ -1892,3 +1999,18 @@ object RDD { new DoubleRDDFunctions(rdd.map(x => num.toDouble(x))) } } + +/** + * The deterministic level of RDD's output (i.e. what `RDD#compute` returns). This explains how + * the output will diff when Spark reruns the tasks for the RDD. There are 3 deterministic levels: + * 1. DETERMINATE: The RDD output is always the same data set in the same order after a rerun. + * 2. UNORDERED: The RDD output is always the same data set but the order can be different + * after a rerun. + * 3. INDETERMINATE. The RDD output can be different after a rerun. + * + * Note that, the output of an RDD usually relies on the parent RDDs. When the parent RDD's output + * is INDETERMINATE, it's very likely the RDD's output is also INDETERMINATE. + */ +private[spark] object DeterministicLevel extends Enumeration { + val DETERMINATE, UNORDERED, INDETERMINATE = Value +} diff --git a/core/src/main/scala/org/apache/spark/rdd/RDDBarrier.scala b/core/src/main/scala/org/apache/spark/rdd/RDDBarrier.scala new file mode 100644 index 0000000000000..42802f7113a19 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/rdd/RDDBarrier.scala @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rdd + +import scala.reflect.ClassTag + +import org.apache.spark.TaskContext +import org.apache.spark.annotation.{Experimental, Since} + +/** + * :: Experimental :: + * Wraps an RDD in a barrier stage, which forces Spark to launch tasks of this stage together. + * [[org.apache.spark.rdd.RDDBarrier]] instances are created by + * [[org.apache.spark.rdd.RDD#barrier]]. + */ +@Experimental +@Since("2.4.0") +class RDDBarrier[T: ClassTag] private[spark] (rdd: RDD[T]) { + + /** + * :: Experimental :: + * Returns a new RDD by applying a function to each partition of the wrapped RDD, + * where tasks are launched together in a barrier stage. + * The interface is the same as [[org.apache.spark.rdd.RDD#mapPartitions]]. + * Please see the API doc there. + * @see [[org.apache.spark.BarrierTaskContext]] + */ + @Experimental + @Since("2.4.0") + def mapPartitions[S: ClassTag]( + f: Iterator[T] => Iterator[S], + preservesPartitioning: Boolean = false): RDD[S] = rdd.withScope { + val cleanedF = rdd.sparkContext.clean(f) + new MapPartitionsRDD( + rdd, + (context: TaskContext, index: Int, iter: Iterator[T]) => cleanedF(iter), + preservesPartitioning, + isFromBarrier = true + ) + } + + // TODO: [SPARK-25247] add extra conf to RDDBarrier, e.g., timeout. +} diff --git a/core/src/main/scala/org/apache/spark/rdd/ReliableCheckpointRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ReliableCheckpointRDD.scala index 979152b55f957..8273d8a9eb476 100644 --- a/core/src/main/scala/org/apache/spark/rdd/ReliableCheckpointRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/ReliableCheckpointRDD.scala @@ -300,7 +300,7 @@ private[spark] object ReliableCheckpointRDD extends Logging { val deserializeStream = serializer.deserializeStream(fileInputStream) // Register an on-task-completion callback to close the input stream. - context.addTaskCompletionListener(context => deserializeStream.close()) + context.addTaskCompletionListener[Unit](context => deserializeStream.close()) deserializeStream.asIterator.asInstanceOf[Iterator[T]] } diff --git a/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala index 26eaa9aa3d03f..e8f9b27b7eb55 100644 --- a/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala @@ -110,4 +110,6 @@ class ShuffledRDD[K: ClassTag, V: ClassTag, C: ClassTag]( super.clearDependencies() prev = null } + + private[spark] override def isBarrier(): Boolean = false } diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala index a2936d6ad539c..47576959322d1 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala @@ -50,7 +50,7 @@ private[netty] class NettyRpcEnv( private[netty] val transportConf = SparkTransportConf.fromSparkConf( conf.clone.set("spark.rpc.io.numConnectionsPerPeer", "1"), "rpc", - conf.getInt("spark.rpc.io.threads", 0)) + conf.getInt("spark.rpc.io.threads", numUsableCores)) private val dispatcher: Dispatcher = new Dispatcher(this, numUsableCores) diff --git a/core/src/main/scala/org/apache/spark/scheduler/ActiveJob.scala b/core/src/main/scala/org/apache/spark/scheduler/ActiveJob.scala index 949e88f606275..6e4d062749d5f 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ActiveJob.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ActiveJob.scala @@ -60,4 +60,10 @@ private[spark] class ActiveJob( val finished = Array.fill[Boolean](numPartitions)(false) var numFinished = 0 + + /** Resets the status of all partitions in this stage so they are marked as not finished. */ + def resetAllPartitions(): Unit = { + (0 until numPartitions).foreach(finished.update(_, false)) + numFinished = 0 + } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/BarrierJobAllocationFailed.scala b/core/src/main/scala/org/apache/spark/scheduler/BarrierJobAllocationFailed.scala new file mode 100644 index 0000000000000..803a0a1226d6c --- /dev/null +++ b/core/src/main/scala/org/apache/spark/scheduler/BarrierJobAllocationFailed.scala @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler + +import org.apache.spark.SparkException + +/** + * Exception thrown when submit a job with barrier stage(s) failing a required check. + */ +private[spark] class BarrierJobAllocationFailed(message: String) extends SparkException(message) + +private[spark] class BarrierJobUnsupportedRDDChainException + extends BarrierJobAllocationFailed( + BarrierJobAllocationFailed.ERROR_MESSAGE_RUN_BARRIER_WITH_UNSUPPORTED_RDD_CHAIN_PATTERN) + +private[spark] class BarrierJobRunWithDynamicAllocationException + extends BarrierJobAllocationFailed( + BarrierJobAllocationFailed.ERROR_MESSAGE_RUN_BARRIER_WITH_DYN_ALLOCATION) + +private[spark] class BarrierJobSlotsNumberCheckFailed + extends BarrierJobAllocationFailed( + BarrierJobAllocationFailed.ERROR_MESSAGE_BARRIER_REQUIRE_MORE_SLOTS_THAN_CURRENT_TOTAL_NUMBER) + +private[spark] object BarrierJobAllocationFailed { + + // Error message when running a barrier stage that have unsupported RDD chain pattern. + val ERROR_MESSAGE_RUN_BARRIER_WITH_UNSUPPORTED_RDD_CHAIN_PATTERN = + "[SPARK-24820][SPARK-24821]: Barrier execution mode does not allow the following pattern of " + + "RDD chain within a barrier stage:\n1. Ancestor RDDs that have different number of " + + "partitions from the resulting RDD (eg. union()/coalesce()/first()/take()/" + + "PartitionPruningRDD). A workaround for first()/take() can be barrierRdd.collect().head " + + "(scala) or barrierRdd.collect()[0] (python).\n" + + "2. An RDD that depends on multiple barrier RDDs (eg. barrierRdd1.zip(barrierRdd2))." + + // Error message when running a barrier stage with dynamic resource allocation enabled. + val ERROR_MESSAGE_RUN_BARRIER_WITH_DYN_ALLOCATION = + "[SPARK-24942]: Barrier execution mode does not support dynamic resource allocation for " + + "now. You can disable dynamic resource allocation by setting Spark conf " + + "\"spark.dynamicAllocation.enabled\" to \"false\"." + + // Error message when running a barrier stage that requires more slots than current total number. + val ERROR_MESSAGE_BARRIER_REQUIRE_MORE_SLOTS_THAN_CURRENT_TOTAL_NUMBER = + "[SPARK-24819]: Barrier execution mode does not allow run a barrier stage that requires " + + "more slots than the total number of slots in the cluster currently. Please init a new " + + "cluster with more CPU cores or repartition the input RDD(s) to reduce the number of " + + "slots required to run this barrier stage." +} diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index f74425d73b392..47108353583a8 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -19,8 +19,9 @@ package org.apache.spark.scheduler import java.io.NotSerializableException import java.util.Properties -import java.util.concurrent.TimeUnit +import java.util.concurrent.{ConcurrentHashMap, TimeUnit} import java.util.concurrent.atomic.AtomicInteger +import java.util.function.BiFunction import scala.annotation.tailrec import scala.collection.Map @@ -34,12 +35,12 @@ import org.apache.commons.lang3.SerializationUtils import org.apache.spark._ import org.apache.spark.broadcast.Broadcast -import org.apache.spark.executor.TaskMetrics +import org.apache.spark.executor.{ExecutorMetrics, TaskMetrics} import org.apache.spark.internal.Logging import org.apache.spark.internal.config import org.apache.spark.network.util.JavaUtils import org.apache.spark.partial.{ApproximateActionListener, ApproximateEvaluator, PartialResult} -import org.apache.spark.rdd.{RDD, RDDCheckpointData} +import org.apache.spark.rdd.{DeterministicLevel, RDD, RDDCheckpointData} import org.apache.spark.rpc.RpcTimeout import org.apache.spark.storage._ import org.apache.spark.storage.BlockManagerMessages.BlockManagerHeartbeat @@ -111,8 +112,7 @@ import org.apache.spark.util._ * - When adding a new data structure, update `DAGSchedulerSuite.assertDataStructuresEmpty` to * include the new structure. This will help to catch memory leaks. */ -private[spark] -class DAGScheduler( +private[spark] class DAGScheduler( private[scheduler] val sc: SparkContext, private[scheduler] val taskScheduler: TaskScheduler, listenerBus: LiveListenerBus, @@ -203,6 +203,24 @@ class DAGScheduler( sc.getConf.getInt("spark.stage.maxConsecutiveAttempts", DAGScheduler.DEFAULT_MAX_CONSECUTIVE_STAGE_ATTEMPTS) + /** + * Number of max concurrent tasks check failures for each barrier job. + */ + private[scheduler] val barrierJobIdToNumTasksCheckFailures = new ConcurrentHashMap[Int, Int] + + /** + * Time in seconds to wait between a max concurrent tasks check failure and the next check. + */ + private val timeIntervalNumTasksCheck = sc.getConf + .get(config.BARRIER_MAX_CONCURRENT_TASKS_CHECK_INTERVAL) + + /** + * Max number of max concurrent tasks check failures allowed for a job before fail the job + * submission. + */ + private val maxFailureNumTasksCheck = sc.getConf + .get(config.BARRIER_MAX_CONCURRENT_TASKS_CHECK_MAX_FAILURES) + private val messageScheduler = ThreadUtils.newDaemonSingleThreadScheduledExecutor("dag-scheduler-message") @@ -246,8 +264,11 @@ class DAGScheduler( execId: String, // (taskId, stageId, stageAttemptId, accumUpdates) accumUpdates: Array[(Long, Int, Int, Seq[AccumulableInfo])], - blockManagerId: BlockManagerId): Boolean = { - listenerBus.post(SparkListenerExecutorMetricsUpdate(execId, accumUpdates)) + blockManagerId: BlockManagerId, + // executor metrics indexed by MetricGetter.values + executorUpdates: ExecutorMetrics): Boolean = { + listenerBus.post(SparkListenerExecutorMetricsUpdate(execId, accumUpdates, + Some(executorUpdates))) blockManagerMaster.driverEndpoint.askSync[Boolean]( BlockManagerHeartbeat(blockManagerId), new RpcTimeout(600 seconds, "BlockManagerHeartbeat")) } @@ -340,6 +361,21 @@ class DAGScheduler( } } + /** + * Check to make sure we don't launch a barrier stage with unsupported RDD chain pattern. The + * following patterns are not supported: + * 1. Ancestor RDDs that have different number of partitions from the resulting RDD (eg. + * union()/coalesce()/first()/take()/PartitionPruningRDD); + * 2. An RDD that depends on multiple barrier RDDs (eg. barrierRdd1.zip(barrierRdd2)). + */ + private def checkBarrierStageWithRDDChainPattern(rdd: RDD[_], numTasksInStage: Int): Unit = { + val predicate: RDD[_] => Boolean = (r => + r.getNumPartitions == numTasksInStage && r.dependencies.filter(_.rdd.isBarrier()).size <= 1) + if (rdd.isBarrier() && !traverseParentRDDsWithinStage(rdd, predicate)) { + throw new BarrierJobUnsupportedRDDChainException + } + } + /** * Creates a ShuffleMapStage that generates the given shuffle dependency's partitions. If a * previously run stage generated the same shuffle data, this function will copy the output @@ -348,6 +384,9 @@ class DAGScheduler( */ def createShuffleMapStage(shuffleDep: ShuffleDependency[_, _, _], jobId: Int): ShuffleMapStage = { val rdd = shuffleDep.rdd + checkBarrierStageWithDynamicAllocation(rdd) + checkBarrierStageWithNumSlots(rdd) + checkBarrierStageWithRDDChainPattern(rdd, rdd.getNumPartitions) val numTasks = rdd.partitions.length val parents = getOrCreateParentStages(rdd, jobId) val id = nextStageId.getAndIncrement() @@ -367,6 +406,36 @@ class DAGScheduler( stage } + /** + * We don't support run a barrier stage with dynamic resource allocation enabled, it shall lead + * to some confusing behaviors (eg. with dynamic resource allocation enabled, it may happen that + * we acquire some executors (but not enough to launch all the tasks in a barrier stage) and + * later release them due to executor idle time expire, and then acquire again). + * + * We perform the check on job submit and fail fast if running a barrier stage with dynamic + * resource allocation enabled. + * + * TODO SPARK-24942 Improve cluster resource management with jobs containing barrier stage + */ + private def checkBarrierStageWithDynamicAllocation(rdd: RDD[_]): Unit = { + if (rdd.isBarrier() && Utils.isDynamicAllocationEnabled(sc.getConf)) { + throw new BarrierJobRunWithDynamicAllocationException + } + } + + /** + * Check whether the barrier stage requires more slots (to be able to launch all tasks in the + * barrier stage together) than the total number of active slots currently. Fail current check + * if trying to submit a barrier stage that requires more slots than current total number. If + * the check fails consecutively beyond a configured number for a job, then fail current job + * submission. + */ + private def checkBarrierStageWithNumSlots(rdd: RDD[_]): Unit = { + if (rdd.isBarrier() && rdd.getNumPartitions > sc.maxNumConcurrentTasks) { + throw new BarrierJobSlotsNumberCheckFailed + } + } + /** * Create a ResultStage associated with the provided jobId. */ @@ -376,6 +445,9 @@ class DAGScheduler( partitions: Array[Int], jobId: Int, callSite: CallSite): ResultStage = { + checkBarrierStageWithDynamicAllocation(rdd) + checkBarrierStageWithNumSlots(rdd) + checkBarrierStageWithRDDChainPattern(rdd, partitions.toSet.size) val parents = getOrCreateParentStages(rdd, jobId) val id = nextStageId.getAndIncrement() val stage = new ResultStage(id, rdd, func, partitions, parents, jobId, callSite) @@ -451,6 +523,32 @@ class DAGScheduler( parents } + /** + * Traverses the given RDD and its ancestors within the same stage and checks whether all of the + * RDDs satisfy a given predicate. + */ + private def traverseParentRDDsWithinStage(rdd: RDD[_], predicate: RDD[_] => Boolean): Boolean = { + val visited = new HashSet[RDD[_]] + val waitingForVisit = new ArrayStack[RDD[_]] + waitingForVisit.push(rdd) + while (waitingForVisit.nonEmpty) { + val toVisit = waitingForVisit.pop() + if (!visited(toVisit)) { + if (!predicate(toVisit)) { + return false + } + visited += toVisit + toVisit.dependencies.foreach { + case _: ShuffleDependency[_, _, _] => + // Not within the same stage with current rdd, do nothing. + case dependency => + waitingForVisit.push(dependency.rdd) + } + } + } + true + } + private def getMissingParentStages(stage: Stage): List[Stage] = { val missing = new HashSet[Stage] val visited = new HashSet[RDD[_]] @@ -866,11 +964,38 @@ class DAGScheduler( // HadoopRDD whose underlying HDFS files have been deleted. finalStage = createResultStage(finalRDD, func, partitions, jobId, callSite) } catch { + case e: BarrierJobSlotsNumberCheckFailed => + logWarning(s"The job $jobId requires to run a barrier stage that requires more slots " + + "than the total number of slots in the cluster currently.") + // If jobId doesn't exist in the map, Scala coverts its value null to 0: Int automatically. + val numCheckFailures = barrierJobIdToNumTasksCheckFailures.compute(jobId, + new BiFunction[Int, Int, Int] { + override def apply(key: Int, value: Int): Int = value + 1 + }) + if (numCheckFailures <= maxFailureNumTasksCheck) { + messageScheduler.schedule( + new Runnable { + override def run(): Unit = eventProcessLoop.post(JobSubmitted(jobId, finalRDD, func, + partitions, callSite, listener, properties)) + }, + timeIntervalNumTasksCheck, + TimeUnit.SECONDS + ) + return + } else { + // Job failed, clear internal data. + barrierJobIdToNumTasksCheckFailures.remove(jobId) + listener.jobFailed(e) + return + } + case e: Exception => logWarning("Creating new stage failed due to exception - job: " + jobId, e) listener.jobFailed(e) return } + // Job submitted, clear internal data. + barrierJobIdToNumTasksCheckFailures.remove(jobId) val job = new ActiveJob(jobId, finalStage, callSite, listener, properties) clearCacheLocs() @@ -1062,7 +1187,7 @@ class DAGScheduler( stage.pendingPartitions += id new ShuffleMapTask(stage.id, stage.latestInfo.attemptNumber, taskBinary, part, locs, properties, serializedTaskMetrics, Option(jobId), - Option(sc.applicationId), sc.applicationAttemptId) + Option(sc.applicationId), sc.applicationAttemptId, stage.rdd.isBarrier()) } case stage: ResultStage => @@ -1072,7 +1197,8 @@ class DAGScheduler( val locs = taskIdToLocations(id) new ResultTask(stage.id, stage.latestInfo.attemptNumber, taskBinary, part, locs, id, properties, serializedTaskMetrics, - Option(jobId), Option(sc.applicationId), sc.applicationAttemptId) + Option(jobId), Option(sc.applicationId), sc.applicationAttemptId, + stage.rdd.isBarrier()) } } } catch { @@ -1250,18 +1376,10 @@ class DAGScheduler( case smt: ShuffleMapTask => val shuffleStage = stage.asInstanceOf[ShuffleMapStage] + shuffleStage.pendingPartitions -= task.partitionId val status = event.result.asInstanceOf[MapStatus] val execId = status.location.executorId logDebug("ShuffleMapTask finished on " + execId) - if (stageIdToStage(task.stageId).latestInfo.attemptNumber == task.stageAttemptId) { - // This task was for the currently running attempt of the stage. Since the task - // completed successfully from the perspective of the TaskSetManager, mark it as - // no longer pending (the TaskSetManager may consider the task complete even - // when the output needs to be ignored because the task's epoch is too small below. - // In this case, when pending partitions is empty, there will still be missing - // output locations, which will cause the DAGScheduler to resubmit the stage below.) - shuffleStage.pendingPartitions -= task.partitionId - } if (failedEpoch.contains(execId) && smt.epoch <= failedEpoch(execId)) { logInfo(s"Ignoring possibly bogus $smt completion from executor $execId") } else { @@ -1270,13 +1388,6 @@ class DAGScheduler( // available. mapOutputTracker.registerMapOutput( shuffleStage.shuffleDep.shuffleId, smt.partitionId, status) - // Remove the task's partition from pending partitions. This may have already been - // done above, but will not have been done yet in cases where the task attempt was - // from an earlier attempt of the stage (i.e., not the attempt that's currently - // running). This allows the DAGScheduler to mark the stage as complete when one - // copy of each task has finished successfully, even if the currently active stage - // still has tasks running. - shuffleStage.pendingPartitions -= task.partitionId } if (runningStages.contains(shuffleStage) && shuffleStage.pendingPartitions.isEmpty) { @@ -1311,17 +1422,6 @@ class DAGScheduler( } } - case Resubmitted => - logInfo("Resubmitted " + task + ", so marking it as still running") - stage match { - case sms: ShuffleMapStage => - sms.pendingPartitions += task.partitionId - - case _ => - assert(false, "TaskSetManagers should only send Resubmitted task statuses for " + - "tasks in ShuffleMapStages.") - } - case FetchFailed(bmAddress, shuffleId, mapId, _, failureMessage) => val failedStage = stageIdToStage(task.stageId) val mapStage = shuffleIdToMapStage(shuffleId) @@ -1331,9 +1431,9 @@ class DAGScheduler( s" ${task.stageAttemptId} and there is a more recent attempt for that stage " + s"(attempt ${failedStage.latestInfo.attemptNumber}) running") } else { - failedStage.fetchFailedAttemptIds.add(task.stageAttemptId) + failedStage.failedAttemptIds.add(task.stageAttemptId) val shouldAbortStage = - failedStage.fetchFailedAttemptIds.size >= maxConsecutiveStageAttempts || + failedStage.failedAttemptIds.size >= maxConsecutiveStageAttempts || disallowStageRetryForTest // It is likely that we receive multiple FetchFailed for a single stage (because we have @@ -1349,6 +1449,31 @@ class DAGScheduler( s"longer running") } + if (mapStage.rdd.isBarrier()) { + // Mark all the map as broken in the map stage, to ensure retry all the tasks on + // resubmitted stage attempt. + mapOutputTracker.unregisterAllMapOutput(shuffleId) + } else if (mapId != -1) { + // Mark the map whose fetch failed as broken in the map stage + mapOutputTracker.unregisterMapOutput(shuffleId, mapId, bmAddress) + } + + if (failedStage.rdd.isBarrier()) { + failedStage match { + case failedMapStage: ShuffleMapStage => + // Mark all the map as broken in the map stage, to ensure retry all the tasks on + // resubmitted stage attempt. + mapOutputTracker.unregisterAllMapOutput(failedMapStage.shuffleDep.shuffleId) + + case failedResultStage: ResultStage => + // Abort the failed result stage since we may have committed output for some + // partitions. + val reason = "Could not recover from a failed barrier ResultStage. Most recent " + + s"failure reason: $failureMessage" + abortStage(failedResultStage, reason, None) + } + } + if (shouldAbortStage) { val abortMessage = if (disallowStageRetryForTest) { "Fetch failure will not retry stage due to testing config" @@ -1365,6 +1490,63 @@ class DAGScheduler( failedStages += failedStage failedStages += mapStage if (noResubmitEnqueued) { + // If the map stage is INDETERMINATE, which means the map tasks may return + // different result when re-try, we need to re-try all the tasks of the failed + // stage and its succeeding stages, because the input data will be changed after the + // map tasks are re-tried. + // Note that, if map stage is UNORDERED, we are fine. The shuffle partitioner is + // guaranteed to be determinate, so the input data of the reducers will not change + // even if the map tasks are re-tried. + if (mapStage.rdd.outputDeterministicLevel == DeterministicLevel.INDETERMINATE) { + // It's a little tricky to find all the succeeding stages of `failedStage`, because + // each stage only know its parents not children. Here we traverse the stages from + // the leaf nodes (the result stages of active jobs), and rollback all the stages + // in the stage chains that connect to the `failedStage`. To speed up the stage + // traversing, we collect the stages to rollback first. If a stage needs to + // rollback, all its succeeding stages need to rollback to. + val stagesToRollback = scala.collection.mutable.HashSet(failedStage) + + def collectStagesToRollback(stageChain: List[Stage]): Unit = { + if (stagesToRollback.contains(stageChain.head)) { + stageChain.drop(1).foreach(s => stagesToRollback += s) + } else { + stageChain.head.parents.foreach { s => + collectStagesToRollback(s :: stageChain) + } + } + } + + def generateErrorMessage(stage: Stage): String = { + "A shuffle map stage with indeterminate output was failed and retried. " + + s"However, Spark cannot rollback the $stage to re-process the input data, " + + "and has to fail this job. Please eliminate the indeterminacy by " + + "checkpointing the RDD before repartition and try again." + } + + activeJobs.foreach(job => collectStagesToRollback(job.finalStage :: Nil)) + + stagesToRollback.foreach { + case mapStage: ShuffleMapStage => + val numMissingPartitions = mapStage.findMissingPartitions().length + if (numMissingPartitions < mapStage.numTasks) { + // TODO: support to rollback shuffle files. + // Currently the shuffle writing is "first write wins", so we can't re-run a + // shuffle map stage and overwrite existing shuffle files. We have to finish + // SPARK-8029 first. + abortStage(mapStage, generateErrorMessage(mapStage), None) + } + + case resultStage: ResultStage if resultStage.activeJob.isDefined => + val numMissingPartitions = resultStage.findMissingPartitions().length + if (numMissingPartitions < resultStage.numTasks) { + // TODO: support to rollback result tasks. + abortStage(resultStage, generateErrorMessage(resultStage), None) + } + + case _ => + } + } + // We expect one executor failure to trigger many FetchFailures in rapid succession, // but all of those task failures can typically be handled by a single resubmission of // the failed stage. We avoid flooding the scheduler's event queue with resubmit @@ -1375,7 +1557,7 @@ class DAGScheduler( // simpler while not producing an overwhelming number of scheduler events. logInfo( s"Resubmitting $mapStage (${mapStage.name}) and " + - s"$failedStage (${failedStage.name}) due to fetch failure" + s"$failedStage (${failedStage.name}) due to fetch failure" ) messageScheduler.schedule( new Runnable { @@ -1386,10 +1568,6 @@ class DAGScheduler( ) } } - // Mark the map whose fetch failed as broken in the map stage - if (mapId != -1) { - mapOutputTracker.unregisterMapOutput(shuffleId, mapId, bmAddress) - } // TODO: mark the executor as failed only if there were lots of fetch failures on it if (bmAddress != null) { @@ -1411,6 +1589,91 @@ class DAGScheduler( } } + case failure: TaskFailedReason if task.isBarrier => + // Also handle the task failed reasons here. + failure match { + case Resubmitted => + handleResubmittedFailure(task, stage) + + case _ => // Do nothing. + } + + // Always fail the current stage and retry all the tasks when a barrier task fail. + val failedStage = stageIdToStage(task.stageId) + if (failedStage.latestInfo.attemptNumber != task.stageAttemptId) { + logInfo(s"Ignoring task failure from $task as it's from $failedStage attempt" + + s" ${task.stageAttemptId} and there is a more recent attempt for that stage " + + s"(attempt ${failedStage.latestInfo.attemptNumber}) running") + } else { + logInfo(s"Marking $failedStage (${failedStage.name}) as failed due to a barrier task " + + "failed.") + val message = s"Stage failed because barrier task $task finished unsuccessfully.\n" + + failure.toErrorString + try { + // killAllTaskAttempts will fail if a SchedulerBackend does not implement killTask. + val reason = s"Task $task from barrier stage $failedStage (${failedStage.name}) " + + "failed." + taskScheduler.killAllTaskAttempts(stageId, interruptThread = false, reason) + } catch { + case e: UnsupportedOperationException => + // Cannot continue with barrier stage if failed to cancel zombie barrier tasks. + // TODO SPARK-24877 leave the zombie tasks and ignore their completion events. + logWarning(s"Could not kill all tasks for stage $stageId", e) + abortStage(failedStage, "Could not kill zombie barrier tasks for stage " + + s"$failedStage (${failedStage.name})", Some(e)) + } + markStageAsFinished(failedStage, Some(message)) + + failedStage.failedAttemptIds.add(task.stageAttemptId) + // TODO Refactor the failure handling logic to combine similar code with that of + // FetchFailed. + val shouldAbortStage = + failedStage.failedAttemptIds.size >= maxConsecutiveStageAttempts || + disallowStageRetryForTest + + if (shouldAbortStage) { + val abortMessage = if (disallowStageRetryForTest) { + "Barrier stage will not retry stage due to testing config. Most recent failure " + + s"reason: $message" + } else { + s"""$failedStage (${failedStage.name}) + |has failed the maximum allowable number of + |times: $maxConsecutiveStageAttempts. + |Most recent failure reason: $message + """.stripMargin.replaceAll("\n", " ") + } + abortStage(failedStage, abortMessage, None) + } else { + failedStage match { + case failedMapStage: ShuffleMapStage => + // Mark all the map as broken in the map stage, to ensure retry all the tasks on + // resubmitted stage attempt. + mapOutputTracker.unregisterAllMapOutput(failedMapStage.shuffleDep.shuffleId) + + case failedResultStage: ResultStage => + // Abort the failed result stage since we may have committed output for some + // partitions. + val reason = "Could not recover from a failed barrier ResultStage. Most recent " + + s"failure reason: $message" + abortStage(failedResultStage, reason, None) + } + // In case multiple task failures triggered for a single stage attempt, ensure we only + // resubmit the failed stage once. + val noResubmitEnqueued = !failedStages.contains(failedStage) + failedStages += failedStage + if (noResubmitEnqueued) { + logInfo(s"Resubmitting $failedStage (${failedStage.name}) due to barrier stage " + + "failure.") + messageScheduler.schedule(new Runnable { + override def run(): Unit = eventProcessLoop.post(ResubmitFailedStages) + }, DAGScheduler.RESUBMIT_TIMEOUT, TimeUnit.MILLISECONDS) + } + } + } + + case Resubmitted => + handleResubmittedFailure(task, stage) + case _: TaskCommitDenied => // Do nothing here, left up to the TaskScheduler to decide how to handle denied commits @@ -1426,6 +1689,18 @@ class DAGScheduler( } } + private def handleResubmittedFailure(task: Task[_], stage: Stage): Unit = { + logInfo(s"Resubmitted $task, so marking it as still running.") + stage match { + case sms: ShuffleMapStage => + sms.pendingPartitions += task.partitionId + + case _ => + throw new SparkException("TaskSetManagers should only send Resubmitted task " + + "statuses for tasks in ShuffleMapStages.") + } + } + private[scheduler] def markMapStageJobsAsFinished(shuffleStage: ShuffleMapStage): Unit = { // Mark any map-stage jobs waiting on this stage as finished if (shuffleStage.isAvailable && shuffleStage.mapStageJobs.nonEmpty) { diff --git a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala index 69bc51c1ecf90..1629e1797977f 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala @@ -23,8 +23,7 @@ import java.nio.charset.StandardCharsets import java.util.EnumSet import java.util.Locale -import scala.collection.mutable -import scala.collection.mutable.ArrayBuffer +import scala.collection.mutable.{ArrayBuffer, Map} import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileSystem, FSDataOutputStream, Path} @@ -36,6 +35,7 @@ import org.json4s.jackson.JsonMethods._ import org.apache.spark.{SPARK_VERSION, SparkConf} import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.executor.ExecutorMetrics import org.apache.spark.internal.Logging import org.apache.spark.internal.config._ import org.apache.spark.io.CompressionCodec @@ -51,6 +51,7 @@ import org.apache.spark.util.{JsonProtocol, Utils} * spark.eventLog.overwrite - Whether to overwrite any existing files. * spark.eventLog.dir - Path to the directory in which events are logged. * spark.eventLog.buffer.kb - Buffer size to use when writing to output streams + * spark.eventLog.logStageExecutorMetrics.enabled - Whether to log stage executor metrics */ private[spark] class EventLoggingListener( appId: String, @@ -69,6 +70,7 @@ private[spark] class EventLoggingListener( private val shouldCompress = sparkConf.get(EVENT_LOG_COMPRESS) private val shouldOverwrite = sparkConf.get(EVENT_LOG_OVERWRITE) private val shouldLogBlockUpdates = sparkConf.get(EVENT_LOG_BLOCK_UPDATES) + private val shouldLogStageExecutorMetrics = sparkConf.get(EVENT_LOG_STAGE_EXECUTOR_METRICS) private val testing = sparkConf.get(EVENT_LOG_TESTING) private val outputBufferSize = sparkConf.get(EVENT_LOG_OUTPUT_BUFFER_SIZE).toInt private val fileSystem = Utils.getHadoopFileSystem(logBaseDir, hadoopConf) @@ -93,6 +95,9 @@ private[spark] class EventLoggingListener( // Visible for tests only. private[scheduler] val logPath = getLogPath(logBaseDir, appId, appAttemptId, compressionCodecName) + // map of (stageId, stageAttempt), to peak executor metrics for the stage + private val liveStageExecutorMetrics = Map.empty[(Int, Int), Map[String, ExecutorMetrics]] + /** * Creates the log file in the configured log directory. */ @@ -155,7 +160,14 @@ private[spark] class EventLoggingListener( } // Events that do not trigger a flush - override def onStageSubmitted(event: SparkListenerStageSubmitted): Unit = logEvent(event) + override def onStageSubmitted(event: SparkListenerStageSubmitted): Unit = { + logEvent(event) + if (shouldLogStageExecutorMetrics) { + // record the peak metrics for the new stage + liveStageExecutorMetrics.put((event.stageInfo.stageId, event.stageInfo.attemptNumber()), + Map.empty[String, ExecutorMetrics]) + } + } override def onTaskStart(event: SparkListenerTaskStart): Unit = logEvent(event) @@ -169,6 +181,26 @@ private[spark] class EventLoggingListener( // Events that trigger a flush override def onStageCompleted(event: SparkListenerStageCompleted): Unit = { + if (shouldLogStageExecutorMetrics) { + // clear out any previous attempts, that did not have a stage completed event + val prevAttemptId = event.stageInfo.attemptNumber() - 1 + for (attemptId <- 0 to prevAttemptId) { + liveStageExecutorMetrics.remove((event.stageInfo.stageId, attemptId)) + } + + // log the peak executor metrics for the stage, for each live executor, + // whether or not the executor is running tasks for the stage + val executorOpt = liveStageExecutorMetrics.remove( + (event.stageInfo.stageId, event.stageInfo.attemptNumber())) + executorOpt.foreach { execMap => + execMap.foreach { case (executorId, peakExecutorMetrics) => + logEvent(new SparkListenerStageExecutorMetrics(executorId, event.stageInfo.stageId, + event.stageInfo.attemptNumber(), peakExecutorMetrics)) + } + } + } + + // log stage completed event logEvent(event, flushLogger = true) } @@ -234,8 +266,18 @@ private[spark] class EventLoggingListener( } } - // No-op because logging every update would be overkill - override def onExecutorMetricsUpdate(event: SparkListenerExecutorMetricsUpdate): Unit = { } + override def onExecutorMetricsUpdate(event: SparkListenerExecutorMetricsUpdate): Unit = { + if (shouldLogStageExecutorMetrics) { + // For the active stages, record any new peak values for the memory metrics for the executor + event.executorUpdates.foreach { executorUpdates => + liveStageExecutorMetrics.values.foreach { peakExecutorMetrics => + val peakMetrics = peakExecutorMetrics.getOrElseUpdate( + event.execId, new ExecutorMetrics()) + peakMetrics.compareAndUpdatePeakValues(executorUpdates) + } + } + } + } override def onOtherEvent(event: SparkListenerEvent): Unit = { if (event.logEvent) { @@ -296,7 +338,7 @@ private[spark] object EventLoggingListener extends Logging { private val LOG_FILE_PERMISSIONS = new FsPermission(Integer.parseInt("770", 8).toShort) // A cache for compression codecs to avoid creating the same codec many times - private val codecMap = new mutable.HashMap[String, CompressionCodec] + private val codecMap = Map.empty[String, CompressionCodec] /** * Write metadata about an event log to the given stream. diff --git a/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala b/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala index 659694dd189ad..7e1d75fe723d6 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala @@ -31,7 +31,8 @@ import org.apache.spark.util.Utils /** * Result returned by a ShuffleMapTask to a scheduler. Includes the block manager address that the - * task ran on as well as the sizes of outputs for each reducer, for passing on to the reduce tasks. + * task ran on, the sizes of outputs for each reducer, and the number of outputs of the map task, + * for passing on to the reduce tasks. */ private[spark] sealed trait MapStatus { /** Location where this task was run. */ @@ -44,18 +45,23 @@ private[spark] sealed trait MapStatus { * necessary for correctness, since block fetchers are allowed to skip zero-size blocks. */ def getSizeForBlock(reduceId: Int): Long + + /** + * The number of outputs for the map task. + */ + def numberOfOutput: Long } private[spark] object MapStatus { - def apply(loc: BlockManagerId, uncompressedSizes: Array[Long]): MapStatus = { + def apply(loc: BlockManagerId, uncompressedSizes: Array[Long], numOutput: Long): MapStatus = { if (uncompressedSizes.length > Option(SparkEnv.get) .map(_.conf.get(config.SHUFFLE_MIN_NUM_PARTS_TO_HIGHLY_COMPRESS)) .getOrElse(config.SHUFFLE_MIN_NUM_PARTS_TO_HIGHLY_COMPRESS.defaultValue.get)) { - HighlyCompressedMapStatus(loc, uncompressedSizes) + HighlyCompressedMapStatus(loc, uncompressedSizes, numOutput) } else { - new CompressedMapStatus(loc, uncompressedSizes) + new CompressedMapStatus(loc, uncompressedSizes, numOutput) } } @@ -98,29 +104,34 @@ private[spark] object MapStatus { */ private[spark] class CompressedMapStatus( private[this] var loc: BlockManagerId, - private[this] var compressedSizes: Array[Byte]) + private[this] var compressedSizes: Array[Byte], + private[this] var numOutput: Long) extends MapStatus with Externalizable { - protected def this() = this(null, null.asInstanceOf[Array[Byte]]) // For deserialization only + protected def this() = this(null, null.asInstanceOf[Array[Byte]], -1) // For deserialization only - def this(loc: BlockManagerId, uncompressedSizes: Array[Long]) { - this(loc, uncompressedSizes.map(MapStatus.compressSize)) + def this(loc: BlockManagerId, uncompressedSizes: Array[Long], numOutput: Long) { + this(loc, uncompressedSizes.map(MapStatus.compressSize), numOutput) } override def location: BlockManagerId = loc + override def numberOfOutput: Long = numOutput + override def getSizeForBlock(reduceId: Int): Long = { MapStatus.decompressSize(compressedSizes(reduceId)) } override def writeExternal(out: ObjectOutput): Unit = Utils.tryOrIOException { loc.writeExternal(out) + out.writeLong(numOutput) out.writeInt(compressedSizes.length) out.write(compressedSizes) } override def readExternal(in: ObjectInput): Unit = Utils.tryOrIOException { loc = BlockManagerId(in) + numOutput = in.readLong() val len = in.readInt() compressedSizes = new Array[Byte](len) in.readFully(compressedSizes) @@ -143,17 +154,20 @@ private[spark] class HighlyCompressedMapStatus private ( private[this] var numNonEmptyBlocks: Int, private[this] var emptyBlocks: RoaringBitmap, private[this] var avgSize: Long, - private var hugeBlockSizes: Map[Int, Byte]) + private var hugeBlockSizes: Map[Int, Byte], + private[this] var numOutput: Long) extends MapStatus with Externalizable { // loc could be null when the default constructor is called during deserialization require(loc == null || avgSize > 0 || hugeBlockSizes.size > 0 || numNonEmptyBlocks == 0, "Average size can only be zero for map stages that produced no output") - protected def this() = this(null, -1, null, -1, null) // For deserialization only + protected def this() = this(null, -1, null, -1, null, -1) // For deserialization only override def location: BlockManagerId = loc + override def numberOfOutput: Long = numOutput + override def getSizeForBlock(reduceId: Int): Long = { assert(hugeBlockSizes != null) if (emptyBlocks.contains(reduceId)) { @@ -168,6 +182,7 @@ private[spark] class HighlyCompressedMapStatus private ( override def writeExternal(out: ObjectOutput): Unit = Utils.tryOrIOException { loc.writeExternal(out) + out.writeLong(numOutput) emptyBlocks.writeExternal(out) out.writeLong(avgSize) out.writeInt(hugeBlockSizes.size) @@ -179,6 +194,7 @@ private[spark] class HighlyCompressedMapStatus private ( override def readExternal(in: ObjectInput): Unit = Utils.tryOrIOException { loc = BlockManagerId(in) + numOutput = in.readLong() emptyBlocks = new RoaringBitmap() emptyBlocks.readExternal(in) avgSize = in.readLong() @@ -194,7 +210,10 @@ private[spark] class HighlyCompressedMapStatus private ( } private[spark] object HighlyCompressedMapStatus { - def apply(loc: BlockManagerId, uncompressedSizes: Array[Long]): HighlyCompressedMapStatus = { + def apply( + loc: BlockManagerId, + uncompressedSizes: Array[Long], + numOutput: Long): HighlyCompressedMapStatus = { // We must keep track of which blocks are empty so that we don't report a zero-sized // block as being non-empty (or vice-versa) when using the average block size. var i = 0 @@ -235,6 +254,6 @@ private[spark] object HighlyCompressedMapStatus { emptyBlocks.trim() emptyBlocks.runOptimize() new HighlyCompressedMapStatus(loc, numNonEmptyBlocks, emptyBlocks, avgSize, - hugeBlockSizesArray.toMap) + hugeBlockSizesArray.toMap, numOutput) } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala index e36c759a42556..aafeae05b566c 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala @@ -48,7 +48,9 @@ import org.apache.spark.rdd.RDD * @param jobId id of the job this task belongs to * @param appId id of the app this task belongs to * @param appAttemptId attempt id of the app this task belongs to - */ + * @param isBarrier whether this task belongs to a barrier stage. Spark must launch all the tasks + * at the same time for a barrier stage. + */ private[spark] class ResultTask[T, U]( stageId: Int, stageAttemptId: Int, @@ -60,9 +62,10 @@ private[spark] class ResultTask[T, U]( serializedTaskMetrics: Array[Byte], jobId: Option[Int] = None, appId: Option[String] = None, - appAttemptId: Option[String] = None) + appAttemptId: Option[String] = None, + isBarrier: Boolean = false) extends Task[U](stageId, stageAttemptId, partition.index, localProperties, serializedTaskMetrics, - jobId, appId, appAttemptId) + jobId, appId, appAttemptId, isBarrier) with Serializable { @transient private[this] val preferredLocs: Seq[TaskLocation] = { diff --git a/core/src/main/scala/org/apache/spark/scheduler/SchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/SchedulerBackend.scala index 22db3350abfa7..c187ee146301b 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SchedulerBackend.scala @@ -69,4 +69,13 @@ private[spark] trait SchedulerBackend { */ def getDriverLogUrls: Option[Map[String, String]] = None + /** + * Get the max number of tasks that can be concurrent launched currently. + * Note that please don't cache the value returned by this method, because the number can change + * due to add/remove executors. + * + * @return The max number of tasks that can be concurrent launched currently. + */ + def maxNumConcurrentTasks(): Int + } diff --git a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala index 7a25c47e2cab3..f2cd65fd523ab 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala @@ -49,6 +49,8 @@ import org.apache.spark.shuffle.ShuffleWriter * @param jobId id of the job this task belongs to * @param appId id of the app this task belongs to * @param appAttemptId attempt id of the app this task belongs to + * @param isBarrier whether this task belongs to a barrier stage. Spark must launch all the tasks + * at the same time for a barrier stage. */ private[spark] class ShuffleMapTask( stageId: Int, @@ -60,9 +62,10 @@ private[spark] class ShuffleMapTask( serializedTaskMetrics: Array[Byte], jobId: Option[Int] = None, appId: Option[String] = None, - appAttemptId: Option[String] = None) + appAttemptId: Option[String] = None, + isBarrier: Boolean = false) extends Task[MapStatus](stageId, stageAttemptId, partition.index, localProperties, - serializedTaskMetrics, jobId, appId, appAttemptId) + serializedTaskMetrics, jobId, appId, appAttemptId, isBarrier) with Logging { /** A constructor used only in test suites. This does not require passing in an RDD. */ diff --git a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala index 8a112f6a37b96..293e8369677f0 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala @@ -26,7 +26,7 @@ import com.fasterxml.jackson.annotation.JsonTypeInfo import org.apache.spark.{SparkConf, TaskEndReason} import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.executor.TaskMetrics +import org.apache.spark.executor.{ExecutorMetrics, TaskMetrics} import org.apache.spark.scheduler.cluster.ExecutorInfo import org.apache.spark.storage.{BlockManagerId, BlockUpdatedInfo} import org.apache.spark.ui.SparkUI @@ -160,11 +160,29 @@ case class SparkListenerBlockUpdated(blockUpdatedInfo: BlockUpdatedInfo) extends * Periodic updates from executors. * @param execId executor id * @param accumUpdates sequence of (taskId, stageId, stageAttemptId, accumUpdates) + * @param executorUpdates executor level metrics updates */ @DeveloperApi case class SparkListenerExecutorMetricsUpdate( execId: String, - accumUpdates: Seq[(Long, Int, Int, Seq[AccumulableInfo])]) + accumUpdates: Seq[(Long, Int, Int, Seq[AccumulableInfo])], + executorUpdates: Option[ExecutorMetrics] = None) + extends SparkListenerEvent + +/** + * Peak metric values for the executor for the stage, written to the history log at stage + * completion. + * @param execId executor id + * @param stageId stage id + * @param stageAttemptId stage attempt + * @param executorMetrics executor level metrics, indexed by MetricGetter.values + */ +@DeveloperApi +case class SparkListenerStageExecutorMetrics( + execId: String, + stageId: Int, + stageAttemptId: Int, + executorMetrics: ExecutorMetrics) extends SparkListenerEvent @DeveloperApi @@ -264,6 +282,13 @@ private[spark] trait SparkListenerInterface { */ def onExecutorMetricsUpdate(executorMetricsUpdate: SparkListenerExecutorMetricsUpdate): Unit + /** + * Called with the peak memory metrics for a given (executor, stage) combination. Note that this + * is only present when reading from the event log (as in the history server), and is never + * called in a live application. + */ + def onStageExecutorMetrics(executorMetrics: SparkListenerStageExecutorMetrics): Unit + /** * Called when the driver registers a new executor. */ @@ -361,6 +386,9 @@ abstract class SparkListener extends SparkListenerInterface { override def onExecutorMetricsUpdate( executorMetricsUpdate: SparkListenerExecutorMetricsUpdate): Unit = { } + override def onStageExecutorMetrics( + executorMetrics: SparkListenerStageExecutorMetrics): Unit = { } + override def onExecutorAdded(executorAdded: SparkListenerExecutorAdded): Unit = { } override def onExecutorRemoved(executorRemoved: SparkListenerExecutorRemoved): Unit = { } diff --git a/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala b/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala index ff19cc65552e0..8f6b7ad309602 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala @@ -57,6 +57,8 @@ private[spark] trait SparkListenerBus listener.onApplicationEnd(applicationEnd) case metricsUpdate: SparkListenerExecutorMetricsUpdate => listener.onExecutorMetricsUpdate(metricsUpdate) + case stageExecutorMetrics: SparkListenerStageExecutorMetrics => + listener.onStageExecutorMetrics(stageExecutorMetrics) case executorAdded: SparkListenerExecutorAdded => listener.onExecutorAdded(executorAdded) case executorRemoved: SparkListenerExecutorRemoved => diff --git a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala index 290fd073caf27..26cca334d3bd5 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala @@ -82,15 +82,15 @@ private[scheduler] abstract class Stage( private var _latestInfo: StageInfo = StageInfo.fromStage(this, nextAttemptId) /** - * Set of stage attempt IDs that have failed with a FetchFailure. We keep track of these - * failures in order to avoid endless retries if a stage keeps failing with a FetchFailure. + * Set of stage attempt IDs that have failed. We keep track of these failures in order to avoid + * endless retries if a stage keeps failing. * We keep track of each attempt ID that has failed to avoid recording duplicate failures if * multiple tasks from the same stage attempt fail (SPARK-5945). */ - val fetchFailedAttemptIds = new HashSet[Int] + val failedAttemptIds = new HashSet[Int] private[scheduler] def clearFailures() : Unit = { - fetchFailedAttemptIds.clear() + failedAttemptIds.clear() } /** Creates a new attempt for this stage by creating a new StageInfo with a new attempt ID. */ diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala index f536fc2a5f0a1..eb059f12be6d3 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala @@ -49,6 +49,8 @@ import org.apache.spark.util._ * @param jobId id of the job this task belongs to * @param appId id of the app this task belongs to * @param appAttemptId attempt id of the app this task belongs to + * @param isBarrier whether this task belongs to a barrier stage. Spark must launch all the tasks + * at the same time for a barrier stage. */ private[spark] abstract class Task[T]( val stageId: Int, @@ -60,7 +62,8 @@ private[spark] abstract class Task[T]( SparkEnv.get.closureSerializer.newInstance().serialize(TaskMetrics.registered).array(), val jobId: Option[Int] = None, val appId: Option[String] = None, - val appAttemptId: Option[String] = None) extends Serializable { + val appAttemptId: Option[String] = None, + val isBarrier: Boolean = false) extends Serializable { @transient lazy val metrics: TaskMetrics = SparkEnv.get.closureSerializer.newInstance().deserialize(ByteBuffer.wrap(serializedTaskMetrics)) @@ -77,7 +80,9 @@ private[spark] abstract class Task[T]( attemptNumber: Int, metricsSystem: MetricsSystem): T = { SparkEnv.get.blockManager.registerTask(taskAttemptId) - context = new TaskContextImpl( + // TODO SPARK-24874 Allow create BarrierTaskContext based on partitions, instead of whether + // the stage is barrier. + val taskContext = new TaskContextImpl( stageId, stageAttemptId, // stageAttemptId and stageAttemptNumber are semantically equal partitionId, @@ -87,6 +92,13 @@ private[spark] abstract class Task[T]( localProperties, metricsSystem, metrics) + + context = if (isBarrier) { + new BarrierTaskContext(taskContext) + } else { + taskContext + } + TaskContext.setTaskContext(context) taskThread = Thread.currentThread() @@ -161,7 +173,7 @@ private[spark] abstract class Task[T]( var epoch: Long = -1 // Task context, to be initialized in run(). - @transient var context: TaskContextImpl = _ + @transient var context: TaskContext = _ // The actual Thread on which the task is running, if any. Initialized in run(). @volatile @transient private var taskThread: Thread = _ diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskDescription.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskDescription.scala index c98b87148e404..bb4a4442b9433 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskDescription.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskDescription.scala @@ -50,6 +50,7 @@ private[spark] class TaskDescription( val executorId: String, val name: String, val index: Int, // Index within this task's TaskSet + val partitionId: Int, val addedFiles: Map[String, Long], val addedJars: Map[String, Long], val properties: Properties, @@ -76,6 +77,7 @@ private[spark] object TaskDescription { dataOut.writeUTF(taskDescription.executorId) dataOut.writeUTF(taskDescription.name) dataOut.writeInt(taskDescription.index) + dataOut.writeInt(taskDescription.partitionId) // Write files. serializeStringLongMap(taskDescription.addedFiles, dataOut) @@ -117,6 +119,7 @@ private[spark] object TaskDescription { val executorId = dataIn.readUTF() val name = dataIn.readUTF() val index = dataIn.readInt() + val partitionId = dataIn.readInt() // Read files. val taskFiles = deserializeStringLongMap(dataIn) @@ -138,7 +141,7 @@ private[spark] object TaskDescription { // Create a sub-buffer for the serialized task into its own buffer (to be deserialized later). val serializedTask = byteBuffer.slice() - new TaskDescription(taskId, attemptNumber, executorId, name, index, taskFiles, taskJars, - properties, serializedTask) + new TaskDescription(taskId, attemptNumber, executorId, name, index, partitionId, taskFiles, + taskJars, properties, serializedTask) } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala index 90644fea23ab1..94221eb0d5515 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala @@ -17,6 +17,7 @@ package org.apache.spark.scheduler +import org.apache.spark.executor.ExecutorMetrics import org.apache.spark.scheduler.SchedulingMode.SchedulingMode import org.apache.spark.storage.BlockManagerId import org.apache.spark.util.AccumulatorV2 @@ -51,16 +52,22 @@ private[spark] trait TaskScheduler { // Submit a sequence of tasks to run. def submitTasks(taskSet: TaskSet): Unit - // Cancel a stage. + // Kill all the tasks in a stage and fail the stage and all the jobs that depend on the stage. + // Throw UnsupportedOperationException if the backend doesn't support kill tasks. def cancelTasks(stageId: Int, interruptThread: Boolean): Unit /** * Kills a task attempt. + * Throw UnsupportedOperationException if the backend doesn't support kill a task. * * @return Whether the task was successfully killed. */ def killTaskAttempt(taskId: Long, interruptThread: Boolean, reason: String): Boolean + // Kill all the running task attempts in a stage. + // Throw UnsupportedOperationException if the backend doesn't support kill tasks. + def killAllTaskAttempts(stageId: Int, interruptThread: Boolean, reason: String): Unit + // Set the DAG scheduler for upcalls. This is guaranteed to be set before submitTasks is called. def setDAGScheduler(dagScheduler: DAGScheduler): Unit @@ -68,14 +75,15 @@ private[spark] trait TaskScheduler { def defaultParallelism(): Int /** - * Update metrics for in-progress tasks and let the master know that the BlockManager is still - * alive. Return true if the driver knows about the given block manager. Otherwise, return false, - * indicating that the block manager should re-register. + * Update metrics for in-progress tasks and executor metrics, and let the master know that the + * BlockManager is still alive. Return true if the driver knows about the given block manager. + * Otherwise, return false, indicating that the block manager should re-register. */ def executorHeartbeatReceived( execId: String, accumUpdates: Array[(Long, Seq[AccumulatorV2[_, _]])], - blockManagerId: BlockManagerId): Boolean + blockManagerId: BlockManagerId, + executorUpdates: ExecutorMetrics): Boolean /** * Get an application ID associated with the job. diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala index 598b62f85a1fa..4f870e85ad38d 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala @@ -19,7 +19,7 @@ package org.apache.spark.scheduler import java.nio.ByteBuffer import java.util.{Locale, Timer, TimerTask} -import java.util.concurrent.TimeUnit +import java.util.concurrent.{ConcurrentHashMap, TimeUnit} import java.util.concurrent.atomic.AtomicLong import scala.collection.Set @@ -28,8 +28,10 @@ import scala.util.Random import org.apache.spark._ import org.apache.spark.TaskState.TaskState +import org.apache.spark.executor.ExecutorMetrics import org.apache.spark.internal.Logging import org.apache.spark.internal.config +import org.apache.spark.rpc.RpcEndpoint import org.apache.spark.scheduler.SchedulingMode.SchedulingMode import org.apache.spark.scheduler.TaskLocality.TaskLocality import org.apache.spark.storage.BlockManagerId @@ -90,7 +92,7 @@ private[spark] class TaskSchedulerImpl( private val taskSetsByStageIdAndAttempt = new HashMap[Int, HashMap[Int, TaskSetManager]] // Protected by `this` - private[scheduler] val taskIdToTaskSetManager = new HashMap[Long, TaskSetManager] + private[scheduler] val taskIdToTaskSetManager = new ConcurrentHashMap[Long, TaskSetManager] val taskIdToExecutorId = new HashMap[Long, String] @volatile private var hasReceivedTask = false @@ -138,6 +140,19 @@ private[spark] class TaskSchedulerImpl( // This is a var so that we can reset it for testing purposes. private[spark] var taskResultGetter = new TaskResultGetter(sc.env, this) + private lazy val barrierSyncTimeout = conf.get(config.BARRIER_SYNC_TIMEOUT) + + private[scheduler] var barrierCoordinator: RpcEndpoint = null + + private def maybeInitBarrierCoordinator(): Unit = { + if (barrierCoordinator == null) { + barrierCoordinator = new BarrierCoordinator(barrierSyncTimeout, sc.listenerBus, + sc.env.rpcEnv) + sc.env.rpcEnv.setupEndpoint("barrierSync", barrierCoordinator) + logInfo("Registered BarrierCoordinator endpoint") + } + } + override def setDAGScheduler(dagScheduler: DAGScheduler) { this.dagScheduler = dagScheduler } @@ -222,18 +237,11 @@ private[spark] class TaskSchedulerImpl( override def cancelTasks(stageId: Int, interruptThread: Boolean): Unit = synchronized { logInfo("Cancelling stage " + stageId) + // Kill all running tasks for the stage. + killAllTaskAttempts(stageId, interruptThread, reason = "Stage cancelled") + // Cancel all attempts for the stage. taskSetsByStageIdAndAttempt.get(stageId).foreach { attempts => attempts.foreach { case (_, tsm) => - // There are two possible cases here: - // 1. The task set manager has been created and some tasks have been scheduled. - // In this case, send a kill signal to the executors to kill the task and then abort - // the stage. - // 2. The task set manager has been created but no tasks have been scheduled. In this case, - // simply abort the stage. - tsm.runningTasksSet.foreach { tid => - taskIdToExecutorId.get(tid).foreach(execId => - backend.killTask(tid, execId, interruptThread, reason = "Stage cancelled")) - } tsm.abort("Stage %s cancelled".format(stageId)) logInfo("Stage %d was cancelled".format(stageId)) } @@ -252,6 +260,27 @@ private[spark] class TaskSchedulerImpl( } } + override def killAllTaskAttempts( + stageId: Int, + interruptThread: Boolean, + reason: String): Unit = synchronized { + logInfo(s"Killing all running tasks in stage $stageId: $reason") + taskSetsByStageIdAndAttempt.get(stageId).foreach { attempts => + attempts.foreach { case (_, tsm) => + // There are two possible cases here: + // 1. The task set manager has been created and some tasks have been scheduled. + // In this case, send a kill signal to the executors to kill the task. + // 2. The task set manager has been created but no tasks have been scheduled. In this case, + // simply continue. + tsm.runningTasksSet.foreach { tid => + taskIdToExecutorId.get(tid).foreach { execId => + backend.killTask(tid, execId, interruptThread, reason) + } + } + } + } + } + /** * Called to indicate that all task attempts (including speculated tasks) associated with the * given TaskSetManager have completed, so state associated with the TaskSetManager should be @@ -274,7 +303,8 @@ private[spark] class TaskSchedulerImpl( maxLocality: TaskLocality, shuffledOffers: Seq[WorkerOffer], availableCpus: Array[Int], - tasks: IndexedSeq[ArrayBuffer[TaskDescription]]) : Boolean = { + tasks: IndexedSeq[ArrayBuffer[TaskDescription]], + addressesWithDescs: ArrayBuffer[(String, TaskDescription)]) : Boolean = { var launchedTask = false // nodes and executors that are blacklisted for the entire application have already been // filtered out by this point @@ -286,11 +316,16 @@ private[spark] class TaskSchedulerImpl( for (task <- taskSet.resourceOffer(execId, host, maxLocality)) { tasks(i) += task val tid = task.taskId - taskIdToTaskSetManager(tid) = taskSet + taskIdToTaskSetManager.put(tid, taskSet) taskIdToExecutorId(tid) = execId executorIdToRunningTaskIds(execId).add(tid) availableCpus(i) -= CPUS_PER_TASK assert(availableCpus(i) >= 0) + // Only update hosts for a barrier task. + if (taskSet.isBarrier) { + // The executor address is expected to be non empty. + addressesWithDescs += (shuffledOffers(i).address.get -> task) + } launchedTask = true } } catch { @@ -346,6 +381,7 @@ private[spark] class TaskSchedulerImpl( // Build a list of tasks to assign to each worker. val tasks = shuffledOffers.map(o => new ArrayBuffer[TaskDescription](o.cores / CPUS_PER_TASK)) val availableCpus = shuffledOffers.map(o => o.cores).toArray + val availableSlots = shuffledOffers.map(o => o.cores / CPUS_PER_TASK).sum val sortedTaskSets = rootPool.getSortedTaskSetQueue for (taskSet <- sortedTaskSets) { logDebug("parentName: %s, name: %s, runningTasks: %s".format( @@ -359,20 +395,58 @@ private[spark] class TaskSchedulerImpl( // of locality levels so that it gets a chance to launch local tasks on all of them. // NOTE: the preferredLocality order: PROCESS_LOCAL, NODE_LOCAL, NO_PREF, RACK_LOCAL, ANY for (taskSet <- sortedTaskSets) { - var launchedAnyTask = false - var launchedTaskAtCurrentMaxLocality = false - for (currentMaxLocality <- taskSet.myLocalityLevels) { - do { - launchedTaskAtCurrentMaxLocality = resourceOfferSingleTaskSet( - taskSet, currentMaxLocality, shuffledOffers, availableCpus, tasks) - launchedAnyTask |= launchedTaskAtCurrentMaxLocality - } while (launchedTaskAtCurrentMaxLocality) - } - if (!launchedAnyTask) { - taskSet.abortIfCompletelyBlacklisted(hostToExecutors) + // Skip the barrier taskSet if the available slots are less than the number of pending tasks. + if (taskSet.isBarrier && availableSlots < taskSet.numTasks) { + // Skip the launch process. + // TODO SPARK-24819 If the job requires more slots than available (both busy and free + // slots), fail the job on submit. + logInfo(s"Skip current round of resource offers for barrier stage ${taskSet.stageId} " + + s"because the barrier taskSet requires ${taskSet.numTasks} slots, while the total " + + s"number of available slots is $availableSlots.") + } else { + var launchedAnyTask = false + // Record all the executor IDs assigned barrier tasks on. + val addressesWithDescs = ArrayBuffer[(String, TaskDescription)]() + for (currentMaxLocality <- taskSet.myLocalityLevels) { + var launchedTaskAtCurrentMaxLocality = false + do { + launchedTaskAtCurrentMaxLocality = resourceOfferSingleTaskSet(taskSet, + currentMaxLocality, shuffledOffers, availableCpus, tasks, addressesWithDescs) + launchedAnyTask |= launchedTaskAtCurrentMaxLocality + } while (launchedTaskAtCurrentMaxLocality) + } + if (!launchedAnyTask) { + taskSet.abortIfCompletelyBlacklisted(hostToExecutors) + } + if (launchedAnyTask && taskSet.isBarrier) { + // Check whether the barrier tasks are partially launched. + // TODO SPARK-24818 handle the assert failure case (that can happen when some locality + // requirements are not fulfilled, and we should revert the launched tasks). + require(addressesWithDescs.size == taskSet.numTasks, + s"Skip current round of resource offers for barrier stage ${taskSet.stageId} " + + s"because only ${addressesWithDescs.size} out of a total number of " + + s"${taskSet.numTasks} tasks got resource offers. The resource offers may have " + + "been blacklisted or cannot fulfill task locality requirements.") + + // materialize the barrier coordinator. + maybeInitBarrierCoordinator() + + // Update the taskInfos into all the barrier task properties. + val addressesStr = addressesWithDescs + // Addresses ordered by partitionId + .sortBy(_._2.partitionId) + .map(_._1) + .mkString(",") + addressesWithDescs.foreach(_._2.properties.setProperty("addresses", addressesStr)) + + logInfo(s"Successfully scheduled all the ${addressesWithDescs.size} tasks for barrier " + + s"stage ${taskSet.stageId}.") + } } } + // TODO SPARK-24823 Cancel a job that contains barrier stage(s) if the barrier tasks don't get + // launched within a configured time. if (tasks.size > 0) { hasLaunchedTask = true } @@ -392,7 +466,7 @@ private[spark] class TaskSchedulerImpl( var reason: Option[ExecutorLossReason] = None synchronized { try { - taskIdToTaskSetManager.get(tid) match { + Option(taskIdToTaskSetManager.get(tid)) match { case Some(taskSet) => if (state == TaskState.LOST) { // TaskState.LOST is only used by the deprecated Mesos fine-grained scheduling mode, @@ -435,24 +509,26 @@ private[spark] class TaskSchedulerImpl( } /** - * Update metrics for in-progress tasks and let the master know that the BlockManager is still - * alive. Return true if the driver knows about the given block manager. Otherwise, return false, - * indicating that the block manager should re-register. + * Update metrics for in-progress tasks and executor metrics, and let the master know that the + * BlockManager is still alive. Return true if the driver knows about the given block manager. + * Otherwise, return false, indicating that the block manager should re-register. */ override def executorHeartbeatReceived( execId: String, accumUpdates: Array[(Long, Seq[AccumulatorV2[_, _]])], - blockManagerId: BlockManagerId): Boolean = { + blockManagerId: BlockManagerId, + executorMetrics: ExecutorMetrics): Boolean = { // (taskId, stageId, stageAttemptId, accumUpdates) - val accumUpdatesWithTaskIds: Array[(Long, Int, Int, Seq[AccumulableInfo])] = synchronized { + val accumUpdatesWithTaskIds: Array[(Long, Int, Int, Seq[AccumulableInfo])] = { accumUpdates.flatMap { case (id, updates) => val accInfos = updates.map(acc => acc.toInfo(Some(acc.value), None)) - taskIdToTaskSetManager.get(id).map { taskSetMgr => + Option(taskIdToTaskSetManager.get(id)).map { taskSetMgr => (id, taskSetMgr.stageId, taskSetMgr.taskSet.stageAttemptId, accInfos) } } } - dagScheduler.executorHeartbeatReceived(execId, accumUpdatesWithTaskIds, blockManagerId) + dagScheduler.executorHeartbeatReceived(execId, accumUpdatesWithTaskIds, blockManagerId, + executorMetrics) } def handleTaskGettingResult(taskSetManager: TaskSetManager, tid: Long): Unit = synchronized { @@ -510,6 +586,9 @@ private[spark] class TaskSchedulerImpl( if (taskResultGetter != null) { taskResultGetter.stop() } + if (barrierCoordinator != null) { + barrierCoordinator.stop() + } starvationTimer.cancel() } @@ -697,9 +776,12 @@ private[spark] class TaskSchedulerImpl( * do not also submit those same tasks. That also means that a task completion from an earlier * attempt can lead to the entire stage getting marked as successful. */ - private[scheduler] def markPartitionCompletedInAllTaskSets(stageId: Int, partitionId: Int) = { + private[scheduler] def markPartitionCompletedInAllTaskSets( + stageId: Int, + partitionId: Int, + taskInfo: TaskInfo) = { taskSetsByStageIdAndAttempt.getOrElse(stageId, Map()).values.foreach { tsm => - tsm.markPartitionCompleted(partitionId) + tsm.markPartitionCompleted(partitionId, taskInfo) } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index a18c66596852a..d5e85a11cb279 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -29,7 +29,7 @@ import org.apache.spark._ import org.apache.spark.TaskState.TaskState import org.apache.spark.internal.{config, Logging} import org.apache.spark.scheduler.SchedulingMode._ -import org.apache.spark.util.{AccumulatorV2, Clock, SystemClock, Utils} +import org.apache.spark.util.{AccumulatorV2, Clock, LongAccumulator, SystemClock, Utils} import org.apache.spark.util.collection.MedianHeap /** @@ -84,10 +84,10 @@ private[spark] class TaskSetManager( val successful = new Array[Boolean](numTasks) private val numFailures = new Array[Int](numTasks) - // Set the coresponding index of Boolean var when the task killed by other attempt tasks, - // this happened while we set the `spark.speculation` to true. The task killed by others + // Add the tid of task into this HashSet when the task is killed by other attempt tasks. + // This happened while we set the `spark.speculation` to true. The task killed by others // should not resubmit while executor lost. - private val killedByOtherAttempt: Array[Boolean] = new Array[Boolean](numTasks) + private val killedByOtherAttempt = new HashSet[Long] val taskAttempts = Array.fill[List[TaskInfo]](numTasks)(Nil) private[scheduler] var tasksSuccessful = 0 @@ -123,6 +123,10 @@ private[spark] class TaskSetManager( // TODO: We should kill any running task attempts when the task set manager becomes a zombie. private[scheduler] var isZombie = false + // Whether the taskSet run tasks from a barrier stage. Spark must launch all the tasks at the + // same time for a barrier stage. + private[scheduler] def isBarrier = taskSet.tasks.nonEmpty && taskSet.tasks(0).isBarrier + // Set of pending tasks for each executor. These collections are actually // treated as stacks, in which new tasks are added to the end of the // ArrayBuffer and removed from the end. This makes it faster to detect @@ -512,6 +516,7 @@ private[spark] class TaskSetManager( execId, taskName, index, + task.partitionId, addedFiles, addedJars, task.localProperties, @@ -723,6 +728,23 @@ private[spark] class TaskSetManager( def handleSuccessfulTask(tid: Long, result: DirectTaskResult[_]): Unit = { val info = taskInfos(tid) val index = info.index + // Check if any other attempt succeeded before this and this attempt has not been handled + if (successful(index) && killedByOtherAttempt.contains(tid)) { + // Undo the effect on calculatedTasks and totalResultSize made earlier when + // checking if can fetch more results + calculatedTasks -= 1 + val resultSizeAcc = result.accumUpdates.find(a => + a.name == Some(InternalAccumulator.RESULT_SIZE)) + if (resultSizeAcc.isDefined) { + totalResultSize -= resultSizeAcc.get.asInstanceOf[LongAccumulator].value + } + + // Handle this task as a killed task + handleFailedTask(tid, TaskState.KILLED, + TaskKilled("Finish but did not commit due to another attempt succeeded")) + return + } + info.markFinished(TaskState.FINISHED, clock.getTimeMillis()) if (speculationEnabled) { successfulTaskDurations.insert(info.duration) @@ -735,7 +757,7 @@ private[spark] class TaskSetManager( logInfo(s"Killing attempt ${attemptInfo.attemptNumber} for task ${attemptInfo.id} " + s"in stage ${taskSet.id} (TID ${attemptInfo.taskId}) on ${attemptInfo.host} " + s"as the attempt ${info.attemptNumber} succeeded on ${info.host}") - killedByOtherAttempt(index) = true + killedByOtherAttempt += attemptInfo.taskId sched.backend.killTask( attemptInfo.taskId, attemptInfo.executorId, @@ -758,7 +780,7 @@ private[spark] class TaskSetManager( } // There may be multiple tasksets for this stage -- we let all of them know that the partition // was completed. This may result in some of the tasksets getting completed. - sched.markPartitionCompletedInAllTaskSets(stageId, tasks(index).partitionId) + sched.markPartitionCompletedInAllTaskSets(stageId, tasks(index).partitionId, info) // This method is called by "TaskSchedulerImpl.handleSuccessfulTask" which holds the // "TaskSchedulerImpl" lock until exiting. To avoid the SPARK-7655 issue, we should not // "deserialize" the value when holding a lock to avoid blocking other threads. So we call @@ -769,9 +791,12 @@ private[spark] class TaskSetManager( maybeFinishTaskSet() } - private[scheduler] def markPartitionCompleted(partitionId: Int): Unit = { + private[scheduler] def markPartitionCompleted(partitionId: Int, taskInfo: TaskInfo): Unit = { partitionToIndex.get(partitionId).foreach { index => if (!successful(index)) { + if (speculationEnabled && !isZombie) { + successfulTaskDurations.insert(taskInfo.duration) + } tasksSuccessful += 1 successful(index) = true if (tasksSuccessful == numTasks) { @@ -868,6 +893,10 @@ private[spark] class TaskSetManager( None } + if (tasks(index).isBarrier) { + isZombie = true + } + sched.dagScheduler.taskEnded(tasks(index), reason, null, accumUpdates, info) if (!isZombie && reason.countTowardsTaskFailures) { @@ -944,7 +973,7 @@ private[spark] class TaskSetManager( && !isZombie) { for ((tid, info) <- taskInfos if info.executorId == execId) { val index = taskInfos(tid).index - if (successful(index) && !killedByOtherAttempt(index)) { + if (successful(index) && !killedByOtherAttempt.contains(tid)) { successful(index) = false copiesRunning(index) -= 1 tasksSuccessful -= 1 @@ -976,8 +1005,8 @@ private[spark] class TaskSetManager( */ override def checkSpeculatableTasks(minTimeToSpeculation: Int): Boolean = { // Can't speculate if we only have one task, and no need to speculate if the task set is a - // zombie. - if (isZombie || numTasks == 1) { + // zombie or is from a barrier stage. + if (isZombie || isBarrier || numTasks == 1) { return false } var foundTasks = false diff --git a/core/src/main/scala/org/apache/spark/scheduler/WorkerOffer.scala b/core/src/main/scala/org/apache/spark/scheduler/WorkerOffer.scala index 810b36cddf835..6ec74913e42f2 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/WorkerOffer.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/WorkerOffer.scala @@ -21,4 +21,10 @@ package org.apache.spark.scheduler * Represents free resources available on an executor. */ private[spark] -case class WorkerOffer(executorId: String, host: String, cores: Int) +case class WorkerOffer( + executorId: String, + host: String, + cores: Int, + // `address` is an optional hostPort string, it provide more useful information than `host` + // when multiple executors are launched on the same host. + address: Option[String] = None) diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index 9b90e309d2e04..de7c0d813ae65 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -242,7 +242,8 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp val activeExecutors = executorDataMap.filterKeys(executorIsAlive) val workOffers = activeExecutors.map { case (id, executorData) => - new WorkerOffer(id, executorData.executorHost, executorData.freeCores) + new WorkerOffer(id, executorData.executorHost, executorData.freeCores, + Some(executorData.executorAddress.hostPort)) }.toIndexedSeq scheduler.resourceOffers(workOffers) } @@ -267,7 +268,8 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp if (executorIsAlive(executorId)) { val executorData = executorDataMap(executorId) val workOffers = IndexedSeq( - new WorkerOffer(executorId, executorData.executorHost, executorData.freeCores)) + new WorkerOffer(executorId, executorData.executorHost, executorData.freeCores, + Some(executorData.executorAddress.hostPort))) scheduler.resourceOffers(workOffers) } else { Seq.empty @@ -288,7 +290,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp for (task <- tasks.flatten) { val serializedTask = TaskDescription.encode(task) if (serializedTask.limit() >= maxRpcMessageSize) { - scheduler.taskIdToTaskSetManager.get(task.taskId).foreach { taskSetMgr => + Option(scheduler.taskIdToTaskSetManager.get(task.taskId)).foreach { taskSetMgr => try { var msg = "Serialized task %s:%d was %d bytes, which exceeds max allowed: " + "spark.rpc.message.maxSize (%d bytes). Consider increasing " + @@ -494,6 +496,12 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp executorDataMap.keySet.toSeq } + override def maxNumConcurrentTasks(): Int = { + executorDataMap.values.map { executor => + executor.totalCores / scheduler.CPUS_PER_TASK + }.sum + } + /** * Request an additional number of executors from the cluster manager. * @return whether the request is acknowledged. diff --git a/core/src/main/scala/org/apache/spark/scheduler/local/LocalSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/local/LocalSchedulerBackend.scala index 4c614c5c0f602..0de57fbd5600c 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/local/LocalSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/local/LocalSchedulerBackend.scala @@ -81,7 +81,8 @@ private[spark] class LocalEndpoint( } def reviveOffers() { - val offers = IndexedSeq(new WorkerOffer(localExecutorId, localExecutorHostname, freeCores)) + val offers = IndexedSeq(new WorkerOffer(localExecutorId, localExecutorHostname, freeCores, + Some(rpcEnv.address.hostPort))) for (task <- scheduler.resourceOffers(offers).flatten) { freeCores -= scheduler.CPUS_PER_TASK executor.launchTask(executorBackend, task) @@ -155,6 +156,8 @@ private[spark] class LocalSchedulerBackend( override def applicationId(): String = appId + override def maxNumConcurrentTasks(): Int = totalCores / scheduler.CPUS_PER_TASK + private def stop(finalState: SparkAppHandle.State): Unit = { localEndpoint.ask(StopExecutor) try { diff --git a/core/src/main/scala/org/apache/spark/security/SocketAuthHelper.scala b/core/src/main/scala/org/apache/spark/security/SocketAuthHelper.scala index d15e7937b0523..ea38ccb289c30 100644 --- a/core/src/main/scala/org/apache/spark/security/SocketAuthHelper.scala +++ b/core/src/main/scala/org/apache/spark/security/SocketAuthHelper.scala @@ -42,43 +42,59 @@ private[spark] class SocketAuthHelper(conf: SparkConf) { * Read the auth secret from the socket and compare to the expected value. Write the reply back * to the socket. * - * If authentication fails, this method will close the socket. + * If authentication fails or error is thrown, this method will close the socket. * * @param s The client socket. * @throws IllegalArgumentException If authentication fails. */ def authClient(s: Socket): Unit = { - // Set the socket timeout while checking the auth secret. Reset it before returning. - val currentTimeout = s.getSoTimeout() + var shouldClose = true try { - s.setSoTimeout(10000) - val clientSecret = readUtf8(s) - if (secret == clientSecret) { - writeUtf8("ok", s) - } else { - writeUtf8("err", s) - JavaUtils.closeQuietly(s) + // Set the socket timeout while checking the auth secret. Reset it before returning. + val currentTimeout = s.getSoTimeout() + try { + s.setSoTimeout(10000) + val clientSecret = readUtf8(s) + if (secret == clientSecret) { + writeUtf8("ok", s) + shouldClose = false + } else { + writeUtf8("err", s) + throw new IllegalArgumentException("Authentication failed.") + } + } finally { + s.setSoTimeout(currentTimeout) } } finally { - s.setSoTimeout(currentTimeout) + if (shouldClose) { + JavaUtils.closeQuietly(s) + } } } /** * Authenticate with a server by writing the auth secret and checking the server's reply. * - * If authentication fails, this method will close the socket. + * If authentication fails or error is thrown, this method will close the socket. * * @param s The socket connected to the server. * @throws IllegalArgumentException If authentication fails. */ def authToServer(s: Socket): Unit = { - writeUtf8(secret, s) + var shouldClose = true + try { + writeUtf8(secret, s) - val reply = readUtf8(s) - if (reply != "ok") { - JavaUtils.closeQuietly(s) - throw new IllegalArgumentException("Authentication failed.") + val reply = readUtf8(s) + if (reply != "ok") { + throw new IllegalArgumentException("Authentication failed.") + } else { + shouldClose = false + } + } finally { + if (shouldClose) { + JavaUtils.closeQuietly(s) + } } } diff --git a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala index 4103dfb10175e..74b0e0b3a741a 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala @@ -104,7 +104,7 @@ private[spark] class BlockStoreShuffleReader[K, C]( context.taskMetrics().incDiskBytesSpilled(sorter.diskBytesSpilled) context.taskMetrics().incPeakExecutionMemory(sorter.peakMemoryUsedBytes) // Use completion callback to stop sorter if task was finished/cancelled. - context.addTaskCompletionListener(_ => { + context.addTaskCompletionListener[Unit](_ => { sorter.stop() }) CompletionIterator[Product2[K, C], Iterator[Product2[K, C]]](sorter.iterator, sorter.stop()) diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala index d9fad64f34c7c..0caf84c6050a8 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala @@ -27,7 +27,7 @@ import org.apache.spark.shuffle._ * In sort-based shuffle, incoming records are sorted according to their target partition ids, then * written to a single map output file. Reducers fetch contiguous regions of this file in order to * read their portion of the map output. In cases where the map output data is too large to fit in - * memory, sorted subsets of the output can are spilled to disk and those on-disk files are merged + * memory, sorted subsets of the output can be spilled to disk and those on-disk files are merged * to produce the final output file. * * Sort-based shuffle has two different write paths for producing its map output files: diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala index 274399b9cc1f3..91fc26762e533 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala @@ -70,7 +70,8 @@ private[spark] class SortShuffleWriter[K, V, C]( val blockId = ShuffleBlockId(dep.shuffleId, mapId, IndexShuffleBlockResolver.NOOP_REDUCE_ID) val partitionLengths = sorter.writePartitionedFile(blockId, tmp) shuffleBlockResolver.writeIndexFileAndCommit(dep.shuffleId, mapId, partitionLengths, tmp) - mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths) + mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths, + writeMetrics.recordsWritten) } finally { if (tmp.exists() && !tmp.delete()) { logError(s"Error while deleting temp file ${tmp.getAbsolutePath}") diff --git a/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala b/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala index 5ea161cd0d151..f21eee1965761 100644 --- a/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala +++ b/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala @@ -25,7 +25,7 @@ import scala.collection.JavaConverters._ import scala.collection.mutable.HashMap import org.apache.spark._ -import org.apache.spark.executor.TaskMetrics +import org.apache.spark.executor.{ExecutorMetrics, TaskMetrics} import org.apache.spark.internal.Logging import org.apache.spark.scheduler._ import org.apache.spark.status.api.v1 @@ -66,6 +66,7 @@ private[spark] class AppStatusListener( private val liveStages = new ConcurrentHashMap[(Int, Int), LiveStage]() private val liveJobs = new HashMap[Int, LiveJob]() private val liveExecutors = new HashMap[String, LiveExecutor]() + private val deadExecutors = new HashMap[String, LiveExecutor]() private val liveTasks = new HashMap[Long, LiveTask]() private val liveRDDs = new HashMap[Int, LiveRDD]() private val pools = new HashMap[String, SchedulerPool]() @@ -204,6 +205,19 @@ private[spark] class AppStatusListener( update(rdd, now) } } + if (isExecutorActiveForLiveStages(exec)) { + // the executor was running for a currently active stage, so save it for now in + // deadExecutors, and remove when there are no active stages overlapping with the + // executor. + deadExecutors.put(event.executorId, exec) + } + } + } + + /** Was the specified executor active for any currently live stages? */ + private def isExecutorActiveForLiveStages(exec: LiveExecutor): Boolean = { + liveStages.values.asScala.exists { stage => + stage.info.submissionTime.getOrElse(0L) < exec.removeTime.getTime } } @@ -350,11 +364,20 @@ private[spark] class AppStatusListener( val e = it.next() if (job.stageIds.contains(e.getKey()._1)) { val stage = e.getValue() - stage.status = v1.StageStatus.SKIPPED - job.skippedStages += stage.info.stageId - job.skippedTasks += stage.info.numTasks - it.remove() - update(stage, now) + if (v1.StageStatus.PENDING.equals(stage.status)) { + stage.status = v1.StageStatus.SKIPPED + job.skippedStages += stage.info.stageId + job.skippedTasks += stage.info.numTasks + job.activeStages -= 1 + + pools.get(stage.schedulingPool).foreach { pool => + pool.stageIds = pool.stageIds - stage.info.stageId + update(pool, now) + } + + it.remove() + update(stage, now, last = true) + } } } @@ -506,7 +529,16 @@ private[spark] class AppStatusListener( if (killedDelta > 0) { stage.killedSummary = killedTasksSummary(event.reason, stage.killedSummary) } - maybeUpdate(stage, now) + // [SPARK-24415] Wait for all tasks to finish before removing stage from live list + val removeStage = + stage.activeTasks == 0 && + (v1.StageStatus.COMPLETE.equals(stage.status) || + v1.StageStatus.FAILED.equals(stage.status)) + if (removeStage) { + update(stage, now, last = true) + } else { + maybeUpdate(stage, now) + } // Store both stage ID and task index in a single long variable for tracking at job level. val taskIndex = (event.stageId.toLong << Integer.SIZE) | event.taskInfo.index @@ -521,7 +553,7 @@ private[spark] class AppStatusListener( if (killedDelta > 0) { job.killedSummary = killedTasksSummary(event.reason, job.killedSummary) } - maybeUpdate(job, now) + conditionalLiveUpdate(job, now, removeStage) } val esummary = stage.executorSummary(event.taskInfo.executorId) @@ -532,7 +564,7 @@ private[spark] class AppStatusListener( if (metricsDelta != null) { esummary.metrics = LiveEntityHelpers.addMetrics(esummary.metrics, metricsDelta) } - maybeUpdate(esummary, now) + conditionalLiveUpdate(esummary, now, removeStage) if (!stage.cleaning && stage.savedTasks.get() > maxTasksPerStage) { stage.cleaning = true @@ -540,6 +572,9 @@ private[spark] class AppStatusListener( cleanupTasks(stage) } } + if (removeStage) { + liveStages.remove((event.stageId, event.stageAttemptId)) + } } liveExecutors.get(event.taskInfo.executorId).foreach { exec => @@ -564,17 +599,13 @@ private[spark] class AppStatusListener( // Force an update on live applications when the number of active tasks reaches 0. This is // checked in some tests (e.g. SQLTestUtilsBase) so it needs to be reliably up to date. - if (exec.activeTasks == 0) { - liveUpdate(exec, now) - } else { - maybeUpdate(exec, now) - } + conditionalLiveUpdate(exec, now, exec.activeTasks == 0) } } override def onStageCompleted(event: SparkListenerStageCompleted): Unit = { val maybeStage = - Option(liveStages.remove((event.stageInfo.stageId, event.stageInfo.attemptNumber))) + Option(liveStages.get((event.stageInfo.stageId, event.stageInfo.attemptNumber))) maybeStage.foreach { stage => val now = System.nanoTime() stage.info = event.stageInfo @@ -608,7 +639,6 @@ private[spark] class AppStatusListener( } stage.executorSummaries.values.foreach(update(_, now)) - update(stage, now, last = true) val executorIdsForStage = stage.blackListedExecutors executorIdsForStage.foreach { executorId => @@ -616,8 +646,18 @@ private[spark] class AppStatusListener( removeBlackListedStageFrom(exec, event.stageInfo.stageId, now) } } + + // Remove stage only if there are no active tasks remaining + val removeStage = stage.activeTasks == 0 + update(stage, now, last = removeStage) + if (removeStage) { + liveStages.remove((event.stageInfo.stageId, event.stageInfo.attemptNumber)) + } } + // remove any dead executors that were not running for any currently active stages + deadExecutors.retain((execId, exec) => isExecutorActiveForLiveStages(exec)) + appSummary = new AppSummary(appSummary.numCompletedJobs, appSummary.numCompletedStages + 1) kvstore.write(appSummary) } @@ -646,7 +686,37 @@ private[spark] class AppStatusListener( } override def onUnpersistRDD(event: SparkListenerUnpersistRDD): Unit = { - liveRDDs.remove(event.rddId) + liveRDDs.remove(event.rddId).foreach { liveRDD => + val storageLevel = liveRDD.info.storageLevel + + // Use RDD partition info to update executor block info. + liveRDD.getPartitions().foreach { case (_, part) => + part.executors.foreach { executorId => + liveExecutors.get(executorId).foreach { exec => + exec.rddBlocks = exec.rddBlocks - 1 + } + } + } + + val now = System.nanoTime() + + // Use RDD distribution to update executor memory and disk usage info. + liveRDD.getDistributions().foreach { case (executorId, rddDist) => + liveExecutors.get(executorId).foreach { exec => + if (exec.hasMemoryInfo) { + if (storageLevel.useOffHeap) { + exec.usedOffHeap = addDeltaToValue(exec.usedOffHeap, -rddDist.offHeapUsed) + } else { + exec.usedOnHeap = addDeltaToValue(exec.usedOnHeap, -rddDist.onHeapUsed) + } + } + exec.memoryUsed = addDeltaToValue(exec.memoryUsed, -rddDist.memoryUsed) + exec.diskUsed = addDeltaToValue(exec.diskUsed, -rddDist.diskUsed) + maybeUpdate(exec, now) + } + } + } + kvstore.delete(classOf[RDDStorageInfoWrapper], event.rddId) } @@ -669,6 +739,31 @@ private[spark] class AppStatusListener( } } } + + // check if there is a new peak value for any of the executor level memory metrics + // for the live UI. SparkListenerExecutorMetricsUpdate events are only processed + // for the live UI. + event.executorUpdates.foreach { updates => + liveExecutors.get(event.execId).foreach { exec => + if (exec.peakExecutorMetrics.compareAndUpdatePeakValues(updates)) { + maybeUpdate(exec, now) + } + } + } + } + + override def onStageExecutorMetrics(executorMetrics: SparkListenerStageExecutorMetrics): Unit = { + val now = System.nanoTime() + + // check if there is a new peak value for any of the executor level memory metrics, + // while reading from the log. SparkListenerStageExecutorMetrics are only processed + // when reading logs. + liveExecutors.get(executorMetrics.execId) + .orElse(deadExecutors.get(executorMetrics.execId)).map { exec => + if (exec.peakExecutorMetrics.compareAndUpdatePeakValues(executorMetrics.executorMetrics)) { + update(exec, now) + } + } } override def onBlockUpdated(event: SparkListenerBlockUpdated): Unit = { @@ -705,6 +800,11 @@ private[spark] class AppStatusListener( .sortBy(_.stageId) } + /** + * Apply a delta to a value, but ensure that it doesn't go negative. + */ + private def addDeltaToValue(old: Long, delta: Long): Long = math.max(0, old + delta) + private def updateRDDBlock(event: SparkListenerBlockUpdated, block: RDDBlockId): Unit = { val now = System.nanoTime() val executorId = event.blockUpdatedInfo.blockManagerId.executorId @@ -714,9 +814,6 @@ private[spark] class AppStatusListener( val diskDelta = event.blockUpdatedInfo.diskSize * (if (storageLevel.useDisk) 1 else -1) val memoryDelta = event.blockUpdatedInfo.memSize * (if (storageLevel.useMemory) 1 else -1) - // Function to apply a delta to a value, but ensure that it doesn't go negative. - def newValue(old: Long, delta: Long): Long = math.max(0, old + delta) - val updatedStorageLevel = if (storageLevel.isValid) { Some(storageLevel.description) } else { @@ -733,13 +830,13 @@ private[spark] class AppStatusListener( maybeExec.foreach { exec => if (exec.hasMemoryInfo) { if (storageLevel.useOffHeap) { - exec.usedOffHeap = newValue(exec.usedOffHeap, memoryDelta) + exec.usedOffHeap = addDeltaToValue(exec.usedOffHeap, memoryDelta) } else { - exec.usedOnHeap = newValue(exec.usedOnHeap, memoryDelta) + exec.usedOnHeap = addDeltaToValue(exec.usedOnHeap, memoryDelta) } } - exec.memoryUsed = newValue(exec.memoryUsed, memoryDelta) - exec.diskUsed = newValue(exec.diskUsed, diskDelta) + exec.memoryUsed = addDeltaToValue(exec.memoryUsed, memoryDelta) + exec.diskUsed = addDeltaToValue(exec.diskUsed, diskDelta) } // Update the block entry in the RDD info, keeping track of the deltas above so that we @@ -767,8 +864,8 @@ private[spark] class AppStatusListener( // Only update the partition if it's still stored in some executor, otherwise get rid of it. if (executors.nonEmpty) { partition.update(executors, rdd.storageLevel, - newValue(partition.memoryUsed, memoryDelta), - newValue(partition.diskUsed, diskDelta)) + addDeltaToValue(partition.memoryUsed, memoryDelta), + addDeltaToValue(partition.diskUsed, diskDelta)) } else { rdd.removePartition(block.name) } @@ -776,14 +873,14 @@ private[spark] class AppStatusListener( maybeExec.foreach { exec => if (exec.rddBlocks + rddBlocksDelta > 0) { val dist = rdd.distribution(exec) - dist.memoryUsed = newValue(dist.memoryUsed, memoryDelta) - dist.diskUsed = newValue(dist.diskUsed, diskDelta) + dist.memoryUsed = addDeltaToValue(dist.memoryUsed, memoryDelta) + dist.diskUsed = addDeltaToValue(dist.diskUsed, diskDelta) if (exec.hasMemoryInfo) { if (storageLevel.useOffHeap) { - dist.offHeapUsed = newValue(dist.offHeapUsed, memoryDelta) + dist.offHeapUsed = addDeltaToValue(dist.offHeapUsed, memoryDelta) } else { - dist.onHeapUsed = newValue(dist.onHeapUsed, memoryDelta) + dist.onHeapUsed = addDeltaToValue(dist.onHeapUsed, memoryDelta) } } dist.lastUpdate = null @@ -802,8 +899,8 @@ private[spark] class AppStatusListener( } } - rdd.memoryUsed = newValue(rdd.memoryUsed, memoryDelta) - rdd.diskUsed = newValue(rdd.diskUsed, diskDelta) + rdd.memoryUsed = addDeltaToValue(rdd.memoryUsed, memoryDelta) + rdd.diskUsed = addDeltaToValue(rdd.diskUsed, diskDelta) update(rdd, now) } @@ -882,6 +979,14 @@ private[spark] class AppStatusListener( } } + private def conditionalLiveUpdate(entity: LiveEntity, now: Long, condition: Boolean): Unit = { + if (condition) { + liveUpdate(entity, now) + } else { + maybeUpdate(entity, now) + } + } + private def cleanupExecutors(count: Long): Unit = { // Because the limit is on the number of *dead* executors, we need to calculate whether // there are actually enough dead executors to be deleted. diff --git a/core/src/main/scala/org/apache/spark/status/AppStatusStore.scala b/core/src/main/scala/org/apache/spark/status/AppStatusStore.scala index 688f25a9fdea1..e237281c552b1 100644 --- a/core/src/main/scala/org/apache/spark/status/AppStatusStore.scala +++ b/core/src/main/scala/org/apache/spark/status/AppStatusStore.scala @@ -471,7 +471,7 @@ private[spark] class AppStatusStore( def operationGraphForJob(jobId: Int): Seq[RDDOperationGraph] = { val job = store.read(classOf[JobDataWrapper], jobId) - val stages = job.info.stageIds + val stages = job.info.stageIds.sorted stages.map { id => val g = store.read(classOf[RDDOperationGraphWrapper], id).toRDDOperationGraph() diff --git a/core/src/main/scala/org/apache/spark/status/LiveEntity.scala b/core/src/main/scala/org/apache/spark/status/LiveEntity.scala index 79e3f13b826ce..8708e64db3c17 100644 --- a/core/src/main/scala/org/apache/spark/status/LiveEntity.scala +++ b/core/src/main/scala/org/apache/spark/status/LiveEntity.scala @@ -26,14 +26,13 @@ import scala.collection.mutable.HashMap import com.google.common.collect.Interners import org.apache.spark.JobExecutionStatus -import org.apache.spark.executor.TaskMetrics +import org.apache.spark.executor.{ExecutorMetrics, TaskMetrics} import org.apache.spark.scheduler.{AccumulableInfo, StageInfo, TaskInfo} import org.apache.spark.status.api.v1 import org.apache.spark.storage.RDDInfo import org.apache.spark.ui.SparkUI import org.apache.spark.util.AccumulatorContext import org.apache.spark.util.collection.OpenHashSet -import org.apache.spark.util.kvstore.KVStore /** * A mutable representation of a live entity in Spark (jobs, stages, tasks, et al). Every live @@ -268,6 +267,9 @@ private class LiveExecutor(val executorId: String, _addTime: Long) extends LiveE def hasMemoryInfo: Boolean = totalOnHeap >= 0L + // peak values for executor level metrics + val peakExecutorMetrics = new ExecutorMetrics() + def hostname: String = if (host != null) host else hostPort.split(":")(0) override protected def doUpdate(): Any = { @@ -302,10 +304,10 @@ private class LiveExecutor(val executorId: String, _addTime: Long) extends LiveE Option(removeReason), executorLogs, memoryMetrics, - blacklistedInStages) + blacklistedInStages, + Some(peakExecutorMetrics).filter(_.isSet)) new ExecutorSummaryWrapper(info) } - } private class LiveExecutorStageSummary( @@ -538,6 +540,10 @@ private class LiveRDD(val info: RDDInfo) extends LiveEntity { distributions.get(exec.executorId) } + def getPartitions(): scala.collection.Map[String, LiveRDDPartition] = partitions + + def getDistributions(): scala.collection.Map[String, LiveRDDDistribution] = distributions + override protected def doUpdate(): Any = { val dists = if (distributions.nonEmpty) { Some(distributions.values.map(_.toApi()).toSeq) @@ -581,8 +587,7 @@ private object LiveEntityHelpers { .filter { acc => // We don't need to store internal or SQL accumulables as their values will be shown in // other places, so drop them to reduce the memory usage. - !acc.internal && (!acc.metadata.isDefined || - acc.metadata.get != Some(AccumulatorContext.SQL_ACCUM_IDENTIFIER)) + !acc.internal && acc.metadata != Some(AccumulatorContext.SQL_ACCUM_IDENTIFIER) } .map { acc => new v1.AccumulableInfo( diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/api.scala b/core/src/main/scala/org/apache/spark/status/api/v1/api.scala index 971d7e90fa7b8..77466b62ff6ed 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/api.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/api.scala @@ -22,9 +22,14 @@ import java.util.Date import scala.xml.{NodeSeq, Text} import com.fasterxml.jackson.annotation.JsonIgnoreProperties -import com.fasterxml.jackson.databind.annotation.JsonDeserialize +import com.fasterxml.jackson.core.{JsonGenerator, JsonParser} +import com.fasterxml.jackson.core.`type`.TypeReference +import com.fasterxml.jackson.databind.{DeserializationContext, JsonDeserializer, JsonSerializer, SerializerProvider} +import com.fasterxml.jackson.databind.annotation.{JsonDeserialize, JsonSerialize} import org.apache.spark.JobExecutionStatus +import org.apache.spark.executor.ExecutorMetrics +import org.apache.spark.metrics.ExecutorMetricType case class ApplicationInfo private[spark]( id: String, @@ -98,7 +103,10 @@ class ExecutorSummary private[spark]( val removeReason: Option[String], val executorLogs: Map[String, String], val memoryMetrics: Option[MemoryMetrics], - val blacklistedInStages: Set[Int]) + val blacklistedInStages: Set[Int], + @JsonSerialize(using = classOf[ExecutorMetricsJsonSerializer]) + @JsonDeserialize(using = classOf[ExecutorMetricsJsonDeserializer]) + val peakMemoryMetrics: Option[ExecutorMetrics]) class MemoryMetrics private[spark]( val usedOnHeapStorageMemory: Long, @@ -106,6 +114,33 @@ class MemoryMetrics private[spark]( val totalOnHeapStorageMemory: Long, val totalOffHeapStorageMemory: Long) +/** deserializer for peakMemoryMetrics: convert map to ExecutorMetrics */ +private[spark] class ExecutorMetricsJsonDeserializer + extends JsonDeserializer[Option[ExecutorMetrics]] { + override def deserialize( + jsonParser: JsonParser, + deserializationContext: DeserializationContext): Option[ExecutorMetrics] = { + val metricsMap = jsonParser.readValueAs[Option[Map[String, Long]]]( + new TypeReference[Option[Map[String, java.lang.Long]]] {}) + metricsMap.map(metrics => new ExecutorMetrics(metrics)) + } +} +/** serializer for peakMemoryMetrics: convert ExecutorMetrics to map with metric name as key */ +private[spark] class ExecutorMetricsJsonSerializer + extends JsonSerializer[Option[ExecutorMetrics]] { + override def serialize( + metrics: Option[ExecutorMetrics], + jsonGenerator: JsonGenerator, + serializerProvider: SerializerProvider): Unit = { + metrics.foreach { m: ExecutorMetrics => + val metricsMap = ExecutorMetricType.values.map { metricType => + metricType.name -> m.getMetricValue(metricType) + }.toMap + jsonGenerator.writeObject(metricsMap) + } + } +} + class JobData private[spark]( val jobId: Int, val name: String, diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index df1a4bef616b2..f5c69ad241e3a 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -41,10 +41,12 @@ import org.apache.spark.memory.{MemoryManager, MemoryMode} import org.apache.spark.metrics.source.Source import org.apache.spark.network._ import org.apache.spark.network.buffer.ManagedBuffer +import org.apache.spark.network.client.StreamCallbackWithID import org.apache.spark.network.netty.SparkTransportConf import org.apache.spark.network.shuffle.{ExternalShuffleClient, TempFileManager} import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo import org.apache.spark.rpc.RpcEnv +import org.apache.spark.scheduler.ExecutorCacheTaskLocation import org.apache.spark.serializer.{SerializerInstance, SerializerManager} import org.apache.spark.shuffle.ShuffleManager import org.apache.spark.storage.memory._ @@ -128,7 +130,11 @@ private[spark] class BlockManager( extends BlockDataManager with BlockEvictionHandler with Logging { private[spark] val externalShuffleServiceEnabled = - conf.getBoolean("spark.shuffle.service.enabled", false) + conf.get(config.SHUFFLE_SERVICE_ENABLED) + private val chunkSize = + conf.getSizeAsBytes("spark.storage.memoryMapLimitForTests", Int.MaxValue.toString).toInt + private val remoteReadNioBufferConversion = + conf.getBoolean("spark.network.remoteReadNioBufferConversion", false) val diskBlockManager = { // Only perform cleanup if an external service is not serving our shuffle files. @@ -159,12 +165,13 @@ private[spark] class BlockManager( // Port used by the external shuffle service. In Yarn mode, this may be already be // set through the Hadoop configuration as the server is launched in the Yarn NM. private val externalShuffleServicePort = { - val tmpPort = Utils.getSparkOrYarnConfig(conf, "spark.shuffle.service.port", "7337").toInt + val tmpPort = Utils.getSparkOrYarnConfig(conf, config.SHUFFLE_SERVICE_PORT.key, + config.SHUFFLE_SERVICE_PORT.defaultValueString).toInt if (tmpPort == 0) { // for testing, we set "spark.shuffle.service.port" to 0 in the yarn config, so yarn finds // an open port. But we still need to tell our spark apps the right port to use. So // only if the yarn config has the port set to 0, we prefer the value in the spark config - conf.get("spark.shuffle.service.port").toInt + conf.get(config.SHUFFLE_SERVICE_PORT.key).toInt } else { tmpPort } @@ -401,6 +408,63 @@ private[spark] class BlockManager( putBytes(blockId, new ChunkedByteBuffer(data.nioByteBuffer()), level)(classTag) } + override def putBlockDataAsStream( + blockId: BlockId, + level: StorageLevel, + classTag: ClassTag[_]): StreamCallbackWithID = { + // TODO if we're going to only put the data in the disk store, we should just write it directly + // to the final location, but that would require a deeper refactor of this code. So instead + // we just write to a temp file, and call putBytes on the data in that file. + val tmpFile = diskBlockManager.createTempLocalBlock()._2 + val channel = new CountingWritableChannel( + Channels.newChannel(serializerManager.wrapForEncryption(new FileOutputStream(tmpFile)))) + logTrace(s"Streaming block $blockId to tmp file $tmpFile") + new StreamCallbackWithID { + + override def getID: String = blockId.name + + override def onData(streamId: String, buf: ByteBuffer): Unit = { + while (buf.hasRemaining) { + channel.write(buf) + } + } + + override def onComplete(streamId: String): Unit = { + logTrace(s"Done receiving block $blockId, now putting into local blockManager") + // Read the contents of the downloaded file as a buffer to put into the blockManager. + // Note this is all happening inside the netty thread as soon as it reads the end of the + // stream. + channel.close() + // TODO SPARK-25035 Even if we're only going to write the data to disk after this, we end up + // using a lot of memory here. With encryption, we'll read the whole file into a regular + // byte buffer and OOM. Without encryption, we'll memory map the file and won't get a jvm + // OOM, but might get killed by the OS / cluster manager. We could at least read the tmp + // file as a stream in both cases. + val buffer = securityManager.getIOEncryptionKey() match { + case Some(key) => + // we need to pass in the size of the unencrypted block + val blockSize = channel.getCount + val allocator = level.memoryMode match { + case MemoryMode.ON_HEAP => ByteBuffer.allocate _ + case MemoryMode.OFF_HEAP => Platform.allocateDirectBuffer _ + } + new EncryptedBlockData(tmpFile, blockSize, conf, key).toChunkedByteBuffer(allocator) + + case None => + ChunkedByteBuffer.map(tmpFile, conf.get(config.MEMORY_MAP_LIMIT_FOR_TESTS).toInt) + } + putBytes(blockId, buffer, level)(classTag) + tmpFile.delete() + } + + override def onFailure(streamId: String, cause: Throwable): Unit = { + // the framework handles the connection itself, we just need to do local cleanup + channel.close() + tmpFile.delete() + } + } + } + /** * Get the BlockStatus for the block identified by the given ID, if it exists. * NOTE: This is mainly for testing. @@ -659,6 +723,11 @@ private[spark] class BlockManager( * Get block from remote block managers as serialized bytes. */ def getRemoteBytes(blockId: BlockId): Option[ChunkedByteBuffer] = { + // TODO if we change this method to return the ManagedBuffer, then getRemoteValues + // could just use the inputStream on the temp file, rather than memory-mapping the file. + // Until then, replication can cause the process to use too much memory and get killed + // by the OS / cluster manager (not a java OOM, since it's a memory-mapped file) even though + // we've read the data to disk. logDebug(s"Getting remote block $blockId") require(blockId != null, "BlockId is null") var runningFailureCount = 0 @@ -689,7 +758,7 @@ private[spark] class BlockManager( logDebug(s"Getting remote block $blockId from $loc") val data = try { blockTransferService.fetchBlockSync( - loc.host, loc.port, loc.executorId, blockId.toString, tempFileManager).nioByteBuffer() + loc.host, loc.port, loc.executorId, blockId.toString, tempFileManager) } catch { case NonFatal(e) => runningFailureCount += 1 @@ -723,7 +792,14 @@ private[spark] class BlockManager( } if (data != null) { - return Some(new ChunkedByteBuffer(data)) + // SPARK-24307 undocumented "escape-hatch" in case there are any issues in converting to + // ChunkedByteBuffer, to go back to old code-path. Can be removed post Spark 2.4 if + // new path is stable. + if (remoteReadNioBufferConversion) { + return Some(new ChunkedByteBuffer(data.nioByteBuffer())) + } else { + return Some(ChunkedByteBuffer.fromManagedBuffer(data, chunkSize)) + } } logDebug(s"The value of block $blockId is null") } @@ -1341,12 +1417,16 @@ private[spark] class BlockManager( try { val onePeerStartTime = System.nanoTime logTrace(s"Trying to replicate $blockId of ${data.size} bytes to $peer") + // This thread keeps a lock on the block, so we do not want the netty thread to unlock + // block when it finishes sending the message. + val buffer = new BlockManagerManagedBuffer(blockInfoManager, blockId, data, false, + unlockOnDeallocate = false) blockTransferService.uploadBlockSync( peer.host, peer.port, peer.executorId, blockId, - new BlockManagerManagedBuffer(blockInfoManager, blockId, data, false), + buffer, tLevel, classTag) logTrace(s"Replicated $blockId of ${data.size} bytes to $peer" + @@ -1554,7 +1634,7 @@ private[spark] class BlockManager( private[spark] object BlockManager { private val ID_GENERATOR = new IdGenerator - def blockIdsToHosts( + def blockIdsToLocations( blockIds: Array[BlockId], env: SparkEnv, blockManagerMaster: BlockManagerMaster = null): Map[BlockId, Seq[String]] = { @@ -1569,7 +1649,9 @@ private[spark] object BlockManager { val blockManagers = new HashMap[BlockId, Seq[String]] for (i <- 0 until blockIds.length) { - blockManagers(blockIds(i)) = blockLocations(i).map(_.host) + blockManagers(blockIds(i)) = blockLocations(i).map { loc => + ExecutorCacheTaskLocation(loc.host, loc.executorId).toString + } } blockManagers.toMap } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerManagedBuffer.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerManagedBuffer.scala index 3d3806126676c..5c12b5cee4d2f 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerManagedBuffer.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerManagedBuffer.scala @@ -38,7 +38,8 @@ private[storage] class BlockManagerManagedBuffer( blockInfoManager: BlockInfoManager, blockId: BlockId, data: BlockData, - dispose: Boolean) extends ManagedBuffer { + dispose: Boolean, + unlockOnDeallocate: Boolean = true) extends ManagedBuffer { private val refCount = new AtomicInteger(1) @@ -58,7 +59,9 @@ private[storage] class BlockManagerManagedBuffer( } override def release(): ManagedBuffer = { - blockInfoManager.unlock(blockId) + if (unlockOnDeallocate) { + blockInfoManager.unlock(blockId) + } if (refCount.decrementAndGet() == 0 && dispose) { data.dispose() } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala index 8e8f7d197c9ef..f984cf76e3463 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala @@ -54,7 +54,8 @@ class BlockManagerMasterEndpoint( // Mapping from block id to the set of block managers that have the block. private val blockLocations = new JHashMap[BlockId, mutable.HashSet[BlockManagerId]] - private val askThreadPool = ThreadUtils.newDaemonCachedThreadPool("block-manager-ask-thread-pool") + private val askThreadPool = + ThreadUtils.newDaemonCachedThreadPool("block-manager-ask-thread-pool", 100) private implicit val askExecutionContext = ExecutionContext.fromExecutorService(askThreadPool) private val topologyMapper = { diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveEndpoint.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveEndpoint.scala index 742cf4fe393f9..67544b20408a6 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveEndpoint.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveEndpoint.scala @@ -37,7 +37,7 @@ class BlockManagerSlaveEndpoint( extends ThreadSafeRpcEndpoint with Logging { private val asyncThreadPool = - ThreadUtils.newDaemonCachedThreadPool("block-manager-slave-async-thread-pool") + ThreadUtils.newDaemonCachedThreadPool("block-manager-slave-async-thread-pool", 100) private implicit val asyncExecutionContext = ExecutionContext.fromExecutorService(asyncThreadPool) // Operations that involve removing blocks may be slow and should be done asynchronously diff --git a/core/src/main/scala/org/apache/spark/storage/DiskStore.scala b/core/src/main/scala/org/apache/spark/storage/DiskStore.scala index 39249d411b582..a820bc70b33b2 100644 --- a/core/src/main/scala/org/apache/spark/storage/DiskStore.scala +++ b/core/src/main/scala/org/apache/spark/storage/DiskStore.scala @@ -29,7 +29,7 @@ import com.google.common.io.Closeables import io.netty.channel.DefaultFileRegion import org.apache.spark.{SecurityManager, SparkConf} -import org.apache.spark.internal.Logging +import org.apache.spark.internal.{config, Logging} import org.apache.spark.network.util.{AbstractFileRegion, JavaUtils} import org.apache.spark.security.CryptoStreamUtils import org.apache.spark.util.Utils @@ -44,8 +44,7 @@ private[spark] class DiskStore( securityManager: SecurityManager) extends Logging { private val minMemoryMapBytes = conf.getSizeAsBytes("spark.storage.memoryMapThreshold", "2m") - private val maxMemoryMapBytes = conf.getSizeAsBytes("spark.storage.memoryMapLimitForTests", - Int.MaxValue.toString) + private val maxMemoryMapBytes = conf.get(config.MEMORY_MAP_LIMIT_FOR_TESTS) private val blockSizes = new ConcurrentHashMap[BlockId, Long]() def getSize(blockId: BlockId): Long = blockSizes.get(blockId) @@ -279,7 +278,7 @@ private class ReadableChannelFileRegion(source: ReadableByteChannel, blockSize: override def transferred(): Long = _transferred override def transferTo(target: WritableByteChannel, pos: Long): Long = { - assert(pos == transfered(), "Invalid position.") + assert(pos == transferred(), "Invalid position.") var written = 0L var lastWrite = -1L diff --git a/core/src/main/scala/org/apache/spark/storage/RDDInfo.scala b/core/src/main/scala/org/apache/spark/storage/RDDInfo.scala index e5abbf745cc41..19f86569c1e3c 100644 --- a/core/src/main/scala/org/apache/spark/storage/RDDInfo.scala +++ b/core/src/main/scala/org/apache/spark/storage/RDDInfo.scala @@ -17,7 +17,9 @@ package org.apache.spark.storage +import org.apache.spark.SparkEnv import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.internal.config._ import org.apache.spark.rdd.{RDD, RDDOperationScope} import org.apache.spark.util.Utils @@ -53,10 +55,17 @@ class RDDInfo( } private[spark] object RDDInfo { + private val callsiteLongForm = SparkEnv.get.conf.get(EVENT_LOG_CALLSITE_LONG_FORM) + def fromRdd(rdd: RDD[_]): RDDInfo = { val rddName = Option(rdd.name).getOrElse(Utils.getFormattedClassName(rdd)) val parentIds = rdd.dependencies.map(_.rdd.id) + val callSite = if (callsiteLongForm) { + rdd.creationSite.longForm + } else { + rdd.creationSite.shortForm + } new RDDInfo(rdd.id, rddName, rdd.partitions.length, - rdd.getStorageLevel, parentIds, rdd.creationSite.shortForm, rdd.scope) + rdd.getStorageLevel, parentIds, callSite, rdd.scope) } } diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index b31862323a895..00d01dd28afb5 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -346,7 +346,7 @@ final class ShuffleBlockFetcherIterator( private[this] def initialize(): Unit = { // Add a task completion callback (called in both success case and failure case) to cleanup. - context.addTaskCompletionListener(_ => cleanup()) + context.addTaskCompletionListener[Unit](_ => cleanup()) // Split local and remote blocks. val remoteRequests = splitLocalRemoteBlocks() diff --git a/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala b/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala index 4cc5bcb7f9baf..06fd56e54d9c8 100644 --- a/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala +++ b/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala @@ -827,7 +827,7 @@ private[storage] class PartiallySerializedBlock[T]( // completion listener here in order to ensure that `unrolled.dispose()` is called at least once. // The dispose() method is idempotent, so it's safe to call it unconditionally. Option(TaskContext.get()).foreach { taskContext => - taskContext.addTaskCompletionListener { _ => + taskContext.addTaskCompletionListener[Unit] { _ => // When a task completes, its unroll memory will automatically be freed. Thus we do not call // releaseUnrollMemoryForThisTask() here because we want to avoid double-freeing. unrolledBuffer.dispose() diff --git a/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala b/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala index ad0c0639521f6..43d62561e8eba 100644 --- a/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala +++ b/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala @@ -18,12 +18,13 @@ package org.apache.spark.util import java.io.{ByteArrayInputStream, ByteArrayOutputStream} +import java.lang.invoke.SerializedLambda import scala.collection.mutable.{Map, Set, Stack} import scala.language.existentials -import org.apache.xbean.asm5.{ClassReader, ClassVisitor, MethodVisitor, Type} -import org.apache.xbean.asm5.Opcodes._ +import org.apache.xbean.asm6.{ClassReader, ClassVisitor, MethodVisitor, Type} +import org.apache.xbean.asm6.Opcodes._ import org.apache.spark.{SparkEnv, SparkException} import org.apache.spark.internal.Logging @@ -33,6 +34,8 @@ import org.apache.spark.internal.Logging */ private[spark] object ClosureCleaner extends Logging { + private val isScala2_11 = scala.util.Properties.versionString.contains("2.11") + // Get an ASM class reader for a given class from the JAR that loaded it private[util] def getClassReader(cls: Class[_]): ClassReader = { // Copy data over, before delegating to ClassReader - else we can run out of open file handles. @@ -159,6 +162,42 @@ private[spark] object ClosureCleaner extends Logging { clean(closure, checkSerializable, cleanTransitively, Map.empty) } + /** + * Try to get a serialized Lambda from the closure. + * + * @param closure the closure to check. + */ + private def getSerializedLambda(closure: AnyRef): Option[SerializedLambda] = { + if (isScala2_11) { + return None + } + val isClosureCandidate = + closure.getClass.isSynthetic && + closure + .getClass + .getInterfaces.exists(_.getName == "scala.Serializable") + + if (isClosureCandidate) { + try { + Option(inspect(closure)) + } catch { + case e: Exception => + // no need to check if debug is enabled here the Spark + // logging api covers this. + logDebug("Closure is not a serialized lambda.", e) + None + } + } else { + None + } + } + + private def inspect(closure: AnyRef): SerializedLambda = { + val writeReplace = closure.getClass.getDeclaredMethod("writeReplace") + writeReplace.setAccessible(true) + writeReplace.invoke(closure).asInstanceOf[java.lang.invoke.SerializedLambda] + } + /** * Helper method to clean the given closure in place. * @@ -206,7 +245,12 @@ private[spark] object ClosureCleaner extends Logging { cleanTransitively: Boolean, accessedFields: Map[Class[_], Set[String]]): Unit = { - if (!isClosure(func.getClass)) { + // most likely to be the case with 2.12, 2.13 + // so we check first + // non LMF-closures should be less frequent from now on + val lambdaFunc = getSerializedLambda(func) + + if (!isClosure(func.getClass) && lambdaFunc.isEmpty) { logDebug(s"Expected a closure; got ${func.getClass.getName}") return } @@ -218,118 +262,132 @@ private[spark] object ClosureCleaner extends Logging { return } - logDebug(s"+++ Cleaning closure $func (${func.getClass.getName}) +++") - - // A list of classes that represents closures enclosed in the given one - val innerClasses = getInnerClosureClasses(func) - - // A list of enclosing objects and their respective classes, from innermost to outermost - // An outer object at a given index is of type outer class at the same index - val (outerClasses, outerObjects) = getOuterClassesAndObjects(func) - - // For logging purposes only - val declaredFields = func.getClass.getDeclaredFields - val declaredMethods = func.getClass.getDeclaredMethods - - if (log.isDebugEnabled) { - logDebug(" + declared fields: " + declaredFields.size) - declaredFields.foreach { f => logDebug(" " + f) } - logDebug(" + declared methods: " + declaredMethods.size) - declaredMethods.foreach { m => logDebug(" " + m) } - logDebug(" + inner classes: " + innerClasses.size) - innerClasses.foreach { c => logDebug(" " + c.getName) } - logDebug(" + outer classes: " + outerClasses.size) - outerClasses.foreach { c => logDebug(" " + c.getName) } - logDebug(" + outer objects: " + outerObjects.size) - outerObjects.foreach { o => logDebug(" " + o) } - } + if (lambdaFunc.isEmpty) { + logDebug(s"+++ Cleaning closure $func (${func.getClass.getName}) +++") + + // A list of classes that represents closures enclosed in the given one + val innerClasses = getInnerClosureClasses(func) + + // A list of enclosing objects and their respective classes, from innermost to outermost + // An outer object at a given index is of type outer class at the same index + val (outerClasses, outerObjects) = getOuterClassesAndObjects(func) + + // For logging purposes only + val declaredFields = func.getClass.getDeclaredFields + val declaredMethods = func.getClass.getDeclaredMethods + + if (log.isDebugEnabled) { + logDebug(s" + declared fields: ${declaredFields.size}") + declaredFields.foreach { f => logDebug(s" $f") } + logDebug(s" + declared methods: ${declaredMethods.size}") + declaredMethods.foreach { m => logDebug(s" $m") } + logDebug(s" + inner classes: ${innerClasses.size}") + innerClasses.foreach { c => logDebug(s" ${c.getName}") } + logDebug(s" + outer classes: ${outerClasses.size}" ) + outerClasses.foreach { c => logDebug(s" ${c.getName}") } + logDebug(s" + outer objects: ${outerObjects.size}") + outerObjects.foreach { o => logDebug(s" $o") } + } - // Fail fast if we detect return statements in closures - getClassReader(func.getClass).accept(new ReturnStatementFinder(), 0) - - // If accessed fields is not populated yet, we assume that - // the closure we are trying to clean is the starting one - if (accessedFields.isEmpty) { - logDebug(s" + populating accessed fields because this is the starting closure") - // Initialize accessed fields with the outer classes first - // This step is needed to associate the fields to the correct classes later - initAccessedFields(accessedFields, outerClasses) - - // Populate accessed fields by visiting all fields and methods accessed by this and - // all of its inner closures. If transitive cleaning is enabled, this may recursively - // visits methods that belong to other classes in search of transitively referenced fields. - for (cls <- func.getClass :: innerClasses) { - getClassReader(cls).accept(new FieldAccessFinder(accessedFields, cleanTransitively), 0) + // Fail fast if we detect return statements in closures + getClassReader(func.getClass).accept(new ReturnStatementFinder(), 0) + + // If accessed fields is not populated yet, we assume that + // the closure we are trying to clean is the starting one + if (accessedFields.isEmpty) { + logDebug(" + populating accessed fields because this is the starting closure") + // Initialize accessed fields with the outer classes first + // This step is needed to associate the fields to the correct classes later + initAccessedFields(accessedFields, outerClasses) + + // Populate accessed fields by visiting all fields and methods accessed by this and + // all of its inner closures. If transitive cleaning is enabled, this may recursively + // visits methods that belong to other classes in search of transitively referenced fields. + for (cls <- func.getClass :: innerClasses) { + getClassReader(cls).accept(new FieldAccessFinder(accessedFields, cleanTransitively), 0) + } } - } - logDebug(s" + fields accessed by starting closure: " + accessedFields.size) - accessedFields.foreach { f => logDebug(" " + f) } - - // List of outer (class, object) pairs, ordered from outermost to innermost - // Note that all outer objects but the outermost one (first one in this list) must be closures - var outerPairs: List[(Class[_], AnyRef)] = (outerClasses zip outerObjects).reverse - var parent: AnyRef = null - if (outerPairs.size > 0) { - val (outermostClass, outermostObject) = outerPairs.head - if (isClosure(outermostClass)) { - logDebug(s" + outermost object is a closure, so we clone it: ${outerPairs.head}") - } else if (outermostClass.getName.startsWith("$line")) { - // SPARK-14558: if the outermost object is a REPL line object, we should clone and clean it - // as it may carray a lot of unnecessary information, e.g. hadoop conf, spark conf, etc. - logDebug(s" + outermost object is a REPL line object, so we clone it: ${outerPairs.head}") + logDebug(s" + fields accessed by starting closure: " + accessedFields.size) + accessedFields.foreach { f => logDebug(" " + f) } + + // List of outer (class, object) pairs, ordered from outermost to innermost + // Note that all outer objects but the outermost one (first one in this list) must be closures + var outerPairs: List[(Class[_], AnyRef)] = outerClasses.zip(outerObjects).reverse + var parent: AnyRef = null + if (outerPairs.nonEmpty) { + val (outermostClass, outermostObject) = outerPairs.head + if (isClosure(outermostClass)) { + logDebug(s" + outermost object is a closure, so we clone it: ${outerPairs.head}") + } else if (outermostClass.getName.startsWith("$line")) { + // SPARK-14558: if the outermost object is a REPL line object, we should clone + // and clean it as it may carray a lot of unnecessary information, + // e.g. hadoop conf, spark conf, etc. + logDebug(s" + outermost object is a REPL line object, so we clone it: ${outerPairs.head}") + } else { + // The closure is ultimately nested inside a class; keep the object of that + // class without cloning it since we don't want to clone the user's objects. + // Note that we still need to keep around the outermost object itself because + // we need it to clone its child closure later (see below). + logDebug(" + outermost object is not a closure or REPL line object," + + "so do not clone it: " + outerPairs.head) + parent = outermostObject // e.g. SparkContext + outerPairs = outerPairs.tail + } } else { - // The closure is ultimately nested inside a class; keep the object of that - // class without cloning it since we don't want to clone the user's objects. - // Note that we still need to keep around the outermost object itself because - // we need it to clone its child closure later (see below). - logDebug(" + outermost object is not a closure or REPL line object, so do not clone it: " + - outerPairs.head) - parent = outermostObject // e.g. SparkContext - outerPairs = outerPairs.tail + logDebug(" + there are no enclosing objects!") } - } else { - logDebug(" + there are no enclosing objects!") - } - // Clone the closure objects themselves, nulling out any fields that are not - // used in the closure we're working on or any of its inner closures. - for ((cls, obj) <- outerPairs) { - logDebug(s" + cloning the object $obj of class ${cls.getName}") - // We null out these unused references by cloning each object and then filling in all - // required fields from the original object. We need the parent here because the Java - // language specification requires the first constructor parameter of any closure to be - // its enclosing object. - val clone = cloneAndSetFields(parent, obj, cls, accessedFields) - - // If transitive cleaning is enabled, we recursively clean any enclosing closure using - // the already populated accessed fields map of the starting closure - if (cleanTransitively && isClosure(clone.getClass)) { - logDebug(s" + cleaning cloned closure $clone recursively (${cls.getName})") - // No need to check serializable here for the outer closures because we're - // only interested in the serializability of the starting closure - clean(clone, checkSerializable = false, cleanTransitively, accessedFields) + // Clone the closure objects themselves, nulling out any fields that are not + // used in the closure we're working on or any of its inner closures. + for ((cls, obj) <- outerPairs) { + logDebug(s" + cloning the object $obj of class ${cls.getName}") + // We null out these unused references by cloning each object and then filling in all + // required fields from the original object. We need the parent here because the Java + // language specification requires the first constructor parameter of any closure to be + // its enclosing object. + val clone = cloneAndSetFields(parent, obj, cls, accessedFields) + + // If transitive cleaning is enabled, we recursively clean any enclosing closure using + // the already populated accessed fields map of the starting closure + if (cleanTransitively && isClosure(clone.getClass)) { + logDebug(s" + cleaning cloned closure $clone recursively (${cls.getName})") + // No need to check serializable here for the outer closures because we're + // only interested in the serializability of the starting closure + clean(clone, checkSerializable = false, cleanTransitively, accessedFields) + } + parent = clone } - parent = clone - } - // Update the parent pointer ($outer) of this closure - if (parent != null) { - val field = func.getClass.getDeclaredField("$outer") - field.setAccessible(true) - // If the starting closure doesn't actually need our enclosing object, then just null it out - if (accessedFields.contains(func.getClass) && - !accessedFields(func.getClass).contains("$outer")) { - logDebug(s" + the starting closure doesn't actually need $parent, so we null it out") - field.set(func, null) - } else { - // Update this closure's parent pointer to point to our enclosing object, - // which could either be a cloned closure or the original user object - field.set(func, parent) + // Update the parent pointer ($outer) of this closure + if (parent != null) { + val field = func.getClass.getDeclaredField("$outer") + field.setAccessible(true) + // If the starting closure doesn't actually need our enclosing object, then just null it out + if (accessedFields.contains(func.getClass) && + !accessedFields(func.getClass).contains("$outer")) { + logDebug(s" + the starting closure doesn't actually need $parent, so we null it out") + field.set(func, null) + } else { + // Update this closure's parent pointer to point to our enclosing object, + // which could either be a cloned closure or the original user object + field.set(func, parent) + } } - } - logDebug(s" +++ closure $func (${func.getClass.getName}) is now cleaned +++") + logDebug(s" +++ closure $func (${func.getClass.getName}) is now cleaned +++") + } else { + logDebug(s"Cleaning lambda: ${lambdaFunc.get.getImplMethodName}") + + // scalastyle:off classforname + val captClass = Class.forName(lambdaFunc.get.getCapturingClass.replace('/', '.'), + false, Thread.currentThread.getContextClassLoader) + // scalastyle:on classforname + // Fail fast if we detect return statements in closures + getClassReader(captClass) + .accept(new ReturnStatementFinder(Some(lambdaFunc.get.getImplMethodName)), 0) + logDebug(s" +++ Lambda closure (${lambdaFunc.get.getImplMethodName}) is now cleaned +++") + } if (checkSerializable) { ensureSerializable(func) @@ -366,20 +424,30 @@ private[spark] object ClosureCleaner extends Logging { private[spark] class ReturnStatementInClosureException extends SparkException("Return statements aren't allowed in Spark closures") -private class ReturnStatementFinder extends ClassVisitor(ASM5) { +private class ReturnStatementFinder(targetMethodName: Option[String] = None) + extends ClassVisitor(ASM6) { override def visitMethod(access: Int, name: String, desc: String, sig: String, exceptions: Array[String]): MethodVisitor = { + // $anonfun$ covers Java 8 lambdas if (name.contains("apply") || name.contains("$anonfun$")) { - new MethodVisitor(ASM5) { + // A method with suffix "$adapted" will be generated in cases like + // { _:Int => return; Seq()} but not { _:Int => return; true} + // closure passed is $anonfun$t$1$adapted while actual code resides in $anonfun$s$1 + // visitor will see only $anonfun$s$1$adapted, so we remove the suffix, see + // https://github.com/scala/scala-dev/issues/109 + val isTargetMethod = targetMethodName.isEmpty || + name == targetMethodName.get || name == targetMethodName.get.stripSuffix("$adapted") + + new MethodVisitor(ASM6) { override def visitTypeInsn(op: Int, tp: String) { - if (op == NEW && tp.contains("scala/runtime/NonLocalReturnControl")) { + if (op == NEW && tp.contains("scala/runtime/NonLocalReturnControl") && isTargetMethod) { throw new ReturnStatementInClosureException } } } } else { - new MethodVisitor(ASM5) {} + new MethodVisitor(ASM6) {} } } } @@ -403,7 +471,7 @@ private[util] class FieldAccessFinder( findTransitively: Boolean, specificMethod: Option[MethodIdentifier[_]] = None, visitedMethods: Set[MethodIdentifier[_]] = Set.empty) - extends ClassVisitor(ASM5) { + extends ClassVisitor(ASM6) { override def visitMethod( access: Int, @@ -418,7 +486,7 @@ private[util] class FieldAccessFinder( return null } - new MethodVisitor(ASM5) { + new MethodVisitor(ASM6) { override def visitFieldInsn(op: Int, owner: String, name: String, desc: String) { if (op == GETFIELD) { for (cl <- fields.keys if cl.getName == owner.replace('/', '.')) { @@ -458,7 +526,7 @@ private[util] class FieldAccessFinder( } } -private class InnerClosureFinder(output: Set[Class[_]]) extends ClassVisitor(ASM5) { +private class InnerClosureFinder(output: Set[Class[_]]) extends ClassVisitor(ASM6) { var myName: String = null // TODO: Recursively find inner closures that we indirectly reference, e.g. @@ -473,7 +541,7 @@ private class InnerClosureFinder(output: Set[Class[_]]) extends ClassVisitor(ASM override def visitMethod(access: Int, name: String, desc: String, sig: String, exceptions: Array[String]): MethodVisitor = { - new MethodVisitor(ASM5) { + new MethodVisitor(ASM6) { override def visitMethodInsn( op: Int, owner: String, name: String, desc: String, itf: Boolean) { val argTypes = Type.getArgumentTypes(desc) diff --git a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala index 50c6461373dee..0cd8612b8fd1c 100644 --- a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala +++ b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala @@ -31,6 +31,7 @@ import org.json4s.jackson.JsonMethods._ import org.apache.spark._ import org.apache.spark.executor._ +import org.apache.spark.metrics.ExecutorMetricType import org.apache.spark.rdd.RDDOperationScope import org.apache.spark.scheduler._ import org.apache.spark.scheduler.cluster.ExecutorInfo @@ -98,6 +99,8 @@ private[spark] object JsonProtocol { logStartToJson(logStart) case metricsUpdate: SparkListenerExecutorMetricsUpdate => executorMetricsUpdateToJson(metricsUpdate) + case stageExecutorMetrics: SparkListenerStageExecutorMetrics => + stageExecutorMetricsToJson(stageExecutorMetrics) case blockUpdate: SparkListenerBlockUpdated => blockUpdateToJson(blockUpdate) case _ => parse(mapper.writeValueAsString(event)) @@ -236,6 +239,7 @@ private[spark] object JsonProtocol { def executorMetricsUpdateToJson(metricsUpdate: SparkListenerExecutorMetricsUpdate): JValue = { val execId = metricsUpdate.execId val accumUpdates = metricsUpdate.accumUpdates + val executorMetrics = metricsUpdate.executorUpdates.map(executorMetricsToJson(_)) ("Event" -> SPARK_LISTENER_EVENT_FORMATTED_CLASS_NAMES.metricsUpdate) ~ ("Executor ID" -> execId) ~ ("Metrics Updated" -> accumUpdates.map { case (taskId, stageId, stageAttemptId, updates) => @@ -243,7 +247,16 @@ private[spark] object JsonProtocol { ("Stage ID" -> stageId) ~ ("Stage Attempt ID" -> stageAttemptId) ~ ("Accumulator Updates" -> JArray(updates.map(accumulableInfoToJson).toList)) - }) + }) ~ + ("Executor Metrics Updated" -> executorMetrics) + } + + def stageExecutorMetricsToJson(metrics: SparkListenerStageExecutorMetrics): JValue = { + ("Event" -> SPARK_LISTENER_EVENT_FORMATTED_CLASS_NAMES.stageExecutorMetrics) ~ + ("Executor ID" -> metrics.execId) ~ + ("Stage ID" -> metrics.stageId) ~ + ("Stage Attempt ID" -> metrics.stageAttemptId) ~ + ("Executor Metrics" -> executorMetricsToJson(metrics.executorMetrics)) } def blockUpdateToJson(blockUpdate: SparkListenerBlockUpdated): JValue = { @@ -379,6 +392,14 @@ private[spark] object JsonProtocol { ("Updated Blocks" -> updatedBlocks) } + /** Convert executor metrics to JSON. */ + def executorMetricsToJson(executorMetrics: ExecutorMetrics): JValue = { + val metrics = ExecutorMetricType.values.map{ metricType => + JField(metricType.name, executorMetrics.getMetricValue(metricType)) + } + JObject(metrics: _*) + } + def taskEndReasonToJson(taskEndReason: TaskEndReason): JValue = { val reason = Utils.getFormattedClassName(taskEndReason) val json: JObject = taskEndReason match { @@ -531,6 +552,7 @@ private[spark] object JsonProtocol { val executorRemoved = Utils.getFormattedClassName(SparkListenerExecutorRemoved) val logStart = Utils.getFormattedClassName(SparkListenerLogStart) val metricsUpdate = Utils.getFormattedClassName(SparkListenerExecutorMetricsUpdate) + val stageExecutorMetrics = Utils.getFormattedClassName(SparkListenerStageExecutorMetrics) val blockUpdate = Utils.getFormattedClassName(SparkListenerBlockUpdated) } @@ -555,6 +577,7 @@ private[spark] object JsonProtocol { case `executorRemoved` => executorRemovedFromJson(json) case `logStart` => logStartFromJson(json) case `metricsUpdate` => executorMetricsUpdateFromJson(json) + case `stageExecutorMetrics` => stageExecutorMetricsFromJson(json) case `blockUpdate` => blockUpdateFromJson(json) case other => mapper.readValue(compact(render(json)), Utils.classForName(other)) .asInstanceOf[SparkListenerEvent] @@ -585,6 +608,15 @@ private[spark] object JsonProtocol { SparkListenerTaskGettingResult(taskInfo) } + /** Extract the executor metrics from JSON. */ + def executorMetricsFromJson(json: JValue): ExecutorMetrics = { + val metrics = + ExecutorMetricType.values.map { metric => + metric.name -> jsonOption(json \ metric.name).map(_.extract[Long]).getOrElse(0L) + }.toMap + new ExecutorMetrics(metrics) + } + def taskEndFromJson(json: JValue): SparkListenerTaskEnd = { val stageId = (json \ "Stage ID").extract[Int] val stageAttemptId = @@ -691,7 +723,18 @@ private[spark] object JsonProtocol { (json \ "Accumulator Updates").extract[List[JValue]].map(accumulableInfoFromJson) (taskId, stageId, stageAttemptId, updates) } - SparkListenerExecutorMetricsUpdate(execInfo, accumUpdates) + val executorUpdates = jsonOption(json \ "Executor Metrics Updated").map { + executorUpdate => executorMetricsFromJson(executorUpdate) + } + SparkListenerExecutorMetricsUpdate(execInfo, accumUpdates, executorUpdates) + } + + def stageExecutorMetricsFromJson(json: JValue): SparkListenerStageExecutorMetrics = { + val execId = (json \ "Executor ID").extract[String] + val stageId = (json \ "Stage ID").extract[Int] + val stageAttemptId = (json \ "Stage Attempt ID").extract[Int] + val executorMetrics = executorMetricsFromJson(json \ "Executor Metrics") + SparkListenerStageExecutorMetrics(execId, stageId, stageAttemptId, executorMetrics) } def blockUpdateFromJson(json: JValue): SparkListenerBlockUpdated = { diff --git a/core/src/main/scala/org/apache/spark/util/ListenerBus.scala b/core/src/main/scala/org/apache/spark/util/ListenerBus.scala index d4474a90b26f1..a8f10684d5a2c 100644 --- a/core/src/main/scala/org/apache/spark/util/ListenerBus.scala +++ b/core/src/main/scala/org/apache/spark/util/ListenerBus.scala @@ -61,7 +61,7 @@ private[spark] trait ListenerBus[L <: AnyRef, E] extends Logging { } /** - * This can be overriden by subclasses if there is any extra cleanup to do when removing a + * This can be overridden by subclasses if there is any extra cleanup to do when removing a * listener. In particular AsyncEventQueues can clean up queues in the LiveListenerBus. */ def removeListenerOnError(listener: L): Unit = { diff --git a/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala b/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala index 0f08a2b0ad895..cb0c20541d0d7 100644 --- a/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala +++ b/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala @@ -19,8 +19,12 @@ package org.apache.spark.util import java.util.concurrent._ +import scala.collection.TraversableLike +import scala.collection.generic.CanBuildFrom +import scala.language.higherKinds + import com.google.common.util.concurrent.{MoreExecutors, ThreadFactoryBuilder} -import scala.concurrent.{Awaitable, ExecutionContext, ExecutionContextExecutor} +import scala.concurrent.{Awaitable, ExecutionContext, ExecutionContextExecutor, Future} import scala.concurrent.duration.{Duration, FiniteDuration} import scala.concurrent.forkjoin.{ForkJoinPool => SForkJoinPool, ForkJoinWorkerThread => SForkJoinWorkerThread} import scala.util.control.NonFatal @@ -254,4 +258,38 @@ private[spark] object ThreadUtils { executor.shutdownNow() } } + + /** + * Transforms input collection by applying the given function to each element in parallel fashion. + * Comparing to the map() method of Scala parallel collections, this method can be interrupted + * at any time. This is useful on canceling of task execution, for example. + * + * @param in - the input collection which should be transformed in parallel. + * @param prefix - the prefix assigned to the underlying thread pool. + * @param maxThreads - maximum number of thread can be created during execution. + * @param f - the lambda function will be applied to each element of `in`. + * @tparam I - the type of elements in the input collection. + * @tparam O - the type of elements in resulted collection. + * @return new collection in which each element was given from the input collection `in` by + * applying the lambda function `f`. + */ + def parmap[I, O, Col[X] <: TraversableLike[X, Col[X]]] + (in: Col[I], prefix: String, maxThreads: Int) + (f: I => O) + (implicit + cbf: CanBuildFrom[Col[I], Future[O], Col[Future[O]]], // For in.map + cbf2: CanBuildFrom[Col[Future[O]], O, Col[O]] // for Future.sequence + ): Col[O] = { + val pool = newForkJoinPool(prefix, maxThreads) + try { + implicit val ec = ExecutionContext.fromExecutor(pool) + + val futures = in.map(x => Future(f(x))) + val futureSeq = Future.sequence(futures) + + awaitResult(futureSeq, Duration.Inf) + } finally { + pool.shutdownNow() + } + } } diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index a6fd3637663e8..14f68cd6f3509 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -19,7 +19,6 @@ package org.apache.spark.util import java.io._ import java.lang.{Byte => JByte} -import java.lang.InternalError import java.lang.management.{LockInfo, ManagementFactory, MonitorInfo, ThreadInfo} import java.lang.reflect.InvocationTargetException import java.math.{MathContext, RoundingMode} @@ -60,7 +59,7 @@ import org.slf4j.Logger import org.apache.spark._ import org.apache.spark.deploy.SparkHadoopUtil -import org.apache.spark.internal.Logging +import org.apache.spark.internal.{config, Logging} import org.apache.spark.internal.config._ import org.apache.spark.launcher.SparkLauncher import org.apache.spark.network.util.JavaUtils @@ -83,6 +82,7 @@ private[spark] object Utils extends Logging { val random = new Random() private val sparkUncaughtExceptionHandler = new SparkUncaughtExceptionHandler + @volatile private var cachedLocalDir: String = "" /** * Define a default value for driver memory here since this value is referenced across the code @@ -462,7 +462,15 @@ private[spark] object Utils extends Logging { if (useCache && fetchCacheEnabled) { val cachedFileName = s"${url.hashCode}${timestamp}_cache" val lockFileName = s"${url.hashCode}${timestamp}_lock" - val localDir = new File(getLocalDir(conf)) + // Set the cachedLocalDir for the first time and re-use it later + if (cachedLocalDir.isEmpty) { + this.synchronized { + if (cachedLocalDir.isEmpty) { + cachedLocalDir = getLocalDir(conf) + } + } + } + val localDir = new File(cachedLocalDir) val lockFile = new File(localDir, lockFileName) val lockFileChannel = new RandomAccessFile(lockFile, "rw").getChannel() // Only one executor entry. @@ -767,13 +775,17 @@ private[spark] object Utils extends Logging { * - Otherwise, this will return java.io.tmpdir. * * Some of these configuration options might be lists of multiple paths, but this method will - * always return a single directory. + * always return a single directory. The return directory is chosen randomly from the array + * of directories it gets from getOrCreateLocalRootDirs. */ def getLocalDir(conf: SparkConf): String = { - getOrCreateLocalRootDirs(conf).headOption.getOrElse { + val localRootDirs = getOrCreateLocalRootDirs(conf) + if (localRootDirs.isEmpty) { val configuredLocalDirs = getConfiguredLocalDirs(conf) throw new IOException( s"Failed to get a temp directory under [${configuredLocalDirs.mkString(",")}].") + } else { + localRootDirs(scala.util.Random.nextInt(localRootDirs.length)) } } @@ -809,13 +821,13 @@ private[spark] object Utils extends Logging { * logic of locating the local directories according to deployment mode. */ def getConfiguredLocalDirs(conf: SparkConf): Array[String] = { - val shuffleServiceEnabled = conf.getBoolean("spark.shuffle.service.enabled", false) + val shuffleServiceEnabled = conf.get(config.SHUFFLE_SERVICE_ENABLED) if (isRunningInYarnContainer(conf)) { // If we are in yarn mode, systems can have different disk layouts so we must set it // to what Yarn on this system said was available. Note this assumes that Yarn has // created the directories already, and that they are secured so that only the // user has access to them. - getYarnLocalDirs(conf).split(",") + randomizeInPlace(getYarnLocalDirs(conf).split(",")) } else if (conf.getenv("SPARK_EXECUTOR_DIRS") != null) { conf.getenv("SPARK_EXECUTOR_DIRS").split(File.pathSeparator) } else if (conf.getenv("SPARK_LOCAL_DIRS") != null) { @@ -1374,7 +1386,7 @@ private[spark] object Utils extends Logging { originalThrowable = cause try { logError("Aborting task", originalThrowable) - TaskContext.get().asInstanceOf[TaskContextImpl].markTaskFailed(originalThrowable) + TaskContext.get().markTaskFailed(originalThrowable) catchBlock } catch { case t: Throwable => @@ -1396,13 +1408,14 @@ private[spark] object Utils extends Logging { } } + // A regular expression to match classes of the internal Spark API's + // that we want to skip when finding the call site of a method. + private val SPARK_CORE_CLASS_REGEX = + """^org\.apache\.spark(\.api\.java)?(\.util)?(\.rdd)?(\.broadcast)?\.[A-Z]""".r + private val SPARK_SQL_CLASS_REGEX = """^org\.apache\.spark\.sql.*""".r + /** Default filtering function for finding call sites using `getCallSite`. */ private def sparkInternalExclusionFunction(className: String): Boolean = { - // A regular expression to match classes of the internal Spark API's - // that we want to skip when finding the call site of a method. - val SPARK_CORE_CLASS_REGEX = - """^org\.apache\.spark(\.api\.java)?(\.util)?(\.rdd)?(\.broadcast)?\.[A-Z]""".r - val SPARK_SQL_CLASS_REGEX = """^org\.apache\.spark\.sql.*""".r val SCALA_CORE_CLASS_PREFIX = "scala" val isSparkClass = SPARK_CORE_CLASS_REGEX.findFirstIn(className).isDefined || SPARK_SQL_CLASS_REGEX.findFirstIn(className).isDefined @@ -2038,6 +2051,30 @@ private[spark] object Utils extends Logging { } } + /** + * Implements the same logic as JDK `java.lang.String#trim` by removing leading and trailing + * non-printable characters less or equal to '\u0020' (SPACE) but preserves natural line + * delimiters according to [[java.util.Properties]] load method. The natural line delimiters are + * removed by JDK during load. Therefore any remaining ones have been specifically provided and + * escaped by the user, and must not be ignored + * + * @param str + * @return the trimmed value of str + */ + private[util] def trimExceptCRLF(str: String): String = { + val nonSpaceOrNaturalLineDelimiter: Char => Boolean = { ch => + ch > ' ' || ch == '\r' || ch == '\n' + } + + val firstPos = str.indexWhere(nonSpaceOrNaturalLineDelimiter) + val lastPos = str.lastIndexWhere(nonSpaceOrNaturalLineDelimiter) + if (firstPos >= 0 && lastPos >= 0) { + str.substring(firstPos, lastPos + 1) + } else { + "" + } + } + /** Load properties present in the given file. */ def getPropertiesFromFile(filename: String): Map[String, String] = { val file = new File(filename) @@ -2048,8 +2085,10 @@ private[spark] object Utils extends Logging { try { val properties = new Properties() properties.load(inReader) - properties.stringPropertyNames().asScala.map( - k => (k, properties.getProperty(k).trim)).toMap + properties.stringPropertyNames().asScala + .map { k => (k, trimExceptCRLF(properties.getProperty(k))) } + .toMap + } catch { case e: IOException => throw new SparkException(s"Failed when loading Spark properties from $filename", e) @@ -2781,6 +2820,36 @@ private[spark] object Utils extends Logging { } } } + + /** + * Regular expression matching full width characters. + * + * Looked at all the 0x0000-0xFFFF characters (unicode) and showed them under Xshell. + * Found all the full width characters, then get the regular expression. + */ + private val fullWidthRegex = ("""[""" + + // scalastyle:off nonascii + """\u1100-\u115F""" + + """\u2E80-\uA4CF""" + + """\uAC00-\uD7A3""" + + """\uF900-\uFAFF""" + + """\uFE10-\uFE19""" + + """\uFE30-\uFE6F""" + + """\uFF00-\uFF60""" + + """\uFFE0-\uFFE6""" + + // scalastyle:on nonascii + """]""").r + + /** + * Return the number of half widths in a given string. Note that a full width character + * occupies two half widths. + * + * For a string consisting of 1 million characters, the execution of this method requires + * about 50ms. + */ + def stringHalfWidth(str: String): Int = { + if (str == null) 0 else str.length + fullWidthRegex.findAllIn(str).size + } } private[util] object CallerContext extends Logging { diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala index 5c6dd45ec58e3..19ff109b673e1 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala @@ -80,7 +80,10 @@ class ExternalAppendOnlyMap[K, V, C]( this(createCombiner, mergeValue, mergeCombiners, serializer, blockManager, TaskContext.get()) } - @volatile private var currentMap = new SizeTrackingAppendOnlyMap[K, C] + /** + * Exposed for testing + */ + @volatile private[collection] var currentMap = new SizeTrackingAppendOnlyMap[K, C] private val spilledMaps = new ArrayBuffer[DiskMapIterator] private val sparkConf = SparkEnv.get.conf private val diskBlockManager = blockManager.diskBlockManager @@ -267,7 +270,7 @@ class ExternalAppendOnlyMap[K, V, C]( */ def destructiveIterator(inMemoryIterator: Iterator[(K, C)]): Iterator[(K, C)] = { readingIterator = new SpillableIterator(inMemoryIterator) - readingIterator + readingIterator.toCompletionIterator } /** @@ -280,8 +283,7 @@ class ExternalAppendOnlyMap[K, V, C]( "ExternalAppendOnlyMap.iterator is destructive and should only be called once.") } if (spilledMaps.isEmpty) { - CompletionIterator[(K, C), Iterator[(K, C)]]( - destructiveIterator(currentMap.iterator), freeCurrentMap()) + destructiveIterator(currentMap.iterator) } else { new ExternalIterator() } @@ -305,8 +307,8 @@ class ExternalAppendOnlyMap[K, V, C]( // Input streams are derived both from the in-memory map and spilled maps on disk // The in-memory map is sorted in place, while the spilled maps are already in sorted order - private val sortedMap = CompletionIterator[(K, C), Iterator[(K, C)]](destructiveIterator( - currentMap.destructiveSortedIterator(keyComparator)), freeCurrentMap()) + private val sortedMap = destructiveIterator( + currentMap.destructiveSortedIterator(keyComparator)) private val inputStreams = (Seq(sortedMap) ++ spilledMaps).map(it => it.buffered) inputStreams.foreach { it => @@ -565,16 +567,14 @@ class ExternalAppendOnlyMap[K, V, C]( } } - context.addTaskCompletionListener(context => cleanup()) + context.addTaskCompletionListener[Unit](context => cleanup()) } - private[this] class SpillableIterator(var upstream: Iterator[(K, C)]) + private class SpillableIterator(var upstream: Iterator[(K, C)]) extends Iterator[(K, C)] { private val SPILL_LOCK = new Object() - private var nextUpstream: Iterator[(K, C)] = null - private var cur: (K, C) = readNext() private var hasSpilled: Boolean = false @@ -585,17 +585,24 @@ class ExternalAppendOnlyMap[K, V, C]( } else { logInfo(s"Task ${context.taskAttemptId} force spilling in-memory map to disk and " + s"it will release ${org.apache.spark.util.Utils.bytesToString(getUsed())} memory") - nextUpstream = spillMemoryIteratorToDisk(upstream) + val nextUpstream = spillMemoryIteratorToDisk(upstream) + assert(!upstream.hasNext) hasSpilled = true + upstream = nextUpstream true } } + private def destroy(): Unit = { + freeCurrentMap() + upstream = Iterator.empty + } + + def toCompletionIterator: CompletionIterator[(K, C), SpillableIterator] = { + CompletionIterator[(K, C), SpillableIterator](this, this.destroy) + } + def readNext(): (K, C) = SPILL_LOCK.synchronized { - if (nextUpstream != null) { - upstream = nextUpstream - nextUpstream = null - } if (upstream.hasNext) { upstream.next() } else { diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala index 176f84fa2a0d2..b159200d79222 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala @@ -368,8 +368,8 @@ private[spark] class ExternalSorter[K, V, C]( val bufferedIters = iterators.filter(_.hasNext).map(_.buffered) type Iter = BufferedIterator[Product2[K, C]] val heap = new mutable.PriorityQueue[Iter]()(new Ordering[Iter] { - // Use the reverse of comparator.compare because PriorityQueue dequeues the max - override def compare(x: Iter, y: Iter): Int = -comparator.compare(x.head._1, y.head._1) + // Use the reverse order because PriorityQueue dequeues the max + override def compare(x: Iter, y: Iter): Int = comparator.compare(y.head._1, x.head._1) }) heap.enqueue(bufferedIters: _*) // Will contain only the iterators with hasNext = true new Iterator[Product2[K, C]] { diff --git a/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala b/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala index 60f6f537c1d54..8883e17bf3164 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala @@ -28,9 +28,9 @@ import org.apache.spark.annotation.Private * removed. * * The underlying implementation uses Scala compiler's specialization to generate optimized - * storage for two primitive types (Long and Int). It is much faster than Java's standard HashSet - * while incurring much less memory overhead. This can serve as building blocks for higher level - * data structures such as an optimized HashMap. + * storage for four primitive types (Long, Int, Double, and Float). It is much faster than Java's + * standard HashSet while incurring much less memory overhead. This can serve as building blocks + * for higher level data structures such as an optimized HashMap. * * This OpenHashSet is designed to serve as building blocks for higher level data structures * such as an optimized hash map. Compared with standard hash set implementations, this class @@ -41,7 +41,7 @@ import org.apache.spark.annotation.Private * to explore all spaces for each key (see http://en.wikipedia.org/wiki/Quadratic_probing). */ @Private -class OpenHashSet[@specialized(Long, Int) T: ClassTag]( +class OpenHashSet[@specialized(Long, Int, Double, Float) T: ClassTag]( initialCapacity: Int, loadFactor: Double) extends Serializable { @@ -77,6 +77,10 @@ class OpenHashSet[@specialized(Long, Int) T: ClassTag]( (new LongHasher).asInstanceOf[Hasher[T]] } else if (mt == ClassTag.Int) { (new IntHasher).asInstanceOf[Hasher[T]] + } else if (mt == ClassTag.Double) { + (new DoubleHasher).asInstanceOf[Hasher[T]] + } else if (mt == ClassTag.Float) { + (new FloatHasher).asInstanceOf[Hasher[T]] } else { new Hasher[T] } @@ -293,7 +297,7 @@ object OpenHashSet { * A set of specialized hash function implementation to avoid boxing hash code computation * in the specialized implementation of OpenHashSet. */ - sealed class Hasher[@specialized(Long, Int) T] extends Serializable { + sealed class Hasher[@specialized(Long, Int, Double, Float) T] extends Serializable { def hash(o: T): Int = o.hashCode() } @@ -305,6 +309,17 @@ object OpenHashSet { override def hash(o: Int): Int = o } + class DoubleHasher extends Hasher[Double] { + override def hash(o: Double): Int = { + val bits = java.lang.Double.doubleToLongBits(o) + (bits ^ (bits >>> 32)).toInt + } + } + + class FloatHasher extends Hasher[Float] { + override def hash(o: Float): Int = java.lang.Float.floatToIntBits(o) + } + private def grow1(newSize: Int) {} private def move1(oldPos: Int, newPos: Int) { } diff --git a/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala b/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala index 700ce56466c35..39f050f6ca5ad 100644 --- a/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala +++ b/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala @@ -17,17 +17,21 @@ package org.apache.spark.util.io -import java.io.InputStream +import java.io.{File, FileInputStream, InputStream} import java.nio.ByteBuffer -import java.nio.channels.WritableByteChannel +import java.nio.channels.{FileChannel, WritableByteChannel} +import java.nio.file.StandardOpenOption + +import scala.collection.mutable.ListBuffer import com.google.common.primitives.UnsignedBytes -import io.netty.buffer.{ByteBuf, Unpooled} import org.apache.spark.SparkEnv import org.apache.spark.internal.config +import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer} import org.apache.spark.network.util.ByteArrayWritableChannel import org.apache.spark.storage.StorageUtils +import org.apache.spark.util.Utils /** * Read-only byte buffer which is physically stored as multiple chunks rather than a single @@ -81,10 +85,10 @@ private[spark] class ChunkedByteBuffer(var chunks: Array[ByteBuffer]) { } /** - * Wrap this buffer to view it as a Netty ByteBuf. + * Wrap this in a custom "FileRegion" which allows us to transfer over 2 GB. */ - def toNetty: ByteBuf = { - Unpooled.wrappedBuffer(chunks.length, getChunks(): _*) + def toNetty: ChunkedByteBufferFileRegion = { + new ChunkedByteBufferFileRegion(this, bufferWriteChunkSize) } /** @@ -166,6 +170,38 @@ private[spark] class ChunkedByteBuffer(var chunks: Array[ByteBuffer]) { } +object ChunkedByteBuffer { + // TODO eliminate this method if we switch BlockManager to getting InputStreams + def fromManagedBuffer(data: ManagedBuffer, maxChunkSize: Int): ChunkedByteBuffer = { + data match { + case f: FileSegmentManagedBuffer => + map(f.getFile, maxChunkSize, f.getOffset, f.getLength) + case other => + new ChunkedByteBuffer(other.nioByteBuffer()) + } + } + + def map(file: File, maxChunkSize: Int): ChunkedByteBuffer = { + map(file, maxChunkSize, 0, file.length()) + } + + def map(file: File, maxChunkSize: Int, offset: Long, length: Long): ChunkedByteBuffer = { + Utils.tryWithResource(FileChannel.open(file.toPath, StandardOpenOption.READ)) { channel => + var remaining = length + var pos = offset + val chunks = new ListBuffer[ByteBuffer]() + while (remaining > 0) { + val chunkSize = math.min(remaining, maxChunkSize) + val chunk = channel.map(FileChannel.MapMode.READ_ONLY, pos, chunkSize) + pos += chunkSize + remaining -= chunkSize + chunks += chunk + } + new ChunkedByteBuffer(chunks.toArray) + } + } +} + /** * Reads data from a ChunkedByteBuffer. * diff --git a/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBufferFileRegion.scala b/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBufferFileRegion.scala new file mode 100644 index 0000000000000..9622d0ac05368 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBufferFileRegion.scala @@ -0,0 +1,86 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.util.io + +import java.nio.channels.WritableByteChannel + +import io.netty.channel.FileRegion +import io.netty.util.AbstractReferenceCounted + +import org.apache.spark.internal.Logging +import org.apache.spark.network.util.AbstractFileRegion + + +/** + * This exposes a ChunkedByteBuffer as a netty FileRegion, just to allow sending > 2gb in one netty + * message. This is because netty cannot send a ByteBuf > 2g, but it can send a large FileRegion, + * even though the data is not backed by a file. + */ +private[io] class ChunkedByteBufferFileRegion( + private val chunkedByteBuffer: ChunkedByteBuffer, + private val ioChunkSize: Int) extends AbstractFileRegion { + + private var _transferred: Long = 0 + // this duplicates the original chunks, so we're free to modify the position, limit, etc. + private val chunks = chunkedByteBuffer.getChunks() + private val size = chunks.foldLeft(0L) { _ + _.remaining() } + + protected def deallocate: Unit = {} + + override def count(): Long = size + + // this is the "start position" of the overall Data in the backing file, not our current position + override def position(): Long = 0 + + override def transferred(): Long = _transferred + + private var currentChunkIdx = 0 + + def transferTo(target: WritableByteChannel, position: Long): Long = { + assert(position == _transferred) + if (position == size) return 0L + var keepGoing = true + var written = 0L + var currentChunk = chunks(currentChunkIdx) + while (keepGoing) { + while (currentChunk.hasRemaining && keepGoing) { + val ioSize = Math.min(currentChunk.remaining(), ioChunkSize) + val originalLimit = currentChunk.limit() + currentChunk.limit(currentChunk.position() + ioSize) + val thisWriteSize = target.write(currentChunk) + currentChunk.limit(originalLimit) + written += thisWriteSize + if (thisWriteSize < ioSize) { + // the channel did not accept our entire write. We do *not* keep trying -- netty wants + // us to just stop, and report how much we've written. + keepGoing = false + } + } + if (keepGoing) { + // advance to the next chunk (if there are any more) + currentChunkIdx += 1 + if (currentChunkIdx == chunks.size) { + keepGoing = false + } else { + currentChunk = chunks(currentChunkIdx) + } + } + } + _transferred += written + written + } +} diff --git a/core/src/test/java/org/apache/spark/memory/TaskMemoryManagerSuite.java b/core/src/test/java/org/apache/spark/memory/TaskMemoryManagerSuite.java index d7d2d0b012bd3..a0664b30d6cc2 100644 --- a/core/src/test/java/org/apache/spark/memory/TaskMemoryManagerSuite.java +++ b/core/src/test/java/org/apache/spark/memory/TaskMemoryManagerSuite.java @@ -76,7 +76,7 @@ public void freeingPageSetsPageNumberToSpecialConstant() { final MemoryConsumer c = new TestMemoryConsumer(manager, MemoryMode.ON_HEAP); final MemoryBlock dataPage = manager.allocatePage(256, c); c.freePage(dataPage); - Assert.assertEquals(MemoryBlock.FREED_IN_ALLOCATOR_PAGE_NUMBER, dataPage.getPageNumber()); + Assert.assertEquals(MemoryBlock.FREED_IN_ALLOCATOR_PAGE_NUMBER, dataPage.pageNumber); } @Test(expected = AssertionError.class) diff --git a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java index 0d5c5ea7903e9..faa70f23b0ac6 100644 --- a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java +++ b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java @@ -233,6 +233,7 @@ public void writeEmptyIterator() throws Exception { writer.write(Iterators.emptyIterator()); final Option mapStatus = writer.stop(true); assertTrue(mapStatus.isDefined()); + assertEquals(0, mapStatus.get().numberOfOutput()); assertTrue(mergedOutputFile.exists()); assertArrayEquals(new long[NUM_PARTITITONS], partitionSizesInMergedFile); assertEquals(0, taskMetrics.shuffleWriteMetrics().recordsWritten()); @@ -252,6 +253,7 @@ public void writeWithoutSpilling() throws Exception { writer.write(dataToWrite.iterator()); final Option mapStatus = writer.stop(true); assertTrue(mapStatus.isDefined()); + assertEquals(NUM_PARTITITONS, mapStatus.get().numberOfOutput()); assertTrue(mergedOutputFile.exists()); long sumOfPartitionSizes = 0; diff --git a/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java b/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java index 03cec8ed81b72..53a233f698c7a 100644 --- a/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java +++ b/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java @@ -379,7 +379,7 @@ public void iteratingOverDataPagesWithWastedSpace() throws Exception { @Test public void randomizedStressTest() { - final int size = 65536; + final int size = 32768; // Java arrays' hashCodes() aren't based on the arrays' contents, so we need to wrap arrays // into ByteBuffers in order to use them as keys here. final Map expected = new HashMap<>(); @@ -388,7 +388,7 @@ public void randomizedStressTest() { // Fill the map to 90% full so that we can trigger probing for (int i = 0; i < size * 0.9; i++) { final byte[] key = getRandomByteArray(rand.nextInt(256) + 1); - final byte[] value = getRandomByteArray(rand.nextInt(512) + 1); + final byte[] value = getRandomByteArray(rand.nextInt(256) + 1); if (!expected.containsKey(ByteBuffer.wrap(key))) { expected.put(ByteBuffer.wrap(key), value); final BytesToBytesMap.Location loc = map.lookup( diff --git a/core/src/test/java/test/org/apache/spark/JavaSparkContextSuite.java b/core/src/test/java/test/org/apache/spark/JavaSparkContextSuite.java index 7e9cc70d8651f..0f489fb219010 100644 --- a/core/src/test/java/test/org/apache/spark/JavaSparkContextSuite.java +++ b/core/src/test/java/test/org/apache/spark/JavaSparkContextSuite.java @@ -30,7 +30,7 @@ import org.apache.spark.*; /** - * Java apps can uses both Java-friendly JavaSparkContext and Scala SparkContext. + * Java apps can use both Java-friendly JavaSparkContext and Scala SparkContext. */ public class JavaSparkContextSuite implements Serializable { diff --git a/core/src/test/resources/HistoryServerExpectations/application_list_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/application_list_json_expectation.json index 4fecf84db65a2..eea6f595efd2a 100644 --- a/core/src/test/resources/HistoryServerExpectations/application_list_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/application_list_json_expectation.json @@ -1,4 +1,19 @@ [ { + "id" : "application_1506645932520_24630151", + "name" : "Spark shell", + "attempts" : [ { + "startTime" : "2018-04-19T23:54:42.734GMT", + "endTime" : "2018-04-19T23:56:29.134GMT", + "lastUpdated" : "", + "duration" : 106400, + "sparkUser" : "edlu", + "completed" : true, + "appSparkVersion" : "2.4.0-SNAPSHOT", + "lastUpdatedEpoch" : 0, + "startTimeEpoch" : 1524182082734, + "endTimeEpoch" : 1524182189134 + } ] +}, { "id" : "application_1516285256255_0012", "name" : "Spark shell", "attempts" : [ { diff --git a/core/src/test/resources/HistoryServerExpectations/completed_app_list_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/completed_app_list_json_expectation.json index 4fecf84db65a2..7bc7f31be097b 100644 --- a/core/src/test/resources/HistoryServerExpectations/completed_app_list_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/completed_app_list_json_expectation.json @@ -1,4 +1,19 @@ [ { + "id" : "application_1506645932520_24630151", + "name" : "Spark shell", + "attempts" : [ { + "startTime" : "2018-04-19T23:54:42.734GMT", + "endTime" : "2018-04-19T23:56:29.134GMT", + "lastUpdated" : "", + "duration" : 106400, + "sparkUser" : "edlu", + "completed" : true, + "appSparkVersion" : "2.4.0-SNAPSHOT", + "startTimeEpoch" : 1524182082734, + "endTimeEpoch" : 1524182189134, + "lastUpdatedEpoch" : 0 + } ] +}, { "id" : "application_1516285256255_0012", "name" : "Spark shell", "attempts" : [ { diff --git a/core/src/test/resources/HistoryServerExpectations/executor_list_with_executor_metrics_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/executor_list_with_executor_metrics_json_expectation.json new file mode 100644 index 0000000000000..9bf2086cc8e72 --- /dev/null +++ b/core/src/test/resources/HistoryServerExpectations/executor_list_with_executor_metrics_json_expectation.json @@ -0,0 +1,314 @@ +[ { + "id" : "driver", + "hostPort" : "node0033.grid.company.com:60749", + "isActive" : true, + "rddBlocks" : 0, + "memoryUsed" : 0, + "diskUsed" : 0, + "totalCores" : 0, + "maxTasks" : 0, + "activeTasks" : 0, + "failedTasks" : 0, + "completedTasks" : 0, + "totalTasks" : 0, + "totalDuration" : 0, + "totalGCTime" : 0, + "totalInputBytes" : 0, + "totalShuffleRead" : 0, + "totalShuffleWrite" : 0, + "isBlacklisted" : false, + "maxMemory" : 1043437977, + "addTime" : "2018-04-19T23:55:05.107GMT", + "executorLogs" : { }, + "memoryMetrics" : { + "usedOnHeapStorageMemory" : 0, + "usedOffHeapStorageMemory" : 0, + "totalOnHeapStorageMemory" : 1043437977, + "totalOffHeapStorageMemory" : 0 + }, + "blacklistedInStages" : [ ], + "peakMemoryMetrics" : { + "OnHeapStorageMemory" : 905801, + "JVMOffHeapMemory" : 205304696, + "OffHeapExecutionMemory" : 0, + "OnHeapUnifiedMemory" : 905801, + "OnHeapExecutionMemory" : 0, + "OffHeapUnifiedMemory" : 0, + "DirectPoolMemory" : 397602, + "MappedPoolMemory" : 0, + "JVMHeapMemory" : 629553808, + "OffHeapStorageMemory" : 0 + } +}, { + "id" : "7", + "hostPort" : "node6340.grid.company.com:5933", + "isActive" : true, + "rddBlocks" : 0, + "memoryUsed" : 0, + "diskUsed" : 0, + "totalCores" : 1, + "maxTasks" : 1, + "activeTasks" : 0, + "failedTasks" : 0, + "completedTasks" : 0, + "totalTasks" : 0, + "totalDuration" : 0, + "totalGCTime" : 0, + "totalInputBytes" : 0, + "totalShuffleRead" : 0, + "totalShuffleWrite" : 0, + "isBlacklisted" : false, + "maxMemory" : 956615884, + "addTime" : "2018-04-19T23:55:49.826GMT", + "executorLogs" : { + "stdout" : "http://node6340.grid.company.com:8042/node/containerlogs/container_e05_1523494505172_1552404_01_000009/edlu/stdout?start=-4096", + "stderr" : "http://node6340.grid.company.com:8042/node/containerlogs/container_e05_1523494505172_1552404_01_000009/edlu/stderr?start=-4096" + }, + "memoryMetrics" : { + "usedOnHeapStorageMemory" : 0, + "usedOffHeapStorageMemory" : 0, + "totalOnHeapStorageMemory" : 956615884, + "totalOffHeapStorageMemory" : 0 + }, + "blacklistedInStages" : [ ] +}, { + "id" : "6", + "hostPort" : "node6644.grid.company.com:8445", + "isActive" : true, + "rddBlocks" : 0, + "memoryUsed" : 0, + "diskUsed" : 0, + "totalCores" : 1, + "maxTasks" : 1, + "activeTasks" : 0, + "failedTasks" : 0, + "completedTasks" : 0, + "totalTasks" : 0, + "totalDuration" : 0, + "totalGCTime" : 0, + "totalInputBytes" : 0, + "totalShuffleRead" : 0, + "totalShuffleWrite" : 0, + "isBlacklisted" : false, + "maxMemory" : 956615884, + "addTime" : "2018-04-19T23:55:47.549GMT", + "executorLogs" : { + "stdout" : "http://node6644.grid.company.com:8042/node/containerlogs/container_e05_1523494505172_1552404_01_000008/edlu/stdout?start=-4096", + "stderr" : "http://node6644.grid.company.com:8042/node/containerlogs/container_e05_1523494505172_1552404_01_000008/edlu/stderr?start=-4096" + }, + "memoryMetrics" : { + "usedOnHeapStorageMemory" : 0, + "usedOffHeapStorageMemory" : 0, + "totalOnHeapStorageMemory" : 956615884, + "totalOffHeapStorageMemory" : 0 + }, + "blacklistedInStages" : [ ] +}, { + "id" : "5", + "hostPort" : "node2477.grid.company.com:20123", + "isActive" : true, + "rddBlocks" : 0, + "memoryUsed" : 0, + "diskUsed" : 0, + "totalCores" : 1, + "maxTasks" : 1, + "activeTasks" : 0, + "failedTasks" : 0, + "completedTasks" : 1, + "totalTasks" : 1, + "totalDuration" : 9252, + "totalGCTime" : 920, + "totalInputBytes" : 36838295, + "totalShuffleRead" : 0, + "totalShuffleWrite" : 355051, + "isBlacklisted" : false, + "maxMemory" : 956615884, + "addTime" : "2018-04-19T23:55:43.160GMT", + "executorLogs" : { + "stdout" : "http://node2477.grid.company.com:8042/node/containerlogs/container_e05_1523494505172_1552404_01_000007/edlu/stdout?start=-4096", + "stderr" : "http://node2477.grid.company.com:8042/node/containerlogs/container_e05_1523494505172_1552404_01_000007/edlu/stderr?start=-4096" + }, + "memoryMetrics" : { + "usedOnHeapStorageMemory" : 0, + "usedOffHeapStorageMemory" : 0, + "totalOnHeapStorageMemory" : 956615884, + "totalOffHeapStorageMemory" : 0 + }, + "blacklistedInStages" : [ ] +}, { + "id" : "4", + "hostPort" : "node4243.grid.company.com:16084", + "isActive" : true, + "rddBlocks" : 0, + "memoryUsed" : 0, + "diskUsed" : 0, + "totalCores" : 1, + "maxTasks" : 1, + "activeTasks" : 0, + "failedTasks" : 0, + "completedTasks" : 3, + "totalTasks" : 3, + "totalDuration" : 15645, + "totalGCTime" : 405, + "totalInputBytes" : 87272855, + "totalShuffleRead" : 438675, + "totalShuffleWrite" : 26773039, + "isBlacklisted" : false, + "maxMemory" : 956615884, + "addTime" : "2018-04-19T23:55:12.278GMT", + "executorLogs" : { + "stdout" : "http://node4243.grid.company.com:8042/node/containerlogs/container_e05_1523494505172_1552404_01_000006/edlu/stdout?start=-4096", + "stderr" : "http://node4243.grid.company.com:8042/node/containerlogs/container_e05_1523494505172_1552404_01_000006/edlu/stderr?start=-4096" + }, + "memoryMetrics" : { + "usedOnHeapStorageMemory" : 0, + "usedOffHeapStorageMemory" : 0, + "totalOnHeapStorageMemory" : 956615884, + "totalOffHeapStorageMemory" : 0 + }, + "blacklistedInStages" : [ ], + "peakMemoryMetrics" : { + "OnHeapStorageMemory" : 63104457, + "JVMOffHeapMemory" : 95657456, + "OffHeapExecutionMemory" : 0, + "OnHeapUnifiedMemory" : 100853193, + "OnHeapExecutionMemory" : 37748736, + "OffHeapUnifiedMemory" : 0, + "DirectPoolMemory" : 126261, + "MappedPoolMemory" : 0, + "JVMHeapMemory" : 518613056, + "OffHeapStorageMemory" : 0 + } +}, { + "id" : "3", + "hostPort" : "node0998.grid.company.com:45265", + "isActive" : true, + "rddBlocks" : 0, + "memoryUsed" : 0, + "diskUsed" : 0, + "totalCores" : 1, + "maxTasks" : 1, + "activeTasks" : 0, + "failedTasks" : 0, + "completedTasks" : 1, + "totalTasks" : 1, + "totalDuration" : 14491, + "totalGCTime" : 342, + "totalInputBytes" : 50409514, + "totalShuffleRead" : 0, + "totalShuffleWrite" : 31362123, + "isBlacklisted" : false, + "maxMemory" : 956615884, + "addTime" : "2018-04-19T23:55:12.088GMT", + "executorLogs" : { + "stdout" : "http://node0998.grid.company.com:8042/node/containerlogs/container_e05_1523494505172_1552404_01_000005/edlu/stdout?start=-4096", + "stderr" : "http://node0998.grid.company.com:8042/node/containerlogs/container_e05_1523494505172_1552404_01_000005/edlu/stderr?start=-4096" + }, + "memoryMetrics" : { + "usedOnHeapStorageMemory" : 0, + "usedOffHeapStorageMemory" : 0, + "totalOnHeapStorageMemory" : 956615884, + "totalOffHeapStorageMemory" : 0 + }, + "blacklistedInStages" : [ ], + "peakMemoryMetrics" : { + "OnHeapStorageMemory" : 69535048, + "JVMOffHeapMemory" : 90709624, + "OffHeapExecutionMemory" : 0, + "OnHeapUnifiedMemory" : 69535048, + "OnHeapExecutionMemory" : 0, + "OffHeapUnifiedMemory" : 0, + "DirectPoolMemory" : 87796, + "MappedPoolMemory" : 0, + "JVMHeapMemory" : 726805712, + "OffHeapStorageMemory" : 0 + } +}, { + "id" : "2", + "hostPort" : "node4045.grid.company.com:29262", + "isActive" : true, + "rddBlocks" : 0, + "memoryUsed" : 0, + "diskUsed" : 0, + "totalCores" : 1, + "maxTasks" : 1, + "activeTasks" : 0, + "failedTasks" : 0, + "completedTasks" : 1, + "totalTasks" : 1, + "totalDuration" : 14113, + "totalGCTime" : 326, + "totalInputBytes" : 50423423, + "totalShuffleRead" : 0, + "totalShuffleWrite" : 22950296, + "isBlacklisted" : false, + "maxMemory" : 956615884, + "addTime" : "2018-04-19T23:55:12.471GMT", + "executorLogs" : { + "stdout" : "http://node4045.grid.company.com:8042/node/containerlogs/container_e05_1523494505172_1552404_01_000004/edlu/stdout?start=-4096", + "stderr" : "http://node4045.grid.company.com:8042/node/containerlogs/container_e05_1523494505172_1552404_01_000004/edlu/stderr?start=-4096" + }, + "memoryMetrics" : { + "usedOnHeapStorageMemory" : 0, + "usedOffHeapStorageMemory" : 0, + "totalOnHeapStorageMemory" : 956615884, + "totalOffHeapStorageMemory" : 0 + }, + "blacklistedInStages" : [ ], + "peakMemoryMetrics" : { + "OnHeapStorageMemory" : 58468944, + "JVMOffHeapMemory" : 91208368, + "OffHeapExecutionMemory" : 0, + "OnHeapUnifiedMemory" : 58468944, + "OnHeapExecutionMemory" : 0, + "OffHeapUnifiedMemory" : 0, + "DirectPoolMemory" : 87796, + "MappedPoolMemory" : 0, + "JVMHeapMemory" : 595946552, + "OffHeapStorageMemory" : 0 + } +}, { + "id" : "1", + "hostPort" : "node1404.grid.company.com:34043", + "isActive" : true, + "rddBlocks" : 0, + "memoryUsed" : 0, + "diskUsed" : 0, + "totalCores" : 1, + "maxTasks" : 1, + "activeTasks" : 0, + "failedTasks" : 0, + "completedTasks" : 3, + "totalTasks" : 3, + "totalDuration" : 15665, + "totalGCTime" : 471, + "totalInputBytes" : 98905018, + "totalShuffleRead" : 0, + "totalShuffleWrite" : 20594744, + "isBlacklisted" : false, + "maxMemory" : 956615884, + "addTime" : "2018-04-19T23:55:11.695GMT", + "executorLogs" : { + "stdout" : "http://node1404.grid.company.com:8042/node/containerlogs/container_e05_1523494505172_1552404_01_000002/edlu/stdout?start=-4096", + "stderr" : "http://node1404.grid.company.com:8042/node/containerlogs/container_e05_1523494505172_1552404_01_000002/edlu/stderr?start=-4096" + }, + "memoryMetrics" : { + "usedOnHeapStorageMemory" : 0, + "usedOffHeapStorageMemory" : 0, + "totalOnHeapStorageMemory" : 956615884, + "totalOffHeapStorageMemory" : 0 + }, + "blacklistedInStages" : [ ], + "peakMemoryMetrics" : { + "OnHeapStorageMemory" : 47962185, + "JVMOffHeapMemory" : 100519936, + "OffHeapExecutionMemory" : 0, + "OnHeapUnifiedMemory" : 47962185, + "OnHeapExecutionMemory" : 0, + "OffHeapUnifiedMemory" : 0, + "DirectPoolMemory" : 98230, + "MappedPoolMemory" : 0, + "JVMHeapMemory" : 755008624, + "OffHeapStorageMemory" : 0 + } +} ] diff --git a/core/src/test/resources/HistoryServerExpectations/limit_app_list_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/limit_app_list_json_expectation.json index 79950b0dc6486..9e1e65a358815 100644 --- a/core/src/test/resources/HistoryServerExpectations/limit_app_list_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/limit_app_list_json_expectation.json @@ -1,4 +1,19 @@ [ { + "id" : "application_1506645932520_24630151", + "name" : "Spark shell", + "attempts" : [ { + "startTime" : "2018-04-19T23:54:42.734GMT", + "endTime" : "2018-04-19T23:56:29.134GMT", + "lastUpdated" : "", + "duration" : 106400, + "sparkUser" : "edlu", + "completed" : true, + "appSparkVersion" : "2.4.0-SNAPSHOT", + "lastUpdatedEpoch" : 0, + "startTimeEpoch" : 1524182082734, + "endTimeEpoch" : 1524182189134 + } ] +}, { "id" : "application_1516285256255_0012", "name" : "Spark shell", "attempts" : [ { @@ -28,19 +43,4 @@ "startTimeEpoch" : 1515492942372, "endTimeEpoch" : 1515493477606 } ] -}, { - "id" : "app-20161116163331-0000", - "name" : "Spark shell", - "attempts" : [ { - "startTime" : "2016-11-16T22:33:29.916GMT", - "endTime" : "2016-11-16T22:33:40.587GMT", - "lastUpdated" : "", - "duration" : 10671, - "sparkUser" : "jose", - "completed" : true, - "appSparkVersion" : "2.1.0-SNAPSHOT", - "lastUpdatedEpoch" : 0, - "startTimeEpoch" : 1479335609916, - "endTimeEpoch" : 1479335620587 - } ] } ] diff --git a/core/src/test/resources/HistoryServerExpectations/minDate_app_list_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/minDate_app_list_json_expectation.json index 7d60977dcd4fe..28c6bf1b3e01e 100644 --- a/core/src/test/resources/HistoryServerExpectations/minDate_app_list_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/minDate_app_list_json_expectation.json @@ -1,4 +1,19 @@ [ { + "id" : "application_1506645932520_24630151", + "name" : "Spark shell", + "attempts" : [ { + "startTime" : "2018-04-19T23:54:42.734GMT", + "endTime" : "2018-04-19T23:56:29.134GMT", + "lastUpdated" : "", + "duration" : 106400, + "sparkUser" : "edlu", + "completed" : true, + "appSparkVersion" : "2.4.0-SNAPSHOT", + "lastUpdatedEpoch" : 0, + "startTimeEpoch" : 1524182082734, + "endTimeEpoch" : 1524182189134 + } ] +}, { "id" : "application_1516285256255_0012", "name" : "Spark shell", "attempts" : [ { diff --git a/core/src/test/resources/HistoryServerExpectations/minEndDate_app_list_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/minEndDate_app_list_json_expectation.json index dfbfd8aedcc23..f547b79f47e1a 100644 --- a/core/src/test/resources/HistoryServerExpectations/minEndDate_app_list_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/minEndDate_app_list_json_expectation.json @@ -1,4 +1,19 @@ [ { + "id" : "application_1506645932520_24630151", + "name" : "Spark shell", + "attempts" : [ { + "startTime" : "2018-04-19T23:54:42.734GMT", + "endTime" : "2018-04-19T23:56:29.134GMT", + "lastUpdated" : "", + "duration" : 106400, + "sparkUser" : "edlu", + "completed" : true, + "appSparkVersion" : "2.4.0-SNAPSHOT", + "lastUpdatedEpoch" : 0, + "startTimeEpoch" : 1524182082734, + "endTimeEpoch" : 1524182189134 + } ] +}, { "id" : "application_1516285256255_0012", "name" : "Spark shell", "attempts" : [ { @@ -101,4 +116,4 @@ "startTimeEpoch" : 1430917380880, "endTimeEpoch" : 1430917380890 } ] -} ] \ No newline at end of file +} ] diff --git a/core/src/test/resources/spark-events/application_1506645932520_24630151 b/core/src/test/resources/spark-events/application_1506645932520_24630151 new file mode 100644 index 0000000000000..c48ed741c56e0 --- /dev/null +++ b/core/src/test/resources/spark-events/application_1506645932520_24630151 @@ -0,0 +1,63 @@ +{"Event":"SparkListenerLogStart","Spark Version":"2.4.0-SNAPSHOT"} +{"Event":"SparkListenerBlockManagerAdded","Block Manager ID":{"Executor ID":"driver","Host":"node0033.grid.company.com","Port":60749},"Maximum Memory":1043437977,"Timestamp":1524182105107,"Maximum Onheap Memory":1043437977,"Maximum Offheap Memory":0} +{"Event":"SparkListenerEnvironmentUpdate","JVM Information":{"Java Home":"/usr/java/jdk1.8.0_31/jre","Java Version":"1.8.0_31 (Oracle Corporation)","Scala Version":"version 2.11.8"},"Spark Properties":{"spark.jars.ivySettings":"/export/apps/spark/commonconf/ivysettings.xml","spark.serializer":"org.apache.spark.serializer.KryoSerializer","spark.driver.host":"node0033.grid.company.com","spark.dynamicAllocation.sustainedSchedulerBacklogTimeout":"5","spark.eventLog.enabled":"true","spark.ui.port":"0","spark.driver.port":"57705","spark.shuffle.service.enabled":"true","spark.ui.acls.enable":"true","spark.reducer.maxSizeInFlight":"48m","spark.yarn.queue":"spark_default","spark.repl.class.uri":"spark://node0033.grid.company.com:57705/classes","spark.jars":"","spark.yarn.historyServer.address":"clustersh01.grid.company.com:18080","spark.memoryOverhead.multiplier.percent":"10","spark.repl.class.outputDir":"/grid/a/mapred/tmp/spark-21b68b4b-c1db-460e-a228-b87545d870f1/repl-58778a76-04c1-434d-bfb7-9a9b83afe718","spark.dynamicAllocation.cachedExecutorIdleTimeout":"1200","spark.yarn.access.namenodes":"hdfs://clusternn02.grid.company.com:9000","spark.app.name":"Spark shell","spark.dynamicAllocation.schedulerBacklogTimeout":"5","spark.yarn.security.credentials.hive.enabled":"false","spark.yarn.am.cores":"1","spark.memoryOverhead.min":"384","spark.scheduler.mode":"FIFO","spark.driver.memory":"2G","spark.executor.instances":"4","spark.isolated.classloader.additional.classes.prefix":"com_company_","spark.logConf":"true","spark.ui.showConsoleProgress":"true","spark.user.priority.jars":"*********(redacted)","spark.isolated.classloader":"true","spark.sql.sources.schemaStringLengthThreshold":"40000","spark.yarn.secondary.jars":"spark-avro_2.11-3.2.0.21.jar,grid-topology-1.0.jar","spark.reducer.maxBlocksInFlightPerAddress":"100","spark.dynamicAllocation.maxExecutors":"900","spark.yarn.appMasterEnv.LD_LIBRARY_PATH":"/export/apps/hadoop/latest/lib/native","spark.executor.id":"driver","spark.yarn.am.memory":"2G","spark.driver.cores":"1","spark.search.packages":"com.company.dali:dali-data-spark,com.company.spark-common:spark-common","spark.min.mem.vore.ratio":"5","spark.sql.sources.partitionOverwriteMode":"DYNAMIC","spark.submit.deployMode":"client","spark.yarn.maxAppAttempts":"1","spark.master":"yarn","spark.default.packages":"com.company.dali:dali-data-spark:8.+?classifier=all,com.company.spark-common:spark-common_2.10:0.+?","spark.isolated.classloader.default.jar":"*dali-data-spark*","spark.authenticate":"true","spark.eventLog.usexattr":"true","spark.ui.filters":"org.apache.hadoop.yarn.server.webproxy.amfilter.AmIpFilter","spark.executor.memory":"2G","spark.home":"/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51","spark.reducer.maxReqsInFlight":"10","spark.eventLog.dir":"hdfs://clusternn02.grid.company.com:9000/system/spark-history","spark.dynamicAllocation.enabled":"true","spark.sql.catalogImplementation":"hive","spark.isolated.classes":"org.apache.hadoop.hive.ql.io.CombineHiveInputFormat$CombineHiveInputSplit","spark.eventLog.compress":"true","spark.executor.cores":"1","spark.version":"2.1.0","spark.driver.appUIAddress":"http://node0033.grid.company.com:8364","spark.repl.local.jars":"file:///export/home/edlu/spark-avro_2.11-3.2.0.21.jar,file:///export/apps/hadoop/site/lib/grid-topology-1.0.jar","spark.org.apache.hadoop.yarn.server.webproxy.amfilter.AmIpFilter.param.PROXY_HOSTS":"clusterwp01.grid.company.com","spark.min.memory-gb.size":"10","spark.dynamicAllocation.minExecutors":"1","spark.dynamicAllocation.initialExecutors":"3","spark.expressionencoder.org.apache.avro.specific.SpecificRecord":"com.databricks.spark.avro.AvroEncoder$","spark.org.apache.hadoop.yarn.server.webproxy.amfilter.AmIpFilter.param.PROXY_URI_BASES":"http://clusterwp01.grid.company.com:8080/proxy/application_1506645932520_24630151","spark.executorEnv.LD_LIBRARY_PATH":"/export/apps/hadoop/latest/lib/native","spark.dynamicAllocation.executorIdleTimeout":"150","spark.shell.auto.node.labeling":"true","spark.yarn.dist.jars":"file:///export/home/edlu/spark-avro_2.11-3.2.0.21.jar,file:///export/apps/hadoop/site/lib/grid-topology-1.0.jar","spark.app.id":"application_1506645932520_24630151","spark.ui.view.acls":"*"},"System Properties":{"java.io.tmpdir":"/tmp","line.separator":"\n","path.separator":":","sun.management.compiler":"HotSpot 64-Bit Tiered Compilers","SPARK_SUBMIT":"true","sun.cpu.endian":"little","java.specification.version":"1.8","java.vm.specification.name":"Java Virtual Machine Specification","java.vendor":"Oracle Corporation","java.vm.specification.version":"1.8","user.home":"*********(redacted)","file.encoding.pkg":"sun.io","sun.nio.ch.bugLevel":"","sun.arch.data.model":"64","sun.boot.library.path":"/usr/java/jdk1.8.0_31/jre/lib/amd64","user.dir":"*********(redacted)","java.library.path":"/usr/java/packages/lib/amd64:/usr/lib64:/lib64:/lib:/usr/lib","sun.cpu.isalist":"","os.arch":"amd64","java.vm.version":"25.31-b07","java.endorsed.dirs":"/usr/java/jdk1.8.0_31/jre/lib/endorsed","java.runtime.version":"1.8.0_31-b13","java.vm.info":"mixed mode","java.ext.dirs":"/usr/java/jdk1.8.0_31/jre/lib/ext:/usr/java/packages/lib/ext","java.runtime.name":"Java(TM) SE Runtime Environment","file.separator":"/","java.class.version":"52.0","scala.usejavacp":"true","java.specification.name":"Java Platform API Specification","sun.boot.class.path":"/usr/java/jdk1.8.0_31/jre/lib/resources.jar:/usr/java/jdk1.8.0_31/jre/lib/rt.jar:/usr/java/jdk1.8.0_31/jre/lib/sunrsasign.jar:/usr/java/jdk1.8.0_31/jre/lib/jsse.jar:/usr/java/jdk1.8.0_31/jre/lib/jce.jar:/usr/java/jdk1.8.0_31/jre/lib/charsets.jar:/usr/java/jdk1.8.0_31/jre/lib/jfr.jar:/usr/java/jdk1.8.0_31/jre/classes","file.encoding":"UTF-8","user.timezone":"*********(redacted)","java.specification.vendor":"Oracle Corporation","sun.java.launcher":"SUN_STANDARD","os.version":"2.6.32-504.16.2.el6.x86_64","sun.os.patch.level":"unknown","java.vm.specification.vendor":"Oracle Corporation","user.country":"*********(redacted)","sun.jnu.encoding":"UTF-8","user.language":"*********(redacted)","java.vendor.url":"*********(redacted)","java.awt.printerjob":"sun.print.PSPrinterJob","java.awt.graphicsenv":"sun.awt.X11GraphicsEnvironment","awt.toolkit":"sun.awt.X11.XToolkit","os.name":"Linux","java.vm.vendor":"Oracle Corporation","java.vendor.url.bug":"*********(redacted)","user.name":"*********(redacted)","java.vm.name":"Java HotSpot(TM) 64-Bit Server VM","sun.java.command":"org.apache.spark.deploy.SparkSubmit --master yarn --deploy-mode client --class org.apache.spark.repl.Main --name Spark shell --jars /export/home/edlu/spark-avro_2.11-3.2.0.21.jar,/export/apps/hadoop/site/lib/grid-topology-1.0.jar --num-executors 4 spark-shell","java.home":"/usr/java/jdk1.8.0_31/jre","java.version":"1.8.0_31","sun.io.unicode.encoding":"UnicodeLittle"},"Classpath Entries":{"/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/guice-servlet-3.0.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/jackson-mapper-asl-1.9.13.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/derby-10.12.1.1.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/htrace-core-3.0.4.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/api-asn1-api-1.0.0-M20.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/scala-reflect-2.11.8.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/datanucleus-rdbms-3.2.9.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/spark-graphx_2.11-2.4.0-SNAPSHOT.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/api-util-1.0.0-M20.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/hadoop-yarn-client-2.7.4.51.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/base64-2.3.8.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/hadoop-auth-2.7.4.51.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/validation-api-1.1.0.Final.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/hk2-utils-2.4.0-b34.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/zstd-jni-1.3.2-2.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/hadoop-yarn-api-2.7.4.51.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/objenesis-2.1.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/conf/":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/httpclient-4.5.4.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/kryo-shaded-3.0.3.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/scala-library-2.11.8.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/commons-net-3.1.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/xz-1.0.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/json4s-jackson_2.11-3.5.3.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/javax.servlet-api-3.1.0.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/jersey-server-1.9.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/jackson-annotations-2.6.7.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/parquet-hadoop-1.8.2.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/activation-1.1.1.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/spire_2.11-0.13.0.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/arpack_combined_all-0.1.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/libthrift-0.9.3.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/aircompressor-0.8.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/parquet-jackson-1.8.2.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/hk2-api-2.4.0-b34.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/asm-3.1.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/apacheds-kerberos-codec-2.0.0-M15.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/spark-hive_2.11-2.4.0-SNAPSHOT.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/ivy-2.4.0.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/javax.inject-2.4.0-b34.jar":"System Classpath","/export/apps/hadoop/site/etc/hadoop/":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/snappy-java-1.1.7.1.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/arrow-format-0.8.0.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/netty-all-4.1.17.Final.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/avro-ipc-1.7.7.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/xmlenc-0.52.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/jdo-api-3.0.1.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/curator-client-2.7.1.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/antlr-runtime-3.4.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/pyrolite-4.13.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/scala-xml_2.11-1.0.5.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/spark-catalyst_2.11-2.4.0-SNAPSHOT.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/commons-collections-3.2.2.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/slf4j-api-1.7.16.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/stream-2.7.0.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/parquet-format-2.3.1.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/arrow-vector-0.8.0.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/hadoop-yarn-server-web-proxy-2.7.4.51.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/htrace-core-3.1.0-incubating.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/spark-sketch_2.11-2.4.0-SNAPSHOT.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/jersey-common-2.22.2.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/hppc-0.7.2.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/jackson-core-asl-1.9.13.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/spark-sql_2.11-2.4.0-SNAPSHOT.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/univocity-parsers-2.5.9.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/commons-math3-3.4.1.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/commons-compiler-3.0.8.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/commons-beanutils-1.7.0.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/java-xmlbuilder-1.1.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/javax.inject-1.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/hadoop-annotations-2.7.4.51.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/netty-3.9.9.Final.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/aopalliance-repackaged-2.4.0-b34.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/zookeeper-3.4.6.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/guice-3.0.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/scala-compiler-2.11.8.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/eigenbase-properties-1.1.5.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/aopalliance-1.0.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/spark-yarn_2.11-2.4.0-SNAPSHOT.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/JavaEWAH-0.3.2.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/jsr305-1.3.9.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/libfb303-0.9.3.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/javax.annotation-api-1.2.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/hadoop-yarn-server-common-2.7.4.51.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/commons-digester-1.8.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/metrics-jvm-3.1.5.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/curator-framework-2.7.1.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/javax.ws.rs-api-2.0.1.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/paranamer-2.8.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/janino-3.0.8.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/hadoop-mapreduce-client-core-2.7.4.51.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/jersey-server-2.22.2.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/orc-core-1.4.3-nohive.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/jsch-0.1.42.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/calcite-linq4j-1.2.0-incubating.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/spark-unsafe_2.11-2.4.0-SNAPSHOT.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/commons-codec-1.10.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/jtransforms-2.4.0.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/lz4-java-1.4.0.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/datanucleus-core-3.2.10.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/flatbuffers-1.2.0-3f79e055.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/hive-exec-1.2.1.spark2.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/avro-mapred-1.7.7-hadoop2.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/stax-api-1.0.1.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/core-1.1.2.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/leveldbjni-all-1.8.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/datanucleus-api-jdo-3.2.6.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/jackson-databind-2.6.7.1.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/commons-dbcp-1.4.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/jackson-module-scala_2.11-2.6.7.1.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/commons-lang3-3.5.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/spire-macros_2.11-0.13.0.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/jackson-module-paranamer-2.7.9.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/slf4j-log4j12-1.7.16.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/chill-java-0.8.4.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/jodd-core-3.5.2.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/commons-pool-1.5.4.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/osgi-resource-locator-1.0.1.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/minlog-1.3.0.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/hadoop-mapreduce-client-common-2.7.4.51.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/gson-2.2.4.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/py4j-0.10.6.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/spark-streaming_2.11-2.4.0-SNAPSHOT.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/jackson-core-2.6.7.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/calcite-avatica-1.2.0-incubating.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/machinist_2.11-0.6.1.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/avro-1.7.7.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/commons-beanutils-core-1.8.0.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/apacheds-i18n-2.0.0-M15.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/jersey-media-jaxb-2.22.2.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/snappy-0.2.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/hadoop-mapreduce-client-app-2.7.4.51.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/parquet-hadoop-bundle-1.6.0.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/jul-to-slf4j-1.7.16.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/metrics-graphite-3.1.5.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/jcl-over-slf4j-1.7.16.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/metrics-core-3.1.5.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/spark-mllib-local_2.11-2.4.0-SNAPSHOT.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/arrow-memory-0.8.0.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/breeze_2.11-0.13.2.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/jersey-guava-2.22.2.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/hadoop-client-2.7.4.51.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/xercesImpl-2.9.1.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/spark-tags_2.11-2.4.0-SNAPSHOT.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/javolution-5.5.1.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/jetty-6.1.26.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/joda-time-2.9.3.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/antlr-2.7.7.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/hadoop-mapreduce-client-jobclient-2.7.4.51.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/commons-lang-2.6.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/compress-lzf-1.0.3.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/commons-crypto-1.0.0.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/jersey-core-1.9.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/curator-recipes-2.7.1.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/hk2-locator-2.4.0-b34.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/guava-14.0.1.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/jackson-jaxrs-1.9.13.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/spark-core_2.11-2.4.0-SNAPSHOT.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/jetty-sslengine-6.1.26.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/spark-network-common_2.11-2.4.0-SNAPSHOT.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/spark-launcher_2.11-2.4.0-SNAPSHOT.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/json4s-ast_2.11-3.5.3.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/antlr4-runtime-4.7.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/jetty-util-6.1.26.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/jaxb-api-2.2.2.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/commons-io-2.4.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/parquet-encoding-1.8.2.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/httpcore-4.4.8.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/macro-compat_2.11-1.1.1.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/jackson-xc-1.9.13.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/xbean-asm5-shaded-4.4.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/breeze-macros_2.11-0.13.2.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/protobuf-java-2.5.0.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/json4s-scalap_2.11-3.5.3.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/spark-mllib_2.11-2.4.0-SNAPSHOT.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/commons-configuration-1.6.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/commons-compress-1.4.1.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/json4s-core_2.11-3.5.3.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/orc-mapreduce-1.4.3-nohive.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/ST4-4.0.4.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/calcite-core-1.2.0-incubating.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/hadoop-mapreduce-client-shuffle-2.7.4.51.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/hadoop-common-2.7.4.51.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/spark-repl_2.11-2.4.0-SNAPSHOT.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/jersey-container-servlet-2.22.2.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/opencsv-2.3.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/commons-logging-1.1.3.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/shapeless_2.11-2.3.2.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/commons-cli-1.2.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/jersey-client-2.22.2.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/hadoop-yarn-common-2.7.4.51.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/hadoop-hdfs-2.7.4.51.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/log4j-1.2.17.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/parquet-column-1.8.2.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/hive-metastore-1.2.1.spark2.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/RoaringBitmap-0.5.11.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/chill_2.11-0.8.4.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/jersey-container-servlet-core-2.22.2.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/stringtemplate-3.2.1.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/parquet-common-1.8.2.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/spark-network-shuffle_2.11-2.4.0-SNAPSHOT.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/spark-kvstore_2.11-2.4.0-SNAPSHOT.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/stax-api-1.0-2.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/jta-1.1.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/javassist-3.18.1-GA.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/commons-httpclient-3.1.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/jets3t-0.9.4.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/apache-log4j-extras-1.2.17.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/metrics-json-3.1.5.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/bcprov-jdk15on-1.58.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/oro-2.0.8.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/bonecp-0.8.0.RELEASE.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/jsp-api-2.1.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/scala-parser-combinators_2.11-1.0.4.jar":"System Classpath"}} +{"Event":"SparkListenerApplicationStart","App Name":"Spark shell","App ID":"application_1506645932520_24630151","Timestamp":1524182082734,"User":"edlu"} +{"Event":"SparkListenerExecutorAdded","Timestamp":1524182111695,"Executor ID":"1","Executor Info":{"Host":"node1404.grid.company.com","Total Cores":1,"Log Urls":{"stdout":"http://node1404.grid.company.com:8042/node/containerlogs/container_e05_1523494505172_1552404_01_000002/edlu/stdout?start=-4096","stderr":"http://node1404.grid.company.com:8042/node/containerlogs/container_e05_1523494505172_1552404_01_000002/edlu/stderr?start=-4096"}}} +{"Event":"SparkListenerBlockManagerAdded","Block Manager ID":{"Executor ID":"1","Host":"node1404.grid.company.com","Port":34043},"Maximum Memory":956615884,"Timestamp":1524182111795,"Maximum Onheap Memory":956615884,"Maximum Offheap Memory":0} +{"Event":"SparkListenerExecutorAdded","Timestamp":1524182112088,"Executor ID":"3","Executor Info":{"Host":"node0998.grid.company.com","Total Cores":1,"Log Urls":{"stdout":"http://node0998.grid.company.com:8042/node/containerlogs/container_e05_1523494505172_1552404_01_000005/edlu/stdout?start=-4096","stderr":"http://node0998.grid.company.com:8042/node/containerlogs/container_e05_1523494505172_1552404_01_000005/edlu/stderr?start=-4096"}}} +{"Event":"SparkListenerBlockManagerAdded","Block Manager ID":{"Executor ID":"3","Host":"node0998.grid.company.com","Port":45265},"Maximum Memory":956615884,"Timestamp":1524182112208,"Maximum Onheap Memory":956615884,"Maximum Offheap Memory":0} +{"Event":"SparkListenerExecutorAdded","Timestamp":1524182112278,"Executor ID":"4","Executor Info":{"Host":"node4243.grid.company.com","Total Cores":1,"Log Urls":{"stdout":"http://node4243.grid.company.com:8042/node/containerlogs/container_e05_1523494505172_1552404_01_000006/edlu/stdout?start=-4096","stderr":"http://node4243.grid.company.com:8042/node/containerlogs/container_e05_1523494505172_1552404_01_000006/edlu/stderr?start=-4096"}}} +{"Event":"SparkListenerBlockManagerAdded","Block Manager ID":{"Executor ID":"4","Host":"node4243.grid.company.com","Port":16084},"Maximum Memory":956615884,"Timestamp":1524182112408,"Maximum Onheap Memory":956615884,"Maximum Offheap Memory":0} +{"Event":"SparkListenerExecutorAdded","Timestamp":1524182112471,"Executor ID":"2","Executor Info":{"Host":"node4045.grid.company.com","Total Cores":1,"Log Urls":{"stdout":"http://node4045.grid.company.com:8042/node/containerlogs/container_e05_1523494505172_1552404_01_000004/edlu/stdout?start=-4096","stderr":"http://node4045.grid.company.com:8042/node/containerlogs/container_e05_1523494505172_1552404_01_000004/edlu/stderr?start=-4096"}}} +{"Event":"SparkListenerBlockManagerAdded","Block Manager ID":{"Executor ID":"2","Host":"node4045.grid.company.com","Port":29262},"Maximum Memory":956615884,"Timestamp":1524182112578,"Maximum Onheap Memory":956615884,"Maximum Offheap Memory":0} +{"Event":"org.apache.spark.sql.execution.ui.SparkListenerSQLExecutionStart","executionId":0,"description":"createOrReplaceTempView at :40","details":"org.apache.spark.sql.Dataset.createOrReplaceTempView(Dataset.scala:3033)\n$line44.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:40)\n$line44.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:45)\n$line44.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:47)\n$line44.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:49)\n$line44.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:51)\n$line44.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:53)\n$line44.$read$$iw$$iw$$iw$$iw$$iw$$iw.(:55)\n$line44.$read$$iw$$iw$$iw$$iw$$iw.(:57)\n$line44.$read$$iw$$iw$$iw$$iw.(:59)\n$line44.$read$$iw$$iw$$iw.(:61)\n$line44.$read$$iw$$iw.(:63)\n$line44.$read$$iw.(:65)\n$line44.$read.(:67)\n$line44.$read$.(:71)\n$line44.$read$.()\n$line44.$eval$.$print$lzycompute(:7)\n$line44.$eval$.$print(:6)\n$line44.$eval.$print()\nsun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)","physicalPlanDescription":"== Parsed Logical Plan ==\nCreateViewCommand `apps`, false, true, LocalTempView\n +- AnalysisBarrier\n +- Project [appId#0, attemptId#1, name#2, mode#3, completed#4, duration#5L, endTime#6, endTimeEpoch#7L, lastUpdated#8, lastUpdatedEpoch#9L, sparkUser#10, startTime#11, startTimeEpoch#12L, appSparkVersion#13, cast(endTime#6 as date) AS endDate#28]\n +- Relation[appId#0,attemptId#1,name#2,mode#3,completed#4,duration#5L,endTime#6,endTimeEpoch#7L,lastUpdated#8,lastUpdatedEpoch#9L,sparkUser#10,startTime#11,startTimeEpoch#12L,appSparkVersion#13] avro\n\n== Analyzed Logical Plan ==\nCreateViewCommand `apps`, false, true, LocalTempView\n +- AnalysisBarrier\n +- Project [appId#0, attemptId#1, name#2, mode#3, completed#4, duration#5L, endTime#6, endTimeEpoch#7L, lastUpdated#8, lastUpdatedEpoch#9L, sparkUser#10, startTime#11, startTimeEpoch#12L, appSparkVersion#13, cast(endTime#6 as date) AS endDate#28]\n +- Relation[appId#0,attemptId#1,name#2,mode#3,completed#4,duration#5L,endTime#6,endTimeEpoch#7L,lastUpdated#8,lastUpdatedEpoch#9L,sparkUser#10,startTime#11,startTimeEpoch#12L,appSparkVersion#13] avro\n\n== Optimized Logical Plan ==\nCreateViewCommand `apps`, false, true, LocalTempView\n +- AnalysisBarrier\n +- Project [appId#0, attemptId#1, name#2, mode#3, completed#4, duration#5L, endTime#6, endTimeEpoch#7L, lastUpdated#8, lastUpdatedEpoch#9L, sparkUser#10, startTime#11, startTimeEpoch#12L, appSparkVersion#13, cast(endTime#6 as date) AS endDate#28]\n +- Relation[appId#0,attemptId#1,name#2,mode#3,completed#4,duration#5L,endTime#6,endTimeEpoch#7L,lastUpdated#8,lastUpdatedEpoch#9L,sparkUser#10,startTime#11,startTimeEpoch#12L,appSparkVersion#13] avro\n\n== Physical Plan ==\nExecute CreateViewCommand\n +- CreateViewCommand `apps`, false, true, LocalTempView\n +- AnalysisBarrier\n +- Project [appId#0, attemptId#1, name#2, mode#3, completed#4, duration#5L, endTime#6, endTimeEpoch#7L, lastUpdated#8, lastUpdatedEpoch#9L, sparkUser#10, startTime#11, startTimeEpoch#12L, appSparkVersion#13, cast(endTime#6 as date) AS endDate#28]\n +- Relation[appId#0,attemptId#1,name#2,mode#3,completed#4,duration#5L,endTime#6,endTimeEpoch#7L,lastUpdated#8,lastUpdatedEpoch#9L,sparkUser#10,startTime#11,startTimeEpoch#12L,appSparkVersion#13] avro","sparkPlanInfo":{"nodeName":"Execute CreateViewCommand","simpleString":"Execute CreateViewCommand","children":[],"metrics":[]},"time":1524182125829} +{"Event":"org.apache.spark.sql.execution.ui.SparkListenerSQLExecutionEnd","executionId":0,"time":1524182125832} +{"Event":"org.apache.spark.sql.execution.ui.SparkListenerSQLExecutionStart","executionId":1,"description":"createOrReplaceTempView at :40","details":"org.apache.spark.sql.Dataset.createOrReplaceTempView(Dataset.scala:3033)\n$line48.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:40)\n$line48.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:45)\n$line48.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:47)\n$line48.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:49)\n$line48.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:51)\n$line48.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:53)\n$line48.$read$$iw$$iw$$iw$$iw$$iw$$iw.(:55)\n$line48.$read$$iw$$iw$$iw$$iw$$iw.(:57)\n$line48.$read$$iw$$iw$$iw$$iw.(:59)\n$line48.$read$$iw$$iw$$iw.(:61)\n$line48.$read$$iw$$iw.(:63)\n$line48.$read$$iw.(:65)\n$line48.$read.(:67)\n$line48.$read$.(:71)\n$line48.$read$.()\n$line48.$eval$.$print$lzycompute(:7)\n$line48.$eval$.$print(:6)\n$line48.$eval.$print()\nsun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)","physicalPlanDescription":"== Parsed Logical Plan ==\nCreateViewCommand `sys_props`, false, true, LocalTempView\n +- AnalysisBarrier\n +- Aggregate [appId#137], [appId#137, first(if ((key#148 <=> azkaban.link.workflow.url)) value#149 else cast(null as string), true) AS azkaban.link.workflow.url#159, first(if ((key#148 <=> azkaban.link.execution.url)) value#149 else cast(null as string), true) AS azkaban.link.execution.url#161, first(if ((key#148 <=> azkaban.link.job.url)) value#149 else cast(null as string), true) AS azkaban.link.job.url#163, first(if ((key#148 <=> user.name)) value#149 else cast(null as string), true) AS user.name#165]\n +- Project [appId#137, col#145.key AS key#148, col#145.value AS value#149]\n +- Project [appId#137, col#145]\n +- Generate explode(systemProperties#135), false, [col#145]\n +- Relation[runtime#133,sparkProperties#134,systemProperties#135,classpathEntries#136,appId#137,attemptId#138] avro\n\n== Analyzed Logical Plan ==\nCreateViewCommand `sys_props`, false, true, LocalTempView\n +- AnalysisBarrier\n +- Aggregate [appId#137], [appId#137, first(if ((key#148 <=> azkaban.link.workflow.url)) value#149 else cast(null as string), true) AS azkaban.link.workflow.url#159, first(if ((key#148 <=> azkaban.link.execution.url)) value#149 else cast(null as string), true) AS azkaban.link.execution.url#161, first(if ((key#148 <=> azkaban.link.job.url)) value#149 else cast(null as string), true) AS azkaban.link.job.url#163, first(if ((key#148 <=> user.name)) value#149 else cast(null as string), true) AS user.name#165]\n +- Project [appId#137, col#145.key AS key#148, col#145.value AS value#149]\n +- Project [appId#137, col#145]\n +- Generate explode(systemProperties#135), false, [col#145]\n +- Relation[runtime#133,sparkProperties#134,systemProperties#135,classpathEntries#136,appId#137,attemptId#138] avro\n\n== Optimized Logical Plan ==\nCreateViewCommand `sys_props`, false, true, LocalTempView\n +- AnalysisBarrier\n +- Aggregate [appId#137], [appId#137, first(if ((key#148 <=> azkaban.link.workflow.url)) value#149 else cast(null as string), true) AS azkaban.link.workflow.url#159, first(if ((key#148 <=> azkaban.link.execution.url)) value#149 else cast(null as string), true) AS azkaban.link.execution.url#161, first(if ((key#148 <=> azkaban.link.job.url)) value#149 else cast(null as string), true) AS azkaban.link.job.url#163, first(if ((key#148 <=> user.name)) value#149 else cast(null as string), true) AS user.name#165]\n +- Project [appId#137, col#145.key AS key#148, col#145.value AS value#149]\n +- Project [appId#137, col#145]\n +- Generate explode(systemProperties#135), false, [col#145]\n +- Relation[runtime#133,sparkProperties#134,systemProperties#135,classpathEntries#136,appId#137,attemptId#138] avro\n\n== Physical Plan ==\nExecute CreateViewCommand\n +- CreateViewCommand `sys_props`, false, true, LocalTempView\n +- AnalysisBarrier\n +- Aggregate [appId#137], [appId#137, first(if ((key#148 <=> azkaban.link.workflow.url)) value#149 else cast(null as string), true) AS azkaban.link.workflow.url#159, first(if ((key#148 <=> azkaban.link.execution.url)) value#149 else cast(null as string), true) AS azkaban.link.execution.url#161, first(if ((key#148 <=> azkaban.link.job.url)) value#149 else cast(null as string), true) AS azkaban.link.job.url#163, first(if ((key#148 <=> user.name)) value#149 else cast(null as string), true) AS user.name#165]\n +- Project [appId#137, col#145.key AS key#148, col#145.value AS value#149]\n +- Project [appId#137, col#145]\n +- Generate explode(systemProperties#135), false, [col#145]\n +- Relation[runtime#133,sparkProperties#134,systemProperties#135,classpathEntries#136,appId#137,attemptId#138] avro","sparkPlanInfo":{"nodeName":"Execute CreateViewCommand","simpleString":"Execute CreateViewCommand","children":[],"metrics":[]},"time":1524182128463} +{"Event":"org.apache.spark.sql.execution.ui.SparkListenerSQLExecutionEnd","executionId":1,"time":1524182128463} +{"Event":"org.apache.spark.sql.execution.ui.SparkListenerSQLExecutionStart","executionId":2,"description":"show at :40","details":"org.apache.spark.sql.Dataset.show(Dataset.scala:691)\n$line50.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:40)\n$line50.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:45)\n$line50.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:47)\n$line50.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:49)\n$line50.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:51)\n$line50.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:53)\n$line50.$read$$iw$$iw$$iw$$iw$$iw$$iw.(:55)\n$line50.$read$$iw$$iw$$iw$$iw$$iw.(:57)\n$line50.$read$$iw$$iw$$iw$$iw.(:59)\n$line50.$read$$iw$$iw$$iw.(:61)\n$line50.$read$$iw$$iw.(:63)\n$line50.$read$$iw.(:65)\n$line50.$read.(:67)\n$line50.$read$.(:71)\n$line50.$read$.()\n$line50.$eval$.$print$lzycompute(:7)\n$line50.$eval$.$print(:6)\n$line50.$eval.$print()\nsun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)","physicalPlanDescription":"== Parsed Logical Plan ==\nGlobalLimit 21\n+- LocalLimit 21\n +- AnalysisBarrier\n +- Project [cast(appId#0 as string) AS appId#397, cast(attemptId#1 as string) AS attemptId#398, cast(name#2 as string) AS name#399, cast(mode#3 as string) AS mode#400, cast(completed#4 as string) AS completed#401, cast(duration#5L as string) AS duration#402, cast(endTime#6 as string) AS endTime#403, cast(endTimeEpoch#7L as string) AS endTimeEpoch#404, cast(lastUpdated#8 as string) AS lastUpdated#405, cast(lastUpdatedEpoch#9L as string) AS lastUpdatedEpoch#406, cast(sparkUser#10 as string) AS sparkUser#407, cast(startTime#11 as string) AS startTime#408, cast(startTimeEpoch#12L as string) AS startTimeEpoch#409, cast(appSparkVersion#13 as string) AS appSparkVersion#410, cast(endDate#28 as string) AS endDate#411, cast(azkaban.link.workflow.url#159 as string) AS azkaban.link.workflow.url#412, cast(azkaban.link.execution.url#161 as string) AS azkaban.link.execution.url#413, cast(azkaban.link.job.url#163 as string) AS azkaban.link.job.url#414, cast(user.name#165 as string) AS user.name#415]\n +- Project [appId#0, attemptId#1, name#2, mode#3, completed#4, duration#5L, endTime#6, endTimeEpoch#7L, lastUpdated#8, lastUpdatedEpoch#9L, sparkUser#10, startTime#11, startTimeEpoch#12L, appSparkVersion#13, endDate#28, azkaban.link.workflow.url#159, azkaban.link.execution.url#161, azkaban.link.job.url#163, user.name#165]\n +- Join LeftOuter, (appId#0 = appId#137)\n :- Project [appId#0, attemptId#1, name#2, mode#3, completed#4, duration#5L, endTime#6, endTimeEpoch#7L, lastUpdated#8, lastUpdatedEpoch#9L, sparkUser#10, startTime#11, startTimeEpoch#12L, appSparkVersion#13, cast(endTime#6 as date) AS endDate#28]\n : +- Relation[appId#0,attemptId#1,name#2,mode#3,completed#4,duration#5L,endTime#6,endTimeEpoch#7L,lastUpdated#8,lastUpdatedEpoch#9L,sparkUser#10,startTime#11,startTimeEpoch#12L,appSparkVersion#13] avro\n +- Aggregate [appId#137], [appId#137, first(if ((key#148 <=> azkaban.link.workflow.url)) value#149 else cast(null as string), true) AS azkaban.link.workflow.url#159, first(if ((key#148 <=> azkaban.link.execution.url)) value#149 else cast(null as string), true) AS azkaban.link.execution.url#161, first(if ((key#148 <=> azkaban.link.job.url)) value#149 else cast(null as string), true) AS azkaban.link.job.url#163, first(if ((key#148 <=> user.name)) value#149 else cast(null as string), true) AS user.name#165]\n +- Project [appId#137, col#145.key AS key#148, col#145.value AS value#149]\n +- Project [appId#137, col#145]\n +- Generate explode(systemProperties#135), false, [col#145]\n +- Relation[runtime#133,sparkProperties#134,systemProperties#135,classpathEntries#136,appId#137,attemptId#138] avro\n\n== Analyzed Logical Plan ==\nappId: string, attemptId: string, name: string, mode: string, completed: string, duration: string, endTime: string, endTimeEpoch: string, lastUpdated: string, lastUpdatedEpoch: string, sparkUser: string, startTime: string, startTimeEpoch: string, appSparkVersion: string, endDate: string, azkaban.link.workflow.url: string, azkaban.link.execution.url: string, azkaban.link.job.url: string, user.name: string\nGlobalLimit 21\n+- LocalLimit 21\n +- Project [cast(appId#0 as string) AS appId#397, cast(attemptId#1 as string) AS attemptId#398, cast(name#2 as string) AS name#399, cast(mode#3 as string) AS mode#400, cast(completed#4 as string) AS completed#401, cast(duration#5L as string) AS duration#402, cast(endTime#6 as string) AS endTime#403, cast(endTimeEpoch#7L as string) AS endTimeEpoch#404, cast(lastUpdated#8 as string) AS lastUpdated#405, cast(lastUpdatedEpoch#9L as string) AS lastUpdatedEpoch#406, cast(sparkUser#10 as string) AS sparkUser#407, cast(startTime#11 as string) AS startTime#408, cast(startTimeEpoch#12L as string) AS startTimeEpoch#409, cast(appSparkVersion#13 as string) AS appSparkVersion#410, cast(endDate#28 as string) AS endDate#411, cast(azkaban.link.workflow.url#159 as string) AS azkaban.link.workflow.url#412, cast(azkaban.link.execution.url#161 as string) AS azkaban.link.execution.url#413, cast(azkaban.link.job.url#163 as string) AS azkaban.link.job.url#414, cast(user.name#165 as string) AS user.name#415]\n +- Project [appId#0, attemptId#1, name#2, mode#3, completed#4, duration#5L, endTime#6, endTimeEpoch#7L, lastUpdated#8, lastUpdatedEpoch#9L, sparkUser#10, startTime#11, startTimeEpoch#12L, appSparkVersion#13, endDate#28, azkaban.link.workflow.url#159, azkaban.link.execution.url#161, azkaban.link.job.url#163, user.name#165]\n +- Join LeftOuter, (appId#0 = appId#137)\n :- Project [appId#0, attemptId#1, name#2, mode#3, completed#4, duration#5L, endTime#6, endTimeEpoch#7L, lastUpdated#8, lastUpdatedEpoch#9L, sparkUser#10, startTime#11, startTimeEpoch#12L, appSparkVersion#13, cast(endTime#6 as date) AS endDate#28]\n : +- Relation[appId#0,attemptId#1,name#2,mode#3,completed#4,duration#5L,endTime#6,endTimeEpoch#7L,lastUpdated#8,lastUpdatedEpoch#9L,sparkUser#10,startTime#11,startTimeEpoch#12L,appSparkVersion#13] avro\n +- Aggregate [appId#137], [appId#137, first(if ((key#148 <=> azkaban.link.workflow.url)) value#149 else cast(null as string), true) AS azkaban.link.workflow.url#159, first(if ((key#148 <=> azkaban.link.execution.url)) value#149 else cast(null as string), true) AS azkaban.link.execution.url#161, first(if ((key#148 <=> azkaban.link.job.url)) value#149 else cast(null as string), true) AS azkaban.link.job.url#163, first(if ((key#148 <=> user.name)) value#149 else cast(null as string), true) AS user.name#165]\n +- Project [appId#137, col#145.key AS key#148, col#145.value AS value#149]\n +- Project [appId#137, col#145]\n +- Generate explode(systemProperties#135), false, [col#145]\n +- Relation[runtime#133,sparkProperties#134,systemProperties#135,classpathEntries#136,appId#137,attemptId#138] avro\n\n== Optimized Logical Plan ==\nGlobalLimit 21\n+- LocalLimit 21\n +- Project [appId#0, attemptId#1, name#2, mode#3, cast(completed#4 as string) AS completed#401, cast(duration#5L as string) AS duration#402, endTime#6, cast(endTimeEpoch#7L as string) AS endTimeEpoch#404, lastUpdated#8, cast(lastUpdatedEpoch#9L as string) AS lastUpdatedEpoch#406, sparkUser#10, startTime#11, cast(startTimeEpoch#12L as string) AS startTimeEpoch#409, appSparkVersion#13, cast(endDate#28 as string) AS endDate#411, azkaban.link.workflow.url#159, azkaban.link.execution.url#161, azkaban.link.job.url#163, user.name#165]\n +- InMemoryRelation [appId#0, attemptId#1, name#2, mode#3, completed#4, duration#5L, endTime#6, endTimeEpoch#7L, lastUpdated#8, lastUpdatedEpoch#9L, sparkUser#10, startTime#11, startTimeEpoch#12L, appSparkVersion#13, endDate#28, azkaban.link.workflow.url#159, azkaban.link.execution.url#161, azkaban.link.job.url#163, user.name#165], true, 10000, StorageLevel(disk, memory, deserialized, 1 replicas)\n +- *(5) Project [appId#0, attemptId#1, name#2, mode#3, completed#4, duration#5L, endTime#6, endTimeEpoch#7L, lastUpdated#8, lastUpdatedEpoch#9L, sparkUser#10, startTime#11, startTimeEpoch#12L, appSparkVersion#13, endDate#28, azkaban.link.workflow.url#159, azkaban.link.execution.url#161, azkaban.link.job.url#163, user.name#165]\n +- SortMergeJoin [appId#0], [appId#137], LeftOuter\n :- *(1) Sort [appId#0 ASC NULLS FIRST], false, 0\n : +- Exchange hashpartitioning(appId#0, 200)\n : +- InMemoryTableScan [appId#0, attemptId#1, name#2, mode#3, completed#4, duration#5L, endTime#6, endTimeEpoch#7L, lastUpdated#8, lastUpdatedEpoch#9L, sparkUser#10, startTime#11, startTimeEpoch#12L, appSparkVersion#13, endDate#28]\n : +- InMemoryRelation [appId#0, attemptId#1, name#2, mode#3, completed#4, duration#5L, endTime#6, endTimeEpoch#7L, lastUpdated#8, lastUpdatedEpoch#9L, sparkUser#10, startTime#11, startTimeEpoch#12L, appSparkVersion#13, endDate#28], true, 10000, StorageLevel(disk, memory, deserialized, 1 replicas)\n : +- *(1) Project [appId#0, attemptId#1, name#2, mode#3, completed#4, duration#5L, endTime#6, endTimeEpoch#7L, lastUpdated#8, lastUpdatedEpoch#9L, sparkUser#10, startTime#11, startTimeEpoch#12L, appSparkVersion#13, cast(endTime#6 as date) AS endDate#28]\n : +- *(1) FileScan avro [appId#0,attemptId#1,name#2,mode#3,completed#4,duration#5L,endTime#6,endTimeEpoch#7L,lastUpdated#8,lastUpdatedEpoch#9L,sparkUser#10,startTime#11,startTimeEpoch#12L,appSparkVersion#13] Batched: false, Format: com.databricks.spark.avro.DefaultSource@7006b304, Location: InMemoryFileIndex[hdfs://clusternn01.grid.company.com:9000/data/hadoopdev/sparkmetrics/ltx1-..., PartitionFilters: [], PushedFilters: [], ReadSchema: struct azkaban.link.workflow.url)) value#149 else null, true), first(if ((key#148 <=> azkaban.link.execution.url)) value#149 else null, true), first(if ((key#148 <=> azkaban.link.job.url)) value#149 else null, true), first(if ((key#148 <=> user.name)) value#149 else null, true)], output=[appId#137, azkaban.link.workflow.url#159, azkaban.link.execution.url#161, azkaban.link.job.url#163, user.name#165])\n +- *(4) Sort [appId#137 ASC NULLS FIRST], false, 0\n +- Exchange hashpartitioning(appId#137, 200)\n +- SortAggregate(key=[appId#137], functions=[partial_first(if ((key#148 <=> azkaban.link.workflow.url)) value#149 else null, true), partial_first(if ((key#148 <=> azkaban.link.execution.url)) value#149 else null, true), partial_first(if ((key#148 <=> azkaban.link.job.url)) value#149 else null, true), partial_first(if ((key#148 <=> user.name)) value#149 else null, true)], output=[appId#137, first#273, valueSet#274, first#275, valueSet#276, first#277, valueSet#278, first#279, valueSet#280])\n +- *(3) Sort [appId#137 ASC NULLS FIRST], false, 0\n +- *(3) Project [appId#137, col#145.key AS key#148, col#145.value AS value#149]\n +- Generate explode(systemProperties#135), [appId#137], false, [col#145]\n +- *(2) FileScan avro [systemProperties#135,appId#137] Batched: false, Format: com.databricks.spark.avro.DefaultSource@485d3d1, Location: InMemoryFileIndex[hdfs://clusternn01.grid.company.com:9000/data/hadoopdev/sparkmetrics/ltx1-..., PartitionFilters: [], PushedFilters: [], ReadSchema: struct>,appId:string>\n\n== Physical Plan ==\nCollectLimit 21\n+- *(1) LocalLimit 21\n +- *(1) Project [appId#0, attemptId#1, name#2, mode#3, cast(completed#4 as string) AS completed#401, cast(duration#5L as string) AS duration#402, endTime#6, cast(endTimeEpoch#7L as string) AS endTimeEpoch#404, lastUpdated#8, cast(lastUpdatedEpoch#9L as string) AS lastUpdatedEpoch#406, sparkUser#10, startTime#11, cast(startTimeEpoch#12L as string) AS startTimeEpoch#409, appSparkVersion#13, cast(endDate#28 as string) AS endDate#411, azkaban.link.workflow.url#159, azkaban.link.execution.url#161, azkaban.link.job.url#163, user.name#165]\n +- InMemoryTableScan [appId#0, appSparkVersion#13, attemptId#1, azkaban.link.execution.url#161, azkaban.link.job.url#163, azkaban.link.workflow.url#159, completed#4, duration#5L, endDate#28, endTime#6, endTimeEpoch#7L, lastUpdated#8, lastUpdatedEpoch#9L, mode#3, name#2, sparkUser#10, startTime#11, startTimeEpoch#12L, user.name#165]\n +- InMemoryRelation [appId#0, attemptId#1, name#2, mode#3, completed#4, duration#5L, endTime#6, endTimeEpoch#7L, lastUpdated#8, lastUpdatedEpoch#9L, sparkUser#10, startTime#11, startTimeEpoch#12L, appSparkVersion#13, endDate#28, azkaban.link.workflow.url#159, azkaban.link.execution.url#161, azkaban.link.job.url#163, user.name#165], true, 10000, StorageLevel(disk, memory, deserialized, 1 replicas)\n +- *(5) Project [appId#0, attemptId#1, name#2, mode#3, completed#4, duration#5L, endTime#6, endTimeEpoch#7L, lastUpdated#8, lastUpdatedEpoch#9L, sparkUser#10, startTime#11, startTimeEpoch#12L, appSparkVersion#13, endDate#28, azkaban.link.workflow.url#159, azkaban.link.execution.url#161, azkaban.link.job.url#163, user.name#165]\n +- SortMergeJoin [appId#0], [appId#137], LeftOuter\n :- *(1) Sort [appId#0 ASC NULLS FIRST], false, 0\n : +- Exchange hashpartitioning(appId#0, 200)\n : +- InMemoryTableScan [appId#0, attemptId#1, name#2, mode#3, completed#4, duration#5L, endTime#6, endTimeEpoch#7L, lastUpdated#8, lastUpdatedEpoch#9L, sparkUser#10, startTime#11, startTimeEpoch#12L, appSparkVersion#13, endDate#28]\n : +- InMemoryRelation [appId#0, attemptId#1, name#2, mode#3, completed#4, duration#5L, endTime#6, endTimeEpoch#7L, lastUpdated#8, lastUpdatedEpoch#9L, sparkUser#10, startTime#11, startTimeEpoch#12L, appSparkVersion#13, endDate#28], true, 10000, StorageLevel(disk, memory, deserialized, 1 replicas)\n : +- *(1) Project [appId#0, attemptId#1, name#2, mode#3, completed#4, duration#5L, endTime#6, endTimeEpoch#7L, lastUpdated#8, lastUpdatedEpoch#9L, sparkUser#10, startTime#11, startTimeEpoch#12L, appSparkVersion#13, cast(endTime#6 as date) AS endDate#28]\n : +- *(1) FileScan avro [appId#0,attemptId#1,name#2,mode#3,completed#4,duration#5L,endTime#6,endTimeEpoch#7L,lastUpdated#8,lastUpdatedEpoch#9L,sparkUser#10,startTime#11,startTimeEpoch#12L,appSparkVersion#13] Batched: false, Format: com.databricks.spark.avro.DefaultSource@7006b304, Location: InMemoryFileIndex[hdfs://clusternn01.grid.company.com:9000/data/hadoopdev/sparkmetrics/ltx1-..., PartitionFilters: [], PushedFilters: [], ReadSchema: struct azkaban.link.workflow.url)) value#149 else null, true), first(if ((key#148 <=> azkaban.link.execution.url)) value#149 else null, true), first(if ((key#148 <=> azkaban.link.job.url)) value#149 else null, true), first(if ((key#148 <=> user.name)) value#149 else null, true)], output=[appId#137, azkaban.link.workflow.url#159, azkaban.link.execution.url#161, azkaban.link.job.url#163, user.name#165])\n +- *(4) Sort [appId#137 ASC NULLS FIRST], false, 0\n +- Exchange hashpartitioning(appId#137, 200)\n +- SortAggregate(key=[appId#137], functions=[partial_first(if ((key#148 <=> azkaban.link.workflow.url)) value#149 else null, true), partial_first(if ((key#148 <=> azkaban.link.execution.url)) value#149 else null, true), partial_first(if ((key#148 <=> azkaban.link.job.url)) value#149 else null, true), partial_first(if ((key#148 <=> user.name)) value#149 else null, true)], output=[appId#137, first#273, valueSet#274, first#275, valueSet#276, first#277, valueSet#278, first#279, valueSet#280])\n +- *(3) Sort [appId#137 ASC NULLS FIRST], false, 0\n +- *(3) Project [appId#137, col#145.key AS key#148, col#145.value AS value#149]\n +- Generate explode(systemProperties#135), [appId#137], false, [col#145]\n +- *(2) FileScan avro [systemProperties#135,appId#137] Batched: false, Format: com.databricks.spark.avro.DefaultSource@485d3d1, Location: InMemoryFileIndex[hdfs://clusternn01.grid.company.com:9000/data/hadoopdev/sparkmetrics/ltx1-..., PartitionFilters: [], PushedFilters: [], ReadSchema: struct>,appId:string>","sparkPlanInfo":{"nodeName":"CollectLimit","simpleString":"CollectLimit 21","children":[{"nodeName":"WholeStageCodegen","simpleString":"WholeStageCodegen","children":[{"nodeName":"LocalLimit","simpleString":"LocalLimit 21","children":[{"nodeName":"Project","simpleString":"Project [appId#0, attemptId#1, name#2, mode#3, cast(completed#4 as string) AS completed#401, cast(duration#5L as string) AS duration#402, endTime#6, cast(endTimeEpoch#7L as string) AS endTimeEpoch#404, lastUpdated#8, cast(lastUpdatedEpoch#9L as string) AS lastUpdatedEpoch#406, sparkUser#10, startTime#11, cast(startTimeEpoch#12L as string) AS startTimeEpoch#409, appSparkVersion#13, cast(endDate#28 as string) AS endDate#411, azkaban.link.workflow.url#159, azkaban.link.execution.url#161, azkaban.link.job.url#163, user.name#165]","children":[{"nodeName":"InputAdapter","simpleString":"InputAdapter","children":[{"nodeName":"InMemoryTableScan","simpleString":"InMemoryTableScan [appId#0, appSparkVersion#13, attemptId#1, azkaban.link.execution.url#161, azkaban.link.job.url#163, azkaban.link.workflow.url#159, completed#4, duration#5L, endDate#28, endTime#6, endTimeEpoch#7L, lastUpdated#8, lastUpdatedEpoch#9L, mode#3, name#2, sparkUser#10, startTime#11, startTimeEpoch#12L, user.name#165]","children":[],"metrics":[{"name":"number of output rows","accumulatorId":35,"metricType":"sum"},{"name":"scan time total (min, med, max)","accumulatorId":36,"metricType":"timing"}]}],"metrics":[]}],"metrics":[]}],"metrics":[]}],"metrics":[{"name":"duration total (min, med, max)","accumulatorId":34,"metricType":"timing"}]}],"metrics":[]},"time":1524182129952} +{"Event":"SparkListenerJobStart","Job ID":0,"Submission Time":1524182130194,"Stage Infos":[{"Stage ID":0,"Stage Attempt ID":0,"Stage Name":"cache at :41","Number of Tasks":4,"RDD Info":[{"RDD ID":6,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"11\",\"name\":\"Exchange\"}","Callsite":"cache at :41","Parent IDs":[5],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":4,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":0,"Name":"FileScanRDD","Scope":"{\"id\":\"0\",\"name\":\"WholeStageCodegen\"}","Callsite":"cache at :39","Parent IDs":[],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":4,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":2,"Name":"*(1) Project [appId#0, attemptId#1, name#2, mode#3, completed#4, duration#5L, endTime#6, endTimeEpoch#7L, lastUpdated#8, lastUpdatedEpoch#9L, sparkUser#10, startTime#11, startTimeEpoch#12L, appSparkVersion#13, cast(endTime#6 as date) AS endDate#28]\n+- *(1) FileScan avro [appId#0,attemptId#1,name#2,mode#3,completed#4,duration#5L,endTime#6,endTimeEpoch#7L,lastUpdated#8,lastUpdatedEpoch#9L,sparkUser#10,startTime#11,startTimeEpoch#12L,appSparkVersion#13] Batched: false, Format: com.databricks.spark.avro.DefaultSource@7006b304, Location: InMemoryFileIndex[hdfs://clusternn01.grid.company.com:9000/data/hadoopdev/sparkmetrics/ltx1-..., PartitionFilters: [], PushedFilters: [], ReadSchema: struct:39","Parent IDs":[1],"Storage Level":{"Use Disk":true,"Use Memory":true,"Deserialized":true,"Replication":1},"Number of Partitions":4,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":1,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"0\",\"name\":\"WholeStageCodegen\"}","Callsite":"cache at :39","Parent IDs":[0],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":4,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":5,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"12\",\"name\":\"InMemoryTableScan\"}","Callsite":"cache at :41","Parent IDs":[4],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":4,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":4,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"12\",\"name\":\"InMemoryTableScan\"}","Callsite":"cache at :41","Parent IDs":[2],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":4,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0}],"Parent IDs":[],"Details":"org.apache.spark.sql.Dataset.cache(Dataset.scala:2912)\n$line49.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:41)\n$line49.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:46)\n$line49.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:48)\n$line49.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:50)\n$line49.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:52)\n$line49.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:54)\n$line49.$read$$iw$$iw$$iw$$iw$$iw$$iw.(:56)\n$line49.$read$$iw$$iw$$iw$$iw$$iw.(:58)\n$line49.$read$$iw$$iw$$iw$$iw.(:60)\n$line49.$read$$iw$$iw$$iw.(:62)\n$line49.$read$$iw$$iw.(:64)\n$line49.$read$$iw.(:66)\n$line49.$read.(:68)\n$line49.$read$.(:72)\n$line49.$read$.()\n$line49.$eval$.$print$lzycompute(:7)\n$line49.$eval$.$print(:6)\n$line49.$eval.$print()\nsun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)","Accumulables":[]},{"Stage ID":1,"Stage Attempt ID":0,"Stage Name":"cache at :41","Number of Tasks":4,"RDD Info":[{"RDD ID":14,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"17\",\"name\":\"Exchange\"}","Callsite":"cache at :41","Parent IDs":[13],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":4,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":10,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"24\",\"name\":\"WholeStageCodegen\"}","Callsite":"cache at :41","Parent IDs":[9],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":4,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":9,"Name":"FileScanRDD","Scope":"{\"id\":\"24\",\"name\":\"WholeStageCodegen\"}","Callsite":"cache at :41","Parent IDs":[],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":4,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":12,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"19\",\"name\":\"WholeStageCodegen\"}","Callsite":"cache at :41","Parent IDs":[11],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":4,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":11,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"23\",\"name\":\"Generate\"}","Callsite":"cache at :41","Parent IDs":[10],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":4,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":13,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"18\",\"name\":\"SortAggregate\"}","Callsite":"cache at :41","Parent IDs":[12],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":4,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0}],"Parent IDs":[],"Details":"org.apache.spark.sql.Dataset.cache(Dataset.scala:2912)\n$line49.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:41)\n$line49.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:46)\n$line49.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:48)\n$line49.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:50)\n$line49.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:52)\n$line49.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:54)\n$line49.$read$$iw$$iw$$iw$$iw$$iw$$iw.(:56)\n$line49.$read$$iw$$iw$$iw$$iw$$iw.(:58)\n$line49.$read$$iw$$iw$$iw$$iw.(:60)\n$line49.$read$$iw$$iw$$iw.(:62)\n$line49.$read$$iw$$iw.(:64)\n$line49.$read$$iw.(:66)\n$line49.$read.(:68)\n$line49.$read$.(:72)\n$line49.$read$.()\n$line49.$eval$.$print$lzycompute(:7)\n$line49.$eval$.$print(:6)\n$line49.$eval.$print()\nsun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)","Accumulables":[]},{"Stage ID":2,"Stage Attempt ID":0,"Stage Name":"show at :40","Number of Tasks":1,"RDD Info":[{"RDD ID":26,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"33\",\"name\":\"map\"}","Callsite":"show at :40","Parent IDs":[25],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":200,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":25,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"32\",\"name\":\"mapPartitionsInternal\"}","Callsite":"show at :40","Parent IDs":[24],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":200,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":8,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"8\",\"name\":\"WholeStageCodegen\"}","Callsite":"cache at :41","Parent IDs":[7],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":200,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":24,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"27\",\"name\":\"WholeStageCodegen\"}","Callsite":"show at :40","Parent IDs":[23],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":200,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":22,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"31\",\"name\":\"InMemoryTableScan\"}","Callsite":"show at :40","Parent IDs":[20],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":200,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":20,"Name":"*(5) Project [appId#0, attemptId#1, name#2, mode#3, completed#4, duration#5L, endTime#6, endTimeEpoch#7L, lastUpdated#8, lastUpdatedEpoch#9L, sparkUser#10, startTime#11, startTimeEpoch#12L, appSparkVersion#13, endDate#28, azkaban.link.workflow.url#159, azkaban.link.execution.url#161, azkaban.link.job.url#163, user.name#165]\n+- SortMergeJoin [appId#0], [appId#137], LeftOuter\n :- *(1) Sort [appId#0 ASC NULLS FIRST], false, 0\n : +- Exchange hashpartitioning(appId#0, 200)\n : +- InMemoryTableScan [appId#0, attemptId#1, name#2, mode#3, completed#4, duration#5L, endTime#6, endTimeEpoch#7L, lastUpdated#8, lastUpdatedEpoch#9L, sparkUser#10, startTime#11, startTimeEpoch#12L, appSparkVersion#13, endDate#28]\n : +- InMemoryRelation [appId#0, attemptId#1, name#2, mode#3, completed#4, duration#5L, endTime#6, endTimeEpoch#7L, lastUpdated#8, lastUpdatedEpoch#9L, sparkUser#10, startTime#11, startTimeEpoch#12L, appSparkVersion#13, endDate#28], true, 10000, StorageLevel(disk, memory, deserialized, 1 rep...","Scope":"{\"id\":\"26\",\"name\":\"mapPartitionsInternal\"}","Callsite":"cache at :41","Parent IDs":[19],"Storage Level":{"Use Disk":true,"Use Memory":true,"Deserialized":true,"Replication":1},"Number of Partitions":200,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":23,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"31\",\"name\":\"InMemoryTableScan\"}","Callsite":"show at :40","Parent IDs":[22],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":200,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":18,"Name":"ZippedPartitionsRDD2","Scope":"{\"id\":\"7\",\"name\":\"SortMergeJoin\"}","Callsite":"cache at :41","Parent IDs":[8,17],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":200,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":17,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"13\",\"name\":\"SortAggregate\"}","Callsite":"cache at :41","Parent IDs":[16],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":200,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":7,"Name":"ShuffledRowRDD","Scope":"{\"id\":\"11\",\"name\":\"Exchange\"}","Callsite":"cache at :41","Parent IDs":[6],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":200,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":16,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"14\",\"name\":\"WholeStageCodegen\"}","Callsite":"cache at :41","Parent IDs":[15],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":200,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":15,"Name":"ShuffledRowRDD","Scope":"{\"id\":\"17\",\"name\":\"Exchange\"}","Callsite":"cache at :41","Parent IDs":[14],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":200,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":19,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"4\",\"name\":\"WholeStageCodegen\"}","Callsite":"cache at :41","Parent IDs":[18],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":200,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0}],"Parent IDs":[0,1],"Details":"org.apache.spark.sql.Dataset.show(Dataset.scala:691)\n$line50.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:40)\n$line50.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:45)\n$line50.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:47)\n$line50.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:49)\n$line50.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:51)\n$line50.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:53)\n$line50.$read$$iw$$iw$$iw$$iw$$iw$$iw.(:55)\n$line50.$read$$iw$$iw$$iw$$iw$$iw.(:57)\n$line50.$read$$iw$$iw$$iw$$iw.(:59)\n$line50.$read$$iw$$iw$$iw.(:61)\n$line50.$read$$iw$$iw.(:63)\n$line50.$read$$iw.(:65)\n$line50.$read.(:67)\n$line50.$read$.(:71)\n$line50.$read$.()\n$line50.$eval$.$print$lzycompute(:7)\n$line50.$eval$.$print(:6)\n$line50.$eval.$print()\nsun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)","Accumulables":[]}],"Stage IDs":[0,1,2],"Properties":{"spark.sql.execution.id":"2"}} +{"Event":"SparkListenerStageSubmitted","Stage Info":{"Stage ID":0,"Stage Attempt ID":0,"Stage Name":"cache at :41","Number of Tasks":4,"RDD Info":[{"RDD ID":6,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"11\",\"name\":\"Exchange\"}","Callsite":"cache at :41","Parent IDs":[5],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":4,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":0,"Name":"FileScanRDD","Scope":"{\"id\":\"0\",\"name\":\"WholeStageCodegen\"}","Callsite":"cache at :39","Parent IDs":[],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":4,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":2,"Name":"*(1) Project [appId#0, attemptId#1, name#2, mode#3, completed#4, duration#5L, endTime#6, endTimeEpoch#7L, lastUpdated#8, lastUpdatedEpoch#9L, sparkUser#10, startTime#11, startTimeEpoch#12L, appSparkVersion#13, cast(endTime#6 as date) AS endDate#28]\n+- *(1) FileScan avro [appId#0,attemptId#1,name#2,mode#3,completed#4,duration#5L,endTime#6,endTimeEpoch#7L,lastUpdated#8,lastUpdatedEpoch#9L,sparkUser#10,startTime#11,startTimeEpoch#12L,appSparkVersion#13] Batched: false, Format: com.databricks.spark.avro.DefaultSource@7006b304, Location: InMemoryFileIndex[hdfs://clusternn01.grid.company.com:9000/data/hadoopdev/sparkmetrics/ltx1-..., PartitionFilters: [], PushedFilters: [], ReadSchema: struct:39","Parent IDs":[1],"Storage Level":{"Use Disk":true,"Use Memory":true,"Deserialized":true,"Replication":1},"Number of Partitions":4,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":1,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"0\",\"name\":\"WholeStageCodegen\"}","Callsite":"cache at :39","Parent IDs":[0],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":4,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":5,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"12\",\"name\":\"InMemoryTableScan\"}","Callsite":"cache at :41","Parent IDs":[4],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":4,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":4,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"12\",\"name\":\"InMemoryTableScan\"}","Callsite":"cache at :41","Parent IDs":[2],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":4,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0}],"Parent IDs":[],"Details":"org.apache.spark.sql.Dataset.cache(Dataset.scala:2912)\n$line49.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:41)\n$line49.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:46)\n$line49.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:48)\n$line49.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:50)\n$line49.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:52)\n$line49.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:54)\n$line49.$read$$iw$$iw$$iw$$iw$$iw$$iw.(:56)\n$line49.$read$$iw$$iw$$iw$$iw$$iw.(:58)\n$line49.$read$$iw$$iw$$iw$$iw.(:60)\n$line49.$read$$iw$$iw$$iw.(:62)\n$line49.$read$$iw$$iw.(:64)\n$line49.$read$$iw.(:66)\n$line49.$read.(:68)\n$line49.$read$.(:72)\n$line49.$read$.()\n$line49.$eval$.$print$lzycompute(:7)\n$line49.$eval$.$print(:6)\n$line49.$eval.$print()\nsun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)","Submission Time":1524182130229,"Accumulables":[]},"Properties":{"spark.sql.execution.id":"2"}} +{"Event":"SparkListenerStageSubmitted","Stage Info":{"Stage ID":1,"Stage Attempt ID":0,"Stage Name":"cache at :41","Number of Tasks":4,"RDD Info":[{"RDD ID":14,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"17\",\"name\":\"Exchange\"}","Callsite":"cache at :41","Parent IDs":[13],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":4,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":10,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"24\",\"name\":\"WholeStageCodegen\"}","Callsite":"cache at :41","Parent IDs":[9],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":4,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":9,"Name":"FileScanRDD","Scope":"{\"id\":\"24\",\"name\":\"WholeStageCodegen\"}","Callsite":"cache at :41","Parent IDs":[],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":4,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":12,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"19\",\"name\":\"WholeStageCodegen\"}","Callsite":"cache at :41","Parent IDs":[11],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":4,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":11,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"23\",\"name\":\"Generate\"}","Callsite":"cache at :41","Parent IDs":[10],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":4,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":13,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"18\",\"name\":\"SortAggregate\"}","Callsite":"cache at :41","Parent IDs":[12],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":4,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0}],"Parent IDs":[],"Details":"org.apache.spark.sql.Dataset.cache(Dataset.scala:2912)\n$line49.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:41)\n$line49.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:46)\n$line49.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:48)\n$line49.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:50)\n$line49.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:52)\n$line49.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:54)\n$line49.$read$$iw$$iw$$iw$$iw$$iw$$iw.(:56)\n$line49.$read$$iw$$iw$$iw$$iw$$iw.(:58)\n$line49.$read$$iw$$iw$$iw$$iw.(:60)\n$line49.$read$$iw$$iw$$iw.(:62)\n$line49.$read$$iw$$iw.(:64)\n$line49.$read$$iw.(:66)\n$line49.$read.(:68)\n$line49.$read$.(:72)\n$line49.$read$.()\n$line49.$eval$.$print$lzycompute(:7)\n$line49.$eval$.$print(:6)\n$line49.$eval.$print()\nsun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)","Submission Time":1524182130328,"Accumulables":[]},"Properties":{"spark.sql.execution.id":"2"}} +{"Event":"SparkListenerTaskStart","Stage ID":0,"Stage Attempt ID":0,"Task Info":{"Task ID":0,"Index":0,"Attempt":0,"Launch Time":1524182130331,"Executor ID":"2","Host":"node4045.grid.company.com","Locality":"ANY","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":0,"Stage Attempt ID":0,"Task Info":{"Task ID":1,"Index":1,"Attempt":0,"Launch Time":1524182130349,"Executor ID":"3","Host":"node0998.grid.company.com","Locality":"ANY","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":0,"Stage Attempt ID":0,"Task Info":{"Task ID":2,"Index":2,"Attempt":0,"Launch Time":1524182130350,"Executor ID":"4","Host":"node4243.grid.company.com","Locality":"ANY","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":0,"Stage Attempt ID":0,"Task Info":{"Task ID":3,"Index":3,"Attempt":0,"Launch Time":1524182130350,"Executor ID":"1","Host":"node1404.grid.company.com","Locality":"ANY","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":1,"Stage Attempt ID":0,"Task Info":{"Task ID":4,"Index":0,"Attempt":0,"Launch Time":1524182142251,"Executor ID":"1","Host":"node1404.grid.company.com","Locality":"ANY","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":0,"Stage Attempt ID":0,"Task Type":"ShuffleMapTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":3,"Index":3,"Attempt":0,"Launch Time":1524182130350,"Executor ID":"1","Host":"node1404.grid.company.com","Locality":"ANY","Speculative":false,"Getting Result Time":0,"Finish Time":1524182142286,"Failed":false,"Killed":false,"Accumulables":[{"ID":7,"Name":"data size total (min, med, max)","Update":"154334487","Value":"154334486","Internal":true,"Count Failed Values":true,"Metadata":"sql"},{"ID":16,"Name":"number of output rows","Update":"466636","Value":"466636","Internal":true,"Count Failed Values":true,"Metadata":"sql"},{"ID":1,"Name":"number of output rows","Update":"466636","Value":"466636","Internal":true,"Count Failed Values":true,"Metadata":"sql"},{"ID":5,"Name":"duration total (min, med, max)","Update":"19666","Value":"19665","Internal":true,"Count Failed Values":true,"Metadata":"sql"},{"ID":59,"Name":"internal.metrics.input.recordsRead","Update":466636,"Value":466636,"Internal":true,"Count Failed Values":true},{"ID":58,"Name":"internal.metrics.input.bytesRead","Update":37809697,"Value":37809697,"Internal":true,"Count Failed Values":true},{"ID":57,"Name":"internal.metrics.shuffle.write.writeTime","Update":91545212,"Value":91545212,"Internal":true,"Count Failed Values":true},{"ID":56,"Name":"internal.metrics.shuffle.write.recordsWritten","Update":466636,"Value":466636,"Internal":true,"Count Failed Values":true},{"ID":55,"Name":"internal.metrics.shuffle.write.bytesWritten","Update":20002743,"Value":20002743,"Internal":true,"Count Failed Values":true},{"ID":43,"Name":"internal.metrics.resultSerializationTime","Update":2,"Value":2,"Internal":true,"Count Failed Values":true},{"ID":42,"Name":"internal.metrics.jvmGCTime","Update":407,"Value":407,"Internal":true,"Count Failed Values":true},{"ID":41,"Name":"internal.metrics.resultSize","Update":1856,"Value":1856,"Internal":true,"Count Failed Values":true},{"ID":40,"Name":"internal.metrics.executorCpuTime","Update":9020410971,"Value":9020410971,"Internal":true,"Count Failed Values":true},{"ID":39,"Name":"internal.metrics.executorRunTime","Update":11146,"Value":11146,"Internal":true,"Count Failed Values":true},{"ID":38,"Name":"internal.metrics.executorDeserializeCpuTime","Update":574344183,"Value":574344183,"Internal":true,"Count Failed Values":true},{"ID":37,"Name":"internal.metrics.executorDeserializeTime","Update":714,"Value":714,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":714,"Executor Deserialize CPU Time":574344183,"Executor Run Time":11146,"Executor CPU Time":9020410971,"Result Size":1856,"JVM GC Time":407,"Result Serialization Time":2,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":20002743,"Shuffle Write Time":91545212,"Shuffle Records Written":466636},"Input Metrics":{"Bytes Read":37809697,"Records Read":466636},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":1,"Stage Attempt ID":0,"Task Info":{"Task ID":5,"Index":1,"Attempt":0,"Launch Time":1524182142997,"Executor ID":"4","Host":"node4243.grid.company.com","Locality":"ANY","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":0,"Stage Attempt ID":0,"Task Type":"ShuffleMapTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":2,"Index":2,"Attempt":0,"Launch Time":1524182130350,"Executor ID":"4","Host":"node4243.grid.company.com","Locality":"ANY","Speculative":false,"Getting Result Time":0,"Finish Time":1524182143009,"Failed":false,"Killed":false,"Accumulables":[{"ID":7,"Name":"data size total (min, med, max)","Update":"206421303","Value":"360755789","Internal":true,"Count Failed Values":true,"Metadata":"sql"},{"ID":16,"Name":"number of output rows","Update":"624246","Value":"1090882","Internal":true,"Count Failed Values":true,"Metadata":"sql"},{"ID":1,"Name":"number of output rows","Update":"624246","Value":"1090882","Internal":true,"Count Failed Values":true,"Metadata":"sql"},{"ID":5,"Name":"duration total (min, med, max)","Update":"20604","Value":"40269","Internal":true,"Count Failed Values":true,"Metadata":"sql"},{"ID":59,"Name":"internal.metrics.input.recordsRead","Update":624246,"Value":1090882,"Internal":true,"Count Failed Values":true},{"ID":58,"Name":"internal.metrics.input.bytesRead","Update":50423609,"Value":88233306,"Internal":true,"Count Failed Values":true},{"ID":57,"Name":"internal.metrics.shuffle.write.writeTime","Update":104125550,"Value":195670762,"Internal":true,"Count Failed Values":true},{"ID":56,"Name":"internal.metrics.shuffle.write.recordsWritten","Update":624246,"Value":1090882,"Internal":true,"Count Failed Values":true},{"ID":55,"Name":"internal.metrics.shuffle.write.bytesWritten","Update":26424033,"Value":46426776,"Internal":true,"Count Failed Values":true},{"ID":43,"Name":"internal.metrics.resultSerializationTime","Update":1,"Value":3,"Internal":true,"Count Failed Values":true},{"ID":42,"Name":"internal.metrics.jvmGCTime","Update":374,"Value":781,"Internal":true,"Count Failed Values":true},{"ID":41,"Name":"internal.metrics.resultSize","Update":1856,"Value":3712,"Internal":true,"Count Failed Values":true},{"ID":40,"Name":"internal.metrics.executorCpuTime","Update":11039226628,"Value":20059637599,"Internal":true,"Count Failed Values":true},{"ID":39,"Name":"internal.metrics.executorRunTime","Update":11978,"Value":23124,"Internal":true,"Count Failed Values":true},{"ID":38,"Name":"internal.metrics.executorDeserializeCpuTime","Update":526915936,"Value":1101260119,"Internal":true,"Count Failed Values":true},{"ID":37,"Name":"internal.metrics.executorDeserializeTime","Update":622,"Value":1336,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":622,"Executor Deserialize CPU Time":526915936,"Executor Run Time":11978,"Executor CPU Time":11039226628,"Result Size":1856,"JVM GC Time":374,"Result Serialization Time":1,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":26424033,"Shuffle Write Time":104125550,"Shuffle Records Written":624246},"Input Metrics":{"Bytes Read":50423609,"Records Read":624246},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerExecutorAdded","Timestamp":1524182143160,"Executor ID":"5","Executor Info":{"Host":"node2477.grid.company.com","Total Cores":1,"Log Urls":{"stdout":"http://node2477.grid.company.com:8042/node/containerlogs/container_e05_1523494505172_1552404_01_000007/edlu/stdout?start=-4096","stderr":"http://node2477.grid.company.com:8042/node/containerlogs/container_e05_1523494505172_1552404_01_000007/edlu/stderr?start=-4096"}}} +{"Event":"SparkListenerTaskStart","Stage ID":1,"Stage Attempt ID":0,"Task Info":{"Task ID":6,"Index":2,"Attempt":0,"Launch Time":1524182143166,"Executor ID":"5","Host":"node2477.grid.company.com","Locality":"ANY","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerBlockManagerAdded","Block Manager ID":{"Executor ID":"5","Host":"node2477.grid.company.com","Port":20123},"Maximum Memory":956615884,"Timestamp":1524182143406,"Maximum Onheap Memory":956615884,"Maximum Offheap Memory":0} +{"Event":"SparkListenerTaskStart","Stage ID":1,"Stage Attempt ID":0,"Task Info":{"Task ID":7,"Index":3,"Attempt":0,"Launch Time":1524182144237,"Executor ID":"1","Host":"node1404.grid.company.com","Locality":"ANY","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":1,"Stage Attempt ID":0,"Task Type":"ShuffleMapTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":4,"Index":0,"Attempt":0,"Launch Time":1524182142251,"Executor ID":"1","Host":"node1404.grid.company.com","Locality":"ANY","Speculative":false,"Getting Result Time":0,"Finish Time":1524182144246,"Failed":false,"Killed":false,"Accumulables":[{"ID":8,"Name":"data size total (min, med, max)","Update":"1920975","Value":"1920974","Internal":true,"Count Failed Values":true,"Metadata":"sql"},{"ID":23,"Name":"number of output rows","Update":"3562","Value":"3562","Internal":true,"Count Failed Values":true,"Metadata":"sql"},{"ID":25,"Name":"peak memory total (min, med, max)","Update":"41943039","Value":"41943038","Internal":true,"Count Failed Values":true,"Metadata":"sql"},{"ID":24,"Name":"sort time total (min, med, max)","Update":"38","Value":"37","Internal":true,"Count Failed Values":true,"Metadata":"sql"},{"ID":27,"Name":"duration total (min, med, max)","Update":"1813","Value":"1812","Internal":true,"Count Failed Values":true,"Metadata":"sql"},{"ID":28,"Name":"number of output rows","Update":"195602","Value":"195602","Internal":true,"Count Failed Values":true,"Metadata":"sql"},{"ID":29,"Name":"number of output rows","Update":"3563","Value":"3563","Internal":true,"Count Failed Values":true,"Metadata":"sql"},{"ID":33,"Name":"duration total (min, med, max)","Update":"1558","Value":"1557","Internal":true,"Count Failed Values":true,"Metadata":"sql"},{"ID":84,"Name":"internal.metrics.input.recordsRead","Update":3563,"Value":3563,"Internal":true,"Count Failed Values":true},{"ID":83,"Name":"internal.metrics.input.bytesRead","Update":36845111,"Value":36845111,"Internal":true,"Count Failed Values":true},{"ID":82,"Name":"internal.metrics.shuffle.write.writeTime","Update":27318908,"Value":27318908,"Internal":true,"Count Failed Values":true},{"ID":81,"Name":"internal.metrics.shuffle.write.recordsWritten","Update":3562,"Value":3562,"Internal":true,"Count Failed Values":true},{"ID":80,"Name":"internal.metrics.shuffle.write.bytesWritten","Update":349287,"Value":349287,"Internal":true,"Count Failed Values":true},{"ID":71,"Name":"internal.metrics.peakExecutionMemory","Update":41943040,"Value":41943040,"Internal":true,"Count Failed Values":true},{"ID":67,"Name":"internal.metrics.jvmGCTime","Update":33,"Value":33,"Internal":true,"Count Failed Values":true},{"ID":66,"Name":"internal.metrics.resultSize","Update":2394,"Value":2394,"Internal":true,"Count Failed Values":true},{"ID":65,"Name":"internal.metrics.executorCpuTime","Update":1498974375,"Value":1498974375,"Internal":true,"Count Failed Values":true},{"ID":64,"Name":"internal.metrics.executorRunTime","Update":1922,"Value":1922,"Internal":true,"Count Failed Values":true},{"ID":63,"Name":"internal.metrics.executorDeserializeCpuTime","Update":49547405,"Value":49547405,"Internal":true,"Count Failed Values":true},{"ID":62,"Name":"internal.metrics.executorDeserializeTime","Update":56,"Value":56,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":56,"Executor Deserialize CPU Time":49547405,"Executor Run Time":1922,"Executor CPU Time":1498974375,"Result Size":2394,"JVM GC Time":33,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":349287,"Shuffle Write Time":27318908,"Shuffle Records Written":3562},"Input Metrics":{"Bytes Read":36845111,"Records Read":3563},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":0,"Stage Attempt ID":0,"Task Type":"ShuffleMapTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":0,"Index":0,"Attempt":0,"Launch Time":1524182130331,"Executor ID":"2","Host":"node4045.grid.company.com","Locality":"ANY","Speculative":false,"Getting Result Time":0,"Finish Time":1524182144444,"Failed":false,"Killed":false,"Accumulables":[{"ID":7,"Name":"data size total (min, med, max)","Update":"204058975","Value":"564814764","Internal":true,"Count Failed Values":true,"Metadata":"sql"},{"ID":16,"Name":"number of output rows","Update":"616897","Value":"1707779","Internal":true,"Count Failed Values":true,"Metadata":"sql"},{"ID":1,"Name":"number of output rows","Update":"616897","Value":"1707779","Internal":true,"Count Failed Values":true,"Metadata":"sql"},{"ID":5,"Name":"duration total (min, med, max)","Update":"23365","Value":"63634","Internal":true,"Count Failed Values":true,"Metadata":"sql"},{"ID":59,"Name":"internal.metrics.input.recordsRead","Update":616897,"Value":1707779,"Internal":true,"Count Failed Values":true},{"ID":58,"Name":"internal.metrics.input.bytesRead","Update":50423423,"Value":138656729,"Internal":true,"Count Failed Values":true},{"ID":57,"Name":"internal.metrics.shuffle.write.writeTime","Update":105575962,"Value":301246724,"Internal":true,"Count Failed Values":true},{"ID":56,"Name":"internal.metrics.shuffle.write.recordsWritten","Update":616897,"Value":1707779,"Internal":true,"Count Failed Values":true},{"ID":55,"Name":"internal.metrics.shuffle.write.bytesWritten","Update":22950296,"Value":69377072,"Internal":true,"Count Failed Values":true},{"ID":43,"Name":"internal.metrics.resultSerializationTime","Update":2,"Value":5,"Internal":true,"Count Failed Values":true},{"ID":42,"Name":"internal.metrics.jvmGCTime","Update":326,"Value":1107,"Internal":true,"Count Failed Values":true},{"ID":41,"Name":"internal.metrics.resultSize","Update":1856,"Value":5568,"Internal":true,"Count Failed Values":true},{"ID":40,"Name":"internal.metrics.executorCpuTime","Update":11931694025,"Value":31991331624,"Internal":true,"Count Failed Values":true},{"ID":39,"Name":"internal.metrics.executorRunTime","Update":13454,"Value":36578,"Internal":true,"Count Failed Values":true},{"ID":38,"Name":"internal.metrics.executorDeserializeCpuTime","Update":531799977,"Value":1633060096,"Internal":true,"Count Failed Values":true},{"ID":37,"Name":"internal.metrics.executorDeserializeTime","Update":594,"Value":1930,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":594,"Executor Deserialize CPU Time":531799977,"Executor Run Time":13454,"Executor CPU Time":11931694025,"Result Size":1856,"JVM GC Time":326,"Result Serialization Time":2,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":22950296,"Shuffle Write Time":105575962,"Shuffle Records Written":616897},"Input Metrics":{"Bytes Read":50423423,"Records Read":616897},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":0,"Stage Attempt ID":0,"Task Type":"ShuffleMapTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":1,"Index":1,"Attempt":0,"Launch Time":1524182130349,"Executor ID":"3","Host":"node0998.grid.company.com","Locality":"ANY","Speculative":false,"Getting Result Time":0,"Finish Time":1524182144840,"Failed":false,"Killed":false,"Accumulables":[{"ID":7,"Name":"data size total (min, med, max)","Update":"207338935","Value":"772153699","Internal":true,"Count Failed Values":true,"Metadata":"sql"},{"ID":16,"Name":"number of output rows","Update":"626277","Value":"2334056","Internal":true,"Count Failed Values":true,"Metadata":"sql"},{"ID":1,"Name":"number of output rows","Update":"626277","Value":"2334056","Internal":true,"Count Failed Values":true,"Metadata":"sql"},{"ID":5,"Name":"duration total (min, med, max)","Update":"24254","Value":"87888","Internal":true,"Count Failed Values":true,"Metadata":"sql"},{"ID":59,"Name":"internal.metrics.input.recordsRead","Update":626277,"Value":2334056,"Internal":true,"Count Failed Values":true},{"ID":58,"Name":"internal.metrics.input.bytesRead","Update":50409514,"Value":189066243,"Internal":true,"Count Failed Values":true},{"ID":57,"Name":"internal.metrics.shuffle.write.writeTime","Update":106963069,"Value":408209793,"Internal":true,"Count Failed Values":true},{"ID":56,"Name":"internal.metrics.shuffle.write.recordsWritten","Update":626277,"Value":2334056,"Internal":true,"Count Failed Values":true},{"ID":55,"Name":"internal.metrics.shuffle.write.bytesWritten","Update":31362123,"Value":100739195,"Internal":true,"Count Failed Values":true},{"ID":43,"Name":"internal.metrics.resultSerializationTime","Update":2,"Value":7,"Internal":true,"Count Failed Values":true},{"ID":42,"Name":"internal.metrics.jvmGCTime","Update":342,"Value":1449,"Internal":true,"Count Failed Values":true},{"ID":41,"Name":"internal.metrics.resultSize","Update":1856,"Value":7424,"Internal":true,"Count Failed Values":true},{"ID":40,"Name":"internal.metrics.executorCpuTime","Update":12267596062,"Value":44258927686,"Internal":true,"Count Failed Values":true},{"ID":39,"Name":"internal.metrics.executorRunTime","Update":13858,"Value":50436,"Internal":true,"Count Failed Values":true},{"ID":38,"Name":"internal.metrics.executorDeserializeCpuTime","Update":519573839,"Value":2152633935,"Internal":true,"Count Failed Values":true},{"ID":37,"Name":"internal.metrics.executorDeserializeTime","Update":573,"Value":2503,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":573,"Executor Deserialize CPU Time":519573839,"Executor Run Time":13858,"Executor CPU Time":12267596062,"Result Size":1856,"JVM GC Time":342,"Result Serialization Time":2,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":31362123,"Shuffle Write Time":106963069,"Shuffle Records Written":626277},"Input Metrics":{"Bytes Read":50409514,"Records Read":626277},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerStageExecutorMetrics","Executor ID":"driver","Stage ID":0,"Stage Attempt ID":0,"Executor Metrics":{"JVMHeapMemory":592412824,"JVMOffHeapMemory":202907152,"OnHeapExecutionMemory":0,"OffHeapExecutionMemory":0,"OnHeapStorageMemory":905801,"OffHeapStorageMemory":0,"OnHeapUnifiedMemory":905801,"OffHeapUnifiedMemory":0,"DirectPoolMemory":355389,"MappedPoolMemory":0}} +{"Event":"SparkListenerStageExecutorMetrics","Executor ID":"2","Stage ID":0,"Stage Attempt ID":0,"Executor Metrics":{"JVMHeapMemory":523121272,"JVMOffHeapMemory":88280720,"OnHeapExecutionMemory":0,"OffHeapExecutionMemory":0,"OnHeapStorageMemory":52050147,"OffHeapStorageMemory":0,"OnHeapUnifiedMemory":52050147,"OffHeapUnifiedMemory":0,"DirectPoolMemory":87796,"MappedPoolMemory":0}} +{"Event":"SparkListenerStageExecutorMetrics","Executor ID":"1","Stage ID":0,"Stage Attempt ID":0,"Executor Metrics":{"JVMHeapMemory":214174608,"JVMOffHeapMemory":91548704,"OnHeapExecutionMemory":0,"OffHeapExecutionMemory":0,"OnHeapStorageMemory":47399168,"OffHeapStorageMemory":0,"OnHeapUnifiedMemory":47399168,"OffHeapUnifiedMemory":0,"DirectPoolMemory":87796,"MappedPoolMemory":0}} +{"Event":"SparkListenerStageExecutorMetrics","Executor ID":"4","Stage ID":0,"Stage Attempt ID":0,"Executor Metrics":{"JVMHeapMemory":518613056,"JVMOffHeapMemory":95657456,"OnHeapExecutionMemory":37748736,"OffHeapExecutionMemory":0,"OnHeapStorageMemory":63104457,"OffHeapStorageMemory":0,"OnHeapUnifiedMemory":100853193,"OffHeapUnifiedMemory":0,"DirectPoolMemory":126261,"MappedPoolMemory":0}} +{"Event":"SparkListenerStageExecutorMetrics","Executor ID":"3","Stage ID":0,"Stage Attempt ID":0,"Executor Metrics":{"JVMHeapMemory":726805712,"JVMOffHeapMemory":90709624,"OnHeapExecutionMemory":0,"OffHeapExecutionMemory":0,"OnHeapStorageMemory":69535048,"OffHeapStorageMemory":0,"OnHeapUnifiedMemory":69535048,"OffHeapUnifiedMemory":0,"DirectPoolMemory":87796,"MappedPoolMemory":0}} +{"Event":"SparkListenerStageCompleted","Stage Info":{"Stage ID":0,"Stage Attempt ID":0,"Stage Name":"cache at :41","Number of Tasks":4,"RDD Info":[{"RDD ID":6,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"11\",\"name\":\"Exchange\"}","Callsite":"cache at :41","Parent IDs":[5],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":4,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":0,"Name":"FileScanRDD","Scope":"{\"id\":\"0\",\"name\":\"WholeStageCodegen\"}","Callsite":"cache at :39","Parent IDs":[],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":4,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":2,"Name":"*(1) Project [appId#0, attemptId#1, name#2, mode#3, completed#4, duration#5L, endTime#6, endTimeEpoch#7L, lastUpdated#8, lastUpdatedEpoch#9L, sparkUser#10, startTime#11, startTimeEpoch#12L, appSparkVersion#13, cast(endTime#6 as date) AS endDate#28]\n+- *(1) FileScan avro [appId#0,attemptId#1,name#2,mode#3,completed#4,duration#5L,endTime#6,endTimeEpoch#7L,lastUpdated#8,lastUpdatedEpoch#9L,sparkUser#10,startTime#11,startTimeEpoch#12L,appSparkVersion#13] Batched: false, Format: com.databricks.spark.avro.DefaultSource@7006b304, Location: InMemoryFileIndex[hdfs://clusternn01.grid.company.com:9000/data/hadoopdev/sparkmetrics/ltx1-..., PartitionFilters: [], PushedFilters: [], ReadSchema: struct:39","Parent IDs":[1],"Storage Level":{"Use Disk":true,"Use Memory":true,"Deserialized":true,"Replication":1},"Number of Partitions":4,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":1,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"0\",\"name\":\"WholeStageCodegen\"}","Callsite":"cache at :39","Parent IDs":[0],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":4,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":5,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"12\",\"name\":\"InMemoryTableScan\"}","Callsite":"cache at :41","Parent IDs":[4],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":4,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":4,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"12\",\"name\":\"InMemoryTableScan\"}","Callsite":"cache at :41","Parent IDs":[2],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":4,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0}],"Parent IDs":[],"Details":"org.apache.spark.sql.Dataset.cache(Dataset.scala:2912)\n$line49.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:41)\n$line49.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:46)\n$line49.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:48)\n$line49.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:50)\n$line49.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:52)\n$line49.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:54)\n$line49.$read$$iw$$iw$$iw$$iw$$iw$$iw.(:56)\n$line49.$read$$iw$$iw$$iw$$iw$$iw.(:58)\n$line49.$read$$iw$$iw$$iw$$iw.(:60)\n$line49.$read$$iw$$iw$$iw.(:62)\n$line49.$read$$iw$$iw.(:64)\n$line49.$read$$iw.(:66)\n$line49.$read.(:68)\n$line49.$read$.(:72)\n$line49.$read$.()\n$line49.$eval$.$print$lzycompute(:7)\n$line49.$eval$.$print(:6)\n$line49.$eval.$print()\nsun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)","Submission Time":1524182130229,"Completion Time":1524182144852,"Accumulables":[{"ID":41,"Name":"internal.metrics.resultSize","Value":7424,"Internal":true,"Count Failed Values":true},{"ID":59,"Name":"internal.metrics.input.recordsRead","Value":2334056,"Internal":true,"Count Failed Values":true},{"ID":38,"Name":"internal.metrics.executorDeserializeCpuTime","Value":2152633935,"Internal":true,"Count Failed Values":true},{"ID":56,"Name":"internal.metrics.shuffle.write.recordsWritten","Value":2334056,"Internal":true,"Count Failed Values":true},{"ID":5,"Name":"duration total (min, med, max)","Value":"87888","Internal":true,"Count Failed Values":true,"Metadata":"sql"},{"ID":55,"Name":"internal.metrics.shuffle.write.bytesWritten","Value":100739195,"Internal":true,"Count Failed Values":true},{"ID":40,"Name":"internal.metrics.executorCpuTime","Value":44258927686,"Internal":true,"Count Failed Values":true},{"ID":58,"Name":"internal.metrics.input.bytesRead","Value":189066243,"Internal":true,"Count Failed Values":true},{"ID":7,"Name":"data size total (min, med, max)","Value":"772153699","Internal":true,"Count Failed Values":true,"Metadata":"sql"},{"ID":16,"Name":"number of output rows","Value":"2334056","Internal":true,"Count Failed Values":true,"Metadata":"sql"},{"ID":43,"Name":"internal.metrics.resultSerializationTime","Value":7,"Internal":true,"Count Failed Values":true},{"ID":1,"Name":"number of output rows","Value":"2334056","Internal":true,"Count Failed Values":true,"Metadata":"sql"},{"ID":37,"Name":"internal.metrics.executorDeserializeTime","Value":2503,"Internal":true,"Count Failed Values":true},{"ID":57,"Name":"internal.metrics.shuffle.write.writeTime","Value":408209793,"Internal":true,"Count Failed Values":true},{"ID":39,"Name":"internal.metrics.executorRunTime","Value":50436,"Internal":true,"Count Failed Values":true},{"ID":42,"Name":"internal.metrics.jvmGCTime","Value":1449,"Internal":true,"Count Failed Values":true}]}} +{"Event":"SparkListenerTaskEnd","Stage ID":1,"Stage Attempt ID":0,"Task Type":"ShuffleMapTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":5,"Index":1,"Attempt":0,"Launch Time":1524182142997,"Executor ID":"4","Host":"node4243.grid.company.com","Locality":"ANY","Speculative":false,"Getting Result Time":0,"Finish Time":1524182145327,"Failed":false,"Killed":false,"Accumulables":[{"ID":8,"Name":"data size total (min, med, max)","Update":"1953295","Value":"3874269","Internal":true,"Count Failed Values":true,"Metadata":"sql"},{"ID":23,"Name":"number of output rows","Update":"3575","Value":"7137","Internal":true,"Count Failed Values":true,"Metadata":"sql"},{"ID":25,"Name":"peak memory total (min, med, max)","Update":"41943039","Value":"83886077","Internal":true,"Count Failed Values":true,"Metadata":"sql"},{"ID":24,"Name":"sort time total (min, med, max)","Update":"49","Value":"86","Internal":true,"Count Failed Values":true,"Metadata":"sql"},{"ID":27,"Name":"duration total (min, med, max)","Update":"2002","Value":"3814","Internal":true,"Count Failed Values":true,"Metadata":"sql"},{"ID":28,"Name":"number of output rows","Update":"196587","Value":"392189","Internal":true,"Count Failed Values":true,"Metadata":"sql"},{"ID":29,"Name":"number of output rows","Update":"3575","Value":"7138","Internal":true,"Count Failed Values":true,"Metadata":"sql"},{"ID":33,"Name":"duration total (min, med, max)","Update":"1755","Value":"3312","Internal":true,"Count Failed Values":true,"Metadata":"sql"},{"ID":84,"Name":"internal.metrics.input.recordsRead","Update":3575,"Value":7138,"Internal":true,"Count Failed Values":true},{"ID":83,"Name":"internal.metrics.input.bytesRead","Update":36849246,"Value":73694357,"Internal":true,"Count Failed Values":true},{"ID":82,"Name":"internal.metrics.shuffle.write.writeTime","Update":32035583,"Value":59354491,"Internal":true,"Count Failed Values":true},{"ID":81,"Name":"internal.metrics.shuffle.write.recordsWritten","Update":3575,"Value":7137,"Internal":true,"Count Failed Values":true},{"ID":80,"Name":"internal.metrics.shuffle.write.bytesWritten","Update":349006,"Value":698293,"Internal":true,"Count Failed Values":true},{"ID":71,"Name":"internal.metrics.peakExecutionMemory","Update":41943040,"Value":83886080,"Internal":true,"Count Failed Values":true},{"ID":67,"Name":"internal.metrics.jvmGCTime","Update":31,"Value":64,"Internal":true,"Count Failed Values":true},{"ID":66,"Name":"internal.metrics.resultSize","Update":2394,"Value":4788,"Internal":true,"Count Failed Values":true},{"ID":65,"Name":"internal.metrics.executorCpuTime","Update":1785119941,"Value":3284094316,"Internal":true,"Count Failed Values":true},{"ID":64,"Name":"internal.metrics.executorRunTime","Update":2182,"Value":4104,"Internal":true,"Count Failed Values":true},{"ID":63,"Name":"internal.metrics.executorDeserializeCpuTime","Update":71500541,"Value":121047946,"Internal":true,"Count Failed Values":true},{"ID":62,"Name":"internal.metrics.executorDeserializeTime","Update":136,"Value":192,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":136,"Executor Deserialize CPU Time":71500541,"Executor Run Time":2182,"Executor CPU Time":1785119941,"Result Size":2394,"JVM GC Time":31,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":349006,"Shuffle Write Time":32035583,"Shuffle Records Written":3575},"Input Metrics":{"Bytes Read":36849246,"Records Read":3575},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":1,"Stage Attempt ID":0,"Task Type":"ShuffleMapTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":7,"Index":3,"Attempt":0,"Launch Time":1524182144237,"Executor ID":"1","Host":"node1404.grid.company.com","Locality":"ANY","Speculative":false,"Getting Result Time":0,"Finish Time":1524182145971,"Failed":false,"Killed":false,"Accumulables":[{"ID":8,"Name":"data size total (min, med, max)","Update":"1337999","Value":"5212268","Internal":true,"Count Failed Values":true,"Metadata":"sql"},{"ID":23,"Name":"number of output rows","Update":"2435","Value":"9572","Internal":true,"Count Failed Values":true,"Metadata":"sql"},{"ID":25,"Name":"peak memory total (min, med, max)","Update":"37748735","Value":"121634812","Internal":true,"Count Failed Values":true,"Metadata":"sql"},{"ID":24,"Name":"sort time total (min, med, max)","Update":"9","Value":"95","Internal":true,"Count Failed Values":true,"Metadata":"sql"},{"ID":27,"Name":"duration total (min, med, max)","Update":"1703","Value":"5517","Internal":true,"Count Failed Values":true,"Metadata":"sql"},{"ID":28,"Name":"number of output rows","Update":"133759","Value":"525948","Internal":true,"Count Failed Values":true,"Metadata":"sql"},{"ID":29,"Name":"number of output rows","Update":"2435","Value":"9573","Internal":true,"Count Failed Values":true,"Metadata":"sql"},{"ID":33,"Name":"duration total (min, med, max)","Update":"1609","Value":"4921","Internal":true,"Count Failed Values":true,"Metadata":"sql"},{"ID":84,"Name":"internal.metrics.input.recordsRead","Update":2435,"Value":9573,"Internal":true,"Count Failed Values":true},{"ID":83,"Name":"internal.metrics.input.bytesRead","Update":24250210,"Value":97944567,"Internal":true,"Count Failed Values":true},{"ID":82,"Name":"internal.metrics.shuffle.write.writeTime","Update":20055909,"Value":79410400,"Internal":true,"Count Failed Values":true},{"ID":81,"Name":"internal.metrics.shuffle.write.recordsWritten","Update":2435,"Value":9572,"Internal":true,"Count Failed Values":true},{"ID":80,"Name":"internal.metrics.shuffle.write.bytesWritten","Update":242714,"Value":941007,"Internal":true,"Count Failed Values":true},{"ID":71,"Name":"internal.metrics.peakExecutionMemory","Update":37748736,"Value":121634816,"Internal":true,"Count Failed Values":true},{"ID":67,"Name":"internal.metrics.jvmGCTime","Update":31,"Value":95,"Internal":true,"Count Failed Values":true},{"ID":66,"Name":"internal.metrics.resultSize","Update":2394,"Value":7182,"Internal":true,"Count Failed Values":true},{"ID":65,"Name":"internal.metrics.executorCpuTime","Update":896878991,"Value":4180973307,"Internal":true,"Count Failed Values":true},{"ID":64,"Name":"internal.metrics.executorRunTime","Update":1722,"Value":5826,"Internal":true,"Count Failed Values":true},{"ID":63,"Name":"internal.metrics.executorDeserializeCpuTime","Update":2787355,"Value":123835301,"Internal":true,"Count Failed Values":true},{"ID":62,"Name":"internal.metrics.executorDeserializeTime","Update":3,"Value":195,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":3,"Executor Deserialize CPU Time":2787355,"Executor Run Time":1722,"Executor CPU Time":896878991,"Result Size":2394,"JVM GC Time":31,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":242714,"Shuffle Write Time":20055909,"Shuffle Records Written":2435},"Input Metrics":{"Bytes Read":24250210,"Records Read":2435},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerExecutorAdded","Timestamp":1524182147549,"Executor ID":"6","Executor Info":{"Host":"node6644.grid.company.com","Total Cores":1,"Log Urls":{"stdout":"http://node6644.grid.company.com:8042/node/containerlogs/container_e05_1523494505172_1552404_01_000008/edlu/stdout?start=-4096","stderr":"http://node6644.grid.company.com:8042/node/containerlogs/container_e05_1523494505172_1552404_01_000008/edlu/stderr?start=-4096"}}} +{"Event":"SparkListenerBlockManagerAdded","Block Manager ID":{"Executor ID":"6","Host":"node6644.grid.company.com","Port":8445},"Maximum Memory":956615884,"Timestamp":1524182147706,"Maximum Onheap Memory":956615884,"Maximum Offheap Memory":0} +{"Event":"SparkListenerExecutorAdded","Timestamp":1524182149826,"Executor ID":"7","Executor Info":{"Host":"node6340.grid.company.com","Total Cores":1,"Log Urls":{"stdout":"http://node6340.grid.company.com:8042/node/containerlogs/container_e05_1523494505172_1552404_01_000009/edlu/stdout?start=-4096","stderr":"http://node6340.grid.company.com:8042/node/containerlogs/container_e05_1523494505172_1552404_01_000009/edlu/stderr?start=-4096"}}} +{"Event":"SparkListenerBlockManagerAdded","Block Manager ID":{"Executor ID":"7","Host":"node6340.grid.company.com","Port":5933},"Maximum Memory":956615884,"Timestamp":1524182149983,"Maximum Onheap Memory":956615884,"Maximum Offheap Memory":0} +{"Event":"SparkListenerTaskEnd","Stage ID":1,"Stage Attempt ID":0,"Task Type":"ShuffleMapTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":6,"Index":2,"Attempt":0,"Launch Time":1524182143166,"Executor ID":"5","Host":"node2477.grid.company.com","Locality":"ANY","Speculative":false,"Getting Result Time":0,"Finish Time":1524182152418,"Failed":false,"Killed":false,"Accumulables":[{"ID":8,"Name":"data size total (min, med, max)","Update":"1910103","Value":"7122371","Internal":true,"Count Failed Values":true,"Metadata":"sql"},{"ID":23,"Name":"number of output rows","Update":"3541","Value":"13113","Internal":true,"Count Failed Values":true,"Metadata":"sql"},{"ID":25,"Name":"peak memory total (min, med, max)","Update":"41943039","Value":"163577851","Internal":true,"Count Failed Values":true,"Metadata":"sql"},{"ID":24,"Name":"sort time total (min, med, max)","Update":"48","Value":"143","Internal":true,"Count Failed Values":true,"Metadata":"sql"},{"ID":27,"Name":"duration total (min, med, max)","Update":"6093","Value":"11610","Internal":true,"Count Failed Values":true,"Metadata":"sql"},{"ID":28,"Name":"number of output rows","Update":"194553","Value":"720501","Internal":true,"Count Failed Values":true,"Metadata":"sql"},{"ID":29,"Name":"number of output rows","Update":"3541","Value":"13114","Internal":true,"Count Failed Values":true,"Metadata":"sql"},{"ID":33,"Name":"duration total (min, med, max)","Update":"5951","Value":"10872","Internal":true,"Count Failed Values":true,"Metadata":"sql"},{"ID":84,"Name":"internal.metrics.input.recordsRead","Update":3541,"Value":13114,"Internal":true,"Count Failed Values":true},{"ID":83,"Name":"internal.metrics.input.bytesRead","Update":36838295,"Value":134782862,"Internal":true,"Count Failed Values":true},{"ID":82,"Name":"internal.metrics.shuffle.write.writeTime","Update":49790497,"Value":129200897,"Internal":true,"Count Failed Values":true},{"ID":81,"Name":"internal.metrics.shuffle.write.recordsWritten","Update":3541,"Value":13113,"Internal":true,"Count Failed Values":true},{"ID":80,"Name":"internal.metrics.shuffle.write.bytesWritten","Update":355051,"Value":1296058,"Internal":true,"Count Failed Values":true},{"ID":71,"Name":"internal.metrics.peakExecutionMemory","Update":41943040,"Value":163577856,"Internal":true,"Count Failed Values":true},{"ID":68,"Name":"internal.metrics.resultSerializationTime","Update":2,"Value":2,"Internal":true,"Count Failed Values":true},{"ID":67,"Name":"internal.metrics.jvmGCTime","Update":920,"Value":1015,"Internal":true,"Count Failed Values":true},{"ID":66,"Name":"internal.metrics.resultSize","Update":2437,"Value":9619,"Internal":true,"Count Failed Values":true},{"ID":65,"Name":"internal.metrics.executorCpuTime","Update":5299274511,"Value":9480247818,"Internal":true,"Count Failed Values":true},{"ID":64,"Name":"internal.metrics.executorRunTime","Update":7847,"Value":13673,"Internal":true,"Count Failed Values":true},{"ID":63,"Name":"internal.metrics.executorDeserializeCpuTime","Update":687811857,"Value":811647158,"Internal":true,"Count Failed Values":true},{"ID":62,"Name":"internal.metrics.executorDeserializeTime","Update":1037,"Value":1232,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":1037,"Executor Deserialize CPU Time":687811857,"Executor Run Time":7847,"Executor CPU Time":5299274511,"Result Size":2437,"JVM GC Time":920,"Result Serialization Time":2,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":355051,"Shuffle Write Time":49790497,"Shuffle Records Written":3541},"Input Metrics":{"Bytes Read":36838295,"Records Read":3541},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerStageExecutorMetrics","Executor ID":"driver","Stage ID":1,"Stage Attempt ID":0,"Executor Metrics":{"JVMHeapMemory":629553808,"JVMOffHeapMemory":205304696,"OnHeapExecutionMemory":0,"OffHeapExecutionMemory":0,"OnHeapStorageMemory":905801,"OffHeapStorageMemory":0,"OnHeapUnifiedMemory":905801,"OffHeapUnifiedMemory":0,"DirectPoolMemory":397602,"MappedPoolMemory":0}} +{"Event":"SparkListenerStageExecutorMetrics","Executor ID":"2","Stage ID":1,"Stage Attempt ID":0,"Executor Metrics":{"JVMHeapMemory":595946552,"JVMOffHeapMemory":91208368,"OnHeapExecutionMemory":0,"OffHeapExecutionMemory":0,"OnHeapStorageMemory":58468944,"OffHeapStorageMemory":0,"OnHeapUnifiedMemory":58468944,"OffHeapUnifiedMemory":0,"DirectPoolMemory":87796,"MappedPoolMemory":0}} +{"Event":"SparkListenerStageExecutorMetrics","Executor ID":"1","Stage ID":1,"Stage Attempt ID":0,"Executor Metrics":{"JVMHeapMemory":755008624,"JVMOffHeapMemory":100519936,"OnHeapExecutionMemory":0,"OffHeapExecutionMemory":0,"OnHeapStorageMemory":47962185,"OffHeapStorageMemory":0,"OnHeapUnifiedMemory":47962185,"OffHeapUnifiedMemory":0,"DirectPoolMemory":98230,"MappedPoolMemory":0}} +{"Event":"SparkListenerStageExecutorMetrics","Executor ID":"4","Stage ID":1,"Stage Attempt ID":0,"Executor Metrics":{"JVMHeapMemory":518613056,"JVMOffHeapMemory":95657456,"OnHeapExecutionMemory":37748736,"OffHeapExecutionMemory":0,"OnHeapStorageMemory":63104457,"OffHeapStorageMemory":0,"OnHeapUnifiedMemory":100853193,"OffHeapUnifiedMemory":0,"DirectPoolMemory":126261,"MappedPoolMemory":0}} +{"Event":"SparkListenerStageExecutorMetrics","Executor ID":"3","Stage ID":1,"Stage Attempt ID":0,"Executor Metrics":{"JVMHeapMemory":726805712,"JVMOffHeapMemory":90709624,"OnHeapExecutionMemory":0,"OffHeapExecutionMemory":0,"OnHeapStorageMemory":69535048,"OffHeapStorageMemory":0,"OnHeapUnifiedMemory":69535048,"OffHeapUnifiedMemory":0,"DirectPoolMemory":87796,"MappedPoolMemory":0}} +{"Event":"SparkListenerStageCompleted","Stage Info":{"Stage ID":1,"Stage Attempt ID":0,"Stage Name":"cache at :41","Number of Tasks":4,"RDD Info":[{"RDD ID":14,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"17\",\"name\":\"Exchange\"}","Callsite":"cache at :41","Parent IDs":[13],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":4,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":10,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"24\",\"name\":\"WholeStageCodegen\"}","Callsite":"cache at :41","Parent IDs":[9],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":4,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":9,"Name":"FileScanRDD","Scope":"{\"id\":\"24\",\"name\":\"WholeStageCodegen\"}","Callsite":"cache at :41","Parent IDs":[],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":4,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":12,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"19\",\"name\":\"WholeStageCodegen\"}","Callsite":"cache at :41","Parent IDs":[11],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":4,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":11,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"23\",\"name\":\"Generate\"}","Callsite":"cache at :41","Parent IDs":[10],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":4,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":13,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"18\",\"name\":\"SortAggregate\"}","Callsite":"cache at :41","Parent IDs":[12],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":4,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0}],"Parent IDs":[],"Details":"org.apache.spark.sql.Dataset.cache(Dataset.scala:2912)\n$line49.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:41)\n$line49.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:46)\n$line49.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:48)\n$line49.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:50)\n$line49.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:52)\n$line49.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:54)\n$line49.$read$$iw$$iw$$iw$$iw$$iw$$iw.(:56)\n$line49.$read$$iw$$iw$$iw$$iw$$iw.(:58)\n$line49.$read$$iw$$iw$$iw$$iw.(:60)\n$line49.$read$$iw$$iw$$iw.(:62)\n$line49.$read$$iw$$iw.(:64)\n$line49.$read$$iw.(:66)\n$line49.$read.(:68)\n$line49.$read$.(:72)\n$line49.$read$.()\n$line49.$eval$.$print$lzycompute(:7)\n$line49.$eval$.$print(:6)\n$line49.$eval.$print()\nsun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)","Submission Time":1524182130328,"Completion Time":1524182152419,"Accumulables":[{"ID":83,"Name":"internal.metrics.input.bytesRead","Value":134782862,"Internal":true,"Count Failed Values":true},{"ID":23,"Name":"number of output rows","Value":"13113","Internal":true,"Count Failed Values":true,"Metadata":"sql"},{"ID":68,"Name":"internal.metrics.resultSerializationTime","Value":2,"Internal":true,"Count Failed Values":true},{"ID":8,"Name":"data size total (min, med, max)","Value":"7122371","Internal":true,"Count Failed Values":true,"Metadata":"sql"},{"ID":62,"Name":"internal.metrics.executorDeserializeTime","Value":1232,"Internal":true,"Count Failed Values":true},{"ID":80,"Name":"internal.metrics.shuffle.write.bytesWritten","Value":1296058,"Internal":true,"Count Failed Values":true},{"ID":71,"Name":"internal.metrics.peakExecutionMemory","Value":163577856,"Internal":true,"Count Failed Values":true},{"ID":29,"Name":"number of output rows","Value":"13114","Internal":true,"Count Failed Values":true,"Metadata":"sql"},{"ID":65,"Name":"internal.metrics.executorCpuTime","Value":9480247818,"Internal":true,"Count Failed Values":true},{"ID":64,"Name":"internal.metrics.executorRunTime","Value":13673,"Internal":true,"Count Failed Values":true},{"ID":82,"Name":"internal.metrics.shuffle.write.writeTime","Value":129200897,"Internal":true,"Count Failed Values":true},{"ID":67,"Name":"internal.metrics.jvmGCTime","Value":1015,"Internal":true,"Count Failed Values":true},{"ID":25,"Name":"peak memory total (min, med, max)","Value":"163577851","Internal":true,"Count Failed Values":true,"Metadata":"sql"},{"ID":28,"Name":"number of output rows","Value":"720501","Internal":true,"Count Failed Values":true,"Metadata":"sql"},{"ID":63,"Name":"internal.metrics.executorDeserializeCpuTime","Value":811647158,"Internal":true,"Count Failed Values":true},{"ID":27,"Name":"duration total (min, med, max)","Value":"11610","Internal":true,"Count Failed Values":true,"Metadata":"sql"},{"ID":81,"Name":"internal.metrics.shuffle.write.recordsWritten","Value":13113,"Internal":true,"Count Failed Values":true},{"ID":84,"Name":"internal.metrics.input.recordsRead","Value":13114,"Internal":true,"Count Failed Values":true},{"ID":66,"Name":"internal.metrics.resultSize","Value":9619,"Internal":true,"Count Failed Values":true},{"ID":24,"Name":"sort time total (min, med, max)","Value":"143","Internal":true,"Count Failed Values":true,"Metadata":"sql"},{"ID":33,"Name":"duration total (min, med, max)","Value":"10872","Internal":true,"Count Failed Values":true,"Metadata":"sql"}]}} +{"Event":"SparkListenerStageSubmitted","Stage Info":{"Stage ID":2,"Stage Attempt ID":0,"Stage Name":"show at :40","Number of Tasks":1,"RDD Info":[{"RDD ID":26,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"33\",\"name\":\"map\"}","Callsite":"show at :40","Parent IDs":[25],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":200,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":25,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"32\",\"name\":\"mapPartitionsInternal\"}","Callsite":"show at :40","Parent IDs":[24],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":200,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":8,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"8\",\"name\":\"WholeStageCodegen\"}","Callsite":"cache at :41","Parent IDs":[7],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":200,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":24,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"27\",\"name\":\"WholeStageCodegen\"}","Callsite":"show at :40","Parent IDs":[23],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":200,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":22,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"31\",\"name\":\"InMemoryTableScan\"}","Callsite":"show at :40","Parent IDs":[20],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":200,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":20,"Name":"*(5) Project [appId#0, attemptId#1, name#2, mode#3, completed#4, duration#5L, endTime#6, endTimeEpoch#7L, lastUpdated#8, lastUpdatedEpoch#9L, sparkUser#10, startTime#11, startTimeEpoch#12L, appSparkVersion#13, endDate#28, azkaban.link.workflow.url#159, azkaban.link.execution.url#161, azkaban.link.job.url#163, user.name#165]\n+- SortMergeJoin [appId#0], [appId#137], LeftOuter\n :- *(1) Sort [appId#0 ASC NULLS FIRST], false, 0\n : +- Exchange hashpartitioning(appId#0, 200)\n : +- InMemoryTableScan [appId#0, attemptId#1, name#2, mode#3, completed#4, duration#5L, endTime#6, endTimeEpoch#7L, lastUpdated#8, lastUpdatedEpoch#9L, sparkUser#10, startTime#11, startTimeEpoch#12L, appSparkVersion#13, endDate#28]\n : +- InMemoryRelation [appId#0, attemptId#1, name#2, mode#3, completed#4, duration#5L, endTime#6, endTimeEpoch#7L, lastUpdated#8, lastUpdatedEpoch#9L, sparkUser#10, startTime#11, startTimeEpoch#12L, appSparkVersion#13, endDate#28], true, 10000, StorageLevel(disk, memory, deserialized, 1 rep...","Scope":"{\"id\":\"26\",\"name\":\"mapPartitionsInternal\"}","Callsite":"cache at :41","Parent IDs":[19],"Storage Level":{"Use Disk":true,"Use Memory":true,"Deserialized":true,"Replication":1},"Number of Partitions":200,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":23,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"31\",\"name\":\"InMemoryTableScan\"}","Callsite":"show at :40","Parent IDs":[22],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":200,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":18,"Name":"ZippedPartitionsRDD2","Scope":"{\"id\":\"7\",\"name\":\"SortMergeJoin\"}","Callsite":"cache at :41","Parent IDs":[8,17],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":200,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":17,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"13\",\"name\":\"SortAggregate\"}","Callsite":"cache at :41","Parent IDs":[16],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":200,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":7,"Name":"ShuffledRowRDD","Scope":"{\"id\":\"11\",\"name\":\"Exchange\"}","Callsite":"cache at :41","Parent IDs":[6],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":200,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":16,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"14\",\"name\":\"WholeStageCodegen\"}","Callsite":"cache at :41","Parent IDs":[15],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":200,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":15,"Name":"ShuffledRowRDD","Scope":"{\"id\":\"17\",\"name\":\"Exchange\"}","Callsite":"cache at :41","Parent IDs":[14],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":200,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":19,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"4\",\"name\":\"WholeStageCodegen\"}","Callsite":"cache at :41","Parent IDs":[18],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":200,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0}],"Parent IDs":[0,1],"Details":"org.apache.spark.sql.Dataset.show(Dataset.scala:691)\n$line50.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:40)\n$line50.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:45)\n$line50.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:47)\n$line50.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:49)\n$line50.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:51)\n$line50.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:53)\n$line50.$read$$iw$$iw$$iw$$iw$$iw$$iw.(:55)\n$line50.$read$$iw$$iw$$iw$$iw$$iw.(:57)\n$line50.$read$$iw$$iw$$iw$$iw.(:59)\n$line50.$read$$iw$$iw$$iw.(:61)\n$line50.$read$$iw$$iw.(:63)\n$line50.$read$$iw.(:65)\n$line50.$read.(:67)\n$line50.$read$.(:71)\n$line50.$read$.()\n$line50.$eval$.$print$lzycompute(:7)\n$line50.$eval$.$print(:6)\n$line50.$eval.$print()\nsun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)","Submission Time":1524182152430,"Accumulables":[]},"Properties":{"spark.sql.execution.id":"2"}} +{"Event":"SparkListenerTaskStart","Stage ID":2,"Stage Attempt ID":0,"Task Info":{"Task ID":8,"Index":0,"Attempt":0,"Launch Time":1524182152447,"Executor ID":"4","Host":"node4243.grid.company.com","Locality":"NODE_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":2,"Stage Attempt ID":0,"Task Type":"ResultTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":8,"Index":0,"Attempt":0,"Launch Time":1524182152447,"Executor ID":"4","Host":"node4243.grid.company.com","Locality":"NODE_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1524182153103,"Failed":false,"Killed":false,"Accumulables":[{"ID":34,"Name":"duration total (min, med, max)","Update":"1","Value":"0","Internal":true,"Count Failed Values":true,"Metadata":"sql"},{"ID":35,"Name":"number of output rows","Update":"6928","Value":"6928","Internal":true,"Count Failed Values":true,"Metadata":"sql"},{"ID":10,"Name":"duration total (min, med, max)","Update":"452","Value":"451","Internal":true,"Count Failed Values":true,"Metadata":"sql"},{"ID":11,"Name":"number of output rows","Update":"10945","Value":"10945","Internal":true,"Count Failed Values":true,"Metadata":"sql"},{"ID":18,"Name":"number of output rows","Update":"62","Value":"62","Internal":true,"Count Failed Values":true,"Metadata":"sql"},{"ID":20,"Name":"peak memory total (min, med, max)","Update":"33619967","Value":"33619966","Internal":true,"Count Failed Values":true,"Metadata":"sql"},{"ID":22,"Name":"duration total (min, med, max)","Update":"323","Value":"322","Internal":true,"Count Failed Values":true,"Metadata":"sql"},{"ID":13,"Name":"peak memory total (min, med, max)","Update":"34078719","Value":"34078718","Internal":true,"Count Failed Values":true,"Metadata":"sql"},{"ID":12,"Name":"sort time total (min, med, max)","Update":"10","Value":"9","Internal":true,"Count Failed Values":true,"Metadata":"sql"},{"ID":15,"Name":"duration total (min, med, max)","Update":"367","Value":"366","Internal":true,"Count Failed Values":true,"Metadata":"sql"},{"ID":104,"Name":"internal.metrics.shuffle.read.recordsRead","Update":11007,"Value":11007,"Internal":true,"Count Failed Values":true},{"ID":103,"Name":"internal.metrics.shuffle.read.fetchWaitTime","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":102,"Name":"internal.metrics.shuffle.read.localBytesRead","Update":124513,"Value":124513,"Internal":true,"Count Failed Values":true},{"ID":101,"Name":"internal.metrics.shuffle.read.remoteBytesReadToDisk","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":100,"Name":"internal.metrics.shuffle.read.remoteBytesRead","Update":314162,"Value":314162,"Internal":true,"Count Failed Values":true},{"ID":99,"Name":"internal.metrics.shuffle.read.localBlocksFetched","Update":2,"Value":2,"Internal":true,"Count Failed Values":true},{"ID":98,"Name":"internal.metrics.shuffle.read.remoteBlocksFetched","Update":6,"Value":6,"Internal":true,"Count Failed Values":true},{"ID":96,"Name":"internal.metrics.peakExecutionMemory","Update":67698688,"Value":67698688,"Internal":true,"Count Failed Values":true},{"ID":91,"Name":"internal.metrics.resultSize","Update":4642,"Value":4642,"Internal":true,"Count Failed Values":true},{"ID":90,"Name":"internal.metrics.executorCpuTime","Update":517655714,"Value":517655714,"Internal":true,"Count Failed Values":true},{"ID":89,"Name":"internal.metrics.executorRunTime","Update":589,"Value":589,"Internal":true,"Count Failed Values":true},{"ID":88,"Name":"internal.metrics.executorDeserializeCpuTime","Update":45797784,"Value":45797784,"Internal":true,"Count Failed Values":true},{"ID":87,"Name":"internal.metrics.executorDeserializeTime","Update":50,"Value":50,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":50,"Executor Deserialize CPU Time":45797784,"Executor Run Time":589,"Executor CPU Time":517655714,"Result Size":4642,"JVM GC Time":0,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":6,"Local Blocks Fetched":2,"Fetch Wait Time":0,"Remote Bytes Read":314162,"Remote Bytes Read To Disk":0,"Local Bytes Read":124513,"Total Records Read":11007},"Shuffle Write Metrics":{"Shuffle Bytes Written":0,"Shuffle Write Time":0,"Shuffle Records Written":0},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerStageCompleted","Stage Info":{"Stage ID":2,"Stage Attempt ID":0,"Stage Name":"show at :40","Number of Tasks":1,"RDD Info":[{"RDD ID":26,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"33\",\"name\":\"map\"}","Callsite":"show at :40","Parent IDs":[25],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":200,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":25,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"32\",\"name\":\"mapPartitionsInternal\"}","Callsite":"show at :40","Parent IDs":[24],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":200,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":8,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"8\",\"name\":\"WholeStageCodegen\"}","Callsite":"cache at :41","Parent IDs":[7],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":200,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":24,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"27\",\"name\":\"WholeStageCodegen\"}","Callsite":"show at :40","Parent IDs":[23],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":200,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":22,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"31\",\"name\":\"InMemoryTableScan\"}","Callsite":"show at :40","Parent IDs":[20],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":200,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":20,"Name":"*(5) Project [appId#0, attemptId#1, name#2, mode#3, completed#4, duration#5L, endTime#6, endTimeEpoch#7L, lastUpdated#8, lastUpdatedEpoch#9L, sparkUser#10, startTime#11, startTimeEpoch#12L, appSparkVersion#13, endDate#28, azkaban.link.workflow.url#159, azkaban.link.execution.url#161, azkaban.link.job.url#163, user.name#165]\n+- SortMergeJoin [appId#0], [appId#137], LeftOuter\n :- *(1) Sort [appId#0 ASC NULLS FIRST], false, 0\n : +- Exchange hashpartitioning(appId#0, 200)\n : +- InMemoryTableScan [appId#0, attemptId#1, name#2, mode#3, completed#4, duration#5L, endTime#6, endTimeEpoch#7L, lastUpdated#8, lastUpdatedEpoch#9L, sparkUser#10, startTime#11, startTimeEpoch#12L, appSparkVersion#13, endDate#28]\n : +- InMemoryRelation [appId#0, attemptId#1, name#2, mode#3, completed#4, duration#5L, endTime#6, endTimeEpoch#7L, lastUpdated#8, lastUpdatedEpoch#9L, sparkUser#10, startTime#11, startTimeEpoch#12L, appSparkVersion#13, endDate#28], true, 10000, StorageLevel(disk, memory, deserialized, 1 rep...","Scope":"{\"id\":\"26\",\"name\":\"mapPartitionsInternal\"}","Callsite":"cache at :41","Parent IDs":[19],"Storage Level":{"Use Disk":true,"Use Memory":true,"Deserialized":true,"Replication":1},"Number of Partitions":200,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":23,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"31\",\"name\":\"InMemoryTableScan\"}","Callsite":"show at :40","Parent IDs":[22],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":200,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":18,"Name":"ZippedPartitionsRDD2","Scope":"{\"id\":\"7\",\"name\":\"SortMergeJoin\"}","Callsite":"cache at :41","Parent IDs":[8,17],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":200,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":17,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"13\",\"name\":\"SortAggregate\"}","Callsite":"cache at :41","Parent IDs":[16],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":200,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":7,"Name":"ShuffledRowRDD","Scope":"{\"id\":\"11\",\"name\":\"Exchange\"}","Callsite":"cache at :41","Parent IDs":[6],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":200,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":16,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"14\",\"name\":\"WholeStageCodegen\"}","Callsite":"cache at :41","Parent IDs":[15],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":200,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":15,"Name":"ShuffledRowRDD","Scope":"{\"id\":\"17\",\"name\":\"Exchange\"}","Callsite":"cache at :41","Parent IDs":[14],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":200,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":19,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"4\",\"name\":\"WholeStageCodegen\"}","Callsite":"cache at :41","Parent IDs":[18],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":200,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0}],"Parent IDs":[0,1],"Details":"org.apache.spark.sql.Dataset.show(Dataset.scala:691)\n$line50.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:40)\n$line50.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:45)\n$line50.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:47)\n$line50.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:49)\n$line50.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:51)\n$line50.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:53)\n$line50.$read$$iw$$iw$$iw$$iw$$iw$$iw.(:55)\n$line50.$read$$iw$$iw$$iw$$iw$$iw.(:57)\n$line50.$read$$iw$$iw$$iw$$iw.(:59)\n$line50.$read$$iw$$iw$$iw.(:61)\n$line50.$read$$iw$$iw.(:63)\n$line50.$read$$iw.(:65)\n$line50.$read.(:67)\n$line50.$read$.(:71)\n$line50.$read$.()\n$line50.$eval$.$print$lzycompute(:7)\n$line50.$eval$.$print(:6)\n$line50.$eval.$print()\nsun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)","Submission Time":1524182152430,"Completion Time":1524182153104,"Accumulables":[{"ID":101,"Name":"internal.metrics.shuffle.read.remoteBytesReadToDisk","Value":0,"Internal":true,"Count Failed Values":true},{"ID":104,"Name":"internal.metrics.shuffle.read.recordsRead","Value":11007,"Internal":true,"Count Failed Values":true},{"ID":35,"Name":"number of output rows","Value":"6928","Internal":true,"Count Failed Values":true,"Metadata":"sql"},{"ID":89,"Name":"internal.metrics.executorRunTime","Value":589,"Internal":true,"Count Failed Values":true},{"ID":98,"Name":"internal.metrics.shuffle.read.remoteBlocksFetched","Value":6,"Internal":true,"Count Failed Values":true},{"ID":11,"Name":"number of output rows","Value":"10945","Internal":true,"Count Failed Values":true,"Metadata":"sql"},{"ID":20,"Name":"peak memory total (min, med, max)","Value":"33619966","Internal":true,"Count Failed Values":true,"Metadata":"sql"},{"ID":91,"Name":"internal.metrics.resultSize","Value":4642,"Internal":true,"Count Failed Values":true},{"ID":100,"Name":"internal.metrics.shuffle.read.remoteBytesRead","Value":314162,"Internal":true,"Count Failed Values":true},{"ID":13,"Name":"peak memory total (min, med, max)","Value":"34078718","Internal":true,"Count Failed Values":true,"Metadata":"sql"},{"ID":22,"Name":"duration total (min, med, max)","Value":"322","Internal":true,"Count Failed Values":true,"Metadata":"sql"},{"ID":103,"Name":"internal.metrics.shuffle.read.fetchWaitTime","Value":0,"Internal":true,"Count Failed Values":true},{"ID":88,"Name":"internal.metrics.executorDeserializeCpuTime","Value":45797784,"Internal":true,"Count Failed Values":true},{"ID":34,"Name":"duration total (min, med, max)","Value":"0","Internal":true,"Count Failed Values":true,"Metadata":"sql"},{"ID":10,"Name":"duration total (min, med, max)","Value":"451","Internal":true,"Count Failed Values":true,"Metadata":"sql"},{"ID":87,"Name":"internal.metrics.executorDeserializeTime","Value":50,"Internal":true,"Count Failed Values":true},{"ID":96,"Name":"internal.metrics.peakExecutionMemory","Value":67698688,"Internal":true,"Count Failed Values":true},{"ID":90,"Name":"internal.metrics.executorCpuTime","Value":517655714,"Internal":true,"Count Failed Values":true},{"ID":99,"Name":"internal.metrics.shuffle.read.localBlocksFetched","Value":2,"Internal":true,"Count Failed Values":true},{"ID":18,"Name":"number of output rows","Value":"62","Internal":true,"Count Failed Values":true,"Metadata":"sql"},{"ID":12,"Name":"sort time total (min, med, max)","Value":"9","Internal":true,"Count Failed Values":true,"Metadata":"sql"},{"ID":102,"Name":"internal.metrics.shuffle.read.localBytesRead","Value":124513,"Internal":true,"Count Failed Values":true},{"ID":15,"Name":"duration total (min, med, max)","Value":"366","Internal":true,"Count Failed Values":true,"Metadata":"sql"}]}} +{"Event":"SparkListenerJobEnd","Job ID":0,"Completion Time":1524182153112,"Job Result":{"Result":"JobSucceeded"}} +{"Event":"org.apache.spark.sql.execution.ui.SparkListenerSQLExecutionEnd","executionId":2,"time":1524182153139} +{"Event":"SparkListenerUnpersistRDD","RDD ID":2} +{"Event":"SparkListenerUnpersistRDD","RDD ID":20} +{"Event":"SparkListenerApplicationEnd","Timestamp":1524182189134} diff --git a/core/src/test/scala/org/apache/spark/BarrierStageOnSubmittedSuite.scala b/core/src/test/scala/org/apache/spark/BarrierStageOnSubmittedSuite.scala new file mode 100644 index 0000000000000..d49ab4aa7df12 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/BarrierStageOnSubmittedSuite.scala @@ -0,0 +1,263 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +import scala.concurrent.duration._ +import scala.language.postfixOps + +import org.apache.spark.rdd.{PartitionPruningRDD, RDD} +import org.apache.spark.scheduler.BarrierJobAllocationFailed._ +import org.apache.spark.scheduler.DAGScheduler +import org.apache.spark.util.ThreadUtils + +/** + * This test suite covers all the cases that shall fail fast on job submitted that contains one + * of more barrier stages. + */ +class BarrierStageOnSubmittedSuite extends SparkFunSuite with LocalSparkContext { + + private def createSparkContext(conf: Option[SparkConf] = None): SparkContext = { + new SparkContext(conf.getOrElse( + new SparkConf() + .setMaster("local[4]") + .setAppName("test"))) + } + + private def testSubmitJob( + sc: SparkContext, + rdd: RDD[Int], + partitions: Option[Seq[Int]] = None, + message: String): Unit = { + val futureAction = sc.submitJob( + rdd, + (iter: Iterator[Int]) => iter.toArray, + partitions.getOrElse(0 until rdd.partitions.length), + { case (_, _) => return }: (Int, Array[Int]) => Unit, + { return } + ) + + val error = intercept[SparkException] { + ThreadUtils.awaitResult(futureAction, 5 seconds) + }.getCause.getMessage + assert(error.contains(message)) + } + + test("submit a barrier ResultStage that contains PartitionPruningRDD") { + sc = createSparkContext() + val prunedRdd = new PartitionPruningRDD(sc.parallelize(1 to 10, 4), index => index > 1) + val rdd = prunedRdd + .barrier() + .mapPartitions(iter => iter) + testSubmitJob(sc, rdd, + message = ERROR_MESSAGE_RUN_BARRIER_WITH_UNSUPPORTED_RDD_CHAIN_PATTERN) + } + + test("submit a barrier ShuffleMapStage that contains PartitionPruningRDD") { + sc = createSparkContext() + val prunedRdd = new PartitionPruningRDD(sc.parallelize(1 to 10, 4), index => index > 1) + val rdd = prunedRdd + .barrier() + .mapPartitions(iter => iter) + .repartition(2) + .map(x => x + 1) + testSubmitJob(sc, rdd, + message = ERROR_MESSAGE_RUN_BARRIER_WITH_UNSUPPORTED_RDD_CHAIN_PATTERN) + } + + test("submit a barrier stage that doesn't contain PartitionPruningRDD") { + sc = createSparkContext() + val prunedRdd = new PartitionPruningRDD(sc.parallelize(1 to 10, 4), index => index > 1) + val rdd = prunedRdd + .repartition(2) + .barrier() + .mapPartitions(iter => iter) + // Should be able to submit job and run successfully. + val result = rdd.collect().sorted + assert(result === Seq(6, 7, 8, 9, 10)) + } + + test("submit a barrier stage with partial partitions") { + sc = createSparkContext() + val rdd = sc.parallelize(1 to 10, 4) + .barrier() + .mapPartitions(iter => iter) + testSubmitJob(sc, rdd, Some(Seq(1, 3)), + message = ERROR_MESSAGE_RUN_BARRIER_WITH_UNSUPPORTED_RDD_CHAIN_PATTERN) + } + + test("submit a barrier stage with union()") { + sc = createSparkContext() + val rdd1 = sc.parallelize(1 to 10, 2) + .barrier() + .mapPartitions(iter => iter) + val rdd2 = sc.parallelize(1 to 20, 2) + val rdd3 = rdd1 + .union(rdd2) + .map(x => x * 2) + // Fail the job on submit because the barrier RDD (rdd1) may be not assigned Task 0. + testSubmitJob(sc, rdd3, + message = ERROR_MESSAGE_RUN_BARRIER_WITH_UNSUPPORTED_RDD_CHAIN_PATTERN) + } + + test("submit a barrier stage with coalesce()") { + sc = createSparkContext() + val rdd = sc.parallelize(1 to 10, 4) + .barrier() + .mapPartitions(iter => iter) + .coalesce(1) + // Fail the job on submit because the barrier RDD requires to run on 4 tasks, but the stage + // only launches 1 task. + testSubmitJob(sc, rdd, + message = ERROR_MESSAGE_RUN_BARRIER_WITH_UNSUPPORTED_RDD_CHAIN_PATTERN) + } + + test("submit a barrier stage that contains an RDD that depends on multiple barrier RDDs") { + sc = createSparkContext() + val rdd1 = sc.parallelize(1 to 10, 4) + .barrier() + .mapPartitions(iter => iter) + val rdd2 = sc.parallelize(11 to 20, 4) + .barrier() + .mapPartitions(iter => iter) + val rdd3 = rdd1 + .zip(rdd2) + .map(x => x._1 + x._2) + testSubmitJob(sc, rdd3, + message = ERROR_MESSAGE_RUN_BARRIER_WITH_UNSUPPORTED_RDD_CHAIN_PATTERN) + } + + test("submit a barrier stage with zip()") { + sc = createSparkContext() + val rdd1 = sc.parallelize(1 to 10, 4) + .barrier() + .mapPartitions(iter => iter) + val rdd2 = sc.parallelize(11 to 20, 4) + val rdd3 = rdd1 + .zip(rdd2) + .map(x => x._1 + x._2) + // Should be able to submit job and run successfully. + val result = rdd3.collect().sorted + assert(result === Seq(12, 14, 16, 18, 20, 22, 24, 26, 28, 30)) + } + + test("submit a barrier ResultStage with dynamic resource allocation enabled") { + val conf = new SparkConf() + .set("spark.dynamicAllocation.enabled", "true") + .set("spark.dynamicAllocation.testing", "true") + .setMaster("local[4]") + .setAppName("test") + sc = createSparkContext(Some(conf)) + + val rdd = sc.parallelize(1 to 10, 4) + .barrier() + .mapPartitions(iter => iter) + testSubmitJob(sc, rdd, + message = ERROR_MESSAGE_RUN_BARRIER_WITH_DYN_ALLOCATION) + } + + test("submit a barrier ShuffleMapStage with dynamic resource allocation enabled") { + val conf = new SparkConf() + .set("spark.dynamicAllocation.enabled", "true") + .set("spark.dynamicAllocation.testing", "true") + .setMaster("local[4]") + .setAppName("test") + sc = createSparkContext(Some(conf)) + + val rdd = sc.parallelize(1 to 10, 4) + .barrier() + .mapPartitions(iter => iter) + .repartition(2) + .map(x => x + 1) + testSubmitJob(sc, rdd, + message = ERROR_MESSAGE_RUN_BARRIER_WITH_DYN_ALLOCATION) + } + + test("submit a barrier ResultStage that requires more slots than current total under local " + + "mode") { + val conf = new SparkConf() + // Shorten the time interval between two failed checks to make the test fail faster. + .set("spark.scheduler.barrier.maxConcurrentTasksCheck.interval", "1s") + // Reduce max check failures allowed to make the test fail faster. + .set("spark.scheduler.barrier.maxConcurrentTasksCheck.maxFailures", "3") + .setMaster("local[4]") + .setAppName("test") + sc = createSparkContext(Some(conf)) + val rdd = sc.parallelize(1 to 10, 5) + .barrier() + .mapPartitions(iter => iter) + testSubmitJob(sc, rdd, + message = ERROR_MESSAGE_BARRIER_REQUIRE_MORE_SLOTS_THAN_CURRENT_TOTAL_NUMBER) + } + + test("submit a barrier ShuffleMapStage that requires more slots than current total under " + + "local mode") { + val conf = new SparkConf() + // Shorten the time interval between two failed checks to make the test fail faster. + .set("spark.scheduler.barrier.maxConcurrentTasksCheck.interval", "1s") + // Reduce max check failures allowed to make the test fail faster. + .set("spark.scheduler.barrier.maxConcurrentTasksCheck.maxFailures", "3") + .setMaster("local[4]") + .setAppName("test") + sc = createSparkContext(Some(conf)) + val rdd = sc.parallelize(1 to 10, 5) + .barrier() + .mapPartitions(iter => iter) + .repartition(2) + .map(x => x + 1) + testSubmitJob(sc, rdd, + message = ERROR_MESSAGE_BARRIER_REQUIRE_MORE_SLOTS_THAN_CURRENT_TOTAL_NUMBER) + } + + test("submit a barrier ResultStage that requires more slots than current total under " + + "local-cluster mode") { + val conf = new SparkConf() + .set("spark.task.cpus", "2") + // Shorten the time interval between two failed checks to make the test fail faster. + .set("spark.scheduler.barrier.maxConcurrentTasksCheck.interval", "1s") + // Reduce max check failures allowed to make the test fail faster. + .set("spark.scheduler.barrier.maxConcurrentTasksCheck.maxFailures", "3") + .setMaster("local-cluster[4, 3, 1024]") + .setAppName("test") + sc = createSparkContext(Some(conf)) + val rdd = sc.parallelize(1 to 10, 5) + .barrier() + .mapPartitions(iter => iter) + testSubmitJob(sc, rdd, + message = ERROR_MESSAGE_BARRIER_REQUIRE_MORE_SLOTS_THAN_CURRENT_TOTAL_NUMBER) + } + + test("submit a barrier ShuffleMapStage that requires more slots than current total under " + + "local-cluster mode") { + val conf = new SparkConf() + .set("spark.task.cpus", "2") + // Shorten the time interval between two failed checks to make the test fail faster. + .set("spark.scheduler.barrier.maxConcurrentTasksCheck.interval", "1s") + // Reduce max check failures allowed to make the test fail faster. + .set("spark.scheduler.barrier.maxConcurrentTasksCheck.maxFailures", "3") + .setMaster("local-cluster[4, 3, 1024]") + .setAppName("test") + sc = createSparkContext(Some(conf)) + val rdd = sc.parallelize(1 to 10, 5) + .barrier() + .mapPartitions(iter => iter) + .repartition(2) + .map(x => x + 1) + testSubmitJob(sc, rdd, + message = ERROR_MESSAGE_BARRIER_REQUIRE_MORE_SLOTS_THAN_CURRENT_TOTAL_NUMBER) + } +} diff --git a/core/src/test/scala/org/apache/spark/DistributedSuite.scala b/core/src/test/scala/org/apache/spark/DistributedSuite.scala index 28ea0c6f0bdba..629a323042ff2 100644 --- a/core/src/test/scala/org/apache/spark/DistributedSuite.scala +++ b/core/src/test/scala/org/apache/spark/DistributedSuite.scala @@ -21,6 +21,7 @@ import org.scalatest.Matchers import org.scalatest.concurrent.{Signaler, ThreadSignaler, TimeLimits} import org.scalatest.time.{Millis, Span} +import org.apache.spark.internal.config import org.apache.spark.security.EncryptionFunSuite import org.apache.spark.storage.{RDDBlockId, StorageLevel} import org.apache.spark.util.io.ChunkedByteBuffer @@ -154,6 +155,21 @@ class DistributedSuite extends SparkFunSuite with Matchers with LocalSparkContex sc.parallelize(1 to 10).count() } + private def testCaching(testName: String, conf: SparkConf, storageLevel: StorageLevel): Unit = { + test(testName) { + testCaching(conf, storageLevel) + } + if (storageLevel.replication > 1) { + // also try with block replication as a stream + val uploadStreamConf = new SparkConf() + uploadStreamConf.setAll(conf.getAll) + uploadStreamConf.set(config.MAX_REMOTE_BLOCK_SIZE_FETCH_TO_MEM, 1L) + test(s"$testName (with replication as stream)") { + testCaching(uploadStreamConf, storageLevel) + } + } + } + private def testCaching(conf: SparkConf, storageLevel: StorageLevel): Unit = { sc = new SparkContext(conf.setMaster(clusterUrl).setAppName("test")) TestUtils.waitUntilExecutorsUp(sc, 2, 30000) @@ -169,7 +185,10 @@ class DistributedSuite extends SparkFunSuite with Matchers with LocalSparkContex val blockManager = SparkEnv.get.blockManager val blockTransfer = blockManager.blockTransferService val serializerManager = SparkEnv.get.serializerManager - blockManager.master.getLocations(blockId).foreach { cmId => + val locations = blockManager.master.getLocations(blockId) + assert(locations.size === storageLevel.replication, + s"; got ${locations.size} replicas instead of ${storageLevel.replication}") + locations.foreach { cmId => val bytes = blockTransfer.fetchBlockSync(cmId.host, cmId.port, cmId.executorId, blockId.toString, null) val deserialized = serializerManager.dataDeserializeStream(blockId, @@ -189,8 +208,8 @@ class DistributedSuite extends SparkFunSuite with Matchers with LocalSparkContex "caching in memory and disk, replicated" -> StorageLevel.MEMORY_AND_DISK_2, "caching in memory and disk, serialized, replicated" -> StorageLevel.MEMORY_AND_DISK_SER_2 ).foreach { case (testName, storageLevel) => - encryptionTest(testName) { conf => - testCaching(conf, storageLevel) + encryptionTestHelper(testName) { case (name, conf) => + testCaching(name, conf, storageLevel) } } diff --git a/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala b/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala index 3cfb0a9feb32b..5c718cb654ce8 100644 --- a/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala @@ -24,6 +24,7 @@ import org.mockito.Mockito.{mock, never, verify, when} import org.scalatest.{BeforeAndAfter, PrivateMethodTester} import org.apache.spark.executor.TaskMetrics +import org.apache.spark.internal.config import org.apache.spark.scheduler._ import org.apache.spark.scheduler.ExternalClusterManager import org.apache.spark.scheduler.cluster.ExecutorInfo @@ -1092,7 +1093,7 @@ class ExecutorAllocationManagerSuite val maxExecutors = 2 val conf = new SparkConf() .set("spark.dynamicAllocation.enabled", "true") - .set("spark.shuffle.service.enabled", "true") + .set(config.SHUFFLE_SERVICE_ENABLED.key, "true") .set("spark.dynamicAllocation.minExecutors", minExecutors.toString) .set("spark.dynamicAllocation.maxExecutors", maxExecutors.toString) .set("spark.dynamicAllocation.initialExecutors", initialExecutors.toString) @@ -1376,6 +1377,8 @@ private class DummyLocalSchedulerBackend (sc: SparkContext, sb: SchedulerBackend override def defaultParallelism(): Int = sb.defaultParallelism() + override def maxNumConcurrentTasks(): Int = sb.maxNumConcurrentTasks() + override def killExecutorsOnHost(host: String): Boolean = { false } diff --git a/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala b/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala index 472952addf353..462d5f5604ae3 100644 --- a/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala +++ b/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark import org.scalatest.BeforeAndAfterAll +import org.apache.spark.internal.config import org.apache.spark.network.TransportContext import org.apache.spark.network.netty.SparkTransportConf import org.apache.spark.network.server.TransportServer @@ -42,8 +43,8 @@ class ExternalShuffleServiceSuite extends ShuffleSuite with BeforeAndAfterAll { server = transportContext.createServer() conf.set("spark.shuffle.manager", "sort") - conf.set("spark.shuffle.service.enabled", "true") - conf.set("spark.shuffle.service.port", server.getPort.toString) + conf.set(config.SHUFFLE_SERVICE_ENABLED.key, "true") + conf.set(config.SHUFFLE_SERVICE_PORT.key, server.getPort.toString) } override def afterAll() { diff --git a/core/src/test/scala/org/apache/spark/FileSuite.scala b/core/src/test/scala/org/apache/spark/FileSuite.scala index a441b9c8ab97a..81b18c71f30ee 100644 --- a/core/src/test/scala/org/apache/spark/FileSuite.scala +++ b/core/src/test/scala/org/apache/spark/FileSuite.scala @@ -19,10 +19,12 @@ package org.apache.spark import java.io._ import java.nio.ByteBuffer +import java.nio.charset.StandardCharsets import java.util.zip.GZIPOutputStream import scala.io.Source +import com.google.common.io.Files import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path import org.apache.hadoop.io._ @@ -299,6 +301,25 @@ class FileSuite extends SparkFunSuite with LocalSparkContext { } } + test("SPARK-22357 test binaryFiles minPartitions") { + sc = new SparkContext(new SparkConf().setAppName("test").setMaster("local") + .set("spark.files.openCostInBytes", "0") + .set("spark.default.parallelism", "1")) + + val tempDir = Utils.createTempDir() + val tempDirPath = tempDir.getAbsolutePath + + for (i <- 0 until 8) { + val tempFile = new File(tempDir, s"part-0000$i") + Files.write("someline1 in file1\nsomeline2 in file1\nsomeline3 in file1", tempFile, + StandardCharsets.UTF_8) + } + + for (p <- Seq(1, 2, 8)) { + assert(sc.binaryFiles(tempDirPath, minPartitions = p).getNumPartitions === p) + } + } + test("fixed record length binary file as byte array") { sc = new SparkContext("local", "test") val testOutput = Array[Byte](1, 2, 3, 4, 5, 6) diff --git a/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala b/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala index b705556e54b14..de479db5fbc0f 100644 --- a/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala +++ b/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala @@ -28,7 +28,7 @@ import org.mockito.Matchers._ import org.mockito.Mockito.{mock, spy, verify, when} import org.scalatest.{BeforeAndAfterEach, PrivateMethodTester} -import org.apache.spark.executor.TaskMetrics +import org.apache.spark.executor.{ExecutorMetrics, TaskMetrics} import org.apache.spark.rpc.{RpcCallContext, RpcEndpoint, RpcEndpointRef, RpcEnv} import org.apache.spark.scheduler._ import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._ @@ -77,7 +77,7 @@ class HeartbeatReceiverSuite heartbeatReceiverClock = new ManualClock heartbeatReceiver = new HeartbeatReceiver(sc, heartbeatReceiverClock) heartbeatReceiverRef = sc.env.rpcEnv.setupEndpoint("heartbeat", heartbeatReceiver) - when(scheduler.executorHeartbeatReceived(any(), any(), any())).thenReturn(true) + when(scheduler.executorHeartbeatReceived(any(), any(), any(), any())).thenReturn(true) } /** @@ -213,8 +213,10 @@ class HeartbeatReceiverSuite executorShouldReregister: Boolean): Unit = { val metrics = TaskMetrics.empty val blockManagerId = BlockManagerId(executorId, "localhost", 12345) + val executorUpdates = new ExecutorMetrics(Array(123456L, 543L, 12345L, 1234L, 123L, + 12L, 432L, 321L, 654L, 765L)) val response = heartbeatReceiverRef.askSync[HeartbeatResponse]( - Heartbeat(executorId, Array(1L -> metrics.accumulators()), blockManagerId)) + Heartbeat(executorId, Array(1L -> metrics.accumulators()), blockManagerId, executorUpdates)) if (executorShouldReregister) { assert(response.reregisterBlockManager) } else { @@ -223,7 +225,8 @@ class HeartbeatReceiverSuite verify(scheduler).executorHeartbeatReceived( Matchers.eq(executorId), Matchers.eq(Array(1L -> metrics.accumulators())), - Matchers.eq(blockManagerId)) + Matchers.eq(blockManagerId), + Matchers.eq(executorUpdates)) } } diff --git a/core/src/test/scala/org/apache/spark/LocalSparkContext.scala b/core/src/test/scala/org/apache/spark/LocalSparkContext.scala index 1dd89bcbe36bc..05aaaa11451b4 100644 --- a/core/src/test/scala/org/apache/spark/LocalSparkContext.scala +++ b/core/src/test/scala/org/apache/spark/LocalSparkContext.scala @@ -29,7 +29,7 @@ trait LocalSparkContext extends BeforeAndAfterEach with BeforeAndAfterAll { self override def beforeAll() { super.beforeAll() - InternalLoggerFactory.setDefaultFactory(new Slf4JLoggerFactory()) + InternalLoggerFactory.setDefaultFactory(Slf4JLoggerFactory.INSTANCE) } override def afterEach() { diff --git a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala index 21f481d477242..e79739692fe13 100644 --- a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala +++ b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala @@ -62,9 +62,9 @@ class MapOutputTrackerSuite extends SparkFunSuite { val size1000 = MapStatus.decompressSize(MapStatus.compressSize(1000L)) val size10000 = MapStatus.decompressSize(MapStatus.compressSize(10000L)) tracker.registerMapOutput(10, 0, MapStatus(BlockManagerId("a", "hostA", 1000), - Array(1000L, 10000L))) + Array(1000L, 10000L), 10)) tracker.registerMapOutput(10, 1, MapStatus(BlockManagerId("b", "hostB", 1000), - Array(10000L, 1000L))) + Array(10000L, 1000L), 10)) val statuses = tracker.getMapSizesByExecutorId(10, 0) assert(statuses.toSet === Seq((BlockManagerId("a", "hostA", 1000), ArrayBuffer((ShuffleBlockId(10, 0, 0), size1000))), @@ -84,9 +84,9 @@ class MapOutputTrackerSuite extends SparkFunSuite { val compressedSize1000 = MapStatus.compressSize(1000L) val compressedSize10000 = MapStatus.compressSize(10000L) tracker.registerMapOutput(10, 0, MapStatus(BlockManagerId("a", "hostA", 1000), - Array(compressedSize1000, compressedSize10000))) + Array(compressedSize1000, compressedSize10000), 10)) tracker.registerMapOutput(10, 1, MapStatus(BlockManagerId("b", "hostB", 1000), - Array(compressedSize10000, compressedSize1000))) + Array(compressedSize10000, compressedSize1000), 10)) assert(tracker.containsShuffle(10)) assert(tracker.getMapSizesByExecutorId(10, 0).nonEmpty) assert(0 == tracker.getNumCachedSerializedBroadcast) @@ -107,9 +107,9 @@ class MapOutputTrackerSuite extends SparkFunSuite { val compressedSize1000 = MapStatus.compressSize(1000L) val compressedSize10000 = MapStatus.compressSize(10000L) tracker.registerMapOutput(10, 0, MapStatus(BlockManagerId("a", "hostA", 1000), - Array(compressedSize1000, compressedSize1000, compressedSize1000))) + Array(compressedSize1000, compressedSize1000, compressedSize1000), 10)) tracker.registerMapOutput(10, 1, MapStatus(BlockManagerId("b", "hostB", 1000), - Array(compressedSize10000, compressedSize1000, compressedSize1000))) + Array(compressedSize10000, compressedSize1000, compressedSize1000), 10)) assert(0 == tracker.getNumCachedSerializedBroadcast) // As if we had two simultaneous fetch failures @@ -145,7 +145,7 @@ class MapOutputTrackerSuite extends SparkFunSuite { val size1000 = MapStatus.decompressSize(MapStatus.compressSize(1000L)) masterTracker.registerMapOutput(10, 0, MapStatus( - BlockManagerId("a", "hostA", 1000), Array(1000L))) + BlockManagerId("a", "hostA", 1000), Array(1000L), 10)) slaveTracker.updateEpoch(masterTracker.getEpoch) assert(slaveTracker.getMapSizesByExecutorId(10, 0).toSeq === Seq((BlockManagerId("a", "hostA", 1000), ArrayBuffer((ShuffleBlockId(10, 0, 0), size1000))))) @@ -182,7 +182,7 @@ class MapOutputTrackerSuite extends SparkFunSuite { // Message size should be ~123B, and no exception should be thrown masterTracker.registerShuffle(10, 1) masterTracker.registerMapOutput(10, 0, MapStatus( - BlockManagerId("88", "mph", 1000), Array.fill[Long](10)(0))) + BlockManagerId("88", "mph", 1000), Array.fill[Long](10)(0), 0)) val senderAddress = RpcAddress("localhost", 12345) val rpcCallContext = mock(classOf[RpcCallContext]) when(rpcCallContext.senderAddress).thenReturn(senderAddress) @@ -216,11 +216,11 @@ class MapOutputTrackerSuite extends SparkFunSuite { // on hostB with output size 3 tracker.registerShuffle(10, 3) tracker.registerMapOutput(10, 0, MapStatus(BlockManagerId("a", "hostA", 1000), - Array(2L))) + Array(2L), 1)) tracker.registerMapOutput(10, 1, MapStatus(BlockManagerId("a", "hostA", 1000), - Array(2L))) + Array(2L), 1)) tracker.registerMapOutput(10, 2, MapStatus(BlockManagerId("b", "hostB", 1000), - Array(3L))) + Array(3L), 1)) // When the threshold is 50%, only host A should be returned as a preferred location // as it has 4 out of 7 bytes of output. @@ -260,7 +260,7 @@ class MapOutputTrackerSuite extends SparkFunSuite { masterTracker.registerShuffle(20, 100) (0 until 100).foreach { i => masterTracker.registerMapOutput(20, i, new CompressedMapStatus( - BlockManagerId("999", "mps", 1000), Array.fill[Long](4000000)(0))) + BlockManagerId("999", "mps", 1000), Array.fill[Long](4000000)(0), 0)) } val senderAddress = RpcAddress("localhost", 12345) val rpcCallContext = mock(classOf[RpcCallContext]) @@ -309,9 +309,9 @@ class MapOutputTrackerSuite extends SparkFunSuite { val size1000 = MapStatus.decompressSize(MapStatus.compressSize(1000L)) val size10000 = MapStatus.decompressSize(MapStatus.compressSize(10000L)) tracker.registerMapOutput(10, 0, MapStatus(BlockManagerId("a", "hostA", 1000), - Array(size0, size1000, size0, size10000))) + Array(size0, size1000, size0, size10000), 1)) tracker.registerMapOutput(10, 1, MapStatus(BlockManagerId("b", "hostB", 1000), - Array(size10000, size0, size1000, size0))) + Array(size10000, size0, size1000, size0), 1)) assert(tracker.containsShuffle(10)) assert(tracker.getMapSizesByExecutorId(10, 0, 4).toSeq === Seq( diff --git a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala index ced5a06516f75..456f97b535ef6 100644 --- a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala +++ b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala @@ -208,7 +208,7 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC val pairs2: RDD[MutablePair[Int, String]] = sc.parallelize(data2, 2) val results = new SubtractedRDD(pairs1, pairs2, new HashPartitioner(2)).collect() results should have length (1) - // substracted rdd return results as Tuple2 + // subtracted rdd return results as Tuple2 results(0) should be ((3, 33)) } @@ -391,6 +391,7 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC assert(mapOutput2.isDefined) assert(mapOutput1.get.location === mapOutput2.get.location) assert(mapOutput1.get.getSizeForBlock(0) === mapOutput1.get.getSizeForBlock(0)) + assert(mapOutput1.get.numberOfOutput === mapOutput2.get.numberOfOutput) // register one of the map outputs -- doesn't matter which one mapOutput1.foreach { case mapStatus => diff --git a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala index ce9f2be1c02dd..e1666a35271d3 100644 --- a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala @@ -627,6 +627,51 @@ class SparkContextSuite extends SparkFunSuite with LocalSparkContext with Eventu assert(exc.getCause() != null) stream.close() } + + test("support barrier execution mode under local mode") { + val conf = new SparkConf().setAppName("test").setMaster("local[2]") + sc = new SparkContext(conf) + val rdd = sc.makeRDD(Seq(1, 2, 3, 4), 2) + val rdd2 = rdd.barrier().mapPartitions { it => + val context = BarrierTaskContext.get() + // If we don't get the expected taskInfos, the job shall abort due to stage failure. + if (context.getTaskInfos().length != 2) { + throw new SparkException("Expected taksInfos length is 2, actual length is " + + s"${context.getTaskInfos().length}.") + } + context.barrier() + it + } + rdd2.collect() + + eventually(timeout(10.seconds)) { + assert(sc.statusTracker.getExecutorInfos.map(_.numRunningTasks()).sum == 0) + } + } + + test("support barrier execution mode under local-cluster mode") { + val conf = new SparkConf() + .setMaster("local-cluster[3, 1, 1024]") + .setAppName("test-cluster") + sc = new SparkContext(conf) + + val rdd = sc.makeRDD(Seq(1, 2, 3, 4), 2) + val rdd2 = rdd.barrier().mapPartitions { it => + val context = BarrierTaskContext.get() + // If we don't get the expected taskInfos, the job shall abort due to stage failure. + if (context.getTaskInfos().length != 2) { + throw new SparkException("Expected taksInfos length is 2, actual length is " + + s"${context.getTaskInfos().length}.") + } + context.barrier() + it + } + rdd2.collect() + + eventually(timeout(10.seconds)) { + assert(sc.statusTracker.getExecutorInfos.map(_.numRunningTasks()).sum == 0) + } + } } object SparkContextSuite { diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkHadoopUtilSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkHadoopUtilSuite.scala deleted file mode 100644 index ab24a76e20a30..0000000000000 --- a/core/src/test/scala/org/apache/spark/deploy/SparkHadoopUtilSuite.scala +++ /dev/null @@ -1,97 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.deploy - -import java.security.PrivilegedExceptionAction - -import scala.util.Random - -import org.apache.hadoop.fs.FileStatus -import org.apache.hadoop.fs.permission.{FsAction, FsPermission} -import org.apache.hadoop.security.UserGroupInformation -import org.scalatest.Matchers - -import org.apache.spark.SparkFunSuite - -class SparkHadoopUtilSuite extends SparkFunSuite with Matchers { - test("check file permission") { - import FsAction._ - val testUser = s"user-${Random.nextInt(100)}" - val testGroups = Array(s"group-${Random.nextInt(100)}") - val testUgi = UserGroupInformation.createUserForTesting(testUser, testGroups) - - testUgi.doAs(new PrivilegedExceptionAction[Void] { - override def run(): Void = { - val sparkHadoopUtil = new SparkHadoopUtil - - // If file is owned by user and user has access permission - var status = fileStatus(testUser, testGroups.head, READ_WRITE, READ_WRITE, NONE) - sparkHadoopUtil.checkAccessPermission(status, READ) should be(true) - sparkHadoopUtil.checkAccessPermission(status, WRITE) should be(true) - - // If file is owned by user but user has no access permission - status = fileStatus(testUser, testGroups.head, NONE, READ_WRITE, NONE) - sparkHadoopUtil.checkAccessPermission(status, READ) should be(false) - sparkHadoopUtil.checkAccessPermission(status, WRITE) should be(false) - - val otherUser = s"test-${Random.nextInt(100)}" - val otherGroup = s"test-${Random.nextInt(100)}" - - // If file is owned by user's group and user's group has access permission - status = fileStatus(otherUser, testGroups.head, NONE, READ_WRITE, NONE) - sparkHadoopUtil.checkAccessPermission(status, READ) should be(true) - sparkHadoopUtil.checkAccessPermission(status, WRITE) should be(true) - - // If file is owned by user's group but user's group has no access permission - status = fileStatus(otherUser, testGroups.head, READ_WRITE, NONE, NONE) - sparkHadoopUtil.checkAccessPermission(status, READ) should be(false) - sparkHadoopUtil.checkAccessPermission(status, WRITE) should be(false) - - // If file is owned by other user and this user has access permission - status = fileStatus(otherUser, otherGroup, READ_WRITE, READ_WRITE, READ_WRITE) - sparkHadoopUtil.checkAccessPermission(status, READ) should be(true) - sparkHadoopUtil.checkAccessPermission(status, WRITE) should be(true) - - // If file is owned by other user but this user has no access permission - status = fileStatus(otherUser, otherGroup, READ_WRITE, READ_WRITE, NONE) - sparkHadoopUtil.checkAccessPermission(status, READ) should be(false) - sparkHadoopUtil.checkAccessPermission(status, WRITE) should be(false) - - null - } - }) - } - - private def fileStatus( - owner: String, - group: String, - userAction: FsAction, - groupAction: FsAction, - otherAction: FsAction): FileStatus = { - new FileStatus(0L, - false, - 0, - 0L, - 0L, - 0L, - new FsPermission(userAction, groupAction, otherAction), - owner, - group, - null) - } -} diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala index 545c8d0423dc3..9eae3605d0738 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala @@ -995,20 +995,24 @@ class SparkSubmitSuite } test("download remote resource if it is not supported by yarn service") { - testRemoteResources(enableHttpFs = false, blacklistHttpFs = false) + testRemoteResources(enableHttpFs = false) } test("avoid downloading remote resource if it is supported by yarn service") { - testRemoteResources(enableHttpFs = true, blacklistHttpFs = false) + testRemoteResources(enableHttpFs = true) } test("force download from blacklisted schemes") { - testRemoteResources(enableHttpFs = true, blacklistHttpFs = true) + testRemoteResources(enableHttpFs = true, blacklistSchemes = Seq("http")) + } + + test("force download for all the schemes") { + testRemoteResources(enableHttpFs = true, blacklistSchemes = Seq("*")) } private def testRemoteResources( enableHttpFs: Boolean, - blacklistHttpFs: Boolean): Unit = { + blacklistSchemes: Seq[String] = Nil): Unit = { val hadoopConf = new Configuration() updateConfWithFakeS3Fs(hadoopConf) if (enableHttpFs) { @@ -1025,8 +1029,8 @@ class SparkSubmitSuite val tmpHttpJar = TestUtils.createJarWithFiles(Map("test.resource" -> "USER"), tmpDir) val tmpHttpJarPath = s"http://${new File(tmpHttpJar.toURI).getAbsolutePath}" - val forceDownloadArgs = if (blacklistHttpFs) { - Seq("--conf", "spark.yarn.dist.forceDownloadSchemes=http") + val forceDownloadArgs = if (blacklistSchemes.nonEmpty) { + Seq("--conf", s"spark.yarn.dist.forceDownloadSchemes=${blacklistSchemes.mkString(",")}") } else { Nil } @@ -1044,14 +1048,19 @@ class SparkSubmitSuite val jars = conf.get("spark.yarn.dist.jars").split(",").toSet - // The URI of remote S3 resource should still be remote. - assert(jars.contains(tmpS3JarPath)) + def isSchemeBlacklisted(scheme: String) = { + blacklistSchemes.contains("*") || blacklistSchemes.contains(scheme) + } - if (enableHttpFs && !blacklistHttpFs) { + if (!isSchemeBlacklisted("s3")) { + assert(jars.contains(tmpS3JarPath)) + } + + if (enableHttpFs && blacklistSchemes.isEmpty) { // If Http FS is supported by yarn service, the URI of remote http resource should // still be remote. assert(jars.contains(tmpHttpJarPath)) - } else { + } else if (!enableHttpFs || isSchemeBlacklisted("http")) { // If Http FS is not supported by yarn service, or http scheme is configured to be force // downloading, the URI of remote http resource should be changed to a local one. val jarName = new File(tmpHttpJar.toURI).getName @@ -1135,6 +1144,53 @@ class SparkSubmitSuite conf1.get(PY_FILES.key) should be (s"s3a://${pyFile.getAbsolutePath}") conf1.get("spark.submit.pyFiles") should (startWith("/")) } + + test("handles natural line delimiters in --properties-file and --conf uniformly") { + val delimKey = "spark.my.delimiter." + val LF = "\n" + val CR = "\r" + + val lineFeedFromCommandLine = s"${delimKey}lineFeedFromCommandLine" -> LF + val leadingDelimKeyFromFile = s"${delimKey}leadingDelimKeyFromFile" -> s"${LF}blah" + val trailingDelimKeyFromFile = s"${delimKey}trailingDelimKeyFromFile" -> s"blah${CR}" + val infixDelimFromFile = s"${delimKey}infixDelimFromFile" -> s"${CR}blah${LF}" + val nonDelimSpaceFromFile = s"${delimKey}nonDelimSpaceFromFile" -> " blah\f" + + val testProps = Seq(leadingDelimKeyFromFile, trailingDelimKeyFromFile, infixDelimFromFile, + nonDelimSpaceFromFile) + + val props = new java.util.Properties() + val propsFile = File.createTempFile("test-spark-conf", ".properties", + Utils.createTempDir()) + val propsOutputStream = new FileOutputStream(propsFile) + try { + testProps.foreach { case (k, v) => props.put(k, v) } + props.store(propsOutputStream, "test whitespace") + } finally { + propsOutputStream.close() + } + + val clArgs = Seq( + "--class", "org.SomeClass", + "--conf", s"${lineFeedFromCommandLine._1}=${lineFeedFromCommandLine._2}", + "--conf", "spark.master=yarn", + "--properties-file", propsFile.getPath, + "thejar.jar") + + val appArgs = new SparkSubmitArguments(clArgs) + val (_, _, conf, _) = submit.prepareSubmitEnvironment(appArgs) + + Seq( + lineFeedFromCommandLine, + leadingDelimKeyFromFile, + trailingDelimKeyFromFile, + infixDelimFromFile + ).foreach { case (k, v) => + conf.get(k) should be (v) + } + + conf.get(nonDelimSpaceFromFile._1) should be ("blah") + } } object SparkSubmitSuite extends SparkFunSuite with TimeLimits { diff --git a/core/src/test/scala/org/apache/spark/deploy/StandaloneDynamicAllocationSuite.scala b/core/src/test/scala/org/apache/spark/deploy/StandaloneDynamicAllocationSuite.scala index 27cc47496c805..a1d2a1283db14 100644 --- a/core/src/test/scala/org/apache/spark/deploy/StandaloneDynamicAllocationSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/StandaloneDynamicAllocationSuite.scala @@ -458,7 +458,7 @@ class StandaloneDynamicAllocationSuite val initialExecutorLimit = 1 val myConf = appConf .set("spark.dynamicAllocation.enabled", "true") - .set("spark.shuffle.service.enabled", "true") + .set(config.SHUFFLE_SERVICE_ENABLED.key, "true") .set("spark.dynamicAllocation.initialExecutors", initialExecutorLimit.toString) sc = new SparkContext(myConf) val appId = sc.applicationId diff --git a/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala index 77b239489d489..b4eba755eccbf 100644 --- a/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala @@ -29,9 +29,11 @@ import scala.language.postfixOps import com.google.common.io.{ByteStreams, Files} import org.apache.hadoop.fs.{FileStatus, Path} import org.apache.hadoop.hdfs.DistributedFileSystem +import org.apache.hadoop.security.AccessControlException import org.json4s.jackson.JsonMethods._ -import org.mockito.Matchers.any -import org.mockito.Mockito.{mock, spy, verify} +import org.mockito.ArgumentMatcher +import org.mockito.Matchers.{any, argThat} +import org.mockito.Mockito.{doThrow, mock, spy, verify, when} import org.scalatest.BeforeAndAfter import org.scalatest.Matchers import org.scalatest.concurrent.Eventually._ @@ -818,6 +820,42 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc } } + test("SPARK-24948: blacklist files we don't have read permission on") { + val clock = new ManualClock(1533132471) + val provider = new FsHistoryProvider(createTestConf(), clock) + val accessDenied = newLogFile("accessDenied", None, inProgress = false) + writeFile(accessDenied, true, None, + SparkListenerApplicationStart("accessDenied", Some("accessDenied"), 1L, "test", None)) + val accessGranted = newLogFile("accessGranted", None, inProgress = false) + writeFile(accessGranted, true, None, + SparkListenerApplicationStart("accessGranted", Some("accessGranted"), 1L, "test", None), + SparkListenerApplicationEnd(5L)) + val mockedFs = spy(provider.fs) + doThrow(new AccessControlException("Cannot read accessDenied file")).when(mockedFs).open( + argThat(new ArgumentMatcher[Path]() { + override def matches(path: Any): Boolean = { + path.asInstanceOf[Path].getName.toLowerCase == "accessdenied" + } + })) + val mockedProvider = spy(provider) + when(mockedProvider.fs).thenReturn(mockedFs) + updateAndCheck(mockedProvider) { list => + list.size should be(1) + } + writeFile(accessDenied, true, None, + SparkListenerApplicationStart("accessDenied", Some("accessDenied"), 1L, "test", None), + SparkListenerApplicationEnd(5L)) + // Doing 2 times in order to check the blacklist filter too + updateAndCheck(mockedProvider) { list => + list.size should be(1) + } + val accessDeniedPath = new Path(accessDenied.getPath) + assert(mockedProvider.isBlacklisted(accessDeniedPath)) + clock.advance(24 * 60 * 60 * 1000 + 1) // add a bit more than 1d + mockedProvider.cleanLogs() + assert(!mockedProvider.isBlacklisted(accessDeniedPath)) + } + /** * Asks the provider to check for logs and calls a function to perform checks on the updated * app list. Example: diff --git a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala index 11b29121739a4..11a2db81f7c6d 100644 --- a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala @@ -82,6 +82,7 @@ class HistoryServerSuite extends SparkFunSuite with BeforeAndAfter with Matchers .set("spark.history.fs.update.interval", "0") .set("spark.testing", "true") .set(LOCAL_STORE_DIR, storeDir.getAbsolutePath()) + .set("spark.eventLog.logStageExecutorMetrics.enabled", "true") conf.setAll(extraConf) provider = new FsHistoryProvider(conf) provider.checkForLogs() @@ -128,6 +129,8 @@ class HistoryServerSuite extends SparkFunSuite with BeforeAndAfter with Matchers "succeeded&failed job list json" -> "applications/local-1422981780767/jobs?status=succeeded&status=failed", "executor list json" -> "applications/local-1422981780767/executors", + "executor list with executor metrics json" -> + "applications/application_1506645932520_24630151/executors", "stage list json" -> "applications/local-1422981780767/stages", "complete stage list json" -> "applications/local-1422981780767/stages?status=complete", "failed stage list json" -> "applications/local-1422981780767/stages?status=failed", diff --git a/core/src/test/scala/org/apache/spark/deploy/master/ui/MasterWebUISuite.scala b/core/src/test/scala/org/apache/spark/deploy/master/ui/MasterWebUISuite.scala index 69a460fbc7dba..f4558aa3eb893 100644 --- a/core/src/test/scala/org/apache/spark/deploy/master/ui/MasterWebUISuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/master/ui/MasterWebUISuite.scala @@ -53,8 +53,11 @@ class MasterWebUISuite extends SparkFunSuite with BeforeAndAfterAll { } override def afterAll() { - masterWebUI.stop() - super.afterAll() + try { + masterWebUI.stop() + } finally { + super.afterAll() + } } test("kill application") { diff --git a/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala b/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala index 1a7bebe2c53cd..77a7668d3a1d1 100644 --- a/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala +++ b/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala @@ -275,6 +275,7 @@ class ExecutorSuite extends SparkFunSuite with LocalSparkContext with MockitoSug executorId = "", name = "", index = 0, + partitionId = 0, addedFiles = Map[String, Long](), addedJars = Map[String, Long](), properties = new Properties, diff --git a/core/src/test/scala/org/apache/spark/input/WholeTextFileInputFormatSuite.scala b/core/src/test/scala/org/apache/spark/input/WholeTextFileInputFormatSuite.scala new file mode 100644 index 0000000000000..817dc082b7d38 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/input/WholeTextFileInputFormatSuite.scala @@ -0,0 +1,96 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.input + +import java.io.{DataOutputStream, File, FileOutputStream} + +import scala.collection.immutable.IndexedSeq + +import org.scalatest.BeforeAndAfterAll + +import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} +import org.apache.spark.internal.Logging +import org.apache.spark.util.Utils + +/** + * Tests the correctness of + * [[org.apache.spark.input.WholeTextFileInputFormat WholeTextFileInputFormat]]. A temporary + * directory containing files is created as fake input which is deleted in the end. + */ +class WholeTextFileInputFormatSuite extends SparkFunSuite with BeforeAndAfterAll with Logging { + private var sc: SparkContext = _ + + override def beforeAll() { + super.beforeAll() + val conf = new SparkConf() + sc = new SparkContext("local", "test", conf) + } + + override def afterAll() { + try { + sc.stop() + } finally { + super.afterAll() + } + } + + private def createNativeFile(inputDir: File, fileName: String, contents: Array[Byte], + compress: Boolean) = { + val path = s"${inputDir.toString}/$fileName" + val out = new DataOutputStream(new FileOutputStream(path)) + out.write(contents, 0, contents.length) + out.close() + } + + test("for small files minimum split size per node and per rack should be less than or equal to " + + "maximum split size.") { + var dir : File = null; + try { + dir = Utils.createTempDir() + logInfo(s"Local disk address is ${dir.toString}.") + + // Set the minsize per node and rack to be larger than the size of the input file. + sc.hadoopConfiguration.setLong( + "mapreduce.input.fileinputformat.split.minsize.per.node", 123456) + sc.hadoopConfiguration.setLong( + "mapreduce.input.fileinputformat.split.minsize.per.rack", 123456) + + WholeTextFileInputFormatSuite.files.foreach { case (filename, contents) => + createNativeFile(dir, filename, contents, false) + } + // ensure spark job runs successfully without exceptions from the CombineFileInputFormat + assert(sc.wholeTextFiles(dir.toString).count == 3) + } finally { + Utils.deleteRecursively(dir) + } + } +} + +/** + * Files to be tested are defined here. + */ +object WholeTextFileInputFormatSuite { + private val testWords: IndexedSeq[Byte] = "Spark is easy to use.\n".map(_.toByte) + + private val fileNames = Array("part-00000", "part-00001", "part-00002") + private val fileLengths = Array(10, 100, 1000) + + private val files = fileLengths.zip(fileNames).map { case (upperBound, filename) => + filename -> Stream.continually(testWords.toList.toStream).flatten.take(upperBound).toArray + }.toMap +} diff --git a/core/src/test/scala/org/apache/spark/io/ChunkedByteBufferFileRegionSuite.scala b/core/src/test/scala/org/apache/spark/io/ChunkedByteBufferFileRegionSuite.scala new file mode 100644 index 0000000000000..a6b0654204f34 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/io/ChunkedByteBufferFileRegionSuite.scala @@ -0,0 +1,152 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.io + +import java.nio.ByteBuffer +import java.nio.channels.WritableByteChannel + +import scala.util.Random + +import org.mockito.Mockito.when +import org.scalatest.BeforeAndAfterEach +import org.scalatest.mockito.MockitoSugar + +import org.apache.spark.{SparkConf, SparkEnv, SparkFunSuite} +import org.apache.spark.internal.config +import org.apache.spark.util.io.ChunkedByteBuffer + +class ChunkedByteBufferFileRegionSuite extends SparkFunSuite with MockitoSugar + with BeforeAndAfterEach { + + override protected def beforeEach(): Unit = { + super.beforeEach() + val conf = new SparkConf() + val env = mock[SparkEnv] + SparkEnv.set(env) + when(env.conf).thenReturn(conf) + } + + override protected def afterEach(): Unit = { + SparkEnv.set(null) + } + + private def generateChunkedByteBuffer(nChunks: Int, perChunk: Int): ChunkedByteBuffer = { + val bytes = (0 until nChunks).map { chunkIdx => + val bb = ByteBuffer.allocate(perChunk) + (0 until perChunk).foreach { idx => + bb.put((chunkIdx * perChunk + idx).toByte) + } + bb.position(0) + bb + }.toArray + new ChunkedByteBuffer(bytes) + } + + test("transferTo can stop and resume correctly") { + SparkEnv.get.conf.set(config.BUFFER_WRITE_CHUNK_SIZE, 9L) + val cbb = generateChunkedByteBuffer(4, 10) + val fileRegion = cbb.toNetty + + val targetChannel = new LimitedWritableByteChannel(40) + + var pos = 0L + // write the fileregion to the channel, but with the transfer limited at various spots along + // the way. + + // limit to within the first chunk + targetChannel.acceptNBytes = 5 + pos = fileRegion.transferTo(targetChannel, pos) + assert(targetChannel.pos === 5) + + // a little bit further within the first chunk + targetChannel.acceptNBytes = 2 + pos += fileRegion.transferTo(targetChannel, pos) + assert(targetChannel.pos === 7) + + // past the first chunk, into the 2nd + targetChannel.acceptNBytes = 6 + pos += fileRegion.transferTo(targetChannel, pos) + assert(targetChannel.pos === 13) + + // right to the end of the 2nd chunk + targetChannel.acceptNBytes = 7 + pos += fileRegion.transferTo(targetChannel, pos) + assert(targetChannel.pos === 20) + + // rest of 2nd chunk, all of 3rd, some of 4th + targetChannel.acceptNBytes = 15 + pos += fileRegion.transferTo(targetChannel, pos) + assert(targetChannel.pos === 35) + + // now till the end + targetChannel.acceptNBytes = 5 + pos += fileRegion.transferTo(targetChannel, pos) + assert(targetChannel.pos === 40) + + // calling again at the end should be OK + targetChannel.acceptNBytes = 20 + fileRegion.transferTo(targetChannel, pos) + assert(targetChannel.pos === 40) + } + + test(s"transfer to with random limits") { + val rng = new Random() + val seed = System.currentTimeMillis() + logInfo(s"seed = $seed") + rng.setSeed(seed) + val chunkSize = 1e4.toInt + SparkEnv.get.conf.set(config.BUFFER_WRITE_CHUNK_SIZE, rng.nextInt(chunkSize).toLong) + + val cbb = generateChunkedByteBuffer(50, chunkSize) + val fileRegion = cbb.toNetty + val transferLimit = 1e5.toInt + val targetChannel = new LimitedWritableByteChannel(transferLimit) + while (targetChannel.pos < cbb.size) { + val nextTransferSize = rng.nextInt(transferLimit) + targetChannel.acceptNBytes = nextTransferSize + fileRegion.transferTo(targetChannel, targetChannel.pos) + } + assert(0 === fileRegion.transferTo(targetChannel, targetChannel.pos)) + } + + /** + * This mocks a channel which only accepts a limited number of bytes at a time. It also verifies + * the written data matches our expectations as the data is received. + */ + private class LimitedWritableByteChannel(maxWriteSize: Int) extends WritableByteChannel { + val bytes = new Array[Byte](maxWriteSize) + var acceptNBytes = 0 + var pos = 0 + + override def write(src: ByteBuffer): Int = { + val length = math.min(acceptNBytes, src.remaining()) + src.get(bytes, 0, length) + acceptNBytes -= length + // verify we got the right data + (0 until length).foreach { idx => + assert(bytes(idx) === (pos + idx).toByte, s"; wrong data at ${pos + idx}") + } + pos += length + length + } + + override def isOpen: Boolean = true + + override def close(): Unit = {} + } + +} diff --git a/core/src/test/scala/org/apache/spark/io/ChunkedByteBufferSuite.scala b/core/src/test/scala/org/apache/spark/io/ChunkedByteBufferSuite.scala index 2107559572d78..ff117b1c21cb1 100644 --- a/core/src/test/scala/org/apache/spark/io/ChunkedByteBufferSuite.scala +++ b/core/src/test/scala/org/apache/spark/io/ChunkedByteBufferSuite.scala @@ -34,7 +34,7 @@ class ChunkedByteBufferSuite extends SparkFunSuite with SharedSparkContext { assert(emptyChunkedByteBuffer.getChunks().isEmpty) assert(emptyChunkedByteBuffer.toArray === Array.empty) assert(emptyChunkedByteBuffer.toByteBuffer.capacity() === 0) - assert(emptyChunkedByteBuffer.toNetty.capacity() === 0) + assert(emptyChunkedByteBuffer.toNetty.count() === 0) emptyChunkedByteBuffer.toInputStream(dispose = false).close() emptyChunkedByteBuffer.toInputStream(dispose = true).close() } diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDBarrierSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDBarrierSuite.scala new file mode 100644 index 0000000000000..d57ea4d5501e3 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/rdd/RDDBarrierSuite.scala @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rdd + +import org.apache.spark.{SharedSparkContext, SparkFunSuite} + +class RDDBarrierSuite extends SparkFunSuite with SharedSparkContext { + + test("create an RDDBarrier") { + val rdd = sc.parallelize(1 to 10, 4) + assert(rdd.isBarrier() === false) + + val rdd2 = rdd.barrier().mapPartitions(iter => iter) + assert(rdd2.isBarrier() === true) + } + + test("create an RDDBarrier in the middle of a chain of RDDs") { + val rdd = sc.parallelize(1 to 10, 4).map(x => x * 2) + val rdd2 = rdd.barrier().mapPartitions(iter => iter).map(x => (x, x + 1)) + assert(rdd2.isBarrier() === true) + } + + test("RDDBarrier with shuffle") { + val rdd = sc.parallelize(1 to 10, 4) + val rdd2 = rdd.barrier().mapPartitions(iter => iter).repartition(2) + assert(rdd2.isBarrier() === false) + } +} diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala index 5148ce05bd918..b143a468a1baf 100644 --- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala @@ -443,7 +443,7 @@ class RDDSuite extends SparkFunSuite with SharedSparkContext { map{x => List(x)}.toList, "Tried coalescing 9 partitions to 20 but didn't get 9 back") } - test("coalesced RDDs with partial locality") { + test("coalesced RDDs with partial locality") { // Make an RDD that has some locality preferences and some without. This can happen // with UnionRDD val data = sc.makeRDD((1 to 9).map(i => { @@ -846,6 +846,28 @@ class RDDSuite extends SparkFunSuite with SharedSparkContext { assert(partitions(1) === Seq((1, 3), (3, 8), (3, 8))) } + test("cartesian on empty RDD") { + val a = sc.emptyRDD[Int] + val b = sc.parallelize(1 to 3) + val cartesian_result = Array.empty[(Int, Int)] + assert(a.cartesian(a).collect().toList === cartesian_result) + assert(a.cartesian(b).collect().toList === cartesian_result) + assert(b.cartesian(a).collect().toList === cartesian_result) + } + + test("cartesian on non-empty RDDs") { + val a = sc.parallelize(1 to 3) + val b = sc.parallelize(2 to 4) + val c = sc.parallelize(1 to 1) + val a_cartesian_b = + Array((1, 2), (1, 3), (1, 4), (2, 2), (2, 3), (2, 4), (3, 2), (3, 3), (3, 4)) + val a_cartesian_c = Array((1, 1), (2, 1), (3, 1)) + val c_cartesian_a = Array((1, 1), (1, 2), (1, 3)) + assert(a.cartesian[Int](b).collect().toList.sorted === a_cartesian_b) + assert(a.cartesian[Int](c).collect().toList.sorted === a_cartesian_c) + assert(c.cartesian[Int](a).collect().toList.sorted === c_cartesian_a) + } + test("intersection") { val all = sc.parallelize(1 to 10) val evens = sc.parallelize(2 to 10 by 2) diff --git a/core/src/test/scala/org/apache/spark/scheduler/BarrierTaskContextSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/BarrierTaskContextSuite.scala new file mode 100644 index 0000000000000..36dd620a56853 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/scheduler/BarrierTaskContextSuite.scala @@ -0,0 +1,155 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler + +import scala.util.Random + +import org.apache.spark._ + +class BarrierTaskContextSuite extends SparkFunSuite with LocalSparkContext { + + test("global sync by barrier() call") { + val conf = new SparkConf() + // Init local cluster here so each barrier task runs in a separated process, thus `barrier()` + // call is actually useful. + .setMaster("local-cluster[4, 1, 1024]") + .setAppName("test-cluster") + sc = new SparkContext(conf) + val rdd = sc.makeRDD(1 to 10, 4) + val rdd2 = rdd.barrier().mapPartitions { it => + val context = BarrierTaskContext.get() + // Sleep for a random time before global sync. + Thread.sleep(Random.nextInt(1000)) + context.barrier() + Seq(System.currentTimeMillis()).iterator + } + + val times = rdd2.collect() + // All the tasks shall finish global sync within a short time slot. + assert(times.max - times.min <= 1000) + } + + test("support multiple barrier() call within a single task") { + val conf = new SparkConf() + .setMaster("local-cluster[4, 1, 1024]") + .setAppName("test-cluster") + sc = new SparkContext(conf) + val rdd = sc.makeRDD(1 to 10, 4) + val rdd2 = rdd.barrier().mapPartitions { it => + val context = BarrierTaskContext.get() + // Sleep for a random time before global sync. + Thread.sleep(Random.nextInt(1000)) + context.barrier() + val time1 = System.currentTimeMillis() + // Sleep for a random time between two global syncs. + Thread.sleep(Random.nextInt(1000)) + context.barrier() + val time2 = System.currentTimeMillis() + Seq((time1, time2)).iterator + } + + val times = rdd2.collect() + // All the tasks shall finish the first round of global sync within a short time slot. + val times1 = times.map(_._1) + assert(times1.max - times1.min <= 1000) + + // All the tasks shall finish the second round of global sync within a short time slot. + val times2 = times.map(_._2) + assert(times2.max - times2.min <= 1000) + } + + test("throw exception on barrier() call timeout") { + val conf = new SparkConf() + .set("spark.barrier.sync.timeout", "1") + .set("spark.test.noStageRetry", "true") + .setMaster("local-cluster[4, 1, 1024]") + .setAppName("test-cluster") + sc = new SparkContext(conf) + val rdd = sc.makeRDD(1 to 10, 4) + val rdd2 = rdd.barrier().mapPartitions { it => + val context = BarrierTaskContext.get() + // Task 3 shall sleep 2000ms to ensure barrier() call timeout + if (context.taskAttemptId == 3) { + Thread.sleep(2000) + } + context.barrier() + it + } + + val error = intercept[SparkException] { + rdd2.collect() + }.getMessage + assert(error.contains("The coordinator didn't get all barrier sync requests")) + assert(error.contains("within 1 second(s)")) + } + + test("throw exception if barrier() call doesn't happen on every task") { + val conf = new SparkConf() + .set("spark.barrier.sync.timeout", "1") + .set("spark.test.noStageRetry", "true") + .setMaster("local-cluster[4, 1, 1024]") + .setAppName("test-cluster") + sc = new SparkContext(conf) + val rdd = sc.makeRDD(1 to 10, 4) + val rdd2 = rdd.barrier().mapPartitions { it => + val context = BarrierTaskContext.get() + if (context.taskAttemptId != 0) { + context.barrier() + } + it + } + + val error = intercept[SparkException] { + rdd2.collect() + }.getMessage + assert(error.contains("The coordinator didn't get all barrier sync requests")) + assert(error.contains("within 1 second(s)")) + } + + test("throw exception if the number of barrier() calls are not the same on every task") { + val conf = new SparkConf() + .set("spark.barrier.sync.timeout", "1") + .set("spark.test.noStageRetry", "true") + .setMaster("local-cluster[4, 1, 1024]") + .setAppName("test-cluster") + sc = new SparkContext(conf) + val rdd = sc.makeRDD(1 to 10, 4) + val rdd2 = rdd.barrier().mapPartitions { it => + val context = BarrierTaskContext.get() + try { + if (context.taskAttemptId == 0) { + // Due to some non-obvious reason, the code can trigger an Exception and skip the + // following statements within the try ... catch block, including the first barrier() + // call. + throw new SparkException("test") + } + context.barrier() + } catch { + case e: Exception => // Do nothing + } + context.barrier() + it + } + + val error = intercept[SparkException] { + rdd2.collect() + }.getMessage + assert(error.contains("The coordinator didn't get all barrier sync requests")) + assert(error.contains("within 1 second(s)")) + } +} diff --git a/core/src/test/scala/org/apache/spark/scheduler/BlacklistIntegrationSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/BlacklistIntegrationSuite.scala index d3bbfd11d406d..fe22d70850c7d 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/BlacklistIntegrationSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/BlacklistIntegrationSuite.scala @@ -24,7 +24,6 @@ import org.apache.spark.internal.config class BlacklistIntegrationSuite extends SchedulerIntegrationSuite[MultiExecutorMockBackend]{ val badHost = "host-0" - val duration = Duration(10, SECONDS) /** * This backend just always fails if the task is executed on a bad host, but otherwise succeeds diff --git a/core/src/test/scala/org/apache/spark/scheduler/CoarseGrainedSchedulerBackendSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/CoarseGrainedSchedulerBackendSuite.scala index 04cccc67e328e..80c9c6f0422a8 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/CoarseGrainedSchedulerBackendSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/CoarseGrainedSchedulerBackendSuite.scala @@ -17,10 +17,18 @@ package org.apache.spark.scheduler +import java.util.concurrent.atomic.AtomicBoolean + +import scala.concurrent.duration._ + +import org.scalatest.concurrent.Eventually + import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkException, SparkFunSuite} +import org.apache.spark.rdd.RDD import org.apache.spark.util.{RpcUtils, SerializableBuffer} -class CoarseGrainedSchedulerBackendSuite extends SparkFunSuite with LocalSparkContext { +class CoarseGrainedSchedulerBackendSuite extends SparkFunSuite with LocalSparkContext + with Eventually { test("serialized task larger than max RPC message size") { val conf = new SparkConf @@ -38,4 +46,83 @@ class CoarseGrainedSchedulerBackendSuite extends SparkFunSuite with LocalSparkCo assert(smaller.size === 4) } + test("compute max number of concurrent tasks can be launched") { + val conf = new SparkConf() + .setMaster("local-cluster[4, 3, 1024]") + .setAppName("test") + sc = new SparkContext(conf) + eventually(timeout(10.seconds)) { + // Ensure all executors have been launched. + assert(sc.getExecutorIds().length == 4) + } + assert(sc.maxNumConcurrentTasks() == 12) + } + + test("compute max number of concurrent tasks can be launched when spark.task.cpus > 1") { + val conf = new SparkConf() + .set("spark.task.cpus", "2") + .setMaster("local-cluster[4, 3, 1024]") + .setAppName("test") + sc = new SparkContext(conf) + eventually(timeout(10.seconds)) { + // Ensure all executors have been launched. + assert(sc.getExecutorIds().length == 4) + } + // Each executor can only launch one task since `spark.task.cpus` is 2. + assert(sc.maxNumConcurrentTasks() == 4) + } + + test("compute max number of concurrent tasks can be launched when some executors are busy") { + val conf = new SparkConf() + .set("spark.task.cpus", "2") + .setMaster("local-cluster[4, 3, 1024]") + .setAppName("test") + sc = new SparkContext(conf) + val rdd = sc.parallelize(1 to 10, 4).mapPartitions { iter => + Thread.sleep(5000) + iter + } + var taskStarted = new AtomicBoolean(false) + var taskEnded = new AtomicBoolean(false) + val listener = new SparkListener() { + override def onTaskStart(taskStart: SparkListenerTaskStart): Unit = { + taskStarted.set(true) + } + + override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = { + taskEnded.set(true) + } + } + + try { + sc.addSparkListener(listener) + eventually(timeout(10.seconds)) { + // Ensure all executors have been launched. + assert(sc.getExecutorIds().length == 4) + } + + // Submit a job to trigger some tasks on active executors. + testSubmitJob(sc, rdd) + + eventually(timeout(10.seconds)) { + // Ensure some tasks have started and no task finished, so some executors must be busy. + assert(taskStarted.get() == true) + assert(taskEnded.get() == false) + // Assert we count in slots on both busy and free executors. + assert(sc.maxNumConcurrentTasks() == 4) + } + } finally { + sc.removeSparkListener(listener) + } + } + + private def testSubmitJob(sc: SparkContext, rdd: RDD[Int]): Unit = { + sc.submitJob( + rdd, + (iter: Iterator[Int]) => iter.toArray, + 0 until rdd.partitions.length, + { case (_, _) => return }: (Int, Array[Int]) => Unit, + { return } + ) + } } diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index 2987170bf5026..365eab0668ab2 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -30,7 +30,9 @@ import org.scalatest.time.SpanSugar._ import org.apache.spark._ import org.apache.spark.broadcast.BroadcastManager -import org.apache.spark.rdd.RDD +import org.apache.spark.executor.ExecutorMetrics +import org.apache.spark.internal.config +import org.apache.spark.rdd.{DeterministicLevel, RDD} import org.apache.spark.scheduler.SchedulingMode.SchedulingMode import org.apache.spark.shuffle.{FetchFailedException, MetadataFetchFailedException} import org.apache.spark.storage.{BlockId, BlockManagerId, BlockManagerMaster} @@ -56,6 +58,20 @@ class DAGSchedulerEventProcessLoopTester(dagScheduler: DAGScheduler) } +class MyCheckpointRDD( + sc: SparkContext, + numPartitions: Int, + dependencies: List[Dependency[_]], + locations: Seq[Seq[String]] = Nil, + @(transient @param) tracker: MapOutputTrackerMaster = null, + indeterminate: Boolean = false) + extends MyRDD(sc, numPartitions, dependencies, locations, tracker, indeterminate) { + + // Allow doCheckpoint() on this RDD. + override def compute(split: Partition, context: TaskContext): Iterator[(Int, Int)] = + Iterator.empty +} + /** * An RDD for passing to DAGScheduler. These RDDs will use the dependencies and * preferredLocations (if any) that are passed to them. They are deliberately not executable @@ -70,7 +86,8 @@ class MyRDD( numPartitions: Int, dependencies: List[Dependency[_]], locations: Seq[Seq[String]] = Nil, - @(transient @param) tracker: MapOutputTrackerMaster = null) + @(transient @param) tracker: MapOutputTrackerMaster = null, + indeterminate: Boolean = false) extends RDD[(Int, Int)](sc, dependencies) with Serializable { override def compute(split: Partition, context: TaskContext): Iterator[(Int, Int)] = @@ -80,6 +97,10 @@ class MyRDD( override def index: Int = i }).toArray + override protected def getOutputDeterministicLevel = { + if (indeterminate) DeterministicLevel.INDETERMINATE else super.getOutputDeterministicLevel + } + override def getPreferredLocations(partition: Partition): Seq[String] = { if (locations.isDefinedAt(partition.index)) { locations(partition.index) @@ -120,7 +141,8 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi override def executorHeartbeatReceived( execId: String, accumUpdates: Array[(Long, Seq[AccumulatorV2[_, _]])], - blockManagerId: BlockManagerId): Boolean = true + blockManagerId: BlockManagerId, + executorUpdates: ExecutorMetrics): Boolean = true override def submitTasks(taskSet: TaskSet) = { // normally done by TaskSetManager taskSet.tasks.foreach(_.epoch = mapOutputTracker.getEpoch) @@ -131,6 +153,8 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi } override def killTaskAttempt( taskId: Long, interruptThread: Boolean, reason: String): Boolean = false + override def killAllTaskAttempts( + stageId: Int, interruptThread: Boolean, reason: String): Unit = {} override def setDAGScheduler(dagScheduler: DAGScheduler) = {} override def defaultParallelism() = 2 override def executorLost(executorId: String, reason: ExecutorLossReason): Unit = {} @@ -213,7 +237,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi } private def init(testConf: SparkConf): Unit = { - sc = new SparkContext("local", "DAGSchedulerSuite", testConf) + sc = new SparkContext("local[2]", "DAGSchedulerSuite", testConf) sparkListener.submittedStageInfos.clear() sparkListener.successfulStages.clear() sparkListener.failedStages.clear() @@ -404,7 +428,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi // reset the test context with the right shuffle service config afterEach() val conf = new SparkConf() - conf.set("spark.shuffle.service.enabled", "true") + conf.set(config.SHUFFLE_SERVICE_ENABLED.key, "true") conf.set("spark.files.fetchFailure.unRegisterOutputOnHost", "true") init(conf) runEvent(ExecutorAdded("exec-hostA1", "hostA")) @@ -421,17 +445,17 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi // map stage1 completes successfully, with one task on each executor complete(taskSets(0), Seq( (Success, - MapStatus(BlockManagerId("exec-hostA1", "hostA", 12345), Array.fill[Long](1)(2))), + MapStatus(BlockManagerId("exec-hostA1", "hostA", 12345), Array.fill[Long](1)(2), 1)), (Success, - MapStatus(BlockManagerId("exec-hostA2", "hostA", 12345), Array.fill[Long](1)(2))), + MapStatus(BlockManagerId("exec-hostA2", "hostA", 12345), Array.fill[Long](1)(2), 1)), (Success, makeMapStatus("hostB", 1)) )) // map stage2 completes successfully, with one task on each executor complete(taskSets(1), Seq( (Success, - MapStatus(BlockManagerId("exec-hostA1", "hostA", 12345), Array.fill[Long](1)(2))), + MapStatus(BlockManagerId("exec-hostA1", "hostA", 12345), Array.fill[Long](1)(2), 1)), (Success, - MapStatus(BlockManagerId("exec-hostA2", "hostA", 12345), Array.fill[Long](1)(2))), + MapStatus(BlockManagerId("exec-hostA2", "hostA", 12345), Array.fill[Long](1)(2), 1)), (Success, makeMapStatus("hostB", 1)) )) // make sure our test setup is correct @@ -629,12 +653,17 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi taskId: Long, interruptThread: Boolean, reason: String): Boolean = { throw new UnsupportedOperationException } + override def killAllTaskAttempts( + stageId: Int, interruptThread: Boolean, reason: String): Unit = { + throw new UnsupportedOperationException + } override def setDAGScheduler(dagScheduler: DAGScheduler): Unit = {} override def defaultParallelism(): Int = 2 override def executorHeartbeatReceived( execId: String, accumUpdates: Array[(Long, Seq[AccumulatorV2[_, _]])], - blockManagerId: BlockManagerId): Boolean = true + blockManagerId: BlockManagerId, + executorMetrics: ExecutorMetrics): Boolean = true override def executorLost(executorId: String, reason: ExecutorLossReason): Unit = {} override def workerRemoved(workerId: String, host: String, message: String): Unit = {} override def applicationAttemptId(): Option[String] = None @@ -722,7 +751,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi // reset the test context with the right shuffle service config afterEach() val conf = new SparkConf() - conf.set("spark.shuffle.service.enabled", shuffleServiceOn.toString) + conf.set(config.SHUFFLE_SERVICE_ENABLED.key, shuffleServiceOn.toString) init(conf) assert(sc.env.blockManager.externalShuffleServiceEnabled == shuffleServiceOn) @@ -1055,6 +1084,91 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi assert(sparkListener.failedStages.size == 1) } + test("Retry all the tasks on a resubmitted attempt of a barrier stage caused by FetchFailure") { + val shuffleMapRdd = new MyRDD(sc, 2, Nil).barrier().mapPartitions(iter => iter) + val shuffleDep = new ShuffleDependency(shuffleMapRdd, new HashPartitioner(2)) + val shuffleId = shuffleDep.shuffleId + val reduceRdd = new MyRDD(sc, 2, List(shuffleDep), tracker = mapOutputTracker) + submit(reduceRdd, Array(0, 1)) + complete(taskSets(0), Seq( + (Success, makeMapStatus("hostA", reduceRdd.partitions.length)), + (Success, makeMapStatus("hostB", reduceRdd.partitions.length)))) + assert(mapOutputTracker.findMissingPartitions(shuffleId) === Some(Seq.empty)) + + // The first result task fails, with a fetch failure for the output from the first mapper. + runEvent(makeCompletionEvent( + taskSets(1).tasks(0), + FetchFailed(makeBlockManagerId("hostA"), shuffleId, 0, 0, "ignored"), + null)) + assert(mapOutputTracker.findMissingPartitions(shuffleId) === Some(Seq(0, 1))) + + scheduler.resubmitFailedStages() + // Complete the map stage. + completeShuffleMapStageSuccessfully(0, 1, numShufflePartitions = 2) + + // Complete the result stage. + completeNextResultStageWithSuccess(1, 1) + + sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) + assertDataStructuresEmpty() + } + + test("Retry all the tasks on a resubmitted attempt of a barrier stage caused by TaskKilled") { + val shuffleMapRdd = new MyRDD(sc, 2, Nil).barrier().mapPartitions(iter => iter) + val shuffleDep = new ShuffleDependency(shuffleMapRdd, new HashPartitioner(2)) + val shuffleId = shuffleDep.shuffleId + val reduceRdd = new MyRDD(sc, 2, List(shuffleDep), tracker = mapOutputTracker) + submit(reduceRdd, Array(0, 1)) + complete(taskSets(0), Seq( + (Success, makeMapStatus("hostA", reduceRdd.partitions.length)))) + assert(mapOutputTracker.findMissingPartitions(shuffleId) === Some(Seq(1))) + + // The second map task fails with TaskKilled. + runEvent(makeCompletionEvent( + taskSets(0).tasks(1), + TaskKilled("test"), + null)) + assert(sparkListener.failedStages === Seq(0)) + assert(mapOutputTracker.findMissingPartitions(shuffleId) === Some(Seq(0, 1))) + + scheduler.resubmitFailedStages() + // Complete the map stage. + completeShuffleMapStageSuccessfully(0, 1, numShufflePartitions = 2) + + // Complete the result stage. + completeNextResultStageWithSuccess(1, 0) + + sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) + assertDataStructuresEmpty() + } + + test("Fail the job if a barrier ResultTask failed") { + val shuffleMapRdd = new MyRDD(sc, 2, Nil) + val shuffleDep = new ShuffleDependency(shuffleMapRdd, new HashPartitioner(2)) + val shuffleId = shuffleDep.shuffleId + val reduceRdd = new MyRDD(sc, 2, List(shuffleDep), tracker = mapOutputTracker) + .barrier() + .mapPartitions(iter => iter) + submit(reduceRdd, Array(0, 1)) + + // Complete the map stage. + complete(taskSets(0), Seq( + (Success, makeMapStatus("hostA", 2)), + (Success, makeMapStatus("hostA", 2)))) + assert(mapOutputTracker.findMissingPartitions(shuffleId) === Some(Seq.empty)) + + // The first ResultTask fails + runEvent(makeCompletionEvent( + taskSets(1).tasks(0), + TaskKilled("test"), + null)) + + // Assert the stage has been cancelled. + sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) + assert(failure.getMessage.startsWith("Job aborted due to stage failure: Could not recover " + + "from a failed barrier ResultStage.")) + } + /** * This tests the case where another FetchFailed comes in while the map stage is getting * re-run. @@ -2322,9 +2436,6 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi // Runs a job that encounters a single fetch failure but succeeds on the second attempt def runJobWithTemporaryFetchFailure: Unit = { - object FailThisAttempt { - val _fail = new AtomicBoolean(true) - } val rdd1 = sc.makeRDD(Array(1, 2, 3, 4), 2).map(x => (x, 1)).groupByKey() val shuffleHandle = rdd1.dependencies.head.asInstanceOf[ShuffleDependency[_, _, _]].shuffleHandle @@ -2404,7 +2515,11 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi runEvent(makeCompletionEvent( taskSets(1).tasks(1), Success, makeMapStatus("hostA", 2))) - // Both tasks in rddB should be resubmitted, because none of them has succeeded truely. + // task(stageId=1, stageAttemptId=1, partitionId=1) should be marked completed when + // task(stageId=1, stageAttemptId=0, partitionId=1) finished + // ideally we would verify that but no way to get into task scheduler to verify + + // Both tasks in rddB should be resubmitted, because none of them has succeeded truly. // Complete the task(stageId=1, stageAttemptId=1, partitionId=0) successfully. // Task(stageId=1, stageAttemptId=1, partitionId=1) of this new active stage attempt // is still running. @@ -2413,19 +2528,21 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi runEvent(makeCompletionEvent( taskSets(3).tasks(0), Success, makeMapStatus("hostB", 2))) - // There should be no new attempt of stage submitted, - // because task(stageId=1, stageAttempt=1, partitionId=1) is still running in - // the current attempt (and hasn't completed successfully in any earlier attempts). - assert(taskSets.size === 4) + // At this point there should be no active task set for stageId=1 and we need + // to resubmit because the output from (stageId=1, stageAttemptId=0, partitionId=1) + // was ignored due to executor failure + assert(taskSets.size === 5) + assert(taskSets(4).stageId === 1 && taskSets(4).stageAttemptId === 2 + && taskSets(4).tasks.size === 1) - // Complete task(stageId=1, stageAttempt=1, partitionId=1) successfully. + // Complete task(stageId=1, stageAttempt=2, partitionId=1) successfully. runEvent(makeCompletionEvent( - taskSets(3).tasks(1), Success, makeMapStatus("hostB", 2))) + taskSets(4).tasks(0), Success, makeMapStatus("hostB", 2))) // Now the ResultStage should be submitted, because all of the tasks of rddB have // completed successfully on alive executors. - assert(taskSets.size === 5 && taskSets(4).tasks(0).isInstanceOf[ResultTask[_, _]]) - complete(taskSets(4), Seq( + assert(taskSets.size === 6 && taskSets(5).tasks(0).isInstanceOf[ResultTask[_, _]]) + complete(taskSets(5), Seq( (Success, 1), (Success, 1))) } @@ -2460,6 +2577,231 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi } } + test("Barrier task failures from the same stage attempt don't trigger multiple stage retries") { + val shuffleMapRdd = new MyRDD(sc, 2, Nil).barrier().mapPartitions(iter => iter) + val shuffleDep = new ShuffleDependency(shuffleMapRdd, new HashPartitioner(2)) + val shuffleId = shuffleDep.shuffleId + val reduceRdd = new MyRDD(sc, 2, List(shuffleDep), tracker = mapOutputTracker) + submit(reduceRdd, Array(0, 1)) + + val mapStageId = 0 + def countSubmittedMapStageAttempts(): Int = { + sparkListener.submittedStageInfos.count(_.stageId == mapStageId) + } + + // The map stage should have been submitted. + sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) + assert(countSubmittedMapStageAttempts() === 1) + + // The first map task fails with TaskKilled. + runEvent(makeCompletionEvent( + taskSets(0).tasks(0), + TaskKilled("test"), + null)) + assert(sparkListener.failedStages === Seq(0)) + + // The second map task fails with TaskKilled. + runEvent(makeCompletionEvent( + taskSets(0).tasks(1), + TaskKilled("test"), + null)) + + // Trigger resubmission of the failed map stage. + runEvent(ResubmitFailedStages) + sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) + + // Another attempt for the map stage should have been submitted, resulting in 2 total attempts. + assert(countSubmittedMapStageAttempts() === 2) + } + + test("Barrier task failures from a previous stage attempt don't trigger stage retry") { + val shuffleMapRdd = new MyRDD(sc, 2, Nil).barrier().mapPartitions(iter => iter) + val shuffleDep = new ShuffleDependency(shuffleMapRdd, new HashPartitioner(2)) + val shuffleId = shuffleDep.shuffleId + val reduceRdd = new MyRDD(sc, 2, List(shuffleDep), tracker = mapOutputTracker) + submit(reduceRdd, Array(0, 1)) + + val mapStageId = 0 + def countSubmittedMapStageAttempts(): Int = { + sparkListener.submittedStageInfos.count(_.stageId == mapStageId) + } + + // The map stage should have been submitted. + sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) + assert(countSubmittedMapStageAttempts() === 1) + + // The first map task fails with TaskKilled. + runEvent(makeCompletionEvent( + taskSets(0).tasks(0), + TaskKilled("test"), + null)) + assert(sparkListener.failedStages === Seq(0)) + + // Trigger resubmission of the failed map stage. + runEvent(ResubmitFailedStages) + sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) + + // Another attempt for the map stage should have been submitted, resulting in 2 total attempts. + assert(countSubmittedMapStageAttempts() === 2) + + // The second map task fails with TaskKilled. + runEvent(makeCompletionEvent( + taskSets(0).tasks(1), + TaskKilled("test"), + null)) + + // The second map task failure doesn't trigger stage retry. + runEvent(ResubmitFailedStages) + sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) + assert(countSubmittedMapStageAttempts() === 2) + } + + test("SPARK-23207: retry all the succeeding stages when the map stage is indeterminate") { + val shuffleMapRdd1 = new MyRDD(sc, 2, Nil, indeterminate = true) + + val shuffleDep1 = new ShuffleDependency(shuffleMapRdd1, new HashPartitioner(2)) + val shuffleId1 = shuffleDep1.shuffleId + val shuffleMapRdd2 = new MyRDD(sc, 2, List(shuffleDep1), tracker = mapOutputTracker) + + val shuffleDep2 = new ShuffleDependency(shuffleMapRdd2, new HashPartitioner(2)) + val shuffleId2 = shuffleDep2.shuffleId + val finalRdd = new MyRDD(sc, 2, List(shuffleDep2), tracker = mapOutputTracker) + + submit(finalRdd, Array(0, 1)) + + // Finish the first shuffle map stage. + complete(taskSets(0), Seq( + (Success, makeMapStatus("hostA", 2)), + (Success, makeMapStatus("hostB", 2)))) + assert(mapOutputTracker.findMissingPartitions(shuffleId1) === Some(Seq.empty)) + + // Finish the second shuffle map stage. + complete(taskSets(1), Seq( + (Success, makeMapStatus("hostC", 2)), + (Success, makeMapStatus("hostD", 2)))) + assert(mapOutputTracker.findMissingPartitions(shuffleId2) === Some(Seq.empty)) + + // The first task of the final stage failed with fetch failure + runEvent(makeCompletionEvent( + taskSets(2).tasks(0), + FetchFailed(makeBlockManagerId("hostC"), shuffleId2, 0, 0, "ignored"), + null)) + + val failedStages = scheduler.failedStages.toSeq + assert(failedStages.length == 2) + // Shuffle blocks of "hostC" is lost, so first task of the `shuffleMapRdd2` needs to retry. + assert(failedStages.collect { + case stage: ShuffleMapStage if stage.shuffleDep.shuffleId == shuffleId2 => stage + }.head.findMissingPartitions() == Seq(0)) + // The result stage is still waiting for its 2 tasks to complete + assert(failedStages.collect { + case stage: ResultStage => stage + }.head.findMissingPartitions() == Seq(0, 1)) + + scheduler.resubmitFailedStages() + + // The first task of the `shuffleMapRdd2` failed with fetch failure + runEvent(makeCompletionEvent( + taskSets(3).tasks(0), + FetchFailed(makeBlockManagerId("hostA"), shuffleId1, 0, 0, "ignored"), + null)) + + // The job should fail because Spark can't rollback the shuffle map stage. + assert(failure != null && failure.getMessage.contains("Spark cannot rollback")) + } + + private def assertResultStageFailToRollback(mapRdd: MyRDD): Unit = { + val shuffleDep = new ShuffleDependency(mapRdd, new HashPartitioner(2)) + val shuffleId = shuffleDep.shuffleId + val finalRdd = new MyRDD(sc, 2, List(shuffleDep), tracker = mapOutputTracker) + + submit(finalRdd, Array(0, 1)) + + completeShuffleMapStageSuccessfully(taskSets.length - 1, 0, numShufflePartitions = 2) + assert(mapOutputTracker.findMissingPartitions(shuffleId) === Some(Seq.empty)) + + // Finish the first task of the result stage + runEvent(makeCompletionEvent( + taskSets.last.tasks(0), Success, 42, + Seq.empty, createFakeTaskInfoWithId(0))) + + // Fail the second task with FetchFailed. + runEvent(makeCompletionEvent( + taskSets.last.tasks(1), + FetchFailed(makeBlockManagerId("hostA"), shuffleId, 0, 0, "ignored"), + null)) + + // The job should fail because Spark can't rollback the result stage. + assert(failure != null && failure.getMessage.contains("Spark cannot rollback")) + } + + test("SPARK-23207: cannot rollback a result stage") { + val shuffleMapRdd = new MyRDD(sc, 2, Nil, indeterminate = true) + assertResultStageFailToRollback(shuffleMapRdd) + } + + test("SPARK-23207: local checkpoint fail to rollback (checkpointed before)") { + val shuffleMapRdd = new MyCheckpointRDD(sc, 2, Nil, indeterminate = true) + shuffleMapRdd.localCheckpoint() + shuffleMapRdd.doCheckpoint() + assertResultStageFailToRollback(shuffleMapRdd) + } + + test("SPARK-23207: local checkpoint fail to rollback (checkpointing now)") { + val shuffleMapRdd = new MyCheckpointRDD(sc, 2, Nil, indeterminate = true) + shuffleMapRdd.localCheckpoint() + assertResultStageFailToRollback(shuffleMapRdd) + } + + private def assertResultStageNotRollbacked(mapRdd: MyRDD): Unit = { + val shuffleDep = new ShuffleDependency(mapRdd, new HashPartitioner(2)) + val shuffleId = shuffleDep.shuffleId + val finalRdd = new MyRDD(sc, 2, List(shuffleDep), tracker = mapOutputTracker) + + submit(finalRdd, Array(0, 1)) + + completeShuffleMapStageSuccessfully(taskSets.length - 1, 0, numShufflePartitions = 2) + assert(mapOutputTracker.findMissingPartitions(shuffleId) === Some(Seq.empty)) + + // Finish the first task of the result stage + runEvent(makeCompletionEvent( + taskSets.last.tasks(0), Success, 42, + Seq.empty, createFakeTaskInfoWithId(0))) + + // Fail the second task with FetchFailed. + runEvent(makeCompletionEvent( + taskSets.last.tasks(1), + FetchFailed(makeBlockManagerId("hostA"), shuffleId, 0, 0, "ignored"), + null)) + + assert(failure == null, "job should not fail") + val failedStages = scheduler.failedStages.toSeq + assert(failedStages.length == 2) + // Shuffle blocks of "hostA" is lost, so first task of the `shuffleMapRdd2` needs to retry. + assert(failedStages.collect { + case stage: ShuffleMapStage if stage.shuffleDep.shuffleId == shuffleId => stage + }.head.findMissingPartitions() == Seq(0)) + // The first task of result stage remains completed. + assert(failedStages.collect { + case stage: ResultStage => stage + }.head.findMissingPartitions() == Seq(1)) + } + + test("SPARK-23207: reliable checkpoint can avoid rollback (checkpointed before)") { + sc.setCheckpointDir(Utils.createTempDir().getCanonicalPath) + val shuffleMapRdd = new MyCheckpointRDD(sc, 2, Nil, indeterminate = true) + shuffleMapRdd.checkpoint() + shuffleMapRdd.doCheckpoint() + assertResultStageNotRollbacked(shuffleMapRdd) + } + + test("SPARK-23207: reliable checkpoint fail to rollback (checkpointing now)") { + sc.setCheckpointDir(Utils.createTempDir().getCanonicalPath) + val shuffleMapRdd = new MyCheckpointRDD(sc, 2, Nil, indeterminate = true) + shuffleMapRdd.checkpoint() + assertResultStageFailToRollback(shuffleMapRdd) + } + /** * Assert that the supplied TaskSet has exactly the given hosts as its preferred locations. * Note that this checks only the host and not the executor ID. @@ -2515,8 +2857,12 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi object DAGSchedulerSuite { def makeMapStatus(host: String, reduces: Int, sizes: Byte = 2): MapStatus = - MapStatus(makeBlockManagerId(host), Array.fill[Long](reduces)(sizes)) + MapStatus(makeBlockManagerId(host), Array.fill[Long](reduces)(sizes), 1) def makeBlockManagerId(host: String): BlockManagerId = BlockManagerId("exec-" + host, host, 12345) } + +object FailThisAttempt { + val _fail = new AtomicBoolean(true) +} diff --git a/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala index a9e92fa07b9dd..cecd6996df7bd 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala @@ -19,7 +19,9 @@ package org.apache.spark.scheduler import java.io.{File, FileOutputStream, InputStream, IOException} +import scala.collection.immutable.Map import scala.collection.mutable +import scala.collection.mutable.Set import scala.io.Source import org.apache.hadoop.fs.Path @@ -29,11 +31,14 @@ import org.scalatest.BeforeAndAfter import org.apache.spark._ import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.executor.{ExecutorMetrics, TaskMetrics} import org.apache.spark.internal.Logging import org.apache.spark.io._ -import org.apache.spark.metrics.MetricsSystem +import org.apache.spark.metrics.{ExecutorMetricType, MetricsSystem} +import org.apache.spark.scheduler.cluster.ExecutorInfo import org.apache.spark.util.{JsonProtocol, Utils} + /** * Test whether EventLoggingListener logs events properly. * @@ -43,6 +48,7 @@ import org.apache.spark.util.{JsonProtocol, Utils} */ class EventLoggingListenerSuite extends SparkFunSuite with LocalSparkContext with BeforeAndAfter with Logging { + import EventLoggingListenerSuite._ private val fileSystem = Utils.getHadoopFileSystem("/", @@ -137,6 +143,10 @@ class EventLoggingListenerSuite extends SparkFunSuite with LocalSparkContext wit "a fine:mind$dollar{bills}.1", None, Some("lz4"))) } + test("Executor metrics update") { + testStageExecutorMetricsEventLogging() + } + /* ----------------- * * Actual test logic * * ----------------- */ @@ -251,6 +261,214 @@ class EventLoggingListenerSuite extends SparkFunSuite with LocalSparkContext wit } } + /** + * Test stage executor metrics logging functionality. This checks that peak + * values from SparkListenerExecutorMetricsUpdate events during a stage are + * logged in a StageExecutorMetrics event for each executor at stage completion. + */ + private def testStageExecutorMetricsEventLogging() { + val conf = getLoggingConf(testDirPath, None) + val logName = "stageExecutorMetrics-test" + val eventLogger = new EventLoggingListener(logName, None, testDirPath.toUri(), conf) + val listenerBus = new LiveListenerBus(conf) + + // Events to post. + val events = Array( + SparkListenerApplicationStart("executionMetrics", None, + 1L, "update", None), + createExecutorAddedEvent(1), + createExecutorAddedEvent(2), + createStageSubmittedEvent(0), + // receive 3 metric updates from each executor with just stage 0 running, + // with different peak updates for each executor + createExecutorMetricsUpdateEvent(1, + new ExecutorMetrics(Array(4000L, 50L, 20L, 0L, 40L, 0L, 60L, 0L, 70L, 20L))), + createExecutorMetricsUpdateEvent(2, + new ExecutorMetrics(Array(1500L, 50L, 20L, 0L, 0L, 0L, 20L, 0L, 70L, 0L))), + // exec 1: new stage 0 peaks for metrics at indexes: 2, 4, 6 + createExecutorMetricsUpdateEvent(1, + new ExecutorMetrics(Array(4000L, 50L, 50L, 0L, 50L, 0L, 100L, 0L, 70L, 20L))), + // exec 2: new stage 0 peaks for metrics at indexes: 0, 4, 6 + createExecutorMetricsUpdateEvent(2, + new ExecutorMetrics(Array(2000L, 50L, 10L, 0L, 10L, 0L, 30L, 0L, 70L, 0L))), + // exec 1: new stage 0 peaks for metrics at indexes: 5, 7 + createExecutorMetricsUpdateEvent(1, + new ExecutorMetrics(Array(2000L, 40L, 50L, 0L, 40L, 10L, 90L, 10L, 50L, 0L))), + // exec 2: new stage 0 peaks for metrics at indexes: 0, 5, 6, 7, 8 + createExecutorMetricsUpdateEvent(2, + new ExecutorMetrics(Array(3500L, 50L, 15L, 0L, 10L, 10L, 35L, 10L, 80L, 0L))), + // now start stage 1, one more metric update for each executor, and new + // peaks for some stage 1 metrics (as listed), initialize stage 1 peaks + createStageSubmittedEvent(1), + // exec 1: new stage 0 peaks for metrics at indexes: 0, 3, 7; initialize stage 1 peaks + createExecutorMetricsUpdateEvent(1, + new ExecutorMetrics(Array(5000L, 30L, 50L, 20L, 30L, 10L, 80L, 30L, 50L, 0L))), + // exec 2: new stage 0 peaks for metrics at indexes: 0, 1, 2, 3, 6, 7, 9; + // initialize stage 1 peaks + createExecutorMetricsUpdateEvent(2, + new ExecutorMetrics(Array(7000L, 70L, 50L, 20L, 0L, 10L, 50L, 30L, 10L, 40L))), + // complete stage 0, and 3 more updates for each executor with just + // stage 1 running + createStageCompletedEvent(0), + // exec 1: new stage 1 peaks for metrics at indexes: 0, 1, 3 + createExecutorMetricsUpdateEvent(1, + new ExecutorMetrics(Array(6000L, 70L, 20L, 30L, 10L, 0L, 30L, 30L, 30L, 0L))), + // enew ExecutorMetrics(xec 2: new stage 1 peaks for metrics at indexes: 3, 4, 7, 8 + createExecutorMetricsUpdateEvent(2, + new ExecutorMetrics(Array(5500L, 30L, 20L, 40L, 10L, 0L, 30L, 40L, 40L, 20L))), + // exec 1: new stage 1 peaks for metrics at indexes: 0, 4, 5, 7 + createExecutorMetricsUpdateEvent(1, + new ExecutorMetrics(Array(7000L, 70L, 5L, 25L, 60L, 30L, 65L, 55L, 30L, 0L))), + // exec 2: new stage 1 peak for metrics at index: 7 + createExecutorMetricsUpdateEvent(2, + new ExecutorMetrics(Array(5500L, 40L, 25L, 30L, 10L, 30L, 35L, 60L, 0L, 20L))), + // exec 1: no new stage 1 peaks + createExecutorMetricsUpdateEvent(1, + new ExecutorMetrics(Array(5500L, 70L, 15L, 20L, 55L, 20L, 70L, 40L, 20L, 0L))), + createExecutorRemovedEvent(1), + // exec 2: new stage 1 peak for metrics at index: 6 + createExecutorMetricsUpdateEvent(2, + new ExecutorMetrics(Array(4000L, 20L, 25L, 30L, 10L, 30L, 35L, 60L, 0L, 0L))), + createStageCompletedEvent(1), + SparkListenerApplicationEnd(1000L)) + + // play the events for the event logger + eventLogger.start() + listenerBus.start(Mockito.mock(classOf[SparkContext]), Mockito.mock(classOf[MetricsSystem])) + listenerBus.addToEventLogQueue(eventLogger) + events.foreach(event => listenerBus.post(event)) + listenerBus.stop() + eventLogger.stop() + + // expected StageExecutorMetrics, for the given stage id and executor id + val expectedMetricsEvents: Map[(Int, String), SparkListenerStageExecutorMetrics] = + Map( + ((0, "1"), + new SparkListenerStageExecutorMetrics("1", 0, 0, + new ExecutorMetrics(Array(5000L, 50L, 50L, 20L, 50L, 10L, 100L, 30L, 70L, 20L)))), + ((0, "2"), + new SparkListenerStageExecutorMetrics("2", 0, 0, + new ExecutorMetrics(Array(7000L, 70L, 50L, 20L, 10L, 10L, 50L, 30L, 80L, 40L)))), + ((1, "1"), + new SparkListenerStageExecutorMetrics("1", 1, 0, + new ExecutorMetrics(Array(7000L, 70L, 50L, 30L, 60L, 30L, 80L, 55L, 50L, 0L)))), + ((1, "2"), + new SparkListenerStageExecutorMetrics("2", 1, 0, + new ExecutorMetrics(Array(7000L, 70L, 50L, 40L, 10L, 30L, 50L, 60L, 40L, 40L))))) + + // Verify the log file contains the expected events. + // Posted events should be logged, except for ExecutorMetricsUpdate events -- these + // are consolidated, and the peak values for each stage are logged at stage end. + val logData = EventLoggingListener.openEventLog(new Path(eventLogger.logPath), fileSystem) + try { + val lines = readLines(logData) + val logStart = SparkListenerLogStart(SPARK_VERSION) + assert(lines.size === 14) + assert(lines(0).contains("SparkListenerLogStart")) + assert(lines(1).contains("SparkListenerApplicationStart")) + assert(JsonProtocol.sparkEventFromJson(parse(lines(0))) === logStart) + var logIdx = 1 + events.foreach {event => + event match { + case metricsUpdate: SparkListenerExecutorMetricsUpdate => + case stageCompleted: SparkListenerStageCompleted => + val execIds = Set[String]() + (1 to 2).foreach { _ => + val execId = checkStageExecutorMetrics(lines(logIdx), + stageCompleted.stageInfo.stageId, expectedMetricsEvents) + execIds += execId + logIdx += 1 + } + assert(execIds.size == 2) // check that each executor was logged + checkEvent(lines(logIdx), event) + logIdx += 1 + case _ => + checkEvent(lines(logIdx), event) + logIdx += 1 + } + } + } finally { + logData.close() + } + } + + private def createStageSubmittedEvent(stageId: Int) = { + SparkListenerStageSubmitted(new StageInfo(stageId, 0, stageId.toString, 0, + Seq.empty, Seq.empty, "details")) + } + + private def createStageCompletedEvent(stageId: Int) = { + SparkListenerStageCompleted(new StageInfo(stageId, 0, stageId.toString, 0, + Seq.empty, Seq.empty, "details")) + } + + private def createExecutorAddedEvent(executorId: Int) = { + SparkListenerExecutorAdded(0L, executorId.toString, new ExecutorInfo("host1", 1, Map.empty)) + } + + private def createExecutorRemovedEvent(executorId: Int) = { + SparkListenerExecutorRemoved(0L, executorId.toString, "test") + } + + private def createExecutorMetricsUpdateEvent( + executorId: Int, + executorMetrics: ExecutorMetrics): SparkListenerExecutorMetricsUpdate = { + val taskMetrics = TaskMetrics.empty + taskMetrics.incDiskBytesSpilled(111) + taskMetrics.incMemoryBytesSpilled(222) + val accum = Array((333L, 1, 1, taskMetrics.accumulators().map(AccumulatorSuite.makeInfo))) + SparkListenerExecutorMetricsUpdate(executorId.toString, accum, Some(executorMetrics)) + } + + /** Check that the Spark history log line matches the expected event. */ + private def checkEvent(line: String, event: SparkListenerEvent): Unit = { + assert(line.contains(event.getClass.toString.split("\\.").last)) + val parsed = JsonProtocol.sparkEventFromJson(parse(line)) + assert(parsed.getClass === event.getClass) + (event, parsed) match { + case (expected: SparkListenerStageSubmitted, actual: SparkListenerStageSubmitted) => + // accumulables can be different, so only check the stage Id + assert(expected.stageInfo.stageId == actual.stageInfo.stageId) + case (expected: SparkListenerStageCompleted, actual: SparkListenerStageCompleted) => + // accumulables can be different, so only check the stage Id + assert(expected.stageInfo.stageId == actual.stageInfo.stageId) + case (expected: SparkListenerEvent, actual: SparkListenerEvent) => + assert(expected === actual) + } + } + + /** + * Check that the Spark history log line is an StageExecutorMetrics event, and matches the + * expected value for the stage and executor. + * + * @param line the Spark history log line + * @param stageId the stage ID the ExecutorMetricsUpdate is associated with + * @param expectedEvents map of expected ExecutorMetricsUpdate events, for (stageId, executorId) + */ + private def checkStageExecutorMetrics( + line: String, + stageId: Int, + expectedEvents: Map[(Int, String), SparkListenerStageExecutorMetrics]): String = { + JsonProtocol.sparkEventFromJson(parse(line)) match { + case executorMetrics: SparkListenerStageExecutorMetrics => + expectedEvents.get((stageId, executorMetrics.execId)) match { + case Some(expectedMetrics) => + assert(executorMetrics.execId === expectedMetrics.execId) + assert(executorMetrics.stageId === expectedMetrics.stageId) + assert(executorMetrics.stageAttemptId === expectedMetrics.stageAttemptId) + ExecutorMetricType.values.foreach { metricType => + assert(executorMetrics.executorMetrics.getMetricValue(metricType) === + expectedMetrics.executorMetrics.getMetricValue(metricType)) + } + case None => + assert(false) + } + executorMetrics.execId + case _ => + fail("expecting SparkListenerStageExecutorMetrics") + } + } + private def readLines(in: InputStream): Seq[String] = { Source.fromInputStream(in).getLines().toSeq } @@ -299,6 +517,7 @@ object EventLoggingListenerSuite { conf.set("spark.eventLog.compress", "true") conf.set("spark.io.compression.codec", codec) } + conf.set("spark.eventLog.logStageExecutorMetrics.enabled", "true") conf } diff --git a/core/src/test/scala/org/apache/spark/scheduler/ExternalClusterManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/ExternalClusterManagerSuite.scala index a4e4ea7cd2894..0621c98d41184 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/ExternalClusterManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/ExternalClusterManagerSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.scheduler import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkFunSuite} +import org.apache.spark.executor.ExecutorMetrics import org.apache.spark.scheduler.SchedulingMode.SchedulingMode import org.apache.spark.storage.BlockManagerId import org.apache.spark.util.AccumulatorV2 @@ -69,6 +70,7 @@ private class DummySchedulerBackend extends SchedulerBackend { def stop() {} def reviveOffers() {} def defaultParallelism(): Int = 1 + def maxNumConcurrentTasks(): Int = 0 } private class DummyTaskScheduler extends TaskScheduler { @@ -81,6 +83,8 @@ private class DummyTaskScheduler extends TaskScheduler { override def cancelTasks(stageId: Int, interruptThread: Boolean): Unit = {} override def killTaskAttempt( taskId: Long, interruptThread: Boolean, reason: String): Boolean = false + override def killAllTaskAttempts( + stageId: Int, interruptThread: Boolean, reason: String): Unit = {} override def setDAGScheduler(dagScheduler: DAGScheduler): Unit = {} override def defaultParallelism(): Int = 2 override def executorLost(executorId: String, reason: ExecutorLossReason): Unit = {} @@ -89,5 +93,6 @@ private class DummyTaskScheduler extends TaskScheduler { def executorHeartbeatReceived( execId: String, accumUpdates: Array[(Long, Seq[AccumulatorV2[_, _]])], - blockManagerId: BlockManagerId): Boolean = true + blockManagerId: BlockManagerId, + executorMetrics: ExecutorMetrics): Boolean = true } diff --git a/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala b/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala index 109d4a0a870b8..b29d32f7b35c5 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala @@ -27,8 +27,10 @@ class FakeTask( partitionId: Int, prefLocs: Seq[TaskLocation] = Nil, serializedTaskMetrics: Array[Byte] = - SparkEnv.get.closureSerializer.newInstance().serialize(TaskMetrics.registered).array()) - extends Task[Int](stageId, 0, partitionId, new Properties, serializedTaskMetrics) { + SparkEnv.get.closureSerializer.newInstance().serialize(TaskMetrics.registered).array(), + isBarrier: Boolean = false) + extends Task[Int](stageId, 0, partitionId, new Properties, serializedTaskMetrics, + isBarrier = isBarrier) { override def runTask(context: TaskContext): Int = 0 override def preferredLocations: Seq[TaskLocation] = prefLocs @@ -74,4 +76,22 @@ object FakeTask { } new TaskSet(tasks, stageId, stageAttemptId, priority = 0, null) } + + def createBarrierTaskSet(numTasks: Int, prefLocs: Seq[TaskLocation]*): TaskSet = { + createBarrierTaskSet(numTasks, stageId = 0, stageAttempId = 0, prefLocs: _*) + } + + def createBarrierTaskSet( + numTasks: Int, + stageId: Int, + stageAttempId: Int, + prefLocs: Seq[TaskLocation]*): TaskSet = { + if (prefLocs.size != 0 && prefLocs.size != numTasks) { + throw new IllegalArgumentException("Wrong number of task locations") + } + val tasks = Array.tabulate[Task[_]](numTasks) { i => + new FakeTask(stageId, i, if (prefLocs.size != 0) prefLocs(i) else Nil, isBarrier = true) + } + new TaskSet(tasks, stageId, stageAttempId, priority = 0, null) + } } diff --git a/core/src/test/scala/org/apache/spark/scheduler/MapStatusSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/MapStatusSuite.scala index 354e6386fa60e..555e48bd28aa0 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/MapStatusSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/MapStatusSuite.scala @@ -60,7 +60,7 @@ class MapStatusSuite extends SparkFunSuite { stddev <- Seq(0.0, 0.01, 0.5, 1.0) ) { val sizes = Array.fill[Long](numSizes)(abs(round(Random.nextGaussian() * stddev)) + mean) - val status = MapStatus(BlockManagerId("a", "b", 10), sizes) + val status = MapStatus(BlockManagerId("a", "b", 10), sizes, 1) val status1 = compressAndDecompressMapStatus(status) for (i <- 0 until numSizes) { if (sizes(i) != 0) { @@ -74,7 +74,7 @@ class MapStatusSuite extends SparkFunSuite { test("large tasks should use " + classOf[HighlyCompressedMapStatus].getName) { val sizes = Array.fill[Long](2001)(150L) - val status = MapStatus(null, sizes) + val status = MapStatus(null, sizes, 1) assert(status.isInstanceOf[HighlyCompressedMapStatus]) assert(status.getSizeForBlock(10) === 150L) assert(status.getSizeForBlock(50) === 150L) @@ -86,7 +86,7 @@ class MapStatusSuite extends SparkFunSuite { val sizes = Array.tabulate[Long](3000) { i => i.toLong } val avg = sizes.sum / sizes.count(_ != 0) val loc = BlockManagerId("a", "b", 10) - val status = MapStatus(loc, sizes) + val status = MapStatus(loc, sizes, 1) val status1 = compressAndDecompressMapStatus(status) assert(status1.isInstanceOf[HighlyCompressedMapStatus]) assert(status1.location == loc) @@ -108,7 +108,7 @@ class MapStatusSuite extends SparkFunSuite { val smallBlockSizes = sizes.filter(n => n > 0 && n < threshold) val avg = smallBlockSizes.sum / smallBlockSizes.length val loc = BlockManagerId("a", "b", 10) - val status = MapStatus(loc, sizes) + val status = MapStatus(loc, sizes, 1) val status1 = compressAndDecompressMapStatus(status) assert(status1.isInstanceOf[HighlyCompressedMapStatus]) assert(status1.location == loc) @@ -164,7 +164,7 @@ class MapStatusSuite extends SparkFunSuite { SparkEnv.set(env) // Value of element in sizes is equal to the corresponding index. val sizes = (0L to 2000L).toArray - val status1 = MapStatus(BlockManagerId("exec-0", "host-0", 100), sizes) + val status1 = MapStatus(BlockManagerId("exec-0", "host-0", 100), sizes, 1) val arrayStream = new ByteArrayOutputStream(102400) val objectOutputStream = new ObjectOutputStream(arrayStream) assert(status1.isInstanceOf[HighlyCompressedMapStatus]) @@ -196,19 +196,19 @@ class MapStatusSuite extends SparkFunSuite { SparkEnv.set(env) val sizes = Array.fill[Long](500)(150L) // Test default value - val status = MapStatus(null, sizes) + val status = MapStatus(null, sizes, 1) assert(status.isInstanceOf[CompressedMapStatus]) // Test Non-positive values for (s <- -1 to 0) { assertThrows[IllegalArgumentException] { conf.set(config.SHUFFLE_MIN_NUM_PARTS_TO_HIGHLY_COMPRESS, s) - val status = MapStatus(null, sizes) + val status = MapStatus(null, sizes, 1) } } // Test positive values Seq(1, 100, 499, 500, 501).foreach { s => conf.set(config.SHUFFLE_MIN_NUM_PARTS_TO_HIGHLY_COMPRESS, s) - val status = MapStatus(null, sizes) + val status = MapStatus(null, sizes, 1) if(sizes.length > s) { assert(status.isInstanceOf[HighlyCompressedMapStatus]) } else { diff --git a/core/src/test/scala/org/apache/spark/scheduler/ReplayListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/ReplayListenerSuite.scala index e24d550a62665..d1113c7e0b103 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/ReplayListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/ReplayListenerSuite.scala @@ -22,6 +22,7 @@ import java.net.URI import java.util.concurrent.atomic.AtomicInteger import org.apache.hadoop.fs.Path +import org.json4s.JsonAST.JValue import org.json4s.jackson.JsonMethods._ import org.scalatest.BeforeAndAfter @@ -217,7 +218,9 @@ class ReplayListenerSuite extends SparkFunSuite with BeforeAndAfter with LocalSp // Verify the same events are replayed in the same order assert(sc.eventLogger.isDefined) - val originalEvents = sc.eventLogger.get.loggedEvents + val originalEvents = sc.eventLogger.get.loggedEvents.filter { e => + !JsonProtocol.sparkEventFromJson(e).isInstanceOf[SparkListenerStageExecutorMetrics] + } val replayedEvents = eventMonster.loggedEvents originalEvents.zip(replayedEvents).foreach { case (e1, e2) => // Don't compare the JSON here because accumulators in StageInfo may be out of order diff --git a/core/src/test/scala/org/apache/spark/scheduler/SchedulerIntegrationSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/SchedulerIntegrationSuite.scala index 75ea409e16b4b..ff0f99b5c94d0 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/SchedulerIntegrationSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/SchedulerIntegrationSuite.scala @@ -51,6 +51,9 @@ abstract class SchedulerIntegrationSuite[T <: MockBackend: ClassTag] extends Spa var taskScheduler: TestTaskScheduler = null var scheduler: DAGScheduler = null var backend: T = _ + // Even though the tests aren't doing much, occassionally we see flakiness from pauses over + // a second (probably from GC?) so we leave a long timeout in here + val duration = Duration(10, SECONDS) override def beforeEach(): Unit = { if (taskScheduler != null) { @@ -385,6 +388,8 @@ private[spark] abstract class MockBackend( }.toIndexedSeq } + override def maxNumConcurrentTasks(): Int = 0 + /** * This is called by the scheduler whenever it has tasks it would like to schedule, when a tasks * completes (which will be in a result-getter thread), and by the reviveOffers thread for delay @@ -398,7 +403,8 @@ private[spark] abstract class MockBackend( // get the task now, since that requires a lock on TaskSchedulerImpl, to prevent individual // tests from introducing a race if they need it. val newTasks = newTaskDescriptions.map { taskDescription => - val taskSet = taskScheduler.taskIdToTaskSetManager(taskDescription.taskId).taskSet + val taskSet = + Option(taskScheduler.taskIdToTaskSetManager.get(taskDescription.taskId).taskSet).get val task = taskSet.tasks(taskDescription.index) (taskDescription, task) } @@ -536,7 +542,6 @@ class BasicSchedulerIntegrationSuite extends SchedulerIntegrationSuite[SingleCor } withBackend(runBackend _) { val jobFuture = submit(new MockRDD(sc, 10, Nil), (0 until 10).toArray) - val duration = Duration(1, SECONDS) awaitJobTermination(jobFuture, duration) } assert(results === (0 until 10).map { _ -> 42 }.toMap) @@ -589,7 +594,6 @@ class BasicSchedulerIntegrationSuite extends SchedulerIntegrationSuite[SingleCor } withBackend(runBackend _) { val jobFuture = submit(d, (0 until 30).toArray) - val duration = Duration(1, SECONDS) awaitJobTermination(jobFuture, duration) } assert(results === (0 until 30).map { idx => idx -> (4321 + idx) }.toMap) @@ -631,7 +635,6 @@ class BasicSchedulerIntegrationSuite extends SchedulerIntegrationSuite[SingleCor } withBackend(runBackend _) { val jobFuture = submit(shuffledRdd, (0 until 10).toArray) - val duration = Duration(1, SECONDS) awaitJobTermination(jobFuture, duration) } assertDataStructuresEmpty() @@ -646,7 +649,6 @@ class BasicSchedulerIntegrationSuite extends SchedulerIntegrationSuite[SingleCor } withBackend(runBackend _) { val jobFuture = submit(new MockRDD(sc, 10, Nil), (0 until 10).toArray) - val duration = Duration(1, SECONDS) awaitJobTermination(jobFuture, duration) assert(failure.getMessage.contains("test task failure")) } diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskDescriptionSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskDescriptionSuite.scala index 97487ce1d2ca8..ba62eec0522db 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskDescriptionSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskDescriptionSuite.scala @@ -62,6 +62,7 @@ class TaskDescriptionSuite extends SparkFunSuite { executorId = "testExecutor", name = "task for test", index = 19, + partitionId = 1, originalFiles, originalJars, originalProperties, @@ -77,6 +78,7 @@ class TaskDescriptionSuite extends SparkFunSuite { assert(decodedTaskDescription.executorId === originalTaskDescription.executorId) assert(decodedTaskDescription.name === originalTaskDescription.name) assert(decodedTaskDescription.index === originalTaskDescription.index) + assert(decodedTaskDescription.partitionId === originalTaskDescription.partitionId) assert(decodedTaskDescription.addedFiles.equals(originalFiles)) assert(decodedTaskDescription.addedJars.equals(originalJars)) assert(decodedTaskDescription.properties.equals(originalTaskDescription.properties)) diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala index 33f2ea1c94e75..9e1d13e369ad9 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala @@ -36,6 +36,7 @@ class FakeSchedulerBackend extends SchedulerBackend { def stop() {} def reviveOffers() {} def defaultParallelism(): Int = 1 + def maxNumConcurrentTasks(): Int = 0 } class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with BeforeAndAfterEach @@ -62,7 +63,6 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with B } override def afterEach(): Unit = { - super.afterEach() if (taskScheduler != null) { taskScheduler.stop() taskScheduler = null @@ -71,6 +71,7 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with B dagScheduler.stop() dagScheduler = null } + super.afterEach() } def setupScheduler(confs: (String, String)*): TaskSchedulerImpl = { @@ -247,7 +248,7 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with B taskScheduler.submitTasks(attempt2) val taskDescriptions3 = taskScheduler.resourceOffers(workerOffers).flatten assert(1 === taskDescriptions3.length) - val mgr = taskScheduler.taskIdToTaskSetManager.get(taskDescriptions3(0).taskId).get + val mgr = Option(taskScheduler.taskIdToTaskSetManager.get(taskDescriptions3(0).taskId)).get assert(mgr.taskSet.stageAttemptId === 1) assert(!failedTaskSet) } @@ -285,7 +286,7 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with B assert(10 === taskDescriptions3.length) taskDescriptions3.foreach { task => - val mgr = taskScheduler.taskIdToTaskSetManager.get(task.taskId).get + val mgr = Option(taskScheduler.taskIdToTaskSetManager.get(task.taskId)).get assert(mgr.taskSet.stageAttemptId === 1) } assert(!failedTaskSet) @@ -723,7 +724,7 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with B // only schedule one task because of locality assert(taskDescs.size === 1) - val mgr = taskScheduler.taskIdToTaskSetManager.get(taskDescs(0).taskId).get + val mgr = Option(taskScheduler.taskIdToTaskSetManager.get(taskDescs(0).taskId)).get assert(mgr.myLocalityLevels.toSet === Set(TaskLocality.NODE_LOCAL, TaskLocality.ANY)) // we should know about both executors, even though we only scheduled tasks on one of them assert(taskScheduler.getExecutorsAliveOnHost("host0") === Some(Set("executor0"))) @@ -1021,4 +1022,118 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with B verify(blacklist).updateBlacklistForSuccessfulTaskSet(meq(0), meq(stageAttempt), anyObject()) } } + + test("don't schedule for a barrier taskSet if available slots are less than pending tasks") { + val taskCpus = 2 + val taskScheduler = setupScheduler("spark.task.cpus" -> taskCpus.toString) + + val numFreeCores = 3 + val workerOffers = IndexedSeq( + new WorkerOffer("executor0", "host0", numFreeCores, Some("192.168.0.101:49625")), + new WorkerOffer("executor1", "host1", numFreeCores, Some("192.168.0.101:49627"))) + val attempt1 = FakeTask.createBarrierTaskSet(3) + + // submit attempt 1, offer some resources, since the available slots are less than pending + // tasks, don't schedule barrier tasks on the resource offer. + taskScheduler.submitTasks(attempt1) + val taskDescriptions = taskScheduler.resourceOffers(workerOffers).flatten + assert(0 === taskDescriptions.length) + } + + test("schedule tasks for a barrier taskSet if all tasks can be launched together") { + val taskCpus = 2 + val taskScheduler = setupScheduler("spark.task.cpus" -> taskCpus.toString) + + val numFreeCores = 3 + val workerOffers = IndexedSeq( + new WorkerOffer("executor0", "host0", numFreeCores, Some("192.168.0.101:49625")), + new WorkerOffer("executor1", "host1", numFreeCores, Some("192.168.0.101:49627")), + new WorkerOffer("executor2", "host2", numFreeCores, Some("192.168.0.101:49629"))) + val attempt1 = FakeTask.createBarrierTaskSet(3) + + // submit attempt 1, offer some resources, all tasks get launched together + taskScheduler.submitTasks(attempt1) + val taskDescriptions = taskScheduler.resourceOffers(workerOffers).flatten + assert(3 === taskDescriptions.length) + } + + test("cancelTasks shall kill all the running tasks and fail the stage") { + val taskScheduler = setupScheduler() + + taskScheduler.initialize(new FakeSchedulerBackend { + override def killTask( + taskId: Long, + executorId: String, + interruptThread: Boolean, + reason: String): Unit = { + // Since we only submit one stage attempt, the following call is sufficient to mark the + // task as killed. + taskScheduler.taskSetManagerForAttempt(0, 0).get.runningTasksSet.remove(taskId) + } + }) + + val attempt1 = FakeTask.createTaskSet(10, 0) + taskScheduler.submitTasks(attempt1) + + val workerOffers = IndexedSeq(new WorkerOffer("executor0", "host0", 1), + new WorkerOffer("executor1", "host1", 1)) + val taskDescriptions = taskScheduler.resourceOffers(workerOffers).flatten + assert(2 === taskDescriptions.length) + val tsm = taskScheduler.taskSetManagerForAttempt(0, 0).get + assert(2 === tsm.runningTasks) + + taskScheduler.cancelTasks(0, false) + assert(0 === tsm.runningTasks) + assert(tsm.isZombie) + assert(taskScheduler.taskSetManagerForAttempt(0, 0).isEmpty) + } + + test("killAllTaskAttempts shall kill all the running tasks and not fail the stage") { + val taskScheduler = setupScheduler() + + taskScheduler.initialize(new FakeSchedulerBackend { + override def killTask( + taskId: Long, + executorId: String, + interruptThread: Boolean, + reason: String): Unit = { + // Since we only submit one stage attempt, the following call is sufficient to mark the + // task as killed. + taskScheduler.taskSetManagerForAttempt(0, 0).get.runningTasksSet.remove(taskId) + } + }) + + val attempt1 = FakeTask.createTaskSet(10, 0) + taskScheduler.submitTasks(attempt1) + + val workerOffers = IndexedSeq(new WorkerOffer("executor0", "host0", 1), + new WorkerOffer("executor1", "host1", 1)) + val taskDescriptions = taskScheduler.resourceOffers(workerOffers).flatten + assert(2 === taskDescriptions.length) + val tsm = taskScheduler.taskSetManagerForAttempt(0, 0).get + assert(2 === tsm.runningTasks) + + taskScheduler.killAllTaskAttempts(0, false, "test") + assert(0 === tsm.runningTasks) + assert(!tsm.isZombie) + assert(taskScheduler.taskSetManagerForAttempt(0, 0).isDefined) + } + + test("mark taskset for a barrier stage as zombie in case a task fails") { + val taskScheduler = setupScheduler() + + val attempt = FakeTask.createBarrierTaskSet(3) + taskScheduler.submitTasks(attempt) + + val tsm = taskScheduler.taskSetManagerForAttempt(0, 0).get + val offers = (0 until 3).map{ idx => + WorkerOffer(s"exec-$idx", s"host-$idx", 1, Some(s"192.168.0.101:4962$idx")) + } + taskScheduler.resourceOffers(offers) + assert(tsm.runningTasks === 3) + + // Fail a task from the stage attempt. + tsm.handleFailedTask(tsm.taskAttempts.head.head.taskId, TaskState.FAILED, TaskKilled("test")) + assert(tsm.isZombie) + } } diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala index ca6a7e5db3b17..d264adaef90a5 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala @@ -178,12 +178,12 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg } override def afterEach(): Unit = { - super.afterEach() if (sched != null) { sched.dagScheduler.stop() sched.stop() sched = null } + super.afterEach() } @@ -1365,10 +1365,241 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg assert(taskOption4.get.addedJars === addedJarsMidTaskSet) } + test("[SPARK-24677] Avoid NoSuchElementException from MedianHeap") { + val conf = new SparkConf().set("spark.speculation", "true") + sc = new SparkContext("local", "test", conf) + // Set the speculation multiplier to be 0 so speculative tasks are launched immediately + sc.conf.set("spark.speculation.multiplier", "0.0") + sc.conf.set("spark.speculation.quantile", "0.1") + sc.conf.set("spark.speculation", "true") + + sched = new FakeTaskScheduler(sc) + sched.initialize(new FakeSchedulerBackend()) + + val dagScheduler = new FakeDAGScheduler(sc, sched) + sched.setDAGScheduler(dagScheduler) + + val taskSet1 = FakeTask.createTaskSet(10) + val accumUpdatesByTask: Array[Seq[AccumulatorV2[_, _]]] = taskSet1.tasks.map { task => + task.metrics.internalAccums + } + + sched.submitTasks(taskSet1) + sched.resourceOffers( + (0 until 10).map { idx => WorkerOffer(s"exec-$idx", s"host-$idx", 1) }) + + val taskSetManager1 = sched.taskSetManagerForAttempt(0, 0).get + + // fail fetch + taskSetManager1.handleFailedTask( + taskSetManager1.taskAttempts.head.head.taskId, TaskState.FAILED, + FetchFailed(null, 0, 0, 0, "fetch failed")) + + assert(taskSetManager1.isZombie) + assert(taskSetManager1.runningTasks === 9) + + val taskSet2 = FakeTask.createTaskSet(10, stageAttemptId = 1) + sched.submitTasks(taskSet2) + sched.resourceOffers( + (11 until 20).map { idx => WorkerOffer(s"exec-$idx", s"host-$idx", 1) }) + + // Complete the 2 tasks and leave 8 task in running + for (id <- Set(0, 1)) { + taskSetManager1.handleSuccessfulTask(id, createTaskResult(id, accumUpdatesByTask(id))) + assert(sched.endedTasks(id) === Success) + } + + val taskSetManager2 = sched.taskSetManagerForAttempt(0, 1).get + assert(!taskSetManager2.successfulTaskDurations.isEmpty()) + taskSetManager2.checkSpeculatableTasks(0) + } + + + test("SPARK-24755 Executor loss can cause task to not be resubmitted") { + val conf = new SparkConf().set("spark.speculation", "true") + sc = new SparkContext("local", "test", conf) + // Set the speculation multiplier to be 0 so speculative tasks are launched immediately + sc.conf.set("spark.speculation.multiplier", "0.0") + + sc.conf.set("spark.speculation.quantile", "0.5") + sc.conf.set("spark.speculation", "true") + + var killTaskCalled = false + sched = new FakeTaskScheduler(sc, ("exec1", "host1"), + ("exec2", "host2"), ("exec3", "host3")) + sched.initialize(new FakeSchedulerBackend() { + override def killTask( + taskId: Long, + executorId: String, + interruptThread: Boolean, + reason: String): Unit = { + // Check the only one killTask event in this case, which triggered by + // task 2.1 completed. + assert(taskId === 2) + assert(executorId === "exec3") + assert(interruptThread) + assert(reason === "another attempt succeeded") + killTaskCalled = true + } + }) + + // Keep track of the index of tasks that are resubmitted, + // so that the test can check that task is resubmitted correctly + var resubmittedTasks = new mutable.HashSet[Int] + val dagScheduler = new FakeDAGScheduler(sc, sched) { + override def taskEnded( + task: Task[_], + reason: TaskEndReason, + result: Any, + accumUpdates: Seq[AccumulatorV2[_, _]], + taskInfo: TaskInfo): Unit = { + super.taskEnded(task, reason, result, accumUpdates, taskInfo) + reason match { + case Resubmitted => resubmittedTasks += taskInfo.index + case _ => + } + } + } + sched.dagScheduler.stop() + sched.setDAGScheduler(dagScheduler) + + val taskSet = FakeTask.createShuffleMapTaskSet(4, 0, 0, + Seq(TaskLocation("host1", "exec1")), + Seq(TaskLocation("host1", "exec1")), + Seq(TaskLocation("host3", "exec3")), + Seq(TaskLocation("host2", "exec2"))) + + val clock = new ManualClock() + val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock = clock) + val accumUpdatesByTask: Array[Seq[AccumulatorV2[_, _]]] = taskSet.tasks.map { task => + task.metrics.internalAccums + } + // Offer resources for 4 tasks to start + for ((exec, host) <- Seq( + "exec1" -> "host1", + "exec1" -> "host1", + "exec3" -> "host3", + "exec2" -> "host2")) { + val taskOption = manager.resourceOffer(exec, host, NO_PREF) + assert(taskOption.isDefined) + val task = taskOption.get + assert(task.executorId === exec) + // Add an extra assert to make sure task 2.0 is running on exec3 + if (task.index == 2) { + assert(task.attemptNumber === 0) + assert(task.executorId === "exec3") + } + } + assert(sched.startedTasks.toSet === Set(0, 1, 2, 3)) + clock.advance(1) + // Complete the 2 tasks and leave 2 task in running + for (id <- Set(0, 1)) { + manager.handleSuccessfulTask(id, createTaskResult(id, accumUpdatesByTask(id))) + assert(sched.endedTasks(id) === Success) + } + + // checkSpeculatableTasks checks that the task runtime is greater than the threshold for + // speculating. Since we use a threshold of 0 for speculation, tasks need to be running for + // > 0ms, so advance the clock by 1ms here. + clock.advance(1) + assert(manager.checkSpeculatableTasks(0)) + assert(sched.speculativeTasks.toSet === Set(2, 3)) + + // Offer resource to start the speculative attempt for the running task 2.0 + val taskOption = manager.resourceOffer("exec2", "host2", ANY) + assert(taskOption.isDefined) + val task4 = taskOption.get + assert(task4.index === 2) + assert(task4.taskId === 4) + assert(task4.executorId === "exec2") + assert(task4.attemptNumber === 1) + // Complete the speculative attempt for the running task + manager.handleSuccessfulTask(4, createTaskResult(2, accumUpdatesByTask(2))) + // Make sure schedBackend.killTask(2, "exec3", true, "another attempt succeeded") gets called + assert(killTaskCalled) + + assert(resubmittedTasks.isEmpty) + // Host 2 Losts, meaning we lost the map output task4 + manager.executorLost("exec2", "host2", SlaveLost()) + // Make sure that task with index 2 is re-submitted + assert(resubmittedTasks.contains(2)) + + } + private def createTaskResult( id: Int, accumUpdates: Seq[AccumulatorV2[_, _]] = Seq.empty): DirectTaskResult[Int] = { val valueSer = SparkEnv.get.serializer.newInstance() new DirectTaskResult[Int](valueSer.serialize(id), accumUpdates) } + + test("SPARK-13343 speculative tasks that didn't commit shouldn't be marked as success") { + sc = new SparkContext("local", "test") + sched = new FakeTaskScheduler(sc, ("exec1", "host1"), ("exec2", "host2")) + val taskSet = FakeTask.createTaskSet(4) + // Set the speculation multiplier to be 0 so speculative tasks are launched immediately + sc.conf.set("spark.speculation.multiplier", "0.0") + sc.conf.set("spark.speculation", "true") + val clock = new ManualClock() + val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock = clock) + val accumUpdatesByTask: Array[Seq[AccumulatorV2[_, _]]] = taskSet.tasks.map { task => + task.metrics.internalAccums + } + // Offer resources for 4 tasks to start + for ((k, v) <- List( + "exec1" -> "host1", + "exec1" -> "host1", + "exec2" -> "host2", + "exec2" -> "host2")) { + val taskOption = manager.resourceOffer(k, v, NO_PREF) + assert(taskOption.isDefined) + val task = taskOption.get + assert(task.executorId === k) + } + assert(sched.startedTasks.toSet === Set(0, 1, 2, 3)) + clock.advance(1) + // Complete the 3 tasks and leave 1 task in running + for (id <- Set(0, 1, 2)) { + manager.handleSuccessfulTask(id, createTaskResult(id, accumUpdatesByTask(id))) + assert(sched.endedTasks(id) === Success) + } + // checkSpeculatableTasks checks that the task runtime is greater than the threshold for + // speculating. Since we use a threshold of 0 for speculation, tasks need to be running for + // > 0ms, so advance the clock by 1ms here. + clock.advance(1) + assert(manager.checkSpeculatableTasks(0)) + assert(sched.speculativeTasks.toSet === Set(3)) + + // Offer resource to start the speculative attempt for the running task + val taskOption5 = manager.resourceOffer("exec1", "host1", NO_PREF) + assert(taskOption5.isDefined) + val task5 = taskOption5.get + assert(task5.index === 3) + assert(task5.taskId === 4) + assert(task5.executorId === "exec1") + assert(task5.attemptNumber === 1) + sched.backend = mock(classOf[SchedulerBackend]) + sched.dagScheduler.stop() + sched.dagScheduler = mock(classOf[DAGScheduler]) + // Complete one attempt for the running task + val result = createTaskResult(3, accumUpdatesByTask(3)) + manager.handleSuccessfulTask(3, result) + // There is a race between the scheduler asking to kill the other task, and that task + // actually finishing. We simulate what happens if the other task finishes before we kill it. + verify(sched.backend).killTask(4, "exec1", true, "another attempt succeeded") + manager.handleSuccessfulTask(4, result) + + val info3 = manager.taskInfos(3) + val info4 = manager.taskInfos(4) + assert(info3.successful) + assert(info4.killed) + verify(sched.dagScheduler).taskEnded( + manager.tasks(3), + TaskKilled("Finish but did not commit due to another attempt succeeded"), + null, + Seq.empty, + info4) + verify(sched.dagScheduler).taskEnded(manager.tasks(3), Success, result.value(), + result.accumUpdates, info3) + } } diff --git a/core/src/test/scala/org/apache/spark/security/EncryptionFunSuite.scala b/core/src/test/scala/org/apache/spark/security/EncryptionFunSuite.scala index 3f52dc41abf6d..be6b8a6b5b108 100644 --- a/core/src/test/scala/org/apache/spark/security/EncryptionFunSuite.scala +++ b/core/src/test/scala/org/apache/spark/security/EncryptionFunSuite.scala @@ -28,11 +28,15 @@ trait EncryptionFunSuite { * for the test to modify the provided SparkConf. */ final protected def encryptionTest(name: String)(fn: SparkConf => Unit) { + encryptionTestHelper(name) { case (name, conf) => + test(name)(fn(conf)) + } + } + + final protected def encryptionTestHelper(name: String)(fn: (String, SparkConf) => Unit): Unit = { Seq(false, true).foreach { encrypt => - test(s"$name (encryption = ${ if (encrypt) "on" else "off" })") { - val conf = new SparkConf().set(IO_ENCRYPTION_ENABLED, encrypt) - fn(conf) - } + val conf = new SparkConf().set(IO_ENCRYPTION_ENABLED, encrypt) + fn(s"$name (encryption = ${ if (encrypt) "on" else "off" })", conf) } } diff --git a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala index fc78655bf52ec..36912441c03bd 100644 --- a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala +++ b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala @@ -345,7 +345,8 @@ class KryoSerializerSuite extends SparkFunSuite with SharedSparkContext { val denseBlockSizes = new Array[Long](5000) val sparseBlockSizes = Array[Long](0L, 1L, 0L, 2L) Seq(denseBlockSizes, sparseBlockSizes).foreach { blockSizes => - ser.serialize(HighlyCompressedMapStatus(BlockManagerId("exec-1", "host", 1234), blockSizes)) + ser.serialize( + HighlyCompressedMapStatus(BlockManagerId("exec-1", "host", 1234), blockSizes, 1)) } } @@ -411,6 +412,26 @@ class KryoSerializerSuite extends SparkFunSuite with SharedSparkContext { assert(!ser2.getAutoReset) } + test("SPARK-25176 ClassCastException when writing a Map after previously " + + "reading a Map with different generic type") { + // This test uses the example in https://github.com/EsotericSoftware/kryo/issues/384 + import java.util._ + val ser = new KryoSerializer(new SparkConf).newInstance().asInstanceOf[KryoSerializerInstance] + + class MapHolder { + private val mapOne = new HashMap[Int, String] + private val mapTwo = this.mapOne + } + + val serializedMapHolder = ser.serialize(new MapHolder) + ser.deserialize[MapHolder](serializedMapHolder) + + val stringMap = new HashMap[Int, List[String]] + stringMap.put(1, new ArrayList[String]) + val serializedMap = ser.serialize[Map[Int, List[String]]](stringMap) + ser.deserialize[HashMap[Int, List[String]]](serializedMap) + } + private def testSerializerInstanceReuse(autoReset: Boolean, referenceTracking: Boolean): Unit = { val conf = new SparkConf(loadDefaults = false) .set("spark.kryo.referenceTracking", referenceTracking.toString) diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/ShuffleExternalSorterSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/ShuffleExternalSorterSuite.scala new file mode 100644 index 0000000000000..b9f0e873375b0 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/ShuffleExternalSorterSuite.scala @@ -0,0 +1,111 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.sort + +import java.lang.{Long => JLong} + +import org.mockito.Mockito.when +import org.scalatest.mockito.MockitoSugar + +import org.apache.spark._ +import org.apache.spark.executor.{ShuffleWriteMetrics, TaskMetrics} +import org.apache.spark.memory._ +import org.apache.spark.unsafe.Platform + +class ShuffleExternalSorterSuite extends SparkFunSuite with LocalSparkContext with MockitoSugar { + + test("nested spill should be no-op") { + val conf = new SparkConf() + .setMaster("local[1]") + .setAppName("ShuffleExternalSorterSuite") + .set("spark.testing", "true") + .set("spark.testing.memory", "1600") + .set("spark.memory.fraction", "1") + sc = new SparkContext(conf) + + val memoryManager = UnifiedMemoryManager(conf, 1) + + var shouldAllocate = false + + // Mock `TaskMemoryManager` to allocate free memory when `shouldAllocate` is true. + // This will trigger a nested spill and expose issues if we don't handle this case properly. + val taskMemoryManager = new TaskMemoryManager(memoryManager, 0) { + override def acquireExecutionMemory(required: Long, consumer: MemoryConsumer): Long = { + // ExecutionMemoryPool.acquireMemory will wait until there are 400 bytes for a task to use. + // So we leave 400 bytes for the task. + if (shouldAllocate && + memoryManager.maxHeapMemory - memoryManager.executionMemoryUsed > 400) { + val acquireExecutionMemoryMethod = + memoryManager.getClass.getMethods.filter(_.getName == "acquireExecutionMemory").head + acquireExecutionMemoryMethod.invoke( + memoryManager, + JLong.valueOf( + memoryManager.maxHeapMemory - memoryManager.executionMemoryUsed - 400), + JLong.valueOf(1L), // taskAttemptId + MemoryMode.ON_HEAP + ).asInstanceOf[java.lang.Long] + } + super.acquireExecutionMemory(required, consumer) + } + } + val taskContext = mock[TaskContext] + val taskMetrics = new TaskMetrics + when(taskContext.taskMetrics()).thenReturn(taskMetrics) + val sorter = new ShuffleExternalSorter( + taskMemoryManager, + sc.env.blockManager, + taskContext, + 100, // initialSize - This will require ShuffleInMemorySorter to acquire at least 800 bytes + 1, // numPartitions + conf, + new ShuffleWriteMetrics) + val inMemSorter = { + val field = sorter.getClass.getDeclaredField("inMemSorter") + field.setAccessible(true) + field.get(sorter).asInstanceOf[ShuffleInMemorySorter] + } + // Allocate memory to make the next "insertRecord" call triggers a spill. + val bytes = new Array[Byte](1) + while (inMemSorter.hasSpaceForAnotherRecord) { + sorter.insertRecord(bytes, Platform.BYTE_ARRAY_OFFSET, 1, 0) + } + + // This flag will make the mocked TaskMemoryManager acquire free memory released by spill to + // trigger a nested spill. + shouldAllocate = true + + // Should throw `SparkOutOfMemoryError` as there is no enough memory: `ShuffleInMemorySorter` + // will try to acquire 800 bytes but there are only 400 bytes available. + // + // Before the fix, a nested spill may use a released page and this causes two tasks access the + // same memory page. When a task reads memory written by another task, many types of failures + // may happen. Here are some examples we have seen: + // + // - JVM crash. (This is easy to reproduce in the unit test as we fill newly allocated and + // deallocated memory with 0xa5 and 0x5a bytes which usually points to an invalid memory + // address) + // - java.lang.IllegalArgumentException: Comparison method violates its general contract! + // - java.lang.NullPointerException + // at org.apache.spark.memory.TaskMemoryManager.getPage(TaskMemoryManager.java:384) + // - java.lang.UnsupportedOperationException: Cannot grow BufferHolder by size -536870912 + // because the size after growing exceeds size limitation 2147483632 + intercept[SparkOutOfMemoryError] { + sorter.insertRecord(bytes, Platform.BYTE_ARRAY_OFFSET, 1, 0) + } + } +} diff --git a/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala b/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala index 1cd71955ad4d9..0b2bbd2fa8a78 100644 --- a/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala @@ -22,18 +22,19 @@ import java.lang.{Integer => JInteger, Long => JLong} import java.util.{Arrays, Date, Properties} import scala.collection.JavaConverters._ +import scala.collection.immutable.Map import scala.reflect.{classTag, ClassTag} import org.scalatest.BeforeAndAfter import org.apache.spark._ -import org.apache.spark.executor.TaskMetrics +import org.apache.spark.executor.{ExecutorMetrics, TaskMetrics} +import org.apache.spark.metrics.ExecutorMetricType import org.apache.spark.scheduler._ import org.apache.spark.scheduler.cluster._ import org.apache.spark.status.api.v1 import org.apache.spark.storage._ import org.apache.spark.util.Utils -import org.apache.spark.util.kvstore._ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { @@ -215,7 +216,7 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { check[TaskDataWrapper](task.taskId) { wrapper => assert(wrapper.taskId === task.taskId) assert(wrapper.stageId === stages.head.stageId) - assert(wrapper.stageAttemptId === stages.head.attemptId) + assert(wrapper.stageAttemptId === stages.head.attemptNumber) assert(wrapper.index === task.index) assert(wrapper.attempt === task.attemptNumber) assert(wrapper.launchTime === task.launchTime) @@ -258,7 +259,7 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { executorId = execIds.head, taskFailures = 2, stageId = stages.head.stageId, - stageAttemptId = stages.head.attemptId)) + stageAttemptId = stages.head.attemptNumber)) val executorStageSummaryWrappers = store.view(classOf[ExecutorStageSummaryWrapper]).index("stage") @@ -284,7 +285,7 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { hostId = "2.example.com", // this is where the second executor is hosted executorFailures = 1, stageId = stages.head.stageId, - stageAttemptId = stages.head.attemptId)) + stageAttemptId = stages.head.attemptNumber)) val executorStageSummaryWrappersForNode = store.view(classOf[ExecutorStageSummaryWrapper]).index("stage") @@ -468,7 +469,7 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { hostId = "1.example.com", executorFailures = 1, stageId = stages.last.stageId, - stageAttemptId = stages.last.attemptId)) + stageAttemptId = stages.last.attemptNumber)) check[ExecutorSummaryWrapper](execIds.head) { exec => assert(exec.info.blacklistedInStages === Set(stages.last.stageId)) @@ -881,12 +882,41 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { assert(dist.memoryRemaining === maxMemory - rdd2b1.memSize - rdd1b2.memSize ) } + // Add block1 of rdd1 back to bm 1. + listener.onBlockUpdated(SparkListenerBlockUpdated( + BlockUpdatedInfo(bm1, rdd1b1.blockId, level, rdd1b1.memSize, rdd1b1.diskSize))) + + check[ExecutorSummaryWrapper](bm1.executorId) { exec => + assert(exec.info.rddBlocks === 3L) + assert(exec.info.memoryUsed === rdd1b1.memSize + rdd1b2.memSize + rdd2b1.memSize) + assert(exec.info.diskUsed === rdd1b1.diskSize + rdd1b2.diskSize + rdd2b1.diskSize) + } + // Unpersist RDD1. listener.onUnpersistRDD(SparkListenerUnpersistRDD(rdd1b1.rddId)) intercept[NoSuchElementException] { check[RDDStorageInfoWrapper](rdd1b1.rddId) { _ => () } } + // executor1 now only contains block1 from rdd2. + check[ExecutorSummaryWrapper](bm1.executorId) { exec => + assert(exec.info.rddBlocks === 1L) + assert(exec.info.memoryUsed === rdd2b1.memSize) + assert(exec.info.diskUsed === rdd2b1.diskSize) + } + + // Unpersist RDD2. + listener.onUnpersistRDD(SparkListenerUnpersistRDD(rdd2b1.rddId)) + intercept[NoSuchElementException] { + check[RDDStorageInfoWrapper](rdd2b1.rddId) { _ => () } + } + + check[ExecutorSummaryWrapper](bm1.executorId) { exec => + assert(exec.info.rddBlocks === 0L) + assert(exec.info.memoryUsed === 0) + assert(exec.info.diskUsed === 0) + } + // Update a StreamBlock. val stream1 = StreamBlockId(1, 1L) listener.onBlockUpdated(SparkListenerBlockUpdated( @@ -963,17 +993,17 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { // task end event. time += 1 val task = createTasks(1, Array("1")).head - listener.onTaskStart(SparkListenerTaskStart(dropped.stageId, dropped.attemptId, task)) + listener.onTaskStart(SparkListenerTaskStart(dropped.stageId, dropped.attemptNumber, task)) time += 1 task.markFinished(TaskState.FINISHED, time) val metrics = TaskMetrics.empty metrics.setExecutorRunTime(42L) - listener.onTaskEnd(SparkListenerTaskEnd(dropped.stageId, dropped.attemptId, + listener.onTaskEnd(SparkListenerTaskEnd(dropped.stageId, dropped.attemptNumber, "taskType", Success, task, metrics)) new AppStatusStore(store) - .taskSummary(dropped.stageId, dropped.attemptId, Array(0.25d, 0.50d, 0.75d)) + .taskSummary(dropped.stageId, dropped.attemptNumber, Array(0.25d, 0.50d, 0.75d)) assert(store.count(classOf[CachedQuantile], "stage", key(dropped)) === 3) stages.drop(1).foreach { s => @@ -1190,6 +1220,61 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { assert(appStore.asOption(appStore.lastStageAttempt(3)) === None) } + test("SPARK-24415: update metrics for tasks that finish late") { + val listener = new AppStatusListener(store, conf, true) + + val stage1 = new StageInfo(1, 0, "stage1", 4, Nil, Nil, "details1") + val stage2 = new StageInfo(2, 0, "stage2", 4, Nil, Nil, "details2") + + // Start job + listener.onJobStart(SparkListenerJobStart(1, time, Seq(stage1, stage2), null)) + + // Start 2 stages + listener.onStageSubmitted(SparkListenerStageSubmitted(stage1, new Properties())) + listener.onStageSubmitted(SparkListenerStageSubmitted(stage2, new Properties())) + + // Start 2 Tasks + val tasks = createTasks(2, Array("1")) + tasks.foreach { task => + listener.onTaskStart(SparkListenerTaskStart(stage1.stageId, stage1.attemptNumber, task)) + } + + // Task 1 Finished + time += 1 + tasks(0).markFinished(TaskState.FINISHED, time) + listener.onTaskEnd( + SparkListenerTaskEnd(stage1.stageId, stage1.attemptId, "taskType", Success, tasks(0), null)) + + // Stage 1 Completed + stage1.failureReason = Some("Failed") + listener.onStageCompleted(SparkListenerStageCompleted(stage1)) + + // Stop job 1 + time += 1 + listener.onJobEnd(SparkListenerJobEnd(1, time, JobSucceeded)) + + // Task 2 Killed + time += 1 + tasks(1).markFinished(TaskState.FINISHED, time) + listener.onTaskEnd( + SparkListenerTaskEnd(stage1.stageId, stage1.attemptId, "taskType", + TaskKilled(reason = "Killed"), tasks(1), null)) + + // Ensure killed task metrics are updated + val allStages = store.view(classOf[StageDataWrapper]).reverse().asScala.map(_.info) + val failedStages = allStages.filter(_.status == v1.StageStatus.FAILED) + assert(failedStages.size == 1) + assert(failedStages.head.numKilledTasks == 1) + assert(failedStages.head.numCompleteTasks == 1) + + val allJobs = store.view(classOf[JobDataWrapper]).reverse().asScala.map(_.info) + assert(allJobs.size == 1) + assert(allJobs.head.numKilledTasks == 1) + assert(allJobs.head.numCompletedTasks == 1) + assert(allJobs.head.numActiveStages == 1) + assert(allJobs.head.numFailedStages == 1) + } + test("driver logs") { val listener = new AppStatusListener(store, conf, true) @@ -1208,6 +1293,130 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { } } + test("executor metrics updates") { + val listener = new AppStatusListener(store, conf, true) + + val driver = BlockManagerId(SparkContext.DRIVER_IDENTIFIER, "localhost", 42) + + listener.onExecutorAdded(createExecutorAddedEvent(1)) + listener.onExecutorAdded(createExecutorAddedEvent(2)) + listener.onStageSubmitted(createStageSubmittedEvent(0)) + // receive 3 metric updates from each executor with just stage 0 running, + // with different peak updates for each executor + listener.onExecutorMetricsUpdate(createExecutorMetricsUpdateEvent(1, + Array(4000L, 50L, 20L, 0L, 40L, 0L, 60L, 0L, 70L, 20L))) + listener.onExecutorMetricsUpdate(createExecutorMetricsUpdateEvent(2, + Array(1500L, 50L, 20L, 0L, 0L, 0L, 20L, 0L, 70L, 0L))) + // exec 1: new stage 0 peaks for metrics at indexes: 2, 4, 6 + listener.onExecutorMetricsUpdate(createExecutorMetricsUpdateEvent(1, + Array(4000L, 50L, 50L, 0L, 50L, 0L, 100L, 0L, 70L, 20L))) + // exec 2: new stage 0 peaks for metrics at indexes: 0, 4, 6 + listener.onExecutorMetricsUpdate(createExecutorMetricsUpdateEvent(2, + Array(2000L, 50L, 10L, 0L, 10L, 0L, 30L, 0L, 70L, 0L))) + // exec 1: new stage 0 peaks for metrics at indexes: 5, 7 + listener.onExecutorMetricsUpdate(createExecutorMetricsUpdateEvent(1, + Array(2000L, 40L, 50L, 0L, 40L, 10L, 90L, 10L, 50L, 0L))) + // exec 2: new stage 0 peaks for metrics at indexes: 0, 5, 6, 7, 8 + listener.onExecutorMetricsUpdate(createExecutorMetricsUpdateEvent(2, + Array(3500L, 50L, 15L, 0L, 10L, 10L, 35L, 10L, 80L, 0L))) + // now start stage 1, one more metric update for each executor, and new + // peaks for some stage 1 metrics (as listed), initialize stage 1 peaks + listener.onStageSubmitted(createStageSubmittedEvent(1)) + // exec 1: new stage 0 peaks for metrics at indexes: 0, 3, 7 + listener.onExecutorMetricsUpdate(createExecutorMetricsUpdateEvent(1, + Array(5000L, 30L, 50L, 20L, 30L, 10L, 80L, 30L, 50L, 0L))) + // exec 2: new stage 0 peaks for metrics at indexes: 0, 1, 2, 3, 6, 7, 9 + listener.onExecutorMetricsUpdate(createExecutorMetricsUpdateEvent(2, + Array(7000L, 80L, 50L, 20L, 0L, 10L, 50L, 30L, 10L, 40L))) + // complete stage 0, and 3 more updates for each executor with just + // stage 1 running + listener.onStageCompleted(createStageCompletedEvent(0)) + // exec 1: new stage 1 peaks for metrics at indexes: 0, 1, 3 + listener.onExecutorMetricsUpdate(createExecutorMetricsUpdateEvent(1, + Array(6000L, 70L, 20L, 30L, 10L, 0L, 30L, 30L, 30L, 0L))) + // exec 2: new stage 1 peaks for metrics at indexes: 3, 4, 7, 8 + listener.onExecutorMetricsUpdate(createExecutorMetricsUpdateEvent(2, + Array(5500L, 30L, 20L, 40L, 10L, 0L, 30L, 40L, 40L, 20L))) + // exec 1: new stage 1 peaks for metrics at indexes: 0, 4, 5, 7 + listener.onExecutorMetricsUpdate(createExecutorMetricsUpdateEvent(1, + Array(7000L, 70L, 5L, 25L, 60L, 30L, 65L, 55L, 30L, 0L))) + // exec 2: new stage 1 peak for metrics at index: 7 + listener.onExecutorMetricsUpdate(createExecutorMetricsUpdateEvent(2, + Array(5500L, 40L, 25L, 30L, 10L, 30L, 35L, 60L, 0L, 20L))) + // exec 1: no new stage 1 peaks + listener.onExecutorMetricsUpdate(createExecutorMetricsUpdateEvent(1, + Array(5500L, 70L, 15L, 20L, 55L, 20L, 70L, 40L, 20L, 0L))) + listener.onExecutorRemoved(createExecutorRemovedEvent(1)) + // exec 2: new stage 1 peak for metrics at index: 6 + listener.onExecutorMetricsUpdate(createExecutorMetricsUpdateEvent(2, + Array(4000L, 20L, 25L, 30L, 10L, 30L, 35L, 60L, 0L, 0L))) + listener.onStageCompleted(createStageCompletedEvent(1)) + + // expected peak values for each executor + val expectedValues = Map( + "1" -> new ExecutorMetrics(Array(7000L, 70L, 50L, 30L, 60L, 30L, 100L, 55L, 70L, 20L)), + "2" -> new ExecutorMetrics(Array(7000L, 80L, 50L, 40L, 10L, 30L, 50L, 60L, 80L, 40L))) + + // check that the stored peak values match the expected values + expectedValues.foreach { case (id, metrics) => + check[ExecutorSummaryWrapper](id) { exec => + assert(exec.info.id === id) + exec.info.peakMemoryMetrics match { + case Some(actual) => + ExecutorMetricType.values.foreach { metricType => + assert(actual.getMetricValue(metricType) === metrics.getMetricValue(metricType)) + } + case _ => + assert(false) + } + } + } + } + + test("stage executor metrics") { + // simulate reading in StageExecutorMetrics events from the history log + val listener = new AppStatusListener(store, conf, false) + val driver = BlockManagerId(SparkContext.DRIVER_IDENTIFIER, "localhost", 42) + + listener.onExecutorAdded(createExecutorAddedEvent(1)) + listener.onExecutorAdded(createExecutorAddedEvent(2)) + listener.onStageSubmitted(createStageSubmittedEvent(0)) + listener.onStageSubmitted(createStageSubmittedEvent(1)) + listener.onStageExecutorMetrics(SparkListenerStageExecutorMetrics("1", 0, 0, + new ExecutorMetrics(Array(5000L, 50L, 50L, 20L, 50L, 10L, 100L, 30L, 70L, 20L)))) + listener.onStageExecutorMetrics(SparkListenerStageExecutorMetrics("2", 0, 0, + new ExecutorMetrics(Array(7000L, 70L, 50L, 20L, 10L, 10L, 50L, 30L, 80L, 40L)))) + listener.onStageCompleted(createStageCompletedEvent(0)) + // executor 1 is removed before stage 1 has finished, the stage executor metrics + // are logged afterwards and should still be used to update the executor metrics. + listener.onExecutorRemoved(createExecutorRemovedEvent(1)) + listener.onStageExecutorMetrics(SparkListenerStageExecutorMetrics("1", 1, 0, + new ExecutorMetrics(Array(7000L, 70L, 50L, 30L, 60L, 30L, 80L, 55L, 50L, 0L)))) + listener.onStageExecutorMetrics(SparkListenerStageExecutorMetrics("2", 1, 0, + new ExecutorMetrics(Array(7000L, 80L, 50L, 40L, 10L, 30L, 50L, 60L, 40L, 40L)))) + listener.onStageCompleted(createStageCompletedEvent(1)) + + // expected peak values for each executor + val expectedValues = Map( + "1" -> new ExecutorMetrics(Array(7000L, 70L, 50L, 30L, 60L, 30L, 100L, 55L, 70L, 20L)), + "2" -> new ExecutorMetrics(Array(7000L, 80L, 50L, 40L, 10L, 30L, 50L, 60L, 80L, 40L))) + + // check that the stored peak values match the expected values + for ((id, metrics) <- expectedValues) { + check[ExecutorSummaryWrapper](id) { exec => + assert(exec.info.id === id) + exec.info.peakMemoryMetrics match { + case Some(actual) => + ExecutorMetricType.values.foreach { metricType => + assert(actual.getMetricValue(metricType) === metrics.getMetricValue(metricType)) + } + case _ => + assert(false) + } + } + } + } + private def key(stage: StageInfo): Array[Int] = Array(stage.stageId, stage.attemptNumber) private def check[T: ClassTag](key: Any)(fn: T => Unit): Unit = { @@ -1245,4 +1454,37 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { } + /** Create a stage submitted event for the specified stage Id. */ + private def createStageSubmittedEvent(stageId: Int) = { + SparkListenerStageSubmitted(new StageInfo(stageId, 0, stageId.toString, 0, + Seq.empty, Seq.empty, "details")) + } + + /** Create a stage completed event for the specified stage Id. */ + private def createStageCompletedEvent(stageId: Int) = { + SparkListenerStageCompleted(new StageInfo(stageId, 0, stageId.toString, 0, + Seq.empty, Seq.empty, "details")) + } + + /** Create an executor added event for the specified executor Id. */ + private def createExecutorAddedEvent(executorId: Int) = { + SparkListenerExecutorAdded(0L, executorId.toString, new ExecutorInfo("host1", 1, Map.empty)) + } + + /** Create an executor added event for the specified executor Id. */ + private def createExecutorRemovedEvent(executorId: Int) = { + SparkListenerExecutorRemoved(10L, executorId.toString, "test") + } + + /** Create an executor metrics update event, with the specified executor metrics values. */ + private def createExecutorMetricsUpdateEvent( + executorId: Int, + executorMetrics: Array[Long]): SparkListenerExecutorMetricsUpdate = { + val taskMetrics = TaskMetrics.empty + taskMetrics.incDiskBytesSpilled(111) + taskMetrics.incMemoryBytesSpilled(222) + val accum = Array((333L, 1, 1, taskMetrics.accumulators().map(AccumulatorSuite.makeInfo))) + SparkListenerExecutorMetricsUpdate(executorId.toString, accum, + Some(new ExecutorMetrics(executorMetrics))) + } } diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index b19d8ebf72c61..dbee1f60d7af0 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -1377,8 +1377,8 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE val (server, shufflePort) = Utils.startServiceOnPort(candidatePort, newShuffleServer, conf, "ShuffleServer") - conf.set("spark.shuffle.service.enabled", "true") - conf.set("spark.shuffle.service.port", shufflePort.toString) + conf.set(SHUFFLE_SERVICE_ENABLED.key, "true") + conf.set(SHUFFLE_SERVICE_PORT.key, shufflePort.toString) conf.set(SHUFFLE_REGISTRATION_TIMEOUT.key, "40") conf.set(SHUFFLE_REGISTRATION_MAX_ATTEMPTS.key, "1") var e = intercept[SparkException] { @@ -1422,6 +1422,19 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE assert(mockBlockTransferService.tempFileManager === store.remoteBlockTempFileManager) } + test("query locations of blockIds") { + val mockBlockManagerMaster = mock(classOf[BlockManagerMaster]) + val blockLocations = Seq(BlockManagerId("1", "host1", 100), BlockManagerId("2", "host2", 200)) + when(mockBlockManagerMaster.getLocations(mc.any[Array[BlockId]])) + .thenReturn(Array(blockLocations)) + val env = mock(classOf[SparkEnv]) + + val blockIds: Array[BlockId] = Array(StreamBlockId(1, 2)) + val locs = BlockManager.blockIdsToLocations(blockIds, env, mockBlockManagerMaster) + val expectedLocs = Seq("executor_host1_1", "executor_host2_2") + assert(locs(blockIds(0)) == expectedLocs) + } + class MockBlockTransferService(val maxFailures: Int) extends BlockTransferService { var numCalls = 0 var tempFileManager: TempFileManager = null diff --git a/core/src/test/scala/org/apache/spark/storage/DiskStoreSuite.scala b/core/src/test/scala/org/apache/spark/storage/DiskStoreSuite.scala index efdd02fff7871..eec961a491101 100644 --- a/core/src/test/scala/org/apache/spark/storage/DiskStoreSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/DiskStoreSuite.scala @@ -24,6 +24,7 @@ import com.google.common.io.{ByteStreams, Files} import io.netty.channel.FileRegion import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite} +import org.apache.spark.internal.config import org.apache.spark.network.util.{ByteArrayWritableChannel, JavaUtils} import org.apache.spark.security.CryptoStreamUtils import org.apache.spark.util.Utils @@ -94,7 +95,7 @@ class DiskStoreSuite extends SparkFunSuite { test("blocks larger than 2gb") { val conf = new SparkConf() - .set("spark.storage.memoryMapLimitForTests", "10k" ) + .set(config.MEMORY_MAP_LIMIT_FOR_TESTS.key, "10k") val diskBlockManager = new DiskBlockManager(conf, deleteFilesOnStop = true) val diskStore = new DiskStore(conf, diskBlockManager, new SecurityManager(conf)) @@ -194,8 +195,8 @@ class DiskStoreSuite extends SparkFunSuite { val region = data.toNetty().asInstanceOf[FileRegion] val byteChannel = new ByteArrayWritableChannel(data.size.toInt) - while (region.transfered() < region.count()) { - region.transferTo(byteChannel, region.transfered()) + while (region.transferred() < region.count()) { + region.transferTo(byteChannel, region.transferred()) } byteChannel.close() diff --git a/core/src/test/scala/org/apache/spark/storage/FlatmapIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/FlatmapIteratorSuite.scala index b21c91f75d5c7..42828506895a7 100644 --- a/core/src/test/scala/org/apache/spark/storage/FlatmapIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/FlatmapIteratorSuite.scala @@ -22,8 +22,8 @@ import org.apache.spark._ class FlatmapIteratorSuite extends SparkFunSuite with LocalSparkContext { /* Tests the ability of Spark to deal with user provided iterators from flatMap * calls, that may generate more data then available memory. In any - * memory based persistance Spark will unroll the iterator into an ArrayBuffer - * for caching, however in the case that the use defines DISK_ONLY persistance, + * memory based persistence Spark will unroll the iterator into an ArrayBuffer + * for caching, however in the case that the use defines DISK_ONLY persistence, * the iterator will be fed directly to the serializer and written to disk. * * This also tests the ObjectOutputStream reset rate. When serializing using the diff --git a/core/src/test/scala/org/apache/spark/util/AccumulatorV2Suite.scala b/core/src/test/scala/org/apache/spark/util/AccumulatorV2Suite.scala index fe0a9a471a651..94c79388e3639 100644 --- a/core/src/test/scala/org/apache/spark/util/AccumulatorV2Suite.scala +++ b/core/src/test/scala/org/apache/spark/util/AccumulatorV2Suite.scala @@ -165,7 +165,6 @@ class AccumulatorV2Suite extends SparkFunSuite { } test("LegacyAccumulatorWrapper with AccumulatorParam that has no equals/hashCode") { - class MyData(val i: Int) extends Serializable val param = new AccumulatorParam[MyData] { override def zero(initialValue: MyData): MyData = new MyData(0) override def addInPlace(r1: MyData, r2: MyData): MyData = new MyData(r1.i + r2.i) @@ -182,3 +181,5 @@ class AccumulatorV2Suite extends SparkFunSuite { ser.serialize(acc) } } + +class MyData(val i: Int) extends Serializable diff --git a/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala b/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala index 9a19baee9569e..3c6660800f170 100644 --- a/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala @@ -19,6 +19,8 @@ package org.apache.spark.util import java.io.NotSerializableException +import scala.language.reflectiveCalls + import org.apache.spark.{SparkContext, SparkException, SparkFunSuite, TaskContext} import org.apache.spark.LocalSparkContext._ import org.apache.spark.partial.CountEvaluator @@ -121,6 +123,7 @@ class ClosureCleanerSuite extends SparkFunSuite { } test("SPARK-22328: ClosureCleaner misses referenced superclass fields: case 1") { + assume(!ClosureCleanerSuite2.supportsLMFs) val concreteObject = new TestAbstractClass { val n2 = 222 val s2 = "bbb" @@ -141,6 +144,7 @@ class ClosureCleanerSuite extends SparkFunSuite { } test("SPARK-22328: ClosureCleaner misses referenced superclass fields: case 2") { + assume(!ClosureCleanerSuite2.supportsLMFs) val concreteObject = new TestAbstractClass2 { val n2 = 222 val s2 = "bbb" @@ -154,6 +158,7 @@ class ClosureCleanerSuite extends SparkFunSuite { } test("SPARK-22328: multiple outer classes have the same parent class") { + assume(!ClosureCleanerSuite2.supportsLMFs) val concreteObject = new TestAbstractClass2 { val innerObject = new TestAbstractClass2 { diff --git a/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite2.scala b/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite2.scala index 278fada83d78c..96da8ec3b2a1c 100644 --- a/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite2.scala +++ b/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite2.scala @@ -145,6 +145,7 @@ class ClosureCleanerSuite2 extends SparkFunSuite with BeforeAndAfterAll with Pri } test("get inner closure classes") { + assume(!ClosureCleanerSuite2.supportsLMFs) val closure1 = () => 1 val closure2 = () => { () => 1 } val closure3 = (i: Int) => { @@ -171,6 +172,7 @@ class ClosureCleanerSuite2 extends SparkFunSuite with BeforeAndAfterAll with Pri } test("get outer classes and objects") { + assume(!ClosureCleanerSuite2.supportsLMFs) val localValue = someSerializableValue val closure1 = () => 1 val closure2 = () => localValue @@ -207,6 +209,7 @@ class ClosureCleanerSuite2 extends SparkFunSuite with BeforeAndAfterAll with Pri } test("get outer classes and objects with nesting") { + assume(!ClosureCleanerSuite2.supportsLMFs) val localValue = someSerializableValue val test1 = () => { @@ -258,6 +261,7 @@ class ClosureCleanerSuite2 extends SparkFunSuite with BeforeAndAfterAll with Pri } test("find accessed fields") { + assume(!ClosureCleanerSuite2.supportsLMFs) val localValue = someSerializableValue val closure1 = () => 1 val closure2 = () => localValue @@ -296,6 +300,7 @@ class ClosureCleanerSuite2 extends SparkFunSuite with BeforeAndAfterAll with Pri } test("find accessed fields with nesting") { + assume(!ClosureCleanerSuite2.supportsLMFs) val localValue = someSerializableValue val test1 = () => { @@ -538,17 +543,22 @@ class ClosureCleanerSuite2 extends SparkFunSuite with BeforeAndAfterAll with Pri // As before, this closure is neither serializable nor cleanable verifyCleaning(inner1, serializableBefore = false, serializableAfter = false) - // This closure is no longer serializable because it now has a pointer to the outer closure, - // which is itself not serializable because it has a pointer to the ClosureCleanerSuite2. - // If we do not clean transitively, we will not null out this indirect reference. - verifyCleaning( - inner2, serializableBefore = false, serializableAfter = false, transitive = false) - - // If we clean transitively, we will find that method `a` does not actually reference the - // outer closure's parent (i.e. the ClosureCleanerSuite), so we can additionally null out - // the outer closure's parent pointer. This will make `inner2` serializable. - verifyCleaning( - inner2, serializableBefore = false, serializableAfter = true, transitive = true) + if (ClosureCleanerSuite2.supportsLMFs) { + verifyCleaning( + inner2, serializableBefore = true, serializableAfter = true) + } else { + // This closure is no longer serializable because it now has a pointer to the outer closure, + // which is itself not serializable because it has a pointer to the ClosureCleanerSuite2. + // If we do not clean transitively, we will not null out this indirect reference. + verifyCleaning( + inner2, serializableBefore = false, serializableAfter = false, transitive = false) + + // If we clean transitively, we will find that method `a` does not actually reference the + // outer closure's parent (i.e. the ClosureCleanerSuite), so we can additionally null out + // the outer closure's parent pointer. This will make `inner2` serializable. + verifyCleaning( + inner2, serializableBefore = false, serializableAfter = true, transitive = true) + } } // Same as above, but with more levels of nesting @@ -565,4 +575,25 @@ class ClosureCleanerSuite2 extends SparkFunSuite with BeforeAndAfterAll with Pri test6()()() } + test("verify nested non-LMF closures") { + assume(ClosureCleanerSuite2.supportsLMFs) + class A1(val f: Int => Int) + class A2(val f: Int => Int => Int) + class B extends A1(x => x*x) + class C extends A2(x => new B().f ) + val closure1 = new B().f + val closure2 = new C().f + // serializable already + verifyCleaning(closure1, serializableBefore = true, serializableAfter = true) + // brings in deps that can't be cleaned + verifyCleaning(closure2, serializableBefore = false, serializableAfter = false) + } +} + +object ClosureCleanerSuite2 { + // Scala 2.12 allows better interop with Java 8 via lambda syntax. This is supported + // by implementing FunctionN classes in Scala’s standard library as Single Abstract + // Method (SAM) types. Lambdas are implemented via the invokedynamic instruction and + // the use of the LambdaMwtaFactory (LMF) machanism. + val supportsLMFs = scala.util.Properties.versionString.contains("2.12") } diff --git a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala index 74b72d940eeef..1e0d2af9a4711 100644 --- a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala @@ -30,6 +30,7 @@ import org.scalatest.exceptions.TestFailedException import org.apache.spark._ import org.apache.spark.executor._ +import org.apache.spark.metrics.ExecutorMetricType import org.apache.spark.rdd.RDDOperationScope import org.apache.spark.scheduler._ import org.apache.spark.scheduler.cluster.ExecutorInfo @@ -94,11 +95,17 @@ class JsonProtocolSuite extends SparkFunSuite { makeTaskMetrics(300L, 400L, 500L, 600L, 700, 800, hasHadoopInput = true, hasOutput = true) .accumulators().map(AccumulatorSuite.makeInfo) .zipWithIndex.map { case (a, i) => a.copy(id = i) } - SparkListenerExecutorMetricsUpdate("exec3", Seq((1L, 2, 3, accumUpdates))) + val executorUpdates = new ExecutorMetrics( + Array(543L, 123456L, 12345L, 1234L, 123L, 12L, 432L, 321L, 654L, 765L)) + SparkListenerExecutorMetricsUpdate("exec3", Seq((1L, 2, 3, accumUpdates)), + Some(executorUpdates)) } val blockUpdated = SparkListenerBlockUpdated(BlockUpdatedInfo(BlockManagerId("Stars", "In your multitude...", 300), RDDBlockId(0, 0), StorageLevel.MEMORY_ONLY, 100L, 0L)) + val stageExecutorMetrics = + SparkListenerStageExecutorMetrics("1", 2, 3, + new ExecutorMetrics(Array(543L, 123456L, 12345L, 1234L, 123L, 12L, 432L, 321L, 654L, 765L))) testEvent(stageSubmitted, stageSubmittedJsonString) testEvent(stageCompleted, stageCompletedJsonString) @@ -124,6 +131,7 @@ class JsonProtocolSuite extends SparkFunSuite { testEvent(nodeUnblacklisted, nodeUnblacklistedJsonString) testEvent(executorMetricsUpdate, executorMetricsUpdateJsonString) testEvent(blockUpdated, blockUpdatedJsonString) + testEvent(stageExecutorMetrics, stageExecutorMetricsJsonString) } test("Dependent Classes") { @@ -419,6 +427,30 @@ class JsonProtocolSuite extends SparkFunSuite { exceptionFailure.accumUpdates, oldExceptionFailure.accumUpdates, (x, y) => x == y) } + test("ExecutorMetricsUpdate backward compatibility: executor metrics update") { + // executorMetricsUpdate was added in 2.4.0. + val executorMetricsUpdate = makeExecutorMetricsUpdate("1", true, true) + val oldExecutorMetricsUpdateJson = + JsonProtocol.executorMetricsUpdateToJson(executorMetricsUpdate) + .removeField( _._1 == "Executor Metrics Updated") + val exepectedExecutorMetricsUpdate = makeExecutorMetricsUpdate("1", true, false) + assertEquals(exepectedExecutorMetricsUpdate, + JsonProtocol.executorMetricsUpdateFromJson(oldExecutorMetricsUpdateJson)) + } + + test("executorMetricsFromJson backward compatibility: handle missing metrics") { + // any missing metrics should be set to 0 + val executorMetrics = new ExecutorMetrics( + Array(12L, 23L, 45L, 67L, 78L, 89L, 90L, 123L, 456L, 789L)) + val oldExecutorMetricsJson = + JsonProtocol.executorMetricsToJson(executorMetrics) + .removeField( _._1 == "MappedPoolMemory") + val expectedExecutorMetrics = new ExecutorMetrics( + Array(12L, 23L, 45L, 67L, 78L, 89L, 90L, 123L, 456L, 0L)) + assertEquals(expectedExecutorMetrics, + JsonProtocol.executorMetricsFromJson(oldExecutorMetricsJson)) + } + test("AccumulableInfo value de/serialization") { import InternalAccumulator._ val blocks = Seq[(BlockId, BlockStatus)]( @@ -435,7 +467,6 @@ class JsonProtocolSuite extends SparkFunSuite { testAccumValue(Some("anything"), blocks, JString(blocks.toString)) testAccumValue(Some("anything"), 123, JString("123")) } - } @@ -565,6 +596,13 @@ private[spark] object JsonProtocolSuite extends Assertions { assert(stageAttemptId1 === stageAttemptId2) assertSeqEquals[AccumulableInfo](updates1, updates2, (a, b) => a.equals(b)) }) + assertOptionEquals(e1.executorUpdates, e2.executorUpdates, + (e1: ExecutorMetrics, e2: ExecutorMetrics) => assertEquals(e1, e2)) + case (e1: SparkListenerStageExecutorMetrics, e2: SparkListenerStageExecutorMetrics) => + assert(e1.execId === e2.execId) + assert(e1.stageId === e2.stageId) + assert(e1.stageAttemptId === e2.stageAttemptId) + assertEquals(e1.executorMetrics, e2.executorMetrics) case (e1, e2) => assert(e1 === e2) case _ => fail("Events don't match in types!") @@ -715,6 +753,12 @@ private[spark] object JsonProtocolSuite extends Assertions { assertStackTraceElementEquals) } + private def assertEquals(metrics1: ExecutorMetrics, metrics2: ExecutorMetrics) { + ExecutorMetricType.values.foreach { metricType => + assert(metrics1.getMetricValue(metricType) === metrics2.getMetricValue(metricType)) + } + } + private def assertJsonStringEquals(expected: String, actual: String, metadata: String) { val expectedJson = pretty(parse(expected)) val actualJson = pretty(parse(actual)) @@ -765,7 +809,6 @@ private[spark] object JsonProtocolSuite extends Assertions { assert(ste1 === ste2) } - /** ----------------------------------- * | Util methods for constructing events | * ------------------------------------ */ @@ -820,6 +863,27 @@ private[spark] object JsonProtocolSuite extends Assertions { new AccumulableInfo(id, Some(s"Accumulable$id"), Some(s"delta$id"), Some(s"val$id"), internal, countFailedValues, metadata) + /** Creates an SparkListenerExecutorMetricsUpdate event */ + private def makeExecutorMetricsUpdate( + execId: String, + includeTaskMetrics: Boolean, + includeExecutorMetrics: Boolean): SparkListenerExecutorMetricsUpdate = { + val taskMetrics = + if (includeTaskMetrics) { + Seq((1L, 1, 1, Seq(makeAccumulableInfo(1, false, false, None), + makeAccumulableInfo(2, false, false, None)))) + } else { + Seq() + } + val executorMetricsUpdate = + if (includeExecutorMetrics) { + Some(new ExecutorMetrics(Array(123456L, 543L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L))) + } else { + None + } + SparkListenerExecutorMetricsUpdate(execId, taskMetrics, executorMetricsUpdate) + } + /** * Creates a TaskMetrics object describing a task that read data from Hadoop (if hasHadoopInput is * set to true) or read data from a shuffle otherwise. @@ -2007,7 +2071,42 @@ private[spark] object JsonProtocolSuite extends Assertions { | } | ] | } - | ] + | ], + | "Executor Metrics Updated" : { + | "JVMHeapMemory" : 543, + | "JVMOffHeapMemory" : 123456, + | "OnHeapExecutionMemory" : 12345, + | "OffHeapExecutionMemory" : 1234, + | "OnHeapStorageMemory" : 123, + | "OffHeapStorageMemory" : 12, + | "OnHeapUnifiedMemory" : 432, + | "OffHeapUnifiedMemory" : 321, + | "DirectPoolMemory" : 654, + | "MappedPoolMemory" : 765 + | } + | + |} + """.stripMargin + + private val stageExecutorMetricsJsonString = + """ + |{ + | "Event": "SparkListenerStageExecutorMetrics", + | "Executor ID": "1", + | "Stage ID": 2, + | "Stage Attempt ID": 3, + | "Executor Metrics" : { + | "JVMHeapMemory" : 543, + | "JVMOffHeapMemory" : 123456, + | "OnHeapExecutionMemory" : 12345, + | "OffHeapExecutionMemory" : 1234, + | "OnHeapStorageMemory" : 123, + | "OffHeapStorageMemory" : 12, + | "OnHeapUnifiedMemory" : 432, + | "OffHeapUnifiedMemory" : 321, + | "DirectPoolMemory" : 654, + | "MappedPoolMemory" : 765 + | } |} """.stripMargin diff --git a/core/src/test/scala/org/apache/spark/util/ThreadUtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/ThreadUtilsSuite.scala index ae3b3d829f1bb..604f1e1ca3101 100644 --- a/core/src/test/scala/org/apache/spark/util/ThreadUtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/ThreadUtilsSuite.scala @@ -133,4 +133,37 @@ class ThreadUtilsSuite extends SparkFunSuite { "stack trace contains unexpected references to ThreadUtils" ) } + + test("parmap should be interruptible") { + val t = new Thread() { + setDaemon(true) + + override def run() { + try { + // "par" is uninterruptible. The following will keep running even if the thread is + // interrupted. We should prefer to use "ThreadUtils.parmap". + // + // (1 to 10).par.flatMap { i => + // Thread.sleep(100000) + // 1 to i + // } + // + ThreadUtils.parmap(1 to 10, "test", 2) { i => + Thread.sleep(100000) + 1 to i + }.flatten + } catch { + case _: InterruptedException => // excepted + } + } + } + t.start() + eventually(timeout(10.seconds)) { + assert(t.isAlive) + } + t.interrupt() + eventually(timeout(10.seconds)) { + assert(!t.isAlive) + } + } } diff --git a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala index 418d2f9b88500..39f4fba78583f 100644 --- a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala @@ -1184,6 +1184,55 @@ class UtilsSuite extends SparkFunSuite with ResetSystemProperties with Logging { assert(Utils.getSimpleName(classOf[MalformedClassObject.MalformedClass]) === "UtilsSuite$MalformedClassObject$MalformedClass") } + + test("stringHalfWidth") { + // scalastyle:off nonascii + assert(Utils.stringHalfWidth(null) == 0) + assert(Utils.stringHalfWidth("") == 0) + assert(Utils.stringHalfWidth("ab c") == 4) + assert(Utils.stringHalfWidth("1098") == 4) + assert(Utils.stringHalfWidth("mø") == 2) + assert(Utils.stringHalfWidth("γύρ") == 3) + assert(Utils.stringHalfWidth("pê") == 2) + assert(Utils.stringHalfWidth("ー") == 2) + assert(Utils.stringHalfWidth("测") == 2) + assert(Utils.stringHalfWidth("か") == 2) + assert(Utils.stringHalfWidth("걸") == 2) + assert(Utils.stringHalfWidth("à") == 1) + assert(Utils.stringHalfWidth("焼") == 2) + assert(Utils.stringHalfWidth("羍む") == 4) + assert(Utils.stringHalfWidth("뺭ᾘ") == 3) + assert(Utils.stringHalfWidth("\u0967\u0968\u0969") == 3) + // scalastyle:on nonascii + } + + test("trimExceptCRLF standalone") { + val crlfSet = Set("\r", "\n") + val nonPrintableButCRLF = (0 to 32).map(_.toChar.toString).toSet -- crlfSet + + // identity for CRLF + crlfSet.foreach { s => Utils.trimExceptCRLF(s) === s } + + // empty for other non-printables + nonPrintableButCRLF.foreach { s => assert(Utils.trimExceptCRLF(s) === "") } + + // identity for a printable string + assert(Utils.trimExceptCRLF("a") === "a") + + // identity for strings with CRLF + crlfSet.foreach { s => + assert(Utils.trimExceptCRLF(s"${s}a") === s"${s}a") + assert(Utils.trimExceptCRLF(s"a${s}") === s"a${s}") + assert(Utils.trimExceptCRLF(s"b${s}b") === s"b${s}b") + } + + // trim nonPrintableButCRLF except when inside a string + nonPrintableButCRLF.foreach { s => + assert(Utils.trimExceptCRLF(s"${s}a") === "a") + assert(Utils.trimExceptCRLF(s"a${s}") === "a") + assert(Utils.trimExceptCRLF(s"b${s}b") === s"b${s}b") + } + } } private class SimpleExtension diff --git a/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala index 35312f2d71131..8a2f2ffe0acf1 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala @@ -18,13 +18,21 @@ package org.apache.spark.util.collection import scala.collection.mutable.ArrayBuffer +import scala.ref.WeakReference + +import org.scalatest.Matchers +import org.scalatest.concurrent.Eventually import org.apache.spark._ import org.apache.spark.internal.config._ import org.apache.spark.io.CompressionCodec import org.apache.spark.memory.MemoryTestingUtils +import org.apache.spark.util.CompletionIterator -class ExternalAppendOnlyMapSuite extends SparkFunSuite with LocalSparkContext { +class ExternalAppendOnlyMapSuite extends SparkFunSuite + with LocalSparkContext + with Eventually + with Matchers{ import TestUtils.{assertNotSpilled, assertSpilled} private val allCompressionCodecs = CompressionCodec.ALL_COMPRESSION_CODECS @@ -414,7 +422,112 @@ class ExternalAppendOnlyMapSuite extends SparkFunSuite with LocalSparkContext { sc.stop() } - test("external aggregation updates peak execution memory") { + test("SPARK-22713 spill during iteration leaks internal map") { + val size = 1000 + val conf = createSparkConf(loadDefaults = true) + sc = new SparkContext("local-cluster[1,1,1024]", "test", conf) + val map = createExternalMap[Int] + + map.insertAll((0 until size).iterator.map(i => (i / 10, i))) + assert(map.numSpills == 0, "map was not supposed to spill") + + val it = map.iterator + assert(it.isInstanceOf[CompletionIterator[_, _]]) + // org.apache.spark.util.collection.AppendOnlyMap.destructiveSortedIterator returns + // an instance of an annonymous Iterator class. + + val underlyingMapRef = WeakReference(map.currentMap) + + { + // direct asserts introduced some macro generated code that held a reference to the map + val tmpIsNull = null == underlyingMapRef.get.orNull + assert(!tmpIsNull) + } + + val first50Keys = for ( _ <- 0 until 50) yield { + val (k, vs) = it.next + val sortedVs = vs.sorted + assert(sortedVs.seq == (0 until 10).map(10 * k + _)) + k + } + assert(map.numSpills == 0) + map.spill(Long.MaxValue, null) + // these asserts try to show that we're no longer holding references to the underlying map. + // it'd be nice to use something like + // https://github.com/scala/scala/blob/2.13.x/test/junit/scala/tools/testing/AssertUtil.scala + // (lines 69-89) + // assert(map.currentMap == null) + eventually { + System.gc() + // direct asserts introduced some macro generated code that held a reference to the map + val tmpIsNull = null == underlyingMapRef.get.orNull + assert(tmpIsNull) + } + + + val next50Keys = for ( _ <- 0 until 50) yield { + val (k, vs) = it.next + val sortedVs = vs.sorted + assert(sortedVs.seq == (0 until 10).map(10 * k + _)) + k + } + assert(!it.hasNext) + val keys = (first50Keys ++ next50Keys).sorted + assert(keys == (0 until 100)) + } + + test("drop all references to the underlying map once the iterator is exhausted") { + val size = 1000 + val conf = createSparkConf(loadDefaults = true) + sc = new SparkContext("local-cluster[1,1,1024]", "test", conf) + val map = createExternalMap[Int] + + map.insertAll((0 until size).iterator.map(i => (i / 10, i))) + assert(map.numSpills == 0, "map was not supposed to spill") + + val underlyingMapRef = WeakReference(map.currentMap) + + { + // direct asserts introduced some macro generated code that held a reference to the map + val tmpIsNull = null == underlyingMapRef.get.orNull + assert(!tmpIsNull) + } + + val it = map.iterator + assert( it.isInstanceOf[CompletionIterator[_, _]]) + + + val keys = it.map{ + case (k, vs) => + val sortedVs = vs.sorted + assert(sortedVs.seq == (0 until 10).map(10 * k + _)) + k + } + .toList + .sorted + + assert(it.isEmpty) + assert(keys == (0 until 100).toList) + + assert(map.numSpills == 0) + // these asserts try to show that we're no longer holding references to the underlying map. + // it'd be nice to use something like + // https://github.com/scala/scala/blob/2.13.x/test/junit/scala/tools/testing/AssertUtil.scala + // (lines 69-89) + assert(map.currentMap == null) + + eventually { + Thread.sleep(500) + System.gc() + // direct asserts introduced some macro generated code that held a reference to the map + val tmpIsNull = null == underlyingMapRef.get.orNull + assert(tmpIsNull) + } + + assert(it.toList.isEmpty) + } + + test("SPARK-22713 external aggregation updates peak execution memory") { val spillThreshold = 1000 val conf = createSparkConf(loadDefaults = false) .set("spark.shuffle.spill.numElementsForceSpillThreshold", spillThreshold.toString) diff --git a/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala index 3e56db5ea116a..47173b89e91e2 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala @@ -26,7 +26,7 @@ import org.apache.spark._ import org.apache.spark.memory.MemoryTestingUtils import org.apache.spark.serializer.{JavaSerializer, KryoSerializer} import org.apache.spark.unsafe.array.LongArray -import org.apache.spark.unsafe.memory.OnHeapMemoryBlock +import org.apache.spark.unsafe.memory.MemoryBlock import org.apache.spark.util.collection.unsafe.sort.{PrefixComparators, RecordPointerAndKeyPrefix, UnsafeSortDataFormat} class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext { @@ -105,8 +105,9 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext { // the form [150000000, 150000001, 150000002, ...., 300000000, 0, 1, 2, ..., 149999999] // that can trigger copyRange() in TimSort.mergeLo() or TimSort.mergeHi() val ref = Array.tabulate[Long](size) { i => if (i < size / 2) size / 2 + i else i } - val buf = new LongArray(OnHeapMemoryBlock.fromArray(ref)) - val tmpBuf = new LongArray(new OnHeapMemoryBlock((size/2) * 8L)) + val buf = new LongArray(MemoryBlock.fromLongArray(ref)) + val tmp = new Array[Long](size/2) + val tmpBuf = new LongArray(MemoryBlock.fromLongArray(tmp)) new Sorter(new UnsafeSortDataFormat(tmpBuf)).sort( buf, 0, size, new Comparator[RecordPointerAndKeyPrefix] { diff --git a/core/src/test/scala/org/apache/spark/util/collection/OpenHashMapSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/OpenHashMapSuite.scala index 08a3200288981..151235dd0fb90 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/OpenHashMapSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/OpenHashMapSuite.scala @@ -194,4 +194,50 @@ class OpenHashMapSuite extends SparkFunSuite with Matchers { val numInvalidValues = map.iterator.count(_._2 == 0) assertResult(0)(numInvalidValues) } + + test("distinguish between the 0/0.0/0L and null") { + val specializedMap1 = new OpenHashMap[String, Long] + specializedMap1("a") = null.asInstanceOf[Long] + specializedMap1("b") = 0L + assert(specializedMap1.contains("a")) + assert(!specializedMap1.contains("c")) + // null.asInstance[Long] will return 0L + assert(specializedMap1("a") === 0L) + assert(specializedMap1("b") === 0L) + // If the data type is in @specialized annotation, and + // the `key` is not be contained, the `map(key)` will return 0 + assert(specializedMap1("c") === 0L) + + val specializedMap2 = new OpenHashMap[String, Double] + specializedMap2("a") = null.asInstanceOf[Double] + specializedMap2("b") = 0.toDouble + assert(specializedMap2.contains("a")) + assert(!specializedMap2.contains("c")) + // null.asInstance[Double] will return 0.0 + assert(specializedMap2("a") === 0.0) + assert(specializedMap2("b") === 0.0) + assert(specializedMap2("c") === 0.0) + + val map1 = new OpenHashMap[String, Short] + map1("a") = null.asInstanceOf[Short] + map1("b") = 0.toShort + assert(map1.contains("a")) + assert(!map1.contains("c")) + // null.asInstance[Short] will return 0 + assert(map1("a") === 0) + assert(map1("b") === 0) + // If the data type is not in @specialized annotation, and + // the `key` is not be contained, the `map(key)` will return null + assert(map1("c") === null) + + val map2 = new OpenHashMap[String, Float] + map2("a") = null.asInstanceOf[Float] + map2("b") = 0.toFloat + assert(map2.contains("a")) + assert(!map2.contains("c")) + // null.asInstance[Float] will return 0.0 + assert(map2("a") === 0.0) + assert(map2("b") === 0.0) + assert(map2("c") === null) + } } diff --git a/core/src/test/scala/org/apache/spark/util/collection/OpenHashSetSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/OpenHashSetSuite.scala index 210bc5c099742..b887f937a9da9 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/OpenHashSetSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/OpenHashSetSuite.scala @@ -112,6 +112,80 @@ class OpenHashSetSuite extends SparkFunSuite with Matchers { assert(!set.contains(10000L)) } + test("primitive float") { + val set = new OpenHashSet[Float] + assert(set.size === 0) + assert(!set.contains(10.1F)) + assert(!set.contains(50.5F)) + assert(!set.contains(999.9F)) + assert(!set.contains(10000.1F)) + + set.add(10.1F) + assert(set.size === 1) + assert(set.contains(10.1F)) + assert(!set.contains(50.5F)) + assert(!set.contains(999.9F)) + assert(!set.contains(10000.1F)) + + set.add(50.5F) + assert(set.size === 2) + assert(set.contains(10.1F)) + assert(set.contains(50.5F)) + assert(!set.contains(999.9F)) + assert(!set.contains(10000.1F)) + + set.add(999.9F) + assert(set.size === 3) + assert(set.contains(10.1F)) + assert(set.contains(50.5F)) + assert(set.contains(999.9F)) + assert(!set.contains(10000.1F)) + + set.add(50.5F) + assert(set.size === 3) + assert(set.contains(10.1F)) + assert(set.contains(50.5F)) + assert(set.contains(999.9F)) + assert(!set.contains(10000.1F)) + } + + test("primitive double") { + val set = new OpenHashSet[Double] + assert(set.size === 0) + assert(!set.contains(10.1D)) + assert(!set.contains(50.5D)) + assert(!set.contains(999.9D)) + assert(!set.contains(10000.1D)) + + set.add(10.1D) + assert(set.size === 1) + assert(set.contains(10.1D)) + assert(!set.contains(50.5D)) + assert(!set.contains(999.9D)) + assert(!set.contains(10000.1D)) + + set.add(50.5D) + assert(set.size === 2) + assert(set.contains(10.1D)) + assert(set.contains(50.5D)) + assert(!set.contains(999.9D)) + assert(!set.contains(10000.1D)) + + set.add(999.9D) + assert(set.size === 3) + assert(set.contains(10.1D)) + assert(set.contains(50.5D)) + assert(set.contains(999.9D)) + assert(!set.contains(10000.1D)) + + set.add(50.5D) + assert(set.size === 3) + assert(set.contains(10.1D)) + assert(set.contains(50.5D)) + assert(set.contains(999.9D)) + assert(!set.contains(10000.1D)) + } + test("non-primitive") { val set = new OpenHashSet[String] assert(set.size === 0) diff --git a/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/RadixSortSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/RadixSortSuite.scala index ddf3740e76a7a..d5956ea32096a 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/RadixSortSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/RadixSortSuite.scala @@ -27,7 +27,7 @@ import com.google.common.primitives.Ints import org.apache.spark.SparkFunSuite import org.apache.spark.internal.Logging import org.apache.spark.unsafe.array.LongArray -import org.apache.spark.unsafe.memory.OnHeapMemoryBlock +import org.apache.spark.unsafe.memory.MemoryBlock import org.apache.spark.util.collection.Sorter import org.apache.spark.util.random.XORShiftRandom @@ -78,14 +78,14 @@ class RadixSortSuite extends SparkFunSuite with Logging { private def generateTestData(size: Long, rand: => Long): (Array[JLong], LongArray) = { val ref = Array.tabulate[Long](Ints.checkedCast(size)) { i => rand } val extended = ref ++ Array.fill[Long](Ints.checkedCast(size))(0) - (ref.map(i => new JLong(i)), new LongArray(OnHeapMemoryBlock.fromArray(extended))) + (ref.map(i => new JLong(i)), new LongArray(MemoryBlock.fromLongArray(extended))) } private def generateKeyPrefixTestData(size: Long, rand: => Long): (LongArray, LongArray) = { val ref = Array.tabulate[Long](Ints.checkedCast(size * 2)) { i => rand } val extended = ref ++ Array.fill[Long](Ints.checkedCast(size * 2))(0) - (new LongArray(OnHeapMemoryBlock.fromArray(ref)), - new LongArray(OnHeapMemoryBlock.fromArray(extended))) + (new LongArray(MemoryBlock.fromLongArray(ref)), + new LongArray(MemoryBlock.fromLongArray(extended))) } private def collectToArray(array: LongArray, offset: Int, length: Long): Array[Long] = { @@ -110,7 +110,7 @@ class RadixSortSuite extends SparkFunSuite with Logging { } private def referenceKeyPrefixSort(buf: LongArray, lo: Long, hi: Long, refCmp: PrefixComparator) { - val sortBuffer = new LongArray(new OnHeapMemoryBlock(buf.size() * 8L)) + val sortBuffer = new LongArray(MemoryBlock.fromLongArray(new Array[Long](buf.size().toInt))) new Sorter(new UnsafeSortDataFormat(sortBuffer)).sort( buf, Ints.checkedCast(lo), Ints.checkedCast(hi), new Comparator[RecordPointerAndKeyPrefix] { override def compare( diff --git a/data/mllib/images/kittens/29.5.a_b_EGDP022204.jpg b/data/mllib/images/origin/kittens/29.5.a_b_EGDP022204.jpg similarity index 100% rename from data/mllib/images/kittens/29.5.a_b_EGDP022204.jpg rename to data/mllib/images/origin/kittens/29.5.a_b_EGDP022204.jpg diff --git a/data/mllib/images/kittens/54893.jpg b/data/mllib/images/origin/kittens/54893.jpg similarity index 100% rename from data/mllib/images/kittens/54893.jpg rename to data/mllib/images/origin/kittens/54893.jpg diff --git a/data/mllib/images/kittens/DP153539.jpg b/data/mllib/images/origin/kittens/DP153539.jpg similarity index 100% rename from data/mllib/images/kittens/DP153539.jpg rename to data/mllib/images/origin/kittens/DP153539.jpg diff --git a/data/mllib/images/kittens/DP802813.jpg b/data/mllib/images/origin/kittens/DP802813.jpg similarity index 100% rename from data/mllib/images/kittens/DP802813.jpg rename to data/mllib/images/origin/kittens/DP802813.jpg diff --git a/data/mllib/images/kittens/not-image.txt b/data/mllib/images/origin/kittens/not-image.txt similarity index 100% rename from data/mllib/images/kittens/not-image.txt rename to data/mllib/images/origin/kittens/not-image.txt diff --git a/data/mllib/images/origin/license.txt b/data/mllib/images/origin/license.txt new file mode 100644 index 0000000000000..052f302c4670a --- /dev/null +++ b/data/mllib/images/origin/license.txt @@ -0,0 +1,13 @@ +The images in the folder "kittens" are under the creative commons CC0 license, or no rights reserved: +https://creativecommons.org/share-your-work/public-domain/cc0/ +The images are taken from: +https://ccsearch.creativecommons.org/image/detail/WZnbJSJ2-dzIDiuUUdto3Q== +https://ccsearch.creativecommons.org/image/detail/_TlKu_rm_QrWlR0zthQTXA== +https://ccsearch.creativecommons.org/image/detail/OPNnHJb6q37rSZ5o_L5JHQ== +https://ccsearch.creativecommons.org/image/detail/B2CVP_j5KjwZm7UAVJ3Hvw== + +The chr30.4.184.jpg and grayscale.jpg images are also under the CC0 license, taken from: +https://ccsearch.creativecommons.org/image/detail/8eO_qqotBfEm2UYxirLntw== + +The image under "multi-channel" directory is under the CC BY-SA 4.0 license cropped from: +https://en.wikipedia.org/wiki/Alpha_compositing#/media/File:Hue_alpha_falloff.png diff --git a/data/mllib/images/multi-channel/BGRA.png b/data/mllib/images/origin/multi-channel/BGRA.png similarity index 100% rename from data/mllib/images/multi-channel/BGRA.png rename to data/mllib/images/origin/multi-channel/BGRA.png diff --git a/data/mllib/images/multi-channel/BGRA_alpha_60.png b/data/mllib/images/origin/multi-channel/BGRA_alpha_60.png similarity index 100% rename from data/mllib/images/multi-channel/BGRA_alpha_60.png rename to data/mllib/images/origin/multi-channel/BGRA_alpha_60.png diff --git a/data/mllib/images/multi-channel/chr30.4.184.jpg b/data/mllib/images/origin/multi-channel/chr30.4.184.jpg similarity index 100% rename from data/mllib/images/multi-channel/chr30.4.184.jpg rename to data/mllib/images/origin/multi-channel/chr30.4.184.jpg diff --git a/data/mllib/images/multi-channel/grayscale.jpg b/data/mllib/images/origin/multi-channel/grayscale.jpg similarity index 100% rename from data/mllib/images/multi-channel/grayscale.jpg rename to data/mllib/images/origin/multi-channel/grayscale.jpg diff --git a/data/mllib/images/partitioned/cls=kittens/date=2018-01/29.5.a_b_EGDP022204.jpg b/data/mllib/images/partitioned/cls=kittens/date=2018-01/29.5.a_b_EGDP022204.jpg new file mode 100644 index 0000000000000..435e7dfd6a459 Binary files /dev/null and b/data/mllib/images/partitioned/cls=kittens/date=2018-01/29.5.a_b_EGDP022204.jpg differ diff --git a/data/mllib/images/partitioned/cls=kittens/date=2018-01/not-image.txt b/data/mllib/images/partitioned/cls=kittens/date=2018-01/not-image.txt new file mode 100644 index 0000000000000..283e5e936f231 --- /dev/null +++ b/data/mllib/images/partitioned/cls=kittens/date=2018-01/not-image.txt @@ -0,0 +1 @@ +not an image diff --git a/data/mllib/images/partitioned/cls=kittens/date=2018-02/54893.jpg b/data/mllib/images/partitioned/cls=kittens/date=2018-02/54893.jpg new file mode 100644 index 0000000000000..825630cc40288 Binary files /dev/null and b/data/mllib/images/partitioned/cls=kittens/date=2018-02/54893.jpg differ diff --git a/data/mllib/images/partitioned/cls=kittens/date=2018-02/DP153539.jpg b/data/mllib/images/partitioned/cls=kittens/date=2018-02/DP153539.jpg new file mode 100644 index 0000000000000..571efe933ccfc Binary files /dev/null and b/data/mllib/images/partitioned/cls=kittens/date=2018-02/DP153539.jpg differ diff --git a/data/mllib/images/partitioned/cls=kittens/date=2018-02/DP802813.jpg b/data/mllib/images/partitioned/cls=kittens/date=2018-02/DP802813.jpg new file mode 100644 index 0000000000000..2d123594b7af7 Binary files /dev/null and b/data/mllib/images/partitioned/cls=kittens/date=2018-02/DP802813.jpg differ diff --git a/data/mllib/images/partitioned/cls=multichannel/date=2018-01/BGRA.png b/data/mllib/images/partitioned/cls=multichannel/date=2018-01/BGRA.png new file mode 100644 index 0000000000000..a944c6cdb066d Binary files /dev/null and b/data/mllib/images/partitioned/cls=multichannel/date=2018-01/BGRA.png differ diff --git a/data/mllib/images/partitioned/cls=multichannel/date=2018-01/BGRA_alpha_60.png b/data/mllib/images/partitioned/cls=multichannel/date=2018-01/BGRA_alpha_60.png new file mode 100644 index 0000000000000..913637cd2828a Binary files /dev/null and b/data/mllib/images/partitioned/cls=multichannel/date=2018-01/BGRA_alpha_60.png differ diff --git a/data/mllib/images/partitioned/cls=multichannel/date=2018-02/chr30.4.184.jpg b/data/mllib/images/partitioned/cls=multichannel/date=2018-02/chr30.4.184.jpg new file mode 100644 index 0000000000000..7068b97deb344 Binary files /dev/null and b/data/mllib/images/partitioned/cls=multichannel/date=2018-02/chr30.4.184.jpg differ diff --git a/data/mllib/images/partitioned/cls=multichannel/date=2018-02/grayscale.jpg b/data/mllib/images/partitioned/cls=multichannel/date=2018-02/grayscale.jpg new file mode 100644 index 0000000000000..621cdd11e2b92 Binary files /dev/null and b/data/mllib/images/partitioned/cls=multichannel/date=2018-02/grayscale.jpg differ diff --git a/dev/.rat-excludes b/dev/.rat-excludes index 23b24212b4d29..777950016801d 100644 --- a/dev/.rat-excludes +++ b/dev/.rat-excludes @@ -11,6 +11,10 @@ cache .rat-excludes .*md derby.log +licenses/* +licenses-binary/* +LICENSE +NOTICE TAGS RELEASE control @@ -77,6 +81,7 @@ app-20180109111548-0000 app-20161115172038-0000 app-20161116163331-0000 application_1516285256255_0012 +application_1506645932520_24630151 local-1422981759269 local-1422981780767 local-1425081759269 diff --git a/dev/appveyor-install-dependencies.ps1 b/dev/appveyor-install-dependencies.ps1 index e6afb18558852..8a04b621f8ce4 100644 --- a/dev/appveyor-install-dependencies.ps1 +++ b/dev/appveyor-install-dependencies.ps1 @@ -81,7 +81,7 @@ if (!(Test-Path $tools)) { # ========================== Maven Push-Location $tools -$mavenVer = "3.3.9" +$mavenVer = "3.5.4" Start-FileDownload "https://archive.apache.org/dist/maven/maven-3/$mavenVer/binaries/apache-maven-$mavenVer-bin.zip" "maven.zip" # extract diff --git a/dev/create-release/generate-contributors.py b/dev/create-release/generate-contributors.py index 131d81c8a75cf..d9135173419ae 100755 --- a/dev/create-release/generate-contributors.py +++ b/dev/create-release/generate-contributors.py @@ -67,7 +67,7 @@ print("Release tag: %s" % RELEASE_TAG) print("Previous release tag: %s" % PREVIOUS_RELEASE_TAG) print("Number of commits in this range: %s" % len(new_commits)) -print +print("") def print_indented(_list): @@ -88,10 +88,10 @@ def print_indented(_list): def is_release(commit_title): - return re.findall("\[release\]", commit_title.lower()) or \ - "preparing spark release" in commit_title.lower() or \ - "preparing development version" in commit_title.lower() or \ - "CHANGES.txt" in commit_title + return ("[release]" in commit_title.lower() or + "preparing spark release" in commit_title.lower() or + "preparing development version" in commit_title.lower() or + "CHANGES.txt" in commit_title) def is_maintenance(commit_title): diff --git a/dev/create-release/release-build.sh b/dev/create-release/release-build.sh index 24a62a8f4c7d3..73610a3335910 100755 --- a/dev/create-release/release-build.sh +++ b/dev/create-release/release-build.sh @@ -178,12 +178,17 @@ if [[ "$1" == "package" ]]; then SHA512 spark-$SPARK_VERSION.tgz > spark-$SPARK_VERSION.tgz.sha512 rm -rf spark-$SPARK_VERSION + ZINC_PORT=3035 + # Updated for each binary build make_binary_release() { NAME=$1 - FLAGS=$2 - ZINC_PORT=$3 - BUILD_PACKAGE=$4 + FLAGS="$MVN_EXTRA_OPTS -B $SCALA_2_11_PROFILES $BASE_RELEASE_PROFILES $2" + BUILD_PACKAGE=$3 + + # We increment the Zinc port each time to avoid OOM's and other craziness if multiple builds + # share the same Zinc server. + ZINC_PORT=$((ZINC_PORT + 1)) echo "Building binary dist $NAME" cp -r spark spark-$SPARK_VERSION-bin-$NAME @@ -255,20 +260,39 @@ if [[ "$1" == "package" ]]; then spark-$SPARK_VERSION-bin-$NAME.tgz.sha512 } - # We increment the Zinc port each time to avoid OOM's and other craziness if multiple builds - # share the same Zinc server. - if ! make_binary_release "hadoop2.6" "$MVN_EXTRA_OPTS -B -Phadoop-2.6 $HIVE_PROFILES $SCALA_2_11_PROFILES $BASE_RELEASE_PROFILES" "3035" "withr"; then - error "Failed to build hadoop2.6 package. Check logs for details." - fi + # List of binary packages built. Populates two associative arrays, where the key is the "name" of + # the package being built, and the values are respectively the needed maven arguments for building + # the package, and any extra package needed for that particular combination. + # + # In dry run mode, only build the first one. The keys in BINARY_PKGS_ARGS are used as the + # list of packages to be built, so it's ok for things to be missing in BINARY_PKGS_EXTRA. + + declare -A BINARY_PKGS_ARGS + BINARY_PKGS_ARGS["hadoop2.7"]="-Phadoop-2.7 $HIVE_PROFILES" if ! is_dry_run; then - if ! make_binary_release "hadoop2.7" "$MVN_EXTRA_OPTS -B -Phadoop-2.7 $HIVE_PROFILES $SCALA_2_11_PROFILES $BASE_RELEASE_PROFILES" "3036" "withpip"; then - error "Failed to build hadoop2.7 package. Check logs for details." - fi - if ! make_binary_release "without-hadoop" "$MVN_EXTRA_OPTS -B -Phadoop-provided $SCALA_2_11_PROFILES $BASE_RELEASE_PROFILES" "3037"; then - error "Failed to build without-hadoop package. Check logs for details." + BINARY_PKGS_ARGS["hadoop2.6"]="-Phadoop-2.6 $HIVE_PROFILES" + BINARY_PKGS_ARGS["without-hadoop"]="-Pwithout-hadoop" + if [[ $SPARK_VERSION < "2.2." ]]; then + BINARY_PKGS_ARGS["hadoop2.4"]="-Phadoop-2.4 $HIVE_PROFILES" + BINARY_PKGS_ARGS["hadoop2.3"]="-Phadoop-2.3 $HIVE_PROFILES" fi fi + declare -A BINARY_PKGS_EXTRA + BINARY_PKGS_EXTRA["hadoop2.7"]="withpip" + if ! is_dry_run; then + BINARY_PKGS_EXTRA["hadoop2.6"]="withr" + fi + + echo "Packages to build: ${!BINARY_PKGS_ARGS[@]}" + for key in ${!BINARY_PKGS_ARGS[@]}; do + args=${BINARY_PKGS_ARGS[$key]} + extra=${BINARY_PKGS_EXTRA[$key]} + if ! make_binary_release "$key" "$args" "$extra"; then + error "Failed to build $key package. Check logs for details." + fi + done + rm -rf spark-$SPARK_VERSION-bin-*/ if ! is_dry_run; then diff --git a/dev/create-release/releaseutils.py b/dev/create-release/releaseutils.py index 32f6cbb29f0be..f273b337fdb4e 100755 --- a/dev/create-release/releaseutils.py +++ b/dev/create-release/releaseutils.py @@ -49,13 +49,16 @@ print("Install using 'sudo pip install unidecode'") sys.exit(-1) +if sys.version < '3': + input = raw_input # noqa + # Contributors list file name contributors_file_name = "contributors.txt" # Prompt the user to answer yes or no until they do so def yesOrNoPrompt(msg): - response = raw_input("%s [y/n]: " % msg) + response = input("%s [y/n]: " % msg) while response != "y" and response != "n": return yesOrNoPrompt(msg) return response == "y" @@ -149,7 +152,11 @@ def get_commits(tag): if not is_valid_author(author): author = github_username # Guard against special characters - author = unidecode.unidecode(unicode(author, "UTF-8")).strip() + try: # Python 2 + author = unicode(author, "UTF-8") + except NameError: # Python 3 + author = str(author) + author = unidecode.unidecode(author).strip() commit = Commit(_hash, author, title, pr_number) commits.append(commit) return commits @@ -228,7 +235,7 @@ def translate_component(component, commit_hash, warnings): # Parse components in the commit message # The returned components are already filtered and translated def find_components(commit, commit_hash): - components = re.findall("\[\w*\]", commit.lower()) + components = re.findall(r"\[\w*\]", commit.lower()) components = [translate_component(c, commit_hash) for c in components if c in known_components] return components diff --git a/dev/deps/spark-deps-hadoop-2.6 b/dev/deps/spark-deps-hadoop-2.6 index 96e9c27210d05..62ae04dbc255f 100644 --- a/dev/deps/spark-deps-hadoop-2.6 +++ b/dev/deps/spark-deps-hadoop-2.6 @@ -2,7 +2,7 @@ JavaEWAH-0.3.2.jar RoaringBitmap-0.5.11.jar ST4-4.0.4.jar activation-1.1.1.jar -aircompressor-0.8.jar +aircompressor-0.10.jar antlr-2.7.7.jar antlr-runtime-3.4.jar antlr4-runtime-4.7.jar @@ -14,30 +14,28 @@ apacheds-kerberos-codec-2.0.0-M15.jar api-asn1-api-1.0.0-M20.jar api-util-1.0.0-M20.jar arpack_combined_all-0.1.jar -arrow-format-0.8.0.jar -arrow-memory-0.8.0.jar -arrow-vector-0.8.0.jar +arrow-format-0.10.0.jar +arrow-memory-0.10.0.jar +arrow-vector-0.10.0.jar automaton-1.11-8.jar -avro-1.7.7.jar -avro-ipc-1.7.7.jar -avro-mapred-1.7.7-hadoop2.jar -base64-2.3.8.jar -bcprov-jdk15on-1.58.jar +avro-1.8.2.jar +avro-ipc-1.8.2.jar +avro-mapred-1.8.2-hadoop2.jar bonecp-0.8.0.RELEASE.jar breeze-macros_2.11-0.13.2.jar breeze_2.11-0.13.2.jar calcite-avatica-1.2.0-incubating.jar calcite-core-1.2.0-incubating.jar calcite-linq4j-1.2.0-incubating.jar -chill-java-0.8.4.jar -chill_2.11-0.8.4.jar +chill-java-0.9.3.jar +chill_2.11-0.9.3.jar commons-beanutils-1.7.0.jar commons-beanutils-core-1.8.0.jar commons-cli-1.2.jar commons-codec-1.10.jar commons-collections-3.2.2.jar -commons-compiler-3.0.8.jar -commons-compress-1.4.1.jar +commons-compiler-3.0.9.jar +commons-compress-1.8.1.jar commons-configuration-1.6.jar commons-crypto-1.0.0.jar commons-dbcp-1.4.jar @@ -86,8 +84,8 @@ hk2-locator-2.4.0-b34.jar hk2-utils-2.4.0-b34.jar hppc-0.7.2.jar htrace-core-3.0.4.jar -httpclient-4.5.4.jar -httpcore-4.4.8.jar +httpclient-4.5.6.jar +httpcore-4.4.10.jar ivy-2.4.0.jar jackson-annotations-2.6.7.jar jackson-core-2.6.7.jar @@ -100,8 +98,7 @@ jackson-module-jaxb-annotations-2.6.7.jar jackson-module-paranamer-2.7.9.jar jackson-module-scala_2.11-2.6.7.1.jar jackson-xc-1.9.13.jar -janino-3.0.8.jar -java-xmlbuilder-1.1.jar +janino-3.0.9.jar javassist-3.18.1-GA.jar javax.annotation-api-1.2.jar javax.inject-1.jar @@ -119,10 +116,9 @@ jersey-container-servlet-core-2.22.2.jar jersey-guava-2.22.2.jar jersey-media-jaxb-2.22.2.jar jersey-server-2.22.2.jar -jets3t-0.9.4.jar jetty-6.1.26.jar jetty-util-6.1.26.jar -jline-2.14.3.jar +jline-2.14.6.jar joda-time-2.9.3.jar jodd-core-3.5.2.jar jpam-1.1.jar @@ -134,7 +130,7 @@ jsr305-1.3.9.jar jta-1.1.jar jtransforms-2.4.0.jar jul-to-slf4j-1.7.16.jar -kryo-shaded-3.0.3.jar +kryo-shaded-4.0.2.jar kubernetes-client-3.0.0.jar kubernetes-model-2.0.0.jar leveldbjni-all-1.8.jar @@ -153,12 +149,13 @@ metrics-jvm-3.1.5.jar minlog-1.3.0.jar netty-3.9.9.Final.jar netty-all-4.1.17.Final.jar -objenesis-2.1.jar +objenesis-2.5.1.jar okhttp-3.8.1.jar okio-1.13.0.jar opencsv-2.3.jar -orc-core-1.4.4-nohive.jar -orc-mapreduce-1.4.4-nohive.jar +orc-core-1.5.2-nohive.jar +orc-mapreduce-1.5.2-nohive.jar +orc-shims-1.5.2.jar oro-2.0.8.jar osgi-resource-locator-1.0.1.jar paranamer-2.8.jar @@ -190,12 +187,12 @@ stax-api-1.0.1.jar stream-2.7.0.jar stringtemplate-3.2.1.jar super-csv-2.2.0.jar -univocity-parsers-2.6.3.jar +univocity-parsers-2.7.3.jar validation-api-1.1.0.Final.jar -xbean-asm5-shaded-4.4.jar +xbean-asm6-shaded-4.8.jar xercesImpl-2.9.1.jar xmlenc-0.52.jar -xz-1.0.jar +xz-1.5.jar zjsonpatch-0.3.0.jar zookeeper-3.4.6.jar zstd-jni-1.3.2-2.jar diff --git a/dev/deps/spark-deps-hadoop-2.7 b/dev/deps/spark-deps-hadoop-2.7 index 4a6ee027ec355..dcb5d63aeff4d 100644 --- a/dev/deps/spark-deps-hadoop-2.7 +++ b/dev/deps/spark-deps-hadoop-2.7 @@ -2,7 +2,7 @@ JavaEWAH-0.3.2.jar RoaringBitmap-0.5.11.jar ST4-4.0.4.jar activation-1.1.1.jar -aircompressor-0.8.jar +aircompressor-0.10.jar antlr-2.7.7.jar antlr-runtime-3.4.jar antlr4-runtime-4.7.jar @@ -14,30 +14,28 @@ apacheds-kerberos-codec-2.0.0-M15.jar api-asn1-api-1.0.0-M20.jar api-util-1.0.0-M20.jar arpack_combined_all-0.1.jar -arrow-format-0.8.0.jar -arrow-memory-0.8.0.jar -arrow-vector-0.8.0.jar +arrow-format-0.10.0.jar +arrow-memory-0.10.0.jar +arrow-vector-0.10.0.jar automaton-1.11-8.jar -avro-1.7.7.jar -avro-ipc-1.7.7.jar -avro-mapred-1.7.7-hadoop2.jar -base64-2.3.8.jar -bcprov-jdk15on-1.58.jar +avro-1.8.2.jar +avro-ipc-1.8.2.jar +avro-mapred-1.8.2-hadoop2.jar bonecp-0.8.0.RELEASE.jar breeze-macros_2.11-0.13.2.jar breeze_2.11-0.13.2.jar calcite-avatica-1.2.0-incubating.jar calcite-core-1.2.0-incubating.jar calcite-linq4j-1.2.0-incubating.jar -chill-java-0.8.4.jar -chill_2.11-0.8.4.jar +chill-java-0.9.3.jar +chill_2.11-0.9.3.jar commons-beanutils-1.7.0.jar commons-beanutils-core-1.8.0.jar commons-cli-1.2.jar commons-codec-1.10.jar commons-collections-3.2.2.jar -commons-compiler-3.0.8.jar -commons-compress-1.4.1.jar +commons-compiler-3.0.9.jar +commons-compress-1.8.1.jar commons-configuration-1.6.jar commons-crypto-1.0.0.jar commons-dbcp-1.4.jar @@ -86,8 +84,8 @@ hk2-locator-2.4.0-b34.jar hk2-utils-2.4.0-b34.jar hppc-0.7.2.jar htrace-core-3.1.0-incubating.jar -httpclient-4.5.4.jar -httpcore-4.4.8.jar +httpclient-4.5.6.jar +httpcore-4.4.10.jar ivy-2.4.0.jar jackson-annotations-2.6.7.jar jackson-core-2.6.7.jar @@ -100,8 +98,7 @@ jackson-module-jaxb-annotations-2.6.7.jar jackson-module-paranamer-2.7.9.jar jackson-module-scala_2.11-2.6.7.1.jar jackson-xc-1.9.13.jar -janino-3.0.8.jar -java-xmlbuilder-1.1.jar +janino-3.0.9.jar javassist-3.18.1-GA.jar javax.annotation-api-1.2.jar javax.inject-1.jar @@ -119,10 +116,9 @@ jersey-container-servlet-core-2.22.2.jar jersey-guava-2.22.2.jar jersey-media-jaxb-2.22.2.jar jersey-server-2.22.2.jar -jets3t-0.9.4.jar jetty-6.1.26.jar jetty-util-6.1.26.jar -jline-2.14.3.jar +jline-2.14.6.jar joda-time-2.9.3.jar jodd-core-3.5.2.jar jpam-1.1.jar @@ -135,7 +131,7 @@ jsr305-1.3.9.jar jta-1.1.jar jtransforms-2.4.0.jar jul-to-slf4j-1.7.16.jar -kryo-shaded-3.0.3.jar +kryo-shaded-4.0.2.jar kubernetes-client-3.0.0.jar kubernetes-model-2.0.0.jar leveldbjni-all-1.8.jar @@ -154,12 +150,13 @@ metrics-jvm-3.1.5.jar minlog-1.3.0.jar netty-3.9.9.Final.jar netty-all-4.1.17.Final.jar -objenesis-2.1.jar +objenesis-2.5.1.jar okhttp-3.8.1.jar okio-1.13.0.jar opencsv-2.3.jar -orc-core-1.4.4-nohive.jar -orc-mapreduce-1.4.4-nohive.jar +orc-core-1.5.2-nohive.jar +orc-mapreduce-1.5.2-nohive.jar +orc-shims-1.5.2.jar oro-2.0.8.jar osgi-resource-locator-1.0.1.jar paranamer-2.8.jar @@ -191,12 +188,12 @@ stax-api-1.0.1.jar stream-2.7.0.jar stringtemplate-3.2.1.jar super-csv-2.2.0.jar -univocity-parsers-2.6.3.jar +univocity-parsers-2.7.3.jar validation-api-1.1.0.Final.jar -xbean-asm5-shaded-4.4.jar +xbean-asm6-shaded-4.8.jar xercesImpl-2.9.1.jar xmlenc-0.52.jar -xz-1.0.jar +xz-1.5.jar zjsonpatch-0.3.0.jar zookeeper-3.4.6.jar zstd-jni-1.3.2-2.jar diff --git a/dev/deps/spark-deps-hadoop-3.1 b/dev/deps/spark-deps-hadoop-3.1 index e0b560c8ec71f..641b4a15ad7cd 100644 --- a/dev/deps/spark-deps-hadoop-3.1 +++ b/dev/deps/spark-deps-hadoop-3.1 @@ -4,7 +4,7 @@ RoaringBitmap-0.5.11.jar ST4-4.0.4.jar accessors-smart-1.2.jar activation-1.1.1.jar -aircompressor-0.8.jar +aircompressor-0.10.jar antlr-2.7.7.jar antlr-runtime-3.4.jar antlr4-runtime-4.7.jar @@ -12,29 +12,27 @@ aopalliance-1.0.jar aopalliance-repackaged-2.4.0-b34.jar apache-log4j-extras-1.2.17.jar arpack_combined_all-0.1.jar -arrow-format-0.8.0.jar -arrow-memory-0.8.0.jar -arrow-vector-0.8.0.jar +arrow-format-0.10.0.jar +arrow-memory-0.10.0.jar +arrow-vector-0.10.0.jar automaton-1.11-8.jar -avro-1.7.7.jar -avro-ipc-1.7.7.jar -avro-mapred-1.7.7-hadoop2.jar -base64-2.3.8.jar -bcprov-jdk15on-1.58.jar +avro-1.8.2.jar +avro-ipc-1.8.2.jar +avro-mapred-1.8.2-hadoop2.jar bonecp-0.8.0.RELEASE.jar breeze-macros_2.11-0.13.2.jar breeze_2.11-0.13.2.jar calcite-avatica-1.2.0-incubating.jar calcite-core-1.2.0-incubating.jar calcite-linq4j-1.2.0-incubating.jar -chill-java-0.8.4.jar -chill_2.11-0.8.4.jar +chill-java-0.9.3.jar +chill_2.11-0.9.3.jar commons-beanutils-1.9.3.jar commons-cli-1.2.jar commons-codec-1.10.jar commons-collections-3.2.2.jar -commons-compiler-3.0.8.jar -commons-compress-1.4.1.jar +commons-compiler-3.0.9.jar +commons-compress-1.8.1.jar commons-configuration2-2.1.1.jar commons-crypto-1.0.0.jar commons-daemon-1.0.13.jar @@ -85,8 +83,8 @@ hk2-locator-2.4.0-b34.jar hk2-utils-2.4.0-b34.jar hppc-0.7.2.jar htrace-core4-4.1.0-incubating.jar -httpclient-4.5.4.jar -httpcore-4.4.8.jar +httpclient-4.5.6.jar +httpcore-4.4.10.jar ivy-2.4.0.jar jackson-annotations-2.6.7.jar jackson-core-2.6.7.jar @@ -99,8 +97,7 @@ jackson-mapper-asl-1.9.13.jar jackson-module-jaxb-annotations-2.6.7.jar jackson-module-paranamer-2.7.9.jar jackson-module-scala_2.11-2.6.7.1.jar -janino-3.0.8.jar -java-xmlbuilder-1.1.jar +janino-3.0.9.jar javassist-3.18.1-GA.jar javax.annotation-api-1.2.jar javax.inject-1.jar @@ -119,10 +116,9 @@ jersey-container-servlet-core-2.22.2.jar jersey-guava-2.22.2.jar jersey-media-jaxb-2.22.2.jar jersey-server-2.22.2.jar -jets3t-0.9.4.jar -jetty-webapp-9.3.20.v20170531.jar -jetty-xml-9.3.20.v20170531.jar -jline-2.14.3.jar +jetty-webapp-9.3.24.v20180605.jar +jetty-xml-9.3.24.v20180605.jar +jline-2.14.6.jar joda-time-2.9.3.jar jodd-core-3.5.2.jar jpam-1.1.jar @@ -150,7 +146,7 @@ kerby-config-1.0.1.jar kerby-pkix-1.0.1.jar kerby-util-1.0.1.jar kerby-xdr-1.0.1.jar -kryo-shaded-3.0.3.jar +kryo-shaded-4.0.2.jar kubernetes-client-3.0.0.jar kubernetes-model-2.0.0.jar leveldbjni-all-1.8.jar @@ -171,13 +167,14 @@ mssql-jdbc-6.2.1.jre7.jar netty-3.9.9.Final.jar netty-all-4.1.17.Final.jar nimbus-jose-jwt-4.41.1.jar -objenesis-2.1.jar +objenesis-2.5.1.jar okhttp-2.7.5.jar okhttp-3.8.1.jar okio-1.13.0.jar opencsv-2.3.jar -orc-core-1.4.4-nohive.jar -orc-mapreduce-1.4.4-nohive.jar +orc-core-1.5.2-nohive.jar +orc-mapreduce-1.5.2-nohive.jar +orc-shims-1.5.2.jar oro-2.0.8.jar osgi-resource-locator-1.0.1.jar paranamer-2.8.jar @@ -211,11 +208,11 @@ stream-2.7.0.jar stringtemplate-3.2.1.jar super-csv-2.2.0.jar token-provider-1.0.1.jar -univocity-parsers-2.6.3.jar +univocity-parsers-2.7.3.jar validation-api-1.1.0.Final.jar woodstox-core-5.0.3.jar -xbean-asm5-shaded-4.4.jar -xz-1.0.jar +xbean-asm6-shaded-4.8.jar +xz-1.5.jar zjsonpatch-0.3.0.jar zookeeper-3.4.9.jar zstd-jni-1.3.2-2.jar diff --git a/dev/lint-python b/dev/lint-python index f738af9c49763..e26bd4bd4517c 100755 --- a/dev/lint-python +++ b/dev/lint-python @@ -36,7 +36,7 @@ compile_status="${PIPESTATUS[0]}" # Get pycodestyle at runtime so that we don't rely on it being installed on the build server. # See: https://github.com/apache/spark/pull/1744#issuecomment-50982162 # Updated to the latest official version of pep8. pep8 is formally renamed to pycodestyle. -PYCODESTYLE_VERSION="2.3.1" +PYCODESTYLE_VERSION="2.4.0" PYCODESTYLE_SCRIPT_PATH="$SPARK_ROOT_DIR/dev/pycodestyle-$PYCODESTYLE_VERSION.py" PYCODESTYLE_SCRIPT_REMOTE_PATH="https://raw.githubusercontent.com/PyCQA/pycodestyle/$PYCODESTYLE_VERSION/pycodestyle.py" @@ -82,6 +82,23 @@ else rm "$PYCODESTYLE_REPORT_PATH" fi +# stop the build if there are Python syntax errors or undefined names +flake8 . --count --select=E901,E999,F821,F822,F823 --max-line-length=100 --show-source --statistics +flake8_status="${PIPESTATUS[0]}" + +if [ "$flake8_status" -eq 0 ]; then + lint_status=0 +else + lint_status=1 +fi + +if [ "$lint_status" -ne 0 ]; then + echo "flake8 checks failed." + exit "$lint_status" +else + echo "flake8 checks passed." +fi + # Check that the documentation builds acceptably, skip check if sphinx is not installed. if hash "$SPHINXBUILD" 2> /dev/null; then cd python/docs diff --git a/dev/make-distribution.sh b/dev/make-distribution.sh index 84233c64caa9c..778d376c12b56 100755 --- a/dev/make-distribution.sh +++ b/dev/make-distribution.sh @@ -192,6 +192,7 @@ fi if [ -d "$SPARK_HOME"/resource-managers/kubernetes/core/target/ ]; then mkdir -p "$DISTDIR/kubernetes/" cp -a "$SPARK_HOME"/resource-managers/kubernetes/docker/src/main/dockerfiles "$DISTDIR/kubernetes/" + cp -a "$SPARK_HOME"/resource-managers/kubernetes/integration-tests/tests "$DISTDIR/kubernetes/" fi # Copy examples and dependencies @@ -211,9 +212,10 @@ mkdir -p "$DISTDIR/examples/src/main" cp -r "$SPARK_HOME/examples/src/main" "$DISTDIR/examples/src/" # Copy license and ASF files -cp "$SPARK_HOME/LICENSE" "$DISTDIR" -cp -r "$SPARK_HOME/licenses" "$DISTDIR" -cp "$SPARK_HOME/NOTICE" "$DISTDIR" +cp "$SPARK_HOME/LICENSE-binary" "$DISTDIR/LICENSE" +mkdir -p "$DISTDIR/licenses" +cp -r "$SPARK_HOME/licenses-binary" "$DISTDIR/licenses" +cp "$SPARK_HOME/NOTICE-binary" "$DISTDIR/NOTICE" if [ -e "$SPARK_HOME/CHANGES.txt" ]; then cp "$SPARK_HOME/CHANGES.txt" "$DISTDIR" diff --git a/dev/merge_spark_pr.py b/dev/merge_spark_pr.py index 7f46a1c8f6a7c..cca6f405e89ac 100755 --- a/dev/merge_spark_pr.py +++ b/dev/merge_spark_pr.py @@ -39,6 +39,9 @@ except ImportError: JIRA_IMPORTED = False +if sys.version < '3': + input = raw_input # noqa + # Location of your Spark git development area SPARK_HOME = os.environ.get("SPARK_HOME", os.getcwd()) # Remote name which points to the Gihub site @@ -95,7 +98,7 @@ def run_cmd(cmd): def continue_maybe(prompt): - result = raw_input("\n%s (y/n): " % prompt) + result = input("\n%s (y/n): " % prompt) if result.lower() != "y": fail("Okay, exiting") @@ -134,11 +137,16 @@ def merge_pr(pr_num, target_ref, title, body, pr_repo_desc): '--pretty=format:%an <%ae>']).split("\n") distinct_authors = sorted(set(commit_authors), key=lambda x: commit_authors.count(x), reverse=True) - primary_author = raw_input( + primary_author = input( "Enter primary author in the format of \"name \" [%s]: " % distinct_authors[0]) if primary_author == "": primary_author = distinct_authors[0] + else: + # When primary author is specified manually, de-dup it from author list and + # put it at the head of author list. + distinct_authors = list(filter(lambda x: x != primary_author, distinct_authors)) + distinct_authors.insert(0, primary_author) commits = run_cmd(['git', 'log', 'HEAD..%s' % pr_branch_name, '--pretty=format:%h [%an] %s']).split("\n\n") @@ -151,13 +159,10 @@ def merge_pr(pr_num, target_ref, title, body, pr_repo_desc): # to people every time someone creates a public fork of Spark. merge_message_flags += ["-m", body.replace("@", "")] - authors = "\n".join(["Author: %s" % a for a in distinct_authors]) - - merge_message_flags += ["-m", authors] + committer_name = run_cmd("git config --get user.name").strip() + committer_email = run_cmd("git config --get user.email").strip() if had_conflicts: - committer_name = run_cmd("git config --get user.name").strip() - committer_email = run_cmd("git config --get user.email").strip() message = "This patch had conflicts when merged, resolved by\nCommitter: %s <%s>" % ( committer_name, committer_email) merge_message_flags += ["-m", message] @@ -165,6 +170,14 @@ def merge_pr(pr_num, target_ref, title, body, pr_repo_desc): # The string "Closes #%s" string is required for GitHub to correctly close the PR merge_message_flags += ["-m", "Closes #%s from %s." % (pr_num, pr_repo_desc)] + authors = "Authored-by:" if len(distinct_authors) == 1 else "Lead-authored-by:" + authors += " %s" % (distinct_authors.pop(0)) + if len(distinct_authors) > 0: + authors += "\n" + "\n".join(["Co-authored-by: %s" % a for a in distinct_authors]) + authors += "\n" + "Signed-off-by: %s <%s>" % (committer_name, committer_email) + + merge_message_flags += ["-m", authors] + run_cmd(['git', 'commit', '--author="%s"' % primary_author] + merge_message_flags) continue_maybe("Merge complete (local ref %s). Push to %s?" % ( @@ -184,7 +197,7 @@ def merge_pr(pr_num, target_ref, title, body, pr_repo_desc): def cherry_pick(pr_num, merge_hash, default_branch): - pick_ref = raw_input("Enter a branch name [%s]: " % default_branch) + pick_ref = input("Enter a branch name [%s]: " % default_branch) if pick_ref == "": pick_ref = default_branch @@ -231,7 +244,7 @@ def resolve_jira_issue(merge_branches, comment, default_jira_id=""): asf_jira = jira.client.JIRA({'server': JIRA_API_BASE}, basic_auth=(JIRA_USERNAME, JIRA_PASSWORD)) - jira_id = raw_input("Enter a JIRA id [%s]: " % default_jira_id) + jira_id = input("Enter a JIRA id [%s]: " % default_jira_id) if jira_id == "": jira_id = default_jira_id @@ -261,7 +274,7 @@ def resolve_jira_issue(merge_branches, comment, default_jira_id=""): versions = sorted(versions, key=lambda x: x.name, reverse=True) versions = filter(lambda x: x.raw['released'] is False, versions) # Consider only x.y.z versions - versions = filter(lambda x: re.match('\d+\.\d+\.\d+', x.name), versions) + versions = filter(lambda x: re.match(r'\d+\.\d+\.\d+', x.name), versions) default_fix_versions = map(lambda x: fix_version_from_branch(x, versions).name, merge_branches) for v in default_fix_versions: @@ -276,7 +289,7 @@ def resolve_jira_issue(merge_branches, comment, default_jira_id=""): default_fix_versions = filter(lambda x: x != v, default_fix_versions) default_fix_versions = ",".join(default_fix_versions) - fix_versions = raw_input("Enter comma-separated fix version(s) [%s]: " % default_fix_versions) + fix_versions = input("Enter comma-separated fix version(s) [%s]: " % default_fix_versions) if fix_versions == "": fix_versions = default_fix_versions fix_versions = fix_versions.replace(" ", "").split(",") @@ -315,8 +328,8 @@ def choose_jira_assignee(issue, asf_jira): if author in commentors: annotations.append("Commentor") print("[%d] %s (%s)" % (idx, author.displayName, ",".join(annotations))) - raw_assignee = raw_input( - "Enter number of user, or userid, to assign to (blank to leave unassigned):") + raw_assignee = input( + "Enter number of user, or userid, to assign to (blank to leave unassigned):") if raw_assignee == "": return None else: @@ -328,6 +341,8 @@ def choose_jira_assignee(issue, asf_jira): assignee = asf_jira.user(raw_assignee) asf_jira.assign_issue(issue.key, assignee.key) return assignee + except KeyboardInterrupt: + raise except: traceback.print_exc() print("Error assigning JIRA, try again (or leave blank and fix manually)") @@ -359,8 +374,8 @@ def standardize_jira_ref(text): >>> standardize_jira_ref("[SPARK-979] a LRU scheduler for load balancing in TaskSchedulerImpl") '[SPARK-979] a LRU scheduler for load balancing in TaskSchedulerImpl' >>> standardize_jira_ref( - ... "SPARK-1094 Support MiMa for reporting binary compatibility accross versions.") - '[SPARK-1094] Support MiMa for reporting binary compatibility accross versions.' + ... "SPARK-1094 Support MiMa for reporting binary compatibility across versions.") + '[SPARK-1094] Support MiMa for reporting binary compatibility across versions.' >>> standardize_jira_ref("[WIP] [SPARK-1146] Vagrant support for Spark") '[SPARK-1146][WIP] Vagrant support for Spark' >>> standardize_jira_ref( @@ -388,7 +403,7 @@ def standardize_jira_ref(text): # Extract spark component(s): # Look for alphanumeric chars, spaces, dashes, periods, and/or commas - pattern = re.compile(r'(\[[\w\s,-\.]+\])', re.IGNORECASE) + pattern = re.compile(r'(\[[\w\s,.-]+\])', re.IGNORECASE) for component in pattern.findall(text): components.append(component.upper()) text = text.replace(component, '') @@ -423,12 +438,16 @@ def main(): os.chdir(SPARK_HOME) original_head = get_current_ref() + # Check this up front to avoid failing the JIRA update at the very end + if not JIRA_USERNAME or not JIRA_PASSWORD: + continue_maybe("The env-vars JIRA_USERNAME and/or JIRA_PASSWORD are not set. Continue?") + branches = get_json("%s/branches" % GITHUB_API_BASE) branch_names = filter(lambda x: x.startswith("branch-"), [x['name'] for x in branches]) # Assumes branch names can be sorted lexicographically latest_branch = sorted(branch_names, reverse=True)[0] - pr_num = raw_input("Which pull request would you like to merge? (e.g. 34): ") + pr_num = input("Which pull request would you like to merge? (e.g. 34): ") pr = get_json("%s/pulls/%s" % (GITHUB_API_BASE, pr_num)) pr_events = get_json("%s/issues/%s/events" % (GITHUB_API_BASE, pr_num)) @@ -440,7 +459,7 @@ def main(): print("I've re-written the title as follows to match the standard format:") print("Original: %s" % pr["title"]) print("Modified: %s" % modified_title) - result = raw_input("Would you like to use the modified title? (y/n): ") + result = input("Would you like to use the modified title? (y/n): ") if result.lower() == "y": title = modified_title print("Using modified title:") @@ -491,7 +510,7 @@ def main(): merge_hash = merge_pr(pr_num, target_ref, title, body, pr_repo_desc) pick_prompt = "Would you like to pick %s into another branch?" % merge_hash - while raw_input("\n%s (y/n): " % pick_prompt).lower() == "y": + while input("\n%s (y/n): " % pick_prompt).lower() == "y": merged_refs = merged_refs + [cherry_pick(pr_num, merge_hash, latest_branch)] if JIRA_IMPORTED: diff --git a/dev/requirements.txt b/dev/requirements.txt index 79782279f8fbd..3fdd3425ffcc2 100644 --- a/dev/requirements.txt +++ b/dev/requirements.txt @@ -1,4 +1,6 @@ +flake8==3.5.0 jira==1.0.3 PyGithub==1.26.0 Unidecode==0.04.19 pypandoc==1.3.3 +sphinx diff --git a/dev/run-pip-tests b/dev/run-pip-tests index 7271d1014e4ae..60cf4d8209416 100755 --- a/dev/run-pip-tests +++ b/dev/run-pip-tests @@ -52,7 +52,7 @@ if hash virtualenv 2>/dev/null && [ ! -n "$USE_CONDA" ]; then PYTHON_EXECS+=('python3') fi elif hash conda 2>/dev/null; then - echo "Using conda virtual enviroments" + echo "Using conda virtual environments" PYTHON_EXECS=('3.5') USE_CONDA=1 else @@ -88,7 +88,7 @@ for python in "${PYTHON_EXECS[@]}"; do virtualenv --python=$python "$VIRTUALENV_PATH" source "$VIRTUALENV_PATH"/bin/activate fi - # Upgrade pip & friends if using virutal env + # Upgrade pip & friends if using virtual env if [ ! -n "$USE_CONDA" ]; then pip install --upgrade pip pypandoc wheel numpy fi @@ -123,7 +123,7 @@ for python in "${PYTHON_EXECS[@]}"; do cd "$FWDIR" - # conda / virtualenv enviroments need to be deactivated differently + # conda / virtualenv environments need to be deactivated differently if [ -n "$USE_CONDA" ]; then source deactivate else diff --git a/dev/run-tests-jenkins.py b/dev/run-tests-jenkins.py index 3960a0de62530..eca88f2391bf8 100755 --- a/dev/run-tests-jenkins.py +++ b/dev/run-tests-jenkins.py @@ -115,7 +115,8 @@ def run_tests(tests_timeout): os.path.join(SPARK_HOME, 'dev', 'run-tests')]).wait() failure_note_by_errcode = { - 1: 'executing the `dev/run-tests` script', # error to denote run-tests script failures + # error to denote run-tests script failures: + 1: 'executing the `dev/run-tests` script', ERROR_CODES["BLOCK_GENERAL"]: 'some tests', ERROR_CODES["BLOCK_RAT"]: 'RAT tests', ERROR_CODES["BLOCK_SCALA_STYLE"]: 'Scala style tests', @@ -130,7 +131,7 @@ def run_tests(tests_timeout): ERROR_CODES["BLOCK_PYSPARK_UNIT_TESTS"]: 'PySpark unit tests', ERROR_CODES["BLOCK_PYSPARK_PIP_TESTS"]: 'PySpark pip packaging tests', ERROR_CODES["BLOCK_SPARKR_UNIT_TESTS"]: 'SparkR unit tests', - ERROR_CODES["BLOCK_TIMEOUT"]: 'from timeout after a configured wait of \`%s\`' % ( + ERROR_CODES["BLOCK_TIMEOUT"]: 'from timeout after a configured wait of `%s`' % ( tests_timeout) } @@ -181,8 +182,9 @@ def main(): short_commit_hash = ghprb_actual_commit[0:7] # format: http://linux.die.net/man/1/timeout - # must be less than the timeout configured on Jenkins (currently 350m) - tests_timeout = "300m" + # must be less than the timeout configured on Jenkins. Usually Jenkins's timeout is higher + # then this. Please consult with the build manager or a committer when it should be increased. + tests_timeout = "400m" # Array to capture all test names to run on the pull request. These tests are represented # by their file equivalents in the dev/tests/ directory. diff --git a/dev/run-tests.py b/dev/run-tests.py index cd4590864b7d7..f534637b80d6b 100755 --- a/dev/run-tests.py +++ b/dev/run-tests.py @@ -110,7 +110,7 @@ def determine_modules_to_test(changed_modules): ['graphx', 'examples'] >>> x = [x.name for x in determine_modules_to_test([modules.sql])] >>> x # doctest: +NORMALIZE_WHITESPACE - ['sql', 'hive', 'mllib', 'sql-kafka-0-10', 'examples', 'hive-thriftserver', + ['sql', 'avro', 'hive', 'mllib', 'sql-kafka-0-10', 'examples', 'hive-thriftserver', 'pyspark-sql', 'repl', 'sparkr', 'pyspark-mllib', 'pyspark-ml'] """ modules_to_test = set() @@ -169,7 +169,7 @@ def determine_java_version(java_exe): # find raw version string, eg 'java version "1.8.0_25"' raw_version_str = next(x for x in raw_output_lines if " version " in x) - match = re.search('(\d+)\.(\d+)\.(\d+)', raw_version_str) + match = re.search(r'(\d+)\.(\d+)\.(\d+)', raw_version_str) major = int(match.group(1)) minor = int(match.group(2)) diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index dfea762db98c6..2aa355504bf29 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -170,6 +170,16 @@ def __hash__(self): ] ) +avro = Module( + name="avro", + dependencies=[sql], + source_file_regexes=[ + "external/avro", + ], + sbt_test_goals=[ + "avro/test", + ] +) sql_kafka = Module( name="sql-kafka-0-10", diff --git a/dev/tox.ini b/dev/tox.ini index 28dad8f3b5c7c..6ec223b743b4e 100644 --- a/dev/tox.ini +++ b/dev/tox.ini @@ -14,6 +14,6 @@ # limitations under the License. [pycodestyle] -ignore=E402,E731,E241,W503,E226,E722,E741,E305 +ignore=E226,E241,E305,E402,E722,E731,E741,W503,W504 max-line-length=100 exclude=cloudpickle.py,heapq3.py,shared.py,python/docs/conf.py,work/*/*.py,python/.eggs/*,dist/* diff --git a/docs/README.md b/docs/README.md index dbea4d64c4298..fb67c4b3586d6 100644 --- a/docs/README.md +++ b/docs/README.md @@ -2,7 +2,7 @@ Welcome to the Spark documentation! This readme will walk you through navigating and building the Spark documentation, which is included here with the Spark source code. You can also find documentation specific to release versions of -Spark at http://spark.apache.org/documentation.html. +Spark at https://spark.apache.org/documentation.html. Read on to learn more about viewing documentation in plain text (i.e., markdown) or building the documentation yourself. Why build it yourself? So that you have the docs that correspond to @@ -22,8 +22,9 @@ $ sudo gem install jekyll jekyll-redirect-from pygments.rb $ sudo pip install Pygments # Following is needed only for generating API docs $ sudo pip install sphinx pypandoc mkdocs -$ sudo Rscript -e 'install.packages(c("knitr", "devtools", "testthat", "rmarkdown"), repos="http://cran.stat.ucla.edu/")' +$ sudo Rscript -e 'install.packages(c("knitr", "devtools", "rmarkdown"), repos="http://cran.stat.ucla.edu/")' $ sudo Rscript -e 'devtools::install_version("roxygen2", version = "5.0.1", repos="http://cran.stat.ucla.edu/")' +$ sudo Rscript -e 'devtools::install_version("testthat", version = "1.0.2", repos="http://cran.stat.ucla.edu/")' ``` Note: If you are on a system with both Ruby 1.9 and Ruby 2.0 you may need to replace gem with gem2.0. @@ -79,7 +80,7 @@ jekyll plugin to run `build/sbt unidoc` before building the site so if you haven may take some time as it generates all of the scaladoc and javadoc using [Unidoc](https://github.com/sbt/sbt-unidoc). The jekyll plugin also generates the PySpark docs using [Sphinx](http://sphinx-doc.org/), SparkR docs using [roxygen2](https://cran.r-project.org/web/packages/roxygen2/index.html) and SQL docs -using [MkDocs](http://www.mkdocs.org/). +using [MkDocs](https://www.mkdocs.org/). NOTE: To skip the step of building and copying over the Scala, Java, Python, R and SQL API docs, run `SKIP_API=1 jekyll build`. In addition, `SKIP_SCALADOC=1`, `SKIP_PYTHONDOC=1`, `SKIP_RDOC=1` and `SKIP_SQLDOC=1` can be used diff --git a/docs/_layouts/404.html b/docs/_layouts/404.html index 044654413f9c2..78f98b9ede5a7 100755 --- a/docs/_layouts/404.html +++ b/docs/_layouts/404.html @@ -151,7 +151,7 @@

Not found :(

- + diff --git a/docs/_layouts/global.html b/docs/_layouts/global.html index e5af5ae4561c7..88d549c3f1010 100755 --- a/docs/_layouts/global.html +++ b/docs/_layouts/global.html @@ -50,7 +50,7 @@ @@ -114,8 +114,8 @@
  • Hardware Provisioning
  • Building Spark
  • -
  • Contributing to Spark
  • -
  • Third Party Projects
  • +
  • Contributing to Spark
  • +
  • Third Party Projects
  • diff --git a/docs/avro-data-source-guide.md b/docs/avro-data-source-guide.md new file mode 100644 index 0000000000000..d3b81f029d377 --- /dev/null +++ b/docs/avro-data-source-guide.md @@ -0,0 +1,380 @@ +--- +layout: global +title: Apache Avro Data Source Guide +--- + +* This will become a table of contents (this text will be scraped). +{:toc} + +Since Spark 2.4 release, [Spark SQL](https://spark.apache.org/docs/latest/sql-programming-guide.html) provides built-in support for reading and writing Apache Avro data. + +## Deploying +The `spark-avro` module is external and not included in `spark-submit` or `spark-shell` by default. + +As with any Spark applications, `spark-submit` is used to launch your application. `spark-avro_{{site.SCALA_BINARY_VERSION}}` +and its dependencies can be directly added to `spark-submit` using `--packages`, such as, + + ./bin/spark-submit --packages org.apache.spark:spark-avro_{{site.SCALA_BINARY_VERSION}}:{{site.SPARK_VERSION_SHORT}} ... + +For experimenting on `spark-shell`, you can also use `--packages` to add `org.apache.spark:spark-avro_{{site.SCALA_BINARY_VERSION}}` and its dependencies directly, + + ./bin/spark-shell --packages org.apache.spark:spark-avro_{{site.SCALA_BINARY_VERSION}}:{{site.SPARK_VERSION_SHORT}} ... + +See [Application Submission Guide](submitting-applications.html) for more details about submitting applications with external dependencies. + +## Load and Save Functions + +Since `spark-avro` module is external, there is no `.avro` API in +`DataFrameReader` or `DataFrameWriter`. + +To load/save data in Avro format, you need to specify the data source option `format` as `avro`(or `org.apache.spark.sql.avro`). +
    +
    +{% highlight scala %} + +val usersDF = spark.read.format("avro").load("examples/src/main/resources/users.avro") +usersDF.select("name", "favorite_color").write.format("avro").save("namesAndFavColors.avro") + +{% endhighlight %} +
    +
    +{% highlight java %} + +Dataset usersDF = spark.read().format("avro").load("examples/src/main/resources/users.avro"); +usersDF.select("name", "favorite_color").write().format("avro").save("namesAndFavColors.avro"); + +{% endhighlight %} +
    +
    +{% highlight python %} + +df = spark.read.format("avro").load("examples/src/main/resources/users.avro") +df.select("name", "favorite_color").write.format("avro").save("namesAndFavColors.avro") + +{% endhighlight %} +
    +
    +{% highlight r %} + +df <- read.df("examples/src/main/resources/users.avro", "avro") +write.df(select(df, "name", "favorite_color"), "namesAndFavColors.avro", "avro") + +{% endhighlight %} +
    +
    + +## to_avro() and from_avro() +The Avro package provides function `to_avro` to encode a column as binary in Avro +format, and `from_avro()` to decode Avro binary data into a column. Both functions transform one column to +another column, and the input/output SQL data type can be complex type or primitive type. + +Using Avro record as columns are useful when reading from or writing to a streaming source like Kafka. Each +Kafka key-value record will be augmented with some metadata, such as the ingestion timestamp into Kafka, the offset in Kafka, etc. +* If the "value" field that contains your data is in Avro, you could use `from_avro()` to extract your data, enrich it, clean it, and then push it downstream to Kafka again or write it out to a file. +* `to_avro()` can be used to turn structs into Avro records. This method is particularly useful when you would like to re-encode multiple columns into a single one when writing data out to Kafka. + +Both functions are currently only available in Scala and Java. + +
    +
    +{% highlight scala %} +import org.apache.spark.sql.avro._ + +// `from_avro` requires Avro schema in JSON string format. +val jsonFormatSchema = new String(Files.readAllBytes(Paths.get("./examples/src/main/resources/user.avsc"))) + +val df = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", "host1:port1,host2:port2") + .option("subscribe", "topic1") + .load() + +// 1. Decode the Avro data into a struct; +// 2. Filter by column `favorite_color`; +// 3. Encode the column `name` in Avro format. +val output = df + .select(from_avro('value, jsonFormatSchema) as 'user) + .where("user.favorite_color == \"red\"") + .select(to_avro($"user.name") as 'value) + +val query = output + .writeStream + .format("kafka") + .option("kafka.bootstrap.servers", "host1:port1,host2:port2") + .option("topic", "topic2") + .start() + +{% endhighlight %} +
    +
    +{% highlight java %} +import org.apache.spark.sql.avro.*; + +// `from_avro` requires Avro schema in JSON string format. +String jsonFormatSchema = new String(Files.readAllBytes(Paths.get("./examples/src/main/resources/user.avsc"))); + +Dataset df = spark + .readStream() + .format("kafka") + .option("kafka.bootstrap.servers", "host1:port1,host2:port2") + .option("subscribe", "topic1") + .load(); + +// 1. Decode the Avro data into a struct; +// 2. Filter by column `favorite_color`; +// 3. Encode the column `name` in Avro format. +Dataset output = df + .select(from_avro(col("value"), jsonFormatSchema).as("user")) + .where("user.favorite_color == \"red\"") + .select(to_avro(col("user.name")).as("value")); + +StreamingQuery query = output + .writeStream() + .format("kafka") + .option("kafka.bootstrap.servers", "host1:port1,host2:port2") + .option("topic", "topic2") + .start(); + +{% endhighlight %} +
    +
    + +## Data Source Option + +Data source options of Avro can be set using the `.option` method on `DataFrameReader` or `DataFrameWriter`. + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
    Property NameDefaultMeaningScope
    avroSchemaNoneOptional Avro schema provided by an user in JSON format. The date type and naming of record fields + should match the input Avro data or Catalyst data, otherwise the read/write action will fail.read and write
    recordNametopLevelRecordTop level record name in write result, which is required in Avro spec.write
    recordNamespace""Record namespace in write result.write
    ignoreExtensiontrueThe option controls ignoring of files without .avro extensions in read.
    If the option is enabled, all files (with and without .avro extension) are loaded.
    read
    compressionsnappyThe compression option allows to specify a compression codec used in write.
    + Currently supported codecs are uncompressed, snappy, deflate, bzip2 and xz.
    If the option is not set, the configuration spark.sql.avro.compression.codec config is taken into account.
    write
    + +## Configuration +Configuration of Avro can be done using the `setConf` method on SparkSession or by running `SET key=value` commands using SQL. + + + + + + + + + + + + + + + + + +
    Property NameDefaultMeaning
    spark.sql.legacy.replaceDatabricksSparkAvro.enabledtrueIf it is set to true, the data source provider com.databricks.spark.avro is mapped to the built-in but external Avro data source module for backward compatibility.
    spark.sql.avro.compression.codecsnappyCompression codec used in writing of AVRO files. Supported codecs: uncompressed, deflate, snappy, bzip2 and xz. Default codec is snappy.
    spark.sql.avro.deflate.level-1Compression level for the deflate codec used in writing of AVRO files. Valid value must be in the range of from 1 to 9 inclusive or -1. The default value is -1 which corresponds to 6 level in the current implementation.
    + +## Compatibility with Databricks spark-avro +This Avro data source module is originally from and compatible with Databricks's open source repository +[spark-avro](https://github.com/databricks/spark-avro). + +By default with the SQL configuration `spark.sql.legacy.replaceDatabricksSparkAvro.enabled` enabled, the data source provider `com.databricks.spark.avro` is +mapped to this built-in Avro module. For the Spark tables created with `Provider` property as `com.databricks.spark.avro` in +catalog meta store, the mapping is essential to load these tables if you are using this built-in Avro module. + +Note in Databricks's [spark-avro](https://github.com/databricks/spark-avro), implicit classes +`AvroDataFrameWriter` and `AvroDataFrameReader` were created for shortcut function `.avro()`. In this +built-in but external module, both implicit classes are removed. Please use `.format("avro")` in +`DataFrameWriter` or `DataFrameReader` instead, which should be clean and good enough. + +If you prefer using your own build of `spark-avro` jar file, you can simply disable the configuration +`spark.sql.legacy.replaceDatabricksSparkAvro.enabled`, and use the option `--jars` on deploying your +applications. Read the [Advanced Dependency Management](https://spark.apache +.org/docs/latest/submitting-applications.html#advanced-dependency-management) section in Application +Submission Guide for more details. + +## Supported types for Avro -> Spark SQL conversion +Currently Spark supports reading all [primitive types](https://avro.apache.org/docs/1.8.2/spec.html#schema_primitive) and [complex types](https://avro.apache.org/docs/1.8.2/spec.html#schema_complex) under records of Avro. + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
    Avro typeSpark SQL type
    booleanBooleanType
    intIntegerType
    longLongType
    floatFloatType
    doubleDoubleType
    stringStringType
    enumStringType
    fixedBinaryType
    bytesBinaryType
    recordStructType
    arrayArrayType
    mapMapType
    unionSee below
    + +In addition to the types listed above, it supports reading `union` types. The following three types are considered basic `union` types: + +1. `union(int, long)` will be mapped to LongType. +2. `union(float, double)` will be mapped to DoubleType. +3. `union(something, null)`, where something is any supported Avro type. This will be mapped to the same Spark SQL type as that of something, with nullable set to true. +All other union types are considered complex. They will be mapped to StructType where field names are member0, member1, etc., in accordance with members of the union. This is consistent with the behavior when converting between Avro and Parquet. + +It also supports reading the following Avro [logical types](https://avro.apache.org/docs/1.8.2/spec.html#Logical+Types): + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
    Avro logical typeAvro typeSpark SQL type
    dateintDateType
    timestamp-millislongTimestampType
    timestamp-microslongTimestampType
    decimalfixedDecimalType
    decimalbytesDecimalType
    +At the moment, it ignores docs, aliases and other properties present in the Avro file. + +## Supported types for Spark SQL -> Avro conversion +Spark supports writing of all Spark SQL types into Avro. For most types, the mapping from Spark types to Avro types is straightforward (e.g. IntegerType gets converted to int); however, there are a few special cases which are listed below: + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
    Spark SQL typeAvro typeAvro logical type
    ByteTypeint
    ShortTypeint
    BinaryTypebytes
    DateTypeintdate
    TimestampTypelongtimestamp-micros
    DecimalTypefixeddecimal
    + +You can also specify the whole output Avro schema with the option `avroSchema`, so that Spark SQL types can be converted into other Avro types. The following conversions are not applied by default and require user specified Avro schema: + + + + + + + + + + + + + + + + + + + + + + + +
    Spark SQL typeAvro typeAvro logical type
    BinaryTypefixed
    StringTypeenum
    TimestampTypelongtimestamp-millis
    DecimalTypebytesdecimal
    diff --git a/docs/building-spark.md b/docs/building-spark.md index c3bcd90ccc78f..1501f0bb84544 100644 --- a/docs/building-spark.md +++ b/docs/building-spark.md @@ -45,7 +45,7 @@ Other build examples can be found below. ## Building a Runnable Distribution To create a Spark distribution like those distributed by the -[Spark Downloads](http://spark.apache.org/downloads.html) page, and that is laid out so as +[Spark Downloads](https://spark.apache.org/downloads.html) page, and that is laid out so as to be runnable, use `./dev/make-distribution.sh` in the project root directory. It can be configured with Maven profile settings and so on like the direct Maven build. Example: @@ -164,7 +164,7 @@ prompt. Developers who compile Spark frequently may want to speed up compilation; e.g., by using Zinc (for developers who build with Maven) or by avoiding re-compilation of the assembly JAR (for developers who build with SBT). For more information about how to do this, refer to the -[Useful Developer Tools page](http://spark.apache.org/developer-tools.html#reducing-build-times). +[Useful Developer Tools page](https://spark.apache.org/developer-tools.html#reducing-build-times). ## Encrypted Filesystems @@ -182,7 +182,7 @@ to the `sharedSettings` val. See also [this PR](https://github.com/apache/spark/ ## IntelliJ IDEA or Eclipse For help in setting up IntelliJ IDEA or Eclipse for Spark development, and troubleshooting, refer to the -[Useful Developer Tools page](http://spark.apache.org/developer-tools.html). +[Useful Developer Tools page](https://spark.apache.org/developer-tools.html). # Running Tests @@ -203,7 +203,7 @@ The following is an example of a command to run the tests: ## Running Individual Tests For information about how to run individual tests, refer to the -[Useful Developer Tools page](http://spark.apache.org/developer-tools.html#running-individual-tests). +[Useful Developer Tools page](https://spark.apache.org/developer-tools.html#running-individual-tests). ## PySpark pip installable @@ -236,7 +236,8 @@ The run-tests script also can be limited to a specific Python version or a speci To run the SparkR tests you will need to install the [knitr](https://cran.r-project.org/package=knitr), [rmarkdown](https://cran.r-project.org/package=rmarkdown), [testthat](https://cran.r-project.org/package=testthat), [e1071](https://cran.r-project.org/package=e1071) and [survival](https://cran.r-project.org/package=survival) packages first: - R -e "install.packages(c('knitr', 'rmarkdown', 'testthat', 'e1071', 'survival'), repos='http://cran.us.r-project.org')" + R -e "install.packages(c('knitr', 'rmarkdown', 'devtools', 'e1071', 'survival'), repos='http://cran.us.r-project.org')" + R -e "devtools::install_version('testthat', version = '1.0.2', repos='http://cran.us.r-project.org')" You can run just the SparkR tests using the command: @@ -255,3 +256,19 @@ On Linux, this can be done by `sudo service docker start`. or ./build/sbt docker-integration-tests/test + +## Change Scala Version + +To build Spark using another supported Scala version, please change the major Scala version using (e.g. 2.12): + + ./dev/change-scala-version.sh 2.12 + +For Maven, please enable the profile (e.g. 2.12): + + ./build/mvn -Pscala-2.12 compile + +For SBT, specify a complete scala version using (e.g. 2.12.6): + + ./build/sbt -Dscala.version=2.12.6 + +Otherwise, the sbt-pom-reader plugin will use the `scala.version` specified in the spark-parent pom. diff --git a/docs/cloud-integration.md b/docs/cloud-integration.md index 18e8fe77bbdbe..36753f6373b55 100644 --- a/docs/cloud-integration.md +++ b/docs/cloud-integration.md @@ -104,7 +104,7 @@ Spark jobs must authenticate with the object stores to access data within them. and `AWS_SESSION_TOKEN` environment variables and sets the associated authentication options for the `s3n` and `s3a` connectors to Amazon S3. 1. In a Hadoop cluster, settings may be set in the `core-site.xml` file. -1. Authentication details may be manually added to the Spark configuration in `spark-default.conf` +1. Authentication details may be manually added to the Spark configuration in `spark-defaults.conf` 1. Alternatively, they can be programmatically set in the `SparkConf` instance used to configure the application's `SparkContext`. diff --git a/docs/configuration.md b/docs/configuration.md index 0c7c4472be643..782ccff667076 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -152,8 +152,9 @@ of the most common options to set are: spark.driver.memory 1g - Amount of memory to use for the driver process, i.e. where SparkContext is initialized, in MiB - unless otherwise specified (e.g. 1g, 2g). + Amount of memory to use for the driver process, i.e. where SparkContext is initialized, in the + same format as JVM memory strings with a size unit suffix ("k", "m", "g" or "t") + (e.g. 512m, 2g).
    Note: In client mode, this config must not be set through the SparkConf directly in your application, because the driver JVM has already started at that point. @@ -175,8 +176,20 @@ of the most common options to set are: spark.executor.memory 1g - Amount of memory to use per executor process, in MiB unless otherwise specified. - (e.g. 2g, 8g). + Amount of memory to use per executor process, in the same format as JVM memory strings with + a size unit suffix ("k", "m", "g" or "t") (e.g. 512m, 2g). + + + + spark.executor.pyspark.memory + Not set + + The amount of memory to be allocated to PySpark in each executor, in MiB + unless otherwise specified. If set, PySpark memory for an executor will be + limited to this amount. If not set, Spark will not limit Python's memory use + and it is up to the application to avoid exceeding the overhead memory space + shared with other non-JVM processes. When PySpark is run in YARN or Kubernetes, this memory + is added to executor resource requests. @@ -580,13 +593,15 @@ Apart from these, the following properties are also available, and may be useful spark.maxRemoteBlockSizeFetchToMem - Long.MaxValue + Int.MaxValue - 512 The remote block will be fetched to disk when size of the block is above this threshold in bytes. - This is to avoid a giant request takes too much memory. We can enable this config by setting - a specific value(e.g. 200m). Note this configuration will affect both shuffle fetch + This is to avoid a giant request that takes too much memory. By default, this is only enabled + for blocks > 2GB, as those cannot be fetched directly into memory, no matter what resources are + available. But it can be turned down to a much lower value (eg. 200m) to avoid using too much + memory on smaller blocks as well. Note this configuration will affect both shuffle fetch and block manager remote block fetch. For users who enabled external shuffle service, - this feature can only be worked when external shuffle service is newer than Spark 2.2. + this feature can only be used when external shuffle service is newer than Spark 2.2. @@ -731,6 +746,13 @@ Apart from these, the following properties are also available, and may be useful *Warning*: This will increase the size of the event log considerably. + + spark.eventLog.longForm.enabled + false + + If true, use the long form of call sites in the event log. Otherwise use the short form. + + spark.eventLog.compress false @@ -1215,6 +1237,15 @@ Apart from these, the following properties are also available, and may be useful if it is too small, BlockManager might take a performance hit. + + spark.broadcast.checksum + true + + Whether to enable checksum for broadcast. If enabled, broadcasts will include a checksum, which can + help detect corrupted blocks, at the cost of computing and sending a little more data. It's possible + to disable it if the network has other mechanisms to guarantee data won't be corrupted during broadcast. + + spark.executor.cores @@ -1816,7 +1847,7 @@ Apart from these, the following properties are also available, and may be useful executors w.r.t. full parallelism. Defaults to 1.0 to give maximum parallelism. 0.5 will divide the target number of executors by 2 - The target number of executors computed by the dynamicAllocation can still be overriden + The target number of executors computed by the dynamicAllocation can still be overridden by the spark.dynamicAllocation.minExecutors and spark.dynamicAllocation.maxExecutors settings @@ -1974,6 +2005,14 @@ showDF(properties, numRows = 200, truncate = FALSE) for more details. + + spark.streaming.kafka.minRatePerPartition + 1 + + Minimum rate (number of records per second) at which data will be read from each Kafka + partition when using the new Kafka direct stream API. + + spark.streaming.kafka.maxRetries 1 @@ -2202,7 +2241,7 @@ Spark's classpath for each application. In a Spark cluster running on YARN, thes files are set cluster-wide, and cannot safely be changed by the application. The better choice is to use spark hadoop properties in the form of `spark.hadoop.*`. -They can be considered as same as normal spark properties which can be set in `$SPARK_HOME/conf/spark-default.conf` +They can be considered as same as normal spark properties which can be set in `$SPARK_HOME/conf/spark-defaults.conf` In some cases, you may want to avoid hard-coding certain configurations in a `SparkConf`. For instance, Spark allows you to simply create an empty conf and set spark/spark hadoop properties. diff --git a/docs/contributing-to-spark.md b/docs/contributing-to-spark.md index 9252545e4a129..ede5584a0cf99 100644 --- a/docs/contributing-to-spark.md +++ b/docs/contributing-to-spark.md @@ -5,4 +5,4 @@ title: Contributing to Spark The Spark team welcomes all forms of contributions, including bug reports, documentation or patches. For the newest information on how to contribute to the project, please read the -[Contributing to Spark guide](http://spark.apache.org/contributing.html). +[Contributing to Spark guide](https://spark.apache.org/contributing.html). diff --git a/docs/index.md b/docs/index.md index 2f009417fafb0..40f628b794c01 100644 --- a/docs/index.md +++ b/docs/index.md @@ -12,7 +12,7 @@ It also supports a rich set of higher-level tools including [Spark SQL](sql-prog # Downloading -Get Spark from the [downloads page](http://spark.apache.org/downloads.html) of the project website. This documentation is for Spark version {{site.SPARK_VERSION}}. Spark uses Hadoop's client libraries for HDFS and YARN. Downloads are pre-packaged for a handful of popular Hadoop versions. +Get Spark from the [downloads page](https://spark.apache.org/downloads.html) of the project website. This documentation is for Spark version {{site.SPARK_VERSION}}. Spark uses Hadoop's client libraries for HDFS and YARN. Downloads are pre-packaged for a handful of popular Hadoop versions. Users can also download a "Hadoop free" binary and run Spark with any Hadoop version [by augmenting Spark's classpath](hadoop-provided.html). Scala and Java users can include Spark in their projects using its Maven coordinates and in the future Python users can also install Spark from PyPI. @@ -111,7 +111,7 @@ options for deployment: * [Amazon EC2](https://github.com/amplab/spark-ec2): scripts that let you launch a cluster on EC2 in about 5 minutes * [Standalone Deploy Mode](spark-standalone.html): launch a standalone cluster quickly without a third-party cluster manager * [Mesos](running-on-mesos.html): deploy a private cluster using - [Apache Mesos](http://mesos.apache.org) + [Apache Mesos](https://mesos.apache.org) * [YARN](running-on-yarn.html): deploy Spark on top of Hadoop NextGen (YARN) * [Kubernetes](running-on-kubernetes.html): deploy Spark on top of Kubernetes @@ -127,20 +127,20 @@ options for deployment: * [Cloud Infrastructures](cloud-integration.html) * [OpenStack Swift](storage-openstack-swift.html) * [Building Spark](building-spark.html): build Spark using the Maven system -* [Contributing to Spark](http://spark.apache.org/contributing.html) -* [Third Party Projects](http://spark.apache.org/third-party-projects.html): related third party Spark projects +* [Contributing to Spark](https://spark.apache.org/contributing.html) +* [Third Party Projects](https://spark.apache.org/third-party-projects.html): related third party Spark projects **External Resources:** -* [Spark Homepage](http://spark.apache.org) -* [Spark Community](http://spark.apache.org/community.html) resources, including local meetups +* [Spark Homepage](https://spark.apache.org) +* [Spark Community](https://spark.apache.org/community.html) resources, including local meetups * [StackOverflow tag `apache-spark`](http://stackoverflow.com/questions/tagged/apache-spark) -* [Mailing Lists](http://spark.apache.org/mailing-lists.html): ask questions about Spark here +* [Mailing Lists](https://spark.apache.org/mailing-lists.html): ask questions about Spark here * [AMP Camps](http://ampcamp.berkeley.edu/): a series of training camps at UC Berkeley that featured talks and exercises about Spark, Spark Streaming, Mesos, and more. [Videos](http://ampcamp.berkeley.edu/6/), [slides](http://ampcamp.berkeley.edu/6/) and [exercises](http://ampcamp.berkeley.edu/6/exercises/) are available online for free. -* [Code Examples](http://spark.apache.org/examples.html): more are also available in the `examples` subfolder of Spark ([Scala]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/scala/org/apache/spark/examples), +* [Code Examples](https://spark.apache.org/examples.html): more are also available in the `examples` subfolder of Spark ([Scala]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/scala/org/apache/spark/examples), [Java]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/java/org/apache/spark/examples), [Python]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/python), [R]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/r)) diff --git a/docs/job-scheduling.md b/docs/job-scheduling.md index da90342406c84..2316f175676ee 100644 --- a/docs/job-scheduling.md +++ b/docs/job-scheduling.md @@ -264,3 +264,11 @@ within it for the various settings. For example: A full example is also available in `conf/fairscheduler.xml.template`. Note that any pools not configured in the XML file will simply get default values for all settings (scheduling mode FIFO, weight 1, and minShare 0). + +## Scheduling using JDBC Connections +To set a [Fair Scheduler](job-scheduling.html#fair-scheduler-pools) pool for a JDBC client session, +users can set the `spark.sql.thriftserver.scheduler.pool` variable: + +{% highlight SQL %} +SET spark.sql.thriftserver.scheduler.pool=accounting; +{% endhighlight %} diff --git a/docs/ml-features.md b/docs/ml-features.md index 7aed2341584fc..882b895a9d154 100644 --- a/docs/ml-features.md +++ b/docs/ml-features.md @@ -585,7 +585,11 @@ for more details on the API. ## StringIndexer `StringIndexer` encodes a string column of labels to a column of label indices. -The indices are in `[0, numLabels)`, ordered by label frequencies, so the most frequent label gets index `0`. +The indices are in `[0, numLabels)`, and four ordering options are supported: +"frequencyDesc": descending order by label frequency (most frequent label assigned 0), +"frequencyAsc": ascending order by label frequency (least frequent label assigned 0), +"alphabetDesc": descending alphabetical order, and "alphabetAsc": ascending alphabetical order +(default = "frequencyDesc"). The unseen labels will be put at index numLabels if user chooses to keep them. If the input column is numeric, we cast it to string and index the string values. When downstream pipeline components such as `Estimator` or @@ -1429,7 +1433,7 @@ for more details on the API. ## Imputer -The `Imputer` transformer completes missing values in a dataset, either using the mean or the +The `Imputer` estimator completes missing values in a dataset, either using the mean or the median of the columns in which the missing values are located. The input columns should be of `DoubleType` or `FloatType`. Currently `Imputer` does not support categorical features and possibly creates incorrect values for columns containing categorical features. Imputer can impute custom values @@ -1593,10 +1597,25 @@ Suppose `a` and `b` are double columns, we use the following simple examples to * `y ~ a + b + a:b - 1` means model `y ~ w1 * a + w2 * b + w3 * a * b` where `w1, w2, w3` are coefficients. `RFormula` produces a vector column of features and a double or string column of label. -Like when formulas are used in R for linear regression, string input columns will be one-hot encoded, and numeric columns will be cast to doubles. -If the label column is of type string, it will be first transformed to double with `StringIndexer`. +Like when formulas are used in R for linear regression, numeric columns will be cast to doubles. +As to string input columns, they will first be transformed with [StringIndexer](ml-features.html#stringindexer) using ordering determined by `stringOrderType`, +and the last category after ordering is dropped, then the doubles will be one-hot encoded. + +Suppose a string feature column containing values `{'b', 'a', 'b', 'a', 'c', 'b'}`, we set `stringOrderType` to control the encoding: +~~~ +stringOrderType | Category mapped to 0 by StringIndexer | Category dropped by RFormula +----------------|---------------------------------------|--------------------------------- +'frequencyDesc' | most frequent category ('b') | least frequent category ('c') +'frequencyAsc' | least frequent category ('c') | most frequent category ('b') +'alphabetDesc' | last alphabetical category ('c') | first alphabetical category ('a') +'alphabetAsc' | first alphabetical category ('a') | last alphabetical category ('c') +~~~ + +If the label column is of type string, it will be first transformed to double with [StringIndexer](ml-features.html#stringindexer) using `frequencyDesc` ordering. If the label column does not exist in the DataFrame, the output label column will be created from the specified response variable in the formula. +**Note:** The ordering option `stringOrderType` is NOT used for the label column. When the label column is indexed, it uses the default descending frequency ordering in `StringIndexer`. + **Examples** Assume that we have a DataFrame with the columns `id`, `country`, `hour`, and `clicked`: diff --git a/docs/ml-migration-guides.md b/docs/ml-migration-guides.md index e4736411fb5fe..2047065f71eb8 100644 --- a/docs/ml-migration-guides.md +++ b/docs/ml-migration-guides.md @@ -289,7 +289,7 @@ In the `spark.mllib` package, there were several breaking changes. The first ch In the `spark.ml` package, the main API changes are from Spark SQL. We list the most important changes here: -* The old [SchemaRDD](http://spark.apache.org/docs/1.2.1/api/scala/index.html#org.apache.spark.sql.SchemaRDD) has been replaced with [DataFrame](api/scala/index.html#org.apache.spark.sql.DataFrame) with a somewhat modified API. All algorithms in `spark.ml` which used to use SchemaRDD now use DataFrame. +* The old [SchemaRDD](https://spark.apache.org/docs/1.2.1/api/scala/index.html#org.apache.spark.sql.SchemaRDD) has been replaced with [DataFrame](api/scala/index.html#org.apache.spark.sql.DataFrame) with a somewhat modified API. All algorithms in `spark.ml` which used to use SchemaRDD now use DataFrame. * In Spark 1.2, we used implicit conversions from `RDD`s of `LabeledPoint` into `SchemaRDD`s by calling `import sqlContext._` where `sqlContext` was an instance of `SQLContext`. These implicits have been moved, so we now call `import sqlContext.implicits._`. * Java APIs for SQL have also changed accordingly. Please see the examples above and the [Spark SQL Programming Guide](sql-programming-guide.html) for details. diff --git a/docs/ml-statistics.md b/docs/ml-statistics.md index abfb3cab1e566..6c82b3bb94b24 100644 --- a/docs/ml-statistics.md +++ b/docs/ml-statistics.md @@ -89,4 +89,32 @@ Refer to the [`ChiSquareTest` Python docs](api/python/index.html#pyspark.ml.stat {% include_example python/ml/chi_square_test_example.py %} + + +## Summarizer + +We provide vector column summary statistics for `Dataframe` through `Summarizer`. +Available metrics are the column-wise max, min, mean, variance, and number of nonzeros, as well as the total count. + +
    +
    +The following example demonstrates using [`Summarizer`](api/scala/index.html#org.apache.spark.ml.stat.Summarizer$) +to compute the mean and variance for a vector column of the input dataframe, with and without a weight column. + +{% include_example scala/org/apache/spark/examples/ml/SummarizerExample.scala %} +
    + +
    +The following example demonstrates using [`Summarizer`](api/java/org/apache/spark/ml/stat/Summarizer.html) +to compute the mean and variance for a vector column of the input dataframe, with and without a weight column. + +{% include_example java/org/apache/spark/examples/ml/JavaSummarizerExample.java %} +
    + +
    +Refer to the [`Summarizer` Python docs](api/python/index.html#pyspark.ml.stat.Summarizer$) for details on the API. + +{% include_example python/ml/summarizer_example.py %} +
    +
    \ No newline at end of file diff --git a/docs/mllib-data-types.md b/docs/mllib-data-types.md index 5066bb29387dc..eca101132d2e5 100644 --- a/docs/mllib-data-types.md +++ b/docs/mllib-data-types.md @@ -317,7 +317,7 @@ Refer to the [`Matrix` Python docs](api/python/pyspark.mllib.html#pyspark.mllib. from pyspark.mllib.linalg import Matrix, Matrices # Create a dense matrix ((1.0, 2.0), (3.0, 4.0), (5.0, 6.0)) -dm2 = Matrices.dense(3, 2, [1, 2, 3, 4, 5, 6]) +dm2 = Matrices.dense(3, 2, [1, 3, 5, 2, 4, 6]) # Create a sparse matrix ((9.0, 0.0), (0.0, 8.0), (0.0, 6.0)) sm = Matrices.sparse(3, 2, [0, 1, 3], [0, 2, 1], [9, 6, 8]) @@ -624,7 +624,7 @@ from pyspark.mllib.linalg.distributed import CoordinateMatrix, MatrixEntry # Create an RDD of coordinate entries. # - This can be done explicitly with the MatrixEntry class: -entries = sc.parallelize([MatrixEntry(0, 0, 1.2), MatrixEntry(1, 0, 2.1), MatrixEntry(6, 1, 3.7)]) +entries = sc.parallelize([MatrixEntry(0, 0, 1.2), MatrixEntry(1, 0, 2.1), MatrixEntry(2, 1, 3.7)]) # - or using (long, long, float) tuples: entries = sc.parallelize([(0, 0, 1.2), (1, 0, 2.1), (2, 1, 3.7)]) diff --git a/docs/mllib-evaluation-metrics.md b/docs/mllib-evaluation-metrics.md index d9dbbab4840a3..c65ecdcb67ee4 100644 --- a/docs/mllib-evaluation-metrics.md +++ b/docs/mllib-evaluation-metrics.md @@ -462,13 +462,13 @@ $$rel_D(r) = \begin{cases}1 & \text{if $r \in D$}, \\ 0 & \text{otherwise}.\end{ Normalized Discounted Cumulative Gain $NDCG(k)=\frac{1}{M} \sum_{i=0}^{M-1} {\frac{1}{IDCG(D_i, k)}\sum_{j=0}^{n-1} - \frac{rel_{D_i}(R_i(j))}{\text{ln}(j+1)}} \\ + \frac{rel_{D_i}(R_i(j))}{\text{ln}(j+2)}} \\ \text{Where} \\ \hspace{5 mm} n = \text{min}\left(\text{max}\left(|R_i|,|D_i|\right),k\right) \\ - \hspace{5 mm} IDCG(D, k) = \sum_{j=0}^{\text{min}(\left|D\right|, k) - 1} \frac{1}{\text{ln}(j+1)}$ + \hspace{5 mm} IDCG(D, k) = \sum_{j=0}^{\text{min}(\left|D\right|, k) - 1} \frac{1}{\text{ln}(j+2)}$ - NDCG at k is a + NDCG at k is a measure of how many of the first k recommended documents are in the set of true relevant documents averaged across all users. In contrast to precision at k, this metric takes into account the order of the recommendations (documents are assumed to be in order of decreasing relevance). diff --git a/docs/monitoring.md b/docs/monitoring.md index 6eaf33135744d..f6d52ef4597e9 100644 --- a/docs/monitoring.md +++ b/docs/monitoring.md @@ -388,6 +388,158 @@ value triggering garbage collection on jobs, and `spark.ui.retainedStages` that Note that the garbage collection takes place on playback: it is possible to retrieve more entries by increasing these values and restarting the history server. +### Executor Task Metrics + +The REST API exposes the values of the Task Metrics collected by Spark executors with the granularity +of task execution. The metrics can be used for performance troubleshooting and workload characterization. +A list of the available metrics, with a short description: + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
    Spark Executor Task Metric nameShort description
    executorRunTimeElapsed time the executor spent running this task. This includes time fetching shuffle data. + The value is expressed in milliseconds.
    executorCpuTimeCPU time the executor spent running this task. This includes time fetching shuffle data. + The value is expressed in nanoseconds.
    executorDeserializeTimeElapsed time spent to deserialize this task. The value is expressed in milliseconds.
    executorDeserializeCpuTimeCPU time taken on the executor to deserialize this task. The value is expressed + in nanoseconds.
    resultSizeThe number of bytes this task transmitted back to the driver as the TaskResult.
    jvmGCTimeElapsed time the JVM spent in garbage collection while executing this task. + The value is expressed in milliseconds.
    resultSerializationTimeElapsed time spent serializing the task result. The value is expressed in milliseconds.
    memoryBytesSpilledThe number of in-memory bytes spilled by this task.
    diskBytesSpilledThe number of on-disk bytes spilled by this task.
    peakExecutionMemoryPeak memory used by internal data structures created during shuffles, aggregations and + joins. The value of this accumulator should be approximately the sum of the peak sizes + across all such data structures created in this task. For SQL jobs, this only tracks all + unsafe operators and ExternalSort.
    inputMetrics.*Metrics related to reading data from [[org.apache.spark.rdd.HadoopRDD]] + or from persisted data.
        .bytesReadTotal number of bytes read.
        .recordsReadTotal number of records read.
    outputMetrics.*Metrics related to writing data externally (e.g. to a distributed filesystem), + defined only in tasks with output.
        .bytesWrittenTotal number of bytes written
        .recordsWrittenTotal number of records written
    shuffleReadMetrics.*Metrics related to shuffle read operations.
        .recordsReadNumber of records read in shuffle operations
        .remoteBlocksFetchedNumber of remote blocks fetched in shuffle operations
        .localBlocksFetchedNumber of local (as opposed to read from a remote executor) blocks fetched + in shuffle operations
        .totalBlocksFetchedNumber of blocks fetched in shuffle operations (both local and remote)
        .remoteBytesReadNumber of remote bytes read in shuffle operations
        .localBytesReadNumber of bytes read in shuffle operations from local disk (as opposed to + read from a remote executor)
        .totalBytesReadNumber of bytes read in shuffle operations (both local and remote)
        .remoteBytesReadToDiskNumber of remote bytes read to disk in shuffle operations. + Large blocks are fetched to disk in shuffle read operations, as opposed to + being read into memory, which is the default behavior.
        .fetchWaitTimeTime the task spent waiting for remote shuffle blocks. + This only includes the time blocking on shuffle input data. + For instance if block B is being fetched while the task is still not finished + processing block A, it is not considered to be blocking on block B. + The value is expressed in milliseconds.
    shuffleWriteMetrics.*Metrics related to operations writing shuffle data.
        .bytesWrittenNumber of bytes written in shuffle operations
        .recordsWrittenNumber of records written in shuffle operations
        .writeTimeTime spent blocking on writes to disk or buffer cache. The value is expressed + in nanoseconds.
    + + + ### API Versioning Policy These endpoints have been strongly versioned to make it easier to develop applications on top. @@ -435,6 +587,7 @@ set of sinks to which metrics are reported. The following instances are currentl * `executor`: A Spark executor. * `driver`: The Spark driver process (the process in which your SparkContext is created). * `shuffleService`: The Spark shuffle service. +* `applicationMaster`: The Spark ApplicationMaster when running on YARN. Each instance can report to zero or more _sinks_. Sinks are contained in the `org.apache.spark.metrics.sink` package: diff --git a/docs/quick-start.md b/docs/quick-start.md index f1a2096cd4dbd..ef7af6c3f6cec 100644 --- a/docs/quick-start.md +++ b/docs/quick-start.md @@ -12,7 +12,7 @@ interactive shell (in Python or Scala), then show how to write applications in Java, Scala, and Python. To follow along with this guide, first, download a packaged release of Spark from the -[Spark website](http://spark.apache.org/downloads.html). Since we won't be using HDFS, +[Spark website](https://spark.apache.org/downloads.html). Since we won't be using HDFS, you can download a package for any version of Hadoop. Note that, before Spark 2.0, the main programming interface of Spark was the Resilient Distributed Dataset (RDD). After Spark 2.0, RDDs are replaced by Dataset, which is strongly-typed like an RDD, but with richer optimizations under the hood. The RDD interface is still supported, and you can get a more detailed reference at the [RDD programming guide](rdd-programming-guide.html). However, we highly recommend you to switch to use Dataset, which has better performance than RDD. See the [SQL programming guide](sql-programming-guide.html) to get more information about Dataset. diff --git a/docs/rdd-programming-guide.md b/docs/rdd-programming-guide.md index b6424090d2fea..d95b757f36859 100644 --- a/docs/rdd-programming-guide.md +++ b/docs/rdd-programming-guide.md @@ -106,7 +106,7 @@ You can also use `bin/pyspark` to launch an interactive Python shell. If you wish to access HDFS data, you need to use a build of PySpark linking to your version of HDFS. -[Prebuilt packages](http://spark.apache.org/downloads.html) are also available on the Spark homepage +[Prebuilt packages](https://spark.apache.org/downloads.html) are also available on the Spark homepage for common HDFS versions. Finally, you need to import some Spark classes into your program. Add the following line: @@ -1569,7 +1569,7 @@ as Spark does not support two contexts running concurrently in the same program. # Where to Go from Here -You can see some [example Spark programs](http://spark.apache.org/examples.html) on the Spark website. +You can see some [example Spark programs](https://spark.apache.org/examples.html) on the Spark website. In addition, Spark includes several samples in the `examples` directory ([Scala]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/scala/org/apache/spark/examples), [Java]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/java/org/apache/spark/examples), diff --git a/docs/running-on-kubernetes.md b/docs/running-on-kubernetes.md index 408e446ea4822..4ae7acaae2314 100644 --- a/docs/running-on-kubernetes.md +++ b/docs/running-on-kubernetes.md @@ -117,6 +117,45 @@ If the local proxy is running at localhost:8001, `--master k8s://http://127.0.0. spark-submit. Finally, notice that in the above example we specify a jar with a specific URI with a scheme of `local://`. This URI is the location of the example jar that is already in the Docker image. +## Client Mode + +Starting with Spark 2.4.0, it is possible to run Spark applications on Kubernetes in client mode. When your application +runs in client mode, the driver can run inside a pod or on a physical host. When running an application in client mode, +it is recommended to account for the following factors: + +### Client Mode Networking + +Spark executors must be able to connect to the Spark driver over a hostname and a port that is routable from the Spark +executors. The specific network configuration that will be required for Spark to work in client mode will vary per +setup. If you run your driver inside a Kubernetes pod, you can use a +[headless service](https://kubernetes.io/docs/concepts/services-networking/service/#headless-services) to allow your +driver pod to be routable from the executors by a stable hostname. When deploying your headless service, ensure that +the service's label selector will only match the driver pod and no other pods; it is recommended to assign your driver +pod a sufficiently unique label and to use that label in the label selector of the headless service. Specify the driver's +hostname via `spark.driver.host` and your spark driver's port to `spark.driver.port`. + +### Client Mode Executor Pod Garbage Collection + +If you run your Spark driver in a pod, it is highly recommended to set `spark.driver.pod.name` to the name of that pod. +When this property is set, the Spark scheduler will deploy the executor pods with an +[OwnerReference](https://kubernetes.io/docs/concepts/workloads/controllers/garbage-collection/), which in turn will +ensure that once the driver pod is deleted from the cluster, all of the application's executor pods will also be deleted. +The driver will look for a pod with the given name in the namespace specified by `spark.kubernetes.namespace`, and +an OwnerReference pointing to that pod will be added to each executor pod's OwnerReferences list. Be careful to avoid +setting the OwnerReference to a pod that is not actually that driver pod, or else the executors may be terminated +prematurely when the wrong pod is deleted. + +If your application is not running inside a pod, or if `spark.driver.pod.name` is not set when your application is +actually running in a pod, keep in mind that the executor pods may not be properly deleted from the cluster when the +application exits. The Spark scheduler attempts to delete these pods, but if the network request to the API server fails +for any reason, these pods will remain in the cluster. The executor processes should exit when they cannot reach the +driver, so the executor pods should not consume compute resources (cpu and memory) in the cluster after your application +exits. + +### Authentication Parameters + +Use the exact prefix `spark.kubernetes.authenticate` for Kubernetes authentication parameters in client mode. + ## Dependency Management If your application's dependencies are all hosted in remote locations like HDFS or HTTP servers, they may be referred to @@ -146,6 +185,49 @@ To use a secret through an environment variable use the following options to the --conf spark.kubernetes.executor.secretKeyRef.ENV_NAME=name:key ``` +## Using Kubernetes Volumes + +Starting with Spark 2.4.0, users can mount the following types of Kubernetes [volumes](https://kubernetes.io/docs/concepts/storage/volumes/) into the driver and executor pods: +* [hostPath](https://kubernetes.io/docs/concepts/storage/volumes/#hostpath): mounts a file or directory from the host node’s filesystem into a pod. +* [emptyDir](https://kubernetes.io/docs/concepts/storage/volumes/#emptydir): an initially empty volume created when a pod is assigned to a node. +* [persistentVolumeClaim](https://kubernetes.io/docs/concepts/storage/volumes/#persistentvolumeclaim): used to mount a `PersistentVolume` into a pod. + +To mount a volume of any of the types above into the driver pod, use the following configuration property: + +``` +--conf spark.kubernetes.driver.volumes.[VolumeType].[VolumeName].mount.path= +--conf spark.kubernetes.driver.volumes.[VolumeType].[VolumeName].mount.readOnly= +``` + +Specifically, `VolumeType` can be one of the following values: `hostPath`, `emptyDir`, and `persistentVolumeClaim`. `VolumeName` is the name you want to use for the volume under the `volumes` field in the pod specification. + +Each supported type of volumes may have some specific configuration options, which can be specified using configuration properties of the following form: + +``` +spark.kubernetes.driver.volumes.[VolumeType].[VolumeName].options.[OptionName]= +``` + +For example, the claim name of a `persistentVolumeClaim` with volume name `checkpointpvc` can be specified using the following property: + +``` +spark.kubernetes.driver.volumes.persistentVolumeClaim.checkpointpvc.options.claimName=check-point-pvc-claim +``` + +The configuration properties for mounting volumes into the executor pods use prefix `spark.kubernetes.executor.` instead of `spark.kubernetes.driver.`. For a complete list of available options for each supported type of volumes, please refer to the [Spark Properties](#spark-properties) section below. + +## Local Storage + +Spark uses temporary scratch space to spill data to disk during shuffles and other operations. When using Kubernetes as the resource manager the pods will be created with an [emptyDir](https://kubernetes.io/docs/concepts/storage/volumes/#emptydir) volume mounted for each directory listed in `SPARK_LOCAL_DIRS`. If no directories are explicitly specified then a default directory is created and configured appropriately. + +`emptyDir` volumes use the ephemeral storage feature of Kubernetes and do not persist beyond the life of the pod. + +### Using RAM for local storage + +`emptyDir` volumes use the nodes backing storage for ephemeral storage by default, this behaviour may not be appropriate for some compute environments. For example if you have diskless nodes with remote storage mounted over a network, having lots of executors doing IO to this remote storage may actually degrade performance. + +In this case it may be desirable to set `spark.kubernetes.local.dirs.tmpfs=true` in your configuration which will cause the `emptyDir` volumes to be configured as `tmpfs` i.e. RAM backed volumes. When configured like this Sparks local storage usage will count towards your pods memory usage therefore you may wish to increase your memory requests by increasing the value of `spark.kubernetes.memoryOverheadFactor` as appropriate. + + ## Introspection and Debugging These are the different ways in which you can investigate a running/completed Spark application, monitor progress, and @@ -258,27 +340,17 @@ RBAC authorization and how to configure Kubernetes service accounts for pods, pl [Using RBAC Authorization](https://kubernetes.io/docs/admin/authorization/rbac/) and [Configure Service Accounts for Pods](https://kubernetes.io/docs/tasks/configure-pod-container/configure-service-account/). -## Client Mode - -Client mode is not currently supported. - ## Future Work -There are several Spark on Kubernetes features that are currently being incubated in a fork - -[apache-spark-on-k8s/spark](https://github.com/apache-spark-on-k8s/spark), which are expected to eventually make it into -future versions of the spark-kubernetes integration. +There are several Spark on Kubernetes features that are currently being worked on or planned to be worked on. Those features are expected to eventually make it into future versions of the spark-kubernetes integration. Some of these include: -* R -* Dynamic Executor Scaling +* Dynamic Resource Allocation and External Shuffle Service * Local File Dependency Management * Spark Application Management * Job Queues and Resource Management -You can refer to the [documentation](https://apache-spark-on-k8s.github.io/userdocs/) if you want to try these features -and provide feedback to the development team. - # Configuration See the [configuration page](configuration.html) for information on Spark configurations. The following configurations are @@ -354,7 +426,7 @@ specific to Spark on Kubernetes. Path to the CA cert file for connecting to the Kubernetes API server over TLS when starting the driver. This file must be located on the submitting machine's disk. Specify this as a path as opposed to a URI (i.e. do not provide - a scheme). + a scheme). In client mode, use spark.kubernetes.authenticate.caCertFile instead. @@ -363,7 +435,7 @@ specific to Spark on Kubernetes. Path to the client key file for authenticating against the Kubernetes API server when starting the driver. This file must be located on the submitting machine's disk. Specify this as a path as opposed to a URI (i.e. do not provide - a scheme). + a scheme). In client mode, use spark.kubernetes.authenticate.clientKeyFile instead. @@ -372,7 +444,7 @@ specific to Spark on Kubernetes. Path to the client cert file for authenticating against the Kubernetes API server when starting the driver. This file must be located on the submitting machine's disk. Specify this as a path as opposed to a URI (i.e. do not - provide a scheme). + provide a scheme). In client mode, use spark.kubernetes.authenticate.clientCertFile instead. @@ -381,7 +453,7 @@ specific to Spark on Kubernetes. OAuth token to use when authenticating against the Kubernetes API server when starting the driver. Note that unlike the other authentication options, this is expected to be the exact string value of the token to use for - the authentication. + the authentication. In client mode, use spark.kubernetes.authenticate.oauthToken instead. @@ -390,7 +462,7 @@ specific to Spark on Kubernetes. Path to the OAuth token file containing the token to use when authenticating against the Kubernetes API server when starting the driver. This file must be located on the submitting machine's disk. Specify this as a path as opposed to a URI (i.e. do not - provide a scheme). + provide a scheme). In client mode, use spark.kubernetes.authenticate.oauthTokenFile instead. @@ -399,7 +471,8 @@ specific to Spark on Kubernetes. Path to the CA cert file for connecting to the Kubernetes API server over TLS from the driver pod when requesting executors. This file must be located on the submitting machine's disk, and will be uploaded to the driver pod. - Specify this as a path as opposed to a URI (i.e. do not provide a scheme). + Specify this as a path as opposed to a URI (i.e. do not provide a scheme). In client mode, use + spark.kubernetes.authenticate.caCertFile instead. @@ -407,10 +480,9 @@ specific to Spark on Kubernetes. (none) Path to the client key file for authenticating against the Kubernetes API server from the driver pod when requesting - executors. This file must be located on the submitting machine's disk, and will be uploaded to the driver pod. - Specify this as a path as opposed to a URI (i.e. do not provide a scheme). If this is specified, it is highly - recommended to set up TLS for the driver submission server, as this value is sensitive information that would be - passed to the driver pod in plaintext otherwise. + executors. This file must be located on the submitting machine's disk, and will be uploaded to the driver pod as + a Kubernetes secret. Specify this as a path as opposed to a URI (i.e. do not provide a scheme). + In client mode, use spark.kubernetes.authenticate.clientKeyFile instead. @@ -419,7 +491,8 @@ specific to Spark on Kubernetes. Path to the client cert file for authenticating against the Kubernetes API server from the driver pod when requesting executors. This file must be located on the submitting machine's disk, and will be uploaded to the - driver pod. Specify this as a path as opposed to a URI (i.e. do not provide a scheme). + driver pod as a Kubernetes secret. Specify this as a path as opposed to a URI (i.e. do not provide a scheme). + In client mode, use spark.kubernetes.authenticate.clientCertFile instead. @@ -428,9 +501,8 @@ specific to Spark on Kubernetes. OAuth token to use when authenticating against the Kubernetes API server from the driver pod when requesting executors. Note that unlike the other authentication options, this must be the exact string value of - the token to use for the authentication. This token value is uploaded to the driver pod. If this is specified, it is - highly recommended to set up TLS for the driver submission server, as this value is sensitive information that would - be passed to the driver pod in plaintext otherwise. + the token to use for the authentication. This token value is uploaded to the driver pod as a Kubernetes secret. + In client mode, use spark.kubernetes.authenticate.oauthToken instead. @@ -439,9 +511,8 @@ specific to Spark on Kubernetes. Path to the OAuth token file containing the token to use when authenticating against the Kubernetes API server from the driver pod when requesting executors. Note that unlike the other authentication options, this file must contain the exact string value of - the token to use for the authentication. This token value is uploaded to the driver pod. If this is specified, it is - highly recommended to set up TLS for the driver submission server, as this value is sensitive information that would - be passed to the driver pod in plaintext otherwise. + the token to use for the authentication. This token value is uploaded to the driver pod as a secret. In client mode, use + spark.kubernetes.authenticate.oauthTokenFile instead. @@ -450,7 +521,8 @@ specific to Spark on Kubernetes. Path to the CA cert file for connecting to the Kubernetes API server over TLS from the driver pod when requesting executors. This path must be accessible from the driver pod. - Specify this as a path as opposed to a URI (i.e. do not provide a scheme). + Specify this as a path as opposed to a URI (i.e. do not provide a scheme). In client mode, use + spark.kubernetes.authenticate.caCertFile instead. @@ -459,7 +531,8 @@ specific to Spark on Kubernetes. Path to the client key file for authenticating against the Kubernetes API server from the driver pod when requesting executors. This path must be accessible from the driver pod. - Specify this as a path as opposed to a URI (i.e. do not provide a scheme). + Specify this as a path as opposed to a URI (i.e. do not provide a scheme). In client mode, use + spark.kubernetes.authenticate.clientKeyFile instead. @@ -468,7 +541,8 @@ specific to Spark on Kubernetes. Path to the client cert file for authenticating against the Kubernetes API server from the driver pod when requesting executors. This path must be accessible from the driver pod. - Specify this as a path as opposed to a URI (i.e. do not provide a scheme). + Specify this as a path as opposed to a URI (i.e. do not provide a scheme). In client mode, use + spark.kubernetes.authenticate.clientCertFile instead. @@ -477,7 +551,8 @@ specific to Spark on Kubernetes. Path to the file containing the OAuth token to use when authenticating against the Kubernetes API server from the driver pod when requesting executors. This path must be accessible from the driver pod. - Note that unlike the other authentication options, this file must contain the exact string value of the token to use for the authentication. + Note that unlike the other authentication options, this file must contain the exact string value of the token to use + for the authentication. In client mode, use spark.kubernetes.authenticate.oauthTokenFile instead. @@ -486,7 +561,48 @@ specific to Spark on Kubernetes. Service account that is used when running the driver pod. The driver pod uses this service account when requesting executor pods from the API server. Note that this cannot be specified alongside a CA cert file, client key file, - client cert file, and/or OAuth token. + client cert file, and/or OAuth token. In client mode, use spark.kubernetes.authenticate.serviceAccountName instead. + + + + spark.kubernetes.authenticate.caCertFile + (none) + + In client mode, path to the CA cert file for connecting to the Kubernetes API server over TLS when + requesting executors. Specify this as a path as opposed to a URI (i.e. do not provide a scheme). + + + + spark.kubernetes.authenticate.clientKeyFile + (none) + + In client mode, path to the client key file for authenticating against the Kubernetes API server + when requesting executors. Specify this as a path as opposed to a URI (i.e. do not provide a scheme). + + + + spark.kubernetes.authenticate.clientCertFile + (none) + + In client mode, path to the client cert file for authenticating against the Kubernetes API server + when requesting executors. Specify this as a path as opposed to a URI (i.e. do not provide a scheme). + + + + spark.kubernetes.authenticate.oauthToken + (none) + + In client mode, the OAuth token to use when authenticating against the Kubernetes API server when + requesting executors. Note that unlike the other authentication options, this must be the exact string value of + the token to use for the authentication. + + + + spark.kubernetes.authenticate.oauthTokenFile + (none) + + In client mode, path to the file containing the OAuth token to use when authenticating against the Kubernetes API + server when requesting executors. @@ -529,8 +645,11 @@ specific to Spark on Kubernetes. spark.kubernetes.driver.pod.name (none) - Name of the driver pod. If not set, the driver pod name is set to "spark.app.name" suffixed by the current timestamp - to avoid name conflicts. + Name of the driver pod. In cluster mode, if this is not set, the driver pod name is set to "spark.app.name" + suffixed by the current timestamp to avoid name conflicts. In client mode, if your application is running + inside a pod, it is highly recommended to set this to the name of the pod your driver is running in. Setting this + value in client mode allows the driver to become the owner of its executor pods, which in turn allows the executor + pods to be garbage collected by the cluster. @@ -629,6 +748,62 @@ specific to Spark on Kubernetes. Add as an environment variable to the executor container with name EnvName (case sensitive), the value referenced by key key in the data of the referenced Kubernetes Secret. For example, spark.kubernetes.executor.secrets.ENV_VAR=spark-secret:key. + + + spark.kubernetes.driver.volumes.[VolumeType].[VolumeName].mount.path + (none) + + Add the Kubernetes Volume named VolumeName of the VolumeType type to the driver pod on the path specified in the value. For example, + spark.kubernetes.driver.volumes.persistentVolumeClaim.checkpointpvc.mount.path=/checkpoint. + + + + spark.kubernetes.driver.volumes.[VolumeType].[VolumeName].mount.readOnly + (none) + + Specify if the mounted volume is read only or not. For example, + spark.kubernetes.driver.volumes.persistentVolumeClaim.checkpointpvc.mount.readOnly=false. + + + + spark.kubernetes.driver.volumes.[VolumeType].[VolumeName].options.[OptionName] + (none) + + Configure Kubernetes Volume options passed to the Kubernetes with OptionName as key having specified value, must conform with Kubernetes option format. For example, + spark.kubernetes.driver.volumes.persistentVolumeClaim.checkpointpvc.options.claimName=spark-pvc-claim. + + + + spark.kubernetes.executor.volumes.[VolumeType].[VolumeName].mount.path + (none) + + Add the Kubernetes Volume named VolumeName of the VolumeType type to the executor pod on the path specified in the value. For example, + spark.kubernetes.executor.volumes.persistentVolumeClaim.checkpointpvc.mount.path=/checkpoint. + + + + spark.kubernetes.executor.volumes.[VolumeType].[VolumeName].mount.readOnly + false + + Specify if the mounted volume is read only or not. For example, + spark.kubernetes.executor.volumes.persistentVolumeClaim.checkpointpvc.mount.readOnly=false. + + + + spark.kubernetes.executor.volumes.[VolumeType].[VolumeName].options.[OptionName] + (none) + + Configure Kubernetes Volume options passed to the Kubernetes with OptionName as key having specified value. For example, + spark.kubernetes.executor.volumes.persistentVolumeClaim.checkpointpvc.options.claimName=spark-pvc-claim. + + + + spark.kubernetes.local.dirs.tmpfs + false + + Configure the emptyDir volumes used to back SPARK_LOCAL_DIRS within the Spark driver and executor pods to use tmpfs backing i.e. RAM. See Local Storage earlier on this page + for more discussion of this. + spark.kubernetes.memoryOverheadFactor @@ -639,7 +814,7 @@ specific to Spark on Kubernetes. - spark.kubernetes.pyspark.pythonversion + spark.kubernetes.pyspark.pythonVersion "2" This sets the major Python version of the docker image used to run the driver and executor containers. Can either be 2 or 3. diff --git a/docs/running-on-mesos.md b/docs/running-on-mesos.md index 66ffb17949845..b473e654563d6 100644 --- a/docs/running-on-mesos.md +++ b/docs/running-on-mesos.md @@ -174,6 +174,8 @@ can find the results of the driver from the Mesos Web UI. To use cluster mode, you must start the `MesosClusterDispatcher` in your cluster via the `sbin/start-mesos-dispatcher.sh` script, passing in the Mesos master URL (e.g: mesos://host:5050). This starts the `MesosClusterDispatcher` as a daemon running on the host. +Note that the `MesosClusterDispatcher` does not support authentication. You should ensure that all network access to it is +protected (port 7077 by default). By setting the Mesos proxy config property (requires mesos version >= 1.4), `--conf spark.mesos.proxy.baseURL=http://localhost:5050` when launching the dispatcher, the mesos sandbox URI for each driver is added to the mesos dispatcher UI. @@ -670,7 +672,7 @@ See the [configuration page](configuration.html) for information on Spark config spark.mesos.dispatcher.historyServer.url (none) - Set the URL of the history + Set the URL of the history server. The dispatcher will then link each driver to its entry in the history server. diff --git a/docs/running-on-yarn.md b/docs/running-on-yarn.md index 575da7205b529..e3d67c34d53eb 100644 --- a/docs/running-on-yarn.md +++ b/docs/running-on-yarn.md @@ -61,7 +61,7 @@ In `cluster` mode, the driver runs on a different machine than the client, so `S # Preparations Running Spark on YARN requires a binary distribution of Spark which is built with YARN support. -Binary distributions can be downloaded from the [downloads page](http://spark.apache.org/downloads.html) of the project website. +Binary distributions can be downloaded from the [downloads page](https://spark.apache.org/downloads.html) of the project website. To build Spark yourself, refer to [Building Spark](building-spark.html). To make Spark runtime jars accessible from YARN side, you can specify `spark.yarn.archive` or `spark.yarn.jars`. For details please refer to [Spark Properties](running-on-yarn.html#spark-properties). If neither `spark.yarn.archive` nor `spark.yarn.jars` is specified, Spark will create a zip file with all jars under `$SPARK_HOME/jars` and upload it to the distributed cache. @@ -218,9 +218,10 @@ To use a custom metrics.properties for the application master and executors, upd spark.yarn.dist.forceDownloadSchemes (none) - Comma-separated list of schemes for which files will be downloaded to the local disk prior to + Comma-separated list of schemes for which resources will be downloaded to the local disk prior to being added to YARN's distributed cache. For use in cases where the YARN service does not - support schemes that are supported by Spark, like http, https and ftp. + support schemes that are supported by Spark, like http, https and ftp, or jars required to be in the + local YARN client's classpath. Wildcard '*' is denoted to download resources for all the schemes. @@ -420,7 +421,14 @@ To use a custom metrics.properties for the application master and executors, upd spark.blacklist.application.maxFailedExecutorsPerNode. - + + spark.yarn.metrics.namespace + (none) + + The root namespace for AM metrics reporting. + If it is not set then the YARN application ID is used. + + # Important notes diff --git a/docs/security.md b/docs/security.md index 6ef3a808e0471..7fb3e17de94c9 100644 --- a/docs/security.md +++ b/docs/security.md @@ -22,7 +22,12 @@ secrets to be secure. For other resource managers, `spark.authenticate.secret` must be configured on each of the nodes. This secret will be shared by all the daemons and applications, so this deployment configuration is -not as secure as the above, especially when considering multi-tenant clusters. +not as secure as the above, especially when considering multi-tenant clusters. In this +configuration, a user with the secret can effectively impersonate any other user. + +The Rest Submission Server and the MesosClusterDispatcher do not support authentication. You should +ensure that all network access to the REST API & MesosClusterDispatcher (port 6066 and 7077 +respectively by default) are restricted to hosts that are trusted to submit jobs. @@ -44,7 +49,7 @@ not as secure as the above, especially when considering multi-tenant clusters. Spark supports AES-based encryption for RPC connections. For encryption to be enabled, RPC authentication must also be enabled and properly configured. AES encryption uses the -[Apache Commons Crypto](http://commons.apache.org/proper/commons-crypto/) library, and Spark's +[Apache Commons Crypto](https://commons.apache.org/proper/commons-crypto/) library, and Spark's configuration system allows access to that library's configuration for advanced users. There is also support for SASL-based encryption, although it should be considered deprecated. It @@ -164,7 +169,7 @@ The following settings cover enabling encryption for data written to disk: ## Authentication and Authorization -Enabling authentication for the Web UIs is done using [javax servlet filters](http://docs.oracle.com/javaee/6/api/javax/servlet/Filter.html). +Enabling authentication for the Web UIs is done using [javax servlet filters](https://docs.oracle.com/javaee/6/api/javax/servlet/Filter.html). You will need a filter that implements the authentication method you want to deploy. Spark does not provide any built-in authentication filters. @@ -278,7 +283,7 @@ To enable authorization in the SHS, a few extra options are used:
    Property NameDefaultMeaning
    - + - + - + + + + + + + + + diff --git a/docs/spark-standalone.md b/docs/spark-standalone.md index 14d742de5655c..7975b0c8b11ca 100644 --- a/docs/spark-standalone.md +++ b/docs/spark-standalone.md @@ -362,8 +362,15 @@ You can run Spark alongside your existing Hadoop cluster by just launching it as # Configuring Ports for Network Security -Spark makes heavy use of the network, and some environments have strict requirements for using -tight firewall settings. For a complete list of ports to configure, see the +Generally speaking, a Spark cluster and its services are not deployed on the public internet. +They are generally private services, and should only be accessible within the network of the +organization that deploys Spark. Access to the hosts and ports used by Spark services should +be limited to origin hosts that need to access the services. + +This is particularly important for clusters using the standalone resource manager, as they do +not support fine-grained access control in a way that other resource managers do. + +For a complete list of ports to configure, see the [security page](security.html#configuring-ports-for-network-security). # High Availability @@ -376,7 +383,7 @@ By default, standalone scheduling clusters are resilient to Worker failures (ins Utilizing ZooKeeper to provide leader election and some state storage, you can launch multiple Masters in your cluster connected to the same ZooKeeper instance. One will be elected "leader" and the others will remain in standby mode. If the current leader dies, another Master will be elected, recover the old Master's state, and then resume scheduling. The entire recovery process (from the time the first leader goes down) should take between 1 and 2 minutes. Note that this delay only affects scheduling _new_ applications -- applications that were already running during Master failover are unaffected. -Learn more about getting started with ZooKeeper [here](http://zookeeper.apache.org/doc/current/zookeeperStarted.html). +Learn more about getting started with ZooKeeper [here](https://zookeeper.apache.org/doc/current/zookeeperStarted.html). **Configuration** @@ -419,6 +426,6 @@ In order to enable this recovery mode, you can set SPARK_DAEMON_JAVA_OPTS in spa **Details** -* This solution can be used in tandem with a process monitor/manager like [monit](http://mmonit.com/monit/), or just to enable manual recovery via restart. +* This solution can be used in tandem with a process monitor/manager like [monit](https://mmonit.com/monit/), or just to enable manual recovery via restart. * While filesystem recovery seems straightforwardly better than not doing any recovery at all, this mode may be suboptimal for certain development or experimental purposes. In particular, killing a master via stop-master.sh does not clean up its recovery state, so whenever you start a new Master, it will enter recovery mode. This could increase the startup time by up to 1 minute if it needs to wait for all previously-registered Workers/clients to timeout. * While it's not officially supported, you could mount an NFS directory as the recovery directory. If the original Master node dies completely, you could then start a Master on a different node, which would correctly recover all previously registered Workers/applications (equivalent to ZooKeeper recovery). Future applications will have to be able to find the new Master, however, in order to register. diff --git a/docs/sparkr.md b/docs/sparkr.md index 4faad2c4c1824..b4248e8bb21de 100644 --- a/docs/sparkr.md +++ b/docs/sparkr.md @@ -128,7 +128,7 @@ head(df) SparkR supports operating on a variety of data sources through the `SparkDataFrame` interface. This section describes the general methods for loading and saving data using Data Sources. You can check the Spark SQL programming guide for more [specific options](sql-programming-guide.html#manually-specifying-options) that are available for the built-in data sources. The general method for creating SparkDataFrames from data sources is `read.df`. This method takes in the path for the file to load and the type of data source, and the currently active SparkSession will be used automatically. -SparkR supports reading JSON, CSV and Parquet files natively, and through packages available from sources like [Third Party Projects](http://spark.apache.org/third-party-projects.html), you can find data source connectors for popular file formats like Avro. These packages can either be added by +SparkR supports reading JSON, CSV and Parquet files natively, and through packages available from sources like [Third Party Projects](https://spark.apache.org/third-party-projects.html), you can find data source connectors for popular file formats like Avro. These packages can either be added by specifying `--packages` with `spark-submit` or `sparkR` commands, or if initializing SparkSession with `sparkPackages` parameter when in an interactive R shell or from RStudio.
    @@ -667,3 +667,7 @@ You can inspect the search path in R with [`search()`](https://stat.ethz.ch/R-ma ## Upgrading to SparkR 2.3.1 and above - In SparkR 2.3.0 and earlier, the `start` parameter of `substr` method was wrongly subtracted by one and considered as 0-based. This can lead to inconsistent substring results and also does not match with the behaviour with `substr` in R. In version 2.3.1 and later, it has been fixed so the `start` parameter of `substr` method is now 1-base. As an example, `substr(lit('abcdef'), 2, 4))` would result to `abc` in SparkR 2.3.0, and the result would be `bcd` in SparkR 2.3.1. + +## Upgrading to SparkR 2.4.0 + + - Previously, we don't check the validity of the size of the last layer in `spark.mlp`. For example, if the training data only has two labels, a `layers` param like `c(1, 3)` doesn't cause an error previously, now it does. diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index cd7329b621122..9da7d64322eb6 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1345,8 +1345,8 @@ the following case-insensitive options: These options must all be specified if any of them is specified. In addition, numPartitions must be specified. They describe how to partition the table when reading in parallel from multiple workers. - partitionColumn must be a numeric column from the table in question. Notice - that lowerBound and upperBound are just used to decide the + partitionColumn must be a numeric, date, or timestamp column from the table in question. + Notice that lowerBound and upperBound are just used to decide the partition stride, not for filtering the rows in table. So all rows in the table will be partitioned and returned. This option applies only to reading. @@ -1407,6 +1407,13 @@ the following case-insensitive options: This is a JDBC writer related option. When SaveMode.Overwrite is enabled, this option causes Spark to truncate an existing table instead of dropping and recreating it. This can be more efficient, and prevents the table metadata (e.g., indices) from being removed. However, it will not work in some cases, such as when the new data has a different schema. It defaults to false. This option applies only to writing.
    + + + + + @@ -1428,6 +1435,13 @@ the following case-insensitive options: The custom schema to use for reading data from JDBC connectors. For example, "id DECIMAL(38, 0), name STRING". You can also specify partial fields, and the others use the default type mapping. For example, "id DECIMAL(38, 0)". The column names should be identical to the corresponding column names of JDBC table. Users can specify the corresponding data types of Spark SQL instead of using the defaults. This option applies only to reading. + + + + +
    Property NameDefaultMeaning
    spark.history.ui.acls.enablespark.history.ui.acls.enable false Specifies whether ACLs should be checked to authorize users viewing the applications in @@ -292,7 +297,7 @@ To enable authorization in the SHS, a few extra options are used:
    spark.history.ui.admin.aclsspark.history.ui.admin.acls None Comma separated list of users that have view access to all the Spark applications in history @@ -300,7 +305,7 @@ To enable authorization in the SHS, a few extra options are used:
    spark.history.ui.admin.acls.groupsspark.history.ui.admin.acls.groups None Comma separated list of groups that have view access to all the Spark applications in history @@ -487,7 +492,7 @@ distributed with the application using the `--files` command line argument (or t configuration should just reference the file name with no absolute path. Distributing local key stores this way may require the files to be staged in HDFS (or other similar -distributed file system used by the cluster), so it's recommended that the undelying file system be +distributed file system used by the cluster), so it's recommended that the underlying file system be configured with security in mind (e.g. by enabling authentication and wire encryption). ### Standalone mode @@ -501,6 +506,7 @@ can be accomplished by setting `spark.ssl.useNodeLocalConf` to `true`. In that c provided by the user on the client side are not used. ### Mesos mode + Mesos 1.3.0 and newer supports `Secrets` primitives as both file-based and environment based secrets. Spark allows the specification of file-based and environment variable based secrets with `spark.mesos.driver.secret.filenames` and `spark.mesos.driver.secret.envkeys`, respectively. @@ -562,8 +568,12 @@ Security. # Configuring Ports for Network Security -Spark makes heavy use of the network, and some environments have strict requirements for using tight -firewall settings. Below are the primary ports that Spark uses for its communication and how to +Generally speaking, a Spark cluster and its services are not deployed on the public internet. +They are generally private services, and should only be accessible within the network of the +organization that deploys Spark. Access to the hosts and ports used by Spark services should +be limited to origin hosts that need to access the services. + +Below are the primary ports that Spark uses for its communication and how to configure those ports. ## Standalone mode only @@ -597,6 +607,14 @@ configure those ports. SPARK_MASTER_PORT Set to "0" to choose a port randomly. Standalone mode only.
    External ServiceStandalone Master6066Submit job to cluster via REST APIspark.master.rest.portUse spark.master.rest.enabled to enable/disable this service. Standalone mode only.
    Standalone Master Standalone Worker
    cascadeTruncate + This is a JDBC writer related option. If enabled and supported by the JDBC database (PostgreSQL and Oracle at the moment), this options allows execution of a TRUNCATE TABLE t CASCADE (in the case of PostgreSQL a TRUNCATE TABLE ONLY t CASCADE is executed to prevent inadvertently truncating descendant tables). This will affect other tables, and thus should be used with care. This option applies only to writing. It defaults to the default cascading truncate behaviour of the JDBC database in question, specified in the isCascadeTruncate in each JDBCDialect. +
    createTableOptions
    pushDownPredicate + The option to enable or disable predicate push-down into the JDBC data source. The default value is true, in which case Spark will push down filters to the JDBC data source as much as possible. Otherwise, if set to false, no filter will be pushed down to the JDBC data source and thus all filters will be handled by Spark. Predicate push-down is usually turned off when the predicate filtering is performed faster by Spark than by the JDBC data source. +
    @@ -1468,6 +1482,9 @@ SELECT * FROM resultTable
    +## Avro Files +See the [Apache Avro Data Source Guide](avro-data-source-guide.html). + ## Troubleshooting * The JDBC driver class must be visible to the primordial class loader on the client session and on all executors. This is because Java's DriverManager class does a security check that results in it ignoring all drivers not visible to the primordial class loader when one goes to open a connection. One convenient way to do this is to modify compute_classpath.sh on all worker nodes to include your driver JARs. @@ -1782,7 +1799,7 @@ strings, e.g. integer indices. See [pandas.DataFrame](https://pandas.pydata.org/ on how to label columns when constructing a `pandas.DataFrame`. Note that all data for a group will be loaded into memory before the function is applied. This can -lead to out of memory exceptons, especially if the group sizes are skewed. The configuration for +lead to out of memory exceptions, especially if the group sizes are skewed. The configuration for [maxRecordsPerBatch](#setting-arrow-batch-size) is not applied on groups and it is up to the user to ensure that the grouped data will fit into the available memory. @@ -1797,6 +1814,25 @@ The following example shows how to use `groupby().apply()` to subtract the mean For detailed usage, please see [`pyspark.sql.functions.pandas_udf`](api/python/pyspark.sql.html#pyspark.sql.functions.pandas_udf) and [`pyspark.sql.GroupedData.apply`](api/python/pyspark.sql.html#pyspark.sql.GroupedData.apply). +### Grouped Aggregate + +Grouped aggregate Pandas UDFs are similar to Spark aggregate functions. Grouped aggregate Pandas UDFs are used with `groupBy().agg()` and +[`pyspark.sql.Window`](api/python/pyspark.sql.html#pyspark.sql.Window). It defines an aggregation from one or more `pandas.Series` +to a scalar value, where each `pandas.Series` represents a column within the group or window. + +Note that this type of UDF does not support partial aggregation and all data for a group or window will be loaded into memory. Also, +only unbounded window is supported with Grouped aggregate Pandas UDFs currently. + +The following example shows how to use this type of UDF to compute mean with groupBy and window operations: + +
    +
    +{% include_example grouped_agg_pandas_udf python/sql/arrow.py %} +
    +
    + +For detailed usage, please see [`pyspark.sql.functions.pandas_udf`](api/python/pyspark.sql.html#pyspark.sql.functions.pandas_udf) + ## Usage Notes ### Supported SQL Types @@ -1843,6 +1879,8 @@ working with timestamps in `pandas_udf`s to get the best performance, see ## Upgrading From Spark SQL 2.3 to 2.4 + - Since Spark 2.4, Spark will evaluate the set operations referenced in a query by following a precedence rule as per the SQL standard. If the order is not specified by parentheses, set operations are performed from left to right with the exception that all INTERSECT operations are performed before any UNION, EXCEPT or MINUS operations. The old behaviour of giving equal precedence to all the set operations are preserved under a newly added configuration `spark.sql.legacy.setopsPrecedence.enabled` with a default value of `false`. When this property is set to `true`, spark will evaluate the set operators from left to right as they appear in the query given no explicit ordering is enforced by usage of parenthesis. + - Since Spark 2.4, Spark will display table description column Last Access value as UNKNOWN when the value was Jan 01 1970. - Since Spark 2.4, Spark maximizes the usage of a vectorized ORC reader for ORC files by default. To do that, `spark.sql.orc.impl` and `spark.sql.orc.filterPushdown` change their default values to `native` and `true` respectively. - In PySpark, when Arrow optimization is enabled, previously `toPandas` just failed when Arrow optimization is unable to be used whereas `createDataFrame` from Pandas DataFrame allowed the fallback to non-optimization. Now, both `toPandas` and `createDataFrame` from Pandas DataFrame allow the fallback by default, which can be switched off by `spark.sql.execution.arrow.fallback.enabled`. - Since Spark 2.4, writing an empty dataframe to a directory launches at least one write task, even if physically the dataframe has no partition. This introduces a small behavior change that for self-describing file formats like Parquet and Orc, Spark creates a metadata-only file in the target directory when writing a 0-partition dataframe, so that schema inference can still work if users read that directory later. The new behavior is more reasonable and more consistent regarding writing empty dataframe. @@ -1850,12 +1888,16 @@ working with timestamps in `pandas_udf`s to get the best performance, see - Since Spark 2.4, writing a dataframe with an empty or nested empty schema using any file formats (parquet, orc, json, text, csv etc.) is not allowed. An exception is thrown when attempting to write dataframes with empty schema. - Since Spark 2.4, Spark compares a DATE type with a TIMESTAMP type after promotes both sides to TIMESTAMP. To set `false` to `spark.sql.hive.compareDateTimestampInTimestamp` restores the previous behavior. This option will be removed in Spark 3.0. - Since Spark 2.4, creating a managed table with nonempty location is not allowed. An exception is thrown when attempting to create a managed table with nonempty location. To set `true` to `spark.sql.allowCreatingManagedTableUsingNonemptyLocation` restores the previous behavior. This option will be removed in Spark 3.0. + - Since Spark 2.4, renaming a managed table to existing location is not allowed. An exception is thrown when attempting to rename a managed table to existing location. - Since Spark 2.4, the type coercion rules can automatically promote the argument types of the variadic SQL functions (e.g., IN/COALESCE) to the widest common type, no matter how the input arguments order. In prior Spark versions, the promotion could fail in some specific orders (e.g., TimestampType, IntegerType and StringType) and throw an exception. - Since Spark 2.4, Spark has enabled non-cascading SQL cache invalidation in addition to the traditional cache invalidation mechanism. The non-cascading cache invalidation mechanism allows users to remove a cache without impacting its dependent caches. This new cache invalidation mechanism is used in scenarios where the data of the cache to be removed is still valid, e.g., calling unpersist() on a Dataset, or dropping a temporary view. This allows users to free up memory and keep the desired caches valid at the same time. - - In version 2.3 and earlier, `to_utc_timestamp` and `from_utc_timestamp` respect the timezone in the input timestamp string, which breaks the assumption that the input timestamp is in a specific timezone. Therefore, these 2 functions can return unexpected results. In version 2.4 and later, this problem has been fixed. `to_utc_timestamp` and `from_utc_timestamp` will return null if the input timestamp string contains timezone. As an example, `from_utc_timestamp('2000-10-10 00:00:00', 'GMT+1')` will return `2000-10-10 01:00:00` in both Spark 2.3 and 2.4. However, `from_utc_timestamp('2000-10-10 00:00:00+00:00', 'GMT+1')`, assuming a local timezone of GMT+8, will return `2000-10-10 09:00:00` in Spark 2.3 but `null` in 2.4. For people who don't care about this problem and want to retain the previous behaivor to keep their query unchanged, you can set `spark.sql.function.rejectTimezoneInString` to false. This option will be removed in Spark 3.0 and should only be used as a temporary workaround. + - In version 2.3 and earlier, `to_utc_timestamp` and `from_utc_timestamp` respect the timezone in the input timestamp string, which breaks the assumption that the input timestamp is in a specific timezone. Therefore, these 2 functions can return unexpected results. In version 2.4 and later, this problem has been fixed. `to_utc_timestamp` and `from_utc_timestamp` will return null if the input timestamp string contains timezone. As an example, `from_utc_timestamp('2000-10-10 00:00:00', 'GMT+1')` will return `2000-10-10 01:00:00` in both Spark 2.3 and 2.4. However, `from_utc_timestamp('2000-10-10 00:00:00+00:00', 'GMT+1')`, assuming a local timezone of GMT+8, will return `2000-10-10 09:00:00` in Spark 2.3 but `null` in 2.4. For people who don't care about this problem and want to retain the previous behavior to keep their query unchanged, you can set `spark.sql.function.rejectTimezoneInString` to false. This option will be removed in Spark 3.0 and should only be used as a temporary workaround. - In version 2.3 and earlier, Spark converts Parquet Hive tables by default but ignores table properties like `TBLPROPERTIES (parquet.compression 'NONE')`. This happens for ORC Hive table properties like `TBLPROPERTIES (orc.compress 'NONE')` in case of `spark.sql.hive.convertMetastoreOrc=true`, too. Since Spark 2.4, Spark respects Parquet/ORC specific table properties while converting Parquet/ORC Hive tables. As an example, `CREATE TABLE t(id int) STORED AS PARQUET TBLPROPERTIES (parquet.compression 'NONE')` would generate Snappy parquet files during insertion in Spark 2.3, and in Spark 2.4, the result would be uncompressed parquet files. - Since Spark 2.0, Spark converts Parquet Hive tables by default for better performance. Since Spark 2.4, Spark converts ORC Hive tables by default, too. It means Spark uses its own ORC support by default instead of Hive SerDe. As an example, `CREATE TABLE t(id int) STORED AS ORC` would be handled with Hive SerDe in Spark 2.3, and in Spark 2.4, it would be converted into Spark's ORC data source table and ORC vectorization would be applied. To set `false` to `spark.sql.hive.convertMetastoreOrc` restores the previous behavior. - In version 2.3 and earlier, CSV rows are considered as malformed if at least one column value in the row is malformed. CSV parser dropped such rows in the DROPMALFORMED mode or outputs an error in the FAILFAST mode. Since Spark 2.4, CSV row is considered as malformed only when it contains malformed column values requested from CSV datasource, other values can be ignored. As an example, CSV file contains the "id,name" header and one row "1234". In Spark 2.4, selection of the id column consists of a row with one column value 1234 but in Spark 2.3 and earlier it is empty in the DROPMALFORMED mode. To restore the previous behavior, set `spark.sql.csv.parser.columnPruning.enabled` to `false`. + - Since Spark 2.4, File listing for compute statistics is done in parallel by default. This can be disabled by setting `spark.sql.parallelFileListingInStatsComputation.enabled` to `False`. + - Since Spark 2.4, Metadata files (e.g. Parquet summary files) and temporary files are not counted as data files when calculating table size during Statistics computation. + - Since Spark 2.4, empty strings are saved as quoted empty strings `""`. In version 2.3 and earlier, empty strings are equal to `null` values and do not reflect to any characters in saved CSV files. For example, the row of `"a", null, "", 1` was writted as `a,,,1`. Since Spark 2.4, the same row is saved as `a,,"",1`. To restore the previous behavior, set the CSV option `emptyValue` to empty (not quoted) string. ## Upgrading From Spark SQL 2.3.0 to 2.3.1 and above @@ -2124,7 +2166,7 @@ See the API docs for `SQLContext.read` ( Python ) and `DataFrame.write` ( Scala, - Java, + Java, Python ) more information. @@ -3037,3 +3079,10 @@ Specifically: - In aggregations, all NaN values are grouped together. - NaN is treated as a normal value in join keys. - NaN values go last when in ascending order, larger than any other numeric value. + + ## Arithmetic operations + +Operations performed on numeric types (with the exception of `decimal`) are not checked for overflow. +This means that in case an operation causes an overflow, the result is the same that the same operation +returns in a Java/Scala program (eg. if the sum of 2 integers is higher than the maximum value representable, +the result is a negative number). diff --git a/docs/streaming-kinesis-integration.md b/docs/streaming-kinesis-integration.md index 678b0643fd706..6a52e8a7b0ebd 100644 --- a/docs/streaming-kinesis-integration.md +++ b/docs/streaming-kinesis-integration.md @@ -196,7 +196,7 @@ A Kinesis stream can be set up at one of the valid Kinesis endpoints with 1 or m #### Running the Example To run the example, -- Download a Spark binary from the [download site](http://spark.apache.org/downloads.html). +- Download a Spark binary from the [download site](https://spark.apache.org/downloads.html). - Set up Kinesis stream (see earlier section) within AWS. Note the name of the Kinesis stream and the endpoint URL corresponding to the region where the stream was created. diff --git a/docs/streaming-programming-guide.md b/docs/streaming-programming-guide.md index c30959263cdfa..0ca0f2a8b54d5 100644 --- a/docs/streaming-programming-guide.md +++ b/docs/streaming-programming-guide.md @@ -915,8 +915,7 @@ JavaPairDStream runningCounts = pairs.updateStateByKey(updateFu The update function will be called for each word, with `newValues` having a sequence of 1's (from the `(word, 1)` pairs) and the `runningCount` having the previous count. For the complete Java code, take a look at the example -[JavaStatefulNetworkWordCount.java]({{site.SPARK_GITHUB_URL}}/blob/v{{site.SPARK_VERSION_SHORT}}/examples/src/main/java/org/apache/spark/examples/streaming -/JavaStatefulNetworkWordCount.java). +[JavaStatefulNetworkWordCount.java]({{site.SPARK_GITHUB_URL}}/blob/v{{site.SPARK_VERSION_SHORT}}/examples/src/main/java/org/apache/spark/examples/streaming/JavaStatefulNetworkWordCount.java).
    @@ -2176,6 +2175,8 @@ the input data stream (using `inputStream.repartition()`). This distributes the received batches of data across the specified number of machines in the cluster before further processing. +For direct stream, please refer to [Spark Streaming + Kafka Integration Guide](streaming-kafka-integration.html) + ### Level of Parallelism in Data Processing {:.no_toc} Cluster resources can be under-utilized if the number of parallel tasks used in any stage of the @@ -2468,7 +2469,7 @@ additional effort may be necessary to achieve exactly-once semantics. There are - [Kafka Integration Guide](streaming-kafka-integration.html) - [Kinesis Integration Guide](streaming-kinesis-integration.html) - [Custom Receiver Guide](streaming-custom-receivers.html) -* Third-party DStream data sources can be found in [Third Party Projects](http://spark.apache.org/third-party-projects.html) +* Third-party DStream data sources can be found in [Third Party Projects](https://spark.apache.org/third-party-projects.html) * API documentation - Scala docs * [StreamingContext](api/scala/index.html#org.apache.spark.streaming.StreamingContext) and diff --git a/docs/structured-streaming-programming-guide.md b/docs/structured-streaming-programming-guide.md index 0842e8dd88672..73de1892977ac 100644 --- a/docs/structured-streaming-programming-guide.md +++ b/docs/structured-streaming-programming-guide.md @@ -17,7 +17,7 @@ In this guide, we are going to walk you through the programming model and the AP # Quick Example Let’s say you want to maintain a running word count of text data received from a data server listening on a TCP socket. Let’s see how you can express this using Structured Streaming. You can see the full code in [Scala]({{site.SPARK_GITHUB_URL}}/blob/v{{site.SPARK_VERSION_SHORT}}/examples/src/main/scala/org/apache/spark/examples/sql/streaming/StructuredNetworkWordCount.scala)/[Java]({{site.SPARK_GITHUB_URL}}/blob/v{{site.SPARK_VERSION_SHORT}}/examples/src/main/java/org/apache/spark/examples/sql/streaming/JavaStructuredNetworkWordCount.java)/[Python]({{site.SPARK_GITHUB_URL}}/blob/v{{site.SPARK_VERSION_SHORT}}/examples/src/main/python/sql/streaming/structured_network_wordcount.py)/[R]({{site.SPARK_GITHUB_URL}}/blob/v{{site.SPARK_VERSION_SHORT}}/examples/src/main/r/streaming/structured_network_wordcount.R). -And if you [download Spark](http://spark.apache.org/downloads.html), you can directly [run the example](index.html#running-the-examples-and-shell). In any case, let’s walk through the example step-by-step and understand how it works. First, we have to import the necessary classes and create a local SparkSession, the starting point of all functionalities related to Spark. +And if you [download Spark](https://spark.apache.org/downloads.html), you can directly [run the example](index.html#running-the-examples-and-shell). In any case, let’s walk through the example step-by-step and understand how it works. First, we have to import the necessary classes and create a local SparkSession, the starting point of all functionalities related to Spark.
    @@ -522,7 +522,7 @@ Here are the details of all the sources in Spark.
    maxFilesPerTrigger: maximum number of new files to be considered in every trigger (default: no max)
    - latestFirst: whether to processs the latest new files first, useful when there is a large backlog of files (default: false) + latestFirst: whether to process the latest new files first, useful when there is a large backlog of files (default: false)
    fileNameOnly: whether to check new files based on only the filename instead of on the full path (default: false). With this set to `true`, the following files would be considered as the same file, because their filenames, "dataset.txt", are the same:
    @@ -1005,7 +1005,7 @@ Here is an illustration. As shown in the illustration, the maximum event time tracked by the engine is the *blue dashed line*, and the watermark set as `(max event time - '10 mins')` -at the beginning of every trigger is the red line For example, when the engine observes the data +at the beginning of every trigger is the red line. For example, when the engine observes the data `(12:14, dog)`, it sets the watermark for the next trigger as `12:04`. This watermark lets the engine maintain intermediate state for additional 10 minutes to allow late data to be counted. For example, the data `(12:09, cat)` is out of order and late, and it falls in @@ -1162,7 +1162,7 @@ In other words, you will have to do the following additional steps in the join. old rows of one input is not going to be required (i.e. will not satisfy the time constraint) for matches with the other input. This constraint can be defined in one of the two ways. - 1. Time range join conditions (e.g. `...JOIN ON leftTime BETWEN rightTime AND rightTime + INTERVAL 1 HOUR`), + 1. Time range join conditions (e.g. `...JOIN ON leftTime BETWEEN rightTime AND rightTime + INTERVAL 1 HOUR`), 1. Join on event-time windows (e.g. `...JOIN ON leftTimeWindow = rightTimeWindow`). diff --git a/docs/tuning.md b/docs/tuning.md index 1c3bd0e8758ff..f60971aa2e0af 100644 --- a/docs/tuning.md +++ b/docs/tuning.md @@ -35,7 +35,7 @@ in your operations) and performance. It provides two serialization libraries: Java serialization is flexible but often quite slow, and leads to large serialized formats for many classes. * [Kryo serialization](https://github.com/EsotericSoftware/kryo): Spark can also use - the Kryo library (version 2) to serialize objects more quickly. Kryo is significantly + the Kryo library (version 4) to serialize objects more quickly. Kryo is significantly faster and more compact than Java serialization (often as much as 10x), but does not support all `Serializable` types and requires you to *register* the classes you'll use in the program in advance for best performance. diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaSummarizerExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaSummarizerExample.java new file mode 100644 index 0000000000000..e9b84365d86ed --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaSummarizerExample.java @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import org.apache.spark.sql.*; + +// $example on$ +import java.util.Arrays; +import java.util.List; + +import org.apache.spark.ml.linalg.Vector; +import org.apache.spark.ml.linalg.Vectors; +import org.apache.spark.ml.linalg.VectorUDT; +import org.apache.spark.ml.stat.Summarizer; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +// $example off$ + +public class JavaSummarizerExample { + public static void main(String[] args) { + SparkSession spark = SparkSession + .builder() + .appName("JavaSummarizerExample") + .getOrCreate(); + + // $example on$ + List data = Arrays.asList( + RowFactory.create(Vectors.dense(2.0, 3.0, 5.0), 1.0), + RowFactory.create(Vectors.dense(4.0, 6.0, 7.0), 2.0) + ); + + StructType schema = new StructType(new StructField[]{ + new StructField("features", new VectorUDT(), false, Metadata.empty()), + new StructField("weight", DataTypes.DoubleType, false, Metadata.empty()) + }); + + Dataset df = spark.createDataFrame(data, schema); + + Row result1 = df.select(Summarizer.metrics("mean", "variance") + .summary(new Column("features"), new Column("weight")).as("summary")) + .select("summary.mean", "summary.variance").first(); + System.out.println("with weight: mean = " + result1.getAs(0).toString() + + ", variance = " + result1.getAs(1).toString()); + + Row result2 = df.select( + Summarizer.mean(new Column("features")), + Summarizer.variance(new Column("features")) + ).first(); + System.out.println("without weight: mean = " + result2.getAs(0).toString() + + ", variance = " + result2.getAs(1).toString()); + // $example off$ + spark.stop(); + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaHypothesisTestingExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaHypothesisTestingExample.java index b48b95ff1d2a3..273273652c955 100644 --- a/examples/src/main/java/org/apache/spark/examples/mllib/JavaHypothesisTestingExample.java +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaHypothesisTestingExample.java @@ -67,7 +67,7 @@ public static void main(String[] args) { ) ); - // The contingency table is constructed from the raw (feature, label) pairs and used to conduct + // The contingency table is constructed from the raw (label, feature) pairs and used to conduct // the independence test. Returns an array containing the ChiSquaredTestResult for every feature // against the label. ChiSqTestResult[] featureTestResults = Statistics.chiSqTest(obs.rdd()); diff --git a/examples/src/main/java/org/apache/spark/examples/streaming/JavaDirectKafkaWordCount.java b/examples/src/main/java/org/apache/spark/examples/streaming/JavaDirectKafkaWordCount.java index b6b163fa8b2cd..748bf58f30350 100644 --- a/examples/src/main/java/org/apache/spark/examples/streaming/JavaDirectKafkaWordCount.java +++ b/examples/src/main/java/org/apache/spark/examples/streaming/JavaDirectKafkaWordCount.java @@ -26,7 +26,9 @@ import scala.Tuple2; +import org.apache.kafka.clients.consumer.ConsumerConfig; import org.apache.kafka.clients.consumer.ConsumerRecord; +import org.apache.kafka.common.serialization.StringDeserializer; import org.apache.spark.SparkConf; import org.apache.spark.streaming.api.java.*; @@ -37,30 +39,33 @@ /** * Consumes messages from one or more topics in Kafka and does wordcount. - * Usage: JavaDirectKafkaWordCount + * Usage: JavaDirectKafkaWordCount * is a list of one or more Kafka brokers + * is a consumer group name to consume from topics * is a list of one or more kafka topics to consume from * * Example: * $ bin/run-example streaming.JavaDirectKafkaWordCount broker1-host:port,broker2-host:port \ - * topic1,topic2 + * consumer-group topic1,topic2 */ public final class JavaDirectKafkaWordCount { private static final Pattern SPACE = Pattern.compile(" "); public static void main(String[] args) throws Exception { - if (args.length < 2) { - System.err.println("Usage: JavaDirectKafkaWordCount \n" + - " is a list of one or more Kafka brokers\n" + - " is a list of one or more kafka topics to consume from\n\n"); + if (args.length < 3) { + System.err.println("Usage: JavaDirectKafkaWordCount \n" + + " is a list of one or more Kafka brokers\n" + + " is a consumer group name to consume from topics\n" + + " is a list of one or more kafka topics to consume from\n\n"); System.exit(1); } StreamingExamples.setStreamingLogLevels(); String brokers = args[0]; - String topics = args[1]; + String groupId = args[1]; + String topics = args[2]; // Create context with a 2 seconds batch interval SparkConf sparkConf = new SparkConf().setAppName("JavaDirectKafkaWordCount"); @@ -68,7 +73,10 @@ public static void main(String[] args) throws Exception { Set topicsSet = new HashSet<>(Arrays.asList(topics.split(","))); Map kafkaParams = new HashMap<>(); - kafkaParams.put("metadata.broker.list", brokers); + kafkaParams.put(ConsumerConfig.BOOTSTRAP_SERVERS_CONFIG, brokers); + kafkaParams.put(ConsumerConfig.GROUP_ID_CONFIG, groupId); + kafkaParams.put(ConsumerConfig.KEY_DESERIALIZER_CLASS_CONFIG, StringDeserializer.class); + kafkaParams.put(ConsumerConfig.VALUE_DESERIALIZER_CLASS_CONFIG, StringDeserializer.class); // Create direct kafka stream with brokers and topics JavaInputDStream> messages = KafkaUtils.createDirectStream( diff --git a/examples/src/main/python/ml/summarizer_example.py b/examples/src/main/python/ml/summarizer_example.py new file mode 100644 index 0000000000000..8835f189a1ad4 --- /dev/null +++ b/examples/src/main/python/ml/summarizer_example.py @@ -0,0 +1,59 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +An example for summarizer. +Run with: + bin/spark-submit examples/src/main/python/ml/summarizer_example.py +""" +from __future__ import print_function + +from pyspark.sql import SparkSession +# $example on$ +from pyspark.ml.stat import Summarizer +from pyspark.sql import Row +from pyspark.ml.linalg import Vectors +# $example off$ + +if __name__ == "__main__": + spark = SparkSession \ + .builder \ + .appName("SummarizerExample") \ + .getOrCreate() + sc = spark.sparkContext + + # $example on$ + df = sc.parallelize([Row(weight=1.0, features=Vectors.dense(1.0, 1.0, 1.0)), + Row(weight=0.0, features=Vectors.dense(1.0, 2.0, 3.0))]).toDF() + + # create summarizer for multiple metrics "mean" and "count" + summarizer = Summarizer.metrics("mean", "count") + + # compute statistics for multiple metrics with weight + df.select(summarizer.summary(df.features, df.weight)).show(truncate=False) + + # compute statistics for multiple metrics without weight + df.select(summarizer.summary(df.features)).show(truncate=False) + + # compute statistics for single metric "mean" with weight + df.select(Summarizer.mean(df.features, df.weight)).show(truncate=False) + + # compute statistics for single metric "mean" without weight + df.select(Summarizer.mean(df.features)).show(truncate=False) + # $example off$ + + spark.stop() diff --git a/examples/src/main/python/mllib/hypothesis_testing_example.py b/examples/src/main/python/mllib/hypothesis_testing_example.py index e566ead0d318d..21a5584fd6e06 100644 --- a/examples/src/main/python/mllib/hypothesis_testing_example.py +++ b/examples/src/main/python/mllib/hypothesis_testing_example.py @@ -51,7 +51,7 @@ [LabeledPoint(1.0, [1.0, 0.0, 3.0]), LabeledPoint(1.0, [1.0, 2.0, 0.0]), LabeledPoint(1.0, [-1.0, 0.0, -0.5])] - ) # LabeledPoint(feature, label) + ) # LabeledPoint(label, feature) # The contingency table is constructed from an RDD of LabeledPoint and used to conduct # the independence test. Returns an array containing the ChiSquaredTestResult for every feature diff --git a/examples/src/main/python/sql/arrow.py b/examples/src/main/python/sql/arrow.py index 4c5aefb6ff4a6..5eb164b20ad04 100644 --- a/examples/src/main/python/sql/arrow.py +++ b/examples/src/main/python/sql/arrow.py @@ -95,12 +95,12 @@ def grouped_map_pandas_udf_example(spark): ("id", "v")) @pandas_udf("id long, v double", PandasUDFType.GROUPED_MAP) - def substract_mean(pdf): + def subtract_mean(pdf): # pdf is a pandas.DataFrame v = pdf.v return pdf.assign(v=v - v.mean()) - df.groupby("id").apply(substract_mean).show() + df.groupby("id").apply(subtract_mean).show() # +---+----+ # | id| v| # +---+----+ @@ -113,6 +113,43 @@ def substract_mean(pdf): # $example off:grouped_map_pandas_udf$ +def grouped_agg_pandas_udf_example(spark): + # $example on:grouped_agg_pandas_udf$ + from pyspark.sql.functions import pandas_udf, PandasUDFType + from pyspark.sql import Window + + df = spark.createDataFrame( + [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)], + ("id", "v")) + + @pandas_udf("double", PandasUDFType.GROUPED_AGG) + def mean_udf(v): + return v.mean() + + df.groupby("id").agg(mean_udf(df['v'])).show() + # +---+-----------+ + # | id|mean_udf(v)| + # +---+-----------+ + # | 1| 1.5| + # | 2| 6.0| + # +---+-----------+ + + w = Window \ + .partitionBy('id') \ + .rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing) + df.withColumn('mean_v', mean_udf(df['v']).over(w)).show() + # +---+----+------+ + # | id| v|mean_v| + # +---+----+------+ + # | 1| 1.0| 1.5| + # | 1| 2.0| 1.5| + # | 2| 3.0| 6.0| + # | 2| 5.0| 6.0| + # | 2|10.0| 6.0| + # +---+----+------+ + # $example off:grouped_agg_pandas_udf$ + + if __name__ == "__main__": spark = SparkSession \ .builder \ diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/SummarizerExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/SummarizerExample.scala new file mode 100644 index 0000000000000..2f54d1d81bc48 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/SummarizerExample.scala @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +// $example on$ +import org.apache.spark.ml.linalg.{Vector, Vectors} +import org.apache.spark.ml.stat.Summarizer +// $example off$ +import org.apache.spark.sql.SparkSession + +object SummarizerExample { + def main(args: Array[String]): Unit = { + val spark = SparkSession + .builder + .appName("SummarizerExample") + .getOrCreate() + + import spark.implicits._ + import Summarizer._ + + // $example on$ + val data = Seq( + (Vectors.dense(2.0, 3.0, 5.0), 1.0), + (Vectors.dense(4.0, 6.0, 7.0), 2.0) + ) + + val df = data.toDF("features", "weight") + + val (meanVal, varianceVal) = df.select(metrics("mean", "variance") + .summary($"features", $"weight").as("summary")) + .select("summary.mean", "summary.variance") + .as[(Vector, Vector)].first() + + println(s"with weight: mean = ${meanVal}, variance = ${varianceVal}") + + val (meanVal2, varianceVal2) = df.select(mean($"features"), variance($"features")) + .as[(Vector, Vector)].first() + + println(s"without weight: mean = ${meanVal2}, sum = ${varianceVal2}") + // $example off$ + + spark.stop() + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/HypothesisTestingExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/HypothesisTestingExample.scala index add1719739539..9b3c3266ee30a 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/HypothesisTestingExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/HypothesisTestingExample.scala @@ -61,9 +61,9 @@ object HypothesisTestingExample { LabeledPoint(-1.0, Vectors.dense(-1.0, 0.0, -0.5) ) ) - ) // (feature, label) pairs. + ) // (label, feature) pairs. - // The contingency table is constructed from the raw (feature, label) pairs and used to conduct + // The contingency table is constructed from the raw (label, feature) pairs and used to conduct // the independence test. Returns an array containing the ChiSquaredTestResult for every feature // against the label. val featureTestResults: Array[ChiSqTestResult] = Statistics.chiSqTest(obs) diff --git a/external/avro/pom.xml b/external/avro/pom.xml new file mode 100644 index 0000000000000..8f118ba48201b --- /dev/null +++ b/external/avro/pom.xml @@ -0,0 +1,78 @@ + + + + + 4.0.0 + + org.apache.spark + spark-parent_2.11 + 2.4.0-SNAPSHOT + ../../pom.xml + + + spark-avro_2.11 + + avro + + jar + Spark Avro + http://spark.apache.org/ + + + + org.apache.spark + spark-sql_${scala.binary.version} + ${project.version} + provided + + + org.apache.spark + spark-core_${scala.binary.version} + ${project.version} + test-jar + test + + + org.apache.spark + spark-catalyst_${scala.binary.version} + ${project.version} + test-jar + test + + + org.apache.spark + spark-sql_${scala.binary.version} + ${project.version} + test-jar + test + + + org.scalacheck + scalacheck_${scala.binary.version} + test + + + org.apache.spark + spark-tags_${scala.binary.version} + + + + target/scala-${scala.binary.version}/classes + target/scala-${scala.binary.version}/test-classes + + diff --git a/external/avro/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister b/external/avro/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister new file mode 100644 index 0000000000000..95835f0d4ca49 --- /dev/null +++ b/external/avro/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister @@ -0,0 +1 @@ +org.apache.spark.sql.avro.AvroFileFormat diff --git a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroDataToCatalyst.scala b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroDataToCatalyst.scala new file mode 100644 index 0000000000000..915769fa708b0 --- /dev/null +++ b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroDataToCatalyst.scala @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.avro + +import org.apache.avro.Schema +import org.apache.avro.generic.GenericDatumReader +import org.apache.avro.io.{BinaryDecoder, DecoderFactory} + +import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, Expression, UnaryExpression} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} +import org.apache.spark.sql.types.{AbstractDataType, BinaryType, DataType} + +case class AvroDataToCatalyst(child: Expression, jsonFormatSchema: String) + extends UnaryExpression with ExpectsInputTypes { + + override def inputTypes: Seq[AbstractDataType] = Seq(BinaryType) + + override lazy val dataType: DataType = SchemaConverters.toSqlType(avroSchema).dataType + + override def nullable: Boolean = true + + @transient private lazy val avroSchema = new Schema.Parser().parse(jsonFormatSchema) + + @transient private lazy val reader = new GenericDatumReader[Any](avroSchema) + + @transient private lazy val deserializer = new AvroDeserializer(avroSchema, dataType) + + @transient private var decoder: BinaryDecoder = _ + + @transient private var result: Any = _ + + override def nullSafeEval(input: Any): Any = { + val binary = input.asInstanceOf[Array[Byte]] + decoder = DecoderFactory.get().binaryDecoder(binary, 0, binary.length, decoder) + result = reader.read(result, decoder) + deserializer.deserialize(result) + } + + override def simpleString: String = { + s"from_avro(${child.sql}, ${dataType.simpleString})" + } + + override def sql: String = { + s"from_avro(${child.sql}, ${dataType.catalogString})" + } + + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val expr = ctx.addReferenceObj("this", this) + defineCodeGen(ctx, ev, input => + s"(${CodeGenerator.boxedType(dataType)})$expr.nullSafeEval($input)") + } +} diff --git a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroDeserializer.scala b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroDeserializer.scala new file mode 100644 index 0000000000000..272e7d5b388d9 --- /dev/null +++ b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroDeserializer.scala @@ -0,0 +1,394 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.avro + +import java.math.{BigDecimal} +import java.nio.ByteBuffer + +import scala.collection.JavaConverters._ +import scala.collection.mutable.ArrayBuffer + +import org.apache.avro.{LogicalTypes, Schema, SchemaBuilder} +import org.apache.avro.Conversions.DecimalConversion +import org.apache.avro.LogicalTypes.{TimestampMicros, TimestampMillis} +import org.apache.avro.Schema.Type._ +import org.apache.avro.generic._ +import org.apache.avro.util.Utf8 + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{SpecificInternalRow, UnsafeArrayData} +import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, DateTimeUtils, GenericArrayData} +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String + +/** + * A deserializer to deserialize data in avro format to data in catalyst format. + */ +class AvroDeserializer(rootAvroType: Schema, rootCatalystType: DataType) { + private lazy val decimalConversions = new DecimalConversion() + + private val converter: Any => Any = rootCatalystType match { + // A shortcut for empty schema. + case st: StructType if st.isEmpty => + (data: Any) => InternalRow.empty + + case st: StructType => + val resultRow = new SpecificInternalRow(st.map(_.dataType)) + val fieldUpdater = new RowUpdater(resultRow) + val writer = getRecordWriter(rootAvroType, st, Nil) + (data: Any) => { + val record = data.asInstanceOf[GenericRecord] + writer(fieldUpdater, record) + resultRow + } + + case _ => + val tmpRow = new SpecificInternalRow(Seq(rootCatalystType)) + val fieldUpdater = new RowUpdater(tmpRow) + val writer = newWriter(rootAvroType, rootCatalystType, Nil) + (data: Any) => { + writer(fieldUpdater, 0, data) + tmpRow.get(0, rootCatalystType) + } + } + + def deserialize(data: Any): Any = converter(data) + + /** + * Creates a writer to write avro values to Catalyst values at the given ordinal with the given + * updater. + */ + private def newWriter( + avroType: Schema, + catalystType: DataType, + path: List[String]): (CatalystDataUpdater, Int, Any) => Unit = + (avroType.getType, catalystType) match { + case (NULL, NullType) => (updater, ordinal, _) => + updater.setNullAt(ordinal) + + // TODO: we can avoid boxing if future version of avro provide primitive accessors. + case (BOOLEAN, BooleanType) => (updater, ordinal, value) => + updater.setBoolean(ordinal, value.asInstanceOf[Boolean]) + + case (INT, IntegerType) => (updater, ordinal, value) => + updater.setInt(ordinal, value.asInstanceOf[Int]) + + case (INT, DateType) => (updater, ordinal, value) => + updater.setInt(ordinal, value.asInstanceOf[Int]) + + case (LONG, LongType) => (updater, ordinal, value) => + updater.setLong(ordinal, value.asInstanceOf[Long]) + + case (LONG, TimestampType) => avroType.getLogicalType match { + case _: TimestampMillis => (updater, ordinal, value) => + updater.setLong(ordinal, value.asInstanceOf[Long] * 1000) + case _: TimestampMicros => (updater, ordinal, value) => + updater.setLong(ordinal, value.asInstanceOf[Long]) + case null => (updater, ordinal, value) => + // For backward compatibility, if the Avro type is Long and it is not logical type, + // the value is processed as timestamp type with millisecond precision. + updater.setLong(ordinal, value.asInstanceOf[Long] * 1000) + case other => throw new IncompatibleSchemaException( + s"Cannot convert Avro logical type ${other} to Catalyst Timestamp type.") + } + + // Before we upgrade Avro to 1.8 for logical type support, spark-avro converts Long to Date. + // For backward compatibility, we still keep this conversion. + case (LONG, DateType) => (updater, ordinal, value) => + updater.setInt(ordinal, (value.asInstanceOf[Long] / DateTimeUtils.MILLIS_PER_DAY).toInt) + + case (FLOAT, FloatType) => (updater, ordinal, value) => + updater.setFloat(ordinal, value.asInstanceOf[Float]) + + case (DOUBLE, DoubleType) => (updater, ordinal, value) => + updater.setDouble(ordinal, value.asInstanceOf[Double]) + + case (STRING, StringType) => (updater, ordinal, value) => + val str = value match { + case s: String => UTF8String.fromString(s) + case s: Utf8 => + val bytes = new Array[Byte](s.getByteLength) + System.arraycopy(s.getBytes, 0, bytes, 0, s.getByteLength) + UTF8String.fromBytes(bytes) + } + updater.set(ordinal, str) + + case (ENUM, StringType) => (updater, ordinal, value) => + updater.set(ordinal, UTF8String.fromString(value.toString)) + + case (FIXED, BinaryType) => (updater, ordinal, value) => + updater.set(ordinal, value.asInstanceOf[GenericFixed].bytes().clone()) + + case (BYTES, BinaryType) => (updater, ordinal, value) => + val bytes = value match { + case b: ByteBuffer => + val bytes = new Array[Byte](b.remaining) + b.get(bytes) + bytes + case b: Array[Byte] => b + case other => throw new RuntimeException(s"$other is not a valid avro binary.") + } + updater.set(ordinal, bytes) + + case (FIXED, d: DecimalType) => (updater, ordinal, value) => + val bigDecimal = decimalConversions.fromFixed(value.asInstanceOf[GenericFixed], avroType, + LogicalTypes.decimal(d.precision, d.scale)) + val decimal = createDecimal(bigDecimal, d.precision, d.scale) + updater.setDecimal(ordinal, decimal) + + case (BYTES, d: DecimalType) => (updater, ordinal, value) => + val bigDecimal = decimalConversions.fromBytes(value.asInstanceOf[ByteBuffer], avroType, + LogicalTypes.decimal(d.precision, d.scale)) + val decimal = createDecimal(bigDecimal, d.precision, d.scale) + updater.setDecimal(ordinal, decimal) + + case (RECORD, st: StructType) => + val writeRecord = getRecordWriter(avroType, st, path) + (updater, ordinal, value) => + val row = new SpecificInternalRow(st) + writeRecord(new RowUpdater(row), value.asInstanceOf[GenericRecord]) + updater.set(ordinal, row) + + case (ARRAY, ArrayType(elementType, containsNull)) => + val elementWriter = newWriter(avroType.getElementType, elementType, path) + (updater, ordinal, value) => + val array = value.asInstanceOf[GenericData.Array[Any]] + val len = array.size() + val result = createArrayData(elementType, len) + val elementUpdater = new ArrayDataUpdater(result) + + var i = 0 + while (i < len) { + val element = array.get(i) + if (element == null) { + if (!containsNull) { + throw new RuntimeException(s"Array value at path ${path.mkString(".")} is not " + + "allowed to be null") + } else { + elementUpdater.setNullAt(i) + } + } else { + elementWriter(elementUpdater, i, element) + } + i += 1 + } + + updater.set(ordinal, result) + + case (MAP, MapType(keyType, valueType, valueContainsNull)) if keyType == StringType => + val keyWriter = newWriter(SchemaBuilder.builder().stringType(), StringType, path) + val valueWriter = newWriter(avroType.getValueType, valueType, path) + (updater, ordinal, value) => + val map = value.asInstanceOf[java.util.Map[AnyRef, AnyRef]] + val keyArray = createArrayData(keyType, map.size()) + val keyUpdater = new ArrayDataUpdater(keyArray) + val valueArray = createArrayData(valueType, map.size()) + val valueUpdater = new ArrayDataUpdater(valueArray) + val iter = map.entrySet().iterator() + var i = 0 + while (iter.hasNext) { + val entry = iter.next() + assert(entry.getKey != null) + keyWriter(keyUpdater, i, entry.getKey) + if (entry.getValue == null) { + if (!valueContainsNull) { + throw new RuntimeException(s"Map value at path ${path.mkString(".")} is not " + + "allowed to be null") + } else { + valueUpdater.setNullAt(i) + } + } else { + valueWriter(valueUpdater, i, entry.getValue) + } + i += 1 + } + + updater.set(ordinal, new ArrayBasedMapData(keyArray, valueArray)) + + case (UNION, _) => + val allTypes = avroType.getTypes.asScala + val nonNullTypes = allTypes.filter(_.getType != NULL) + if (nonNullTypes.nonEmpty) { + if (nonNullTypes.length == 1) { + newWriter(nonNullTypes.head, catalystType, path) + } else { + nonNullTypes.map(_.getType) match { + case Seq(a, b) if Set(a, b) == Set(INT, LONG) && catalystType == LongType => + (updater, ordinal, value) => value match { + case null => updater.setNullAt(ordinal) + case l: java.lang.Long => updater.setLong(ordinal, l) + case i: java.lang.Integer => updater.setLong(ordinal, i.longValue()) + } + + case Seq(a, b) if Set(a, b) == Set(FLOAT, DOUBLE) && catalystType == DoubleType => + (updater, ordinal, value) => value match { + case null => updater.setNullAt(ordinal) + case d: java.lang.Double => updater.setDouble(ordinal, d) + case f: java.lang.Float => updater.setDouble(ordinal, f.doubleValue()) + } + + case _ => + catalystType match { + case st: StructType if st.length == nonNullTypes.size => + val fieldWriters = nonNullTypes.zip(st.fields).map { + case (schema, field) => newWriter(schema, field.dataType, path :+ field.name) + }.toArray + (updater, ordinal, value) => { + val row = new SpecificInternalRow(st) + val fieldUpdater = new RowUpdater(row) + val i = GenericData.get().resolveUnion(avroType, value) + fieldWriters(i)(fieldUpdater, i, value) + updater.set(ordinal, row) + } + + case _ => + throw new IncompatibleSchemaException( + s"Cannot convert Avro to catalyst because schema at path " + + s"${path.mkString(".")} is not compatible " + + s"(avroType = $avroType, sqlType = $catalystType).\n" + + s"Source Avro schema: $rootAvroType.\n" + + s"Target Catalyst type: $rootCatalystType") + } + } + } + } else { + (updater, ordinal, value) => updater.setNullAt(ordinal) + } + + case _ => + throw new IncompatibleSchemaException( + s"Cannot convert Avro to catalyst because schema at path ${path.mkString(".")} " + + s"is not compatible (avroType = $avroType, sqlType = $catalystType).\n" + + s"Source Avro schema: $rootAvroType.\n" + + s"Target Catalyst type: $rootCatalystType") + } + + // TODO: move the following method in Decimal object on creating Decimal from BigDecimal? + private def createDecimal(decimal: BigDecimal, precision: Int, scale: Int): Decimal = { + if (precision <= Decimal.MAX_LONG_DIGITS) { + // Constructs a `Decimal` with an unscaled `Long` value if possible. + Decimal(decimal.unscaledValue().longValue(), precision, scale) + } else { + // Otherwise, resorts to an unscaled `BigInteger` instead. + Decimal(decimal, precision, scale) + } + } + + private def getRecordWriter( + avroType: Schema, + sqlType: StructType, + path: List[String]): (CatalystDataUpdater, GenericRecord) => Unit = { + val validFieldIndexes = ArrayBuffer.empty[Int] + val fieldWriters = ArrayBuffer.empty[(CatalystDataUpdater, Any) => Unit] + + val length = sqlType.length + var i = 0 + while (i < length) { + val sqlField = sqlType.fields(i) + val avroField = avroType.getField(sqlField.name) + if (avroField != null) { + validFieldIndexes += avroField.pos() + + val baseWriter = newWriter(avroField.schema(), sqlField.dataType, path :+ sqlField.name) + val ordinal = i + val fieldWriter = (fieldUpdater: CatalystDataUpdater, value: Any) => { + if (value == null) { + fieldUpdater.setNullAt(ordinal) + } else { + baseWriter(fieldUpdater, ordinal, value) + } + } + fieldWriters += fieldWriter + } else if (!sqlField.nullable) { + throw new IncompatibleSchemaException( + s""" + |Cannot find non-nullable field ${path.mkString(".")}.${sqlField.name} in Avro schema. + |Source Avro schema: $rootAvroType. + |Target Catalyst type: $rootCatalystType. + """.stripMargin) + } + i += 1 + } + + (fieldUpdater, record) => { + var i = 0 + while (i < validFieldIndexes.length) { + fieldWriters(i)(fieldUpdater, record.get(validFieldIndexes(i))) + i += 1 + } + } + } + + private def createArrayData(elementType: DataType, length: Int): ArrayData = elementType match { + case BooleanType => UnsafeArrayData.fromPrimitiveArray(new Array[Boolean](length)) + case ByteType => UnsafeArrayData.fromPrimitiveArray(new Array[Byte](length)) + case ShortType => UnsafeArrayData.fromPrimitiveArray(new Array[Short](length)) + case IntegerType => UnsafeArrayData.fromPrimitiveArray(new Array[Int](length)) + case LongType => UnsafeArrayData.fromPrimitiveArray(new Array[Long](length)) + case FloatType => UnsafeArrayData.fromPrimitiveArray(new Array[Float](length)) + case DoubleType => UnsafeArrayData.fromPrimitiveArray(new Array[Double](length)) + case _ => new GenericArrayData(new Array[Any](length)) + } + + /** + * A base interface for updating values inside catalyst data structure like `InternalRow` and + * `ArrayData`. + */ + sealed trait CatalystDataUpdater { + def set(ordinal: Int, value: Any): Unit + + def setNullAt(ordinal: Int): Unit = set(ordinal, null) + def setBoolean(ordinal: Int, value: Boolean): Unit = set(ordinal, value) + def setByte(ordinal: Int, value: Byte): Unit = set(ordinal, value) + def setShort(ordinal: Int, value: Short): Unit = set(ordinal, value) + def setInt(ordinal: Int, value: Int): Unit = set(ordinal, value) + def setLong(ordinal: Int, value: Long): Unit = set(ordinal, value) + def setDouble(ordinal: Int, value: Double): Unit = set(ordinal, value) + def setFloat(ordinal: Int, value: Float): Unit = set(ordinal, value) + def setDecimal(ordinal: Int, value: Decimal): Unit = set(ordinal, value) + } + + final class RowUpdater(row: InternalRow) extends CatalystDataUpdater { + override def set(ordinal: Int, value: Any): Unit = row.update(ordinal, value) + + override def setNullAt(ordinal: Int): Unit = row.setNullAt(ordinal) + override def setBoolean(ordinal: Int, value: Boolean): Unit = row.setBoolean(ordinal, value) + override def setByte(ordinal: Int, value: Byte): Unit = row.setByte(ordinal, value) + override def setShort(ordinal: Int, value: Short): Unit = row.setShort(ordinal, value) + override def setInt(ordinal: Int, value: Int): Unit = row.setInt(ordinal, value) + override def setLong(ordinal: Int, value: Long): Unit = row.setLong(ordinal, value) + override def setDouble(ordinal: Int, value: Double): Unit = row.setDouble(ordinal, value) + override def setFloat(ordinal: Int, value: Float): Unit = row.setFloat(ordinal, value) + override def setDecimal(ordinal: Int, value: Decimal): Unit = + row.setDecimal(ordinal, value, value.precision) + } + + final class ArrayDataUpdater(array: ArrayData) extends CatalystDataUpdater { + override def set(ordinal: Int, value: Any): Unit = array.update(ordinal, value) + + override def setNullAt(ordinal: Int): Unit = array.setNullAt(ordinal) + override def setBoolean(ordinal: Int, value: Boolean): Unit = array.setBoolean(ordinal, value) + override def setByte(ordinal: Int, value: Byte): Unit = array.setByte(ordinal, value) + override def setShort(ordinal: Int, value: Short): Unit = array.setShort(ordinal, value) + override def setInt(ordinal: Int, value: Int): Unit = array.setInt(ordinal, value) + override def setLong(ordinal: Int, value: Long): Unit = array.setLong(ordinal, value) + override def setDouble(ordinal: Int, value: Double): Unit = array.setDouble(ordinal, value) + override def setFloat(ordinal: Int, value: Float): Unit = array.setFloat(ordinal, value) + override def setDecimal(ordinal: Int, value: Decimal): Unit = array.update(ordinal, value) + } +} diff --git a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroFileFormat.scala b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroFileFormat.scala new file mode 100755 index 0000000000000..6df23c93e4c54 --- /dev/null +++ b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroFileFormat.scala @@ -0,0 +1,228 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.avro + +import java.io._ +import java.net.URI + +import scala.util.control.NonFatal + +import org.apache.avro.Schema +import org.apache.avro.file.DataFileConstants._ +import org.apache.avro.file.DataFileReader +import org.apache.avro.generic.{GenericDatumReader, GenericRecord} +import org.apache.avro.mapred.{AvroOutputFormat, FsInput} +import org.apache.avro.mapreduce.AvroJob +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.{FileStatus, Path} +import org.apache.hadoop.mapreduce.Job + +import org.apache.spark.TaskContext +import org.apache.spark.internal.Logging +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.execution.datasources.{FileFormat, OutputWriterFactory, PartitionedFile} +import org.apache.spark.sql.sources.{DataSourceRegister, Filter} +import org.apache.spark.sql.types.StructType +import org.apache.spark.util.SerializableConfiguration + +private[avro] class AvroFileFormat extends FileFormat + with DataSourceRegister with Logging with Serializable { + + override def equals(other: Any): Boolean = other match { + case _: AvroFileFormat => true + case _ => false + } + + // Dummy hashCode() to appease ScalaStyle. + override def hashCode(): Int = super.hashCode() + + override def inferSchema( + spark: SparkSession, + options: Map[String, String], + files: Seq[FileStatus]): Option[StructType] = { + val conf = spark.sessionState.newHadoopConf() + val parsedOptions = new AvroOptions(options, conf) + + // Schema evolution is not supported yet. Here we only pick a single random sample file to + // figure out the schema of the whole dataset. + val sampleFile = + if (parsedOptions.ignoreExtension) { + files.headOption.getOrElse { + throw new FileNotFoundException("Files for schema inferring have been not found.") + } + } else { + files.find(_.getPath.getName.endsWith(".avro")).getOrElse { + throw new FileNotFoundException( + "No Avro files found. If files don't have .avro extension, set ignoreExtension to true") + } + } + + // User can specify an optional avro json schema. + val avroSchema = parsedOptions.schema + .map(new Schema.Parser().parse) + .getOrElse { + val in = new FsInput(sampleFile.getPath, conf) + try { + val reader = DataFileReader.openReader(in, new GenericDatumReader[GenericRecord]()) + try { + reader.getSchema + } finally { + reader.close() + } + } finally { + in.close() + } + } + + SchemaConverters.toSqlType(avroSchema).dataType match { + case t: StructType => Some(t) + case _ => throw new RuntimeException( + s"""Avro schema cannot be converted to a Spark SQL StructType: + | + |${avroSchema.toString(true)} + |""".stripMargin) + } + } + + override def shortName(): String = "avro" + + override def isSplitable( + sparkSession: SparkSession, + options: Map[String, String], + path: Path): Boolean = true + + override def prepareWrite( + spark: SparkSession, + job: Job, + options: Map[String, String], + dataSchema: StructType): OutputWriterFactory = { + val parsedOptions = new AvroOptions(options, spark.sessionState.newHadoopConf()) + val outputAvroSchema: Schema = parsedOptions.schema + .map(new Schema.Parser().parse) + .getOrElse(SchemaConverters.toAvroType(dataSchema, nullable = false, + parsedOptions.recordName, parsedOptions.recordNamespace)) + + AvroJob.setOutputKeySchema(job, outputAvroSchema) + + if (parsedOptions.compression == "uncompressed") { + job.getConfiguration.setBoolean("mapred.output.compress", false) + } else { + job.getConfiguration.setBoolean("mapred.output.compress", true) + logInfo(s"Compressing Avro output using the ${parsedOptions.compression} codec") + val codec = parsedOptions.compression match { + case DEFLATE_CODEC => + val deflateLevel = spark.sessionState.conf.avroDeflateLevel + logInfo(s"Avro compression level $deflateLevel will be used for $DEFLATE_CODEC codec.") + job.getConfiguration.setInt(AvroOutputFormat.DEFLATE_LEVEL_KEY, deflateLevel) + DEFLATE_CODEC + case codec @ (SNAPPY_CODEC | BZIP2_CODEC | XZ_CODEC) => codec + case unknown => throw new IllegalArgumentException(s"Invalid compression codec: $unknown") + } + job.getConfiguration.set(AvroJob.CONF_OUTPUT_CODEC, codec) + } + + new AvroOutputWriterFactory(dataSchema, outputAvroSchema.toString) + } + + override def buildReader( + spark: SparkSession, + dataSchema: StructType, + partitionSchema: StructType, + requiredSchema: StructType, + filters: Seq[Filter], + options: Map[String, String], + hadoopConf: Configuration): (PartitionedFile) => Iterator[InternalRow] = { + + val broadcastedConf = + spark.sparkContext.broadcast(new SerializableConfiguration(hadoopConf)) + val parsedOptions = new AvroOptions(options, hadoopConf) + + (file: PartitionedFile) => { + val conf = broadcastedConf.value.value + val userProvidedSchema = parsedOptions.schema.map(new Schema.Parser().parse) + + // TODO Removes this check once `FileFormat` gets a general file filtering interface method. + // Doing input file filtering is improper because we may generate empty tasks that process no + // input files but stress the scheduler. We should probably add a more general input file + // filtering mechanism for `FileFormat` data sources. See SPARK-16317. + if (parsedOptions.ignoreExtension || file.filePath.endsWith(".avro")) { + val reader = { + val in = new FsInput(new Path(new URI(file.filePath)), conf) + try { + val datumReader = userProvidedSchema match { + case Some(userSchema) => new GenericDatumReader[GenericRecord](userSchema) + case _ => new GenericDatumReader[GenericRecord]() + } + DataFileReader.openReader(in, datumReader) + } catch { + case NonFatal(e) => + logError("Exception while opening DataFileReader", e) + in.close() + throw e + } + } + + // Ensure that the reader is closed even if the task fails or doesn't consume the entire + // iterator of records. + Option(TaskContext.get()).foreach { taskContext => + taskContext.addTaskCompletionListener[Unit] { _ => + reader.close() + } + } + + reader.sync(file.start) + val stop = file.start + file.length + + val deserializer = + new AvroDeserializer(userProvidedSchema.getOrElse(reader.getSchema), requiredSchema) + + new Iterator[InternalRow] { + private[this] var completed = false + + override def hasNext: Boolean = { + if (completed) { + false + } else { + val r = reader.hasNext && !reader.pastSync(stop) + if (!r) { + reader.close() + completed = true + } + r + } + } + + override def next(): InternalRow = { + if (!hasNext) { + throw new NoSuchElementException("next on empty iterator") + } + val record = reader.next() + deserializer.deserialize(record).asInstanceOf[InternalRow] + } + } + } else { + Iterator.empty + } + } + } +} + +private[avro] object AvroFileFormat { + val IgnoreFilesWithoutExtensionProperty = "avro.mapred.ignore.inputs.without.extension" +} diff --git a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroOptions.scala b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroOptions.scala new file mode 100644 index 0000000000000..67f56343b4524 --- /dev/null +++ b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroOptions.scala @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.avro + +import org.apache.hadoop.conf.Configuration + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap +import org.apache.spark.sql.internal.SQLConf + +/** + * Options for Avro Reader and Writer stored in case insensitive manner. + */ +class AvroOptions( + @transient val parameters: CaseInsensitiveMap[String], + @transient val conf: Configuration) extends Logging with Serializable { + + def this(parameters: Map[String, String], conf: Configuration) = { + this(CaseInsensitiveMap(parameters), conf) + } + + /** + * Optional schema provided by an user in JSON format. + */ + val schema: Option[String] = parameters.get("avroSchema") + + /** + * Top level record name in write result, which is required in Avro spec. + * See https://avro.apache.org/docs/1.8.2/spec.html#schema_record . + * Default value is "topLevelRecord" + */ + val recordName: String = parameters.getOrElse("recordName", "topLevelRecord") + + /** + * Record namespace in write result. Default value is "". + * See Avro spec for details: https://avro.apache.org/docs/1.8.2/spec.html#schema_record . + */ + val recordNamespace: String = parameters.getOrElse("recordNamespace", "") + + /** + * The `ignoreExtension` option controls ignoring of files without `.avro` extensions in read. + * If the option is enabled, all files (with and without `.avro` extension) are loaded. + * If the option is not set, the Hadoop's config `avro.mapred.ignore.inputs.without.extension` + * is taken into account. If the former one is not set too, file extensions are ignored. + */ + val ignoreExtension: Boolean = { + val ignoreFilesWithoutExtensionByDefault = false + val ignoreFilesWithoutExtension = conf.getBoolean( + AvroFileFormat.IgnoreFilesWithoutExtensionProperty, + ignoreFilesWithoutExtensionByDefault) + + parameters + .get("ignoreExtension") + .map(_.toBoolean) + .getOrElse(!ignoreFilesWithoutExtension) + } + + /** + * The `compression` option allows to specify a compression codec used in write. + * Currently supported codecs are `uncompressed`, `snappy`, `deflate`, `bzip2` and `xz`. + * If the option is not set, the `spark.sql.avro.compression.codec` config is taken into + * account. If the former one is not set too, the `snappy` codec is used by default. + */ + val compression: String = { + parameters.get("compression").getOrElse(SQLConf.get.avroCompressionCodec) + } +} diff --git a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroOutputWriter.scala b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroOutputWriter.scala new file mode 100644 index 0000000000000..06507115f5ed8 --- /dev/null +++ b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroOutputWriter.scala @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.avro + +import java.io.{IOException, OutputStream} + +import org.apache.avro.Schema +import org.apache.avro.generic.GenericRecord +import org.apache.avro.mapred.AvroKey +import org.apache.avro.mapreduce.AvroKeyOutputFormat +import org.apache.hadoop.fs.Path +import org.apache.hadoop.io.NullWritable +import org.apache.hadoop.mapreduce.{RecordWriter, TaskAttemptContext} + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.execution.datasources.OutputWriter +import org.apache.spark.sql.types._ + +// NOTE: This class is instantiated and used on executor side only, no need to be serializable. +private[avro] class AvroOutputWriter( + path: String, + context: TaskAttemptContext, + schema: StructType, + avroSchema: Schema) extends OutputWriter { + + // The input rows will never be null. + private lazy val serializer = new AvroSerializer(schema, avroSchema, nullable = false) + + /** + * Overrides the couple of methods responsible for generating the output streams / files so + * that the data can be correctly partitioned + */ + private val recordWriter: RecordWriter[AvroKey[GenericRecord], NullWritable] = + new AvroKeyOutputFormat[GenericRecord]() { + + override def getDefaultWorkFile(context: TaskAttemptContext, extension: String): Path = { + new Path(path) + } + + @throws(classOf[IOException]) + override def getAvroFileOutputStream(c: TaskAttemptContext): OutputStream = { + val path = getDefaultWorkFile(context, ".avro") + path.getFileSystem(context.getConfiguration).create(path) + } + + }.getRecordWriter(context) + + override def write(row: InternalRow): Unit = { + val key = new AvroKey(serializer.serialize(row).asInstanceOf[GenericRecord]) + recordWriter.write(key, NullWritable.get()) + } + + override def close(): Unit = recordWriter.close(context) +} diff --git a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroOutputWriterFactory.scala b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroOutputWriterFactory.scala new file mode 100644 index 0000000000000..116020ed5c433 --- /dev/null +++ b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroOutputWriterFactory.scala @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.avro + +import org.apache.avro.Schema +import org.apache.hadoop.mapreduce.TaskAttemptContext + +import org.apache.spark.sql.execution.datasources.{OutputWriter, OutputWriterFactory} +import org.apache.spark.sql.types.StructType + +/** + * A factory that produces [[AvroOutputWriter]]. + * @param catalystSchema Catalyst schema of input data. + * @param avroSchemaAsJsonString Avro schema of output result, in JSON string format. + */ +private[avro] class AvroOutputWriterFactory( + catalystSchema: StructType, + avroSchemaAsJsonString: String) extends OutputWriterFactory { + + private lazy val avroSchema = new Schema.Parser().parse(avroSchemaAsJsonString) + + override def getFileExtension(context: TaskAttemptContext): String = ".avro" + + override def newInstance( + path: String, + dataSchema: StructType, + context: TaskAttemptContext): OutputWriter = { + new AvroOutputWriter(path, context, catalystSchema, avroSchema) + } +} diff --git a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroSerializer.scala b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroSerializer.scala new file mode 100644 index 0000000000000..e902b4c77eaad --- /dev/null +++ b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroSerializer.scala @@ -0,0 +1,238 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.avro + +import java.nio.ByteBuffer + +import scala.collection.JavaConverters._ + +import org.apache.avro.{LogicalTypes, Schema} +import org.apache.avro.Conversions.DecimalConversion +import org.apache.avro.LogicalTypes.{TimestampMicros, TimestampMillis} +import org.apache.avro.Schema +import org.apache.avro.Schema.Type +import org.apache.avro.Schema.Type._ +import org.apache.avro.generic.GenericData.{EnumSymbol, Fixed, Record} +import org.apache.avro.generic.GenericData.Record +import org.apache.avro.util.Utf8 + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{SpecializedGetters, SpecificInternalRow} +import org.apache.spark.sql.types._ + +/** + * A serializer to serialize data in catalyst format to data in avro format. + */ +class AvroSerializer(rootCatalystType: DataType, rootAvroType: Schema, nullable: Boolean) { + + def serialize(catalystData: Any): Any = { + converter.apply(catalystData) + } + + private val converter: Any => Any = { + val actualAvroType = resolveNullableType(rootAvroType, nullable) + val baseConverter = rootCatalystType match { + case st: StructType => + newStructConverter(st, actualAvroType).asInstanceOf[Any => Any] + case _ => + val tmpRow = new SpecificInternalRow(Seq(rootCatalystType)) + val converter = newConverter(rootCatalystType, actualAvroType) + (data: Any) => + tmpRow.update(0, data) + converter.apply(tmpRow, 0) + } + if (nullable) { + (data: Any) => + if (data == null) { + null + } else { + baseConverter.apply(data) + } + } else { + baseConverter + } + } + + private type Converter = (SpecializedGetters, Int) => Any + + private lazy val decimalConversions = new DecimalConversion() + + private def newConverter(catalystType: DataType, avroType: Schema): Converter = { + (catalystType, avroType.getType) match { + case (NullType, NULL) => + (getter, ordinal) => null + case (BooleanType, BOOLEAN) => + (getter, ordinal) => getter.getBoolean(ordinal) + case (ByteType, INT) => + (getter, ordinal) => getter.getByte(ordinal).toInt + case (ShortType, INT) => + (getter, ordinal) => getter.getShort(ordinal).toInt + case (IntegerType, INT) => + (getter, ordinal) => getter.getInt(ordinal) + case (LongType, LONG) => + (getter, ordinal) => getter.getLong(ordinal) + case (FloatType, FLOAT) => + (getter, ordinal) => getter.getFloat(ordinal) + case (DoubleType, DOUBLE) => + (getter, ordinal) => getter.getDouble(ordinal) + case (d: DecimalType, FIXED) + if avroType.getLogicalType == LogicalTypes.decimal(d.precision, d.scale) => + (getter, ordinal) => + val decimal = getter.getDecimal(ordinal, d.precision, d.scale) + decimalConversions.toFixed(decimal.toJavaBigDecimal, avroType, + LogicalTypes.decimal(d.precision, d.scale)) + + case (d: DecimalType, BYTES) + if avroType.getLogicalType == LogicalTypes.decimal(d.precision, d.scale) => + (getter, ordinal) => + val decimal = getter.getDecimal(ordinal, d.precision, d.scale) + decimalConversions.toBytes(decimal.toJavaBigDecimal, avroType, + LogicalTypes.decimal(d.precision, d.scale)) + + case (StringType, ENUM) => + val enumSymbols: Set[String] = avroType.getEnumSymbols.asScala.toSet + (getter, ordinal) => + val data = getter.getUTF8String(ordinal).toString + if (!enumSymbols.contains(data)) { + throw new IncompatibleSchemaException( + "Cannot write \"" + data + "\" since it's not defined in enum \"" + + enumSymbols.mkString("\", \"") + "\"") + } + new EnumSymbol(avroType, data) + + case (StringType, STRING) => + (getter, ordinal) => new Utf8(getter.getUTF8String(ordinal).getBytes) + + case (BinaryType, FIXED) => + val size = avroType.getFixedSize() + (getter, ordinal) => + val data: Array[Byte] = getter.getBinary(ordinal) + if (data.length != size) { + throw new IncompatibleSchemaException( + s"Cannot write ${data.length} ${if (data.length > 1) "bytes" else "byte"} of " + + "binary data into FIXED Type with size of " + + s"$size ${if (size > 1) "bytes" else "byte"}") + } + new Fixed(avroType, data) + + case (BinaryType, BYTES) => + (getter, ordinal) => ByteBuffer.wrap(getter.getBinary(ordinal)) + + case (DateType, INT) => + (getter, ordinal) => getter.getInt(ordinal) + + case (TimestampType, LONG) => avroType.getLogicalType match { + case _: TimestampMillis => (getter, ordinal) => getter.getLong(ordinal) / 1000 + case _: TimestampMicros => (getter, ordinal) => getter.getLong(ordinal) + // For backward compatibility, if the Avro type is Long and it is not logical type, + // output the timestamp value as with millisecond precision. + case null => (getter, ordinal) => getter.getLong(ordinal) / 1000 + case other => throw new IncompatibleSchemaException( + s"Cannot convert Catalyst Timestamp type to Avro logical type ${other}") + } + + case (ArrayType(et, containsNull), ARRAY) => + val elementConverter = newConverter( + et, resolveNullableType(avroType.getElementType, containsNull)) + (getter, ordinal) => { + val arrayData = getter.getArray(ordinal) + val len = arrayData.numElements() + val result = new Array[Any](len) + var i = 0 + while (i < len) { + if (containsNull && arrayData.isNullAt(i)) { + result(i) = null + } else { + result(i) = elementConverter(arrayData, i) + } + i += 1 + } + // avro writer is expecting a Java Collection, so we convert it into + // `ArrayList` backed by the specified array without data copying. + java.util.Arrays.asList(result: _*) + } + + case (st: StructType, RECORD) => + val structConverter = newStructConverter(st, avroType) + val numFields = st.length + (getter, ordinal) => structConverter(getter.getStruct(ordinal, numFields)) + + case (MapType(kt, vt, valueContainsNull), MAP) if kt == StringType => + val valueConverter = newConverter( + vt, resolveNullableType(avroType.getValueType, valueContainsNull)) + (getter, ordinal) => + val mapData = getter.getMap(ordinal) + val len = mapData.numElements() + val result = new java.util.HashMap[String, Any](len) + val keyArray = mapData.keyArray() + val valueArray = mapData.valueArray() + var i = 0 + while (i < len) { + val key = keyArray.getUTF8String(i).toString + if (valueContainsNull && valueArray.isNullAt(i)) { + result.put(key, null) + } else { + result.put(key, valueConverter(valueArray, i)) + } + i += 1 + } + result + + case other => + throw new IncompatibleSchemaException(s"Cannot convert Catalyst type $catalystType to " + + s"Avro type $avroType.") + } + } + + private def newStructConverter( + catalystStruct: StructType, avroStruct: Schema): InternalRow => Record = { + if (avroStruct.getType != RECORD || avroStruct.getFields.size() != catalystStruct.length) { + throw new IncompatibleSchemaException(s"Cannot convert Catalyst type $catalystStruct to " + + s"Avro type $avroStruct.") + } + val fieldConverters = catalystStruct.zip(avroStruct.getFields.asScala).map { + case (f1, f2) => newConverter(f1.dataType, resolveNullableType(f2.schema(), f1.nullable)) + } + val numFields = catalystStruct.length + (row: InternalRow) => + val result = new Record(avroStruct) + var i = 0 + while (i < numFields) { + if (row.isNullAt(i)) { + result.put(i, null) + } else { + result.put(i, fieldConverters(i).apply(row, i)) + } + i += 1 + } + result + } + + private def resolveNullableType(avroType: Schema, nullable: Boolean): Schema = { + if (nullable && avroType.getType != NULL) { + // avro uses union to represent nullable type. + val fields = avroType.getTypes.asScala + assert(fields.length == 2) + val actualType = fields.filter(_.getType != Type.NULL) + assert(actualType.length == 1) + actualType.head + } else { + avroType + } + } +} diff --git a/external/avro/src/main/scala/org/apache/spark/sql/avro/CatalystDataToAvro.scala b/external/avro/src/main/scala/org/apache/spark/sql/avro/CatalystDataToAvro.scala new file mode 100644 index 0000000000000..141ff3782adfb --- /dev/null +++ b/external/avro/src/main/scala/org/apache/spark/sql/avro/CatalystDataToAvro.scala @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.avro + +import java.io.ByteArrayOutputStream + +import org.apache.avro.generic.GenericDatumWriter +import org.apache.avro.io.{BinaryEncoder, EncoderFactory} + +import org.apache.spark.sql.catalyst.expressions.{Expression, UnaryExpression} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.sql.types.{BinaryType, DataType} + +case class CatalystDataToAvro(child: Expression) extends UnaryExpression { + + override def dataType: DataType = BinaryType + + @transient private lazy val avroType = + SchemaConverters.toAvroType(child.dataType, child.nullable) + + @transient private lazy val serializer = + new AvroSerializer(child.dataType, avroType, child.nullable) + + @transient private lazy val writer = + new GenericDatumWriter[Any](avroType) + + @transient private var encoder: BinaryEncoder = _ + + @transient private lazy val out = new ByteArrayOutputStream + + override def nullSafeEval(input: Any): Any = { + out.reset() + encoder = EncoderFactory.get().directBinaryEncoder(out, encoder) + val avroData = serializer.serialize(input) + writer.write(avroData, encoder) + encoder.flush() + out.toByteArray + } + + override def simpleString: String = { + s"to_avro(${child.sql}, ${child.dataType.simpleString})" + } + + override def sql: String = { + s"to_avro(${child.sql}, ${child.dataType.catalogString})" + } + + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val expr = ctx.addReferenceObj("this", this) + defineCodeGen(ctx, ev, input => + s"(byte[]) $expr.nullSafeEval($input)") + } +} diff --git a/external/avro/src/main/scala/org/apache/spark/sql/avro/SchemaConverters.scala b/external/avro/src/main/scala/org/apache/spark/sql/avro/SchemaConverters.scala new file mode 100644 index 0000000000000..bd1576587d7fa --- /dev/null +++ b/external/avro/src/main/scala/org/apache/spark/sql/avro/SchemaConverters.scala @@ -0,0 +1,180 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.avro + +import scala.collection.JavaConverters._ +import scala.util.Random + +import org.apache.avro.{LogicalTypes, Schema, SchemaBuilder} +import org.apache.avro.LogicalTypes.{Date, Decimal, TimestampMicros, TimestampMillis} +import org.apache.avro.Schema.Type._ + +import org.apache.spark.sql.catalyst.util.RandomUUIDGenerator +import org.apache.spark.sql.types._ +import org.apache.spark.sql.types.Decimal.{maxPrecisionForBytes, minBytesForPrecision} + +/** + * This object contains method that are used to convert sparkSQL schemas to avro schemas and vice + * versa. + */ +object SchemaConverters { + private lazy val uuidGenerator = RandomUUIDGenerator(new Random().nextLong()) + + private lazy val nullSchema = Schema.create(Schema.Type.NULL) + + case class SchemaType(dataType: DataType, nullable: Boolean) + + /** + * This function takes an avro schema and returns a sql schema. + */ + def toSqlType(avroSchema: Schema): SchemaType = { + avroSchema.getType match { + case INT => avroSchema.getLogicalType match { + case _: Date => SchemaType(DateType, nullable = false) + case _ => SchemaType(IntegerType, nullable = false) + } + case STRING => SchemaType(StringType, nullable = false) + case BOOLEAN => SchemaType(BooleanType, nullable = false) + case BYTES | FIXED => avroSchema.getLogicalType match { + // For FIXED type, if the precision requires more bytes than fixed size, the logical + // type will be null, which is handled by Avro library. + case d: Decimal => SchemaType(DecimalType(d.getPrecision, d.getScale), nullable = false) + case _ => SchemaType(BinaryType, nullable = false) + } + + case DOUBLE => SchemaType(DoubleType, nullable = false) + case FLOAT => SchemaType(FloatType, nullable = false) + case LONG => avroSchema.getLogicalType match { + case _: TimestampMillis | _: TimestampMicros => SchemaType(TimestampType, nullable = false) + case _ => SchemaType(LongType, nullable = false) + } + + case ENUM => SchemaType(StringType, nullable = false) + + case RECORD => + val fields = avroSchema.getFields.asScala.map { f => + val schemaType = toSqlType(f.schema()) + StructField(f.name, schemaType.dataType, schemaType.nullable) + } + + SchemaType(StructType(fields), nullable = false) + + case ARRAY => + val schemaType = toSqlType(avroSchema.getElementType) + SchemaType( + ArrayType(schemaType.dataType, containsNull = schemaType.nullable), + nullable = false) + + case MAP => + val schemaType = toSqlType(avroSchema.getValueType) + SchemaType( + MapType(StringType, schemaType.dataType, valueContainsNull = schemaType.nullable), + nullable = false) + + case UNION => + if (avroSchema.getTypes.asScala.exists(_.getType == NULL)) { + // In case of a union with null, eliminate it and make a recursive call + val remainingUnionTypes = avroSchema.getTypes.asScala.filterNot(_.getType == NULL) + if (remainingUnionTypes.size == 1) { + toSqlType(remainingUnionTypes.head).copy(nullable = true) + } else { + toSqlType(Schema.createUnion(remainingUnionTypes.asJava)).copy(nullable = true) + } + } else avroSchema.getTypes.asScala.map(_.getType) match { + case Seq(t1) => + toSqlType(avroSchema.getTypes.get(0)) + case Seq(t1, t2) if Set(t1, t2) == Set(INT, LONG) => + SchemaType(LongType, nullable = false) + case Seq(t1, t2) if Set(t1, t2) == Set(FLOAT, DOUBLE) => + SchemaType(DoubleType, nullable = false) + case _ => + // Convert complex unions to struct types where field names are member0, member1, etc. + // This is consistent with the behavior when converting between Avro and Parquet. + val fields = avroSchema.getTypes.asScala.zipWithIndex.map { + case (s, i) => + val schemaType = toSqlType(s) + // All fields are nullable because only one of them is set at a time + StructField(s"member$i", schemaType.dataType, nullable = true) + } + + SchemaType(StructType(fields), nullable = false) + } + + case other => throw new IncompatibleSchemaException(s"Unsupported type $other") + } + } + + def toAvroType( + catalystType: DataType, + nullable: Boolean = false, + recordName: String = "topLevelRecord", + nameSpace: String = "") + : Schema = { + val builder = SchemaBuilder.builder() + + val schema = catalystType match { + case BooleanType => builder.booleanType() + case ByteType | ShortType | IntegerType => builder.intType() + case LongType => builder.longType() + case DateType => + LogicalTypes.date().addToSchema(builder.intType()) + case TimestampType => + LogicalTypes.timestampMicros().addToSchema(builder.longType()) + + case FloatType => builder.floatType() + case DoubleType => builder.doubleType() + case StringType => builder.stringType() + case d: DecimalType => + val avroType = LogicalTypes.decimal(d.precision, d.scale) + val fixedSize = minBytesForPrecision(d.precision) + // Need to avoid naming conflict for the fixed fields + val name = nameSpace match { + case "" => s"$recordName.fixed" + case _ => s"$nameSpace.$recordName.fixed" + } + avroType.addToSchema(SchemaBuilder.fixed(name).size(fixedSize)) + + case BinaryType => builder.bytesType() + case ArrayType(et, containsNull) => + builder.array() + .items(toAvroType(et, containsNull, recordName, nameSpace)) + case MapType(StringType, vt, valueContainsNull) => + builder.map() + .values(toAvroType(vt, valueContainsNull, recordName, nameSpace)) + case st: StructType => + val childNameSpace = if (nameSpace != "") s"$nameSpace.$recordName" else recordName + val fieldsAssembler = builder.record(recordName).namespace(nameSpace).fields() + st.foreach { f => + val fieldAvroType = + toAvroType(f.dataType, f.nullable, f.name, childNameSpace) + fieldsAssembler.name(f.name).`type`(fieldAvroType).noDefault() + } + fieldsAssembler.endRecord() + + // This should never happen. + case other => throw new IncompatibleSchemaException(s"Unexpected type $other.") + } + if (nullable) { + Schema.createUnion(schema, nullSchema) + } else { + schema + } + } +} + +class IncompatibleSchemaException(msg: String, ex: Throwable = null) extends Exception(msg, ex) diff --git a/external/avro/src/main/scala/org/apache/spark/sql/avro/package.scala b/external/avro/src/main/scala/org/apache/spark/sql/avro/package.scala new file mode 100755 index 0000000000000..97f9427f96c55 --- /dev/null +++ b/external/avro/src/main/scala/org/apache/spark/sql/avro/package.scala @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.annotation.Experimental + +package object avro { + /** + * Converts a binary column of avro format into its corresponding catalyst value. The specified + * schema must match the read data, otherwise the behavior is undefined: it may fail or return + * arbitrary result. + * + * @param data the binary column. + * @param jsonFormatSchema the avro schema in JSON string format. + * + * @since 2.4.0 + */ + @Experimental + def from_avro(data: Column, jsonFormatSchema: String): Column = { + new Column(AvroDataToCatalyst(data.expr, jsonFormatSchema)) + } + + /** + * Converts a column into binary of avro format. + * + * @param data the data column. + * + * @since 2.4.0 + */ + @Experimental + def to_avro(data: Column): Column = { + new Column(CatalystDataToAvro(data.expr)) + } +} diff --git a/external/avro/src/test/resources/episodes.avro b/external/avro/src/test/resources/episodes.avro new file mode 100644 index 0000000000000..58a028ce19e6a Binary files /dev/null and b/external/avro/src/test/resources/episodes.avro differ diff --git a/external/avro/src/test/resources/log4j.properties b/external/avro/src/test/resources/log4j.properties new file mode 100644 index 0000000000000..75e3b53a093f6 --- /dev/null +++ b/external/avro/src/test/resources/log4j.properties @@ -0,0 +1,28 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Set everything to be logged to the file target/unit-tests.log +log4j.rootCategory=INFO, file +log4j.appender.file=org.apache.log4j.FileAppender +log4j.appender.file.append=true +log4j.appender.file.file=target/unit-tests.log +log4j.appender.file.layout=org.apache.log4j.PatternLayout +log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n + +# Ignore messages below warning level from Jetty, because it's a bit verbose +log4j.logger.org.spark-project.jetty=WARN + diff --git a/external/avro/src/test/resources/test-random-partitioned/part-r-00000.avro b/external/avro/src/test/resources/test-random-partitioned/part-r-00000.avro new file mode 100755 index 0000000000000..fece892444979 Binary files /dev/null and b/external/avro/src/test/resources/test-random-partitioned/part-r-00000.avro differ diff --git a/external/avro/src/test/resources/test-random-partitioned/part-r-00001.avro b/external/avro/src/test/resources/test-random-partitioned/part-r-00001.avro new file mode 100755 index 0000000000000..1ca623a07dcf3 Binary files /dev/null and b/external/avro/src/test/resources/test-random-partitioned/part-r-00001.avro differ diff --git a/external/avro/src/test/resources/test-random-partitioned/part-r-00002.avro b/external/avro/src/test/resources/test-random-partitioned/part-r-00002.avro new file mode 100755 index 0000000000000..a12e9459e7461 Binary files /dev/null and b/external/avro/src/test/resources/test-random-partitioned/part-r-00002.avro differ diff --git a/external/avro/src/test/resources/test-random-partitioned/part-r-00003.avro b/external/avro/src/test/resources/test-random-partitioned/part-r-00003.avro new file mode 100755 index 0000000000000..60c095691d5d5 Binary files /dev/null and b/external/avro/src/test/resources/test-random-partitioned/part-r-00003.avro differ diff --git a/external/avro/src/test/resources/test-random-partitioned/part-r-00004.avro b/external/avro/src/test/resources/test-random-partitioned/part-r-00004.avro new file mode 100755 index 0000000000000..af56dfc8083dc Binary files /dev/null and b/external/avro/src/test/resources/test-random-partitioned/part-r-00004.avro differ diff --git a/external/avro/src/test/resources/test-random-partitioned/part-r-00005.avro b/external/avro/src/test/resources/test-random-partitioned/part-r-00005.avro new file mode 100755 index 0000000000000..87d78447526f9 Binary files /dev/null and b/external/avro/src/test/resources/test-random-partitioned/part-r-00005.avro differ diff --git a/external/avro/src/test/resources/test-random-partitioned/part-r-00006.avro b/external/avro/src/test/resources/test-random-partitioned/part-r-00006.avro new file mode 100755 index 0000000000000..c326fc434bf18 Binary files /dev/null and b/external/avro/src/test/resources/test-random-partitioned/part-r-00006.avro differ diff --git a/external/avro/src/test/resources/test-random-partitioned/part-r-00007.avro b/external/avro/src/test/resources/test-random-partitioned/part-r-00007.avro new file mode 100755 index 0000000000000..279f36c317eb8 Binary files /dev/null and b/external/avro/src/test/resources/test-random-partitioned/part-r-00007.avro differ diff --git a/external/avro/src/test/resources/test-random-partitioned/part-r-00008.avro b/external/avro/src/test/resources/test-random-partitioned/part-r-00008.avro new file mode 100755 index 0000000000000..8d70f5d1274d4 Binary files /dev/null and b/external/avro/src/test/resources/test-random-partitioned/part-r-00008.avro differ diff --git a/external/avro/src/test/resources/test-random-partitioned/part-r-00009.avro b/external/avro/src/test/resources/test-random-partitioned/part-r-00009.avro new file mode 100755 index 0000000000000..6839d7217e492 Binary files /dev/null and b/external/avro/src/test/resources/test-random-partitioned/part-r-00009.avro differ diff --git a/external/avro/src/test/resources/test-random-partitioned/part-r-00010.avro b/external/avro/src/test/resources/test-random-partitioned/part-r-00010.avro new file mode 100755 index 0000000000000..aedc7f7e0e61c Binary files /dev/null and b/external/avro/src/test/resources/test-random-partitioned/part-r-00010.avro differ diff --git a/external/avro/src/test/resources/test.avro b/external/avro/src/test/resources/test.avro new file mode 100644 index 0000000000000..6425e2107e304 Binary files /dev/null and b/external/avro/src/test/resources/test.avro differ diff --git a/external/avro/src/test/resources/test.avsc b/external/avro/src/test/resources/test.avsc new file mode 100644 index 0000000000000..d7119a01f6aa0 --- /dev/null +++ b/external/avro/src/test/resources/test.avsc @@ -0,0 +1,53 @@ +{ + "type" : "record", + "name" : "test_schema", + "fields" : [{ + "name" : "string", + "type" : "string", + "doc" : "Meaningless string of characters" + }, { + "name" : "simple_map", + "type" : {"type": "map", "values": "int"} + }, { + "name" : "complex_map", + "type" : {"type": "map", "values": {"type": "map", "values": "string"}} + }, { + "name" : "union_string_null", + "type" : ["null", "string"] + }, { + "name" : "union_int_long_null", + "type" : ["int", "long", "null"] + }, { + "name" : "union_float_double", + "type" : ["float", "double"] + }, { + "name": "fixed3", + "type": {"type": "fixed", "size": 3, "name": "fixed3"} + }, { + "name": "fixed2", + "type": {"type": "fixed", "size": 2, "name": "fixed2"} + }, { + "name": "enum", + "type": { "type": "enum", + "name": "Suit", + "symbols" : ["SPADES", "HEARTS", "DIAMONDS", "CLUBS"] + } + }, { + "name": "record", + "type": { + "type": "record", + "name": "record", + "aliases": ["RecordAlias"], + "fields" : [{ + "name": "value_field", + "type": "string" + }] + } + }, { + "name": "array_of_boolean", + "type": {"type": "array", "items": "boolean"} + }, { + "name": "bytes", + "type": "bytes" + }] +} diff --git a/external/avro/src/test/resources/test.json b/external/avro/src/test/resources/test.json new file mode 100644 index 0000000000000..780189a92b378 --- /dev/null +++ b/external/avro/src/test/resources/test.json @@ -0,0 +1,42 @@ +{ + "string": "OMG SPARK IS AWESOME", + "simple_map": {"abc": 1, "bcd": 7}, + "complex_map": {"key": {"a": "b", "c": "d"}}, + "union_string_null": {"string": "abc"}, + "union_int_long_null": {"int": 1}, + "union_float_double": {"float": 3.1415926535}, + "fixed3":"\u0002\u0003\u0004", + "fixed2":"\u0011\u0012", + "enum": "SPADES", + "record": {"value_field": "Two things are infinite: the universe and human stupidity; and I'm not sure about universe."}, + "array_of_boolean": [true, false, false], + "bytes": "\u0041\u0042\u0043" +} +{ + "string": "Terran is IMBA!", + "simple_map": {"mmm": 0, "qqq": 66}, + "complex_map": {"key": {"1": "2", "3": "4"}}, + "union_string_null": {"string": "123"}, + "union_int_long_null": {"long": 66}, + "union_float_double": {"double": 6.6666666666666}, + "fixed3":"\u0007\u0007\u0007", + "fixed2":"\u0001\u0002", + "enum": "CLUBS", + "record": {"value_field": "Life did not intend to make us perfect. Whoever is perfect belongs in a museum."}, + "array_of_boolean": [], + "bytes": "" +} +{ + "string": "The cake is a LIE!", + "simple_map": {}, + "complex_map": {"key": {}}, + "union_string_null": {"null": null}, + "union_int_long_null": {"null": null}, + "union_float_double": {"double": 0}, + "fixed3":"\u0011\u0022\u0009", + "fixed2":"\u0010\u0090", + "enum": "DIAMONDS", + "record": {"value_field": "TEST_STR123"}, + "array_of_boolean": [false], + "bytes": "\u0053" +} diff --git a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroCatalystDataConversionSuite.scala b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroCatalystDataConversionSuite.scala new file mode 100644 index 0000000000000..8334cca6cd8f1 --- /dev/null +++ b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroCatalystDataConversionSuite.scala @@ -0,0 +1,170 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.avro + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.RandomDataGenerator +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} +import org.apache.spark.sql.catalyst.expressions.{ExpressionEvalHelper, GenericInternalRow, Literal} +import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData, MapData} +import org.apache.spark.sql.types._ + +class AvroCatalystDataConversionSuite extends SparkFunSuite with ExpressionEvalHelper { + + private def roundTripTest(data: Literal): Unit = { + val avroType = SchemaConverters.toAvroType(data.dataType, data.nullable) + checkResult(data, avroType.toString, data.eval()) + } + + private def checkResult(data: Literal, schema: String, expected: Any): Unit = { + checkEvaluation( + AvroDataToCatalyst(CatalystDataToAvro(data), schema), + prepareExpectedResult(expected)) + } + + private def assertFail(data: Literal, schema: String): Unit = { + intercept[java.io.EOFException] { + AvroDataToCatalyst(CatalystDataToAvro(data), schema).eval() + } + } + + private val testingTypes = Seq( + BooleanType, + ByteType, + ShortType, + IntegerType, + LongType, + FloatType, + DoubleType, + DecimalType(8, 0), // 32 bits decimal without fraction + DecimalType(8, 4), // 32 bits decimal + DecimalType(16, 0), // 64 bits decimal without fraction + DecimalType(16, 11), // 64 bits decimal + DecimalType(38, 0), + DecimalType(38, 38), + StringType, + BinaryType) + + protected def prepareExpectedResult(expected: Any): Any = expected match { + // Spark byte and short both map to avro int + case b: Byte => b.toInt + case s: Short => s.toInt + case row: GenericInternalRow => InternalRow.fromSeq(row.values.map(prepareExpectedResult)) + case array: GenericArrayData => new GenericArrayData(array.array.map(prepareExpectedResult)) + case map: MapData => + val keys = new GenericArrayData( + map.keyArray().asInstanceOf[GenericArrayData].array.map(prepareExpectedResult)) + val values = new GenericArrayData( + map.valueArray().asInstanceOf[GenericArrayData].array.map(prepareExpectedResult)) + new ArrayBasedMapData(keys, values) + case other => other + } + + testingTypes.foreach { dt => + val seed = scala.util.Random.nextLong() + test(s"single $dt with seed $seed") { + val rand = new scala.util.Random(seed) + val data = RandomDataGenerator.forType(dt, rand = rand).get.apply() + val converter = CatalystTypeConverters.createToCatalystConverter(dt) + val input = Literal.create(converter(data), dt) + roundTripTest(input) + } + } + + for (_ <- 1 to 5) { + val seed = scala.util.Random.nextLong() + val rand = new scala.util.Random(seed) + val schema = RandomDataGenerator.randomSchema(rand, 5, testingTypes) + test(s"flat schema ${schema.catalogString} with seed $seed") { + val data = RandomDataGenerator.randomRow(rand, schema) + val converter = CatalystTypeConverters.createToCatalystConverter(schema) + val input = Literal.create(converter(data), schema) + roundTripTest(input) + } + } + + for (_ <- 1 to 5) { + val seed = scala.util.Random.nextLong() + val rand = new scala.util.Random(seed) + val schema = RandomDataGenerator.randomNestedSchema(rand, 10, testingTypes) + test(s"nested schema ${schema.catalogString} with seed $seed") { + val data = RandomDataGenerator.randomRow(rand, schema) + val converter = CatalystTypeConverters.createToCatalystConverter(schema) + val input = Literal.create(converter(data), schema) + roundTripTest(input) + } + } + + test("read int as string") { + val data = Literal(1) + val avroTypeJson = + s""" + |{ + | "type": "string", + | "name": "my_string" + |} + """.stripMargin + + // When read int as string, avro reader is not able to parse the binary and fail. + assertFail(data, avroTypeJson) + } + + test("read string as int") { + val data = Literal("abc") + val avroTypeJson = + s""" + |{ + | "type": "int", + | "name": "my_int" + |} + """.stripMargin + + // When read string data as int, avro reader is not able to find the type mismatch and read + // the string length as int value. + checkResult(data, avroTypeJson, 3) + } + + test("read float as double") { + val data = Literal(1.23f) + val avroTypeJson = + s""" + |{ + | "type": "double", + | "name": "my_double" + |} + """.stripMargin + + // When read float data as double, avro reader fails(trying to read 8 bytes while the data have + // only 4 bytes). + assertFail(data, avroTypeJson) + } + + test("read double as float") { + val data = Literal(1.23) + val avroTypeJson = + s""" + |{ + | "type": "float", + | "name": "my_float" + |} + """.stripMargin + + // avro reader reads the first 4 bytes of a double as a float, the result is totally undefined. + checkResult(data, avroTypeJson, 5.848603E35f) + } +} diff --git a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroFunctionsSuite.scala b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroFunctionsSuite.scala new file mode 100644 index 0000000000000..90a4cd6ccf9dd --- /dev/null +++ b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroFunctionsSuite.scala @@ -0,0 +1,83 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.avro + +import org.apache.avro.Schema + +import org.apache.spark.sql.QueryTest +import org.apache.spark.sql.functions.struct +import org.apache.spark.sql.test.SharedSQLContext + +class AvroFunctionsSuite extends QueryTest with SharedSQLContext { + import testImplicits._ + + test("roundtrip in to_avro and from_avro - int and string") { + val df = spark.range(10).select('id, 'id.cast("string").as("str")) + + val avroDF = df.select(to_avro('id).as("a"), to_avro('str).as("b")) + val avroTypeLong = s""" + |{ + | "type": "int", + | "name": "id" + |} + """.stripMargin + val avroTypeStr = s""" + |{ + | "type": "string", + | "name": "str" + |} + """.stripMargin + checkAnswer(avroDF.select(from_avro('a, avroTypeLong), from_avro('b, avroTypeStr)), df) + } + + test("roundtrip in to_avro and from_avro - struct") { + val df = spark.range(10).select(struct('id, 'id.cast("string").as("str")).as("struct")) + val avroStructDF = df.select(to_avro('struct).as("avro")) + val avroTypeStruct = s""" + |{ + | "type": "record", + | "name": "struct", + | "fields": [ + | {"name": "col1", "type": "long"}, + | {"name": "col2", "type": "string"} + | ] + |} + """.stripMargin + checkAnswer(avroStructDF.select(from_avro('avro, avroTypeStruct)), df) + } + + test("roundtrip in to_avro and from_avro - array with null") { + val dfOne = Seq(Tuple1(Tuple1(1) :: Nil), Tuple1(null :: Nil)).toDF("array") + val avroTypeArrStruct = s""" + |[ { + | "type" : "array", + | "items" : [ { + | "type" : "record", + | "name" : "x", + | "fields" : [ { + | "name" : "y", + | "type" : "int" + | } ] + | }, "null" ] + |}, "null" ] + """.stripMargin + val readBackOne = dfOne.select(to_avro($"array").as("avro")) + .select(from_avro($"avro", avroTypeArrStruct).as("array")) + checkAnswer(dfOne, readBackOne) + } +} diff --git a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroLogicalTypeSuite.scala b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroLogicalTypeSuite.scala new file mode 100644 index 0000000000000..79ba2871c2264 --- /dev/null +++ b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroLogicalTypeSuite.scala @@ -0,0 +1,351 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.avro + +import java.io.File +import java.sql.Timestamp + +import org.apache.avro.{LogicalTypes, Schema} +import org.apache.avro.Conversions.DecimalConversion +import org.apache.avro.file.DataFileWriter +import org.apache.avro.generic.{GenericData, GenericDatumWriter, GenericRecord} + +import org.apache.spark.SparkException +import org.apache.spark.sql.{QueryTest, Row} +import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.{SharedSQLContext, SQLTestUtils} +import org.apache.spark.sql.types.{StructField, StructType, TimestampType} + +class AvroLogicalTypeSuite extends QueryTest with SharedSQLContext with SQLTestUtils { + import testImplicits._ + + val dateSchema = s""" + { + "namespace": "logical", + "type": "record", + "name": "test", + "fields": [ + {"name": "date", "type": {"type": "int", "logicalType": "date"}} + ] + } + """ + + val dateInputData = Seq(7, 365, 0) + + def dateFile(path: String): String = { + val schema = new Schema.Parser().parse(dateSchema) + val datumWriter = new GenericDatumWriter[GenericRecord](schema) + val dataFileWriter = new DataFileWriter[GenericRecord](datumWriter) + val result = s"$path/test.avro" + dataFileWriter.create(schema, new File(result)) + + dateInputData.foreach { x => + val record = new GenericData.Record(schema) + record.put("date", x) + dataFileWriter.append(record) + } + dataFileWriter.flush() + dataFileWriter.close() + result + } + + test("Logical type: date") { + withTempDir { dir => + val expected = dateInputData.map(t => Row(DateTimeUtils.toJavaDate(t))) + val dateAvro = dateFile(dir.getAbsolutePath) + val df = spark.read.format("avro").load(dateAvro) + + checkAnswer(df, expected) + + checkAnswer(spark.read.format("avro").option("avroSchema", dateSchema).load(dateAvro), + expected) + + withTempPath { path => + df.write.format("avro").save(path.toString) + checkAnswer(spark.read.format("avro").load(path.toString), expected) + } + } + } + + val timestampSchema = s""" + { + "namespace": "logical", + "type": "record", + "name": "test", + "fields": [ + {"name": "timestamp_millis", "type": {"type": "long","logicalType": "timestamp-millis"}}, + {"name": "timestamp_micros", "type": {"type": "long","logicalType": "timestamp-micros"}}, + {"name": "long", "type": "long"} + ] + } + """ + + val timestampInputData = Seq((1000L, 2000L, 3000L), (666000L, 999000L, 777000L)) + + def timestampFile(path: String): String = { + val schema = new Schema.Parser().parse(timestampSchema) + val datumWriter = new GenericDatumWriter[GenericRecord](schema) + val dataFileWriter = new DataFileWriter[GenericRecord](datumWriter) + val result = s"$path/test.avro" + dataFileWriter.create(schema, new File(result)) + + timestampInputData.foreach { t => + val record = new GenericData.Record(schema) + record.put("timestamp_millis", t._1) + // For microsecond precision, we multiple the value by 1000 to match the expected answer as + // timestamp with millisecond precision. + record.put("timestamp_micros", t._2 * 1000) + record.put("long", t._3) + dataFileWriter.append(record) + } + dataFileWriter.flush() + dataFileWriter.close() + result + } + + test("Logical type: timestamp_millis") { + withTempDir { dir => + val expected = timestampInputData.map(t => Row(new Timestamp(t._1))) + val timestampAvro = timestampFile(dir.getAbsolutePath) + val df = spark.read.format("avro").load(timestampAvro).select('timestamp_millis) + + checkAnswer(df, expected) + + withTempPath { path => + df.write.format("avro").save(path.toString) + checkAnswer(spark.read.format("avro").load(path.toString), expected) + } + } + } + + test("Logical type: timestamp_micros") { + withTempDir { dir => + val expected = timestampInputData.map(t => Row(new Timestamp(t._2))) + val timestampAvro = timestampFile(dir.getAbsolutePath) + val df = spark.read.format("avro").load(timestampAvro).select('timestamp_micros) + + checkAnswer(df, expected) + + withTempPath { path => + df.write.format("avro").save(path.toString) + checkAnswer(spark.read.format("avro").load(path.toString), expected) + } + } + } + + test("Logical type: user specified output schema with different timestamp types") { + withTempDir { dir => + val timestampAvro = timestampFile(dir.getAbsolutePath) + val df = + spark.read.format("avro").load(timestampAvro).select('timestamp_millis, 'timestamp_micros) + + val expected = timestampInputData.map(t => Row(new Timestamp(t._1), new Timestamp(t._2))) + + val userSpecifiedTimestampSchema = s""" + { + "namespace": "logical", + "type": "record", + "name": "test", + "fields": [ + {"name": "timestamp_millis", + "type": [{"type": "long","logicalType": "timestamp-micros"}, "null"]}, + {"name": "timestamp_micros", + "type": [{"type": "long","logicalType": "timestamp-millis"}, "null"]} + ] + } + """ + + withTempPath { path => + df.write + .format("avro") + .option("avroSchema", userSpecifiedTimestampSchema) + .save(path.toString) + checkAnswer(spark.read.format("avro").load(path.toString), expected) + } + } + } + + test("Read Long type as Timestamp") { + withTempDir { dir => + val timestampAvro = timestampFile(dir.getAbsolutePath) + val schema = StructType(StructField("long", TimestampType, true) :: Nil) + val df = spark.read.format("avro").schema(schema).load(timestampAvro).select('long) + + val expected = timestampInputData.map(t => Row(new Timestamp(t._3))) + + checkAnswer(df, expected) + } + } + + test("Logical type: user specified read schema") { + withTempDir { dir => + val timestampAvro = timestampFile(dir.getAbsolutePath) + val expected = timestampInputData + .map(t => Row(new Timestamp(t._1), new Timestamp(t._2), t._3)) + + val df = spark.read.format("avro").option("avroSchema", timestampSchema).load(timestampAvro) + checkAnswer(df, expected) + } + } + + val decimalInputData = Seq("1.23", "4.56", "78.90", "-1", "-2.31") + + def decimalSchemaAndFile(path: String): (String, String) = { + val precision = 4 + val scale = 2 + val bytesFieldName = "bytes" + val bytesSchema = s"""{ + "type":"bytes", + "logicalType":"decimal", + "precision":$precision, + "scale":$scale + } + """ + + val fixedFieldName = "fixed" + val fixedSchema = s"""{ + "type":"fixed", + "size":5, + "logicalType":"decimal", + "precision":$precision, + "scale":$scale, + "name":"foo" + } + """ + val avroSchema = s""" + { + "namespace": "logical", + "type": "record", + "name": "test", + "fields": [ + {"name": "$bytesFieldName", "type": $bytesSchema}, + {"name": "$fixedFieldName", "type": $fixedSchema} + ] + } + """ + val schema = new Schema.Parser().parse(avroSchema) + val datumWriter = new GenericDatumWriter[GenericRecord](schema) + val dataFileWriter = new DataFileWriter[GenericRecord](datumWriter) + val decimalConversion = new DecimalConversion + val avroFile = s"$path/test.avro" + dataFileWriter.create(schema, new File(avroFile)) + val logicalType = LogicalTypes.decimal(precision, scale) + + decimalInputData.map { x => + val avroRec = new GenericData.Record(schema) + val decimal = new java.math.BigDecimal(x).setScale(scale) + val bytes = + decimalConversion.toBytes(decimal, schema.getField(bytesFieldName).schema, logicalType) + avroRec.put(bytesFieldName, bytes) + val fixed = + decimalConversion.toFixed(decimal, schema.getField(fixedFieldName).schema, logicalType) + avroRec.put(fixedFieldName, fixed) + dataFileWriter.append(avroRec) + } + dataFileWriter.flush() + dataFileWriter.close() + + (avroSchema, avroFile) + } + + test("Logical type: Decimal") { + withTempDir { dir => + val (avroSchema, avroFile) = decimalSchemaAndFile(dir.getAbsolutePath) + val expected = + decimalInputData.map { x => Row(new java.math.BigDecimal(x), new java.math.BigDecimal(x)) } + val df = spark.read.format("avro").load(avroFile) + checkAnswer(df, expected) + checkAnswer(spark.read.format("avro").option("avroSchema", avroSchema).load(avroFile), + expected) + + withTempPath { path => + df.write.format("avro").save(path.toString) + checkAnswer(spark.read.format("avro").load(path.toString), expected) + } + } + } + + test("Logical type: write Decimal with BYTES type") { + val specifiedSchema = """ + { + "type" : "record", + "name" : "topLevelRecord", + "namespace" : "topLevelRecord", + "fields" : [ { + "name" : "bytes", + "type" : [ { + "type" : "bytes", + "namespace" : "topLevelRecord.bytes", + "logicalType" : "decimal", + "precision" : 4, + "scale" : 2 + }, "null" ] + }, { + "name" : "fixed", + "type" : [ { + "type" : "bytes", + "logicalType" : "decimal", + "precision" : 4, + "scale" : 2 + }, "null" ] + } ] + } + """ + withTempDir { dir => + val (avroSchema, avroFile) = decimalSchemaAndFile(dir.getAbsolutePath) + assert(specifiedSchema != avroSchema) + val expected = + decimalInputData.map { x => Row(new java.math.BigDecimal(x), new java.math.BigDecimal(x)) } + val df = spark.read.format("avro").load(avroFile) + + withTempPath { path => + df.write.format("avro").option("avroSchema", specifiedSchema).save(path.toString) + checkAnswer(spark.read.format("avro").load(path.toString), expected) + } + } + } + + test("Logical type: Decimal with too large precision") { + withTempDir { dir => + val schema = new Schema.Parser().parse("""{ + "namespace": "logical", + "type": "record", + "name": "test", + "fields": [{ + "name": "decimal", + "type": {"type": "bytes", "logicalType": "decimal", "precision": 4, "scale": 2} + }] + }""") + val datumWriter = new GenericDatumWriter[GenericRecord](schema) + val dataFileWriter = new DataFileWriter[GenericRecord](datumWriter) + dataFileWriter.create(schema, new File(s"$dir.avro")) + val avroRec = new GenericData.Record(schema) + val decimal = new java.math.BigDecimal("0.12345678901234567890123456789012345678") + val bytes = (new DecimalConversion).toBytes(decimal, schema, LogicalTypes.decimal(39, 38)) + avroRec.put("decimal", bytes) + dataFileWriter.append(avroRec) + dataFileWriter.flush() + dataFileWriter.close() + + val msg = intercept[SparkException] { + spark.read.format("avro").load(s"$dir.avro").collect() + }.getCause.getMessage + assert(msg.contains("Unscaled value too large for precision")) + } + } +} diff --git a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala new file mode 100644 index 0000000000000..9ad4388414eaa --- /dev/null +++ b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala @@ -0,0 +1,1269 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.avro + +import java.io._ +import java.net.URL +import java.nio.file.{Files, Paths} +import java.sql.{Date, Timestamp} +import java.util.{TimeZone, UUID} + +import scala.collection.JavaConverters._ + +import org.apache.avro.Schema +import org.apache.avro.Schema.{Field, Type} +import org.apache.avro.Schema.Type._ +import org.apache.avro.file.{DataFileReader, DataFileWriter} +import org.apache.avro.generic.{GenericData, GenericDatumReader, GenericDatumWriter, GenericRecord} +import org.apache.avro.generic.GenericData.{EnumSymbol, Fixed} +import org.apache.commons.io.FileUtils + +import org.apache.spark.SparkException +import org.apache.spark.sql._ +import org.apache.spark.sql.execution.datasources.DataSource +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.{SharedSQLContext, SQLTestUtils} +import org.apache.spark.sql.types._ + +class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { + import testImplicits._ + + val episodesAvro = testFile("episodes.avro") + val testAvro = testFile("test.avro") + + override protected def beforeAll(): Unit = { + super.beforeAll() + spark.conf.set("spark.sql.files.maxPartitionBytes", 1024) + } + + def checkReloadMatchesSaved(originalFile: String, newFile: String): Unit = { + val originalEntries = spark.read.format("avro").load(testAvro).collect() + val newEntries = spark.read.format("avro").load(newFile) + checkAnswer(newEntries, originalEntries) + } + + def checkAvroSchemaEquals(avroSchema: String, expectedAvroSchema: String): Unit = { + assert(new Schema.Parser().parse(avroSchema) == + new Schema.Parser().parse(expectedAvroSchema)) + } + + def getAvroSchemaStringFromFiles(filePath: String): String = { + new DataFileReader({ + val file = new File(filePath) + if (file.isFile) { + file + } else { + file.listFiles() + .filter(_.isFile) + .filter(_.getName.endsWith("avro")) + .head + } + }, new GenericDatumReader[Any]()).getSchema.toString(false) + } + + test("resolve avro data source") { + val databricksAvro = "com.databricks.spark.avro" + // By default the backward compatibility for com.databricks.spark.avro is enabled. + Seq("avro", "org.apache.spark.sql.avro.AvroFileFormat", databricksAvro).foreach { provider => + assert(DataSource.lookupDataSource(provider, spark.sessionState.conf) === + classOf[org.apache.spark.sql.avro.AvroFileFormat]) + } + + withSQLConf(SQLConf.LEGACY_REPLACE_DATABRICKS_SPARK_AVRO_ENABLED.key -> "false") { + val message = intercept[AnalysisException] { + DataSource.lookupDataSource(databricksAvro, spark.sessionState.conf) + }.getMessage + assert(message.contains(s"Failed to find data source: $databricksAvro")) + } + } + + test("reading from multiple paths") { + val df = spark.read.format("avro").load(episodesAvro, episodesAvro) + assert(df.count == 16) + } + + test("reading and writing partitioned data") { + val df = spark.read.format("avro").load(episodesAvro) + val fields = List("title", "air_date", "doctor") + for (field <- fields) { + withTempPath { dir => + val outputDir = s"$dir/${UUID.randomUUID}" + df.write.partitionBy(field).format("avro").save(outputDir) + val input = spark.read.format("avro").load(outputDir) + // makes sure that no fields got dropped. + // We convert Rows to Seqs in order to work around SPARK-10325 + assert(input.select(field).collect().map(_.toSeq).toSet === + df.select(field).collect().map(_.toSeq).toSet) + } + } + } + + test("request no fields") { + val df = spark.read.format("avro").load(episodesAvro) + df.createOrReplaceTempView("avro_table") + assert(spark.sql("select count(*) from avro_table").collect().head === Row(8)) + } + + test("convert formats") { + withTempPath { dir => + val df = spark.read.format("avro").load(episodesAvro) + df.write.parquet(dir.getCanonicalPath) + assert(spark.read.parquet(dir.getCanonicalPath).count() === df.count) + } + } + + test("rearrange internal schema") { + withTempPath { dir => + val df = spark.read.format("avro").load(episodesAvro) + df.select("doctor", "title").write.format("avro").save(dir.getCanonicalPath) + } + } + + test("test NULL avro type") { + withTempPath { dir => + val fields = + Seq(new Field("null", Schema.create(Type.NULL), "doc", null)).asJava + val schema = Schema.createRecord("name", "docs", "namespace", false) + schema.setFields(fields) + val datumWriter = new GenericDatumWriter[GenericRecord](schema) + val dataFileWriter = new DataFileWriter[GenericRecord](datumWriter) + dataFileWriter.create(schema, new File(s"$dir.avro")) + val avroRec = new GenericData.Record(schema) + avroRec.put("null", null) + dataFileWriter.append(avroRec) + dataFileWriter.flush() + dataFileWriter.close() + + intercept[IncompatibleSchemaException] { + spark.read.format("avro").load(s"$dir.avro") + } + } + } + + test("union(int, long) is read as long") { + withTempPath { dir => + val avroSchema: Schema = { + val union = + Schema.createUnion(List(Schema.create(Type.INT), Schema.create(Type.LONG)).asJava) + val fields = Seq(new Field("field1", union, "doc", null)).asJava + val schema = Schema.createRecord("name", "docs", "namespace", false) + schema.setFields(fields) + schema + } + + val datumWriter = new GenericDatumWriter[GenericRecord](avroSchema) + val dataFileWriter = new DataFileWriter[GenericRecord](datumWriter) + dataFileWriter.create(avroSchema, new File(s"$dir.avro")) + val rec1 = new GenericData.Record(avroSchema) + rec1.put("field1", 1.toLong) + dataFileWriter.append(rec1) + val rec2 = new GenericData.Record(avroSchema) + rec2.put("field1", 2) + dataFileWriter.append(rec2) + dataFileWriter.flush() + dataFileWriter.close() + val df = spark.read.format("avro").load(s"$dir.avro") + assert(df.schema.fields === Seq(StructField("field1", LongType, nullable = true))) + assert(df.collect().toSet == Set(Row(1L), Row(2L))) + } + } + + test("union(float, double) is read as double") { + withTempPath { dir => + val avroSchema: Schema = { + val union = + Schema.createUnion(List(Schema.create(Type.FLOAT), Schema.create(Type.DOUBLE)).asJava) + val fields = Seq(new Field("field1", union, "doc", null)).asJava + val schema = Schema.createRecord("name", "docs", "namespace", false) + schema.setFields(fields) + schema + } + + val datumWriter = new GenericDatumWriter[GenericRecord](avroSchema) + val dataFileWriter = new DataFileWriter[GenericRecord](datumWriter) + dataFileWriter.create(avroSchema, new File(s"$dir.avro")) + val rec1 = new GenericData.Record(avroSchema) + rec1.put("field1", 1.toFloat) + dataFileWriter.append(rec1) + val rec2 = new GenericData.Record(avroSchema) + rec2.put("field1", 2.toDouble) + dataFileWriter.append(rec2) + dataFileWriter.flush() + dataFileWriter.close() + val df = spark.read.format("avro").load(s"$dir.avro") + assert(df.schema.fields === Seq(StructField("field1", DoubleType, nullable = true))) + assert(df.collect().toSet == Set(Row(1.toDouble), Row(2.toDouble))) + } + } + + test("union(float, double, null) is read as nullable double") { + withTempPath { dir => + val avroSchema: Schema = { + val union = Schema.createUnion( + List(Schema.create(Type.FLOAT), + Schema.create(Type.DOUBLE), + Schema.create(Type.NULL) + ).asJava + ) + val fields = Seq(new Field("field1", union, "doc", null)).asJava + val schema = Schema.createRecord("name", "docs", "namespace", false) + schema.setFields(fields) + schema + } + + val datumWriter = new GenericDatumWriter[GenericRecord](avroSchema) + val dataFileWriter = new DataFileWriter[GenericRecord](datumWriter) + dataFileWriter.create(avroSchema, new File(s"$dir.avro")) + val rec1 = new GenericData.Record(avroSchema) + rec1.put("field1", 1.toFloat) + dataFileWriter.append(rec1) + val rec2 = new GenericData.Record(avroSchema) + rec2.put("field1", null) + dataFileWriter.append(rec2) + dataFileWriter.flush() + dataFileWriter.close() + val df = spark.read.format("avro").load(s"$dir.avro") + assert(df.schema.fields === Seq(StructField("field1", DoubleType, nullable = true))) + assert(df.collect().toSet == Set(Row(1.toDouble), Row(null))) + } + } + + test("Union of a single type") { + withTempPath { dir => + val UnionOfOne = Schema.createUnion(List(Schema.create(Type.INT)).asJava) + val fields = Seq(new Field("field1", UnionOfOne, "doc", null)).asJava + val schema = Schema.createRecord("name", "docs", "namespace", false) + schema.setFields(fields) + + val datumWriter = new GenericDatumWriter[GenericRecord](schema) + val dataFileWriter = new DataFileWriter[GenericRecord](datumWriter) + dataFileWriter.create(schema, new File(s"$dir.avro")) + val avroRec = new GenericData.Record(schema) + + avroRec.put("field1", 8) + + dataFileWriter.append(avroRec) + dataFileWriter.flush() + dataFileWriter.close() + + val df = spark.read.format("avro").load(s"$dir.avro") + assert(df.first() == Row(8)) + } + } + + test("Complex Union Type") { + withTempPath { dir => + val fixedSchema = Schema.createFixed("fixed_name", "doc", "namespace", 4) + val enumSchema = Schema.createEnum("enum_name", "doc", "namespace", List("e1", "e2").asJava) + val complexUnionType = Schema.createUnion( + List(Schema.create(Type.INT), Schema.create(Type.STRING), fixedSchema, enumSchema).asJava) + val fields = Seq( + new Field("field1", complexUnionType, "doc", null), + new Field("field2", complexUnionType, "doc", null), + new Field("field3", complexUnionType, "doc", null), + new Field("field4", complexUnionType, "doc", null) + ).asJava + val schema = Schema.createRecord("name", "docs", "namespace", false) + schema.setFields(fields) + val datumWriter = new GenericDatumWriter[GenericRecord](schema) + val dataFileWriter = new DataFileWriter[GenericRecord](datumWriter) + dataFileWriter.create(schema, new File(s"$dir.avro")) + val avroRec = new GenericData.Record(schema) + val field1 = 1234 + val field2 = "Hope that was not load bearing" + val field3 = Array[Byte](1, 2, 3, 4) + val field4 = "e2" + avroRec.put("field1", field1) + avroRec.put("field2", field2) + avroRec.put("field3", new Fixed(fixedSchema, field3)) + avroRec.put("field4", new EnumSymbol(enumSchema, field4)) + dataFileWriter.append(avroRec) + dataFileWriter.flush() + dataFileWriter.close() + + val df = spark.sqlContext.read.format("avro").load(s"$dir.avro") + assertResult(field1)(df.selectExpr("field1.member0").first().get(0)) + assertResult(field2)(df.selectExpr("field2.member1").first().get(0)) + assertResult(field3)(df.selectExpr("field3.member2").first().get(0)) + assertResult(field4)(df.selectExpr("field4.member3").first().get(0)) + } + } + + test("Lots of nulls") { + withTempPath { dir => + val schema = StructType(Seq( + StructField("binary", BinaryType, true), + StructField("timestamp", TimestampType, true), + StructField("array", ArrayType(ShortType), true), + StructField("map", MapType(StringType, StringType), true), + StructField("struct", StructType(Seq(StructField("int", IntegerType, true)))))) + val rdd = spark.sparkContext.parallelize(Seq[Row]( + Row(null, new Timestamp(1), Array[Short](1, 2, 3), null, null), + Row(null, null, null, null, null), + Row(null, null, null, null, null), + Row(null, null, null, null, null))) + val df = spark.createDataFrame(rdd, schema) + df.write.format("avro").save(dir.toString) + assert(spark.read.format("avro").load(dir.toString).count == rdd.count) + } + } + + test("Struct field type") { + withTempPath { dir => + val schema = StructType(Seq( + StructField("float", FloatType, true), + StructField("short", ShortType, true), + StructField("byte", ByteType, true), + StructField("boolean", BooleanType, true) + )) + val rdd = spark.sparkContext.parallelize(Seq( + Row(1f, 1.toShort, 1.toByte, true), + Row(2f, 2.toShort, 2.toByte, true), + Row(3f, 3.toShort, 3.toByte, true) + )) + val df = spark.createDataFrame(rdd, schema) + df.write.format("avro").save(dir.toString) + assert(spark.read.format("avro").load(dir.toString).count == rdd.count) + } + } + + test("Date field type") { + withTempPath { dir => + val schema = StructType(Seq( + StructField("float", FloatType, true), + StructField("date", DateType, true) + )) + TimeZone.setDefault(TimeZone.getTimeZone("UTC")) + val rdd = spark.sparkContext.parallelize(Seq( + Row(1f, null), + Row(2f, new Date(1451948400000L)), + Row(3f, new Date(1460066400500L)) + )) + val df = spark.createDataFrame(rdd, schema) + df.write.format("avro").save(dir.toString) + assert(spark.read.format("avro").load(dir.toString).count == rdd.count) + checkAnswer( + spark.read.format("avro").load(dir.toString).select("date"), + Seq(Row(null), Row(new Date(1451865600000L)), Row(new Date(1459987200000L)))) + } + } + + test("Array data types") { + withTempPath { dir => + val testSchema = StructType(Seq( + StructField("byte_array", ArrayType(ByteType), true), + StructField("short_array", ArrayType(ShortType), true), + StructField("float_array", ArrayType(FloatType), true), + StructField("bool_array", ArrayType(BooleanType), true), + StructField("long_array", ArrayType(LongType), true), + StructField("double_array", ArrayType(DoubleType), true), + StructField("decimal_array", ArrayType(DecimalType(10, 0)), true), + StructField("bin_array", ArrayType(BinaryType), true), + StructField("timestamp_array", ArrayType(TimestampType), true), + StructField("array_array", ArrayType(ArrayType(StringType), true), true), + StructField("struct_array", ArrayType( + StructType(Seq(StructField("name", StringType, true))))))) + + val arrayOfByte = new Array[Byte](4) + for (i <- arrayOfByte.indices) { + arrayOfByte(i) = i.toByte + } + + val rdd = spark.sparkContext.parallelize(Seq( + Row(arrayOfByte, Array[Short](1, 2, 3, 4), Array[Float](1f, 2f, 3f, 4f), + Array[Boolean](true, false, true, false), Array[Long](1L, 2L), Array[Double](1.0, 2.0), + Array[BigDecimal](BigDecimal.valueOf(3)), Array[Array[Byte]](arrayOfByte, arrayOfByte), + Array[Timestamp](new Timestamp(0)), + Array[Array[String]](Array[String]("CSH, tearing down the walls that divide us", "-jd")), + Array[Row](Row("Bobby G. can't swim"))))) + val df = spark.createDataFrame(rdd, testSchema) + df.write.format("avro").save(dir.toString) + assert(spark.read.format("avro").load(dir.toString).count == rdd.count) + } + } + + test("write with compression - sql configs") { + withTempPath { dir => + val uncompressDir = s"$dir/uncompress" + val bzip2Dir = s"$dir/bzip2" + val xzDir = s"$dir/xz" + val deflateDir = s"$dir/deflate" + val snappyDir = s"$dir/snappy" + + val df = spark.read.format("avro").load(testAvro) + spark.conf.set(SQLConf.AVRO_COMPRESSION_CODEC.key, "uncompressed") + df.write.format("avro").save(uncompressDir) + spark.conf.set(SQLConf.AVRO_COMPRESSION_CODEC.key, "bzip2") + df.write.format("avro").save(bzip2Dir) + spark.conf.set(SQLConf.AVRO_COMPRESSION_CODEC.key, "xz") + df.write.format("avro").save(xzDir) + spark.conf.set(SQLConf.AVRO_COMPRESSION_CODEC.key, "deflate") + spark.conf.set(SQLConf.AVRO_DEFLATE_LEVEL.key, "9") + df.write.format("avro").save(deflateDir) + spark.conf.set(SQLConf.AVRO_COMPRESSION_CODEC.key, "snappy") + df.write.format("avro").save(snappyDir) + + val uncompressSize = FileUtils.sizeOfDirectory(new File(uncompressDir)) + val bzip2Size = FileUtils.sizeOfDirectory(new File(bzip2Dir)) + val xzSize = FileUtils.sizeOfDirectory(new File(xzDir)) + val deflateSize = FileUtils.sizeOfDirectory(new File(deflateDir)) + val snappySize = FileUtils.sizeOfDirectory(new File(snappyDir)) + + assert(uncompressSize > deflateSize) + assert(snappySize > deflateSize) + assert(snappySize > bzip2Size) + assert(bzip2Size > xzSize) + } + } + + test("dsl test") { + val results = spark.read.format("avro").load(episodesAvro).select("title").collect() + assert(results.length === 8) + } + + test("old avro data source name works") { + val results = + spark.read.format("com.databricks.spark.avro") + .load(episodesAvro).select("title").collect() + assert(results.length === 8) + } + + test("support of various data types") { + // This test uses data from test.avro. You can see the data and the schema of this file in + // test.json and test.avsc + val all = spark.read.format("avro").load(testAvro).collect() + assert(all.length == 3) + + val str = spark.read.format("avro").load(testAvro).select("string").collect() + assert(str.map(_(0)).toSet.contains("Terran is IMBA!")) + + val simple_map = spark.read.format("avro").load(testAvro).select("simple_map").collect() + assert(simple_map(0)(0).getClass.toString.contains("Map")) + assert(simple_map.map(_(0).asInstanceOf[Map[String, Some[Int]]].size).toSet == Set(2, 0)) + + val union0 = spark.read.format("avro").load(testAvro).select("union_string_null").collect() + assert(union0.map(_(0)).toSet == Set("abc", "123", null)) + + val union1 = spark.read.format("avro").load(testAvro).select("union_int_long_null").collect() + assert(union1.map(_(0)).toSet == Set(66, 1, null)) + + val union2 = spark.read.format("avro").load(testAvro).select("union_float_double").collect() + assert( + union2 + .map(x => new java.lang.Double(x(0).toString)) + .exists(p => Math.abs(p - Math.PI) < 0.001)) + + val fixed = spark.read.format("avro").load(testAvro).select("fixed3").collect() + assert(fixed.map(_(0).asInstanceOf[Array[Byte]]).exists(p => p(1) == 3)) + + val enum = spark.read.format("avro").load(testAvro).select("enum").collect() + assert(enum.map(_(0)).toSet == Set("SPADES", "CLUBS", "DIAMONDS")) + + val record = spark.read.format("avro").load(testAvro).select("record").collect() + assert(record(0)(0).getClass.toString.contains("Row")) + assert(record.map(_(0).asInstanceOf[Row](0)).contains("TEST_STR123")) + + val array_of_boolean = + spark.read.format("avro").load(testAvro).select("array_of_boolean").collect() + assert(array_of_boolean.map(_(0).asInstanceOf[Seq[Boolean]].size).toSet == Set(3, 1, 0)) + + val bytes = spark.read.format("avro").load(testAvro).select("bytes").collect() + assert(bytes.map(_(0).asInstanceOf[Array[Byte]].length).toSet == Set(3, 1, 0)) + } + + test("sql test") { + spark.sql( + s""" + |CREATE TEMPORARY VIEW avroTable + |USING avro + |OPTIONS (path "${episodesAvro}") + """.stripMargin.replaceAll("\n", " ")) + + assert(spark.sql("SELECT * FROM avroTable").collect().length === 8) + } + + test("conversion to avro and back") { + // Note that test.avro includes a variety of types, some of which are nullable. We expect to + // get the same values back. + withTempPath { dir => + val avroDir = s"$dir/avro" + spark.read.format("avro").load(testAvro).write.format("avro").save(avroDir) + checkReloadMatchesSaved(testAvro, avroDir) + } + } + + test("conversion to avro and back with namespace") { + // Note that test.avro includes a variety of types, some of which are nullable. We expect to + // get the same values back. + withTempPath { tempDir => + val name = "AvroTest" + val namespace = "org.apache.spark.avro" + val parameters = Map("recordName" -> name, "recordNamespace" -> namespace) + + val avroDir = tempDir + "/namedAvro" + spark.read.format("avro").load(testAvro) + .write.options(parameters).format("avro").save(avroDir) + checkReloadMatchesSaved(testAvro, avroDir) + + // Look at raw file and make sure has namespace info + val rawSaved = spark.sparkContext.textFile(avroDir) + val schema = rawSaved.collect().mkString("") + assert(schema.contains(name)) + assert(schema.contains(namespace)) + } + } + + test("converting some specific sparkSQL types to avro") { + withTempPath { tempDir => + val testSchema = StructType(Seq( + StructField("Name", StringType, false), + StructField("Length", IntegerType, true), + StructField("Time", TimestampType, false), + StructField("Decimal", DecimalType(10, 2), true), + StructField("Binary", BinaryType, false))) + + val arrayOfByte = new Array[Byte](4) + for (i <- arrayOfByte.indices) { + arrayOfByte(i) = i.toByte + } + val cityRDD = spark.sparkContext.parallelize(Seq( + Row("San Francisco", 12, new Timestamp(666), null, arrayOfByte), + Row("Palo Alto", null, new Timestamp(777), null, arrayOfByte), + Row("Munich", 8, new Timestamp(42), Decimal(3.14), arrayOfByte))) + val cityDataFrame = spark.createDataFrame(cityRDD, testSchema) + + val avroDir = tempDir + "/avro" + cityDataFrame.write.format("avro").save(avroDir) + assert(spark.read.format("avro").load(avroDir).collect().length == 3) + + // TimesStamps are converted to longs + val times = spark.read.format("avro").load(avroDir).select("Time").collect() + assert(times.map(_(0)).toSet == + Set(new Timestamp(666), new Timestamp(777), new Timestamp(42))) + + // DecimalType should be converted to string + val decimals = spark.read.format("avro").load(avroDir).select("Decimal").collect() + assert(decimals.map(_(0)).contains(new java.math.BigDecimal("3.14"))) + + // There should be a null entry + val length = spark.read.format("avro").load(avroDir).select("Length").collect() + assert(length.map(_(0)).contains(null)) + + val binary = spark.read.format("avro").load(avroDir).select("Binary").collect() + for (i <- arrayOfByte.indices) { + assert(binary(1)(0).asInstanceOf[Array[Byte]](i) == arrayOfByte(i)) + } + } + } + + test("correctly read long as date/timestamp type") { + withTempPath { tempDir => + val currentTime = new Timestamp(System.currentTimeMillis()) + val currentDate = new Date(System.currentTimeMillis()) + val schema = StructType(Seq( + StructField("_1", DateType, false), StructField("_2", TimestampType, false))) + val writeDs = Seq((currentDate, currentTime)).toDS + + val avroDir = tempDir + "/avro" + writeDs.write.format("avro").save(avroDir) + assert(spark.read.format("avro").load(avroDir).collect().length == 1) + + val readDs = spark.read.schema(schema).format("avro").load(avroDir).as[(Date, Timestamp)] + + assert(readDs.collect().sameElements(writeDs.collect())) + } + } + + test("support of globbed paths") { + val resourceDir = testFile(".") + val e1 = spark.read.format("avro").load(resourceDir + "../*/episodes.avro").collect() + assert(e1.length == 8) + + val e2 = spark.read.format("avro").load(resourceDir + "../../*/*/episodes.avro").collect() + assert(e2.length == 8) + } + + test("does not coerce null date/timestamp value to 0 epoch.") { + withTempPath { tempDir => + val nullTime: Timestamp = null + val nullDate: Date = null + val schema = StructType(Seq( + StructField("_1", DateType, nullable = true), + StructField("_2", TimestampType, nullable = true)) + ) + val writeDs = Seq((nullDate, nullTime)).toDS + + val avroDir = tempDir + "/avro" + writeDs.write.format("avro").save(avroDir) + val readValues = + spark.read.schema(schema).format("avro").load(avroDir).as[(Date, Timestamp)].collect + + assert(readValues.size == 1) + assert(readValues.head == ((nullDate, nullTime))) + } + } + + test("support user provided avro schema") { + val avroSchema = + """ + |{ + | "type" : "record", + | "name" : "test_schema", + | "fields" : [{ + | "name" : "string", + | "type" : "string", + | "doc" : "Meaningless string of characters" + | }] + |} + """.stripMargin + val result = spark + .read + .option("avroSchema", avroSchema) + .format("avro") + .load(testAvro) + .collect() + val expected = spark.read.format("avro").load(testAvro).select("string").collect() + assert(result.sameElements(expected)) + } + + test("support user provided avro schema with defaults for missing fields") { + val avroSchema = + """ + |{ + | "type" : "record", + | "name" : "test_schema", + | "fields" : [{ + | "name" : "missingField", + | "type" : "string", + | "default" : "foo" + | }] + |} + """.stripMargin + val result = spark + .read + .option("avroSchema", avroSchema) + .format("avro").load(testAvro).select("missingField").first + assert(result === Row("foo")) + } + + test("support user provided avro schema for writing nullable enum type") { + withTempPath { tempDir => + val avroSchema = + """ + |{ + | "type" : "record", + | "name" : "test_schema", + | "fields" : [{ + | "name": "enum", + | "type": [{ "type": "enum", + | "name": "Suit", + | "symbols" : ["SPADES", "HEARTS", "DIAMONDS", "CLUBS"] + | }, "null"] + | }] + |} + """.stripMargin + + val df = spark.createDataFrame(spark.sparkContext.parallelize(Seq( + Row("SPADES"), Row(null), Row("HEARTS"), Row("DIAMONDS"), + Row(null), Row("CLUBS"), Row("HEARTS"), Row("SPADES"))), + StructType(Seq(StructField("Suit", StringType, true)))) + + val tempSaveDir = s"$tempDir/save/" + + df.write.format("avro").option("avroSchema", avroSchema).save(tempSaveDir) + + checkAnswer(df, spark.read.format("avro").load(tempSaveDir)) + checkAvroSchemaEquals(avroSchema, getAvroSchemaStringFromFiles(tempSaveDir)) + + // Writing df containing data not in the enum will throw an exception + val message = intercept[SparkException] { + spark.createDataFrame(spark.sparkContext.parallelize(Seq( + Row("SPADES"), Row("NOT-IN-ENUM"), Row("HEARTS"), Row("DIAMONDS"))), + StructType(Seq(StructField("Suit", StringType, true)))) + .write.format("avro").option("avroSchema", avroSchema) + .save(s"$tempDir/${UUID.randomUUID()}") + }.getCause.getMessage + assert(message.contains("org.apache.spark.sql.avro.IncompatibleSchemaException: " + + "Cannot write \"NOT-IN-ENUM\" since it's not defined in enum")) + } + } + + test("support user provided avro schema for writing non-nullable enum type") { + withTempPath { tempDir => + val avroSchema = + """ + |{ + | "type" : "record", + | "name" : "test_schema", + | "fields" : [{ + | "name": "enum", + | "type": { "type": "enum", + | "name": "Suit", + | "symbols" : ["SPADES", "HEARTS", "DIAMONDS", "CLUBS"] + | } + | }] + |} + """.stripMargin + + val dfWithNull = spark.createDataFrame(spark.sparkContext.parallelize(Seq( + Row("SPADES"), Row(null), Row("HEARTS"), Row("DIAMONDS"), + Row(null), Row("CLUBS"), Row("HEARTS"), Row("SPADES"))), + StructType(Seq(StructField("Suit", StringType, true)))) + + val df = spark.createDataFrame(dfWithNull.na.drop().rdd, + StructType(Seq(StructField("Suit", StringType, false)))) + + val tempSaveDir = s"$tempDir/save/" + + df.write.format("avro").option("avroSchema", avroSchema).save(tempSaveDir) + + checkAnswer(df, spark.read.format("avro").load(tempSaveDir)) + checkAvroSchemaEquals(avroSchema, getAvroSchemaStringFromFiles(tempSaveDir)) + + // Writing df containing nulls without using avro union type will + // throw an exception as avro uses union type to handle null. + val message1 = intercept[SparkException] { + dfWithNull.write.format("avro") + .option("avroSchema", avroSchema).save(s"$tempDir/${UUID.randomUUID()}") + }.getCause.getMessage + assert(message1.contains("org.apache.avro.AvroRuntimeException: Not a union:")) + + // Writing df containing data not in the enum will throw an exception + val message2 = intercept[SparkException] { + spark.createDataFrame(spark.sparkContext.parallelize(Seq( + Row("SPADES"), Row("NOT-IN-ENUM"), Row("HEARTS"), Row("DIAMONDS"))), + StructType(Seq(StructField("Suit", StringType, false)))) + .write.format("avro").option("avroSchema", avroSchema) + .save(s"$tempDir/${UUID.randomUUID()}") + }.getCause.getMessage + assert(message2.contains("org.apache.spark.sql.avro.IncompatibleSchemaException: " + + "Cannot write \"NOT-IN-ENUM\" since it's not defined in enum")) + } + } + + test("support user provided avro schema for writing nullable fixed type") { + withTempPath { tempDir => + val avroSchema = + """ + |{ + | "type" : "record", + | "name" : "test_schema", + | "fields" : [{ + | "name": "fixed2", + | "type": [{ "type": "fixed", + | "size": 2, + | "name": "fixed2" + | }, "null"] + | }] + |} + """.stripMargin + + val df = spark.createDataFrame(spark.sparkContext.parallelize(Seq( + Row(Array(192, 168).map(_.toByte)), Row(null))), + StructType(Seq(StructField("fixed2", BinaryType, true)))) + + val tempSaveDir = s"$tempDir/save/" + + df.write.format("avro").option("avroSchema", avroSchema).save(tempSaveDir) + + checkAnswer(df, spark.read.format("avro").load(tempSaveDir)) + checkAvroSchemaEquals(avroSchema, getAvroSchemaStringFromFiles(tempSaveDir)) + + // Writing df containing binary data that doesn't fit FIXED size will throw an exception + val message1 = intercept[SparkException] { + spark.createDataFrame(spark.sparkContext.parallelize(Seq( + Row(Array(192, 168, 1).map(_.toByte)))), + StructType(Seq(StructField("fixed2", BinaryType, true)))) + .write.format("avro").option("avroSchema", avroSchema) + .save(s"$tempDir/${UUID.randomUUID()}") + }.getCause.getMessage + assert(message1.contains("org.apache.spark.sql.avro.IncompatibleSchemaException: " + + "Cannot write 3 bytes of binary data into FIXED Type with size of 2 bytes")) + + // Writing df containing binary data that doesn't fit FIXED size will throw an exception + val message2 = intercept[SparkException] { + spark.createDataFrame(spark.sparkContext.parallelize(Seq( + Row(Array(192).map(_.toByte)))), + StructType(Seq(StructField("fixed2", BinaryType, true)))) + .write.format("avro").option("avroSchema", avroSchema) + .save(s"$tempDir/${UUID.randomUUID()}") + }.getCause.getMessage + assert(message2.contains("org.apache.spark.sql.avro.IncompatibleSchemaException: " + + "Cannot write 1 byte of binary data into FIXED Type with size of 2 bytes")) + } + } + + test("support user provided avro schema for writing non-nullable fixed type") { + withTempPath { tempDir => + val avroSchema = + """ + |{ + | "type" : "record", + | "name" : "test_schema", + | "fields" : [{ + | "name": "fixed2", + | "type": { "type": "fixed", + | "size": 2, + | "name": "fixed2" + | } + | }] + |} + """.stripMargin + + val df = spark.createDataFrame(spark.sparkContext.parallelize(Seq( + Row(Array(192, 168).map(_.toByte)), Row(Array(1, 1).map(_.toByte)))), + StructType(Seq(StructField("fixed2", BinaryType, false)))) + + val tempSaveDir = s"$tempDir/save/" + + df.write.format("avro").option("avroSchema", avroSchema).save(tempSaveDir) + + checkAnswer(df, spark.read.format("avro").load(tempSaveDir)) + checkAvroSchemaEquals(avroSchema, getAvroSchemaStringFromFiles(tempSaveDir)) + + // Writing df containing binary data that doesn't fit FIXED size will throw an exception + val message1 = intercept[SparkException] { + spark.createDataFrame(spark.sparkContext.parallelize(Seq( + Row(Array(192, 168, 1).map(_.toByte)))), + StructType(Seq(StructField("fixed2", BinaryType, false)))) + .write.format("avro").option("avroSchema", avroSchema) + .save(s"$tempDir/${UUID.randomUUID()}") + }.getCause.getMessage + assert(message1.contains("org.apache.spark.sql.avro.IncompatibleSchemaException: " + + "Cannot write 3 bytes of binary data into FIXED Type with size of 2 bytes")) + + // Writing df containing binary data that doesn't fit FIXED size will throw an exception + val message2 = intercept[SparkException] { + spark.createDataFrame(spark.sparkContext.parallelize(Seq( + Row(Array(192).map(_.toByte)))), + StructType(Seq(StructField("fixed2", BinaryType, false)))) + .write.format("avro").option("avroSchema", avroSchema) + .save(s"$tempDir/${UUID.randomUUID()}") + }.getCause.getMessage + assert(message2.contains("org.apache.spark.sql.avro.IncompatibleSchemaException: " + + "Cannot write 1 byte of binary data into FIXED Type with size of 2 bytes")) + } + } + + test("throw exception if unable to write with user provided Avro schema") { + val input: Seq[(DataType, Schema.Type)] = Seq( + (NullType, NULL), + (BooleanType, BOOLEAN), + (ByteType, INT), + (ShortType, INT), + (IntegerType, INT), + (LongType, LONG), + (FloatType, FLOAT), + (DoubleType, DOUBLE), + (BinaryType, BYTES), + (DateType, INT), + (TimestampType, LONG), + (DecimalType(4, 2), BYTES) + ) + def assertException(f: () => AvroSerializer) { + val message = intercept[org.apache.spark.sql.avro.IncompatibleSchemaException] { + f() + }.getMessage + assert(message.contains("Cannot convert Catalyst type")) + } + + def resolveNullable(schema: Schema, nullable: Boolean): Schema = { + if (nullable && schema.getType != NULL) { + Schema.createUnion(schema, Schema.create(NULL)) + } else { + schema + } + } + for { + i <- input + j <- input + nullable <- Seq(true, false) + } if (i._2 != j._2) { + val avroType = resolveNullable(Schema.create(j._2), nullable) + val avroArrayType = resolveNullable(Schema.createArray(avroType), nullable) + val avroMapType = resolveNullable(Schema.createMap(avroType), nullable) + val name = "foo" + val avroField = new Field(name, avroType, "", null) + val recordSchema = Schema.createRecord("name", "doc", "space", true, Seq(avroField).asJava) + val avroRecordType = resolveNullable(recordSchema, nullable) + + val catalystType = i._1 + val catalystArrayType = ArrayType(catalystType, nullable) + val catalystMapType = MapType(StringType, catalystType, nullable) + val catalystStructType = StructType(Seq(StructField(name, catalystType, nullable))) + + for { + avro <- Seq(avroType, avroArrayType, avroMapType, avroRecordType) + catalyst <- Seq(catalystType, catalystArrayType, catalystMapType, catalystStructType) + } { + assertException(() => new AvroSerializer(catalyst, avro, nullable)) + } + } + } + + test("reading from invalid path throws exception") { + + // Directory given has no avro files + intercept[AnalysisException] { + withTempPath(dir => spark.read.format("avro").load(dir.getCanonicalPath)) + } + + intercept[AnalysisException] { + spark.read.format("avro").load("very/invalid/path/123.avro") + } + + // In case of globbed path that can't be matched to anything, another exception is thrown (and + // exception message is helpful) + intercept[AnalysisException] { + spark.read.format("avro").load("*/*/*/*/*/*/*/something.avro") + } + + intercept[FileNotFoundException] { + withTempPath { dir => + FileUtils.touch(new File(dir, "test")) + withSQLConf(AvroFileFormat.IgnoreFilesWithoutExtensionProperty -> "true") { + spark.read.format("avro").load(dir.toString) + } + } + } + + intercept[FileNotFoundException] { + withTempPath { dir => + FileUtils.touch(new File(dir, "test")) + + spark + .read + .option("ignoreExtension", false) + .format("avro") + .load(dir.toString) + } + } + } + + test("SQL test insert overwrite") { + withTempPath { tempDir => + val tempEmptyDir = s"$tempDir/sqlOverwrite" + // Create a temp directory for table that will be overwritten + new File(tempEmptyDir).mkdirs() + spark.sql( + s""" + |CREATE TEMPORARY VIEW episodes + |USING avro + |OPTIONS (path "${episodesAvro}") + """.stripMargin.replaceAll("\n", " ")) + spark.sql( + s""" + |CREATE TEMPORARY VIEW episodesEmpty + |(name string, air_date string, doctor int) + |USING avro + |OPTIONS (path "$tempEmptyDir") + """.stripMargin.replaceAll("\n", " ")) + + assert(spark.sql("SELECT * FROM episodes").collect().length === 8) + assert(spark.sql("SELECT * FROM episodesEmpty").collect().isEmpty) + + spark.sql( + s""" + |INSERT OVERWRITE TABLE episodesEmpty + |SELECT * FROM episodes + """.stripMargin.replaceAll("\n", " ")) + assert(spark.sql("SELECT * FROM episodesEmpty").collect().length == 8) + } + } + + test("test save and load") { + // Test if load works as expected + withTempPath { tempDir => + val df = spark.read.format("avro").load(episodesAvro) + assert(df.count == 8) + + val tempSaveDir = s"$tempDir/save/" + + df.write.format("avro").save(tempSaveDir) + val newDf = spark.read.format("avro").load(tempSaveDir) + assert(newDf.count == 8) + } + } + + test("test load with non-Avro file") { + // Test if load works as expected + withTempPath { tempDir => + val df = spark.read.format("avro").load(episodesAvro) + assert(df.count == 8) + + val tempSaveDir = s"$tempDir/save/" + df.write.format("avro").save(tempSaveDir) + + Files.createFile(new File(tempSaveDir, "non-avro").toPath) + + withSQLConf(AvroFileFormat.IgnoreFilesWithoutExtensionProperty -> "true") { + val newDf = spark.read.format("avro").load(tempSaveDir) + assert(newDf.count() == 8) + } + } + } + + test("read avro with user defined schema: read partial columns") { + val partialColumns = StructType(Seq( + StructField("string", StringType, false), + StructField("simple_map", MapType(StringType, IntegerType), false), + StructField("complex_map", MapType(StringType, MapType(StringType, StringType)), false), + StructField("union_string_null", StringType, true), + StructField("union_int_long_null", LongType, true), + StructField("fixed3", BinaryType, true), + StructField("fixed2", BinaryType, true), + StructField("enum", StringType, false), + StructField("record", StructType(Seq(StructField("value_field", StringType, false))), false), + StructField("array_of_boolean", ArrayType(BooleanType), false), + StructField("bytes", BinaryType, true))) + val withSchema = spark.read.schema(partialColumns).format("avro").load(testAvro).collect() + val withOutSchema = spark + .read + .format("avro") + .load(testAvro) + .select("string", "simple_map", "complex_map", "union_string_null", "union_int_long_null", + "fixed3", "fixed2", "enum", "record", "array_of_boolean", "bytes") + .collect() + assert(withSchema.sameElements(withOutSchema)) + } + + test("read avro with user defined schema: read non-exist columns") { + val schema = + StructType( + Seq( + StructField("non_exist_string", StringType, true), + StructField( + "record", + StructType(Seq( + StructField("non_exist_field", StringType, false), + StructField("non_exist_field2", StringType, false))), + false))) + val withEmptyColumn = spark.read.schema(schema).format("avro").load(testAvro).collect() + + assert(withEmptyColumn.forall(_ == Row(null: String, Row(null: String, null: String)))) + } + + test("read avro file partitioned") { + withTempPath { dir => + val df = (0 to 1024 * 3).toDS.map(i => s"record${i}").toDF("records") + val outputDir = s"$dir/${UUID.randomUUID}" + df.write.format("avro").save(outputDir) + val input = spark.read.format("avro").load(outputDir) + assert(input.collect.toSet.size === 1024 * 3 + 1) + assert(input.rdd.partitions.size > 2) + } + } + + case class NestedBottom(id: Int, data: String) + + case class NestedMiddle(id: Int, data: NestedBottom) + + case class NestedTop(id: Int, data: NestedMiddle) + + test("Validate namespace in avro file that has nested records with the same name") { + withTempPath { dir => + val writeDf = spark.createDataFrame(List(NestedTop(1, NestedMiddle(2, NestedBottom(3, "1"))))) + writeDf.write.format("avro").save(dir.toString) + val schema = getAvroSchemaStringFromFiles(dir.toString) + assert(schema.contains("\"namespace\":\"topLevelRecord\"")) + assert(schema.contains("\"namespace\":\"topLevelRecord.data\"")) + } + } + + test("saving avro that has nested records with the same name") { + withTempPath { tempDir => + // Save avro file on output folder path + val writeDf = spark.createDataFrame(List(NestedTop(1, NestedMiddle(2, NestedBottom(3, "1"))))) + val outputFolder = s"$tempDir/duplicate_names/" + writeDf.write.format("avro").save(outputFolder) + // Read avro file saved on the last step + val readDf = spark.read.format("avro").load(outputFolder) + // Check if the written DataFrame is equals than read DataFrame + assert(readDf.collect().sameElements(writeDf.collect())) + } + } + + test("check namespace - toAvroType") { + val sparkSchema = StructType(Seq( + StructField("name", StringType, nullable = false), + StructField("address", StructType(Seq( + StructField("city", StringType, nullable = false), + StructField("state", StringType, nullable = false))), + nullable = false))) + val employeeType = SchemaConverters.toAvroType(sparkSchema, + recordName = "employee", + nameSpace = "foo.bar") + + assert(employeeType.getFullName == "foo.bar.employee") + assert(employeeType.getName == "employee") + assert(employeeType.getNamespace == "foo.bar") + + val addressType = employeeType.getField("address").schema() + assert(addressType.getFullName == "foo.bar.employee.address") + assert(addressType.getName == "address") + assert(addressType.getNamespace == "foo.bar.employee") + } + + test("check empty namespace - toAvroType") { + val sparkSchema = StructType(Seq( + StructField("name", StringType, nullable = false), + StructField("address", StructType(Seq( + StructField("city", StringType, nullable = false), + StructField("state", StringType, nullable = false))), + nullable = false))) + val employeeType = SchemaConverters.toAvroType(sparkSchema, + recordName = "employee") + + assert(employeeType.getFullName == "employee") + assert(employeeType.getName == "employee") + assert(employeeType.getNamespace == null) + + val addressType = employeeType.getField("address").schema() + assert(addressType.getFullName == "employee.address") + assert(addressType.getName == "address") + assert(addressType.getNamespace == "employee") + } + + case class NestedMiddleArray(id: Int, data: Array[NestedBottom]) + + case class NestedTopArray(id: Int, data: NestedMiddleArray) + + test("saving avro that has nested records with the same name inside an array") { + withTempPath { tempDir => + // Save avro file on output folder path + val writeDf = spark.createDataFrame( + List(NestedTopArray(1, NestedMiddleArray(2, Array( + NestedBottom(3, "1"), NestedBottom(4, "2") + )))) + ) + val outputFolder = s"$tempDir/duplicate_names_array/" + writeDf.write.format("avro").save(outputFolder) + // Read avro file saved on the last step + val readDf = spark.read.format("avro").load(outputFolder) + // Check if the written DataFrame is equals than read DataFrame + assert(readDf.collect().sameElements(writeDf.collect())) + } + } + + case class NestedMiddleMap(id: Int, data: Map[String, NestedBottom]) + + case class NestedTopMap(id: Int, data: NestedMiddleMap) + + test("saving avro that has nested records with the same name inside a map") { + withTempPath { tempDir => + // Save avro file on output folder path + val writeDf = spark.createDataFrame( + List(NestedTopMap(1, NestedMiddleMap(2, Map( + "1" -> NestedBottom(3, "1"), "2" -> NestedBottom(4, "2") + )))) + ) + val outputFolder = s"$tempDir/duplicate_names_map/" + writeDf.write.format("avro").save(outputFolder) + // Read avro file saved on the last step + val readDf = spark.read.format("avro").load(outputFolder) + // Check if the written DataFrame is equals than read DataFrame + assert(readDf.collect().sameElements(writeDf.collect())) + } + } + + test("SPARK-24805: do not ignore files without .avro extension by default") { + withTempDir { dir => + Files.copy( + Paths.get(new URL(episodesAvro).toURI), + Paths.get(dir.getCanonicalPath, "episodes")) + + val fileWithoutExtension = s"${dir.getCanonicalPath}/episodes" + val df1 = spark.read.format("avro").load(fileWithoutExtension) + assert(df1.count == 8) + + val schema = new StructType() + .add("title", StringType) + .add("air_date", StringType) + .add("doctor", IntegerType) + val df2 = spark.read.schema(schema).format("avro").load(fileWithoutExtension) + assert(df2.count == 8) + } + } + + test("SPARK-24836: checking the ignoreExtension option") { + withTempPath { tempDir => + val df = spark.read.format("avro").load(episodesAvro) + assert(df.count == 8) + + val tempSaveDir = s"$tempDir/save/" + df.write.format("avro").save(tempSaveDir) + + Files.createFile(new File(tempSaveDir, "non-avro").toPath) + + val newDf = spark + .read + .option("ignoreExtension", false) + .format("avro") + .load(tempSaveDir) + + assert(newDf.count == 8) + } + } + + test("SPARK-24836: ignoreExtension must override hadoop's config") { + withTempDir { dir => + Files.copy( + Paths.get(new URL(episodesAvro).toURI), + Paths.get(dir.getCanonicalPath, "episodes")) + + val hadoopConf = spark.sessionState.newHadoopConf() + withSQLConf(AvroFileFormat.IgnoreFilesWithoutExtensionProperty -> "true") { + val newDf = spark + .read + .option("ignoreExtension", "true") + .format("avro") + .load(s"${dir.getCanonicalPath}/episodes") + assert(newDf.count() == 8) + } + } + } + + test("SPARK-24881: write with compression - avro options") { + def getCodec(dir: String): Option[String] = { + val files = new File(dir) + .listFiles() + .filter(_.isFile) + .filter(_.getName.endsWith("avro")) + files.map { file => + val reader = new DataFileReader(file, new GenericDatumReader[Any]()) + val r = reader.getMetaString("avro.codec") + r + }.map(v => if (v == "null") "uncompressed" else v).headOption + } + def checkCodec(df: DataFrame, dir: String, codec: String): Unit = { + val subdir = s"$dir/$codec" + df.write.option("compression", codec).format("avro").save(subdir) + assert(getCodec(subdir) == Some(codec)) + } + withTempPath { dir => + val path = dir.toString + val df = spark.read.format("avro").load(testAvro) + + checkCodec(df, path, "uncompressed") + checkCodec(df, path, "deflate") + checkCodec(df, path, "snappy") + checkCodec(df, path, "bzip2") + checkCodec(df, path, "xz") + } + } +} diff --git a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/OracleIntegrationSuite.scala b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/OracleIntegrationSuite.scala index 8512496e5fe52..09a2cd83aed6b 100644 --- a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/OracleIntegrationSuite.scala +++ b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/OracleIntegrationSuite.scala @@ -17,12 +17,14 @@ package org.apache.spark.sql.jdbc +import java.math.BigDecimal import java.sql.{Connection, Date, Timestamp} import java.util.{Properties, TimeZone} -import java.math.BigDecimal -import org.apache.spark.sql.{DataFrame, QueryTest, Row, SaveMode} +import org.apache.spark.sql.{Row, SaveMode} import org.apache.spark.sql.execution.{RowDataSourceScanExec, WholeStageCodegenExec} +import org.apache.spark.sql.execution.datasources.LogicalRelation +import org.apache.spark.sql.execution.datasources.jdbc.{JDBCPartition, JDBCRelation} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ @@ -86,7 +88,8 @@ class OracleIntegrationSuite extends DockerJDBCIntegrationSuite with SharedSQLCo conn.prepareStatement( "CREATE TABLE tableWithCustomSchema (id NUMBER, n1 NUMBER(1), n2 NUMBER(1))").executeUpdate() conn.prepareStatement( - "INSERT INTO tableWithCustomSchema values(12312321321321312312312312123, 1, 0)").executeUpdate() + "INSERT INTO tableWithCustomSchema values(12312321321321312312312312123, 1, 0)") + .executeUpdate() conn.commit() sql( @@ -108,15 +111,36 @@ class OracleIntegrationSuite extends DockerJDBCIntegrationSuite with SharedSQLCo """.stripMargin.replaceAll("\n", " ")) - conn.prepareStatement("CREATE TABLE numerics (b DECIMAL(1), f DECIMAL(3, 2), i DECIMAL(10))").executeUpdate() + conn.prepareStatement("CREATE TABLE numerics (b DECIMAL(1), f DECIMAL(3, 2), i DECIMAL(10))") + .executeUpdate() conn.prepareStatement( "INSERT INTO numerics VALUES (4, 1.23, 9999999999)").executeUpdate() conn.commit() - conn.prepareStatement("CREATE TABLE oracle_types (d BINARY_DOUBLE, f BINARY_FLOAT)").executeUpdate() + conn.prepareStatement("CREATE TABLE oracle_types (d BINARY_DOUBLE, f BINARY_FLOAT)") + .executeUpdate() conn.commit() - } + conn.prepareStatement("CREATE TABLE datetimePartitionTest (id NUMBER(10), d DATE, t TIMESTAMP)") + .executeUpdate() + conn.prepareStatement( + """INSERT INTO datetimePartitionTest VALUES + |(1, {d '2018-07-06'}, {ts '2018-07-06 05:50:00'}) + """.stripMargin.replaceAll("\n", " ")).executeUpdate() + conn.prepareStatement( + """INSERT INTO datetimePartitionTest VALUES + |(2, {d '2018-07-06'}, {ts '2018-07-06 08:10:08'}) + """.stripMargin.replaceAll("\n", " ")).executeUpdate() + conn.prepareStatement( + """INSERT INTO datetimePartitionTest VALUES + |(3, {d '2018-07-08'}, {ts '2018-07-08 13:32:01'}) + """.stripMargin.replaceAll("\n", " ")).executeUpdate() + conn.prepareStatement( + """INSERT INTO datetimePartitionTest VALUES + |(4, {d '2018-07-12'}, {ts '2018-07-12 09:51:15'}) + """.stripMargin.replaceAll("\n", " ")).executeUpdate() + conn.commit() + } test("SPARK-16625 : Importing Oracle numeric types") { val df = sqlContext.read.jdbc(jdbcUrl, "numerics", new Properties) @@ -399,4 +423,54 @@ class OracleIntegrationSuite extends DockerJDBCIntegrationSuite with SharedSQLCo assert(values.getDouble(0) === 1.1) assert(values.getFloat(1) === 2.2f) } + + test("SPARK-22814 support date/timestamp types in partitionColumn") { + val expectedResult = Set( + (1, "2018-07-06", "2018-07-06 05:50:00"), + (2, "2018-07-06", "2018-07-06 08:10:08"), + (3, "2018-07-08", "2018-07-08 13:32:01"), + (4, "2018-07-12", "2018-07-12 09:51:15") + ).map { case (id, date, timestamp) => + Row(BigDecimal.valueOf(id), Date.valueOf(date), Timestamp.valueOf(timestamp)) + } + + // DateType partition column + val df1 = spark.read.format("jdbc") + .option("url", jdbcUrl) + .option("dbtable", "datetimePartitionTest") + .option("partitionColumn", "d") + .option("lowerBound", "2018-07-06") + .option("upperBound", "2018-07-20") + .option("numPartitions", 3) + .load() + + df1.logicalPlan match { + case LogicalRelation(JDBCRelation(_, parts, _), _, _, _) => + val whereClauses = parts.map(_.asInstanceOf[JDBCPartition].whereClause).toSet + assert(whereClauses === Set( + """"D" < '2018-07-10' or "D" is null""", + """"D" >= '2018-07-10' AND "D" < '2018-07-14'""", + """"D" >= '2018-07-14'""")) + } + assert(df1.collect.toSet === expectedResult) + + // TimestampType partition column + val df2 = spark.read.format("jdbc") + .option("url", jdbcUrl) + .option("dbtable", "datetimePartitionTest") + .option("partitionColumn", "t") + .option("lowerBound", "2018-07-04 03:30:00.0") + .option("upperBound", "2018-07-27 14:11:05.0") + .option("numPartitions", 2) + .load() + + df2.logicalPlan match { + case LogicalRelation(JDBCRelation(_, parts, _), _, _, _) => + val whereClauses = parts.map(_.asInstanceOf[JDBCPartition].whereClause).toSet + assert(whereClauses === Set( + """"T" < '2018-07-15 20:50:32.5' or "T" is null""", + """"T" >= '2018-07-15 20:50:32.5'""")) + } + assert(df2.collect.toSet === expectedResult) + } } diff --git a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala index 4324cc6d0f804..9241b13c100f1 100644 --- a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala +++ b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala @@ -50,13 +50,18 @@ class FlumePollingStreamSuite extends SparkFunSuite with BeforeAndAfterAll with val utils = new PollingFlumeTestUtils override def beforeAll(): Unit = { + super.beforeAll() _sc = new SparkContext(conf) } override def afterAll(): Unit = { - if (_sc != null) { - _sc.stop() - _sc = null + try { + if (_sc != null) { + _sc.stop() + _sc = null + } + } finally { + super.afterAll() } } diff --git a/external/kafka-0-10-assembly/pom.xml b/external/kafka-0-10-assembly/pom.xml index a742b8d6dbddb..f80f8e3a0183d 100644 --- a/external/kafka-0-10-assembly/pom.xml +++ b/external/kafka-0-10-assembly/pom.xml @@ -95,11 +95,6 @@ log4j provided - - net.java.dev.jets3t - jets3t - provided - org.scala-lang scala-library diff --git a/external/kafka-0-10-sql/pom.xml b/external/kafka-0-10-sql/pom.xml index 16bbc6db641ca..8588e8be052eb 100644 --- a/external/kafka-0-10-sql/pom.xml +++ b/external/kafka-0-10-sql/pom.xml @@ -29,10 +29,11 @@ spark-sql-kafka-0-10_2.11 sql-kafka-0-10 - 0.10.0.1 + + 2.0.0 jar - Kafka 0.10 Source for Structured Streaming + Kafka 0.10+ Source for Structured Streaming http://spark.apache.org/ @@ -73,6 +74,20 @@ kafka_${scala.binary.version} ${kafka.version} test + + + com.fasterxml.jackson.core + jackson-core + + + com.fasterxml.jackson.core + jackson-databind + + + com.fasterxml.jackson.core + jackson-annotations + + net.sf.jopt-simple @@ -80,6 +95,12 @@ 3.2 test + + org.eclipse.jetty + jetty-servlet + ${jetty.version} + test + org.scalacheck scalacheck_${scala.binary.version} @@ -108,13 +129,4 @@ target/scala-${scala.binary.version}/test-classes - - - scala-2.12 - - 0.10.1.1 - - - - diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/CachedKafkaProducer.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/CachedKafkaProducer.scala index 571140b0afbc7..cd680adf44365 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/CachedKafkaProducer.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/CachedKafkaProducer.scala @@ -33,8 +33,12 @@ private[kafka010] object CachedKafkaProducer extends Logging { private type Producer = KafkaProducer[Array[Byte], Array[Byte]] + private val defaultCacheExpireTimeout = TimeUnit.MINUTES.toMillis(10) + private lazy val cacheExpireTimeout: Long = - SparkEnv.get.conf.getTimeAsMs("spark.kafka.producer.cache.timeout", "10m") + Option(SparkEnv.get).map(_.conf.getTimeAsMs( + "spark.kafka.producer.cache.timeout", + s"${defaultCacheExpireTimeout}ms")).getOrElse(defaultCacheExpireTimeout) private val cacheLoader = new CacheLoader[Seq[(String, Object)], Producer] { override def load(config: Seq[(String, Object)]): Producer = { @@ -102,7 +106,7 @@ private[kafka010] object CachedKafkaProducer extends Logging { } } - private def clear(): Unit = { + private[kafka010] def clear(): Unit = { logInfo("Cleaning up guava cache.") guavaCache.invalidateAll() } diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReadSupport.scala similarity index 72% rename from external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala rename to external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReadSupport.scala index badaa69cc303c..1753a28fba2fb 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReadSupport.scala @@ -25,15 +25,15 @@ import org.apache.kafka.common.TopicPartition import org.apache.spark.TaskContext import org.apache.spark.internal.Logging -import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.kafka010.KafkaSourceProvider.{INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_FALSE, INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_TRUE} import org.apache.spark.sql.sources.v2.reader._ -import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousInputPartitionReader, ContinuousReader, Offset, PartitionOffset} +import org.apache.spark.sql.sources.v2.reader.streaming._ import org.apache.spark.sql.types.StructType /** - * A [[ContinuousReader]] for data from kafka. + * A [[ContinuousReadSupport]] for data from kafka. * * @param offsetReader a reader used to get kafka offsets. Note that the actual data will be * read by per-task consumers generated later. @@ -46,70 +46,49 @@ import org.apache.spark.sql.types.StructType * scenarios, where some offsets after the specified initial ones can't be * properly read. */ -class KafkaContinuousReader( +class KafkaContinuousReadSupport( offsetReader: KafkaOffsetReader, kafkaParams: ju.Map[String, Object], sourceOptions: Map[String, String], metadataPath: String, initialOffsets: KafkaOffsetRangeLimit, failOnDataLoss: Boolean) - extends ContinuousReader with SupportsScanUnsafeRow with Logging { - - private lazy val session = SparkSession.getActiveSession.get - private lazy val sc = session.sparkContext + extends ContinuousReadSupport with Logging { private val pollTimeoutMs = sourceOptions.getOrElse("kafkaConsumer.pollTimeoutMs", "512").toLong - // Initialized when creating reader factories. If this diverges from the partitions at the latest - // offsets, we need to reconfigure. - // Exposed outside this object only for unit tests. - @volatile private[sql] var knownPartitions: Set[TopicPartition] = _ - - override def readSchema: StructType = KafkaOffsetReader.kafkaSchema - - private var offset: Offset = _ - override def setStartOffset(start: ju.Optional[Offset]): Unit = { - offset = start.orElse { - val offsets = initialOffsets match { - case EarliestOffsetRangeLimit => KafkaSourceOffset(offsetReader.fetchEarliestOffsets()) - case LatestOffsetRangeLimit => KafkaSourceOffset(offsetReader.fetchLatestOffsets()) - case SpecificOffsetRangeLimit(p) => offsetReader.fetchSpecificOffsets(p, reportDataLoss) - } - logInfo(s"Initial offsets: $offsets") - offsets + override def initialOffset(): Offset = { + val offsets = initialOffsets match { + case EarliestOffsetRangeLimit => KafkaSourceOffset(offsetReader.fetchEarliestOffsets()) + case LatestOffsetRangeLimit => KafkaSourceOffset(offsetReader.fetchLatestOffsets()) + case SpecificOffsetRangeLimit(p) => offsetReader.fetchSpecificOffsets(p, reportDataLoss) } + logInfo(s"Initial offsets: $offsets") + offsets } - override def getStartOffset(): Offset = offset + override def fullSchema(): StructType = KafkaOffsetReader.kafkaSchema + + override def newScanConfigBuilder(start: Offset): ScanConfigBuilder = { + new KafkaContinuousScanConfigBuilder(fullSchema(), start, offsetReader, reportDataLoss) + } override def deserializeOffset(json: String): Offset = { KafkaSourceOffset(JsonUtils.partitionOffsets(json)) } - override def planUnsafeInputPartitions(): ju.List[InputPartition[UnsafeRow]] = { - import scala.collection.JavaConverters._ - - val oldStartPartitionOffsets = KafkaSourceOffset.getPartitionOffsets(offset) - - val currentPartitionSet = offsetReader.fetchEarliestOffsets().keySet - val newPartitions = currentPartitionSet.diff(oldStartPartitionOffsets.keySet) - val newPartitionOffsets = offsetReader.fetchEarliestOffsets(newPartitions.toSeq) - - val deletedPartitions = oldStartPartitionOffsets.keySet.diff(currentPartitionSet) - if (deletedPartitions.nonEmpty) { - reportDataLoss(s"Some partitions were deleted: $deletedPartitions") - } - - val startOffsets = newPartitionOffsets ++ - oldStartPartitionOffsets.filterKeys(!deletedPartitions.contains(_)) - knownPartitions = startOffsets.keySet - + override def planInputPartitions(config: ScanConfig): Array[InputPartition] = { + val startOffsets = config.asInstanceOf[KafkaContinuousScanConfig].startOffsets startOffsets.toSeq.map { case (topicPartition, start) => KafkaContinuousInputPartition( topicPartition, start, kafkaParams, pollTimeoutMs, failOnDataLoss) - .asInstanceOf[InputPartition[UnsafeRow]] - }.asJava + }.toArray + } + + override def createContinuousReaderFactory( + config: ScanConfig): ContinuousPartitionReaderFactory = { + KafkaContinuousReaderFactory } /** Stop this source and free any resources it has allocated. */ @@ -126,8 +105,9 @@ class KafkaContinuousReader( KafkaSourceOffset(mergedMap) } - override def needsReconfiguration(): Boolean = { - knownPartitions != null && offsetReader.fetchLatestOffsets().keySet != knownPartitions + override def needsReconfiguration(config: ScanConfig): Boolean = { + val knownPartitions = config.asInstanceOf[KafkaContinuousScanConfig].knownPartitions + offsetReader.fetchLatestOffsets().keySet != knownPartitions } override def toString(): String = s"KafkaSource[$offsetReader]" @@ -161,22 +141,51 @@ case class KafkaContinuousInputPartition( startOffset: Long, kafkaParams: ju.Map[String, Object], pollTimeoutMs: Long, - failOnDataLoss: Boolean) extends ContinuousInputPartition[UnsafeRow] { - - override def createContinuousReader(offset: PartitionOffset): InputPartitionReader[UnsafeRow] = { - val kafkaOffset = offset.asInstanceOf[KafkaSourcePartitionOffset] - require(kafkaOffset.topicPartition == topicPartition, - s"Expected topicPartition: $topicPartition, but got: ${kafkaOffset.topicPartition}") - new KafkaContinuousInputPartitionReader( - topicPartition, kafkaOffset.partitionOffset, kafkaParams, pollTimeoutMs, failOnDataLoss) + failOnDataLoss: Boolean) extends InputPartition + +object KafkaContinuousReaderFactory extends ContinuousPartitionReaderFactory { + override def createReader(partition: InputPartition): ContinuousPartitionReader[InternalRow] = { + val p = partition.asInstanceOf[KafkaContinuousInputPartition] + new KafkaContinuousPartitionReader( + p.topicPartition, p.startOffset, p.kafkaParams, p.pollTimeoutMs, p.failOnDataLoss) } +} + +class KafkaContinuousScanConfigBuilder( + schema: StructType, + startOffset: Offset, + offsetReader: KafkaOffsetReader, + reportDataLoss: String => Unit) + extends ScanConfigBuilder { + + override def build(): ScanConfig = { + val oldStartPartitionOffsets = KafkaSourceOffset.getPartitionOffsets(startOffset) + + val currentPartitionSet = offsetReader.fetchEarliestOffsets().keySet + val newPartitions = currentPartitionSet.diff(oldStartPartitionOffsets.keySet) + val newPartitionOffsets = offsetReader.fetchEarliestOffsets(newPartitions.toSeq) - override def createPartitionReader(): KafkaContinuousInputPartitionReader = { - new KafkaContinuousInputPartitionReader( - topicPartition, startOffset, kafkaParams, pollTimeoutMs, failOnDataLoss) + val deletedPartitions = oldStartPartitionOffsets.keySet.diff(currentPartitionSet) + if (deletedPartitions.nonEmpty) { + reportDataLoss(s"Some partitions were deleted: $deletedPartitions") + } + + val startOffsets = newPartitionOffsets ++ + oldStartPartitionOffsets.filterKeys(!deletedPartitions.contains(_)) + KafkaContinuousScanConfig(schema, startOffsets) } } +case class KafkaContinuousScanConfig( + readSchema: StructType, + startOffsets: Map[TopicPartition, Long]) + extends ScanConfig { + + // Created when building the scan config builder. If this diverges from the partitions at the + // latest offsets, we need to reconfigure the kafka read support. + def knownPartitions: Set[TopicPartition] = startOffsets.keySet +} + /** * A per-task data reader for continuous Kafka processing. * @@ -187,12 +196,12 @@ case class KafkaContinuousInputPartition( * @param failOnDataLoss Flag indicating whether data reader should fail if some offsets * are skipped. */ -class KafkaContinuousInputPartitionReader( +class KafkaContinuousPartitionReader( topicPartition: TopicPartition, startOffset: Long, kafkaParams: ju.Map[String, Object], pollTimeoutMs: Long, - failOnDataLoss: Boolean) extends ContinuousInputPartitionReader[UnsafeRow] { + failOnDataLoss: Boolean) extends ContinuousPartitionReader[InternalRow] { private val consumer = KafkaDataConsumer.acquire(topicPartition, kafkaParams, useCache = false) private val converter = new KafkaRecordToUnsafeRowConverter @@ -214,11 +223,11 @@ class KafkaContinuousInputPartitionReader( } catch { // We didn't read within the timeout. We're supposed to block indefinitely for new data, so // swallow and ignore this. - case _: TimeoutException => + case _: TimeoutException | _: org.apache.kafka.common.errors.TimeoutException => // This is a failOnDataLoss exception. Retry if nextKafkaOffset is within the data range, // or if it's the endpoint of the data range (i.e. the "true" next offset). - case e: IllegalStateException if e.getCause.isInstanceOf[OffsetOutOfRangeException] => + case e: IllegalStateException if e.getCause.isInstanceOf[OffsetOutOfRangeException] => val range = consumer.getAvailableOffsetRange() if (range.latest >= nextKafkaOffset && range.earliest <= nextKafkaOffset) { // retry diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaDataConsumer.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaDataConsumer.scala index 941f0ab177e48..ceb9e318b283b 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaDataConsumer.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaDataConsumer.scala @@ -33,9 +33,19 @@ import org.apache.spark.util.UninterruptibleThread private[kafka010] sealed trait KafkaDataConsumer { /** - * Get the record for the given offset if available. Otherwise it will either throw error - * (if failOnDataLoss = true), or return the next available offset within [offset, untilOffset), - * or null. + * Get the record for the given offset if available. + * + * If the record is invisible (either a + * transaction message, or an aborted message when the consumer's `isolation.level` is + * `read_committed`), it will be skipped and this method will try to fetch next available record + * within [offset, untilOffset). + * + * This method also will try its best to detect data loss. If `failOnDataLoss` is `true`, it will + * throw an exception when we detect an unavailable offset. If `failOnDataLoss` is `false`, this + * method will try to fetch next available record within [offset, untilOffset). + * + * When this method tries to skip offsets due to either invisible messages or data loss and + * reaches `untilOffset`, it will return `null`. * * @param offset the offset to fetch. * @param untilOffset the max offset to fetch. Exclusive. @@ -80,6 +90,83 @@ private[kafka010] case class InternalKafkaConsumer( kafkaParams: ju.Map[String, Object]) extends Logging { import InternalKafkaConsumer._ + /** + * The internal object to store the fetched data from Kafka consumer and the next offset to poll. + * + * @param _records the pre-fetched Kafka records. + * @param _nextOffsetInFetchedData the next offset in `records`. We use this to verify if we + * should check if the pre-fetched data is still valid. + * @param _offsetAfterPoll the Kafka offset after calling `poll`. We will use this offset to + * poll when `records` is drained. + */ + private case class FetchedData( + private var _records: ju.ListIterator[ConsumerRecord[Array[Byte], Array[Byte]]], + private var _nextOffsetInFetchedData: Long, + private var _offsetAfterPoll: Long) { + + def withNewPoll( + records: ju.ListIterator[ConsumerRecord[Array[Byte], Array[Byte]]], + offsetAfterPoll: Long): FetchedData = { + this._records = records + this._nextOffsetInFetchedData = UNKNOWN_OFFSET + this._offsetAfterPoll = offsetAfterPoll + this + } + + /** Whether there are more elements */ + def hasNext: Boolean = _records.hasNext + + /** Move `records` forward and return the next record. */ + def next(): ConsumerRecord[Array[Byte], Array[Byte]] = { + val record = _records.next() + _nextOffsetInFetchedData = record.offset + 1 + record + } + + /** Move `records` backward and return the previous record. */ + def previous(): ConsumerRecord[Array[Byte], Array[Byte]] = { + assert(_records.hasPrevious, "fetchedData cannot move back") + val record = _records.previous() + _nextOffsetInFetchedData = record.offset + record + } + + /** Reset the internal pre-fetched data. */ + def reset(): Unit = { + _records = ju.Collections.emptyListIterator() + } + + /** + * Returns the next offset in `records`. We use this to verify if we should check if the + * pre-fetched data is still valid. + */ + def nextOffsetInFetchedData: Long = _nextOffsetInFetchedData + + /** + * Returns the next offset to poll after draining the pre-fetched records. + */ + def offsetAfterPoll: Long = _offsetAfterPoll + } + + /** + * The internal object returned by the `fetchRecord` method. If `record` is empty, it means it is + * invisible (either a transaction message, or an aborted message when the consumer's + * `isolation.level` is `read_committed`), and the caller should use `nextOffsetToFetch` to fetch + * instead. + */ + private case class FetchedRecord( + var record: ConsumerRecord[Array[Byte], Array[Byte]], + var nextOffsetToFetch: Long) { + + def withRecord( + record: ConsumerRecord[Array[Byte], Array[Byte]], + nextOffsetToFetch: Long): FetchedRecord = { + this.record = record + this.nextOffsetToFetch = nextOffsetToFetch + this + } + } + private val groupId = kafkaParams.get(ConsumerConfig.GROUP_ID_CONFIG).asInstanceOf[String] @volatile private var consumer = createConsumer @@ -90,10 +177,21 @@ private[kafka010] case class InternalKafkaConsumer( /** indicate whether this consumer is going to be stopped in the next release */ @volatile var markedForClose = false - /** Iterator to the already fetch data */ - @volatile private var fetchedData = - ju.Collections.emptyIterator[ConsumerRecord[Array[Byte], Array[Byte]]] - @volatile private var nextOffsetInFetchedData = UNKNOWN_OFFSET + /** + * The fetched data returned from Kafka consumer. This is a reusable private object to avoid + * memory allocation. + */ + private val fetchedData = FetchedData( + ju.Collections.emptyListIterator[ConsumerRecord[Array[Byte], Array[Byte]]], + UNKNOWN_OFFSET, + UNKNOWN_OFFSET) + + /** + * The fetched record returned from the `fetchRecord` method. This is a reusable private object to + * avoid memory allocation. + */ + private val fetchedRecord: FetchedRecord = FetchedRecord(null, UNKNOWN_OFFSET) + /** Create a KafkaConsumer to fetch records for `topicPartition` */ private def createConsumer: KafkaConsumer[Array[Byte], Array[Byte]] = { @@ -125,20 +223,7 @@ private[kafka010] case class InternalKafkaConsumer( AvailableOffsetRange(earliestOffset, latestOffset) } - /** - * Get the record for the given offset if available. Otherwise it will either throw error - * (if failOnDataLoss = true), or return the next available offset within [offset, untilOffset), - * or null. - * - * @param offset the offset to fetch. - * @param untilOffset the max offset to fetch. Exclusive. - * @param pollTimeoutMs timeout in milliseconds to poll data from Kafka. - * @param failOnDataLoss When `failOnDataLoss` is `true`, this method will either return record at - * offset if available, or throw exception.when `failOnDataLoss` is `false`, - * this method will either return record at offset if available, or return - * the next earliest available record less than untilOffset, or null. It - * will not throw any exception. - */ + /** @see [[KafkaDataConsumer.get]] */ def get( offset: Long, untilOffset: Long, @@ -147,21 +232,32 @@ private[kafka010] case class InternalKafkaConsumer( ConsumerRecord[Array[Byte], Array[Byte]] = runUninterruptiblyIfPossible { require(offset < untilOffset, s"offset must always be less than untilOffset [offset: $offset, untilOffset: $untilOffset]") - logDebug(s"Get $groupId $topicPartition nextOffset $nextOffsetInFetchedData requested $offset") + logDebug(s"Get $groupId $topicPartition nextOffset ${fetchedData.nextOffsetInFetchedData} " + + s"requested $offset") // The following loop is basically for `failOnDataLoss = false`. When `failOnDataLoss` is // `false`, first, we will try to fetch the record at `offset`. If no such record exists, then // we will move to the next available offset within `[offset, untilOffset)` and retry. // If `failOnDataLoss` is `true`, the loop body will be executed only once. var toFetchOffset = offset - var consumerRecord: ConsumerRecord[Array[Byte], Array[Byte]] = null + var fetchedRecord: FetchedRecord = null // We want to break out of the while loop on a successful fetch to avoid using "return" - // which may causes a NonLocalReturnControl exception when this method is used as a function. + // which may cause a NonLocalReturnControl exception when this method is used as a function. var isFetchComplete = false while (toFetchOffset != UNKNOWN_OFFSET && !isFetchComplete) { try { - consumerRecord = fetchData(toFetchOffset, untilOffset, pollTimeoutMs, failOnDataLoss) - isFetchComplete = true + fetchedRecord = fetchRecord(toFetchOffset, untilOffset, pollTimeoutMs, failOnDataLoss) + if (fetchedRecord.record != null) { + isFetchComplete = true + } else { + toFetchOffset = fetchedRecord.nextOffsetToFetch + if (toFetchOffset >= untilOffset) { + fetchedData.reset() + toFetchOffset = UNKNOWN_OFFSET + } else { + logDebug(s"Skipped offsets [$offset, $toFetchOffset]") + } + } } catch { case e: OffsetOutOfRangeException => // When there is some error thrown, it's better to use a new consumer to drop all cached @@ -174,9 +270,9 @@ private[kafka010] case class InternalKafkaConsumer( } if (isFetchComplete) { - consumerRecord + fetchedRecord.record } else { - resetFetchedData() + fetchedData.reset() null } } @@ -239,57 +335,73 @@ private[kafka010] case class InternalKafkaConsumer( } /** - * Get the record for the given offset if available. Otherwise it will either throw error - * (if failOnDataLoss = true), or return the next available offset within [offset, untilOffset), - * or null. + * Get the fetched record for the given offset if available. + * + * If the record is invisible (either a transaction message, or an aborted message when the + * consumer's `isolation.level` is `read_committed`), it will return a `FetchedRecord` with the + * next offset to fetch. + * + * This method also will try the best to detect data loss. If `failOnDataLoss` is true`, it will + * throw an exception when we detect an unavailable offset. If `failOnDataLoss` is `false`, this + * method will return `null` if the next available record is within [offset, untilOffset). * * @throws OffsetOutOfRangeException if `offset` is out of range * @throws TimeoutException if cannot fetch the record in `pollTimeoutMs` milliseconds. */ - private def fetchData( + private def fetchRecord( offset: Long, untilOffset: Long, pollTimeoutMs: Long, - failOnDataLoss: Boolean): ConsumerRecord[Array[Byte], Array[Byte]] = { - if (offset != nextOffsetInFetchedData || !fetchedData.hasNext()) { - // This is the first fetch, or the last pre-fetched data has been drained. - // Seek to the offset because we may call seekToBeginning or seekToEnd before this. - seek(offset) - poll(pollTimeoutMs) - } - - if (!fetchedData.hasNext()) { - // We cannot fetch anything after `poll`. Two possible cases: - // - `offset` is out of range so that Kafka returns nothing. Just throw - // `OffsetOutOfRangeException` to let the caller handle it. - // - Cannot fetch any data before timeout. TimeoutException will be thrown. - val range = getAvailableOffsetRange() - if (offset < range.earliest || offset >= range.latest) { - throw new OffsetOutOfRangeException( - Map(topicPartition -> java.lang.Long.valueOf(offset)).asJava) + failOnDataLoss: Boolean): FetchedRecord = { + if (offset != fetchedData.nextOffsetInFetchedData) { + // This is the first fetch, or the fetched data has been reset. + // Fetch records from Kafka and update `fetchedData`. + fetchData(offset, pollTimeoutMs) + } else if (!fetchedData.hasNext) { // The last pre-fetched data has been drained. + if (offset < fetchedData.offsetAfterPoll) { + // Offsets in [offset, fetchedData.offsetAfterPoll) are invisible. Return a record to ask + // the next call to start from `fetchedData.offsetAfterPoll`. + fetchedData.reset() + return fetchedRecord.withRecord(null, fetchedData.offsetAfterPoll) } else { - throw new TimeoutException( - s"Cannot fetch record for offset $offset in $pollTimeoutMs milliseconds") + // Fetch records from Kafka and update `fetchedData`. + fetchData(offset, pollTimeoutMs) } + } + + if (!fetchedData.hasNext) { + // When we reach here, we have already tried to poll from Kafka. As `fetchedData` is still + // empty, all messages in [offset, fetchedData.offsetAfterPoll) are invisible. Return a + // record to ask the next call to start from `fetchedData.offsetAfterPoll`. + assert(offset <= fetchedData.offsetAfterPoll, + s"seek to $offset and poll but the offset was reset to ${fetchedData.offsetAfterPoll}") + fetchedRecord.withRecord(null, fetchedData.offsetAfterPoll) } else { val record = fetchedData.next() - nextOffsetInFetchedData = record.offset + 1 // In general, Kafka uses the specified offset as the start point, and tries to fetch the next // available offset. Hence we need to handle offset mismatch. if (record.offset > offset) { + val range = getAvailableOffsetRange() + if (range.earliest <= offset) { + // `offset` is still valid but the corresponding message is invisible. We should skip it + // and jump to `record.offset`. Here we move `fetchedData` back so that the next call of + // `fetchRecord` can just return `record` directly. + fetchedData.previous() + return fetchedRecord.withRecord(null, record.offset) + } // This may happen when some records aged out but their offsets already got verified if (failOnDataLoss) { reportDataLoss(true, s"Cannot fetch records in [$offset, ${record.offset})") // Never happen as "reportDataLoss" will throw an exception - null + throw new IllegalStateException( + "reportDataLoss didn't throw an exception when 'failOnDataLoss' is true") + } else if (record.offset >= untilOffset) { + reportDataLoss(false, s"Skip missing records in [$offset, $untilOffset)") + // Set `nextOffsetToFetch` to `untilOffset` to finish the current batch. + fetchedRecord.withRecord(null, untilOffset) } else { - if (record.offset >= untilOffset) { - reportDataLoss(false, s"Skip missing records in [$offset, $untilOffset)") - null - } else { - reportDataLoss(false, s"Skip missing records in [$offset, ${record.offset})") - record - } + reportDataLoss(false, s"Skip missing records in [$offset, ${record.offset})") + fetchedRecord.withRecord(record, fetchedData.nextOffsetInFetchedData) } } else if (record.offset < offset) { // This should not happen. If it does happen, then we probably misunderstand Kafka internal @@ -297,7 +409,7 @@ private[kafka010] case class InternalKafkaConsumer( throw new IllegalStateException( s"Tried to fetch $offset but the returned record offset was ${record.offset}") } else { - record + fetchedRecord.withRecord(record, fetchedData.nextOffsetInFetchedData) } } } @@ -306,13 +418,7 @@ private[kafka010] case class InternalKafkaConsumer( private def resetConsumer(): Unit = { consumer.close() consumer = createConsumer - resetFetchedData() - } - - /** Reset the internal pre-fetched data. */ - private def resetFetchedData(): Unit = { - nextOffsetInFetchedData = UNKNOWN_OFFSET - fetchedData = ju.Collections.emptyIterator[ConsumerRecord[Array[Byte], Array[Byte]]] + fetchedData.reset() } /** @@ -346,11 +452,40 @@ private[kafka010] case class InternalKafkaConsumer( consumer.seek(topicPartition, offset) } - private def poll(pollTimeoutMs: Long): Unit = { + /** + * Poll messages from Kafka starting from `offset` and update `fetchedData`. `fetchedData` may be + * empty if the Kafka consumer fetches some messages but all of them are not visible messages + * (either transaction messages, or aborted messages when `isolation.level` is `read_committed`). + * + * @throws OffsetOutOfRangeException if `offset` is out of range. + * @throws TimeoutException if the consumer position is not changed after polling. It means the + * consumer polls nothing before timeout. + */ + private def fetchData(offset: Long, pollTimeoutMs: Long): Unit = { + // Seek to the offset because we may call seekToBeginning or seekToEnd before this. + seek(offset) val p = consumer.poll(pollTimeoutMs) val r = p.records(topicPartition) logDebug(s"Polled $groupId ${p.partitions()} ${r.size}") - fetchedData = r.iterator + val offsetAfterPoll = consumer.position(topicPartition) + logDebug(s"Offset changed from $offset to $offsetAfterPoll after polling") + fetchedData.withNewPoll(r.listIterator, offsetAfterPoll) + if (!fetchedData.hasNext) { + // We cannot fetch anything after `poll`. Two possible cases: + // - `offset` is out of range so that Kafka returns nothing. `OffsetOutOfRangeException` will + // be thrown. + // - Cannot fetch any data before timeout. `TimeoutException` will be thrown. + // - Fetched something but all of them are not invisible. This is a valid case and let the + // caller handles this. + val range = getAvailableOffsetRange() + if (offset < range.earliest || offset >= range.latest) { + throw new OffsetOutOfRangeException( + Map(topicPartition -> java.lang.Long.valueOf(offset)).asJava) + } else if (offset == offsetAfterPoll) { + throw new TimeoutException( + s"Cannot fetch record for offset $offset in $pollTimeoutMs milliseconds") + } + } } } diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchReader.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchReadSupport.scala similarity index 83% rename from external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchReader.scala rename to external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchReadSupport.scala index 737da2e51b125..bb4de674c3c72 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchReader.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchReadSupport.scala @@ -21,25 +21,25 @@ import java.{util => ju} import java.io._ import java.nio.charset.StandardCharsets -import scala.collection.JavaConverters._ - import org.apache.commons.io.IOUtils import org.apache.spark.SparkEnv import org.apache.spark.internal.Logging import org.apache.spark.scheduler.ExecutorCacheTaskLocation import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.UnsafeRow -import org.apache.spark.sql.execution.streaming.{HDFSMetadataLog, SerializedOffset} +import org.apache.spark.sql.execution.streaming.{HDFSMetadataLog, SerializedOffset, SimpleStreamingScanConfig, SimpleStreamingScanConfigBuilder} +import org.apache.spark.sql.execution.streaming.sources.RateControlMicroBatchReadSupport import org.apache.spark.sql.kafka010.KafkaSourceProvider.{INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_FALSE, INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_TRUE} import org.apache.spark.sql.sources.v2.DataSourceOptions -import org.apache.spark.sql.sources.v2.reader.{InputPartition, InputPartitionReader, SupportsScanUnsafeRow} -import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReader, Offset} +import org.apache.spark.sql.sources.v2.reader._ +import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReadSupport, Offset} import org.apache.spark.sql.types.StructType import org.apache.spark.util.UninterruptibleThread /** - * A [[MicroBatchReader]] that reads data from Kafka. + * A [[MicroBatchReadSupport]] that reads data from Kafka. * * The [[KafkaSourceOffset]] is the custom [[Offset]] defined for this source that contains * a map of TopicPartition -> offset. Note that this offset is 1 + (available offset). For @@ -54,17 +54,13 @@ import org.apache.spark.util.UninterruptibleThread * To avoid this issue, you should make sure stopping the query before stopping the Kafka brokers * and not use wrong broker addresses. */ -private[kafka010] class KafkaMicroBatchReader( +private[kafka010] class KafkaMicroBatchReadSupport( kafkaOffsetReader: KafkaOffsetReader, executorKafkaParams: ju.Map[String, Object], options: DataSourceOptions, metadataPath: String, startingOffsets: KafkaOffsetRangeLimit, - failOnDataLoss: Boolean) - extends MicroBatchReader with SupportsScanUnsafeRow with Logging { - - private var startPartitionOffsets: PartitionOffsetMap = _ - private var endPartitionOffsets: PartitionOffsetMap = _ + failOnDataLoss: Boolean) extends RateControlMicroBatchReadSupport with Logging { private val pollTimeoutMs = options.getLong( "kafkaConsumer.pollTimeoutMs", @@ -74,34 +70,40 @@ private[kafka010] class KafkaMicroBatchReader( Option(options.get("maxOffsetsPerTrigger").orElse(null)).map(_.toLong) private val rangeCalculator = KafkaOffsetRangeCalculator(options) + + private var endPartitionOffsets: KafkaSourceOffset = _ + /** * Lazily initialize `initialPartitionOffsets` to make sure that `KafkaConsumer.poll` is only * called in StreamExecutionThread. Otherwise, interrupting a thread while running * `KafkaConsumer.poll` may hang forever (KAFKA-1894). */ - private lazy val initialPartitionOffsets = getOrCreateInitialPartitionOffsets() - - override def setOffsetRange(start: ju.Optional[Offset], end: ju.Optional[Offset]): Unit = { - // Make sure initialPartitionOffsets is initialized - initialPartitionOffsets - - startPartitionOffsets = Option(start.orElse(null)) - .map(_.asInstanceOf[KafkaSourceOffset].partitionToOffsets) - .getOrElse(initialPartitionOffsets) - - endPartitionOffsets = Option(end.orElse(null)) - .map(_.asInstanceOf[KafkaSourceOffset].partitionToOffsets) - .getOrElse { - val latestPartitionOffsets = kafkaOffsetReader.fetchLatestOffsets() - maxOffsetsPerTrigger.map { maxOffsets => - rateLimit(maxOffsets, startPartitionOffsets, latestPartitionOffsets) - }.getOrElse { - latestPartitionOffsets - } - } + override def initialOffset(): Offset = { + KafkaSourceOffset(getOrCreateInitialPartitionOffsets()) } - override def planUnsafeInputPartitions(): ju.List[InputPartition[UnsafeRow]] = { + override def latestOffset(start: Offset): Offset = { + val startPartitionOffsets = start.asInstanceOf[KafkaSourceOffset].partitionToOffsets + val latestPartitionOffsets = kafkaOffsetReader.fetchLatestOffsets() + endPartitionOffsets = KafkaSourceOffset(maxOffsetsPerTrigger.map { maxOffsets => + rateLimit(maxOffsets, startPartitionOffsets, latestPartitionOffsets) + }.getOrElse { + latestPartitionOffsets + }) + endPartitionOffsets + } + + override def fullSchema(): StructType = KafkaOffsetReader.kafkaSchema + + override def newScanConfigBuilder(start: Offset, end: Offset): ScanConfigBuilder = { + new SimpleStreamingScanConfigBuilder(fullSchema(), start, Some(end)) + } + + override def planInputPartitions(config: ScanConfig): Array[InputPartition] = { + val sc = config.asInstanceOf[SimpleStreamingScanConfig] + val startPartitionOffsets = sc.start.asInstanceOf[KafkaSourceOffset].partitionToOffsets + val endPartitionOffsets = sc.end.get.asInstanceOf[KafkaSourceOffset].partitionToOffsets + // Find the new partitions, and get their earliest offsets val newPartitions = endPartitionOffsets.keySet.diff(startPartitionOffsets.keySet) val newPartitionInitialOffsets = kafkaOffsetReader.fetchEarliestOffsets(newPartitions.toSeq) @@ -142,27 +144,20 @@ private[kafka010] class KafkaMicroBatchReader( val reuseKafkaConsumer = offsetRanges.map(_.topicPartition).toSet.size == offsetRanges.size // Generate factories based on the offset ranges - val factories = offsetRanges.map { range => - new KafkaMicroBatchInputPartition( + offsetRanges.map { range => + KafkaMicroBatchInputPartition( range, executorKafkaParams, pollTimeoutMs, failOnDataLoss, reuseKafkaConsumer) - } - factories.map(_.asInstanceOf[InputPartition[UnsafeRow]]).asJava - } - - override def getStartOffset: Offset = { - KafkaSourceOffset(startPartitionOffsets) + }.toArray } - override def getEndOffset: Offset = { - KafkaSourceOffset(endPartitionOffsets) + override def createReaderFactory(config: ScanConfig): PartitionReaderFactory = { + KafkaMicroBatchReaderFactory } override def deserializeOffset(json: String): Offset = { KafkaSourceOffset(JsonUtils.partitionOffsets(json)) } - override def readSchema(): StructType = KafkaOffsetReader.kafkaSchema - override def commit(end: Offset): Unit = {} override def stop(): Unit = { @@ -305,22 +300,23 @@ private[kafka010] case class KafkaMicroBatchInputPartition( executorKafkaParams: ju.Map[String, Object], pollTimeoutMs: Long, failOnDataLoss: Boolean, - reuseKafkaConsumer: Boolean) extends InputPartition[UnsafeRow] { + reuseKafkaConsumer: Boolean) extends InputPartition - override def preferredLocations(): Array[String] = offsetRange.preferredLoc.toArray - - override def createPartitionReader(): InputPartitionReader[UnsafeRow] = - new KafkaMicroBatchInputPartitionReader(offsetRange, executorKafkaParams, pollTimeoutMs, - failOnDataLoss, reuseKafkaConsumer) +private[kafka010] object KafkaMicroBatchReaderFactory extends PartitionReaderFactory { + override def createReader(partition: InputPartition): PartitionReader[InternalRow] = { + val p = partition.asInstanceOf[KafkaMicroBatchInputPartition] + KafkaMicroBatchPartitionReader(p.offsetRange, p.executorKafkaParams, p.pollTimeoutMs, + p.failOnDataLoss, p.reuseKafkaConsumer) + } } -/** A [[InputPartitionReader]] for reading Kafka data in a micro-batch streaming query. */ -private[kafka010] case class KafkaMicroBatchInputPartitionReader( +/** A [[PartitionReader]] for reading Kafka data in a micro-batch streaming query. */ +private[kafka010] case class KafkaMicroBatchPartitionReader( offsetRange: KafkaOffsetRange, executorKafkaParams: ju.Map[String, Object], pollTimeoutMs: Long, failOnDataLoss: Boolean, - reuseKafkaConsumer: Boolean) extends InputPartitionReader[UnsafeRow] with Logging { + reuseKafkaConsumer: Boolean) extends PartitionReader[InternalRow] with Logging { private val consumer = KafkaDataConsumer.acquire( offsetRange.topicPartition, executorKafkaParams, reuseKafkaConsumer) @@ -336,6 +332,7 @@ private[kafka010] case class KafkaMicroBatchInputPartitionReader( val record = consumer.get(nextOffset, rangeToRead.untilOffset, pollTimeoutMs, failOnDataLoss) if (record != null) { nextRow = converter.toUnsafeRow(record) + nextOffset = record.offset + 1 true } else { false @@ -347,7 +344,6 @@ private[kafka010] case class KafkaMicroBatchInputPartitionReader( override def get(): UnsafeRow = { assert(nextRow != null) - nextOffset += 1 nextRow } diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetRangeCalculator.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetRangeCalculator.scala index 6631ae84167c8..fb209c724afba 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetRangeCalculator.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetRangeCalculator.scala @@ -29,7 +29,6 @@ import org.apache.spark.sql.sources.v2.DataSourceOptions private[kafka010] class KafkaOffsetRangeCalculator(val minPartitions: Option[Int]) { require(minPartitions.isEmpty || minPartitions.get > 0) - import KafkaOffsetRangeCalculator._ /** * Calculate the offset ranges that we are going to process this batch. If `minPartitions` * is not set or is set less than or equal the number of `topicPartitions` that we're going to diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaRelation.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaRelation.scala index c31e6ed3e0903..e6f9d1259e43e 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaRelation.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaRelation.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.kafka010 -import java.{util => ju} import java.util.UUID import org.apache.kafka.common.TopicPartition @@ -117,7 +116,7 @@ private[kafka010] class KafkaRelation( DateTimeUtils.fromJavaTimestamp(new java.sql.Timestamp(cr.timestamp)), cr.timestampType.id) } - sqlContext.internalCreateDataFrame(rdd, schema).rdd + sqlContext.internalCreateDataFrame(rdd.setName("kafka"), schema).rdd } private def getPartitionOffsets( diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala index 101e649727fcf..66ec7e0cd084a 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala @@ -215,7 +215,7 @@ private[kafka010] class KafkaSource( } if (start.isDefined && start.get == end) { return sqlContext.internalCreateDataFrame( - sqlContext.sparkContext.emptyRDD, schema, isStreaming = true) + sqlContext.sparkContext.emptyRDD[InternalRow].setName("empty"), schema, isStreaming = true) } val fromPartitionOffsets = start match { case Some(prevBatchEndOffset) => @@ -299,7 +299,7 @@ private[kafka010] class KafkaSource( logInfo("GetBatch generating RDD of offset range: " + offsetRanges.sortBy(_.topicPartition.toString).mkString(", ")) - sqlContext.internalCreateDataFrame(rdd, schema, isStreaming = true) + sqlContext.internalCreateDataFrame(rdd.setName("kafka"), schema, isStreaming = true) } /** Stop this source and free any resources it has allocated. */ diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala index d225c1ea6b7f1..28c9853bfea9c 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala @@ -30,9 +30,8 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.{AnalysisException, DataFrame, SaveMode, SparkSession, SQLContext} import org.apache.spark.sql.execution.streaming.{Sink, Source} import org.apache.spark.sql.sources._ -import org.apache.spark.sql.sources.v2.{ContinuousReadSupport, DataSourceOptions, MicroBatchReadSupport, StreamWriteSupport} -import org.apache.spark.sql.sources.v2.reader.streaming.ContinuousInputPartitionReader -import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter +import org.apache.spark.sql.sources.v2._ +import org.apache.spark.sql.sources.v2.writer.streaming.StreamingWriteSupport import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.StructType @@ -46,9 +45,9 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister with StreamSinkProvider with RelationProvider with CreatableRelationProvider - with StreamWriteSupport - with ContinuousReadSupport - with MicroBatchReadSupport + with StreamingWriteSupportProvider + with ContinuousReadSupportProvider + with MicroBatchReadSupportProvider with Logging { import KafkaSourceProvider._ @@ -108,13 +107,12 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister } /** - * Creates a [[org.apache.spark.sql.sources.v2.reader.streaming.MicroBatchReader]] to read batches - * of Kafka data in a micro-batch streaming query. + * Creates a [[org.apache.spark.sql.sources.v2.reader.streaming.MicroBatchReadSupport]] to read + * batches of Kafka data in a micro-batch streaming query. */ - override def createMicroBatchReader( - schema: Optional[StructType], + override def createMicroBatchReadSupport( metadataPath: String, - options: DataSourceOptions): KafkaMicroBatchReader = { + options: DataSourceOptions): KafkaMicroBatchReadSupport = { val parameters = options.asMap().asScala.toMap validateStreamOptions(parameters) @@ -140,7 +138,7 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister parameters, driverGroupIdPrefix = s"$uniqueGroupId-driver") - new KafkaMicroBatchReader( + new KafkaMicroBatchReadSupport( kafkaOffsetReader, kafkaParamsForExecutors(specifiedKafkaParams, uniqueGroupId), options, @@ -150,13 +148,12 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister } /** - * Creates a [[ContinuousInputPartitionReader]] to read + * Creates a [[org.apache.spark.sql.sources.v2.reader.streaming.ContinuousReadSupport]] to read * Kafka data in a continuous streaming query. */ - override def createContinuousReader( - schema: Optional[StructType], + override def createContinuousReadSupport( metadataPath: String, - options: DataSourceOptions): KafkaContinuousReader = { + options: DataSourceOptions): KafkaContinuousReadSupport = { val parameters = options.asMap().asScala.toMap validateStreamOptions(parameters) // Each running query should use its own group id. Otherwise, the query may be only assigned @@ -181,7 +178,7 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister parameters, driverGroupIdPrefix = s"$uniqueGroupId-driver") - new KafkaContinuousReader( + new KafkaContinuousReadSupport( kafkaOffsetReader, kafkaParamsForExecutors(specifiedKafkaParams, uniqueGroupId), parameters, @@ -270,11 +267,11 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister } } - override def createStreamWriter( + override def createStreamingWriteSupport( queryId: String, schema: StructType, mode: OutputMode, - options: DataSourceOptions): StreamWriter = { + options: DataSourceOptions): StreamingWriteSupport = { import scala.collection.JavaConverters._ val spark = SparkSession.getActiveSession.get @@ -285,7 +282,7 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister KafkaWriter.validateQuery( schema.toAttributes, new java.util.HashMap[String, Object](producerParams.asJava), topic) - new KafkaStreamWriter(topic, producerParams, schema) + new KafkaStreamingWriteSupport(topic, producerParams, schema) } private def strategy(caseInsensitiveParams: Map[String, String]) = diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceRDD.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceRDD.scala index 498e344ea39f4..f8b90056d2931 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceRDD.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceRDD.scala @@ -77,44 +77,6 @@ private[kafka010] class KafkaSourceRDD( offsetRanges.zipWithIndex.map { case (o, i) => new KafkaSourceRDDPartition(i, o) }.toArray } - override def count(): Long = offsetRanges.map(_.size).sum - - override def countApprox(timeout: Long, confidence: Double): PartialResult[BoundedDouble] = { - val c = count - new PartialResult(new BoundedDouble(c, 1.0, c, c), true) - } - - override def isEmpty(): Boolean = count == 0L - - override def take(num: Int): Array[ConsumerRecord[Array[Byte], Array[Byte]]] = { - val nonEmptyPartitions = - this.partitions.map(_.asInstanceOf[KafkaSourceRDDPartition]).filter(_.offsetRange.size > 0) - - if (num < 1 || nonEmptyPartitions.isEmpty) { - return new Array[ConsumerRecord[Array[Byte], Array[Byte]]](0) - } - - // Determine in advance how many messages need to be taken from each partition - val parts = nonEmptyPartitions.foldLeft(Map[Int, Int]()) { (result, part) => - val remain = num - result.values.sum - if (remain > 0) { - val taken = Math.min(remain, part.offsetRange.size) - result + (part.index -> taken.toInt) - } else { - result - } - } - - val buf = new ArrayBuffer[ConsumerRecord[Array[Byte], Array[Byte]]] - val res = context.runJob( - this, - (tc: TaskContext, it: Iterator[ConsumerRecord[Array[Byte], Array[Byte]]]) => - it.take(parts(tc.partitionId)).toArray, parts.keys.toArray - ) - res.foreach(buf ++= _) - buf.toArray - } - override def getPreferredLocations(split: Partition): Seq[String] = { val part = split.asInstanceOf[KafkaSourceRDDPartition] part.offsetRange.preferredLoc.map(Seq(_)).getOrElse(Seq.empty) @@ -124,8 +86,6 @@ private[kafka010] class KafkaSourceRDD( thePart: Partition, context: TaskContext): Iterator[ConsumerRecord[Array[Byte], Array[Byte]]] = { val sourcePartition = thePart.asInstanceOf[KafkaSourceRDDPartition] - val topic = sourcePartition.offsetRange.topic - val kafkaPartition = sourcePartition.offsetRange.partition val consumer = KafkaDataConsumer.acquire( sourcePartition.offsetRange.topicPartition, executorKafkaParams, reuseKafkaConsumer) @@ -138,6 +98,7 @@ private[kafka010] class KafkaSourceRDD( if (range.fromOffset == range.untilOffset) { logInfo(s"Beginning offset ${range.fromOffset} is the same as ending offset " + s"skipping ${range.topic} ${range.partition}") + consumer.release() Iterator.empty } else { val underlying = new NextIterator[ConsumerRecord[Array[Byte], Array[Byte]]]() { @@ -166,7 +127,7 @@ private[kafka010] class KafkaSourceRDD( } } // Release consumer, either by removing it or indicating we're no longer using it - context.addTaskCompletionListener { _ => + context.addTaskCompletionListener[Unit] { _ => underlying.closeIfNeeded() } underlying diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaStreamWriter.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaStreamingWriteSupport.scala similarity index 88% rename from external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaStreamWriter.scala rename to external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaStreamingWriteSupport.scala index 32923dc9f5a6b..927c56d9ce829 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaStreamWriter.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaStreamingWriteSupport.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.kafka010.KafkaWriter.validateQuery import org.apache.spark.sql.sources.v2.writer._ -import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter +import org.apache.spark.sql.sources.v2.writer.streaming.{StreamingDataWriterFactory, StreamingWriteSupport} import org.apache.spark.sql.types.StructType /** @@ -33,20 +33,20 @@ import org.apache.spark.sql.types.StructType case object KafkaWriterCommitMessage extends WriterCommitMessage /** - * A [[StreamWriter]] for Kafka writing. Responsible for generating the writer factory. + * A [[StreamingWriteSupport]] for Kafka writing. Responsible for generating the writer factory. * * @param topic The topic this writer is responsible for. If None, topic will be inferred from * a `topic` field in the incoming data. * @param producerParams Parameters for Kafka producers in each task. * @param schema The schema of the input data. */ -class KafkaStreamWriter( +class KafkaStreamingWriteSupport( topic: Option[String], producerParams: Map[String, String], schema: StructType) - extends StreamWriter with SupportsWriteInternalRow { + extends StreamingWriteSupport { validateQuery(schema.toAttributes, producerParams.toMap[String, Object].asJava, topic) - override def createInternalRowWriterFactory(): KafkaStreamWriterFactory = + override def createStreamingWriterFactory(): KafkaStreamWriterFactory = KafkaStreamWriterFactory(topic, producerParams, schema) override def commit(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {} @@ -54,8 +54,8 @@ class KafkaStreamWriter( } /** - * A [[DataWriterFactory]] for Kafka writing. Will be serialized and sent to executors to generate - * the per-task data writers. + * A [[StreamingDataWriterFactory]] for Kafka writing. Will be serialized and sent to executors to + * generate the per-task data writers. * @param topic The topic that should be written to. If None, topic will be inferred from * a `topic` field in the incoming data. * @param producerParams Parameters for Kafka producers in each task. @@ -63,9 +63,9 @@ class KafkaStreamWriter( */ case class KafkaStreamWriterFactory( topic: Option[String], producerParams: Map[String, String], schema: StructType) - extends DataWriterFactory[InternalRow] { + extends StreamingDataWriterFactory { - override def createDataWriter( + override def createWriter( partitionId: Int, taskId: Long, epochId: Long): DataWriter[InternalRow] = { diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriteTask.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriteTask.scala index d90630a8adc93..041fac7717635 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriteTask.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriteTask.scala @@ -110,7 +110,7 @@ private[kafka010] abstract class KafkaRowWriter( case t => throw new IllegalStateException(s"${KafkaWriter.TOPIC_ATTRIBUTE_NAME} " + s"attribute unsupported type $t. ${KafkaWriter.TOPIC_ATTRIBUTE_NAME} " + - "must be a StringType") + s"must be a ${StringType.catalogString}") } val keyExpression = inputSchema.find(_.name == KafkaWriter.KEY_ATTRIBUTE_NAME) .getOrElse(Literal(null, BinaryType)) @@ -118,7 +118,7 @@ private[kafka010] abstract class KafkaRowWriter( case StringType | BinaryType => // good case t => throw new IllegalStateException(s"${KafkaWriter.KEY_ATTRIBUTE_NAME} " + - s"attribute unsupported type $t") + s"attribute unsupported type ${t.catalogString}") } val valueExpression = inputSchema .find(_.name == KafkaWriter.VALUE_ATTRIBUTE_NAME).getOrElse( @@ -129,7 +129,7 @@ private[kafka010] abstract class KafkaRowWriter( case StringType | BinaryType => // good case t => throw new IllegalStateException(s"${KafkaWriter.VALUE_ATTRIBUTE_NAME} " + - s"attribute unsupported type $t") + s"attribute unsupported type ${t.catalogString}") } UnsafeProjection.create( Seq(topicExpression, Cast(keyExpression, BinaryType), diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriter.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriter.scala index 15cd44812cb0c..fc09938a43a8c 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriter.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriter.scala @@ -57,7 +57,7 @@ private[kafka010] object KafkaWriter extends Logging { ).dataType match { case StringType => // good case _ => - throw new AnalysisException(s"Topic type must be a String") + throw new AnalysisException(s"Topic type must be a ${StringType.catalogString}") } schema.find(_.name == KEY_ATTRIBUTE_NAME).getOrElse( Literal(null, StringType) @@ -65,7 +65,7 @@ private[kafka010] object KafkaWriter extends Logging { case StringType | BinaryType => // good case _ => throw new AnalysisException(s"$KEY_ATTRIBUTE_NAME attribute type " + - s"must be a String or BinaryType") + s"must be a ${StringType.catalogString} or ${BinaryType.catalogString}") } schema.find(_.name == VALUE_ATTRIBUTE_NAME).getOrElse( throw new AnalysisException(s"Required attribute '$VALUE_ATTRIBUTE_NAME' not found") @@ -73,7 +73,7 @@ private[kafka010] object KafkaWriter extends Logging { case StringType | BinaryType => // good case _ => throw new AnalysisException(s"$VALUE_ATTRIBUTE_NAME attribute type " + - s"must be a String or BinaryType") + s"must be a ${StringType.catalogString} or ${BinaryType.catalogString}") } } diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/CachedKafkaProducerSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/CachedKafkaProducerSuite.scala index 789bffa9da126..0b3355426df10 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/CachedKafkaProducerSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/CachedKafkaProducerSuite.scala @@ -26,14 +26,13 @@ import org.scalatest.PrivateMethodTester import org.apache.spark.sql.test.SharedSQLContext -class CachedKafkaProducerSuite extends SharedSQLContext with PrivateMethodTester { +class CachedKafkaProducerSuite extends SharedSQLContext with PrivateMethodTester with KafkaTest { type KP = KafkaProducer[Array[Byte], Array[Byte]] protected override def beforeEach(): Unit = { super.beforeEach() - val clear = PrivateMethod[Unit]('clear) - CachedKafkaProducer.invokePrivate(clear()) + CachedKafkaProducer.clear() } test("Should return the cached instance on calling getOrCreate with same params.") { diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSinkSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSinkSuite.scala index ddfc0c1a4be2d..3f6fcf6b2e52c 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSinkSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSinkSuite.scala @@ -40,12 +40,7 @@ class KafkaContinuousSinkSuite extends KafkaContinuousTest { override val streamingTimeout = 30.seconds - override def beforeAll(): Unit = { - super.beforeAll() - testUtils = new KafkaTestUtils( - withBrokerProps = Map("auto.create.topics.enable" -> "false")) - testUtils.setup() - } + override val brokerProps = Map("auto.create.topics.enable" -> "false") override def afterAll(): Unit = { if (testUtils != null) { @@ -314,7 +309,7 @@ class KafkaContinuousSinkSuite extends KafkaContinuousTest { writer.stop() } assert(ex.getMessage.toLowerCase(Locale.ROOT).contains( - "value attribute type must be a string or binarytype")) + "value attribute type must be a string or binary")) try { /* key field wrong type */ @@ -330,7 +325,7 @@ class KafkaContinuousSinkSuite extends KafkaContinuousTest { writer.stop() } assert(ex.getMessage.toLowerCase(Locale.ROOT).contains( - "key attribute type must be a string or binarytype")) + "key attribute type must be a string or binary")) } test("streaming - write to non-existing topic") { diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSourceSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSourceSuite.scala index aab8ec42189fb..af510219a6f6f 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSourceSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSourceSuite.scala @@ -17,12 +17,159 @@ package org.apache.spark.sql.kafka010 +import org.apache.kafka.clients.producer.ProducerRecord + import org.apache.spark.sql.Dataset -import org.apache.spark.sql.execution.datasources.v2.StreamingDataSourceV2Relation +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanExec +import org.apache.spark.sql.execution.streaming.continuous.ContinuousTrigger import org.apache.spark.sql.streaming.Trigger // Run tests in KafkaSourceSuiteBase in continuous execution mode. -class KafkaContinuousSourceSuite extends KafkaSourceSuiteBase with KafkaContinuousTest +class KafkaContinuousSourceSuite extends KafkaSourceSuiteBase with KafkaContinuousTest { + import testImplicits._ + + test("read Kafka transactional messages: read_committed") { + val table = "kafka_continuous_source_test" + withTable(table) { + val topic = newTopic() + testUtils.createTopic(topic) + testUtils.withTranscationalProducer { producer => + val df = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("kafka.isolation.level", "read_committed") + .option("startingOffsets", "earliest") + .option("subscribe", topic) + .load() + .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") + .as[(String, String)] + .map(kv => kv._2.toInt) + + val q = df + .writeStream + .format("memory") + .queryName(table) + .trigger(ContinuousTrigger(100)) + .start() + try { + producer.beginTransaction() + (1 to 5).foreach { i => + producer.send(new ProducerRecord[String, String](topic, i.toString)).get() + } + + // Should not read any messages before they are committed + assert(spark.table(table).isEmpty) + + producer.commitTransaction() + + eventually(timeout(streamingTimeout)) { + // Should read all committed messages + checkAnswer(spark.table(table), (1 to 5).toDF) + } + + producer.beginTransaction() + (6 to 10).foreach { i => + producer.send(new ProducerRecord[String, String](topic, i.toString)).get() + } + producer.abortTransaction() + + // Should not read aborted messages + checkAnswer(spark.table(table), (1 to 5).toDF) + + producer.beginTransaction() + (11 to 15).foreach { i => + producer.send(new ProducerRecord[String, String](topic, i.toString)).get() + } + producer.commitTransaction() + + eventually(timeout(streamingTimeout)) { + // Should skip aborted messages and read new committed ones. + checkAnswer(spark.table(table), ((1 to 5) ++ (11 to 15)).toDF) + } + } finally { + q.stop() + } + } + } + } + + test("read Kafka transactional messages: read_uncommitted") { + val table = "kafka_continuous_source_test" + withTable(table) { + val topic = newTopic() + testUtils.createTopic(topic) + testUtils.withTranscationalProducer { producer => + val df = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("kafka.isolation.level", "read_uncommitted") + .option("startingOffsets", "earliest") + .option("subscribe", topic) + .load() + .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") + .as[(String, String)] + .map(kv => kv._2.toInt) + + val q = df + .writeStream + .format("memory") + .queryName(table) + .trigger(ContinuousTrigger(100)) + .start() + try { + producer.beginTransaction() + (1 to 5).foreach { i => + producer.send(new ProducerRecord[String, String](topic, i.toString)).get() + } + + eventually(timeout(streamingTimeout)) { + // Should read uncommitted messages + checkAnswer(spark.table(table), (1 to 5).toDF) + } + + producer.commitTransaction() + + eventually(timeout(streamingTimeout)) { + // Should read all committed messages + checkAnswer(spark.table(table), (1 to 5).toDF) + } + + producer.beginTransaction() + (6 to 10).foreach { i => + producer.send(new ProducerRecord[String, String](topic, i.toString)).get() + } + producer.abortTransaction() + + eventually(timeout(streamingTimeout)) { + // Should read aborted messages + checkAnswer(spark.table(table), (1 to 10).toDF) + } + + producer.beginTransaction() + (11 to 15).foreach { i => + producer.send(new ProducerRecord[String, String](topic, i.toString)).get() + } + + eventually(timeout(streamingTimeout)) { + // Should read all messages including committed, aborted and uncommitted messages + checkAnswer(spark.table(table), (1 to 15).toDF) + } + + producer.commitTransaction() + + eventually(timeout(streamingTimeout)) { + // Should read all messages including committed and aborted messages + checkAnswer(spark.table(table), (1 to 15).toDF) + } + } finally { + q.stop() + } + } + } + } +} class KafkaContinuousSourceTopicDeletionSuite extends KafkaContinuousTest { import testImplicits._ @@ -42,6 +189,7 @@ class KafkaContinuousSourceTopicDeletionSuite extends KafkaContinuousTest { .format("kafka") .option("kafka.bootstrap.servers", testUtils.brokerAddress) .option("kafka.metadata.max.age.ms", "1") + .option("kafka.default.api.timeout.ms", "3000") .option("subscribePattern", s"$topicPrefix-.*") .option("failOnDataLoss", "false") @@ -59,11 +207,13 @@ class KafkaContinuousSourceTopicDeletionSuite extends KafkaContinuousTest { testUtils.createTopic(topic2, partitions = 5) eventually(timeout(streamingTimeout)) { assert( - query.lastExecution.logical.collectFirst { - case StreamingDataSourceV2Relation(_, _, _, r: KafkaContinuousReader) => r - }.exists { r => + query.lastExecution.executedPlan.collectFirst { + case scan: DataSourceV2ScanExec + if scan.readSupport.isInstanceOf[KafkaContinuousReadSupport] => + scan.scanConfig.asInstanceOf[KafkaContinuousScanConfig] + }.exists { config => // Ensure the new topic is present and the old topic is gone. - r.knownPartitions.exists(_.topic == topic2) + config.knownPartitions.exists(_.topic == topic2) }, s"query never reconfigured to new topic $topic2") } diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousTest.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousTest.scala index fa1468a3943c8..fa6bdc20bd4f9 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousTest.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousTest.scala @@ -21,7 +21,7 @@ import java.util.concurrent.atomic.AtomicInteger import org.apache.spark.SparkContext import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskEnd, SparkListenerTaskStart} -import org.apache.spark.sql.execution.datasources.v2.StreamingDataSourceV2Relation +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanExec import org.apache.spark.sql.execution.streaming.StreamExecution import org.apache.spark.sql.execution.streaming.continuous.ContinuousExecution import org.apache.spark.sql.streaming.Trigger @@ -46,8 +46,10 @@ trait KafkaContinuousTest extends KafkaSourceTest { testUtils.addPartitions(topic, newCount) eventually(timeout(streamingTimeout)) { assert( - query.lastExecution.logical.collectFirst { - case StreamingDataSourceV2Relation(_, _, _, r: KafkaContinuousReader) => r + query.lastExecution.executedPlan.collectFirst { + case scan: DataSourceV2ScanExec + if scan.readSupport.isInstanceOf[KafkaContinuousReadSupport] => + scan.scanConfig.asInstanceOf[KafkaContinuousScanConfig] }.exists(_.knownPartitions.size == newCount), s"query never reconfigured to $newCount partitions") } diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaDontFailOnDataLossSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaDontFailOnDataLossSuite.scala new file mode 100644 index 0000000000000..39c4e3fda1a4b --- /dev/null +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaDontFailOnDataLossSuite.scala @@ -0,0 +1,276 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.kafka010 + +import java.util.Properties +import java.util.concurrent.atomic.AtomicInteger + +import scala.collection.mutable +import scala.util.Random + +import org.scalatest.time.SpanSugar._ + +import org.apache.spark.SparkContext +import org.apache.spark.sql.{DataFrame, Dataset, ForeachWriter} +import org.apache.spark.sql.streaming.{StreamTest, Trigger} +import org.apache.spark.sql.test.{SharedSQLContext, TestSparkSession} + +/** + * This is a basic test trait which will set up a Kafka cluster that keeps only several records in + * a topic and ages out records very quickly. This is a helper trait to test + * "failonDataLoss=false" case with missing offsets. + * + * Note: there is a hard-code 30 seconds delay (kafka.log.LogManager.InitialTaskDelayMs) to clean up + * records. Hence each class extending this trait needs to wait at least 30 seconds (or even longer + * when running on a slow Jenkins machine) before records start to be removed. To make sure a test + * does see missing offsets, you can check the earliest offset in `eventually` and make sure it's + * not 0 rather than sleeping a hard-code duration. + */ +trait KafkaMissingOffsetsTest extends SharedSQLContext { + + protected var testUtils: KafkaTestUtils = _ + + override def createSparkSession(): TestSparkSession = { + // Set maxRetries to 3 to handle NPE from `poll` when deleting a topic + new TestSparkSession(new SparkContext("local[2,3]", "test-sql-context", sparkConf)) + } + + override def beforeAll(): Unit = { + super.beforeAll() + testUtils = new KafkaTestUtils { + override def brokerConfiguration: Properties = { + val props = super.brokerConfiguration + // Try to make Kafka clean up messages as fast as possible. However, there is a hard-code + // 30 seconds delay (kafka.log.LogManager.InitialTaskDelayMs) so this test should run at + // least 30 seconds. + props.put("log.cleaner.backoff.ms", "100") + // The size of RecordBatch V2 increases to support transactional write. + props.put("log.segment.bytes", "70") + props.put("log.retention.bytes", "40") + props.put("log.retention.check.interval.ms", "100") + props.put("delete.retention.ms", "10") + props.put("log.flush.scheduler.interval.ms", "10") + props + } + } + testUtils.setup() + } + + override def afterAll(): Unit = { + if (testUtils != null) { + testUtils.teardown() + testUtils = null + } + super.afterAll() + } +} + +class KafkaDontFailOnDataLossSuite extends StreamTest with KafkaMissingOffsetsTest { + + import testImplicits._ + + private val topicId = new AtomicInteger(0) + + private def newTopic(): String = s"failOnDataLoss-${topicId.getAndIncrement()}" + + /** + * @param testStreamingQuery whether to test a streaming query or a batch query. + * @param writeToTable the function to write the specified [[DataFrame]] to the given table. + */ + private def verifyMissingOffsetsDontCauseDuplicatedRecords( + testStreamingQuery: Boolean)(writeToTable: (DataFrame, String) => Unit): Unit = { + val topic = newTopic() + testUtils.createTopic(topic, partitions = 1) + testUtils.sendMessages(topic, (0 until 50).map(_.toString).toArray) + + eventually(timeout(60.seconds)) { + assert( + testUtils.getEarliestOffsets(Set(topic)).head._2 > 0, + "Kafka didn't delete records after 1 minute") + } + + val table = "DontFailOnDataLoss" + withTable(table) { + val kafkaOptions = Map( + "kafka.bootstrap.servers" -> testUtils.brokerAddress, + "kafka.metadata.max.age.ms" -> "1", + "subscribe" -> topic, + "startingOffsets" -> s"""{"$topic":{"0":0}}""", + "failOnDataLoss" -> "false", + "kafkaConsumer.pollTimeoutMs" -> "1000") + val df = + if (testStreamingQuery) { + val reader = spark.readStream.format("kafka") + kafkaOptions.foreach(kv => reader.option(kv._1, kv._2)) + reader.load() + } else { + val reader = spark.read.format("kafka") + kafkaOptions.foreach(kv => reader.option(kv._1, kv._2)) + reader.load() + } + writeToTable(df.selectExpr("CAST(value AS STRING)"), table) + val result = spark.table(table).as[String].collect().toList + assert(result.distinct.size === result.size, s"$result contains duplicated records") + // Make sure Kafka did remove some records so that this test is valid. + assert(result.size > 0 && result.size < 50) + } + } + + test("failOnDataLoss=false should not return duplicated records: v1") { + withSQLConf( + "spark.sql.streaming.disabledV2MicroBatchReaders" -> + classOf[KafkaSourceProvider].getCanonicalName) { + verifyMissingOffsetsDontCauseDuplicatedRecords(testStreamingQuery = true) { (df, table) => + val query = df.writeStream.format("memory").queryName(table).start() + try { + query.processAllAvailable() + } finally { + query.stop() + } + } + } + } + + test("failOnDataLoss=false should not return duplicated records: v2") { + verifyMissingOffsetsDontCauseDuplicatedRecords(testStreamingQuery = true) { (df, table) => + val query = df.writeStream.format("memory").queryName(table).start() + try { + query.processAllAvailable() + } finally { + query.stop() + } + } + } + + test("failOnDataLoss=false should not return duplicated records: continuous processing") { + verifyMissingOffsetsDontCauseDuplicatedRecords(testStreamingQuery = true) { (df, table) => + val query = df.writeStream + .format("memory") + .queryName(table) + .trigger(Trigger.Continuous(100)) + .start() + try { + // `processAllAvailable` doesn't work for continuous processing, so just wait until the last + // record appears in the table. + eventually(timeout(streamingTimeout)) { + assert(spark.table(table).as[String].collect().contains("49")) + } + } finally { + query.stop() + } + } + } + + test("failOnDataLoss=false should not return duplicated records: batch") { + verifyMissingOffsetsDontCauseDuplicatedRecords(testStreamingQuery = false) { (df, table) => + df.write.saveAsTable(table) + } + } +} + +class KafkaSourceStressForDontFailOnDataLossSuite extends StreamTest with KafkaMissingOffsetsTest { + + import testImplicits._ + + private val topicId = new AtomicInteger(0) + + private def newTopic(): String = s"failOnDataLoss-${topicId.getAndIncrement()}" + + protected def startStream(ds: Dataset[Int]) = { + ds.writeStream.foreach(new ForeachWriter[Int] { + + override def open(partitionId: Long, version: Long): Boolean = true + + override def process(value: Int): Unit = { + // Slow down the processing speed so that messages may be aged out. + Thread.sleep(Random.nextInt(500)) + } + + override def close(errorOrNull: Throwable): Unit = {} + }).start() + } + + test("stress test for failOnDataLoss=false") { + val reader = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("kafka.metadata.max.age.ms", "1") + .option("kafka.default.api.timeout.ms", "3000") + .option("subscribePattern", "failOnDataLoss.*") + .option("startingOffsets", "earliest") + .option("failOnDataLoss", "false") + .option("fetchOffset.retryIntervalMs", "3000") + val kafka = reader.load() + .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") + .as[(String, String)] + val query = startStream(kafka.map(kv => kv._2.toInt)) + + val testTime = 1.minutes + val startTime = System.currentTimeMillis() + // Track the current existing topics + val topics = mutable.ArrayBuffer[String]() + // Track topics that have been deleted + val deletedTopics = mutable.Set[String]() + while (System.currentTimeMillis() - testTime.toMillis < startTime) { + Random.nextInt(10) match { + case 0 => // Create a new topic + val topic = newTopic() + topics += topic + // As pushing messages into Kafka updates Zookeeper asynchronously, there is a small + // chance that a topic will be recreated after deletion due to the asynchronous update. + // Hence, always overwrite to handle this race condition. + testUtils.createTopic(topic, partitions = 1, overwrite = true) + logInfo(s"Create topic $topic") + case 1 if topics.nonEmpty => // Delete an existing topic + val topic = topics.remove(Random.nextInt(topics.size)) + testUtils.deleteTopic(topic) + logInfo(s"Delete topic $topic") + deletedTopics += topic + case 2 if deletedTopics.nonEmpty => // Recreate a topic that was deleted. + val topic = deletedTopics.toSeq(Random.nextInt(deletedTopics.size)) + deletedTopics -= topic + topics += topic + // As pushing messages into Kafka updates Zookeeper asynchronously, there is a small + // chance that a topic will be recreated after deletion due to the asynchronous update. + // Hence, always overwrite to handle this race condition. + testUtils.createTopic(topic, partitions = 1, overwrite = true) + logInfo(s"Create topic $topic") + case 3 => + Thread.sleep(1000) + case _ => // Push random messages + for (topic <- topics) { + val size = Random.nextInt(10) + for (_ <- 0 until size) { + testUtils.sendMessages(topic, Array(Random.nextInt(10).toString)) + } + } + } + // `failOnDataLoss` is `false`, we should not fail the query + if (query.exception.nonEmpty) { + throw query.exception.get + } + } + + query.stop() + // `failOnDataLoss` is `false`, we should not fail the query + if (query.exception.nonEmpty) { + throw query.exception.get + } + } +} diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala index c6412eac97dba..8e246dbbf5d70 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala @@ -20,36 +20,31 @@ package org.apache.spark.sql.kafka010 import java.io._ import java.nio.charset.StandardCharsets.UTF_8 import java.nio.file.{Files, Paths} -import java.util.{Locale, Optional, Properties} +import java.util.Locale import java.util.concurrent.ConcurrentLinkedQueue import java.util.concurrent.atomic.AtomicInteger import scala.collection.JavaConverters._ -import scala.collection.mutable import scala.io.Source import scala.util.Random -import org.apache.kafka.clients.producer.RecordMetadata +import org.apache.kafka.clients.producer.{KafkaProducer, ProducerRecord, RecordMetadata} import org.apache.kafka.common.TopicPartition import org.scalatest.concurrent.PatienceConfiguration.Timeout import org.scalatest.time.SpanSugar._ -import org.apache.spark.SparkContext -import org.apache.spark.sql.{Dataset, ForeachWriter, SparkSession} -import org.apache.spark.sql.catalyst.streaming.InternalOutputModes.Update +import org.apache.spark.sql.{ForeachWriter, SparkSession} import org.apache.spark.sql.execution.datasources.v2.StreamingDataSourceV2Relation import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.continuous.ContinuousExecution import org.apache.spark.sql.functions.{count, window} import org.apache.spark.sql.kafka010.KafkaSourceProvider._ import org.apache.spark.sql.sources.v2.DataSourceOptions -import org.apache.spark.sql.sources.v2.reader.streaming.{Offset => OffsetV2} import org.apache.spark.sql.streaming.{ProcessingTime, StreamTest} import org.apache.spark.sql.streaming.util.StreamManualClock -import org.apache.spark.sql.test.{SharedSQLContext, TestSparkSession} -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.test.SharedSQLContext -abstract class KafkaSourceTest extends StreamTest with SharedSQLContext { +abstract class KafkaSourceTest extends StreamTest with SharedSQLContext with KafkaTest { protected var testUtils: KafkaTestUtils = _ @@ -117,14 +112,16 @@ abstract class KafkaSourceTest extends StreamTest with SharedSQLContext { query.nonEmpty, "Cannot add data when there is no query for finding the active kafka source") - val sources = { + val sources: Seq[BaseStreamingSource] = { query.get.logicalPlan.collect { case StreamingExecutionRelation(source: KafkaSource, _) => source - case StreamingExecutionRelation(source: KafkaMicroBatchReader, _) => source + case StreamingExecutionRelation(source: KafkaMicroBatchReadSupport, _) => source } ++ (query.get.lastExecution match { case null => Seq() case e => e.logical.collect { - case StreamingDataSourceV2Relation(_, _, _, reader: KafkaContinuousReader) => reader + case r: StreamingDataSourceV2Relation + if r.readSupport.isInstanceOf[KafkaContinuousReadSupport] => + r.readSupport.asInstanceOf[KafkaContinuousReadSupport] } }) }.distinct @@ -160,6 +157,23 @@ abstract class KafkaSourceTest extends StreamTest with SharedSQLContext { s"AddKafkaData(topics = $topics, data = $data, message = $message)" } + object WithOffsetSync { + /** + * Run `func` to write some Kafka messages and wait until the latest offset of the given + * `TopicPartition` is not less than `expectedOffset`. + */ + def apply( + topicPartition: TopicPartition, + expectedOffset: Long)(func: () => Unit): StreamAction = { + Execute("Run Kafka Producer")(_ => { + func() + // This is a hack for the race condition that the committed message may be not visible to + // consumer for a short time. + testUtils.waitUntilOffsetAppears(topicPartition, expectedOffset) + }) + } + } + private val topicId = new AtomicInteger(0) protected def newTopic(): String = s"topic-${topicId.getAndIncrement()}" } @@ -290,6 +304,7 @@ abstract class KafkaMicroBatchSourceSuiteBase extends KafkaSourceSuiteBase { .format("kafka") .option("kafka.bootstrap.servers", testUtils.brokerAddress) .option("kafka.metadata.max.age.ms", "1") + .option("kafka.default.api.timeout.ms", "3000") .option("subscribePattern", s"$topicPrefix-.*") .option("failOnDataLoss", "false") @@ -467,6 +482,7 @@ abstract class KafkaMicroBatchSourceSuiteBase extends KafkaSourceSuiteBase { .format("kafka") .option("kafka.bootstrap.servers", testUtils.brokerAddress) .option("kafka.metadata.max.age.ms", "1") + .option("kafka.default.api.timeout.ms", "3000") .option("subscribe", topic) // If a topic is deleted and we try to poll data starting from offset 0, // the Kafka consumer will just block until timeout and return an empty result. @@ -595,6 +611,248 @@ abstract class KafkaMicroBatchSourceSuiteBase extends KafkaSourceSuiteBase { } ) } + + test("read Kafka transactional messages: read_committed") { + // This test will cover the following cases: + // 1. the whole batch contains no data messages + // 2. the first offset in a batch is not a committed data message + // 3. the last offset in a batch is not a committed data message + // 4. there is a gap in the middle of a batch + + val topic = newTopic() + testUtils.createTopic(topic, partitions = 1) + + val reader = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("kafka.metadata.max.age.ms", "1") + .option("kafka.isolation.level", "read_committed") + .option("maxOffsetsPerTrigger", 3) + .option("subscribe", topic) + .option("startingOffsets", "earliest") + // Set a short timeout to make the test fast. When a batch doesn't contain any visible data + // messages, "poll" will wait until timeout. + .option("kafkaConsumer.pollTimeoutMs", 5000) + val kafka = reader.load() + .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") + .as[(String, String)] + val mapped: org.apache.spark.sql.Dataset[_] = kafka.map(kv => kv._2.toInt) + + val clock = new StreamManualClock + + // Wait until the manual clock is waiting on further instructions to move forward. Then we can + // ensure all batches we are waiting for have been processed. + val waitUntilBatchProcessed = Execute { q => + eventually(Timeout(streamingTimeout)) { + if (!q.exception.isDefined) { + assert(clock.isStreamWaitingAt(clock.getTimeMillis())) + } + } + if (q.exception.isDefined) { + throw q.exception.get + } + } + + val topicPartition = new TopicPartition(topic, 0) + // The message values are the same as their offsets to make the test easy to follow + testUtils.withTranscationalProducer { producer => + testStream(mapped)( + StartStream(ProcessingTime(100), clock), + waitUntilBatchProcessed, + CheckAnswer(), + WithOffsetSync(topicPartition, expectedOffset = 5) { () => + // Send 5 messages. They should be visible only after being committed. + producer.beginTransaction() + (0 to 4).foreach { i => + producer.send(new ProducerRecord[String, String](topic, i.toString)).get() + } + }, + AdvanceManualClock(100), + waitUntilBatchProcessed, + // Should not see any uncommitted messages + CheckNewAnswer(), + WithOffsetSync(topicPartition, expectedOffset = 6) { () => + producer.commitTransaction() + }, + AdvanceManualClock(100), + waitUntilBatchProcessed, + CheckNewAnswer(0, 1, 2), // offset 0, 1, 2 + AdvanceManualClock(100), + waitUntilBatchProcessed, + CheckNewAnswer(3, 4), // offset: 3, 4, 5* [* means it's not a committed data message] + WithOffsetSync(topicPartition, expectedOffset = 12) { () => + // Send 5 messages and abort the transaction. They should not be read. + producer.beginTransaction() + (6 to 10).foreach { i => + producer.send(new ProducerRecord[String, String](topic, i.toString)).get() + } + producer.abortTransaction() + }, + AdvanceManualClock(100), + waitUntilBatchProcessed, + CheckNewAnswer(), // offset: 6*, 7*, 8* + AdvanceManualClock(100), + waitUntilBatchProcessed, + CheckNewAnswer(), // offset: 9*, 10*, 11* + WithOffsetSync(topicPartition, expectedOffset = 18) { () => + // Send 5 messages again. The consumer should skip the above aborted messages and read + // them. + producer.beginTransaction() + (12 to 16).foreach { i => + producer.send(new ProducerRecord[String, String](topic, i.toString)).get() + } + producer.commitTransaction() + }, + AdvanceManualClock(100), + waitUntilBatchProcessed, + CheckNewAnswer(12, 13, 14), // offset: 12, 13, 14 + AdvanceManualClock(100), + waitUntilBatchProcessed, + CheckNewAnswer(15, 16), // offset: 15, 16, 17* + WithOffsetSync(topicPartition, expectedOffset = 25) { () => + producer.beginTransaction() + producer.send(new ProducerRecord[String, String](topic, "18")).get() + producer.commitTransaction() + producer.beginTransaction() + producer.send(new ProducerRecord[String, String](topic, "20")).get() + producer.commitTransaction() + producer.beginTransaction() + producer.send(new ProducerRecord[String, String](topic, "22")).get() + producer.send(new ProducerRecord[String, String](topic, "23")).get() + producer.commitTransaction() + }, + AdvanceManualClock(100), + waitUntilBatchProcessed, + CheckNewAnswer(18, 20), // offset: 18, 19*, 20 + AdvanceManualClock(100), + waitUntilBatchProcessed, + CheckNewAnswer(22, 23), // offset: 21*, 22, 23 + AdvanceManualClock(100), + waitUntilBatchProcessed, + CheckNewAnswer() // offset: 24* + ) + } + } + + test("read Kafka transactional messages: read_uncommitted") { + // This test will cover the following cases: + // 1. the whole batch contains no data messages + // 2. the first offset in a batch is not a committed data message + // 3. the last offset in a batch is not a committed data message + // 4. there is a gap in the middle of a batch + + val topic = newTopic() + testUtils.createTopic(topic, partitions = 1) + + val reader = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("kafka.metadata.max.age.ms", "1") + .option("kafka.isolation.level", "read_uncommitted") + .option("maxOffsetsPerTrigger", 3) + .option("subscribe", topic) + .option("startingOffsets", "earliest") + // Set a short timeout to make the test fast. When a batch doesn't contain any visible data + // messages, "poll" will wait until timeout. + .option("kafkaConsumer.pollTimeoutMs", 5000) + val kafka = reader.load() + .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") + .as[(String, String)] + val mapped: org.apache.spark.sql.Dataset[_] = kafka.map(kv => kv._2.toInt) + + val clock = new StreamManualClock + + // Wait until the manual clock is waiting on further instructions to move forward. Then we can + // ensure all batches we are waiting for have been processed. + val waitUntilBatchProcessed = Execute { q => + eventually(Timeout(streamingTimeout)) { + if (!q.exception.isDefined) { + assert(clock.isStreamWaitingAt(clock.getTimeMillis())) + } + } + if (q.exception.isDefined) { + throw q.exception.get + } + } + + val topicPartition = new TopicPartition(topic, 0) + // The message values are the same as their offsets to make the test easy to follow + testUtils.withTranscationalProducer { producer => + testStream(mapped)( + StartStream(ProcessingTime(100), clock), + waitUntilBatchProcessed, + CheckNewAnswer(), + WithOffsetSync(topicPartition, expectedOffset = 5) { () => + // Send 5 messages. They should be visible only after being committed. + producer.beginTransaction() + (0 to 4).foreach { i => + producer.send(new ProducerRecord[String, String](topic, i.toString)).get() + } + }, + AdvanceManualClock(100), + waitUntilBatchProcessed, + CheckNewAnswer(0, 1, 2), // offset 0, 1, 2 + WithOffsetSync(topicPartition, expectedOffset = 6) { () => + producer.commitTransaction() + }, + AdvanceManualClock(100), + waitUntilBatchProcessed, + CheckNewAnswer(3, 4), // offset: 3, 4, 5* [* means it's not a committed data message] + WithOffsetSync(topicPartition, expectedOffset = 12) { () => + // Send 5 messages and abort the transaction. They should not be read. + producer.beginTransaction() + (6 to 10).foreach { i => + producer.send(new ProducerRecord[String, String](topic, i.toString)).get() + } + producer.abortTransaction() + }, + AdvanceManualClock(100), + waitUntilBatchProcessed, + CheckNewAnswer(6, 7, 8), // offset: 6, 7, 8 + AdvanceManualClock(100), + waitUntilBatchProcessed, + CheckNewAnswer(9, 10), // offset: 9, 10, 11* + WithOffsetSync(topicPartition, expectedOffset = 18) { () => + // Send 5 messages again. The consumer should skip the above aborted messages and read + // them. + producer.beginTransaction() + (12 to 16).foreach { i => + producer.send(new ProducerRecord[String, String](topic, i.toString)).get() + } + producer.commitTransaction() + }, + AdvanceManualClock(100), + waitUntilBatchProcessed, + CheckNewAnswer(12, 13, 14), // offset: 12, 13, 14 + AdvanceManualClock(100), + waitUntilBatchProcessed, + CheckNewAnswer(15, 16), // offset: 15, 16, 17* + WithOffsetSync(topicPartition, expectedOffset = 25) { () => + producer.beginTransaction() + producer.send(new ProducerRecord[String, String](topic, "18")).get() + producer.commitTransaction() + producer.beginTransaction() + producer.send(new ProducerRecord[String, String](topic, "20")).get() + producer.commitTransaction() + producer.beginTransaction() + producer.send(new ProducerRecord[String, String](topic, "22")).get() + producer.send(new ProducerRecord[String, String](topic, "23")).get() + producer.commitTransaction() + }, + AdvanceManualClock(100), + waitUntilBatchProcessed, + CheckNewAnswer(18, 20), // offset: 18, 19*, 20 + AdvanceManualClock(100), + waitUntilBatchProcessed, + CheckNewAnswer(22, 23), // offset: 21*, 22, 23 + AdvanceManualClock(100), + waitUntilBatchProcessed, + CheckNewAnswer() // offset: 24* + ) + } + } } @@ -647,7 +905,7 @@ class KafkaMicroBatchV2SourceSuite extends KafkaMicroBatchSourceSuiteBase { makeSureGetOffsetCalled, AssertOnQuery { query => query.logicalPlan.collect { - case StreamingExecutionRelation(_: KafkaMicroBatchReader, _) => true + case StreamingExecutionRelation(_: KafkaMicroBatchReadSupport, _) => true }.nonEmpty } ) @@ -672,17 +930,16 @@ class KafkaMicroBatchV2SourceSuite extends KafkaMicroBatchSourceSuiteBase { "kafka.bootstrap.servers" -> testUtils.brokerAddress, "subscribe" -> topic ) ++ Option(minPartitions).map { p => "minPartitions" -> p} - val reader = provider.createMicroBatchReader( - Optional.empty[StructType], dir.getAbsolutePath, new DataSourceOptions(options.asJava)) - reader.setOffsetRange( - Optional.of[OffsetV2](KafkaSourceOffset(Map(tp -> 0L))), - Optional.of[OffsetV2](KafkaSourceOffset(Map(tp -> 100L))) - ) - val factories = reader.planUnsafeInputPartitions().asScala + val readSupport = provider.createMicroBatchReadSupport( + dir.getAbsolutePath, new DataSourceOptions(options.asJava)) + val config = readSupport.newScanConfigBuilder( + KafkaSourceOffset(Map(tp -> 0L)), + KafkaSourceOffset(Map(tp -> 100L))).build() + val inputPartitions = readSupport.planInputPartitions(config) .map(_.asInstanceOf[KafkaMicroBatchInputPartition]) - withClue(s"minPartitions = $minPartitions generated factories $factories\n\t") { - assert(factories.size == numPartitionsGenerated) - factories.foreach { f => assert(f.reuseKafkaConsumer == reusesConsumers) } + withClue(s"minPartitions = $minPartitions generated factories $inputPartitions\n\t") { + assert(inputPartitions.size == numPartitionsGenerated) + inputPartitions.foreach { f => assert(f.reuseKafkaConsumer == reusesConsumers) } } } } @@ -933,7 +1190,8 @@ abstract class KafkaSourceSuiteBase extends KafkaSourceTest { makeSureGetOffsetCalled, Execute { q => // wait to reach the last offset in every partition - q.awaitOffset(0, KafkaSourceOffset(partitionOffsets.mapValues(_ => 3L))) + q.awaitOffset( + 0, KafkaSourceOffset(partitionOffsets.mapValues(_ => 3L)), streamingTimeout.toMillis) }, CheckAnswer(-20, -21, -22, 0, 1, 2, 11, 12, 22), StopStream, @@ -1103,6 +1361,7 @@ class KafkaSourceStressSuite extends KafkaSourceTest { .option("kafka.metadata.max.age.ms", "1") .option("subscribePattern", "stress.*") .option("failOnDataLoss", "false") + .option("kafka.default.api.timeout.ms", "3000") .load() .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") .as[(String, String)] @@ -1148,132 +1407,3 @@ class KafkaSourceStressSuite extends KafkaSourceTest { iterations = 50) } } - -class KafkaSourceStressForDontFailOnDataLossSuite extends StreamTest with SharedSQLContext { - - import testImplicits._ - - private var testUtils: KafkaTestUtils = _ - - private val topicId = new AtomicInteger(0) - - private def newTopic(): String = s"failOnDataLoss-${topicId.getAndIncrement()}" - - override def createSparkSession(): TestSparkSession = { - // Set maxRetries to 3 to handle NPE from `poll` when deleting a topic - new TestSparkSession(new SparkContext("local[2,3]", "test-sql-context", sparkConf)) - } - - override def beforeAll(): Unit = { - super.beforeAll() - testUtils = new KafkaTestUtils { - override def brokerConfiguration: Properties = { - val props = super.brokerConfiguration - // Try to make Kafka clean up messages as fast as possible. However, there is a hard-code - // 30 seconds delay (kafka.log.LogManager.InitialTaskDelayMs) so this test should run at - // least 30 seconds. - props.put("log.cleaner.backoff.ms", "100") - props.put("log.segment.bytes", "40") - props.put("log.retention.bytes", "40") - props.put("log.retention.check.interval.ms", "100") - props.put("delete.retention.ms", "10") - props.put("log.flush.scheduler.interval.ms", "10") - props - } - } - testUtils.setup() - } - - override def afterAll(): Unit = { - if (testUtils != null) { - testUtils.teardown() - testUtils = null - super.afterAll() - } - } - - protected def startStream(ds: Dataset[Int]) = { - ds.writeStream.foreach(new ForeachWriter[Int] { - - override def open(partitionId: Long, version: Long): Boolean = { - true - } - - override def process(value: Int): Unit = { - // Slow down the processing speed so that messages may be aged out. - Thread.sleep(Random.nextInt(500)) - } - - override def close(errorOrNull: Throwable): Unit = { - } - }).start() - } - - test("stress test for failOnDataLoss=false") { - val reader = spark - .readStream - .format("kafka") - .option("kafka.bootstrap.servers", testUtils.brokerAddress) - .option("kafka.metadata.max.age.ms", "1") - .option("subscribePattern", "failOnDataLoss.*") - .option("startingOffsets", "earliest") - .option("failOnDataLoss", "false") - .option("fetchOffset.retryIntervalMs", "3000") - val kafka = reader.load() - .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") - .as[(String, String)] - val query = startStream(kafka.map(kv => kv._2.toInt)) - - val testTime = 1.minutes - val startTime = System.currentTimeMillis() - // Track the current existing topics - val topics = mutable.ArrayBuffer[String]() - // Track topics that have been deleted - val deletedTopics = mutable.Set[String]() - while (System.currentTimeMillis() - testTime.toMillis < startTime) { - Random.nextInt(10) match { - case 0 => // Create a new topic - val topic = newTopic() - topics += topic - // As pushing messages into Kafka updates Zookeeper asynchronously, there is a small - // chance that a topic will be recreated after deletion due to the asynchronous update. - // Hence, always overwrite to handle this race condition. - testUtils.createTopic(topic, partitions = 1, overwrite = true) - logInfo(s"Create topic $topic") - case 1 if topics.nonEmpty => // Delete an existing topic - val topic = topics.remove(Random.nextInt(topics.size)) - testUtils.deleteTopic(topic) - logInfo(s"Delete topic $topic") - deletedTopics += topic - case 2 if deletedTopics.nonEmpty => // Recreate a topic that was deleted. - val topic = deletedTopics.toSeq(Random.nextInt(deletedTopics.size)) - deletedTopics -= topic - topics += topic - // As pushing messages into Kafka updates Zookeeper asynchronously, there is a small - // chance that a topic will be recreated after deletion due to the asynchronous update. - // Hence, always overwrite to handle this race condition. - testUtils.createTopic(topic, partitions = 1, overwrite = true) - logInfo(s"Create topic $topic") - case 3 => - Thread.sleep(1000) - case _ => // Push random messages - for (topic <- topics) { - val size = Random.nextInt(10) - for (_ <- 0 until size) { - testUtils.sendMessages(topic, Array(Random.nextInt(10).toString)) - } - } - } - // `failOnDataLoss` is `false`, we should not fail the query - if (query.exception.nonEmpty) { - throw query.exception.get - } - } - - query.stop() - // `failOnDataLoss` is `false`, we should not fail the query - if (query.exception.nonEmpty) { - throw query.exception.get - } - } -} diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaRelationSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaRelationSuite.scala index 91893df4ec32f..8cfca56433f5d 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaRelationSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaRelationSuite.scala @@ -20,14 +20,14 @@ package org.apache.spark.sql.kafka010 import java.util.Locale import java.util.concurrent.atomic.AtomicInteger +import org.apache.kafka.clients.producer.ProducerRecord import org.apache.kafka.common.TopicPartition -import org.scalatest.BeforeAndAfter import org.apache.spark.sql.QueryTest import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.util.Utils -class KafkaRelationSuite extends QueryTest with BeforeAndAfter with SharedSQLContext { +class KafkaRelationSuite extends QueryTest with SharedSQLContext with KafkaTest { import testImplicits._ @@ -48,9 +48,12 @@ class KafkaRelationSuite extends QueryTest with BeforeAndAfter with SharedSQLCon } override def afterAll(): Unit = { - if (testUtils != null) { - testUtils.teardown() - testUtils = null + try { + if (testUtils != null) { + testUtils.teardown() + testUtils = null + } + } finally { super.afterAll() } } @@ -235,4 +238,103 @@ class KafkaRelationSuite extends QueryTest with BeforeAndAfter with SharedSQLCon testBadOptions("subscribe" -> "")("no topics to subscribe") testBadOptions("subscribePattern" -> "")("pattern to subscribe is empty") } + + test("read Kafka transactional messages: read_committed") { + val topic = newTopic() + testUtils.createTopic(topic) + testUtils.withTranscationalProducer { producer => + val df = spark + .read + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("kafka.isolation.level", "read_committed") + .option("subscribe", topic) + .load() + .selectExpr("CAST(value AS STRING)") + + producer.beginTransaction() + (1 to 5).foreach { i => + producer.send(new ProducerRecord[String, String](topic, i.toString)).get() + } + + // Should not read any messages before they are committed + assert(df.isEmpty) + + producer.commitTransaction() + + // Should read all committed messages + testUtils.waitUntilOffsetAppears(new TopicPartition(topic, 0), 6) + checkAnswer(df, (1 to 5).map(_.toString).toDF) + + producer.beginTransaction() + (6 to 10).foreach { i => + producer.send(new ProducerRecord[String, String](topic, i.toString)).get() + } + producer.abortTransaction() + + // Should not read aborted messages + testUtils.waitUntilOffsetAppears(new TopicPartition(topic, 0), 12) + checkAnswer(df, (1 to 5).map(_.toString).toDF) + + producer.beginTransaction() + (11 to 15).foreach { i => + producer.send(new ProducerRecord[String, String](topic, i.toString)).get() + } + producer.commitTransaction() + + // Should skip aborted messages and read new committed ones. + testUtils.waitUntilOffsetAppears(new TopicPartition(topic, 0), 18) + checkAnswer(df, ((1 to 5) ++ (11 to 15)).map(_.toString).toDF) + } + } + + test("read Kafka transactional messages: read_uncommitted") { + val topic = newTopic() + testUtils.createTopic(topic) + testUtils.withTranscationalProducer { producer => + val df = spark + .read + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("kafka.isolation.level", "read_uncommitted") + .option("subscribe", topic) + .load() + .selectExpr("CAST(value AS STRING)") + + producer.beginTransaction() + (1 to 5).foreach { i => + producer.send(new ProducerRecord[String, String](topic, i.toString)).get() + } + + // "read_uncommitted" should see all messages including uncommitted ones + testUtils.waitUntilOffsetAppears(new TopicPartition(topic, 0), 5) + checkAnswer(df, (1 to 5).map(_.toString).toDF) + + producer.commitTransaction() + + // Should read all committed messages + testUtils.waitUntilOffsetAppears(new TopicPartition(topic, 0), 6) + checkAnswer(df, (1 to 5).map(_.toString).toDF) + + producer.beginTransaction() + (6 to 10).foreach { i => + producer.send(new ProducerRecord[String, String](topic, i.toString)).get() + } + producer.abortTransaction() + + // "read_uncommitted" should see all messages including uncommitted or aborted ones + testUtils.waitUntilOffsetAppears(new TopicPartition(topic, 0), 12) + checkAnswer(df, (1 to 10).map(_.toString).toDF) + + producer.beginTransaction() + (11 to 15).foreach { i => + producer.send(new ProducerRecord[String, String](topic, i.toString)).get() + } + producer.commitTransaction() + + // Should read all messages + testUtils.waitUntilOffsetAppears(new TopicPartition(topic, 0), 18) + checkAnswer(df, (1 to 15).map(_.toString).toDF) + } + } } diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSinkSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSinkSuite.scala index 7079ac6453ffc..81832fbdcd7ec 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSinkSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSinkSuite.scala @@ -33,7 +33,7 @@ import org.apache.spark.sql.streaming._ import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.{BinaryType, DataType} -class KafkaSinkSuite extends StreamTest with SharedSQLContext { +class KafkaSinkSuite extends StreamTest with SharedSQLContext with KafkaTest { import testImplicits._ protected var testUtils: KafkaTestUtils = _ @@ -48,9 +48,12 @@ class KafkaSinkSuite extends StreamTest with SharedSQLContext { } override def afterAll(): Unit = { - if (testUtils != null) { - testUtils.teardown() - testUtils = null + try { + if (testUtils != null) { + testUtils.teardown() + testUtils = null + } + } finally { super.afterAll() } } @@ -303,7 +306,7 @@ class KafkaSinkSuite extends StreamTest with SharedSQLContext { writer.stop() } assert(ex.getMessage.toLowerCase(Locale.ROOT).contains( - "value attribute type must be a string or binarytype")) + "value attribute type must be a string or binary")) try { ex = intercept[StreamingQueryException] { @@ -318,7 +321,7 @@ class KafkaSinkSuite extends StreamTest with SharedSQLContext { writer.stop() } assert(ex.getMessage.toLowerCase(Locale.ROOT).contains( - "key attribute type must be a string or binarytype")) + "key attribute type must be a string or binary")) } test("streaming - write to non-existing topic") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/UTF8StringBuilderSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaTest.scala similarity index 53% rename from sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/UTF8StringBuilderSuite.scala rename to external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaTest.scala index 1b25a4b191f86..19acda95c707c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/UTF8StringBuilderSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaTest.scala @@ -1,42 +1,32 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.expressions.codegen - -import org.apache.spark.SparkFunSuite -import org.apache.spark.unsafe.types.UTF8String - -class UTF8StringBuilderSuite extends SparkFunSuite { - - test("basic test") { - val sb = new UTF8StringBuilder() - assert(sb.build() === UTF8String.EMPTY_UTF8) - - sb.append("") - assert(sb.build() === UTF8String.EMPTY_UTF8) - - sb.append("abcd") - assert(sb.build() === UTF8String.fromString("abcd")) - - sb.append(UTF8String.fromString("1234")) - assert(sb.build() === UTF8String.fromString("abcd1234")) - - // expect to grow an internal buffer - sb.append(UTF8String.fromString("efgijk567890")) - assert(sb.build() === UTF8String.fromString("abcd1234efgijk567890")) - } -} +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.kafka010 + +import org.scalatest.BeforeAndAfterAll + +import org.apache.spark.SparkFunSuite + +/** A trait to clean cached Kafka producers in `afterAll` */ +trait KafkaTest extends BeforeAndAfterAll { + self: SparkFunSuite => + + override def afterAll(): Unit = { + super.afterAll() + CachedKafkaProducer.clear() + } +} diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaTestUtils.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaTestUtils.scala index 75245943c4936..bf6934be52705 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaTestUtils.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaTestUtils.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.kafka010 import java.io.{File, IOException} import java.lang.{Integer => JInt} import java.net.InetSocketAddress -import java.util.{Map => JMap, Properties} +import java.util.{Map => JMap, Properties, UUID} import java.util.concurrent.TimeUnit import scala.collection.JavaConverters._ @@ -29,20 +29,23 @@ import scala.util.Random import kafka.admin.AdminUtils import kafka.api.Request -import kafka.common.TopicAndPartition -import kafka.server.{KafkaConfig, KafkaServer, OffsetCheckpoint} +import kafka.server.{KafkaConfig, KafkaServer} +import kafka.server.checkpoints.OffsetCheckpointFile import kafka.utils.ZkUtils +import org.apache.kafka.clients.CommonClientConfigs +import org.apache.kafka.clients.admin.{AdminClient, CreatePartitionsOptions, NewPartitions} import org.apache.kafka.clients.consumer.KafkaConsumer import org.apache.kafka.clients.producer._ import org.apache.kafka.common.TopicPartition +import org.apache.kafka.common.network.ListenerName import org.apache.kafka.common.serialization.{StringDeserializer, StringSerializer} import org.apache.zookeeper.server.{NIOServerCnxnFactory, ZooKeeperServer} import org.scalatest.concurrent.Eventually._ import org.scalatest.time.SpanSugar._ -import org.apache.spark.SparkConf +import org.apache.spark.{SparkConf, SparkException} import org.apache.spark.internal.Logging -import org.apache.spark.util.Utils +import org.apache.spark.util.{ShutdownHookManager, Utils} /** * This is a helper class for Kafka test suites. This has the functionality to set up @@ -53,17 +56,18 @@ import org.apache.spark.util.Utils class KafkaTestUtils(withBrokerProps: Map[String, Object] = Map.empty) extends Logging { // Zookeeper related configurations - private val zkHost = "localhost" + private val zkHost = "127.0.0.1" private var zkPort: Int = 0 private val zkConnectionTimeout = 60000 - private val zkSessionTimeout = 6000 + private val zkSessionTimeout = 10000 private var zookeeper: EmbeddedZookeeper = _ private var zkUtils: ZkUtils = _ + private var adminClient: AdminClient = null // Kafka broker related configurations - private val brokerHost = "localhost" + private val brokerHost = "127.0.0.1" private var brokerPort = 0 private var brokerConf: KafkaConfig = _ @@ -76,6 +80,7 @@ class KafkaTestUtils(withBrokerProps: Map[String, Object] = Map.empty) extends L // Flag to test whether the system is correctly started private var zkReady = false private var brokerReady = false + private var leakDetector: AnyRef = null def zkAddress: String = { assert(zkReady, "Zookeeper not setup yet or already torn down, cannot get zookeeper address") @@ -113,21 +118,37 @@ class KafkaTestUtils(withBrokerProps: Map[String, Object] = Map.empty) extends L brokerConf = new KafkaConfig(brokerConfiguration, doLog = false) server = new KafkaServer(brokerConf) server.startup() - brokerPort = server.boundPort() + brokerPort = server.boundPort(new ListenerName("PLAINTEXT")) (server, brokerPort) }, new SparkConf(), "KafkaBroker") brokerReady = true + val props = new Properties() + props.put(CommonClientConfigs.BOOTSTRAP_SERVERS_CONFIG, s"$brokerHost:$brokerPort") + adminClient = AdminClient.create(props) } /** setup the whole embedded servers, including Zookeeper and Kafka brokers */ def setup(): Unit = { + // Set up a KafkaTestUtils leak detector so that we can see where the leak KafkaTestUtils is + // created. + val exception = new SparkException("It was created at: ") + leakDetector = ShutdownHookManager.addShutdownHook { () => + logError("Found a leak KafkaTestUtils.", exception) + } + setupEmbeddedZookeeper() setupEmbeddedKafkaServer() + eventually(timeout(60.seconds)) { + assert(zkUtils.getAllBrokersInCluster().nonEmpty, "Broker was not up in 60 seconds") + } } /** Teardown the whole servers, including Kafka broker and Zookeeper */ def teardown(): Unit = { + if (leakDetector != null) { + ShutdownHookManager.removeShutdownHook(leakDetector) + } brokerReady = false zkReady = false @@ -136,6 +157,10 @@ class KafkaTestUtils(withBrokerProps: Map[String, Object] = Map.empty) extends L producer = null } + if (adminClient != null) { + adminClient.close() + } + if (server != null) { server.shutdown() server.awaitShutdown() @@ -203,7 +228,9 @@ class KafkaTestUtils(withBrokerProps: Map[String, Object] = Map.empty) extends L /** Add new partitions to a Kafka topic */ def addPartitions(topic: String, partitions: Int): Unit = { - AdminUtils.addPartitions(zkUtils, topic, partitions) + adminClient.createPartitions( + Map(topic -> NewPartitions.increaseTo(partitions)).asJava, + new CreatePartitionsOptions) // wait until metadata is propagated (0 until partitions).foreach { p => waitUntilMetadataIsPropagated(topic, p) @@ -287,15 +314,23 @@ class KafkaTestUtils(withBrokerProps: Map[String, Object] = Map.empty) extends L protected def brokerConfiguration: Properties = { val props = new Properties() props.put("broker.id", "0") - props.put("host.name", "localhost") - props.put("advertised.host.name", "localhost") + props.put("host.name", "127.0.0.1") + props.put("advertised.host.name", "127.0.0.1") props.put("port", brokerPort.toString) props.put("log.dir", Utils.createTempDir().getAbsolutePath) props.put("zookeeper.connect", zkAddress) + props.put("zookeeper.connection.timeout.ms", "60000") props.put("log.flush.interval.messages", "1") props.put("replica.socket.timeout.ms", "1500") props.put("delete.topic.enable", "true") + props.put("group.initial.rebalance.delay.ms", "10") + + // Change the following settings as we have only 1 broker props.put("offsets.topic.num.partitions", "1") + props.put("offsets.topic.replication.factor", "1") + props.put("transaction.state.log.replication.factor", "1") + props.put("transaction.state.log.min.isr", "1") + // Can not use properties.putAll(propsMap.asJava) in scala-2.12 // See https://github.com/scala/bug/issues/10418 withBrokerProps.foreach { case (k, v) => props.put(k, v) } @@ -312,6 +347,19 @@ class KafkaTestUtils(withBrokerProps: Map[String, Object] = Map.empty) extends L props } + /** Call `f` with a `KafkaProducer` that has initialized transactions. */ + def withTranscationalProducer(f: KafkaProducer[String, String] => Unit): Unit = { + val props = producerConfiguration + props.put("transactional.id", UUID.randomUUID().toString) + val producer = new KafkaProducer[String, String](props) + try { + producer.initTransactions() + f(producer) + } finally { + producer.close() + } + } + private def consumerConfiguration: Properties = { val props = new Properties() props.put("bootstrap.servers", brokerAddress) @@ -327,7 +375,7 @@ class KafkaTestUtils(withBrokerProps: Map[String, Object] = Map.empty) extends L topic: String, numPartitions: Int, servers: Seq[KafkaServer]): Unit = { - val topicAndPartitions = (0 until numPartitions).map(TopicAndPartition(topic, _)) + val topicAndPartitions = (0 until numPartitions).map(new TopicPartition(topic, _)) import ZkUtils._ // wait until admin path for delete topic is deleted, signaling completion of topic deletion @@ -337,7 +385,7 @@ class KafkaTestUtils(withBrokerProps: Map[String, Object] = Map.empty) extends L assert(!zkUtils.pathExists(getTopicPath(topic)), s"${getTopicPath(topic)} still exists") // ensure that the topic-partition has been deleted from all brokers' replica managers assert(servers.forall(server => topicAndPartitions.forall(tp => - server.replicaManager.getPartition(tp.topic, tp.partition) == None)), + server.replicaManager.getPartition(tp) == None)), s"topic $topic still exists in the replica manager") // ensure that logs from all replicas are deleted if delete topic is marked successful assert(servers.forall(server => topicAndPartitions.forall(tp => @@ -345,8 +393,8 @@ class KafkaTestUtils(withBrokerProps: Map[String, Object] = Map.empty) extends L s"topic $topic still exists in log mananger") // ensure that topic is removed from all cleaner offsets assert(servers.forall(server => topicAndPartitions.forall { tp => - val checkpoints = server.getLogManager().logDirs.map { logDir => - new OffsetCheckpoint(new File(logDir, "cleaner-offset-checkpoint")).read() + val checkpoints = server.getLogManager().liveLogDirs.map { logDir => + new OffsetCheckpointFile(new File(logDir, "cleaner-offset-checkpoint")).read() } checkpoints.forall(checkpointsPerLogDir => !checkpointsPerLogDir.contains(tp)) }), s"checkpoint for topic $topic still exists") @@ -379,11 +427,9 @@ class KafkaTestUtils(withBrokerProps: Map[String, Object] = Map.empty) extends L private def waitUntilMetadataIsPropagated(topic: String, partition: Int): Unit = { def isPropagated = server.apis.metadataCache.getPartitionInfo(topic, partition) match { case Some(partitionState) => - val leaderAndInSyncReplicas = partitionState.leaderIsrAndControllerEpoch.leaderAndIsr - zkUtils.getLeaderForPartition(topic, partition).isDefined && - Request.isValidBrokerId(leaderAndInSyncReplicas.leader) && - leaderAndInSyncReplicas.isr.nonEmpty + Request.isValidBrokerId(partitionState.basePartitionState.leader) && + !partitionState.basePartitionState.replicas.isEmpty case _ => false @@ -393,6 +439,16 @@ class KafkaTestUtils(withBrokerProps: Map[String, Object] = Map.empty) extends L } } + /** + * Wait until the latest offset of the given `TopicPartition` is not less than `offset`. + */ + def waitUntilOffsetAppears(topicPartition: TopicPartition, offset: Long): Unit = { + eventually(timeout(60.seconds)) { + val currentOffset = getLatestOffsets(Set(topicPartition.topic)).get(topicPartition) + assert(currentOffset.nonEmpty && currentOffset.get >= offset) + } + } + private class EmbeddedZookeeper(val zkConnect: String) { val snapshotDir = Utils.createTempDir() val logDir = Utils.createTempDir() diff --git a/external/kafka-0-10/pom.xml b/external/kafka-0-10/pom.xml index 3b124b2a69d50..a97fd35bfbb73 100644 --- a/external/kafka-0-10/pom.xml +++ b/external/kafka-0-10/pom.xml @@ -28,7 +28,8 @@ spark-streaming-kafka-0-10_2.11 streaming-kafka-0-10 - 0.10.0.1 + + 2.0.0 jar Spark Integration for Kafka 0.10 @@ -58,6 +59,20 @@ kafka_${scala.binary.version} ${kafka.version} test + + + com.fasterxml.jackson.core + jackson-core + + + com.fasterxml.jackson.core + jackson-databind + + + com.fasterxml.jackson.core + jackson-annotations + + net.sf.jopt-simple @@ -93,13 +108,4 @@ target/scala-${scala.binary.version}/test-classes - - - scala-2.12 - - 0.10.1.1 - - - - diff --git a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/DirectKafkaInputDStream.scala b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/DirectKafkaInputDStream.scala index c3221481556f5..0acc9b8d2a0cf 100644 --- a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/DirectKafkaInputDStream.scala +++ b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/DirectKafkaInputDStream.scala @@ -154,7 +154,8 @@ private[spark] class DirectKafkaInputDStream[K, V]( if (effectiveRateLimitPerPartition.values.sum > 0) { val secsPerBatch = context.graph.batchDuration.milliseconds.toDouble / 1000 Some(effectiveRateLimitPerPartition.map { - case (tp, limit) => tp -> Math.max((secsPerBatch * limit).toLong, 1L) + case (tp, limit) => tp -> Math.max((secsPerBatch * limit).toLong, + ppc.minRatePerPartition(tp)) }) } else { None @@ -166,6 +167,8 @@ private[spark] class DirectKafkaInputDStream[K, V]( * which would throw off consumer position. Fix position if this happens. */ private def paranoidPoll(c: Consumer[K, V]): Unit = { + // don't actually want to consume any messages, so pause all partitions + c.pause(c.assignment()) val msgs = c.poll(0) if (!msgs.isEmpty) { // position should be minimum offset per topicpartition @@ -204,8 +207,6 @@ private[spark] class DirectKafkaInputDStream[K, V]( // position for new partitions determined by auto.offset.reset if no commit currentOffsets = currentOffsets ++ newPartitions.map(tp => tp -> c.position(tp)).toMap - // don't want to consume messages, so pause - c.pause(newPartitions.asJava) // find latest available offsets c.seekToEnd(currentOffsets.keySet.asJava) parts.map(tp => tp -> c.position(tp)).toMap @@ -262,9 +263,6 @@ private[spark] class DirectKafkaInputDStream[K, V]( tp -> c.position(tp) }.toMap } - - // don't actually want to consume any messages, so pause all partitions - c.pause(currentOffsets.keySet.asJava) } override def stop(): Unit = this.synchronized { diff --git a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/KafkaRDD.scala b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/KafkaRDD.scala index 3efc90fe466b2..4513dca44c7c6 100644 --- a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/KafkaRDD.scala +++ b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/KafkaRDD.scala @@ -237,7 +237,7 @@ private class KafkaRDDIterator[K, V]( cacheLoadFactor: Float ) extends Iterator[ConsumerRecord[K, V]] { - context.addTaskCompletionListener(_ => closeIfNeeded()) + context.addTaskCompletionListener[Unit](_ => closeIfNeeded()) val consumer = { KafkaDataConsumer.init(cacheInitialCapacity, cacheMaxCapacity, cacheLoadFactor) diff --git a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/PerPartitionConfig.scala b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/PerPartitionConfig.scala index 4792f2a955110..4017fdbcaf95e 100644 --- a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/PerPartitionConfig.scala +++ b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/PerPartitionConfig.scala @@ -34,6 +34,7 @@ abstract class PerPartitionConfig extends Serializable { * from each Kafka partition. */ def maxRatePerPartition(topicPartition: TopicPartition): Long + def minRatePerPartition(topicPartition: TopicPartition): Long = 1 } /** @@ -42,6 +43,8 @@ abstract class PerPartitionConfig extends Serializable { private class DefaultPerPartitionConfig(conf: SparkConf) extends PerPartitionConfig { val maxRate = conf.getLong("spark.streaming.kafka.maxRatePerPartition", 0) + val minRate = conf.getLong("spark.streaming.kafka.minRatePerPartition", 1) def maxRatePerPartition(topicPartition: TopicPartition): Long = maxRate + override def minRatePerPartition(topicPartition: TopicPartition): Long = minRate } diff --git a/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/DirectKafkaStreamSuite.scala b/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/DirectKafkaStreamSuite.scala index 35e4678f2e3c8..1974bb1e12e15 100644 --- a/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/DirectKafkaStreamSuite.scala +++ b/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/DirectKafkaStreamSuite.scala @@ -59,14 +59,19 @@ class DirectKafkaStreamSuite private var kafkaTestUtils: KafkaTestUtils = _ override def beforeAll { + super.beforeAll() kafkaTestUtils = new KafkaTestUtils kafkaTestUtils.setup() } override def afterAll { - if (kafkaTestUtils != null) { - kafkaTestUtils.teardown() - kafkaTestUtils = null + try { + if (kafkaTestUtils != null) { + kafkaTestUtils.teardown() + kafkaTestUtils = null + } + } finally { + super.afterAll() } } @@ -664,7 +669,8 @@ class DirectKafkaStreamSuite kafkaStream.stop() } - test("maxMessagesPerPartition with zero offset and rate equal to one") { + test("maxMessagesPerPartition with zero offset and rate equal to the specified" + + " minimum with default 1") { val topic = "backpressure" val kafkaParams = getKafkaParams() val batchIntervalMilliseconds = 60000 @@ -674,6 +680,8 @@ class DirectKafkaStreamSuite .setMaster("local[1]") .setAppName(this.getClass.getSimpleName) .set("spark.streaming.kafka.maxRatePerPartition", "100") + .set("spark.streaming.kafka.minRatePerPartition", "5") + // Setup the streaming context ssc = new StreamingContext(sparkConf, Milliseconds(batchIntervalMilliseconds)) @@ -704,12 +712,13 @@ class DirectKafkaStreamSuite ) val result = kafkaStream.maxMessagesPerPartition(offsets) val expected = Map( - new TopicPartition(topic, 0) -> 1L, + new TopicPartition(topic, 0) -> 5L, new TopicPartition(topic, 1) -> 10L, new TopicPartition(topic, 2) -> 20L, new TopicPartition(topic, 3) -> 30L ) - assert(result.contains(expected), s"Number of messages per partition must be at least 1") + assert(result.contains(expected), s"Number of messages per partition must be at least equal" + + s" to the specified minimum") } /** Get the generated offset ranges from the DirectKafkaStream */ diff --git a/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/KafkaRDDSuite.scala b/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/KafkaRDDSuite.scala index 271adea1df731..561bca5f55370 100644 --- a/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/KafkaRDDSuite.scala +++ b/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/KafkaRDDSuite.scala @@ -23,11 +23,11 @@ import java.io.File import scala.collection.JavaConverters._ import scala.util.Random -import kafka.common.TopicAndPartition -import kafka.log._ -import kafka.message._ +import kafka.log.{CleanerConfig, Log, LogCleaner, LogConfig, ProducerStateManager} +import kafka.server.{BrokerTopicStats, LogDirFailureChannel} import kafka.utils.Pool import org.apache.kafka.common.TopicPartition +import org.apache.kafka.common.record.{CompressionType, MemoryRecords, SimpleRecord} import org.apache.kafka.common.serialization.StringDeserializer import org.scalatest.BeforeAndAfterAll @@ -44,20 +44,27 @@ class KafkaRDDSuite extends SparkFunSuite with BeforeAndAfterAll { private var sc: SparkContext = _ override def beforeAll { + super.beforeAll() sc = new SparkContext(sparkConf) kafkaTestUtils = new KafkaTestUtils kafkaTestUtils.setup() } override def afterAll { - if (sc != null) { - sc.stop - sc = null - } - - if (kafkaTestUtils != null) { - kafkaTestUtils.teardown() - kafkaTestUtils = null + try { + try { + if (sc != null) { + sc.stop + sc = null + } + } finally { + if (kafkaTestUtils != null) { + kafkaTestUtils.teardown() + kafkaTestUtils = null + } + } + } finally { + super.afterAll() } } @@ -72,33 +79,39 @@ class KafkaRDDSuite extends SparkFunSuite with BeforeAndAfterAll { private def compactLogs(topic: String, partition: Int, messages: Array[(String, String)]) { val mockTime = new MockTime() - // LogCleaner in 0.10 version of Kafka is still expecting the old TopicAndPartition api - val logs = new Pool[TopicAndPartition, Log]() + val logs = new Pool[TopicPartition, Log]() val logDir = kafkaTestUtils.brokerLogDir val dir = new File(logDir, topic + "-" + partition) dir.mkdirs() val logProps = new ju.Properties() logProps.put(LogConfig.CleanupPolicyProp, LogConfig.Compact) logProps.put(LogConfig.MinCleanableDirtyRatioProp, java.lang.Float.valueOf(0.1f)) + val logDirFailureChannel = new LogDirFailureChannel(1) + val topicPartition = new TopicPartition(topic, partition) val log = new Log( dir, LogConfig(logProps), 0L, + 0L, mockTime.scheduler, - mockTime + new BrokerTopicStats(), + mockTime, + Int.MaxValue, + Int.MaxValue, + topicPartition, + new ProducerStateManager(topicPartition, dir), + logDirFailureChannel ) messages.foreach { case (k, v) => - val msg = new ByteBufferMessageSet( - NoCompressionCodec, - new Message(v.getBytes, k.getBytes, Message.NoTimestamp, Message.CurrentMagicValue)) - log.append(msg) + val record = new SimpleRecord(k.getBytes, v.getBytes) + log.appendAsLeader(MemoryRecords.withRecords(CompressionType.NONE, record), 0); } log.roll() - logs.put(TopicAndPartition(topic, partition), log) + logs.put(topicPartition, log) - val cleaner = new LogCleaner(CleanerConfig(), logDirs = Array(dir), logs = logs) + val cleaner = new LogCleaner(CleanerConfig(), Array(dir), logs, logDirFailureChannel) cleaner.startup() - cleaner.awaitCleaned(topic, partition, log.activeSegment.baseOffset, 1000) + cleaner.awaitCleaned(new TopicPartition(topic, partition), log.activeSegment.baseOffset, 1000) cleaner.shutdown() mockTime.scheduler.shutdown() diff --git a/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/KafkaTestUtils.scala b/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/KafkaTestUtils.scala index 70b579d96d692..efcd5d6a5cdd3 100644 --- a/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/KafkaTestUtils.scala +++ b/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/KafkaTestUtils.scala @@ -32,13 +32,14 @@ import kafka.api.Request import kafka.server.{KafkaConfig, KafkaServer} import kafka.utils.ZkUtils import org.apache.kafka.clients.producer.{KafkaProducer, ProducerRecord} +import org.apache.kafka.common.network.ListenerName import org.apache.kafka.common.serialization.StringSerializer import org.apache.zookeeper.server.{NIOServerCnxnFactory, ZooKeeperServer} -import org.apache.spark.SparkConf +import org.apache.spark.{SparkConf, SparkException} import org.apache.spark.internal.Logging import org.apache.spark.streaming.Time -import org.apache.spark.util.Utils +import org.apache.spark.util.{ShutdownHookManager, Utils} /** * This is a helper class for Kafka test suites. This has the functionality to set up @@ -49,17 +50,17 @@ import org.apache.spark.util.Utils private[kafka010] class KafkaTestUtils extends Logging { // Zookeeper related configurations - private val zkHost = "localhost" + private val zkHost = "127.0.0.1" private var zkPort: Int = 0 private val zkConnectionTimeout = 60000 - private val zkSessionTimeout = 6000 + private val zkSessionTimeout = 10000 private var zookeeper: EmbeddedZookeeper = _ private var zkUtils: ZkUtils = _ // Kafka broker related configurations - private val brokerHost = "localhost" + private val brokerHost = "127.0.0.1" private var brokerPort = 0 private var brokerConf: KafkaConfig = _ @@ -72,6 +73,7 @@ private[kafka010] class KafkaTestUtils extends Logging { // Flag to test whether the system is correctly started private var zkReady = false private var brokerReady = false + private var leakDetector: AnyRef = null def zkAddress: String = { assert(zkReady, "Zookeeper not setup yet or already torn down, cannot get zookeeper address") @@ -109,7 +111,7 @@ private[kafka010] class KafkaTestUtils extends Logging { brokerConf = new KafkaConfig(brokerConfiguration, doLog = false) server = new KafkaServer(brokerConf) server.startup() - brokerPort = server.boundPort() + brokerPort = server.boundPort(new ListenerName("PLAINTEXT")) (server, brokerPort) }, new SparkConf(), "KafkaBroker") @@ -118,12 +120,22 @@ private[kafka010] class KafkaTestUtils extends Logging { /** setup the whole embedded servers, including Zookeeper and Kafka brokers */ def setup(): Unit = { + // Set up a KafkaTestUtils leak detector so that we can see where the leak KafkaTestUtils is + // created. + val exception = new SparkException("It was created at: ") + leakDetector = ShutdownHookManager.addShutdownHook { () => + logError("Found a leak KafkaTestUtils.", exception) + } + setupEmbeddedZookeeper() setupEmbeddedKafkaServer() } /** Teardown the whole servers, including Kafka broker and Zookeeper */ def teardown(): Unit = { + if (leakDetector != null) { + ShutdownHookManager.removeShutdownHook(leakDetector) + } brokerReady = false zkReady = false @@ -216,12 +228,18 @@ private[kafka010] class KafkaTestUtils extends Logging { private def brokerConfiguration: Properties = { val props = new Properties() props.put("broker.id", "0") - props.put("host.name", "localhost") + props.put("host.name", "127.0.0.1") + props.put("advertised.host.name", "127.0.0.1") props.put("port", brokerPort.toString) props.put("log.dir", brokerLogDir) props.put("zookeeper.connect", zkAddress) + props.put("zookeeper.connection.timeout.ms", "60000") props.put("log.flush.interval.messages", "1") props.put("replica.socket.timeout.ms", "1500") + props.put("delete.topic.enable", "true") + props.put("offsets.topic.num.partitions", "1") + props.put("offsets.topic.replication.factor", "1") + props.put("group.initial.rebalance.delay.ms", "10") props } @@ -270,12 +288,10 @@ private[kafka010] class KafkaTestUtils extends Logging { private def waitUntilMetadataIsPropagated(topic: String, partition: Int): Unit = { def isPropagated = server.apis.metadataCache.getPartitionInfo(topic, partition) match { case Some(partitionState) => - val leaderAndInSyncReplicas = partitionState.leaderIsrAndControllerEpoch.leaderAndIsr - + val leader = partitionState.basePartitionState.leader + val isr = partitionState.basePartitionState.isr zkUtils.getLeaderForPartition(topic, partition).isDefined && - Request.isValidBrokerId(leaderAndInSyncReplicas.leader) && - leaderAndInSyncReplicas.isr.nonEmpty - + Request.isValidBrokerId(leader) && !isr.isEmpty case _ => false } diff --git a/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/mocks/MockScheduler.scala b/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/mocks/MockScheduler.scala index 928e1a6ef54b9..4811d041e7e9e 100644 --- a/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/mocks/MockScheduler.scala +++ b/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/mocks/MockScheduler.scala @@ -21,7 +21,8 @@ import java.util.concurrent.TimeUnit import scala.collection.mutable.PriorityQueue -import kafka.utils.{Scheduler, Time} +import kafka.utils.Scheduler +import org.apache.kafka.common.utils.Time /** * A mock scheduler that executes tasks synchronously using a mock time instance. diff --git a/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/mocks/MockTime.scala b/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/mocks/MockTime.scala index a68f94db1f689..8a8646ee4eb94 100644 --- a/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/mocks/MockTime.scala +++ b/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/mocks/MockTime.scala @@ -19,7 +19,7 @@ package org.apache.spark.streaming.kafka010.mocks import java.util.concurrent._ -import kafka.utils.Time +import org.apache.kafka.common.utils.Time /** * A class used for unit testing things which depend on the Time interface. @@ -36,12 +36,14 @@ private[kafka010] class MockTime(@volatile private var currentMs: Long) extends def this() = this(System.currentTimeMillis) - def milliseconds: Long = currentMs + override def milliseconds: Long = currentMs - def nanoseconds: Long = + override def hiResClockMs(): Long = milliseconds + + override def nanoseconds: Long = TimeUnit.NANOSECONDS.convert(currentMs, TimeUnit.MILLISECONDS) - def sleep(ms: Long) { + override def sleep(ms: Long) { this.currentMs += ms scheduler.tick() } diff --git a/external/kafka-0-8-assembly/pom.xml b/external/kafka-0-8-assembly/pom.xml index 41bc8b3e3ee1f..6be17a81f3fed 100644 --- a/external/kafka-0-8-assembly/pom.xml +++ b/external/kafka-0-8-assembly/pom.xml @@ -95,11 +95,6 @@ log4j provided - - net.java.dev.jets3t - jets3t - provided - org.scala-lang scala-library diff --git a/external/kafka-0-8/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala b/external/kafka-0-8/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala index ecca38784e777..3fd37f4c8ac90 100644 --- a/external/kafka-0-8/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala +++ b/external/kafka-0-8/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala @@ -57,14 +57,19 @@ class DirectKafkaStreamSuite private var kafkaTestUtils: KafkaTestUtils = _ override def beforeAll { + super.beforeAll() kafkaTestUtils = new KafkaTestUtils kafkaTestUtils.setup() } override def afterAll { - if (kafkaTestUtils != null) { - kafkaTestUtils.teardown() - kafkaTestUtils = null + try { + if (kafkaTestUtils != null) { + kafkaTestUtils.teardown() + kafkaTestUtils = null + } + } finally { + super.afterAll() } } diff --git a/external/kafka-0-8/src/test/scala/org/apache/spark/streaming/kafka/KafkaClusterSuite.scala b/external/kafka-0-8/src/test/scala/org/apache/spark/streaming/kafka/KafkaClusterSuite.scala index d66830cbacdee..73d528518d486 100644 --- a/external/kafka-0-8/src/test/scala/org/apache/spark/streaming/kafka/KafkaClusterSuite.scala +++ b/external/kafka-0-8/src/test/scala/org/apache/spark/streaming/kafka/KafkaClusterSuite.scala @@ -32,6 +32,7 @@ class KafkaClusterSuite extends SparkFunSuite with BeforeAndAfterAll { private var kafkaTestUtils: KafkaTestUtils = _ override def beforeAll() { + super.beforeAll() kafkaTestUtils = new KafkaTestUtils kafkaTestUtils.setup() @@ -41,9 +42,13 @@ class KafkaClusterSuite extends SparkFunSuite with BeforeAndAfterAll { } override def afterAll() { - if (kafkaTestUtils != null) { - kafkaTestUtils.teardown() - kafkaTestUtils = null + try { + if (kafkaTestUtils != null) { + kafkaTestUtils.teardown() + kafkaTestUtils = null + } + } finally { + super.afterAll() } } diff --git a/external/kafka-0-8/src/test/scala/org/apache/spark/streaming/kafka/KafkaRDDSuite.scala b/external/kafka-0-8/src/test/scala/org/apache/spark/streaming/kafka/KafkaRDDSuite.scala index 809699a739962..72f954149fefe 100644 --- a/external/kafka-0-8/src/test/scala/org/apache/spark/streaming/kafka/KafkaRDDSuite.scala +++ b/external/kafka-0-8/src/test/scala/org/apache/spark/streaming/kafka/KafkaRDDSuite.scala @@ -35,20 +35,27 @@ class KafkaRDDSuite extends SparkFunSuite with BeforeAndAfterAll { private var sc: SparkContext = _ override def beforeAll { + super.beforeAll() sc = new SparkContext(sparkConf) kafkaTestUtils = new KafkaTestUtils kafkaTestUtils.setup() } override def afterAll { - if (sc != null) { - sc.stop - sc = null - } - - if (kafkaTestUtils != null) { - kafkaTestUtils.teardown() - kafkaTestUtils = null + try { + try { + if (sc != null) { + sc.stop + sc = null + } + } finally { + if (kafkaTestUtils != null) { + kafkaTestUtils.teardown() + kafkaTestUtils = null + } + } + } finally { + super.afterAll() } } diff --git a/external/kafka-0-8/src/test/scala/org/apache/spark/streaming/kafka/KafkaStreamSuite.scala b/external/kafka-0-8/src/test/scala/org/apache/spark/streaming/kafka/KafkaStreamSuite.scala index 426cd83b4ddf8..ed130f5990955 100644 --- a/external/kafka-0-8/src/test/scala/org/apache/spark/streaming/kafka/KafkaStreamSuite.scala +++ b/external/kafka-0-8/src/test/scala/org/apache/spark/streaming/kafka/KafkaStreamSuite.scala @@ -35,19 +35,26 @@ class KafkaStreamSuite extends SparkFunSuite with Eventually with BeforeAndAfter private var kafkaTestUtils: KafkaTestUtils = _ override def beforeAll(): Unit = { + super.beforeAll() kafkaTestUtils = new KafkaTestUtils kafkaTestUtils.setup() } override def afterAll(): Unit = { - if (ssc != null) { - ssc.stop() - ssc = null - } - - if (kafkaTestUtils != null) { - kafkaTestUtils.teardown() - kafkaTestUtils = null + try { + try { + if (ssc != null) { + ssc.stop() + ssc = null + } + } finally { + if (kafkaTestUtils != null) { + kafkaTestUtils.teardown() + kafkaTestUtils = null + } + } + } finally { + super.afterAll() } } diff --git a/external/kafka-0-8/src/test/scala/org/apache/spark/streaming/kafka/ReliableKafkaStreamSuite.scala b/external/kafka-0-8/src/test/scala/org/apache/spark/streaming/kafka/ReliableKafkaStreamSuite.scala index 57f89cc7dbc65..5da5ea49d77ed 100644 --- a/external/kafka-0-8/src/test/scala/org/apache/spark/streaming/kafka/ReliableKafkaStreamSuite.scala +++ b/external/kafka-0-8/src/test/scala/org/apache/spark/streaming/kafka/ReliableKafkaStreamSuite.scala @@ -51,6 +51,7 @@ class ReliableKafkaStreamSuite extends SparkFunSuite private var tempDirectory: File = null override def beforeAll(): Unit = { + super.beforeAll() kafkaTestUtils = new KafkaTestUtils kafkaTestUtils.setup() @@ -65,11 +66,15 @@ class ReliableKafkaStreamSuite extends SparkFunSuite } override def afterAll(): Unit = { - Utils.deleteRecursively(tempDirectory) + try { + Utils.deleteRecursively(tempDirectory) - if (kafkaTestUtils != null) { - kafkaTestUtils.teardown() - kafkaTestUtils = null + if (kafkaTestUtils != null) { + kafkaTestUtils.teardown() + kafkaTestUtils = null + } + } finally { + super.afterAll() } } diff --git a/external/kinesis-asl-assembly/pom.xml b/external/kinesis-asl-assembly/pom.xml index 37c7d1e604ec5..68fded515626b 100644 --- a/external/kinesis-asl-assembly/pom.xml +++ b/external/kinesis-asl-assembly/pom.xml @@ -89,11 +89,6 @@ log4j provided - - net.java.dev.jets3t - jets3t - provided - org.apache.hadoop hadoop-client diff --git a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala index fa0de6298a5f1..69c52365b1bf8 100644 --- a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala +++ b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala @@ -160,7 +160,6 @@ private[kinesis] class KinesisReceiver[T]( cloudWatchCreds.map(_.provider).getOrElse(kinesisProvider), workerId) .withKinesisEndpoint(endpointUrl) - .withInitialPositionInStream(initialPosition.getPosition) .withTaskBackoffTimeMillis(500) .withRegionName(regionName) @@ -169,7 +168,8 @@ private[kinesis] class KinesisReceiver[T]( initialPosition match { case ts: AtTimestamp => baseClientLibConfiguration.withTimestampAtInitialPositionInStream(ts.getTimestamp) - case _ => baseClientLibConfiguration + case _ => + baseClientLibConfiguration.withInitialPositionInStream(initialPosition.getPosition) } } diff --git a/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisInputDStreamBuilderSuite.scala b/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisInputDStreamBuilderSuite.scala index e0e26847aa0ec..361520e292266 100644 --- a/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisInputDStreamBuilderSuite.scala +++ b/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisInputDStreamBuilderSuite.scala @@ -40,7 +40,11 @@ class KinesisInputDStreamBuilderSuite extends TestSuiteBase with BeforeAndAfterE .checkpointAppName(checkpointAppName) override def afterAll(): Unit = { - ssc.stop() + try { + ssc.stop() + } finally { + super.afterAll() + } } test("should raise an exception if the StreamingContext is missing") { diff --git a/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala b/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala index a7a68eba910bf..6d27445c5b606 100644 --- a/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala +++ b/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala @@ -71,17 +71,21 @@ abstract class KinesisStreamTests(aggregateTestData: Boolean) extends KinesisFun } override def afterAll(): Unit = { - if (ssc != null) { - ssc.stop() - } - if (sc != null) { - sc.stop() - } - if (testUtils != null) { - // Delete the Kinesis stream as well as the DynamoDB table generated by - // Kinesis Client Library when consuming the stream - testUtils.deleteStream() - testUtils.deleteDynamoDBTable(appName) + try { + if (ssc != null) { + ssc.stop() + } + if (sc != null) { + sc.stop() + } + if (testUtils != null) { + // Delete the Kinesis stream as well as the DynamoDB table generated by + // Kinesis Client Library when consuming the stream + testUtils.deleteStream() + testUtils.deleteDynamoDBTable(appName) + } + } finally { + super.afterAll() } } diff --git a/graphx/pom.xml b/graphx/pom.xml index fbe77fcb958d5..0f5dc548600b2 100644 --- a/graphx/pom.xml +++ b/graphx/pom.xml @@ -53,7 +53,7 @@ org.apache.xbean - xbean-asm5-shaded + xbean-asm6-shaded com.google.guava diff --git a/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala b/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala index ebd65e8320e5c..1305c059b89ce 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala @@ -184,9 +184,11 @@ object PageRank extends Logging { * indexed by the position of nodes in the sources list) and * edge attributes the normalized edge weight */ - def runParallelPersonalizedPageRank[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED], - numIter: Int, resetProb: Double = 0.15, - sources: Array[VertexId]): Graph[Vector, Double] = { + def runParallelPersonalizedPageRank[VD: ClassTag, ED: ClassTag]( + graph: Graph[VD, ED], + numIter: Int, + resetProb: Double = 0.15, + sources: Array[VertexId]): Graph[Vector, Double] = { require(numIter > 0, s"Number of iterations must be greater than 0," + s" but got ${numIter}") require(resetProb >= 0 && resetProb <= 1, s"Random reset probability must belong" + @@ -194,15 +196,13 @@ object PageRank extends Logging { require(sources.nonEmpty, s"The list of sources must be non-empty," + s" but got ${sources.mkString("[", ",", "]")}") - // TODO if one sources vertex id is outside of the int range - // we won't be able to store its activations in a sparse vector - require(sources.max <= Int.MaxValue.toLong, - s"This implementation currently only works for source vertex ids at most ${Int.MaxValue}") val zero = Vectors.sparse(sources.size, List()).asBreeze + // map of vid -> vector where for each vid, the _position of vid in source_ is set to 1.0 val sourcesInitMap = sources.zipWithIndex.map { case (vid, i) => val v = Vectors.sparse(sources.size, Array(i), Array(1.0)).asBreeze (vid, v) }.toMap + val sc = graph.vertices.sparkContext val sourcesInitMapBC = sc.broadcast(sourcesInitMap) // Initialize the PageRank graph with each edge attribute having @@ -212,13 +212,7 @@ object PageRank extends Logging { .outerJoinVertices(graph.outDegrees) { (vid, vdata, deg) => deg.getOrElse(0) } // Set the weight on the edges based on the degree .mapTriplets(e => 1.0 / e.srcAttr, TripletFields.Src) - .mapVertices { (vid, attr) => - if (sourcesInitMapBC.value contains vid) { - sourcesInitMapBC.value(vid) - } else { - zero - } - } + .mapVertices((vid, _) => sourcesInitMapBC.value.getOrElse(vid, zero)) var i = 0 while (i < numIter) { diff --git a/graphx/src/main/scala/org/apache/spark/graphx/util/BytecodeUtils.scala b/graphx/src/main/scala/org/apache/spark/graphx/util/BytecodeUtils.scala index d76e84ed8c9ed..50b03f71379a1 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/util/BytecodeUtils.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/util/BytecodeUtils.scala @@ -22,8 +22,8 @@ import java.io.{ByteArrayInputStream, ByteArrayOutputStream} import scala.collection.mutable.HashSet import scala.language.existentials -import org.apache.xbean.asm5.{ClassReader, ClassVisitor, MethodVisitor} -import org.apache.xbean.asm5.Opcodes._ +import org.apache.xbean.asm6.{ClassReader, ClassVisitor, MethodVisitor} +import org.apache.xbean.asm6.Opcodes._ import org.apache.spark.util.Utils @@ -109,14 +109,14 @@ private[graphx] object BytecodeUtils { * determine the actual method invoked by inspecting the bytecode. */ private class MethodInvocationFinder(className: String, methodName: String) - extends ClassVisitor(ASM5) { + extends ClassVisitor(ASM6) { val methodsInvoked = new HashSet[(Class[_], String)] override def visitMethod(access: Int, name: String, desc: String, sig: String, exceptions: Array[String]): MethodVisitor = { if (name == methodName) { - new MethodVisitor(ASM5) { + new MethodVisitor(ASM6) { override def visitMethodInsn( op: Int, owner: String, name: String, desc: String, itf: Boolean) { if (op == INVOKEVIRTUAL || op == INVOKESPECIAL || op == INVOKESTATIC) { diff --git a/graphx/src/test/scala/org/apache/spark/graphx/lib/PageRankSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/lib/PageRankSuite.scala index 9779553ce85d1..1e4c6c74bd184 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/lib/PageRankSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/lib/PageRankSuite.scala @@ -203,24 +203,42 @@ class PageRankSuite extends SparkFunSuite with LocalSparkContext { test("Chain PersonalizedPageRank") { withSpark { sc => - val chain1 = (0 until 9).map(x => (x, x + 1) ) + // Check that implementation can handle large vertexIds, SPARK-25149 + val vertexIdOffset = Int.MaxValue.toLong + 1 + val sourceOffest = 4 + val source = vertexIdOffset + sourceOffest + val numIter = 10 + val vertices = vertexIdOffset until vertexIdOffset + numIter + val chain1 = vertices.zip(vertices.tail) val rawEdges = sc.parallelize(chain1, 1).map { case (s, d) => (s.toLong, d.toLong) } val chain = Graph.fromEdgeTuples(rawEdges, 1.0).cache() val resetProb = 0.15 val tol = 0.0001 - val numIter = 10 val errorTol = 1.0e-1 - val staticRanks = chain.staticPersonalizedPageRank(4, numIter, resetProb).vertices - val dynamicRanks = chain.personalizedPageRank(4, tol, resetProb).vertices + val a = resetProb / (1 - Math.pow(1 - resetProb, numIter - sourceOffest)) + // We expect the rank to decay as (1 - resetProb) ^ distance + val expectedRanks = sc.parallelize(vertices).map { vid => + val rank = if (vid < source) { + 0.0 + } else { + a * Math.pow(1 - resetProb, vid - source) + } + vid -> rank + } + val expected = VertexRDD(expectedRanks) + + val staticRanks = chain.staticPersonalizedPageRank(source, numIter, resetProb).vertices + assert(compareRanks(staticRanks, expected) < errorTol) - assert(compareRanks(staticRanks, dynamicRanks) < errorTol) + val dynamicRanks = chain.personalizedPageRank(source, tol, resetProb).vertices + assert(compareRanks(dynamicRanks, expected) < errorTol) val parallelStaticRanks = chain - .staticParallelPersonalizedPageRank(Array(4), numIter, resetProb).mapVertices { + .staticParallelPersonalizedPageRank(Array(source), numIter, resetProb).mapVertices { case (vertexId, vector) => vector(0) }.vertices.cache() - assert(compareRanks(staticRanks, parallelStaticRanks) < errorTol) + assert(compareRanks(parallelStaticRanks, expected) < errorTol) } } diff --git a/graphx/src/test/scala/org/apache/spark/graphx/util/BytecodeUtilsSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/util/BytecodeUtilsSuite.scala index 61e44dcab578c..5325978a0a1ec 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/util/BytecodeUtilsSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/util/BytecodeUtilsSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.graphx.util import org.apache.spark.SparkFunSuite +import org.apache.spark.util.ClosureCleanerSuite2 // scalastyle:off println @@ -26,6 +27,7 @@ class BytecodeUtilsSuite extends SparkFunSuite { import BytecodeUtilsSuite.TestClass test("closure invokes a method") { + assume(!ClosureCleanerSuite2.supportsLMFs) val c1 = {e: TestClass => println(e.foo); println(e.bar); println(e.baz); } assert(BytecodeUtils.invokedMethod(c1, classOf[TestClass], "foo")) assert(BytecodeUtils.invokedMethod(c1, classOf[TestClass], "bar")) @@ -43,6 +45,7 @@ class BytecodeUtilsSuite extends SparkFunSuite { } test("closure inside a closure invokes a method") { + assume(!ClosureCleanerSuite2.supportsLMFs) val c1 = {e: TestClass => println(e.foo); println(e.bar); println(e.baz); } val c2 = {e: TestClass => c1(e); println(e.foo); } assert(BytecodeUtils.invokedMethod(c2, classOf[TestClass], "foo")) @@ -51,6 +54,7 @@ class BytecodeUtilsSuite extends SparkFunSuite { } test("closure inside a closure inside a closure invokes a method") { + assume(!ClosureCleanerSuite2.supportsLMFs) val c1 = {e: TestClass => println(e.baz); } val c2 = {e: TestClass => c1(e); println(e.foo); } val c3 = {e: TestClass => c2(e) } @@ -60,6 +64,7 @@ class BytecodeUtilsSuite extends SparkFunSuite { } test("closure calling a function that invokes a method") { + assume(!ClosureCleanerSuite2.supportsLMFs) def zoo(e: TestClass) { println(e.baz) } @@ -70,6 +75,7 @@ class BytecodeUtilsSuite extends SparkFunSuite { } test("closure calling a function that invokes a method which uses another closure") { + assume(!ClosureCleanerSuite2.supportsLMFs) val c2 = {e: TestClass => println(e.baz)} def zoo(e: TestClass) { c2(e) @@ -81,6 +87,7 @@ class BytecodeUtilsSuite extends SparkFunSuite { } test("nested closure") { + assume(!ClosureCleanerSuite2.supportsLMFs) val c2 = {e: TestClass => println(e.baz)} def zoo(e: TestClass, c: TestClass => Unit) { c(e) diff --git a/launcher/src/main/java/org/apache/spark/launcher/AbstractLauncher.java b/launcher/src/main/java/org/apache/spark/launcher/AbstractLauncher.java index 4e02843480e8f..8a1256f73416e 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/AbstractLauncher.java +++ b/launcher/src/main/java/org/apache/spark/launcher/AbstractLauncher.java @@ -28,7 +28,7 @@ * * @since Spark 2.3.0 */ -public abstract class AbstractLauncher { +public abstract class AbstractLauncher> { final SparkSubmitCommandBuilder builder; diff --git a/licenses/LICENSE-scopt.txt b/licenses-binary/LICENSE-AnchorJS.txt similarity index 100% rename from licenses/LICENSE-scopt.txt rename to licenses-binary/LICENSE-AnchorJS.txt diff --git a/licenses-binary/LICENSE-CC0.txt b/licenses-binary/LICENSE-CC0.txt new file mode 100644 index 0000000000000..1625c17936079 --- /dev/null +++ b/licenses-binary/LICENSE-CC0.txt @@ -0,0 +1,121 @@ +Creative Commons Legal Code + +CC0 1.0 Universal + + CREATIVE COMMONS CORPORATION IS NOT A LAW FIRM AND DOES NOT PROVIDE + LEGAL SERVICES. DISTRIBUTION OF THIS DOCUMENT DOES NOT CREATE AN + ATTORNEY-CLIENT RELATIONSHIP. CREATIVE COMMONS PROVIDES THIS + INFORMATION ON AN "AS-IS" BASIS. CREATIVE COMMONS MAKES NO WARRANTIES + REGARDING THE USE OF THIS DOCUMENT OR THE INFORMATION OR WORKS + PROVIDED HEREUNDER, AND DISCLAIMS LIABILITY FOR DAMAGES RESULTING FROM + THE USE OF THIS DOCUMENT OR THE INFORMATION OR WORKS PROVIDED + HEREUNDER. + +Statement of Purpose + +The laws of most jurisdictions throughout the world automatically confer +exclusive Copyright and Related Rights (defined below) upon the creator +and subsequent owner(s) (each and all, an "owner") of an original work of +authorship and/or a database (each, a "Work"). + +Certain owners wish to permanently relinquish those rights to a Work for +the purpose of contributing to a commons of creative, cultural and +scientific works ("Commons") that the public can reliably and without fear +of later claims of infringement build upon, modify, incorporate in other +works, reuse and redistribute as freely as possible in any form whatsoever +and for any purposes, including without limitation commercial purposes. +These owners may contribute to the Commons to promote the ideal of a free +culture and the further production of creative, cultural and scientific +works, or to gain reputation or greater distribution for their Work in +part through the use and efforts of others. + +For these and/or other purposes and motivations, and without any +expectation of additional consideration or compensation, the person +associating CC0 with a Work (the "Affirmer"), to the extent that he or she +is an owner of Copyright and Related Rights in the Work, voluntarily +elects to apply CC0 to the Work and publicly distribute the Work under its +terms, with knowledge of his or her Copyright and Related Rights in the +Work and the meaning and intended legal effect of CC0 on those rights. + +1. Copyright and Related Rights. A Work made available under CC0 may be +protected by copyright and related or neighboring rights ("Copyright and +Related Rights"). Copyright and Related Rights include, but are not +limited to, the following: + + i. the right to reproduce, adapt, distribute, perform, display, + communicate, and translate a Work; + ii. moral rights retained by the original author(s) and/or performer(s); +iii. publicity and privacy rights pertaining to a person's image or + likeness depicted in a Work; + iv. rights protecting against unfair competition in regards to a Work, + subject to the limitations in paragraph 4(a), below; + v. rights protecting the extraction, dissemination, use and reuse of data + in a Work; + vi. database rights (such as those arising under Directive 96/9/EC of the + European Parliament and of the Council of 11 March 1996 on the legal + protection of databases, and under any national implementation + thereof, including any amended or successor version of such + directive); and +vii. other similar, equivalent or corresponding rights throughout the + world based on applicable law or treaty, and any national + implementations thereof. + +2. Waiver. To the greatest extent permitted by, but not in contravention +of, applicable law, Affirmer hereby overtly, fully, permanently, +irrevocably and unconditionally waives, abandons, and surrenders all of +Affirmer's Copyright and Related Rights and associated claims and causes +of action, whether now known or unknown (including existing as well as +future claims and causes of action), in the Work (i) in all territories +worldwide, (ii) for the maximum duration provided by applicable law or +treaty (including future time extensions), (iii) in any current or future +medium and for any number of copies, and (iv) for any purpose whatsoever, +including without limitation commercial, advertising or promotional +purposes (the "Waiver"). Affirmer makes the Waiver for the benefit of each +member of the public at large and to the detriment of Affirmer's heirs and +successors, fully intending that such Waiver shall not be subject to +revocation, rescission, cancellation, termination, or any other legal or +equitable action to disrupt the quiet enjoyment of the Work by the public +as contemplated by Affirmer's express Statement of Purpose. + +3. Public License Fallback. Should any part of the Waiver for any reason +be judged legally invalid or ineffective under applicable law, then the +Waiver shall be preserved to the maximum extent permitted taking into +account Affirmer's express Statement of Purpose. In addition, to the +extent the Waiver is so judged Affirmer hereby grants to each affected +person a royalty-free, non transferable, non sublicensable, non exclusive, +irrevocable and unconditional license to exercise Affirmer's Copyright and +Related Rights in the Work (i) in all territories worldwide, (ii) for the +maximum duration provided by applicable law or treaty (including future +time extensions), (iii) in any current or future medium and for any number +of copies, and (iv) for any purpose whatsoever, including without +limitation commercial, advertising or promotional purposes (the +"License"). The License shall be deemed effective as of the date CC0 was +applied by Affirmer to the Work. Should any part of the License for any +reason be judged legally invalid or ineffective under applicable law, such +partial invalidity or ineffectiveness shall not invalidate the remainder +of the License, and in such case Affirmer hereby affirms that he or she +will not (i) exercise any of his or her remaining Copyright and Related +Rights in the Work or (ii) assert any associated claims and causes of +action with respect to the Work, in either case contrary to Affirmer's +express Statement of Purpose. + +4. Limitations and Disclaimers. + + a. No trademark or patent rights held by Affirmer are waived, abandoned, + surrendered, licensed or otherwise affected by this document. + b. Affirmer offers the Work as-is and makes no representations or + warranties of any kind concerning the Work, express, implied, + statutory or otherwise, including without limitation warranties of + title, merchantability, fitness for a particular purpose, non + infringement, or the absence of latent or other defects, accuracy, or + the present or absence of errors, whether or not discoverable, all to + the greatest extent permissible under applicable law. + c. Affirmer disclaims responsibility for clearing rights of other persons + that may apply to the Work or any use thereof, including without + limitation any person's Copyright and Related Rights in the Work. + Further, Affirmer disclaims responsibility for obtaining any necessary + consents, permissions or other rights required for any use of the + Work. + d. Affirmer understands and acknowledges that Creative Commons is not a + party to this document and has no duty or obligation with respect to + this CC0 or use of the Work. \ No newline at end of file diff --git a/licenses/LICENSE-antlr.txt b/licenses-binary/LICENSE-antlr.txt similarity index 100% rename from licenses/LICENSE-antlr.txt rename to licenses-binary/LICENSE-antlr.txt diff --git a/licenses-binary/LICENSE-arpack.txt b/licenses-binary/LICENSE-arpack.txt new file mode 100644 index 0000000000000..a3ad80087bb63 --- /dev/null +++ b/licenses-binary/LICENSE-arpack.txt @@ -0,0 +1,8 @@ +Copyright © 2018 The University of Tennessee. All rights reserved. + +Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: +· Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. +· Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer listed in this license in the documentation and/or other materials provided with the distribution. +· Neither the name of the copyright holders nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. + +This software is provided by the copyright holders and contributors "as is" and any express or implied warranties, including, but not limited to, the implied warranties of merchantability and fitness for a particular purpose are disclaimed. in no event shall the copyright owner or contributors be liable for any direct, indirect, incidental, special, exemplary, or consequential damages (including, but not limited to, procurement of substitute goods or services; loss of use, data, or profits; or business interruption) however caused and on any theory of liability, whether in contract, strict liability, or tort (including negligence or otherwise) arising in any way out of the use of this software, even if advised of the possibility of such damage. \ No newline at end of file diff --git a/licenses-binary/LICENSE-automaton.txt b/licenses-binary/LICENSE-automaton.txt new file mode 100644 index 0000000000000..2fc6e8c3432f0 --- /dev/null +++ b/licenses-binary/LICENSE-automaton.txt @@ -0,0 +1,24 @@ +Copyright (c) 2001-2017 Anders Moeller +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions +are met: +1. Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. +2. Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. +3. The name of the author may not be used to endorse or promote products + derived from this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR +IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES +OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. +IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT, +INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT +NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF +THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. \ No newline at end of file diff --git a/licenses-binary/LICENSE-bootstrap.txt b/licenses-binary/LICENSE-bootstrap.txt new file mode 100644 index 0000000000000..6c711832fbc85 --- /dev/null +++ b/licenses-binary/LICENSE-bootstrap.txt @@ -0,0 +1,13 @@ +Copyright 2013 Twitter, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. diff --git a/licenses-binary/LICENSE-cloudpickle.txt b/licenses-binary/LICENSE-cloudpickle.txt new file mode 100644 index 0000000000000..b1e20fa1eda88 --- /dev/null +++ b/licenses-binary/LICENSE-cloudpickle.txt @@ -0,0 +1,28 @@ +Copyright (c) 2012, Regents of the University of California. +Copyright (c) 2009 `PiCloud, Inc. `_. +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions +are met: + * Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + * Neither the name of the University of California, Berkeley nor the + names of its contributors may be used to endorse or promote + products derived from this software without specific prior written + permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED +TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF +LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING +NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. \ No newline at end of file diff --git a/licenses-binary/LICENSE-d3.min.js.txt b/licenses-binary/LICENSE-d3.min.js.txt new file mode 100644 index 0000000000000..c71e3f254c068 --- /dev/null +++ b/licenses-binary/LICENSE-d3.min.js.txt @@ -0,0 +1,26 @@ +Copyright (c) 2010-2015, Michael Bostock +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +* Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + +* Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +* The name Michael Bostock may not be used to endorse or promote products + derived from this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL MICHAEL BOSTOCK BE LIABLE FOR ANY DIRECT, +INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, +BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING +NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, +EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. \ No newline at end of file diff --git a/licenses/LICENSE-Mockito.txt b/licenses-binary/LICENSE-dagre-d3.txt similarity index 94% rename from licenses/LICENSE-Mockito.txt rename to licenses-binary/LICENSE-dagre-d3.txt index e0840a446caf5..4864fe05e9803 100644 --- a/licenses/LICENSE-Mockito.txt +++ b/licenses-binary/LICENSE-dagre-d3.txt @@ -1,6 +1,4 @@ -The MIT License - -Copyright (c) 2007 Mockito contributors +Copyright (c) 2013 Chris Pettitt Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal diff --git a/licenses-binary/LICENSE-datatables.txt b/licenses-binary/LICENSE-datatables.txt new file mode 100644 index 0000000000000..bb7708b5b5a49 --- /dev/null +++ b/licenses-binary/LICENSE-datatables.txt @@ -0,0 +1,7 @@ +Copyright (C) 2008-2018, SpryMedia Ltd. + +Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. \ No newline at end of file diff --git a/licenses/LICENSE-f2j.txt b/licenses-binary/LICENSE-f2j.txt similarity index 100% rename from licenses/LICENSE-f2j.txt rename to licenses-binary/LICENSE-f2j.txt diff --git a/licenses-binary/LICENSE-graphlib-dot.txt b/licenses-binary/LICENSE-graphlib-dot.txt new file mode 100644 index 0000000000000..4864fe05e9803 --- /dev/null +++ b/licenses-binary/LICENSE-graphlib-dot.txt @@ -0,0 +1,19 @@ +Copyright (c) 2013 Chris Pettitt + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. \ No newline at end of file diff --git a/licenses-binary/LICENSE-heapq.txt b/licenses-binary/LICENSE-heapq.txt new file mode 100644 index 0000000000000..0c4c4b954bea4 --- /dev/null +++ b/licenses-binary/LICENSE-heapq.txt @@ -0,0 +1,280 @@ + +# A. HISTORY OF THE SOFTWARE +# ========================== +# +# Python was created in the early 1990s by Guido van Rossum at Stichting +# Mathematisch Centrum (CWI, see http://www.cwi.nl) in the Netherlands +# as a successor of a language called ABC. Guido remains Python's +# principal author, although it includes many contributions from others. +# +# In 1995, Guido continued his work on Python at the Corporation for +# National Research Initiatives (CNRI, see http://www.cnri.reston.va.us) +# in Reston, Virginia where he released several versions of the +# software. +# +# In May 2000, Guido and the Python core development team moved to +# BeOpen.com to form the BeOpen PythonLabs team. In October of the same +# year, the PythonLabs team moved to Digital Creations (now Zope +# Corporation, see http://www.zope.com). In 2001, the Python Software +# Foundation (PSF, see http://www.python.org/psf/) was formed, a +# non-profit organization created specifically to own Python-related +# Intellectual Property. Zope Corporation is a sponsoring member of +# the PSF. +# +# All Python releases are Open Source (see http://www.opensource.org for +# the Open Source Definition). Historically, most, but not all, Python +# releases have also been GPL-compatible; the table below summarizes +# the various releases. +# +# Release Derived Year Owner GPL- +# from compatible? (1) +# +# 0.9.0 thru 1.2 1991-1995 CWI yes +# 1.3 thru 1.5.2 1.2 1995-1999 CNRI yes +# 1.6 1.5.2 2000 CNRI no +# 2.0 1.6 2000 BeOpen.com no +# 1.6.1 1.6 2001 CNRI yes (2) +# 2.1 2.0+1.6.1 2001 PSF no +# 2.0.1 2.0+1.6.1 2001 PSF yes +# 2.1.1 2.1+2.0.1 2001 PSF yes +# 2.2 2.1.1 2001 PSF yes +# 2.1.2 2.1.1 2002 PSF yes +# 2.1.3 2.1.2 2002 PSF yes +# 2.2.1 2.2 2002 PSF yes +# 2.2.2 2.2.1 2002 PSF yes +# 2.2.3 2.2.2 2003 PSF yes +# 2.3 2.2.2 2002-2003 PSF yes +# 2.3.1 2.3 2002-2003 PSF yes +# 2.3.2 2.3.1 2002-2003 PSF yes +# 2.3.3 2.3.2 2002-2003 PSF yes +# 2.3.4 2.3.3 2004 PSF yes +# 2.3.5 2.3.4 2005 PSF yes +# 2.4 2.3 2004 PSF yes +# 2.4.1 2.4 2005 PSF yes +# 2.4.2 2.4.1 2005 PSF yes +# 2.4.3 2.4.2 2006 PSF yes +# 2.4.4 2.4.3 2006 PSF yes +# 2.5 2.4 2006 PSF yes +# 2.5.1 2.5 2007 PSF yes +# 2.5.2 2.5.1 2008 PSF yes +# 2.5.3 2.5.2 2008 PSF yes +# 2.6 2.5 2008 PSF yes +# 2.6.1 2.6 2008 PSF yes +# 2.6.2 2.6.1 2009 PSF yes +# 2.6.3 2.6.2 2009 PSF yes +# 2.6.4 2.6.3 2009 PSF yes +# 2.6.5 2.6.4 2010 PSF yes +# 2.7 2.6 2010 PSF yes +# +# Footnotes: +# +# (1) GPL-compatible doesn't mean that we're distributing Python under +# the GPL. All Python licenses, unlike the GPL, let you distribute +# a modified version without making your changes open source. The +# GPL-compatible licenses make it possible to combine Python with +# other software that is released under the GPL; the others don't. +# +# (2) According to Richard Stallman, 1.6.1 is not GPL-compatible, +# because its license has a choice of law clause. According to +# CNRI, however, Stallman's lawyer has told CNRI's lawyer that 1.6.1 +# is "not incompatible" with the GPL. +# +# Thanks to the many outside volunteers who have worked under Guido's +# direction to make these releases possible. +# +# +# B. TERMS AND CONDITIONS FOR ACCESSING OR OTHERWISE USING PYTHON +# =============================================================== +# +# PYTHON SOFTWARE FOUNDATION LICENSE VERSION 2 +# -------------------------------------------- +# +# 1. This LICENSE AGREEMENT is between the Python Software Foundation +# ("PSF"), and the Individual or Organization ("Licensee") accessing and +# otherwise using this software ("Python") in source or binary form and +# its associated documentation. +# +# 2. Subject to the terms and conditions of this License Agreement, PSF hereby +# grants Licensee a nonexclusive, royalty-free, world-wide license to reproduce, +# analyze, test, perform and/or display publicly, prepare derivative works, +# distribute, and otherwise use Python alone or in any derivative version, +# provided, however, that PSF's License Agreement and PSF's notice of copyright, +# i.e., "Copyright (c) 2001, 2002, 2003, 2004, 2005, 2006, 2007, 2008, 2009, 2010, +# 2011, 2012, 2013 Python Software Foundation; All Rights Reserved" are retained +# in Python alone or in any derivative version prepared by Licensee. +# +# 3. In the event Licensee prepares a derivative work that is based on +# or incorporates Python or any part thereof, and wants to make +# the derivative work available to others as provided herein, then +# Licensee hereby agrees to include in any such work a brief summary of +# the changes made to Python. +# +# 4. PSF is making Python available to Licensee on an "AS IS" +# basis. PSF MAKES NO REPRESENTATIONS OR WARRANTIES, EXPRESS OR +# IMPLIED. BY WAY OF EXAMPLE, BUT NOT LIMITATION, PSF MAKES NO AND +# DISCLAIMS ANY REPRESENTATION OR WARRANTY OF MERCHANTABILITY OR FITNESS +# FOR ANY PARTICULAR PURPOSE OR THAT THE USE OF PYTHON WILL NOT +# INFRINGE ANY THIRD PARTY RIGHTS. +# +# 5. PSF SHALL NOT BE LIABLE TO LICENSEE OR ANY OTHER USERS OF PYTHON +# FOR ANY INCIDENTAL, SPECIAL, OR CONSEQUENTIAL DAMAGES OR LOSS AS +# A RESULT OF MODIFYING, DISTRIBUTING, OR OTHERWISE USING PYTHON, +# OR ANY DERIVATIVE THEREOF, EVEN IF ADVISED OF THE POSSIBILITY THEREOF. +# +# 6. This License Agreement will automatically terminate upon a material +# breach of its terms and conditions. +# +# 7. Nothing in this License Agreement shall be deemed to create any +# relationship of agency, partnership, or joint venture between PSF and +# Licensee. This License Agreement does not grant permission to use PSF +# trademarks or trade name in a trademark sense to endorse or promote +# products or services of Licensee, or any third party. +# +# 8. By copying, installing or otherwise using Python, Licensee +# agrees to be bound by the terms and conditions of this License +# Agreement. +# +# +# BEOPEN.COM LICENSE AGREEMENT FOR PYTHON 2.0 +# ------------------------------------------- +# +# BEOPEN PYTHON OPEN SOURCE LICENSE AGREEMENT VERSION 1 +# +# 1. This LICENSE AGREEMENT is between BeOpen.com ("BeOpen"), having an +# office at 160 Saratoga Avenue, Santa Clara, CA 95051, and the +# Individual or Organization ("Licensee") accessing and otherwise using +# this software in source or binary form and its associated +# documentation ("the Software"). +# +# 2. Subject to the terms and conditions of this BeOpen Python License +# Agreement, BeOpen hereby grants Licensee a non-exclusive, +# royalty-free, world-wide license to reproduce, analyze, test, perform +# and/or display publicly, prepare derivative works, distribute, and +# otherwise use the Software alone or in any derivative version, +# provided, however, that the BeOpen Python License is retained in the +# Software, alone or in any derivative version prepared by Licensee. +# +# 3. BeOpen is making the Software available to Licensee on an "AS IS" +# basis. BEOPEN MAKES NO REPRESENTATIONS OR WARRANTIES, EXPRESS OR +# IMPLIED. BY WAY OF EXAMPLE, BUT NOT LIMITATION, BEOPEN MAKES NO AND +# DISCLAIMS ANY REPRESENTATION OR WARRANTY OF MERCHANTABILITY OR FITNESS +# FOR ANY PARTICULAR PURPOSE OR THAT THE USE OF THE SOFTWARE WILL NOT +# INFRINGE ANY THIRD PARTY RIGHTS. +# +# 4. BEOPEN SHALL NOT BE LIABLE TO LICENSEE OR ANY OTHER USERS OF THE +# SOFTWARE FOR ANY INCIDENTAL, SPECIAL, OR CONSEQUENTIAL DAMAGES OR LOSS +# AS A RESULT OF USING, MODIFYING OR DISTRIBUTING THE SOFTWARE, OR ANY +# DERIVATIVE THEREOF, EVEN IF ADVISED OF THE POSSIBILITY THEREOF. +# +# 5. This License Agreement will automatically terminate upon a material +# breach of its terms and conditions. +# +# 6. This License Agreement shall be governed by and interpreted in all +# respects by the law of the State of California, excluding conflict of +# law provisions. Nothing in this License Agreement shall be deemed to +# create any relationship of agency, partnership, or joint venture +# between BeOpen and Licensee. This License Agreement does not grant +# permission to use BeOpen trademarks or trade names in a trademark +# sense to endorse or promote products or services of Licensee, or any +# third party. As an exception, the "BeOpen Python" logos available at +# http://www.pythonlabs.com/logos.html may be used according to the +# permissions granted on that web page. +# +# 7. By copying, installing or otherwise using the software, Licensee +# agrees to be bound by the terms and conditions of this License +# Agreement. +# +# +# CNRI LICENSE AGREEMENT FOR PYTHON 1.6.1 +# --------------------------------------- +# +# 1. This LICENSE AGREEMENT is between the Corporation for National +# Research Initiatives, having an office at 1895 Preston White Drive, +# Reston, VA 20191 ("CNRI"), and the Individual or Organization +# ("Licensee") accessing and otherwise using Python 1.6.1 software in +# source or binary form and its associated documentation. +# +# 2. Subject to the terms and conditions of this License Agreement, CNRI +# hereby grants Licensee a nonexclusive, royalty-free, world-wide +# license to reproduce, analyze, test, perform and/or display publicly, +# prepare derivative works, distribute, and otherwise use Python 1.6.1 +# alone or in any derivative version, provided, however, that CNRI's +# License Agreement and CNRI's notice of copyright, i.e., "Copyright (c) +# 1995-2001 Corporation for National Research Initiatives; All Rights +# Reserved" are retained in Python 1.6.1 alone or in any derivative +# version prepared by Licensee. Alternately, in lieu of CNRI's License +# Agreement, Licensee may substitute the following text (omitting the +# quotes): "Python 1.6.1 is made available subject to the terms and +# conditions in CNRI's License Agreement. This Agreement together with +# Python 1.6.1 may be located on the Internet using the following +# unique, persistent identifier (known as a handle): 1895.22/1013. This +# Agreement may also be obtained from a proxy server on the Internet +# using the following URL: http://hdl.handle.net/1895.22/1013". +# +# 3. In the event Licensee prepares a derivative work that is based on +# or incorporates Python 1.6.1 or any part thereof, and wants to make +# the derivative work available to others as provided herein, then +# Licensee hereby agrees to include in any such work a brief summary of +# the changes made to Python 1.6.1. +# +# 4. CNRI is making Python 1.6.1 available to Licensee on an "AS IS" +# basis. CNRI MAKES NO REPRESENTATIONS OR WARRANTIES, EXPRESS OR +# IMPLIED. BY WAY OF EXAMPLE, BUT NOT LIMITATION, CNRI MAKES NO AND +# DISCLAIMS ANY REPRESENTATION OR WARRANTY OF MERCHANTABILITY OR FITNESS +# FOR ANY PARTICULAR PURPOSE OR THAT THE USE OF PYTHON 1.6.1 WILL NOT +# INFRINGE ANY THIRD PARTY RIGHTS. +# +# 5. CNRI SHALL NOT BE LIABLE TO LICENSEE OR ANY OTHER USERS OF PYTHON +# 1.6.1 FOR ANY INCIDENTAL, SPECIAL, OR CONSEQUENTIAL DAMAGES OR LOSS AS +# A RESULT OF MODIFYING, DISTRIBUTING, OR OTHERWISE USING PYTHON 1.6.1, +# OR ANY DERIVATIVE THEREOF, EVEN IF ADVISED OF THE POSSIBILITY THEREOF. +# +# 6. This License Agreement will automatically terminate upon a material +# breach of its terms and conditions. +# +# 7. This License Agreement shall be governed by the federal +# intellectual property law of the United States, including without +# limitation the federal copyright law, and, to the extent such +# U.S. federal law does not apply, by the law of the Commonwealth of +# Virginia, excluding Virginia's conflict of law provisions. +# Notwithstanding the foregoing, with regard to derivative works based +# on Python 1.6.1 that incorporate non-separable material that was +# previously distributed under the GNU General Public License (GPL), the +# law of the Commonwealth of Virginia shall govern this License +# Agreement only as to issues arising under or with respect to +# Paragraphs 4, 5, and 7 of this License Agreement. Nothing in this +# License Agreement shall be deemed to create any relationship of +# agency, partnership, or joint venture between CNRI and Licensee. This +# License Agreement does not grant permission to use CNRI trademarks or +# trade name in a trademark sense to endorse or promote products or +# services of Licensee, or any third party. +# +# 8. By clicking on the "ACCEPT" button where indicated, or by copying, +# installing or otherwise using Python 1.6.1, Licensee agrees to be +# bound by the terms and conditions of this License Agreement. +# +# ACCEPT +# +# +# CWI LICENSE AGREEMENT FOR PYTHON 0.9.0 THROUGH 1.2 +# -------------------------------------------------- +# +# Copyright (c) 1991 - 1995, Stichting Mathematisch Centrum Amsterdam, +# The Netherlands. All rights reserved. +# +# Permission to use, copy, modify, and distribute this software and its +# documentation for any purpose and without fee is hereby granted, +# provided that the above copyright notice appear in all copies and that +# both that copyright notice and this permission notice appear in +# supporting documentation, and that the name of Stichting Mathematisch +# Centrum or CWI not be used in advertising or publicity pertaining to +# distribution of the software without specific, written prior +# permission. +# +# STICHTING MATHEMATISCH CENTRUM DISCLAIMS ALL WARRANTIES WITH REGARD TO +# THIS SOFTWARE, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY AND +# FITNESS, IN NO EVENT SHALL STICHTING MATHEMATISCH CENTRUM BE LIABLE +# FOR ANY SPECIAL, INDIRECT OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES +# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN +# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT +# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. \ No newline at end of file diff --git a/licenses-binary/LICENSE-janino.txt b/licenses-binary/LICENSE-janino.txt new file mode 100644 index 0000000000000..d1e1f237c4641 --- /dev/null +++ b/licenses-binary/LICENSE-janino.txt @@ -0,0 +1,31 @@ +Janino - An embedded Java[TM] compiler + +Copyright (c) 2001-2016, Arno Unkrig +Copyright (c) 2015-2016 TIBCO Software Inc. +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions +are met: + + 1. Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + 2. Redistributions in binary form must reproduce the above + copyright notice, this list of conditions and the following + disclaimer in the documentation and/or other materials + provided with the distribution. + 3. Neither the name of JANINO nor the names of its contributors + may be used to endorse or promote products derived from this + software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDERS OR CONTRIBUTORS BE +LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER +IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR +OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN +IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. \ No newline at end of file diff --git a/licenses-binary/LICENSE-javassist.html b/licenses-binary/LICENSE-javassist.html new file mode 100644 index 0000000000000..5abd563a0c4d9 --- /dev/null +++ b/licenses-binary/LICENSE-javassist.html @@ -0,0 +1,373 @@ + + + Javassist License + + + + +
    MOZILLA PUBLIC LICENSE
    Version + 1.1 +

    +


    +
    +

    1. Definitions. +

      1.0.1. "Commercial Use" means distribution or otherwise making the + Covered Code available to a third party. +

      1.1. ''Contributor'' means each entity that creates or contributes + to the creation of Modifications. +

      1.2. ''Contributor Version'' means the combination of the Original + Code, prior Modifications used by a Contributor, and the Modifications made by + that particular Contributor. +

      1.3. ''Covered Code'' means the Original Code or Modifications or + the combination of the Original Code and Modifications, in each case including + portions thereof. +

      1.4. ''Electronic Distribution Mechanism'' means a mechanism + generally accepted in the software development community for the electronic + transfer of data. +

      1.5. ''Executable'' means Covered Code in any form other than Source + Code. +

      1.6. ''Initial Developer'' means the individual or entity identified + as the Initial Developer in the Source Code notice required by Exhibit + A. +

      1.7. ''Larger Work'' means a work which combines Covered Code or + portions thereof with code not governed by the terms of this License. +

      1.8. ''License'' means this document. +

      1.8.1. "Licensable" means having the right to grant, to the maximum + extent possible, whether at the time of the initial grant or subsequently + acquired, any and all of the rights conveyed herein. +

      1.9. ''Modifications'' means any addition to or deletion from the + substance or structure of either the Original Code or any previous + Modifications. When Covered Code is released as a series of files, a + Modification is: +

        A. Any addition to or deletion from the contents of a file + containing Original Code or previous Modifications. +

        B. Any new file that contains any part of the Original Code or + previous Modifications.
         

      1.10. ''Original Code'' +means Source Code of computer software code which is described in the Source +Code notice required by Exhibit A as Original Code, and which, at the +time of its release under this License is not already Covered Code governed by +this License. +

      1.10.1. "Patent Claims" means any patent claim(s), now owned or + hereafter acquired, including without limitation,  method, process, and + apparatus claims, in any patent Licensable by grantor. +

      1.11. ''Source Code'' means the preferred form of the Covered Code + for making modifications to it, including all modules it contains, plus any + associated interface definition files, scripts used to control compilation and + installation of an Executable, or source code differential comparisons against + either the Original Code or another well known, available Covered Code of the + Contributor's choice. The Source Code can be in a compressed or archival form, + provided the appropriate decompression or de-archiving software is widely + available for no charge. +

      1.12. "You'' (or "Your")  means an individual or a legal entity + exercising rights under, and complying with all of the terms of, this License + or a future version of this License issued under Section 6.1. For legal + entities, "You'' includes any entity which controls, is controlled by, or is + under common control with You. For purposes of this definition, "control'' + means (a) the power, direct or indirect, to cause the direction or management + of such entity, whether by contract or otherwise, or (b) ownership of more + than fifty percent (50%) of the outstanding shares or beneficial ownership of + such entity.

    2. Source Code License. +
      2.1. The Initial Developer Grant.
      The Initial Developer hereby + grants You a world-wide, royalty-free, non-exclusive license, subject to third + party intellectual property claims: +
        (a)  under intellectual property rights (other than + patent or trademark) Licensable by Initial Developer to use, reproduce, + modify, display, perform, sublicense and distribute the Original Code (or + portions thereof) with or without Modifications, and/or as part of a Larger + Work; and +

        (b) under Patents Claims infringed by the making, using or selling + of Original Code, to make, have made, use, practice, sell, and offer for + sale, and/or otherwise dispose of the Original Code (or portions thereof). +

          +
          (c) the licenses granted in this Section 2.1(a) and (b) + are effective on the date Initial Developer first distributes Original Code + under the terms of this License. +

          (d) Notwithstanding Section 2.1(b) above, no patent license is + granted: 1) for code that You delete from the Original Code; 2) separate + from the Original Code;  or 3) for infringements caused by: i) the + modification of the Original Code or ii) the combination of the Original + Code with other software or devices.
           

        2.2. Contributor + Grant.
        Subject to third party intellectual property claims, each + Contributor hereby grants You a world-wide, royalty-free, non-exclusive + license +

          (a)  under intellectual property rights (other + than patent or trademark) Licensable by Contributor, to use, reproduce, + modify, display, perform, sublicense and distribute the Modifications + created by such Contributor (or portions thereof) either on an unmodified + basis, with other Modifications, as Covered Code and/or as part of a Larger + Work; and +

          (b) under Patent Claims infringed by the making, using, or selling + of  Modifications made by that Contributor either alone and/or in combination with its Contributor Version (or portions of such + combination), to make, use, sell, offer for sale, have made, and/or + otherwise dispose of: 1) Modifications made by that Contributor (or portions + thereof); and 2) the combination of  Modifications made by that + Contributor with its Contributor Version (or portions of such + combination). +

          (c) the licenses granted in Sections 2.2(a) and 2.2(b) are + effective on the date Contributor first makes Commercial Use of the Covered + Code. +

          (d)    Notwithstanding Section 2.2(b) above, no + patent license is granted: 1) for any code that Contributor has deleted from + the Contributor Version; 2)  separate from the Contributor + Version;  3)  for infringements caused by: i) third party + modifications of Contributor Version or ii)  the combination of + Modifications made by that Contributor with other software  (except as + part of the Contributor Version) or other devices; or 4) under Patent Claims + infringed by Covered Code in the absence of Modifications made by that + Contributor.

      +


      3. Distribution Obligations. +

        3.1. Application of License.
        The Modifications which You create + or to which You contribute are governed by the terms of this License, + including without limitation Section 2.2. The Source Code version of + Covered Code may be distributed only under the terms of this License or a + future version of this License released under Section 6.1, and You must + include a copy of this License with every copy of the Source Code You + distribute. You may not offer or impose any terms on any Source Code version + that alters or restricts the applicable version of this License or the + recipients' rights hereunder. However, You may include an additional document + offering the additional rights described in Section 3.5. +

        3.2. Availability of Source Code.
        Any Modification which You + create or to which You contribute must be made available in Source Code form + under the terms of this License either on the same media as an Executable + version or via an accepted Electronic Distribution Mechanism to anyone to whom + you made an Executable version available; and if made available via Electronic + Distribution Mechanism, must remain available for at least twelve (12) months + after the date it initially became available, or at least six (6) months after + a subsequent version of that particular Modification has been made available + to such recipients. You are responsible for ensuring that the Source Code + version remains available even if the Electronic Distribution Mechanism is + maintained by a third party. +

        3.3. Description of Modifications.
        You must cause all Covered + Code to which You contribute to contain a file documenting the changes You + made to create that Covered Code and the date of any change. You must include + a prominent statement that the Modification is derived, directly or + indirectly, from Original Code provided by the Initial Developer and including + the name of the Initial Developer in (a) the Source Code, and (b) in any + notice in an Executable version or related documentation in which You describe + the origin or ownership of the Covered Code. +

        3.4. Intellectual Property Matters +

          (a) Third Party Claims.
          If Contributor has knowledge that a + license under a third party's intellectual property rights is required to + exercise the rights granted by such Contributor under Sections 2.1 or 2.2, + Contributor must include a text file with the Source Code distribution + titled "LEGAL'' which describes the claim and the party making the claim in + sufficient detail that a recipient will know whom to contact. If Contributor + obtains such knowledge after the Modification is made available as described + in Section 3.2, Contributor shall promptly modify the LEGAL file in all + copies Contributor makes available thereafter and shall take other steps + (such as notifying appropriate mailing lists or newsgroups) reasonably + calculated to inform those who received the Covered Code that new knowledge + has been obtained. +

          (b) Contributor APIs.
          If Contributor's Modifications include + an application programming interface and Contributor has knowledge of patent + licenses which are reasonably necessary to implement that API, Contributor + must also include this information in the LEGAL file. +
           

                  +(c)    Representations. +
          Contributor represents that, except as disclosed pursuant to Section + 3.4(a) above, Contributor believes that Contributor's Modifications are + Contributor's original creation(s) and/or Contributor has sufficient rights + to grant the rights conveyed by this License.
        +


        3.5. Required Notices.
        You must duplicate the notice in + Exhibit A in each file of the Source Code.  If it is not possible + to put such notice in a particular Source Code file due to its structure, then + You must include such notice in a location (such as a relevant directory) + where a user would be likely to look for such a notice.  If You created + one or more Modification(s) You may add your name as a Contributor to the + notice described in Exhibit A.  You must also duplicate this + License in any documentation for the Source Code where You describe + recipients' rights or ownership rights relating to Covered Code.  You may + choose to offer, and to charge a fee for, warranty, support, indemnity or + liability obligations to one or more recipients of Covered Code. However, You + may do so only on Your own behalf, and not on behalf of the Initial Developer + or any Contributor. You must make it absolutely clear than any such warranty, + support, indemnity or liability obligation is offered by You alone, and You + hereby agree to indemnify the Initial Developer and every Contributor for any + liability incurred by the Initial Developer or such Contributor as a result of + warranty, support, indemnity or liability terms You offer. +

        3.6. Distribution of Executable Versions.
        You may distribute + Covered Code in Executable form only if the requirements of Section + 3.1-3.5 have been met for that Covered Code, and if You include a + notice stating that the Source Code version of the Covered Code is available + under the terms of this License, including a description of how and where You + have fulfilled the obligations of Section 3.2. The notice must be + conspicuously included in any notice in an Executable version, related + documentation or collateral in which You describe recipients' rights relating + to the Covered Code. You may distribute the Executable version of Covered Code + or ownership rights under a license of Your choice, which may contain terms + different from this License, provided that You are in compliance with the + terms of this License and that the license for the Executable version does not + attempt to limit or alter the recipient's rights in the Source Code version + from the rights set forth in this License. If You distribute the Executable + version under a different license You must make it absolutely clear that any + terms which differ from this License are offered by You alone, not by the + Initial Developer or any Contributor. You hereby agree to indemnify the + Initial Developer and every Contributor for any liability incurred by the + Initial Developer or such Contributor as a result of any such terms You offer. + +

        3.7. Larger Works.
        You may create a Larger Work by combining + Covered Code with other code not governed by the terms of this License and + distribute the Larger Work as a single product. In such a case, You must make + sure the requirements of this License are fulfilled for the Covered + Code.

      4. Inability to Comply Due to Statute or Regulation. +
        If it is impossible for You to comply with any of the terms of this + License with respect to some or all of the Covered Code due to statute, + judicial order, or regulation then You must: (a) comply with the terms of this + License to the maximum extent possible; and (b) describe the limitations and + the code they affect. Such description must be included in the LEGAL file + described in Section 3.4 and must be included with all distributions of + the Source Code. Except to the extent prohibited by statute or regulation, + such description must be sufficiently detailed for a recipient of ordinary + skill to be able to understand it.
      5. Application of this License. +
        This License applies to code to which the Initial Developer has attached + the notice in Exhibit A and to related Covered Code.
      6. Versions + of the License. +
        6.1. New Versions.
        Netscape Communications Corporation + (''Netscape'') may publish revised and/or new versions of the License from + time to time. Each version will be given a distinguishing version number. +

        6.2. Effect of New Versions.
        Once Covered Code has been + published under a particular version of the License, You may always continue + to use it under the terms of that version. You may also choose to use such + Covered Code under the terms of any subsequent version of the License + published by Netscape. No one other than Netscape has the right to modify the + terms applicable to Covered Code created under this License. +

        6.3. Derivative Works.
        If You create or use a modified version + of this License (which you may only do in order to apply it to code which is + not already Covered Code governed by this License), You must (a) rename Your + license so that the phrases ''Mozilla'', ''MOZILLAPL'', ''MOZPL'', + ''Netscape'', "MPL", ''NPL'' or any confusingly similar phrase do not appear + in your license (except to note that your license differs from this License) + and (b) otherwise make it clear that Your version of the license contains + terms which differ from the Mozilla Public License and Netscape Public + License. (Filling in the name of the Initial Developer, Original Code or + Contributor in the notice described in Exhibit A shall not of + themselves be deemed to be modifications of this License.)

      7. + DISCLAIMER OF WARRANTY. +
        COVERED CODE IS PROVIDED UNDER THIS LICENSE ON AN "AS IS'' BASIS, WITHOUT + WARRANTY OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, WITHOUT + LIMITATION, WARRANTIES THAT THE COVERED CODE IS FREE OF DEFECTS, MERCHANTABLE, + FIT FOR A PARTICULAR PURPOSE OR NON-INFRINGING. THE ENTIRE RISK AS TO THE + QUALITY AND PERFORMANCE OF THE COVERED CODE IS WITH YOU. SHOULD ANY COVERED + CODE PROVE DEFECTIVE IN ANY RESPECT, YOU (NOT THE INITIAL DEVELOPER OR ANY + OTHER CONTRIBUTOR) ASSUME THE COST OF ANY NECESSARY SERVICING, REPAIR OR + CORRECTION. THIS DISCLAIMER OF WARRANTY CONSTITUTES AN ESSENTIAL PART OF THIS + LICENSE. NO USE OF ANY COVERED CODE IS AUTHORIZED HEREUNDER EXCEPT UNDER THIS + DISCLAIMER.
      8. TERMINATION. +
        8.1.  This License and the rights granted hereunder will + terminate automatically if You fail to comply with terms herein and fail to + cure such breach within 30 days of becoming aware of the breach. All + sublicenses to the Covered Code which are properly granted shall survive any + termination of this License. Provisions which, by their nature, must remain in + effect beyond the termination of this License shall survive. +

        8.2.  If You initiate litigation by asserting a patent + infringement claim (excluding declatory judgment actions) against Initial + Developer or a Contributor (the Initial Developer or Contributor against whom + You file such action is referred to as "Participant")  alleging that: +

        (a)  such Participant's Contributor Version directly or + indirectly infringes any patent, then any and all rights granted by such + Participant to You under Sections 2.1 and/or 2.2 of this License shall, upon + 60 days notice from Participant terminate prospectively, unless if within 60 + days after receipt of notice You either: (i)  agree in writing to pay + Participant a mutually agreeable reasonable royalty for Your past and future + use of Modifications made by such Participant, or (ii) withdraw Your + litigation claim with respect to the Contributor Version against such + Participant.  If within 60 days of notice, a reasonable royalty and + payment arrangement are not mutually agreed upon in writing by the parties or + the litigation claim is not withdrawn, the rights granted by Participant to + You under Sections 2.1 and/or 2.2 automatically terminate at the expiration of + the 60 day notice period specified above. +

        (b)  any software, hardware, or device, other than such + Participant's Contributor Version, directly or indirectly infringes any + patent, then any rights granted to You by such Participant under Sections + 2.1(b) and 2.2(b) are revoked effective as of the date You first made, used, + sold, distributed, or had made, Modifications made by that Participant. +

        8.3.  If You assert a patent infringement claim against + Participant alleging that such Participant's Contributor Version directly or + indirectly infringes any patent where such claim is resolved (such as by + license or settlement) prior to the initiation of patent infringement + litigation, then the reasonable value of the licenses granted by such + Participant under Sections 2.1 or 2.2 shall be taken into account in + determining the amount or value of any payment or license. +

        8.4.  In the event of termination under Sections 8.1 or 8.2 + above,  all end user license agreements (excluding distributors and + resellers) which have been validly granted by You or any distributor hereunder + prior to termination shall survive termination.

      9. LIMITATION OF + LIABILITY. +
        UNDER NO CIRCUMSTANCES AND UNDER NO LEGAL THEORY, WHETHER TORT (INCLUDING + NEGLIGENCE), CONTRACT, OR OTHERWISE, SHALL YOU, THE INITIAL DEVELOPER, ANY + OTHER CONTRIBUTOR, OR ANY DISTRIBUTOR OF COVERED CODE, OR ANY SUPPLIER OF ANY + OF SUCH PARTIES, BE LIABLE TO ANY PERSON FOR ANY INDIRECT, SPECIAL, + INCIDENTAL, OR CONSEQUENTIAL DAMAGES OF ANY CHARACTER INCLUDING, WITHOUT + LIMITATION, DAMAGES FOR LOSS OF GOODWILL, WORK STOPPAGE, COMPUTER FAILURE OR + MALFUNCTION, OR ANY AND ALL OTHER COMMERCIAL DAMAGES OR LOSSES, EVEN IF SUCH + PARTY SHALL HAVE BEEN INFORMED OF THE POSSIBILITY OF SUCH DAMAGES. THIS + LIMITATION OF LIABILITY SHALL NOT APPLY TO LIABILITY FOR DEATH OR PERSONAL + INJURY RESULTING FROM SUCH PARTY'S NEGLIGENCE TO THE EXTENT APPLICABLE LAW + PROHIBITS SUCH LIMITATION. SOME JURISDICTIONS DO NOT ALLOW THE EXCLUSION OR + LIMITATION OF INCIDENTAL OR CONSEQUENTIAL DAMAGES, SO THIS EXCLUSION AND + LIMITATION MAY NOT APPLY TO YOU.
      10. U.S. GOVERNMENT END USERS. +
        The Covered Code is a ''commercial item,'' as that term is defined in 48 + C.F.R. 2.101 (Oct. 1995), consisting of ''commercial computer software'' and + ''commercial computer software documentation,'' as such terms are used in 48 + C.F.R. 12.212 (Sept. 1995). Consistent with 48 C.F.R. 12.212 and 48 C.F.R. + 227.7202-1 through 227.7202-4 (June 1995), all U.S. Government End Users + acquire Covered Code with only those rights set forth herein.
      11. + MISCELLANEOUS. +
        This License represents the complete agreement concerning subject matter + hereof. If any provision of this License is held to be unenforceable, such + provision shall be reformed only to the extent necessary to make it + enforceable. This License shall be governed by California law provisions + (except to the extent applicable law, if any, provides otherwise), excluding + its conflict-of-law provisions. With respect to disputes in which at least one + party is a citizen of, or an entity chartered or registered to do business in + the United States of America, any litigation relating to this License shall be + subject to the jurisdiction of the Federal Courts of the Northern District of + California, with venue lying in Santa Clara County, California, with the + losing party responsible for costs, including without limitation, court costs + and reasonable attorneys' fees and expenses. The application of the United + Nations Convention on Contracts for the International Sale of Goods is + expressly excluded. Any law or regulation which provides that the language of + a contract shall be construed against the drafter shall not apply to this + License.
      12. RESPONSIBILITY FOR CLAIMS. +
        As between Initial Developer and the Contributors, each party is + responsible for claims and damages arising, directly or indirectly, out of its + utilization of rights under this License and You agree to work with Initial + Developer and Contributors to distribute such responsibility on an equitable + basis. Nothing herein is intended or shall be deemed to constitute any + admission of liability.
      13. MULTIPLE-LICENSED CODE. +
        Initial Developer may designate portions of the Covered Code as + "Multiple-Licensed".  "Multiple-Licensed" means that the Initial + Developer permits you to utilize portions of the Covered Code under Your + choice of the MPL or the alternative licenses, if any, specified by the + Initial Developer in the file described in Exhibit A.
      +


      EXHIBIT A -Mozilla Public License. +

        The contents of this file are subject to the Mozilla Public License + Version 1.1 (the "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at +
        http://www.mozilla.org/MPL/ +

        Software distributed under the License is distributed on an "AS IS" basis, + WITHOUT WARRANTY OF
        ANY KIND, either express or implied. See the License + for the specific language governing rights and
        limitations under the + License. +

        The Original Code is Javassist. +

        The Initial Developer of the Original Code is Shigeru Chiba. + Portions created by the Initial Developer are
          + Copyright (C) 1999- Shigeru Chiba. All Rights Reserved. +

        Contributor(s): __Bill Burke, Jason T. Greene______________. + +

        Alternatively, the contents of this software may be used under the + terms of the GNU Lesser General Public License Version 2.1 or later + (the "LGPL"), or the Apache License Version 2.0 (the "AL"), + in which case the provisions of the LGPL or the AL are applicable + instead of those above. If you wish to allow use of your version of + this software only under the terms of either the LGPL or the AL, and not to allow others to + use your version of this software under the terms of the MPL, indicate + your decision by deleting the provisions above and replace them with + the notice and other provisions required by the LGPL or the AL. If you do not + delete the provisions above, a recipient may use your version of this + software under the terms of any one of the MPL, the LGPL or the AL. + +

      + + \ No newline at end of file diff --git a/licenses/LICENSE-javolution.txt b/licenses-binary/LICENSE-javolution.txt similarity index 100% rename from licenses/LICENSE-javolution.txt rename to licenses-binary/LICENSE-javolution.txt diff --git a/licenses/LICENSE-jline.txt b/licenses-binary/LICENSE-jline.txt similarity index 100% rename from licenses/LICENSE-jline.txt rename to licenses-binary/LICENSE-jline.txt diff --git a/licenses/LICENSE-junit-interface.txt b/licenses-binary/LICENSE-jodd.txt similarity index 69% rename from licenses/LICENSE-junit-interface.txt rename to licenses-binary/LICENSE-jodd.txt index e835350c4e2a4..cc6b458adb386 100644 --- a/licenses/LICENSE-junit-interface.txt +++ b/licenses-binary/LICENSE-jodd.txt @@ -1,15 +1,15 @@ -Copyright (c) 2009-2012, Stefan Zeiger +Copyright (c) 2003-present, Jodd Team (https://jodd.org) All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - * Redistributions of source code must retain the above copyright notice, - this list of conditions and the following disclaimer. +1. Redistributions of source code must retain the above copyright notice, +this list of conditions and the following disclaimer. - * Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. +2. Redistributions in binary form must reproduce the above copyright +notice, this list of conditions and the following disclaimer in the +documentation and/or other materials provided with the distribution. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE diff --git a/licenses/LICENSE-DPark.txt b/licenses-binary/LICENSE-join.txt similarity index 100% rename from licenses/LICENSE-DPark.txt rename to licenses-binary/LICENSE-join.txt diff --git a/licenses-binary/LICENSE-jquery.txt b/licenses-binary/LICENSE-jquery.txt new file mode 100644 index 0000000000000..45930542204fb --- /dev/null +++ b/licenses-binary/LICENSE-jquery.txt @@ -0,0 +1,20 @@ +Copyright JS Foundation and other contributors, https://js.foundation/ + +Permission is hereby granted, free of charge, to any person obtaining +a copy of this software and associated documentation files (the +"Software"), to deal in the Software without restriction, including +without limitation the rights to use, copy, modify, merge, publish, +distribute, sublicense, and/or sell copies of the Software, and to +permit persons to whom the Software is furnished to do so, subject to +the following conditions: + +The above copyright notice and this permission notice shall be +included in all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND +NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE +LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION +OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION +WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. \ No newline at end of file diff --git a/licenses-binary/LICENSE-json-formatter.txt b/licenses-binary/LICENSE-json-formatter.txt new file mode 100644 index 0000000000000..5193348fce126 --- /dev/null +++ b/licenses-binary/LICENSE-json-formatter.txt @@ -0,0 +1,6 @@ +Copyright 2014 Mohsen Azimi + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. \ No newline at end of file diff --git a/licenses-binary/LICENSE-jtransforms.html b/licenses-binary/LICENSE-jtransforms.html new file mode 100644 index 0000000000000..351c17412357b --- /dev/null +++ b/licenses-binary/LICENSE-jtransforms.html @@ -0,0 +1,388 @@ + + +Mozilla Public License version 1.1 + + + + +

      Mozilla Public License Version 1.1

      +

      1. Definitions.

      +
      +
      1.0.1. "Commercial Use" +
      means distribution or otherwise making the Covered Code available to a third party. +
      1.1. "Contributor" +
      means each entity that creates or contributes to the creation of Modifications. +
      1.2. "Contributor Version" +
      means the combination of the Original Code, prior Modifications used by a Contributor, + and the Modifications made by that particular Contributor. +
      1.3. "Covered Code" +
      means the Original Code or Modifications or the combination of the Original Code and + Modifications, in each case including portions thereof. +
      1.4. "Electronic Distribution Mechanism" +
      means a mechanism generally accepted in the software development community for the + electronic transfer of data. +
      1.5. "Executable" +
      means Covered Code in any form other than Source Code. +
      1.6. "Initial Developer" +
      means the individual or entity identified as the Initial Developer in the Source Code + notice required by Exhibit A. +
      1.7. "Larger Work" +
      means a work which combines Covered Code or portions thereof with code not governed + by the terms of this License. +
      1.8. "License" +
      means this document. +
      1.8.1. "Licensable" +
      means having the right to grant, to the maximum extent possible, whether at the + time of the initial grant or subsequently acquired, any and all of the rights + conveyed herein. +
      1.9. "Modifications" +
      +

      means any addition to or deletion from the substance or structure of either the + Original Code or any previous Modifications. When Covered Code is released as a + series of files, a Modification is: +

        +
      1. Any addition to or deletion from the contents of a file + containing Original Code or previous Modifications. +
      2. Any new file that contains any part of the Original Code or + previous Modifications. +
      +
      1.10. "Original Code" +
      means Source Code of computer software code which is described in the Source Code + notice required by Exhibit A as Original Code, and which, + at the time of its release under this License is not already Covered Code governed + by this License. +
      1.10.1. "Patent Claims" +
      means any patent claim(s), now owned or hereafter acquired, including without + limitation, method, process, and apparatus claims, in any patent Licensable by + grantor. +
      1.11. "Source Code" +
      means the preferred form of the Covered Code for making modifications to it, + including all modules it contains, plus any associated interface definition files, + scripts used to control compilation and installation of an Executable, or source + code differential comparisons against either the Original Code or another well known, + available Covered Code of the Contributor's choice. The Source Code can be in a + compressed or archival form, provided the appropriate decompression or de-archiving + software is widely available for no charge. +
      1.12. "You" (or "Your") +
      means an individual or a legal entity exercising rights under, and complying with + all of the terms of, this License or a future version of this License issued under + Section 6.1. For legal entities, "You" includes any entity + which controls, is controlled by, or is under common control with You. For purposes of + this definition, "control" means (a) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or otherwise, or (b) + ownership of more than fifty percent (50%) of the outstanding shares or beneficial + ownership of such entity. +
      +

      2. Source Code License.

      +

      2.1. The Initial Developer Grant.

      +

      The Initial Developer hereby grants You a world-wide, royalty-free, non-exclusive + license, subject to third party intellectual property claims: +

        +
      1. under intellectual property rights (other than patent or + trademark) Licensable by Initial Developer to use, reproduce, modify, display, perform, + sublicense and distribute the Original Code (or portions thereof) with or without + Modifications, and/or as part of a Larger Work; and +
      2. under Patents Claims infringed by the making, using or selling + of Original Code, to make, have made, use, practice, sell, and offer for sale, and/or + otherwise dispose of the Original Code (or portions thereof). +
      3. the licenses granted in this Section 2.1 + (a) and (b) are effective on + the date Initial Developer first distributes Original Code under the terms of this + License. +
      4. Notwithstanding Section 2.1 (b) + above, no patent license is granted: 1) for code that You delete from the Original Code; + 2) separate from the Original Code; or 3) for infringements caused by: i) the + modification of the Original Code or ii) the combination of the Original Code with other + software or devices. +
      +

      2.2. Contributor Grant.

      +

      Subject to third party intellectual property claims, each Contributor hereby grants You + a world-wide, royalty-free, non-exclusive license +

        +
      1. under intellectual property rights (other than patent or trademark) + Licensable by Contributor, to use, reproduce, modify, display, perform, sublicense and + distribute the Modifications created by such Contributor (or portions thereof) either on + an unmodified basis, with other Modifications, as Covered Code and/or as part of a Larger + Work; and +
      2. under Patent Claims infringed by the making, using, or selling of + Modifications made by that Contributor either alone and/or in combination with its + Contributor Version (or portions of such combination), to make, use, sell, offer for + sale, have made, and/or otherwise dispose of: 1) Modifications made by that Contributor + (or portions thereof); and 2) the combination of Modifications made by that Contributor + with its Contributor Version (or portions of such combination). +
      3. the licenses granted in Sections 2.2 + (a) and 2.2 (b) are effective + on the date Contributor first makes Commercial Use of the Covered Code. +
      4. Notwithstanding Section 2.2 (b) + above, no patent license is granted: 1) for any code that Contributor has deleted from + the Contributor Version; 2) separate from the Contributor Version; 3) for infringements + caused by: i) third party modifications of Contributor Version or ii) the combination of + Modifications made by that Contributor with other software (except as part of the + Contributor Version) or other devices; or 4) under Patent Claims infringed by Covered Code + in the absence of Modifications made by that Contributor. +
      +

      3. Distribution Obligations.

      +

      3.1. Application of License.

      +

      The Modifications which You create or to which You contribute are governed by the terms + of this License, including without limitation Section 2.2. The + Source Code version of Covered Code may be distributed only under the terms of this License + or a future version of this License released under Section 6.1, + and You must include a copy of this License with every copy of the Source Code You + distribute. You may not offer or impose any terms on any Source Code version that alters or + restricts the applicable version of this License or the recipients' rights hereunder. + However, You may include an additional document offering the additional rights described in + Section 3.5. +

      3.2. Availability of Source Code.

      +

      Any Modification which You create or to which You contribute must be made available in + Source Code form under the terms of this License either on the same media as an Executable + version or via an accepted Electronic Distribution Mechanism to anyone to whom you made an + Executable version available; and if made available via Electronic Distribution Mechanism, + must remain available for at least twelve (12) months after the date it initially became + available, or at least six (6) months after a subsequent version of that particular + Modification has been made available to such recipients. You are responsible for ensuring + that the Source Code version remains available even if the Electronic Distribution + Mechanism is maintained by a third party. +

      3.3. Description of Modifications.

      +

      You must cause all Covered Code to which You contribute to contain a file documenting the + changes You made to create that Covered Code and the date of any change. You must include a + prominent statement that the Modification is derived, directly or indirectly, from Original + Code provided by the Initial Developer and including the name of the Initial Developer in + (a) the Source Code, and (b) in any notice in an Executable version or related documentation + in which You describe the origin or ownership of the Covered Code. +

      3.4. Intellectual Property Matters

      +

      (a) Third Party Claims

      +

      If Contributor has knowledge that a license under a third party's intellectual property + rights is required to exercise the rights granted by such Contributor under Sections + 2.1 or 2.2, Contributor must include a + text file with the Source Code distribution titled "LEGAL" which describes the claim and the + party making the claim in sufficient detail that a recipient will know whom to contact. If + Contributor obtains such knowledge after the Modification is made available as described in + Section 3.2, Contributor shall promptly modify the LEGAL file in + all copies Contributor makes available thereafter and shall take other steps (such as + notifying appropriate mailing lists or newsgroups) reasonably calculated to inform those who + received the Covered Code that new knowledge has been obtained. +

      (b) Contributor APIs

      +

      If Contributor's Modifications include an application programming interface and Contributor + has knowledge of patent licenses which are reasonably necessary to implement that + API, Contributor must also include this information in the + legal file. +

      (c) Representations.

      +

      Contributor represents that, except as disclosed pursuant to Section 3.4 + (a) above, Contributor believes that Contributor's Modifications + are Contributor's original creation(s) and/or Contributor has sufficient rights to grant the + rights conveyed by this License. +

      3.5. Required Notices.

      +

      You must duplicate the notice in Exhibit A in each file of the + Source Code. If it is not possible to put such notice in a particular Source Code file due to + its structure, then You must include such notice in a location (such as a relevant directory) + where a user would be likely to look for such a notice. If You created one or more + Modification(s) You may add your name as a Contributor to the notice described in + Exhibit A. You must also duplicate this License in any documentation + for the Source Code where You describe recipients' rights or ownership rights relating to + Covered Code. You may choose to offer, and to charge a fee for, warranty, support, indemnity + or liability obligations to one or more recipients of Covered Code. However, You may do so + only on Your own behalf, and not on behalf of the Initial Developer or any Contributor. You + must make it absolutely clear than any such warranty, support, indemnity or liability + obligation is offered by You alone, and You hereby agree to indemnify the Initial Developer + and every Contributor for any liability incurred by the Initial Developer or such Contributor + as a result of warranty, support, indemnity or liability terms You offer. +

      3.6. Distribution of Executable Versions.

      +

      You may distribute Covered Code in Executable form only if the requirements of Sections + 3.1, 3.2, + 3.3, 3.4 and + 3.5 have been met for that Covered Code, and if You include a + notice stating that the Source Code version of the Covered Code is available under the terms + of this License, including a description of how and where You have fulfilled the obligations + of Section 3.2. The notice must be conspicuously included in any + notice in an Executable version, related documentation or collateral in which You describe + recipients' rights relating to the Covered Code. You may distribute the Executable version of + Covered Code or ownership rights under a license of Your choice, which may contain terms + different from this License, provided that You are in compliance with the terms of this + License and that the license for the Executable version does not attempt to limit or alter the + recipient's rights in the Source Code version from the rights set forth in this License. If + You distribute the Executable version under a different license You must make it absolutely + clear that any terms which differ from this License are offered by You alone, not by the + Initial Developer or any Contributor. You hereby agree to indemnify the Initial Developer and + every Contributor for any liability incurred by the Initial Developer or such Contributor as + a result of any such terms You offer. +

      3.7. Larger Works.

      +

      You may create a Larger Work by combining Covered Code with other code not governed by the + terms of this License and distribute the Larger Work as a single product. In such a case, + You must make sure the requirements of this License are fulfilled for the Covered Code. +

      4. Inability to Comply Due to Statute or Regulation.

      +

      If it is impossible for You to comply with any of the terms of this License with respect to + some or all of the Covered Code due to statute, judicial order, or regulation then You must: + (a) comply with the terms of this License to the maximum extent possible; and (b) describe + the limitations and the code they affect. Such description must be included in the + legal file described in Section + 3.4 and must be included with all distributions of the Source Code. + Except to the extent prohibited by statute or regulation, such description must be + sufficiently detailed for a recipient of ordinary skill to be able to understand it. +

      5. Application of this License.

      +

      This License applies to code to which the Initial Developer has attached the notice in + Exhibit A and to related Covered Code. +

      6. Versions of the License.

      +

      6.1. New Versions

      +

      Netscape Communications Corporation ("Netscape") may publish revised and/or new versions + of the License from time to time. Each version will be given a distinguishing version number. +

      6.2. Effect of New Versions

      +

      Once Covered Code has been published under a particular version of the License, You may + always continue to use it under the terms of that version. You may also choose to use such + Covered Code under the terms of any subsequent version of the License published by Netscape. + No one other than Netscape has the right to modify the terms applicable to Covered Code + created under this License. +

      6.3. Derivative Works

      +

      If You create or use a modified version of this License (which you may only do in order to + apply it to code which is not already Covered Code governed by this License), You must (a) + rename Your license so that the phrases "Mozilla", "MOZILLAPL", "MOZPL", "Netscape", "MPL", + "NPL" or any confusingly similar phrase do not appear in your license (except to note that + your license differs from this License) and (b) otherwise make it clear that Your version of + the license contains terms which differ from the Mozilla Public License and Netscape Public + License. (Filling in the name of the Initial Developer, Original Code or Contributor in the + notice described in Exhibit A shall not of themselves be deemed to + be modifications of this License.) +

      7. Disclaimer of warranty

      +

      Covered code is provided under this license on an "as is" + basis, without warranty of any kind, either expressed or implied, including, without + limitation, warranties that the covered code is free of defects, merchantable, fit for a + particular purpose or non-infringing. The entire risk as to the quality and performance of + the covered code is with you. Should any covered code prove defective in any respect, you + (not the initial developer or any other contributor) assume the cost of any necessary + servicing, repair or correction. This disclaimer of warranty constitutes an essential part + of this license. No use of any covered code is authorized hereunder except under this + disclaimer. +

      8. Termination

      +

      8.1. This License and the rights granted hereunder will terminate + automatically if You fail to comply with terms herein and fail to cure such breach + within 30 days of becoming aware of the breach. All sublicenses to the Covered Code which + are properly granted shall survive any termination of this License. Provisions which, by + their nature, must remain in effect beyond the termination of this License shall survive. +

      8.2. If You initiate litigation by asserting a patent infringement + claim (excluding declatory judgment actions) against Initial Developer or a Contributor + (the Initial Developer or Contributor against whom You file such action is referred to + as "Participant") alleging that: +

        +
      1. such Participant's Contributor Version directly or indirectly + infringes any patent, then any and all rights granted by such Participant to You under + Sections 2.1 and/or 2.2 of this + License shall, upon 60 days notice from Participant terminate prospectively, unless if + within 60 days after receipt of notice You either: (i) agree in writing to pay + Participant a mutually agreeable reasonable royalty for Your past and future use of + Modifications made by such Participant, or (ii) withdraw Your litigation claim with + respect to the Contributor Version against such Participant. If within 60 days of + notice, a reasonable royalty and payment arrangement are not mutually agreed upon in + writing by the parties or the litigation claim is not withdrawn, the rights granted by + Participant to You under Sections 2.1 and/or + 2.2 automatically terminate at the expiration of the 60 day + notice period specified above. +
      2. any software, hardware, or device, other than such Participant's + Contributor Version, directly or indirectly infringes any patent, then any rights + granted to You by such Participant under Sections 2.1(b) + and 2.2(b) are revoked effective as of the date You first + made, used, sold, distributed, or had made, Modifications made by that Participant. +
      +

      8.3. If You assert a patent infringement claim against Participant + alleging that such Participant's Contributor Version directly or indirectly infringes + any patent where such claim is resolved (such as by license or settlement) prior to the + initiation of patent infringement litigation, then the reasonable value of the licenses + granted by such Participant under Sections 2.1 or + 2.2 shall be taken into account in determining the amount or + value of any payment or license. +

      8.4. In the event of termination under Sections + 8.1 or 8.2 above, all end user + license agreements (excluding distributors and resellers) which have been validly + granted by You or any distributor hereunder prior to termination shall survive + termination. +

      9. Limitation of liability

      +

      Under no circumstances and under no legal theory, whether + tort (including negligence), contract, or otherwise, shall you, the initial developer, + any other contributor, or any distributor of covered code, or any supplier of any of + such parties, be liable to any person for any indirect, special, incidental, or + consequential damages of any character including, without limitation, damages for loss + of goodwill, work stoppage, computer failure or malfunction, or any and all other + commercial damages or losses, even if such party shall have been informed of the + possibility of such damages. This limitation of liability shall not apply to liability + for death or personal injury resulting from such party's negligence to the extent + applicable law prohibits such limitation. Some jurisdictions do not allow the exclusion + or limitation of incidental or consequential damages, so this exclusion and limitation + may not apply to you. +

      10. U.S. government end users

      +

      The Covered Code is a "commercial item," as that term is defined in 48 + C.F.R. 2.101 (Oct. 1995), consisting of + "commercial computer software" and "commercial computer software documentation," as such + terms are used in 48 C.F.R. 12.212 (Sept. + 1995). Consistent with 48 C.F.R. 12.212 and 48 C.F.R. + 227.7202-1 through 227.7202-4 (June 1995), all U.S. Government End Users + acquire Covered Code with only those rights set forth herein. +

      11. Miscellaneous

      +

      This License represents the complete agreement concerning subject matter hereof. If + any provision of this License is held to be unenforceable, such provision shall be + reformed only to the extent necessary to make it enforceable. This License shall be + governed by California law provisions (except to the extent applicable law, if any, + provides otherwise), excluding its conflict-of-law provisions. With respect to + disputes in which at least one party is a citizen of, or an entity chartered or + registered to do business in the United States of America, any litigation relating to + this License shall be subject to the jurisdiction of the Federal Courts of the + Northern District of California, with venue lying in Santa Clara County, California, + with the losing party responsible for costs, including without limitation, court + costs and reasonable attorneys' fees and expenses. The application of the United + Nations Convention on Contracts for the International Sale of Goods is expressly + excluded. Any law or regulation which provides that the language of a contract + shall be construed against the drafter shall not apply to this License. +

      12. Responsibility for claims

      +

      As between Initial Developer and the Contributors, each party is responsible for + claims and damages arising, directly or indirectly, out of its utilization of rights + under this License and You agree to work with Initial Developer and Contributors to + distribute such responsibility on an equitable basis. Nothing herein is intended or + shall be deemed to constitute any admission of liability. +

      13. Multiple-licensed code

      +

      Initial Developer may designate portions of the Covered Code as + "Multiple-Licensed". "Multiple-Licensed" means that the Initial Developer permits + you to utilize portions of the Covered Code under Your choice of the MPL + or the alternative licenses, if any, specified by the Initial Developer in the file + described in Exhibit A. +

      Exhibit A - Mozilla Public License.

      +
      "The contents of this file are subject to the Mozilla Public License
      +Version 1.1 (the "License"); you may not use this file except in
      +compliance with the License. You may obtain a copy of the License at
      +http://www.mozilla.org/MPL/
      +
      +Software distributed under the License is distributed on an "AS IS"
      +basis, WITHOUT WARRANTY OF ANY KIND, either express or implied. See the
      +License for the specific language governing rights and limitations
      +under the License.
      +
      +The Original Code is JTransforms.
      +
      +The Initial Developer of the Original Code is
      +Piotr Wendykier, Emory University.
      +Portions created by the Initial Developer are Copyright (C) 2007-2009
      +the Initial Developer. All Rights Reserved.
      +
      +Alternatively, the contents of this file may be used under the terms of
      +either the GNU General Public License Version 2 or later (the "GPL"), or
      +the GNU Lesser General Public License Version 2.1 or later (the "LGPL"),
      +in which case the provisions of the GPL or the LGPL are applicable instead
      +of those above. If you wish to allow use of your version of this file only
      +under the terms of either the GPL or the LGPL, and not to allow others to
      +use your version of this file under the terms of the MPL, indicate your
      +decision by deleting the provisions above and replace them with the notice
      +and other provisions required by the GPL or the LGPL. If you do not delete
      +the provisions above, a recipient may use your version of this file under
      +the terms of any one of the MPL, the GPL or the LGPL.
      +

      NOTE: The text of this Exhibit A may differ slightly from the text of + the notices in the Source Code files of the Original Code. You should + use the text of this Exhibit A rather than the text found in the + Original Code Source Code for Your Modifications. + +

      \ No newline at end of file diff --git a/licenses/LICENSE-kryo.txt b/licenses-binary/LICENSE-kryo.txt similarity index 100% rename from licenses/LICENSE-kryo.txt rename to licenses-binary/LICENSE-kryo.txt diff --git a/licenses-binary/LICENSE-leveldbjni.txt b/licenses-binary/LICENSE-leveldbjni.txt new file mode 100644 index 0000000000000..b4dabb9174c6d --- /dev/null +++ b/licenses-binary/LICENSE-leveldbjni.txt @@ -0,0 +1,27 @@ +Copyright (c) 2011 FuseSource Corp. All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: + + * Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above +copyright notice, this list of conditions and the following disclaimer +in the documentation and/or other materials provided with the +distribution. + * Neither the name of FuseSource Corp. nor the names of its +contributors may be used to endorse or promote products derived from +this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. \ No newline at end of file diff --git a/licenses-binary/LICENSE-machinist.txt b/licenses-binary/LICENSE-machinist.txt new file mode 100644 index 0000000000000..68cc3a3e3a9c4 --- /dev/null +++ b/licenses-binary/LICENSE-machinist.txt @@ -0,0 +1,19 @@ +Copyright (c) 2011-2014 Erik Osheim, Tom Switzer + +Permission is hereby granted, free of charge, to any person obtaining a copy of +this software and associated documentation files (the "Software"), to deal in +the Software without restriction, including without limitation the rights to +use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies +of the Software, and to permit persons to whom the Software is furnished to do +so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. \ No newline at end of file diff --git a/licenses-binary/LICENSE-matchMedia-polyfill.txt b/licenses-binary/LICENSE-matchMedia-polyfill.txt new file mode 100644 index 0000000000000..2fd0bc2b37448 --- /dev/null +++ b/licenses-binary/LICENSE-matchMedia-polyfill.txt @@ -0,0 +1 @@ +matchMedia() polyfill - Test a CSS media type/query in JS. Authors & copyright (c) 2012: Scott Jehl, Paul Irish, Nicholas Zakas. Dual MIT/BSD license \ No newline at end of file diff --git a/licenses/LICENSE-minlog.txt b/licenses-binary/LICENSE-minlog.txt similarity index 100% rename from licenses/LICENSE-minlog.txt rename to licenses-binary/LICENSE-minlog.txt diff --git a/licenses-binary/LICENSE-modernizr.txt b/licenses-binary/LICENSE-modernizr.txt new file mode 100644 index 0000000000000..2bf24b9b9f848 --- /dev/null +++ b/licenses-binary/LICENSE-modernizr.txt @@ -0,0 +1,21 @@ +The MIT License (MIT) + +Copyright (c) + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. \ No newline at end of file diff --git a/licenses/LICENSE-netlib.txt b/licenses-binary/LICENSE-netlib.txt similarity index 100% rename from licenses/LICENSE-netlib.txt rename to licenses-binary/LICENSE-netlib.txt diff --git a/licenses/LICENSE-paranamer.txt b/licenses-binary/LICENSE-paranamer.txt similarity index 100% rename from licenses/LICENSE-paranamer.txt rename to licenses-binary/LICENSE-paranamer.txt diff --git a/licenses/LICENSE-jpmml-model.txt b/licenses-binary/LICENSE-pmml-model.txt similarity index 100% rename from licenses/LICENSE-jpmml-model.txt rename to licenses-binary/LICENSE-pmml-model.txt diff --git a/licenses/LICENSE-protobuf.txt b/licenses-binary/LICENSE-protobuf.txt similarity index 100% rename from licenses/LICENSE-protobuf.txt rename to licenses-binary/LICENSE-protobuf.txt diff --git a/licenses-binary/LICENSE-py4j.txt b/licenses-binary/LICENSE-py4j.txt new file mode 100644 index 0000000000000..70af3e69ed67a --- /dev/null +++ b/licenses-binary/LICENSE-py4j.txt @@ -0,0 +1,27 @@ +Copyright (c) 2009-2011, Barthelemy Dagenais All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +- Redistributions of source code must retain the above copyright notice, this +list of conditions and the following disclaimer. + +- Redistributions in binary form must reproduce the above copyright notice, +this list of conditions and the following disclaimer in the documentation +and/or other materials provided with the distribution. + +- The name of the author may not be used to endorse or promote products +derived from this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +POSSIBILITY OF SUCH DAMAGE. + diff --git a/licenses/LICENSE-pyrolite.txt b/licenses-binary/LICENSE-pyrolite.txt similarity index 100% rename from licenses/LICENSE-pyrolite.txt rename to licenses-binary/LICENSE-pyrolite.txt diff --git a/licenses/LICENSE-reflectasm.txt b/licenses-binary/LICENSE-reflectasm.txt similarity index 100% rename from licenses/LICENSE-reflectasm.txt rename to licenses-binary/LICENSE-reflectasm.txt diff --git a/licenses-binary/LICENSE-respond.txt b/licenses-binary/LICENSE-respond.txt new file mode 100644 index 0000000000000..dea4ff9e5b2ea --- /dev/null +++ b/licenses-binary/LICENSE-respond.txt @@ -0,0 +1,22 @@ +Copyright (c) 2012 Scott Jehl + +Permission is hereby granted, free of charge, to any person +obtaining a copy of this software and associated documentation +files (the "Software"), to deal in the Software without +restriction, including without limitation the rights to use, +copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the +Software is furnished to do so, subject to the following +conditions: + +The above copyright notice and this permission notice shall be +included in all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES +OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND +NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT +HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, +WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR +OTHER DEALINGS IN THE SOFTWARE. \ No newline at end of file diff --git a/licenses-binary/LICENSE-sbt-launch-lib.txt b/licenses-binary/LICENSE-sbt-launch-lib.txt new file mode 100644 index 0000000000000..3b9156baaab78 --- /dev/null +++ b/licenses-binary/LICENSE-sbt-launch-lib.txt @@ -0,0 +1,26 @@ +// Generated from http://www.opensource.org/licenses/bsd-license.php +Copyright (c) 2011, Paul Phillips. +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + + * Redistributions of source code must retain the above copyright notice, + this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + * Neither the name of the author nor the names of its contributors may be + used to endorse or promote products derived from this software without + specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING +NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, +EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. \ No newline at end of file diff --git a/licenses/LICENSE-scala.txt b/licenses-binary/LICENSE-scala.txt similarity index 100% rename from licenses/LICENSE-scala.txt rename to licenses-binary/LICENSE-scala.txt diff --git a/licenses-binary/LICENSE-scopt.txt b/licenses-binary/LICENSE-scopt.txt new file mode 100644 index 0000000000000..e92e9b592fba0 --- /dev/null +++ b/licenses-binary/LICENSE-scopt.txt @@ -0,0 +1,9 @@ +This project is licensed under the MIT license. + +Copyright (c) scopt contributors + +Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. \ No newline at end of file diff --git a/licenses/LICENSE-slf4j.txt b/licenses-binary/LICENSE-slf4j.txt similarity index 100% rename from licenses/LICENSE-slf4j.txt rename to licenses-binary/LICENSE-slf4j.txt diff --git a/licenses-binary/LICENSE-sorttable.js.txt b/licenses-binary/LICENSE-sorttable.js.txt new file mode 100644 index 0000000000000..b31a5b206bf40 --- /dev/null +++ b/licenses-binary/LICENSE-sorttable.js.txt @@ -0,0 +1,16 @@ +Copyright (c) 1997-2007 Stuart Langridge + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. diff --git a/licenses/LICENSE-spire.txt b/licenses-binary/LICENSE-spire.txt similarity index 100% rename from licenses/LICENSE-spire.txt rename to licenses-binary/LICENSE-spire.txt diff --git a/licenses-binary/LICENSE-vis.txt b/licenses-binary/LICENSE-vis.txt new file mode 100644 index 0000000000000..18b7323059a41 --- /dev/null +++ b/licenses-binary/LICENSE-vis.txt @@ -0,0 +1,22 @@ +vis.js +https://github.com/almende/vis + +A dynamic, browser-based visualization library. + +@version 4.16.1 +@date 2016-04-18 + +@license +Copyright (C) 2011-2016 Almende B.V, http://almende.com + +Vis.js is dual licensed under both + +* The Apache 2.0 License + http://www.apache.org/licenses/LICENSE-2.0 + +and + +* The MIT License + http://opensource.org/licenses/MIT + +Vis.js may be distributed under either license. \ No newline at end of file diff --git a/licenses/LICENSE-xmlenc.txt b/licenses-binary/LICENSE-xmlenc.txt similarity index 100% rename from licenses/LICENSE-xmlenc.txt rename to licenses-binary/LICENSE-xmlenc.txt diff --git a/licenses/LICENSE-zstd-jni.txt b/licenses-binary/LICENSE-zstd-jni.txt similarity index 100% rename from licenses/LICENSE-zstd-jni.txt rename to licenses-binary/LICENSE-zstd-jni.txt diff --git a/licenses/LICENSE-zstd.txt b/licenses-binary/LICENSE-zstd.txt similarity index 100% rename from licenses/LICENSE-zstd.txt rename to licenses-binary/LICENSE-zstd.txt diff --git a/licenses/LICENSE-CC0.txt b/licenses/LICENSE-CC0.txt new file mode 100644 index 0000000000000..1625c17936079 --- /dev/null +++ b/licenses/LICENSE-CC0.txt @@ -0,0 +1,121 @@ +Creative Commons Legal Code + +CC0 1.0 Universal + + CREATIVE COMMONS CORPORATION IS NOT A LAW FIRM AND DOES NOT PROVIDE + LEGAL SERVICES. DISTRIBUTION OF THIS DOCUMENT DOES NOT CREATE AN + ATTORNEY-CLIENT RELATIONSHIP. CREATIVE COMMONS PROVIDES THIS + INFORMATION ON AN "AS-IS" BASIS. CREATIVE COMMONS MAKES NO WARRANTIES + REGARDING THE USE OF THIS DOCUMENT OR THE INFORMATION OR WORKS + PROVIDED HEREUNDER, AND DISCLAIMS LIABILITY FOR DAMAGES RESULTING FROM + THE USE OF THIS DOCUMENT OR THE INFORMATION OR WORKS PROVIDED + HEREUNDER. + +Statement of Purpose + +The laws of most jurisdictions throughout the world automatically confer +exclusive Copyright and Related Rights (defined below) upon the creator +and subsequent owner(s) (each and all, an "owner") of an original work of +authorship and/or a database (each, a "Work"). + +Certain owners wish to permanently relinquish those rights to a Work for +the purpose of contributing to a commons of creative, cultural and +scientific works ("Commons") that the public can reliably and without fear +of later claims of infringement build upon, modify, incorporate in other +works, reuse and redistribute as freely as possible in any form whatsoever +and for any purposes, including without limitation commercial purposes. +These owners may contribute to the Commons to promote the ideal of a free +culture and the further production of creative, cultural and scientific +works, or to gain reputation or greater distribution for their Work in +part through the use and efforts of others. + +For these and/or other purposes and motivations, and without any +expectation of additional consideration or compensation, the person +associating CC0 with a Work (the "Affirmer"), to the extent that he or she +is an owner of Copyright and Related Rights in the Work, voluntarily +elects to apply CC0 to the Work and publicly distribute the Work under its +terms, with knowledge of his or her Copyright and Related Rights in the +Work and the meaning and intended legal effect of CC0 on those rights. + +1. Copyright and Related Rights. A Work made available under CC0 may be +protected by copyright and related or neighboring rights ("Copyright and +Related Rights"). Copyright and Related Rights include, but are not +limited to, the following: + + i. the right to reproduce, adapt, distribute, perform, display, + communicate, and translate a Work; + ii. moral rights retained by the original author(s) and/or performer(s); +iii. publicity and privacy rights pertaining to a person's image or + likeness depicted in a Work; + iv. rights protecting against unfair competition in regards to a Work, + subject to the limitations in paragraph 4(a), below; + v. rights protecting the extraction, dissemination, use and reuse of data + in a Work; + vi. database rights (such as those arising under Directive 96/9/EC of the + European Parliament and of the Council of 11 March 1996 on the legal + protection of databases, and under any national implementation + thereof, including any amended or successor version of such + directive); and +vii. other similar, equivalent or corresponding rights throughout the + world based on applicable law or treaty, and any national + implementations thereof. + +2. Waiver. To the greatest extent permitted by, but not in contravention +of, applicable law, Affirmer hereby overtly, fully, permanently, +irrevocably and unconditionally waives, abandons, and surrenders all of +Affirmer's Copyright and Related Rights and associated claims and causes +of action, whether now known or unknown (including existing as well as +future claims and causes of action), in the Work (i) in all territories +worldwide, (ii) for the maximum duration provided by applicable law or +treaty (including future time extensions), (iii) in any current or future +medium and for any number of copies, and (iv) for any purpose whatsoever, +including without limitation commercial, advertising or promotional +purposes (the "Waiver"). Affirmer makes the Waiver for the benefit of each +member of the public at large and to the detriment of Affirmer's heirs and +successors, fully intending that such Waiver shall not be subject to +revocation, rescission, cancellation, termination, or any other legal or +equitable action to disrupt the quiet enjoyment of the Work by the public +as contemplated by Affirmer's express Statement of Purpose. + +3. Public License Fallback. Should any part of the Waiver for any reason +be judged legally invalid or ineffective under applicable law, then the +Waiver shall be preserved to the maximum extent permitted taking into +account Affirmer's express Statement of Purpose. In addition, to the +extent the Waiver is so judged Affirmer hereby grants to each affected +person a royalty-free, non transferable, non sublicensable, non exclusive, +irrevocable and unconditional license to exercise Affirmer's Copyright and +Related Rights in the Work (i) in all territories worldwide, (ii) for the +maximum duration provided by applicable law or treaty (including future +time extensions), (iii) in any current or future medium and for any number +of copies, and (iv) for any purpose whatsoever, including without +limitation commercial, advertising or promotional purposes (the +"License"). The License shall be deemed effective as of the date CC0 was +applied by Affirmer to the Work. Should any part of the License for any +reason be judged legally invalid or ineffective under applicable law, such +partial invalidity or ineffectiveness shall not invalidate the remainder +of the License, and in such case Affirmer hereby affirms that he or she +will not (i) exercise any of his or her remaining Copyright and Related +Rights in the Work or (ii) assert any associated claims and causes of +action with respect to the Work, in either case contrary to Affirmer's +express Statement of Purpose. + +4. Limitations and Disclaimers. + + a. No trademark or patent rights held by Affirmer are waived, abandoned, + surrendered, licensed or otherwise affected by this document. + b. Affirmer offers the Work as-is and makes no representations or + warranties of any kind concerning the Work, express, implied, + statutory or otherwise, including without limitation warranties of + title, merchantability, fitness for a particular purpose, non + infringement, or the absence of latent or other defects, accuracy, or + the present or absence of errors, whether or not discoverable, all to + the greatest extent permissible under applicable law. + c. Affirmer disclaims responsibility for clearing rights of other persons + that may apply to the Work or any use thereof, including without + limitation any person's Copyright and Related Rights in the Work. + Further, Affirmer disclaims responsibility for obtaining any necessary + consents, permissions or other rights required for any use of the + Work. + d. Affirmer understands and acknowledges that Creative Commons is not a + party to this document and has no duty or obligation with respect to + this CC0 or use of the Work. \ No newline at end of file diff --git a/licenses/LICENSE-SnapTree.txt b/licenses/LICENSE-SnapTree.txt deleted file mode 100644 index a538825d89ec5..0000000000000 --- a/licenses/LICENSE-SnapTree.txt +++ /dev/null @@ -1,35 +0,0 @@ -SNAPTREE LICENSE - -Copyright (c) 2009-2012 Stanford University, unless otherwise specified. -All rights reserved. - -This software was developed by the Pervasive Parallelism Laboratory of -Stanford University, California, USA. - -Permission to use, copy, modify, and distribute this software in source -or binary form for any purpose with or without fee is hereby granted, -provided that the following conditions are met: - - 1. Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - 2. Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - 3. Neither the name of Stanford University nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. - - -THIS SOFTWARE IS PROVIDED BY THE REGENTS AND CONTRIBUTORS ``AS IS'' AND -ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE -ARE DISCLAIMED. IN NO EVENT SHALL THE REGENTS OR CONTRIBUTORS BE LIABLE -FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT -LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY -OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF -SUCH DAMAGE. diff --git a/licenses/LICENSE-bootstrap.txt b/licenses/LICENSE-bootstrap.txt new file mode 100644 index 0000000000000..6c711832fbc85 --- /dev/null +++ b/licenses/LICENSE-bootstrap.txt @@ -0,0 +1,13 @@ +Copyright 2013 Twitter, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. diff --git a/licenses/LICENSE-boto.txt b/licenses/LICENSE-boto.txt deleted file mode 100644 index 7bba0cd9e10a4..0000000000000 --- a/licenses/LICENSE-boto.txt +++ /dev/null @@ -1,20 +0,0 @@ -Copyright (c) 2006-2008 Mitch Garnaat http://garnaat.org/ - -Permission is hereby granted, free of charge, to any person obtaining a -copy of this software and associated documentation files (the -"Software"), to deal in the Software without restriction, including -without limitation the rights to use, copy, modify, merge, publish, dis- -tribute, sublicense, and/or sell copies of the Software, and to permit -persons to whom the Software is furnished to do so, subject to the fol- -lowing conditions: - -The above copyright notice and this permission notice shall be included -in all copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS -OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABIL- -ITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT -SHALL THE AUTHOR BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, -WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS -IN THE SOFTWARE. \ No newline at end of file diff --git a/licenses/LICENSE-datatables.txt b/licenses/LICENSE-datatables.txt new file mode 100644 index 0000000000000..bb7708b5b5a49 --- /dev/null +++ b/licenses/LICENSE-datatables.txt @@ -0,0 +1,7 @@ +Copyright (C) 2008-2018, SpryMedia Ltd. + +Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. \ No newline at end of file diff --git a/licenses/LICENSE-graphlib-dot.txt b/licenses/LICENSE-graphlib-dot.txt index c9e18cd562423..4864fe05e9803 100644 --- a/licenses/LICENSE-graphlib-dot.txt +++ b/licenses/LICENSE-graphlib-dot.txt @@ -1,4 +1,4 @@ -Copyright (c) 2012-2013 Chris Pettitt +Copyright (c) 2013 Chris Pettitt Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal diff --git a/licenses/LICENSE-jbcrypt.txt b/licenses/LICENSE-jbcrypt.txt deleted file mode 100644 index d332534c06356..0000000000000 --- a/licenses/LICENSE-jbcrypt.txt +++ /dev/null @@ -1,17 +0,0 @@ -jBCrypt is subject to the following license: - -/* - * Copyright (c) 2006 Damien Miller - * - * Permission to use, copy, modify, and distribute this software for any - * purpose with or without fee is hereby granted, provided that the above - * copyright notice and this permission notice appear in all copies. - * - * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES - * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF - * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR - * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES - * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN - * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF - * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. - */ diff --git a/licenses/LICENSE-jmock.txt b/licenses/LICENSE-join.txt similarity index 60% rename from licenses/LICENSE-jmock.txt rename to licenses/LICENSE-join.txt index ed7964fe3d9ef..1d916090e4ea0 100644 --- a/licenses/LICENSE-jmock.txt +++ b/licenses/LICENSE-join.txt @@ -1,19 +1,21 @@ -Copyright (c) 2000-2017, jMock.org +Copyright (c) 2011, Douban Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: -Redistributions of source code must retain the above copyright notice, -this list of conditions and the following disclaimer. Redistributions -in binary form must reproduce the above copyright notice, this list of -conditions and the following disclaimer in the documentation and/or -other materials provided with the distribution. + * Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. -Neither the name of jMock nor the names of its contributors may be -used to endorse or promote products derived from this software without -specific prior written permission. + * Redistributions in binary form must reproduce the above +copyright notice, this list of conditions and the following disclaimer +in the documentation and/or other materials provided with the +distribution. + + * Neither the name of the Douban Inc. nor the names of its +contributors may be used to endorse or promote products derived from +this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT @@ -25,4 +27,4 @@ LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. \ No newline at end of file diff --git a/licenses/LICENSE-jquery.txt b/licenses/LICENSE-jquery.txt index e1dd696d3b6cc..45930542204fb 100644 --- a/licenses/LICENSE-jquery.txt +++ b/licenses/LICENSE-jquery.txt @@ -1,9 +1,20 @@ -The MIT License (MIT) +Copyright JS Foundation and other contributors, https://js.foundation/ -Copyright (c) +Permission is hereby granted, free of charge, to any person obtaining +a copy of this software and associated documentation files (the +"Software"), to deal in the Software without restriction, including +without limitation the rights to use, copy, modify, merge, publish, +distribute, sublicense, and/or sell copies of the Software, and to +permit persons to whom the Software is furnished to do so, subject to +the following conditions: -Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: +The above copyright notice and this permission notice shall be +included in all copies or substantial portions of the Software. -The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. \ No newline at end of file +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND +NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE +LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION +OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION +WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. \ No newline at end of file diff --git a/licenses/LICENSE-json-formatter.txt b/licenses/LICENSE-json-formatter.txt new file mode 100644 index 0000000000000..5193348fce126 --- /dev/null +++ b/licenses/LICENSE-json-formatter.txt @@ -0,0 +1,6 @@ +Copyright 2014 Mohsen Azimi + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. \ No newline at end of file diff --git a/licenses/LICENSE-matchMedia-polyfill.txt b/licenses/LICENSE-matchMedia-polyfill.txt new file mode 100644 index 0000000000000..2fd0bc2b37448 --- /dev/null +++ b/licenses/LICENSE-matchMedia-polyfill.txt @@ -0,0 +1 @@ +matchMedia() polyfill - Test a CSS media type/query in JS. Authors & copyright (c) 2012: Scott Jehl, Paul Irish, Nicholas Zakas. Dual MIT/BSD license \ No newline at end of file diff --git a/licenses/LICENSE-postgresql.txt b/licenses/LICENSE-postgresql.txt deleted file mode 100644 index 515bf9af4d432..0000000000000 --- a/licenses/LICENSE-postgresql.txt +++ /dev/null @@ -1,24 +0,0 @@ -PostgreSQL Database Management System -(formerly known as Postgres, then as Postgres95) - -Portions Copyright (c) 1996-2010, PostgreSQL Global Development Group - -Portions Copyright (c) 1994, The Regents of the University of California - -Permission to use, copy, modify, and distribute this software and its -documentation for any purpose, without fee, and without a written agreement -is hereby granted, provided that the above copyright notice and this -paragraph and the following two paragraphs appear in all copies. - -IN NO EVENT SHALL THE UNIVERSITY OF CALIFORNIA BE LIABLE TO ANY PARTY FOR -DIRECT, INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, INCLUDING -LOST PROFITS, ARISING OUT OF THE USE OF THIS SOFTWARE AND ITS -DOCUMENTATION, EVEN IF THE UNIVERSITY OF CALIFORNIA HAS BEEN ADVISED OF THE -POSSIBILITY OF SUCH DAMAGE. - -THE UNIVERSITY OF CALIFORNIA SPECIFICALLY DISCLAIMS ANY WARRANTIES, -INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY -AND FITNESS FOR A PARTICULAR PURPOSE. THE SOFTWARE PROVIDED HEREUNDER IS -ON AN "AS IS" BASIS, AND THE UNIVERSITY OF CALIFORNIA HAS NO OBLIGATIONS TO -PROVIDE MAINTENANCE, SUPPORT, UPDATES, ENHANCEMENTS, OR MODIFICATIONS. - diff --git a/licenses/LICENSE-respond.txt b/licenses/LICENSE-respond.txt new file mode 100644 index 0000000000000..dea4ff9e5b2ea --- /dev/null +++ b/licenses/LICENSE-respond.txt @@ -0,0 +1,22 @@ +Copyright (c) 2012 Scott Jehl + +Permission is hereby granted, free of charge, to any person +obtaining a copy of this software and associated documentation +files (the "Software"), to deal in the Software without +restriction, including without limitation the rights to use, +copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the +Software is furnished to do so, subject to the following +conditions: + +The above copyright notice and this permission notice shall be +included in all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES +OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND +NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT +HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, +WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR +OTHER DEALINGS IN THE SOFTWARE. \ No newline at end of file diff --git a/licenses/LICENSE-scalacheck.txt b/licenses/LICENSE-scalacheck.txt deleted file mode 100644 index cb8f97842f4c4..0000000000000 --- a/licenses/LICENSE-scalacheck.txt +++ /dev/null @@ -1,32 +0,0 @@ -ScalaCheck LICENSE - -Copyright (c) 2007-2015, Rickard Nilsson -All rights reserved. - -Permission to use, copy, modify, and distribute this software in source -or binary form for any purpose with or without fee is hereby granted, -provided that the following conditions are met: - - 1. Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - 2. Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - 3. Neither the name of the author nor the names of its contributors - may be used to endorse or promote products derived from this - software without specific prior written permission. - - -THIS SOFTWARE IS PROVIDED BY THE REGENTS AND CONTRIBUTORS ``AS IS'' AND -ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE -ARE DISCLAIMED. IN NO EVENT SHALL THE REGENTS OR CONTRIBUTORS BE LIABLE -FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT -LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY -OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF -SUCH DAMAGE. \ No newline at end of file diff --git a/licenses/LICENSE-vis.txt b/licenses/LICENSE-vis.txt new file mode 100644 index 0000000000000..18b7323059a41 --- /dev/null +++ b/licenses/LICENSE-vis.txt @@ -0,0 +1,22 @@ +vis.js +https://github.com/almende/vis + +A dynamic, browser-based visualization library. + +@version 4.16.1 +@date 2016-04-18 + +@license +Copyright (C) 2011-2016 Almende B.V, http://almende.com + +Vis.js is dual licensed under both + +* The Apache 2.0 License + http://www.apache.org/licenses/LICENSE-2.0 + +and + +* The MIT License + http://opensource.org/licenses/MIT + +Vis.js may be distributed under either license. \ No newline at end of file diff --git a/mllib/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister b/mllib/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister index a865cbe19b184..a7dfd2d5c1e70 100644 --- a/mllib/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister +++ b/mllib/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister @@ -1 +1,2 @@ org.apache.spark.ml.source.libsvm.LibSVMFileFormat +org.apache.spark.ml.source.image.ImageFileFormat diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala index c9786f1f7ceb1..8a57bfc029d14 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala @@ -29,6 +29,7 @@ import org.apache.spark.ml.tree._ import org.apache.spark.ml.tree.DecisionTreeModelReadWrite._ import org.apache.spark.ml.tree.impl.RandomForest import org.apache.spark.ml.util._ +import org.apache.spark.ml.util.Instrumentation.instrumented import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy} import org.apache.spark.mllib.tree.model.{DecisionTreeModel => OldDecisionTreeModel} import org.apache.spark.rdd.RDD @@ -96,8 +97,10 @@ class DecisionTreeClassifier @Since("1.4.0") ( @Since("1.6.0") override def setSeed(value: Long): this.type = set(seed, value) - override protected def train(dataset: Dataset[_]): DecisionTreeClassificationModel = { - val instr = Instrumentation.create(this, dataset) + override protected def train( + dataset: Dataset[_]): DecisionTreeClassificationModel = instrumented { instr => + instr.logPipelineStage(this) + instr.logDataset(dataset) val categoricalFeatures: Map[Int, Int] = MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol))) val numClasses: Int = getNumClasses(dataset) @@ -112,30 +115,27 @@ class DecisionTreeClassifier @Since("1.4.0") ( val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset, numClasses) val strategy = getOldStrategy(categoricalFeatures, numClasses) - instr.logParams(maxDepth, maxBins, minInstancesPerNode, minInfoGain, maxMemoryInMB, + instr.logParams(this, maxDepth, maxBins, minInstancesPerNode, minInfoGain, maxMemoryInMB, cacheNodeIds, checkpointInterval, impurity, seed) val trees = RandomForest.run(oldDataset, strategy, numTrees = 1, featureSubsetStrategy = "all", seed = $(seed), instr = Some(instr), parentUID = Some(uid)) - val m = trees.head.asInstanceOf[DecisionTreeClassificationModel] - instr.logSuccess(m) - m + trees.head.asInstanceOf[DecisionTreeClassificationModel] } /** (private[ml]) Train a decision tree on an RDD */ private[ml] def train(data: RDD[LabeledPoint], - oldStrategy: OldStrategy): DecisionTreeClassificationModel = { - val instr = Instrumentation.create(this, data) - instr.logParams(maxDepth, maxBins, minInstancesPerNode, minInfoGain, maxMemoryInMB, + oldStrategy: OldStrategy): DecisionTreeClassificationModel = instrumented { instr => + instr.logPipelineStage(this) + instr.logDataset(data) + instr.logParams(this, maxDepth, maxBins, minInstancesPerNode, minInfoGain, maxMemoryInMB, cacheNodeIds, checkpointInterval, impurity, seed) val trees = RandomForest.run(data, oldStrategy, numTrees = 1, featureSubsetStrategy = "all", seed = 0L, instr = Some(instr), parentUID = Some(uid)) - val m = trees.head.asInstanceOf[DecisionTreeClassificationModel] - instr.logSuccess(m) - m + trees.head.asInstanceOf[DecisionTreeClassificationModel] } /** (private[ml]) Create a Strategy instance to use with the old API. */ diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala index 337133a2e2326..33acd9914073f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala @@ -31,9 +31,9 @@ import org.apache.spark.ml.tree._ import org.apache.spark.ml.tree.impl.GradientBoostedTrees import org.apache.spark.ml.util._ import org.apache.spark.ml.util.DefaultParamsReader.Metadata +import org.apache.spark.ml.util.Instrumentation.instrumented import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} import org.apache.spark.mllib.tree.model.{GradientBoostedTreesModel => OldGBTModel} -import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, Dataset, Row} import org.apache.spark.sql.functions._ @@ -152,7 +152,8 @@ class GBTClassifier @Since("1.4.0") ( set(validationIndicatorCol, value) } - override protected def train(dataset: Dataset[_]): GBTClassificationModel = { + override protected def train( + dataset: Dataset[_]): GBTClassificationModel = instrumented { instr => val categoricalFeatures: Map[Int, Int] = MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol))) @@ -189,8 +190,9 @@ class GBTClassifier @Since("1.4.0") ( s" numClasses=$numClasses, but thresholds has length ${$(thresholds).length}") } - val instr = Instrumentation.create(this, dataset) - instr.logParams(labelCol, featuresCol, predictionCol, impurity, lossType, + instr.logPipelineStage(this) + instr.logDataset(dataset) + instr.logParams(this, labelCol, featuresCol, predictionCol, impurity, lossType, maxDepth, maxBins, maxIter, maxMemoryInMB, minInfoGain, minInstancesPerNode, seed, stepSize, subsamplingRate, cacheNodeIds, checkpointInterval, featureSubsetStrategy, validationIndicatorCol) @@ -204,9 +206,7 @@ class GBTClassifier @Since("1.4.0") ( GradientBoostedTrees.run(trainDataset, boostingStrategy, $(seed), $(featureSubsetStrategy)) } - val m = new GBTClassificationModel(uid, baseLearners, learnerWeights, numFeatures) - instr.logSuccess(m) - m + new GBTClassificationModel(uid, baseLearners, learnerWeights, numFeatures) } @Since("1.4.1") diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala index 38eb04556b775..1b5c02fc9a576 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala @@ -33,6 +33,7 @@ import org.apache.spark.ml.optim.loss.{L2Regularization, RDDLossFunction} import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util._ +import org.apache.spark.ml.util.Instrumentation.instrumented import org.apache.spark.mllib.linalg.VectorImplicits._ import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer import org.apache.spark.rdd.RDD @@ -162,7 +163,7 @@ class LinearSVC @Since("2.2.0") ( @Since("2.2.0") override def copy(extra: ParamMap): LinearSVC = defaultCopy(extra) - override protected def train(dataset: Dataset[_]): LinearSVCModel = { + override protected def train(dataset: Dataset[_]): LinearSVCModel = instrumented { instr => val w = if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else col($(weightCol)) val instances: RDD[Instance] = dataset.select(col($(labelCol)), w, col($(featuresCol))).rdd.map { @@ -170,8 +171,9 @@ class LinearSVC @Since("2.2.0") ( Instance(label, weight, features) } - val instr = Instrumentation.create(this, dataset) - instr.logParams(regParam, maxIter, fitIntercept, tol, standardization, threshold, + instr.logPipelineStage(this) + instr.logDataset(dataset) + instr.logParams(this, regParam, maxIter, fitIntercept, tol, standardization, threshold, aggregationDepth) val (summarizer, labelSummarizer) = { @@ -187,7 +189,7 @@ class LinearSVC @Since("2.2.0") ( (new MultivariateOnlineSummarizer, new MultiClassSummarizer) )(seqOp, combOp, $(aggregationDepth)) } - instr.logNamedValue(Instrumentation.loggerTags.numExamples, summarizer.count) + instr.logNumExamples(summarizer.count) instr.logNamedValue("lowestLabelWeight", labelSummarizer.histogram.min.toString) instr.logNamedValue("highestLabelWeight", labelSummarizer.histogram.max.toString) @@ -276,9 +278,7 @@ class LinearSVC @Since("2.2.0") ( (Vectors.dense(coefficientArray), intercept, scaledObjectiveHistory.result()) } - val model = copyValues(new LinearSVCModel(uid, coefficientVector, interceptVector)) - instr.logSuccess(model) - model + copyValues(new LinearSVCModel(uid, coefficientVector, interceptVector)) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index 92e342ed4a464..6f0804f0c8e4a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -35,6 +35,7 @@ import org.apache.spark.ml.optim.loss.{L2Regularization, RDDLossFunction} import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util._ +import org.apache.spark.ml.util.Instrumentation.instrumented import org.apache.spark.mllib.evaluation.{BinaryClassificationMetrics, MulticlassMetrics} import org.apache.spark.mllib.linalg.VectorImplicits._ import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer @@ -490,7 +491,7 @@ class LogisticRegression @Since("1.2.0") ( protected[spark] def train( dataset: Dataset[_], - handlePersistence: Boolean): LogisticRegressionModel = { + handlePersistence: Boolean): LogisticRegressionModel = instrumented { instr => val w = if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else col($(weightCol)) val instances: RDD[Instance] = dataset.select(col($(labelCol)), w, col($(featuresCol))).rdd.map { @@ -500,8 +501,9 @@ class LogisticRegression @Since("1.2.0") ( if (handlePersistence) instances.persist(StorageLevel.MEMORY_AND_DISK) - val instr = Instrumentation.create(this, dataset) - instr.logParams(regParam, elasticNetParam, standardization, threshold, + instr.logPipelineStage(this) + instr.logDataset(dataset) + instr.logParams(this, regParam, elasticNetParam, standardization, threshold, maxIter, tol, fitIntercept) val (summarizer, labelSummarizer) = { @@ -517,7 +519,7 @@ class LogisticRegression @Since("1.2.0") ( (new MultivariateOnlineSummarizer, new MultiClassSummarizer) )(seqOp, combOp, $(aggregationDepth)) } - instr.logNamedValue(Instrumentation.loggerTags.numExamples, summarizer.count) + instr.logNumExamples(summarizer.count) instr.logNamedValue("lowestLabelWeight", labelSummarizer.histogram.min.toString) instr.logNamedValue("highestLabelWeight", labelSummarizer.histogram.max.toString) @@ -905,8 +907,6 @@ class LogisticRegression @Since("1.2.0") ( objectiveHistory) } model.setSummary(Some(logRegSummary)) - instr.logSuccess(model) - model } @Since("1.4.0") @@ -1484,7 +1484,7 @@ sealed trait LogisticRegressionSummary extends Serializable { /** * Convenient method for casting to binary logistic regression summary. - * This method will throws an Exception if the summary is not a binary summary. + * This method will throw an Exception if the summary is not a binary summary. */ @Since("2.3.0") def asBinary: BinaryLogisticRegressionSummary = this match { diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala index 57ba47e596a97..4feddce1d9f2d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala @@ -23,12 +23,13 @@ import org.apache.hadoop.fs.Path import org.apache.spark.annotation.Since import org.apache.spark.ml.ann.{FeedForwardTopology, FeedForwardTrainer} -import org.apache.spark.ml.feature.LabeledPoint +import org.apache.spark.ml.feature.OneHotEncoderModel import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util._ -import org.apache.spark.sql.Dataset +import org.apache.spark.ml.util.Instrumentation.instrumented +import org.apache.spark.sql.{Dataset, Row} /** Params for Multilayer Perceptron. */ private[classification] trait MultilayerPerceptronParams extends ProbabilisticClassifierParams @@ -102,36 +103,6 @@ private[classification] trait MultilayerPerceptronParams extends ProbabilisticCl solver -> LBFGS, stepSize -> 0.03) } -/** Label to vector converter. */ -private object LabelConverter { - // TODO: Use OneHotEncoder instead - /** - * Encodes a label as a vector. - * Returns a vector of given length with zeroes at all positions - * and value 1.0 at the position that corresponds to the label. - * - * @param labeledPoint labeled point - * @param labelCount total number of labels - * @return pair of features and vector encoding of a label - */ - def encodeLabeledPoint(labeledPoint: LabeledPoint, labelCount: Int): (Vector, Vector) = { - val output = Array.fill(labelCount)(0.0) - output(labeledPoint.label.toInt) = 1.0 - (labeledPoint.features, Vectors.dense(output)) - } - - /** - * Converts a vector to a label. - * Returns the position of the maximal element of a vector. - * - * @param output label encoded with a vector - * @return label - */ - def decodeLabel(output: Vector): Double = { - output.argmax.toDouble - } -} - /** * Classifier trainer based on the Multilayer Perceptron. * Each layer has sigmoid activation function, output layer has softmax. @@ -230,9 +201,11 @@ class MultilayerPerceptronClassifier @Since("1.5.0") ( * @param dataset Training dataset * @return Fitted model */ - override protected def train(dataset: Dataset[_]): MultilayerPerceptronClassificationModel = { - val instr = Instrumentation.create(this, dataset) - instr.logParams(labelCol, featuresCol, predictionCol, layers, maxIter, tol, + override protected def train( + dataset: Dataset[_]): MultilayerPerceptronClassificationModel = instrumented { instr => + instr.logPipelineStage(this) + instr.logDataset(dataset) + instr.logParams(this, labelCol, featuresCol, predictionCol, layers, maxIter, tol, blockSize, solver, stepSize, seed) val myLayers = $(layers) @@ -240,8 +213,18 @@ class MultilayerPerceptronClassifier @Since("1.5.0") ( instr.logNumClasses(labels) instr.logNumFeatures(myLayers.head) - val lpData = extractLabeledPoints(dataset) - val data = lpData.map(lp => LabelConverter.encodeLabeledPoint(lp, labels)) + // One-hot encoding for labels using OneHotEncoderModel. + // As we already know the length of encoding, we skip fitting and directly create + // the model. + val encodedLabelCol = "_encoded" + $(labelCol) + val encodeModel = new OneHotEncoderModel(uid, Array(labels)) + .setInputCols(Array($(labelCol))) + .setOutputCols(Array(encodedLabelCol)) + .setDropLast(false) + val encodedDataset = encodeModel.transform(dataset) + val data = encodedDataset.select($(featuresCol), encodedLabelCol).rdd.map { + case Row(features: Vector, encodedLabel: Vector) => (features, encodedLabel) + } val topology = FeedForwardTopology.multiLayerPerceptron(myLayers, softmaxOnTop = true) val trainer = new FeedForwardTrainer(topology, myLayers(0), myLayers.last) if (isDefined(initialWeights)) { @@ -264,10 +247,7 @@ class MultilayerPerceptronClassifier @Since("1.5.0") ( } trainer.setStackSize($(blockSize)) val mlpModel = trainer.train(data) - val model = new MultilayerPerceptronClassificationModel(uid, myLayers, mlpModel.weights) - - instr.logSuccess(model) - model + new MultilayerPerceptronClassificationModel(uid, myLayers, mlpModel.weights) } } @@ -323,7 +303,7 @@ class MultilayerPerceptronClassificationModel private[ml] ( * This internal method is used to implement `transform()` and output [[predictionCol]]. */ override def predict(features: Vector): Double = { - LabelConverter.decodeLabel(mlpModel.predict(features)) + mlpModel.predict(features).argmax.toDouble } @Since("1.5.0") diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala index 1dde18d2d1a31..51495c1a74e69 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala @@ -25,6 +25,7 @@ import org.apache.spark.ml.linalg._ import org.apache.spark.ml.param.{DoubleParam, Param, ParamMap, ParamValidators} import org.apache.spark.ml.param.shared.HasWeightCol import org.apache.spark.ml.util._ +import org.apache.spark.ml.util.Instrumentation.instrumented import org.apache.spark.mllib.util.MLUtils import org.apache.spark.sql.{Dataset, Row} import org.apache.spark.sql.functions.{col, lit} @@ -125,8 +126,9 @@ class NaiveBayes @Since("1.5.0") ( */ private[spark] def trainWithLabelCheck( dataset: Dataset[_], - positiveLabel: Boolean): NaiveBayesModel = { - val instr = Instrumentation.create(this, dataset) + positiveLabel: Boolean): NaiveBayesModel = instrumented { instr => + instr.logPipelineStage(this) + instr.logDataset(dataset) if (positiveLabel && isDefined(thresholds)) { val numClasses = getNumClasses(dataset) instr.logNumClasses(numClasses) @@ -148,7 +150,7 @@ class NaiveBayes @Since("1.5.0") ( } } - instr.logParams(labelCol, featuresCol, weightCol, predictionCol, rawPredictionCol, + instr.logParams(this, labelCol, featuresCol, weightCol, predictionCol, rawPredictionCol, probabilityCol, modelType, smoothing, thresholds) val numFeatures = dataset.select(col($(featuresCol))).head().getAs[Vector](0).size @@ -160,19 +162,21 @@ class NaiveBayes @Since("1.5.0") ( // TODO: similar to reduceByKeyLocally to save one stage. val aggregated = dataset.select(col($(labelCol)), w, col($(featuresCol))).rdd .map { row => (row.getDouble(0), (row.getDouble(1), row.getAs[Vector](2))) - }.aggregateByKey[(Double, DenseVector)]((0.0, Vectors.zeros(numFeatures).toDense))( + }.aggregateByKey[(Double, DenseVector, Long)]((0.0, Vectors.zeros(numFeatures).toDense, 0L))( seqOp = { - case ((weightSum: Double, featureSum: DenseVector), (weight, features)) => + case ((weightSum, featureSum, count), (weight, features)) => requireValues(features) BLAS.axpy(weight, features, featureSum) - (weightSum + weight, featureSum) + (weightSum + weight, featureSum, count + 1) }, combOp = { - case ((weightSum1, featureSum1), (weightSum2, featureSum2)) => + case ((weightSum1, featureSum1, count1), (weightSum2, featureSum2, count2)) => BLAS.axpy(1.0, featureSum2, featureSum1) - (weightSum1 + weightSum2, featureSum1) + (weightSum1 + weightSum2, featureSum1, count1 + count2) }).collect().sortBy(_._1) + val numSamples = aggregated.map(_._2._3).sum + instr.logNumExamples(numSamples) val numLabels = aggregated.length instr.logNumClasses(numLabels) val numDocuments = aggregated.map(_._2._1).sum @@ -184,7 +188,7 @@ class NaiveBayes @Since("1.5.0") ( val lambda = $(smoothing) val piLogDenom = math.log(numDocuments + numLabels * lambda) var i = 0 - aggregated.foreach { case (label, (n, sumTermFreqs)) => + aggregated.foreach { case (label, (n, sumTermFreqs, _)) => labelArray(i) = label piArray(i) = math.log(n + lambda) - piLogDenom val thetaLogDenom = $(modelType) match { @@ -204,9 +208,7 @@ class NaiveBayes @Since("1.5.0") ( val pi = Vectors.dense(piArray) val theta = new DenseMatrix(numLabels, numFeatures, thetaArray, true) - val model = new NaiveBayesModel(uid, pi, theta).setOldLabels(labelArray) - instr.logSuccess(model) - model + new NaiveBayesModel(uid, pi, theta).setOldLabels(labelArray) } @Since("1.5.0") diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala index 3474b61e40136..1835a91775e0a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala @@ -36,6 +36,7 @@ import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.{Param, ParamMap, ParamPair, Params} import org.apache.spark.ml.param.shared.{HasParallelism, HasWeightCol} import org.apache.spark.ml.util._ +import org.apache.spark.ml.util.Instrumentation.instrumented import org.apache.spark.sql.{DataFrame, Dataset, Row} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ @@ -362,11 +363,12 @@ final class OneVsRest @Since("1.4.0") ( } @Since("2.0.0") - override def fit(dataset: Dataset[_]): OneVsRestModel = { + override def fit(dataset: Dataset[_]): OneVsRestModel = instrumented { instr => transformSchema(dataset.schema) - val instr = Instrumentation.create(this, dataset) - instr.logParams(labelCol, featuresCol, predictionCol, parallelism, rawPredictionCol) + instr.logPipelineStage(this) + instr.logDataset(dataset) + instr.logParams(this, labelCol, featuresCol, predictionCol, parallelism, rawPredictionCol) instr.logNamedValue("classifier", $(classifier).getClass.getCanonicalName) // determine number of classes either from metadata if provided, or via computation. @@ -440,7 +442,6 @@ final class OneVsRest @Since("1.4.0") ( case attr: Attribute => attr } val model = new OneVsRestModel(uid, labelAttribute.toMetadata(), models).setParent(this) - instr.logSuccess(model) copyValues(model) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala index 040db3b94b041..94887ac346fec 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala @@ -28,6 +28,7 @@ import org.apache.spark.ml.tree._ import org.apache.spark.ml.tree.impl.RandomForest import org.apache.spark.ml.util._ import org.apache.spark.ml.util.DefaultParamsReader.Metadata +import org.apache.spark.ml.util.Instrumentation.instrumented import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} import org.apache.spark.mllib.tree.model.{RandomForestModel => OldRandomForestModel} import org.apache.spark.rdd.RDD @@ -115,8 +116,10 @@ class RandomForestClassifier @Since("1.4.0") ( override def setFeatureSubsetStrategy(value: String): this.type = set(featureSubsetStrategy, value) - override protected def train(dataset: Dataset[_]): RandomForestClassificationModel = { - val instr = Instrumentation.create(this, dataset) + override protected def train( + dataset: Dataset[_]): RandomForestClassificationModel = instrumented { instr => + instr.logPipelineStage(this) + instr.logDataset(dataset) val categoricalFeatures: Map[Int, Int] = MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol))) val numClasses: Int = getNumClasses(dataset) @@ -131,7 +134,7 @@ class RandomForestClassifier @Since("1.4.0") ( val strategy = super.getOldStrategy(categoricalFeatures, numClasses, OldAlgo.Classification, getOldImpurity) - instr.logParams(labelCol, featuresCol, predictionCol, probabilityCol, rawPredictionCol, + instr.logParams(this, labelCol, featuresCol, predictionCol, probabilityCol, rawPredictionCol, impurity, numTrees, featureSubsetStrategy, maxDepth, maxBins, maxMemoryInMB, minInfoGain, minInstancesPerNode, seed, subsamplingRate, thresholds, cacheNodeIds, checkpointInterval) @@ -140,11 +143,9 @@ class RandomForestClassifier @Since("1.4.0") ( .map(_.asInstanceOf[DecisionTreeClassificationModel]) val numFeatures = oldDataset.first().features.size - val m = new RandomForestClassificationModel(uid, trees, numFeatures, numClasses) instr.logNumClasses(numClasses) instr.logNumFeatures(numFeatures) - instr.logSuccess(m) - m + new RandomForestClassificationModel(uid, trees, numFeatures, numClasses) } @Since("1.4.1") diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala index 9c9614509c64f..5cb16cc765887 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala @@ -26,6 +26,7 @@ import org.apache.spark.ml.linalg.Vector import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util._ +import org.apache.spark.ml.util.Instrumentation.instrumented import org.apache.spark.mllib.clustering.{BisectingKMeans => MLlibBisectingKMeans, BisectingKMeansModel => MLlibBisectingKMeansModel} import org.apache.spark.mllib.linalg.VectorImplicits._ @@ -103,10 +104,6 @@ class BisectingKMeansModel private[ml] ( @Since("2.1.0") def setPredictionCol(value: String): this.type = set(predictionCol, value) - /** @group expertSetParam */ - @Since("2.4.0") - def setDistanceMeasure(value: String): this.type = set(distanceMeasure, value) - @Since("2.0.0") override def transform(dataset: Dataset[_]): DataFrame = { transformSchema(dataset.schema, logging = true) @@ -257,12 +254,13 @@ class BisectingKMeans @Since("2.0.0") ( def setDistanceMeasure(value: String): this.type = set(distanceMeasure, value) @Since("2.0.0") - override def fit(dataset: Dataset[_]): BisectingKMeansModel = { + override def fit(dataset: Dataset[_]): BisectingKMeansModel = instrumented { instr => transformSchema(dataset.schema, logging = true) val rdd = DatasetUtils.columnToOldVector(dataset, getFeaturesCol) - val instr = Instrumentation.create(this, dataset) - instr.logParams(featuresCol, predictionCol, k, maxIter, seed, + instr.logPipelineStage(this) + instr.logDataset(dataset) + instr.logParams(this, featuresCol, predictionCol, k, maxIter, seed, minDivisibleClusterSize, distanceMeasure) val bkm = new MLlibBisectingKMeans() @@ -271,14 +269,13 @@ class BisectingKMeans @Since("2.0.0") ( .setMinDivisibleClusterSize($(minDivisibleClusterSize)) .setSeed($(seed)) .setDistanceMeasure($(distanceMeasure)) - val parentModel = bkm.run(rdd) + val parentModel = bkm.run(rdd, Some(instr)) val model = copyValues(new BisectingKMeansModel(uid, parentModel).setParent(this)) val summary = new BisectingKMeansSummary( - model.transform(dataset), $(predictionCol), $(featuresCol), $(k)) - model.setSummary(Some(summary)) + model.transform(dataset), $(predictionCol), $(featuresCol), $(k), $(maxIter)) instr.logNamedValue("clusterSizes", summary.clusterSizes) - instr.logSuccess(model) - model + instr.logNumFeatures(model.clusterCenters.head.size) + model.setSummary(Some(summary)) } @Since("2.0.0") @@ -304,6 +301,7 @@ object BisectingKMeans extends DefaultParamsReadable[BisectingKMeans] { * @param predictionCol Name for column of predicted clusters in `predictions`. * @param featuresCol Name for column of features in `predictions`. * @param k Number of clusters. + * @param numIter Number of iterations. */ @Since("2.1.0") @Experimental @@ -311,4 +309,5 @@ class BisectingKMeansSummary private[clustering] ( predictions: DataFrame, predictionCol: String, featuresCol: String, - k: Int) extends ClusteringSummary(predictions, predictionCol, featuresCol, k) + k: Int, + numIter: Int) extends ClusteringSummary(predictions, predictionCol, featuresCol, k, numIter) diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/ClusteringSummary.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/ClusteringSummary.scala index 44e832b058b62..7da4c43a1abf3 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/ClusteringSummary.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/ClusteringSummary.scala @@ -17,7 +17,7 @@ package org.apache.spark.ml.clustering -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.sql.{DataFrame, Row} /** @@ -28,13 +28,15 @@ import org.apache.spark.sql.{DataFrame, Row} * @param predictionCol Name for column of predicted clusters in `predictions`. * @param featuresCol Name for column of features in `predictions`. * @param k Number of clusters. + * @param numIter Number of iterations. */ @Experimental class ClusteringSummary private[clustering] ( @transient val predictions: DataFrame, val predictionCol: String, val featuresCol: String, - val k: Int) extends Serializable { + val k: Int, + @Since("2.4.0") val numIter: Int) extends Serializable { /** * Cluster centers of the transformed data. diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala index 64ecc1ebda589..88abc1605d69f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala @@ -29,6 +29,7 @@ import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.stat.distribution.MultivariateGaussian import org.apache.spark.ml.util._ +import org.apache.spark.ml.util.Instrumentation.instrumented import org.apache.spark.mllib.linalg.{Matrices => OldMatrices, Matrix => OldMatrix, Vector => OldVector, Vectors => OldVectors} import org.apache.spark.rdd.RDD @@ -335,13 +336,13 @@ class GaussianMixture @Since("2.0.0") ( private val numSamples = 5 @Since("2.0.0") - override def fit(dataset: Dataset[_]): GaussianMixtureModel = { + override def fit(dataset: Dataset[_]): GaussianMixtureModel = instrumented { instr => transformSchema(dataset.schema, logging = true) val sc = dataset.sparkSession.sparkContext val numClusters = $(k) - val instances: RDD[Vector] = dataset + val instances = dataset .select(DatasetUtils.columnToVector(dataset, getFeaturesCol)).rdd.map { case Row(features: Vector) => features }.cache() @@ -352,8 +353,9 @@ class GaussianMixture @Since("2.0.0") ( s"than ${GaussianMixture.MAX_NUM_FEATURES} features because the size of the covariance" + s" matrix is quadratic in the number of features.") - val instr = Instrumentation.create(this, dataset) - instr.logParams(featuresCol, predictionCol, probabilityCol, k, maxIter, seed, tol) + instr.logPipelineStage(this) + instr.logDataset(dataset) + instr.logParams(this, featuresCol, predictionCol, probabilityCol, k, maxIter, seed, tol) instr.logNumFeatures(numFeatures) val shouldDistributeGaussians = GaussianMixture.shouldDistributeGaussians( @@ -384,6 +386,11 @@ class GaussianMixture @Since("2.0.0") ( bcWeights.destroy(blocking = false) bcGaussians.destroy(blocking = false) + if (iter == 0) { + val numSamples = sums.count + instr.logNumExamples(numSamples) + } + /* Create new distributions based on the partial assignments (often referred to as the "M" step in literature) @@ -416,6 +423,7 @@ class GaussianMixture @Since("2.0.0") ( iter += 1 } + instances.unpersist(false) val gaussianDists = gaussians.map { case (mean, covVec) => val cov = GaussianMixture.unpackUpperTriangularMatrix(numFeatures, covVec.values) new MultivariateGaussian(mean, cov) @@ -423,12 +431,10 @@ class GaussianMixture @Since("2.0.0") ( val model = copyValues(new GaussianMixtureModel(uid, weights, gaussianDists)).setParent(this) val summary = new GaussianMixtureSummary(model.transform(dataset), - $(predictionCol), $(probabilityCol), $(featuresCol), $(k), logLikelihood) - model.setSummary(Some(summary)) + $(predictionCol), $(probabilityCol), $(featuresCol), $(k), logLikelihood, iter) instr.logNamedValue("logLikelihood", logLikelihood) instr.logNamedValue("clusterSizes", summary.clusterSizes) - instr.logSuccess(model) - model + model.setSummary(Some(summary)) } @Since("2.0.0") @@ -687,6 +693,7 @@ private class ExpectationAggregator( * @param featuresCol Name for column of features in `predictions`. * @param k Number of clusters. * @param logLikelihood Total log-likelihood for this model on the given data. + * @param numIter Number of iterations. */ @Since("2.0.0") @Experimental @@ -696,8 +703,9 @@ class GaussianMixtureSummary private[clustering] ( @Since("2.0.0") val probabilityCol: String, featuresCol: String, k: Int, - @Since("2.2.0") val logLikelihood: Double) - extends ClusteringSummary(predictions, predictionCol, featuresCol, k) { + @Since("2.2.0") val logLikelihood: Double, + numIter: Int) + extends ClusteringSummary(predictions, predictionCol, featuresCol, k, numIter) { /** * Probability of each cluster. diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala index 1704412741d49..498310d6644e1 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala @@ -28,6 +28,7 @@ import org.apache.spark.ml.linalg.Vector import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util._ +import org.apache.spark.ml.util.Instrumentation.instrumented import org.apache.spark.mllib.clustering.{DistanceMeasure, KMeans => MLlibKMeans, KMeansModel => MLlibKMeansModel} import org.apache.spark.mllib.linalg.{Vector => OldVector, Vectors => OldVectors} import org.apache.spark.mllib.linalg.VectorImplicits._ @@ -145,8 +146,12 @@ class KMeansModel private[ml] ( /** * Return the K-means cost (sum of squared distances of points to their nearest center) for this * model on the given data. + * + * @deprecated This method is deprecated and will be removed in 3.0.0. Use ClusteringEvaluator + * instead. You can also get the cost on the training dataset in the summary. */ - // TODO: Replace the temp fix when we have proper evaluators defined for clustering. + @deprecated("This method is deprecated and will be removed in 3.0.0. Use ClusteringEvaluator " + + "instead. You can also get the cost on the training dataset in the summary.", "2.4.0") @Since("2.0.0") def computeCost(dataset: Dataset[_]): Double = { SchemaUtils.validateVectorCompatibleColumn(dataset.schema, getFeaturesCol) @@ -332,7 +337,7 @@ class KMeans @Since("1.5.0") ( def setSeed(value: Long): this.type = set(seed, value) @Since("2.0.0") - override def fit(dataset: Dataset[_]): KMeansModel = { + override def fit(dataset: Dataset[_]): KMeansModel = instrumented { instr => transformSchema(dataset.schema, logging = true) val handlePersistence = dataset.storageLevel == StorageLevel.NONE @@ -342,8 +347,9 @@ class KMeans @Since("1.5.0") ( instances.persist(StorageLevel.MEMORY_AND_DISK) } - val instr = Instrumentation.create(this, dataset) - instr.logParams(featuresCol, predictionCol, k, initMode, initSteps, distanceMeasure, + instr.logPipelineStage(this) + instr.logDataset(dataset) + instr.logParams(this, featuresCol, predictionCol, k, initMode, initSteps, distanceMeasure, maxIter, seed, tol) val algo = new MLlibKMeans() .setK($(k)) @@ -356,11 +362,15 @@ class KMeans @Since("1.5.0") ( val parentModel = algo.run(instances, Option(instr)) val model = copyValues(new KMeansModel(uid, parentModel).setParent(this)) val summary = new KMeansSummary( - model.transform(dataset), $(predictionCol), $(featuresCol), $(k)) + model.transform(dataset), + $(predictionCol), + $(featuresCol), + $(k), + parentModel.numIter, + parentModel.trainingCost) model.setSummary(Some(summary)) instr.logNamedValue("clusterSizes", summary.clusterSizes) - instr.logSuccess(model) if (handlePersistence) { instances.unpersist() } @@ -388,6 +398,9 @@ object KMeans extends DefaultParamsReadable[KMeans] { * @param predictionCol Name for column of predicted clusters in `predictions`. * @param featuresCol Name for column of features in `predictions`. * @param k Number of clusters. + * @param numIter Number of iterations. + * @param trainingCost K-means cost (sum of squared distances to the nearest centroid for all + * points in the training dataset). This is equivalent to sklearn's inertia. */ @Since("2.0.0") @Experimental @@ -395,4 +408,7 @@ class KMeansSummary private[clustering] ( predictions: DataFrame, predictionCol: String, featuresCol: String, - k: Int) extends ClusteringSummary(predictions, predictionCol, featuresCol, k) + k: Int, + numIter: Int, + @Since("2.4.0") val trainingCost: Double) + extends ClusteringSummary(predictions, predictionCol, featuresCol, k, numIter) diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala index fed42c959b5ef..50867f776c522 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala @@ -32,6 +32,7 @@ import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared.{HasCheckpointInterval, HasFeaturesCol, HasMaxIter, HasSeed} import org.apache.spark.ml.util._ import org.apache.spark.ml.util.DefaultParamsReader.Metadata +import org.apache.spark.ml.util.Instrumentation.instrumented import org.apache.spark.mllib.clustering.{DistributedLDAModel => OldDistributedLDAModel, EMLDAOptimizer => OldEMLDAOptimizer, LDA => OldLDA, LDAModel => OldLDAModel, LDAOptimizer => OldLDAOptimizer, LocalLDAModel => OldLocalLDAModel, @@ -896,11 +897,12 @@ class LDA @Since("1.6.0") ( override def copy(extra: ParamMap): LDA = defaultCopy(extra) @Since("2.0.0") - override def fit(dataset: Dataset[_]): LDAModel = { + override def fit(dataset: Dataset[_]): LDAModel = instrumented { instr => transformSchema(dataset.schema, logging = true) - val instr = Instrumentation.create(this, dataset) - instr.logParams(featuresCol, topicDistributionCol, k, maxIter, subsamplingRate, + instr.logPipelineStage(this) + instr.logDataset(dataset) + instr.logParams(this, featuresCol, topicDistributionCol, k, maxIter, subsamplingRate, checkpointInterval, keepLastCheckpoint, optimizeDocConcentration, topicConcentration, learningDecay, optimizer, learningOffset, seed) @@ -923,9 +925,7 @@ class LDA @Since("1.6.0") ( } instr.logNumFeatures(newModel.vocabSize) - val model = copyValues(newModel).setParent(this) - instr.logSuccess(model) - model + copyValues(newModel).setParent(this) } @Since("1.6.0") diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/ClusteringEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/ClusteringEvaluator.scala index 4353c46781e9d..5c1d1aebdc315 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/ClusteringEvaluator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/ClusteringEvaluator.scala @@ -21,11 +21,10 @@ import org.apache.spark.SparkContext import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.broadcast.Broadcast import org.apache.spark.ml.attribute.AttributeGroup -import org.apache.spark.ml.linalg.{BLAS, DenseVector, SparseVector, Vector, Vectors, VectorUDT} +import org.apache.spark.ml.linalg.{BLAS, DenseVector, SparseVector, Vector, Vectors} import org.apache.spark.ml.param.{Param, ParamMap, ParamValidators} import org.apache.spark.ml.param.shared.{HasFeaturesCol, HasPredictionCol} -import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable, - SchemaUtils} +import org.apache.spark.ml.util._ import org.apache.spark.sql.{Column, DataFrame, Dataset} import org.apache.spark.sql.functions.{avg, col, udf} import org.apache.spark.sql.types.DoubleType @@ -107,15 +106,21 @@ class ClusteringEvaluator @Since("2.3.0") (@Since("2.3.0") override val uid: Str @Since("2.3.0") override def evaluate(dataset: Dataset[_]): Double = { - SchemaUtils.checkColumnType(dataset.schema, $(featuresCol), new VectorUDT) + SchemaUtils.validateVectorCompatibleColumn(dataset.schema, $(featuresCol)) SchemaUtils.checkNumericType(dataset.schema, $(predictionCol)) + val vectorCol = DatasetUtils.columnToVector(dataset, $(featuresCol)) + val df = dataset.select(col($(predictionCol)), + vectorCol.as($(featuresCol), dataset.schema($(featuresCol)).metadata)) + ($(metricName), $(distanceMeasure)) match { case ("silhouette", "squaredEuclidean") => SquaredEuclideanSilhouette.computeSilhouetteScore( - dataset, $(predictionCol), $(featuresCol)) + df, $(predictionCol), $(featuresCol)) case ("silhouette", "cosine") => - CosineSilhouette.computeSilhouetteScore(dataset, $(predictionCol), $(featuresCol)) + CosineSilhouette.computeSilhouetteScore(df, $(predictionCol), $(featuresCol)) + case (mn, dm) => + throw new IllegalArgumentException(s"No support for metric $mn, distance $dm") } } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSH.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSH.scala index a906e954fecd5..0554455a66d7f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSH.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSH.scala @@ -82,14 +82,12 @@ class BucketedRandomProjectionLSHModel private[ml]( override def setOutputCol(value: String): this.type = super.set(outputCol, value) @Since("2.1.0") - override protected[ml] val hashFunction: Vector => Array[Vector] = { - key: Vector => { - val hashValues: Array[Double] = randUnitVectors.map({ - randUnitVector => Math.floor(BLAS.dot(key, randUnitVector) / $(bucketLength)) - }) - // TODO: Output vectors of dimension numHashFunctions in SPARK-18450 - hashValues.map(Vectors.dense(_)) - } + override protected[ml] def hashFunction(elems: Vector): Array[Vector] = { + val hashValues = randUnitVectors.map( + randUnitVector => Math.floor(BLAS.dot(elems, randUnitVector) / $(bucketLength)) + ) + // TODO: Output vectors of dimension numHashFunctions in SPARK-18450 + hashValues.map(Vectors.dense(_)) } @Since("2.1.0") diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala index 10c48c3f52085..dc8eb8261dbe2 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala @@ -21,6 +21,7 @@ import org.apache.hadoop.fs.Path import org.apache.spark.annotation.Since import org.apache.spark.broadcast.Broadcast import org.apache.spark.ml.{Estimator, Model} +import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, NumericAttribute} import org.apache.spark.ml.linalg.{Vectors, VectorUDT} import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol} @@ -317,7 +318,9 @@ class CountVectorizerModel( Vectors.sparse(dictBr.value.size, effectiveCounts) } - dataset.withColumn($(outputCol), vectorizer(col($(inputCol)))) + val attrs = vocabulary.map(_ => new NumericAttribute).asInstanceOf[Array[Attribute]] + val metadata = new AttributeGroup($(outputCol), attrs).toMetadata() + dataset.withColumn($(outputCol), vectorizer(col($(inputCol))), metadata) } @Since("1.5.0") diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/DCT.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/DCT.scala index 682787a830113..32d98151bdcff 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/DCT.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/DCT.scala @@ -69,7 +69,8 @@ class DCT @Since("1.5.0") (@Since("1.5.0") override val uid: String) } override protected def validateInputType(inputType: DataType): Unit = { - require(inputType.isInstanceOf[VectorUDT], s"Input type must be VectorUDT but got $inputType.") + require(inputType.isInstanceOf[VectorUDT], + s"Input type must be ${(new VectorUDT).catalogString} but got ${inputType.catalogString}.") } override protected def outputDataType: DataType = new VectorUDT diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/FeatureHasher.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/FeatureHasher.scala index d67e4819b161a..dc18e1d34880a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/FeatureHasher.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/FeatureHasher.scala @@ -29,7 +29,7 @@ import org.apache.spark.mllib.feature.{HashingTF => OldHashingTF} import org.apache.spark.sql.{DataFrame, Dataset, Row} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.hash.Murmur3_x86_32.{hashInt, hashLong, hashUnsafeBytes2Block} +import org.apache.spark.unsafe.hash.Murmur3_x86_32.{hashInt, hashLong, hashUnsafeBytes2} import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.Utils import org.apache.spark.util.collection.OpenHashMap @@ -208,8 +208,9 @@ class FeatureHasher(@Since("2.3.0") override val uid: String) extends Transforme require(dataType.isInstanceOf[NumericType] || dataType.isInstanceOf[StringType] || dataType.isInstanceOf[BooleanType], - s"FeatureHasher requires columns to be of NumericType, BooleanType or StringType. " + - s"Column $fieldName was $dataType") + s"FeatureHasher requires columns to be of ${NumericType.simpleString}, " + + s"${BooleanType.catalogString} or ${StringType.catalogString}. " + + s"Column $fieldName was ${dataType.catalogString}") } val attrGroup = new AttributeGroup($(outputCol), $(numFeatures)) SchemaUtils.appendColumn(schema, attrGroup.toStructField()) @@ -243,7 +244,8 @@ object FeatureHasher extends DefaultParamsReadable[FeatureHasher] { case f: Float => hashInt(java.lang.Float.floatToIntBits(f), seed) case d: Double => hashLong(java.lang.Double.doubleToLongBits(d), seed) case s: String => - hashUnsafeBytes2Block(UTF8String.fromString(s).getMemoryBlock, seed) + val utf8 = UTF8String.fromString(s) + hashUnsafeBytes2(utf8.getBaseObject, utf8.getBaseOffset, utf8.numBytes(), seed) case _ => throw new SparkException("FeatureHasher with murmur3 algorithm does not " + s"support type ${term.getClass.getCanonicalName} of input data.") } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala index db432b6fefaff..dbda5b8d8fd4a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala @@ -104,7 +104,7 @@ class HashingTF @Since("1.4.0") (@Since("1.4.0") override val uid: String) override def transformSchema(schema: StructType): StructType = { val inputType = schema($(inputCol)).dataType require(inputType.isInstanceOf[ArrayType], - s"The input column must be ArrayType, but got $inputType.") + s"The input column must be ${ArrayType.simpleString}, but got ${inputType.catalogString}.") val attrGroup = new AttributeGroup($(outputCol), $(numFeatures)) SchemaUtils.appendColumn(schema, attrGroup.toStructField()) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala index 4ff1d0ef356f3..611f1b691b782 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala @@ -261,7 +261,8 @@ private[ml] class FeatureEncoder(numFeatures: Array[Int]) extends Serializable { */ def foreachNonzeroOutput(value: Any, f: (Int, Double) => Unit): Unit = value match { case d: Double => - assert(numFeatures.length == 1, "DoubleType columns should only contain one feature.") + assert(numFeatures.length == 1, + s"${DoubleType.catalogString} columns should only contain one feature.") val numOutputCols = numFeatures.head if (numOutputCols > 1) { assert( diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/LSH.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/LSH.scala index a70931f783f45..b20852383a6ff 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/LSH.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/LSH.scala @@ -75,7 +75,7 @@ private[ml] abstract class LSHModel[T <: LSHModel[T]] * The hash function of LSH, mapping an input feature vector to multiple hash vectors. * @return The mapping of LSH function. */ - protected[ml] val hashFunction: Vector => Array[Vector] + protected[ml] def hashFunction(elems: Vector): Array[Vector] /** * Calculate the distance between two different keys using the distance metric corresponding @@ -97,7 +97,7 @@ private[ml] abstract class LSHModel[T <: LSHModel[T]] override def transform(dataset: Dataset[_]): DataFrame = { transformSchema(dataset.schema, logging = true) - val transformUDF = udf(hashFunction, DataTypes.createArrayType(new VectorUDT)) + val transformUDF = udf(hashFunction(_: Vector), DataTypes.createArrayType(new VectorUDT)) dataset.withColumn($(outputCol), transformUDF(dataset($(inputCol)))) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/MinHashLSH.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/MinHashLSH.scala index a67a3b0abbc1f..21cde66d8db6b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/MinHashLSH.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/MinHashLSH.scala @@ -60,18 +60,16 @@ class MinHashLSHModel private[ml]( override def setOutputCol(value: String): this.type = super.set(outputCol, value) @Since("2.1.0") - override protected[ml] val hashFunction: Vector => Array[Vector] = { - elems: Vector => { - require(elems.numNonzeros > 0, "Must have at least 1 non zero entry.") - val elemsList = elems.toSparse.indices.toList - val hashValues = randCoefficients.map { case (a, b) => - elemsList.map { elem: Int => - ((1 + elem) * a + b) % MinHashLSH.HASH_PRIME - }.min.toDouble - } - // TODO: Output vectors of dimension numHashFunctions in SPARK-18450 - hashValues.map(Vectors.dense(_)) + override protected[ml] def hashFunction(elems: Vector): Array[Vector] = { + require(elems.numNonzeros > 0, "Must have at least 1 non zero entry.") + val elemsList = elems.toSparse.indices.toList + val hashValues = randCoefficients.map { case (a, b) => + elemsList.map { elem: Int => + ((1L + elem) * a + b) % MinHashLSH.HASH_PRIME + }.min.toDouble } + // TODO: Output vectors of dimension numHashFunctions in SPARK-18450 + hashValues.map(Vectors.dense(_)) } @Since("2.1.0") diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/NGram.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/NGram.scala index c8760f9dc178f..e0772d5af20a9 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/NGram.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/NGram.scala @@ -65,7 +65,8 @@ class NGram @Since("1.5.0") (@Since("1.5.0") override val uid: String) override protected def validateInputType(inputType: DataType): Unit = { require(inputType.sameType(ArrayType(StringType)), - s"Input type must be ArrayType(StringType) but got $inputType.") + s"Input type must be ${ArrayType(StringType).catalogString} but got " + + inputType.catalogString) } override protected def outputDataType: DataType = new ArrayType(StringType, false) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala index 5ab6c2dde667a..27e4869a020b7 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala @@ -85,7 +85,8 @@ class OneHotEncoder @Since("1.4.0") (@Since("1.4.0") override val uid: String) e val inputFields = schema.fields require(schema(inputColName).dataType.isInstanceOf[NumericType], - s"Input column must be of type NumericType but got ${schema(inputColName).dataType}") + s"Input column must be of type ${NumericType.simpleString} but got " + + schema(inputColName).dataType.catalogString) require(!inputFields.exists(_.name == outputColName), s"Output column $outputColName already exists.") diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala index 55e595eee6ffb..346e1823f00b8 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala @@ -394,7 +394,7 @@ class RFormulaModel private[feature]( require(!columnNames.contains($(featuresCol)), "Features column already exists.") require( !columnNames.contains($(labelCol)) || schema($(labelCol)).dataType.isInstanceOf[NumericType], - "Label column already exists and is not of type NumericType.") + s"Label column already exists and is not of type ${NumericType.simpleString}.") } @Since("2.0.0") diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala index 0f946dd2e015b..94640a5cbe310 100755 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala @@ -131,8 +131,8 @@ class StopWordsRemover @Since("1.5.0") (@Since("1.5.0") override val uid: String @Since("1.5.0") override def transformSchema(schema: StructType): StructType = { val inputType = schema($(inputCol)).dataType - require(inputType.sameType(ArrayType(StringType)), - s"Input type must be ArrayType(StringType) but got $inputType.") + require(inputType.sameType(ArrayType(StringType)), "Input type must be " + + s"${ArrayType(StringType).catalogString} but got ${inputType.catalogString}.") SchemaUtils.appendColumn(schema, $(outputCol), inputType, schema($(inputCol)).nullable) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala index cfaf6c0e610b3..aede1f812a552 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala @@ -40,7 +40,8 @@ class Tokenizer @Since("1.4.0") (@Since("1.4.0") override val uid: String) } override protected def validateInputType(inputType: DataType): Unit = { - require(inputType == StringType, s"Input type must be string type but got $inputType.") + require(inputType == StringType, + s"Input type must be ${StringType.catalogString} type but got ${inputType.catalogString}.") } override protected def outputDataType: DataType = new ArrayType(StringType, true) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala index 4061154b39c14..57e23d5072b88 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala @@ -162,7 +162,7 @@ class VectorAssembler @Since("1.4.0") (@Since("1.4.0") override val uid: String) schema(name).dataType match { case _: NumericType | BooleanType => None case t if t.isInstanceOf[VectorUDT] => None - case other => Some(s"Data type $other of column $name is not supported.") + case other => Some(s"Data type ${other.catalogString} of column $name is not supported.") } } if (incorrectColumns.nonEmpty) { diff --git a/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala b/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala index d7fbe28ae7a64..840a89b76d26b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala @@ -20,12 +20,15 @@ package org.apache.spark.ml.fpm import scala.reflect.ClassTag import org.apache.hadoop.fs.Path +import org.json4s.{DefaultFormats, JObject} +import org.json4s.JsonDSL._ import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared.HasPredictionCol import org.apache.spark.ml.util._ +import org.apache.spark.ml.util.Instrumentation.instrumented import org.apache.spark.mllib.fpm.{AssociationRules => MLlibAssociationRules, FPGrowth => MLlibFPGrowth} import org.apache.spark.mllib.fpm.FPGrowth.FreqItemset @@ -33,6 +36,7 @@ import org.apache.spark.sql._ import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ import org.apache.spark.storage.StorageLevel +import org.apache.spark.util.VersionUtils /** * Common params for FPGrowth and FPGrowthModel @@ -106,7 +110,7 @@ private[fpm] trait FPGrowthParams extends Params with HasPredictionCol { protected def validateAndTransformSchema(schema: StructType): StructType = { val inputType = schema($(itemsCol)).dataType require(inputType.isInstanceOf[ArrayType], - s"The input column must be ArrayType, but got $inputType.") + s"The input column must be ${ArrayType.simpleString}, but got ${inputType.catalogString}.") SchemaUtils.appendColumn(schema, $(predictionCol), schema($(itemsCol)).dataType) } } @@ -158,11 +162,12 @@ class FPGrowth @Since("2.2.0") ( genericFit(dataset) } - private def genericFit[T: ClassTag](dataset: Dataset[_]): FPGrowthModel = { + private def genericFit[T: ClassTag](dataset: Dataset[_]): FPGrowthModel = instrumented { instr => val handlePersistence = dataset.storageLevel == StorageLevel.NONE - val instr = Instrumentation.create(this, dataset) - instr.logParams(params: _*) + instr.logPipelineStage(this) + instr.logDataset(dataset) + instr.logParams(this, params: _*) val data = dataset.select($(itemsCol)) val items = data.where(col($(itemsCol)).isNotNull).rdd.map(r => r.getSeq[Any](0).toArray) val mllibFP = new MLlibFPGrowth().setMinSupport($(minSupport)) @@ -173,7 +178,8 @@ class FPGrowth @Since("2.2.0") ( if (handlePersistence) { items.persist(StorageLevel.MEMORY_AND_DISK) } - + val inputRowCount = items.count() + instr.logNumExamples(inputRowCount) val parentModel = mllibFP.run(items) val rows = parentModel.freqItemsets.map(f => Row(f.items, f.freq)) val schema = StructType(Seq( @@ -185,9 +191,8 @@ class FPGrowth @Since("2.2.0") ( items.unpersist() } - val model = copyValues(new FPGrowthModel(uid, frequentItems)).setParent(this) - instr.logSuccess(model) - model + copyValues(new FPGrowthModel(uid, frequentItems, parentModel.itemSupport, inputRowCount)) + .setParent(this) } @Since("2.2.0") @@ -217,7 +222,9 @@ object FPGrowth extends DefaultParamsReadable[FPGrowth] { @Experimental class FPGrowthModel private[ml] ( @Since("2.2.0") override val uid: String, - @Since("2.2.0") @transient val freqItemsets: DataFrame) + @Since("2.2.0") @transient val freqItemsets: DataFrame, + private val itemSupport: scala.collection.Map[Any, Double], + private val numTrainingRecords: Long) extends Model[FPGrowthModel] with FPGrowthParams with MLWritable { /** @group setParam */ @@ -241,9 +248,9 @@ class FPGrowthModel private[ml] ( @transient private var _cachedRules: DataFrame = _ /** - * Get association rules fitted using the minConfidence. Returns a dataframe - * with three fields, "antecedent", "consequent" and "confidence", where "antecedent" and - * "consequent" are Array[T] and "confidence" is Double. + * Get association rules fitted using the minConfidence. Returns a dataframe with four fields, + * "antecedent", "consequent", "confidence" and "lift", where "antecedent" and "consequent" are + * Array[T], whereas "confidence" and "lift" are Double. */ @Since("2.2.0") @transient def associationRules: DataFrame = { @@ -251,7 +258,7 @@ class FPGrowthModel private[ml] ( _cachedRules } else { _cachedRules = AssociationRules - .getAssociationRulesFromFP(freqItemsets, "items", "freq", $(minConfidence)) + .getAssociationRulesFromFP(freqItemsets, "items", "freq", $(minConfidence), itemSupport) _cachedMinConf = $(minConfidence) _cachedRules } @@ -301,7 +308,7 @@ class FPGrowthModel private[ml] ( @Since("2.2.0") override def copy(extra: ParamMap): FPGrowthModel = { - val copied = new FPGrowthModel(uid, freqItemsets) + val copied = new FPGrowthModel(uid, freqItemsets, itemSupport, numTrainingRecords) copyValues(copied, extra).setParent(this.parent) } @@ -323,7 +330,8 @@ object FPGrowthModel extends MLReadable[FPGrowthModel] { class FPGrowthModelWriter(instance: FPGrowthModel) extends MLWriter { override protected def saveImpl(path: String): Unit = { - DefaultParamsWriter.saveMetadata(instance, path, sc) + val extraMetadata: JObject = Map("numTrainingRecords" -> instance.numTrainingRecords) + DefaultParamsWriter.saveMetadata(instance, path, sc, extraMetadata = Some(extraMetadata)) val dataPath = new Path(path, "data").toString instance.freqItemsets.write.parquet(dataPath) } @@ -335,10 +343,28 @@ object FPGrowthModel extends MLReadable[FPGrowthModel] { private val className = classOf[FPGrowthModel].getName override def load(path: String): FPGrowthModel = { + implicit val format = DefaultFormats val metadata = DefaultParamsReader.loadMetadata(path, sc, className) + val (major, minor) = VersionUtils.majorMinorVersion(metadata.sparkVersion) + val numTrainingRecords = if (major.toInt < 2 || (major.toInt == 2 && minor.toInt < 4)) { + // 2.3 and before don't store the count + 0L + } else { + // 2.4+ + (metadata.metadata \ "numTrainingRecords").extract[Long] + } val dataPath = new Path(path, "data").toString val frequentItems = sparkSession.read.parquet(dataPath) - val model = new FPGrowthModel(metadata.uid, frequentItems) + val itemSupport = if (numTrainingRecords == 0L) { + Map.empty[Any, Double] + } else { + frequentItems.rdd.flatMap { + case Row(items: Seq[_], count: Long) if items.length == 1 => + Some(items.head -> count.toDouble / numTrainingRecords) + case _ => None + }.collectAsMap() + } + val model = new FPGrowthModel(metadata.uid, frequentItems, itemSupport, numTrainingRecords) metadata.getAndSetParams(model) model } @@ -354,27 +380,30 @@ private[fpm] object AssociationRules { * @param itemsCol column name for frequent itemsets * @param freqCol column name for appearance count of the frequent itemsets * @param minConfidence minimum confidence for generating the association rules - * @return a DataFrame("antecedent"[Array], "consequent"[Array], "confidence"[Double]) - * containing the association rules. + * @param itemSupport map containing an item and its support + * @return a DataFrame("antecedent"[Array], "consequent"[Array], "confidence"[Double], + * "lift" [Double]) containing the association rules. */ def getAssociationRulesFromFP[T: ClassTag]( dataset: Dataset[_], itemsCol: String, freqCol: String, - minConfidence: Double): DataFrame = { + minConfidence: Double, + itemSupport: scala.collection.Map[T, Double]): DataFrame = { val freqItemSetRdd = dataset.select(itemsCol, freqCol).rdd .map(row => new FreqItemset(row.getSeq[T](0).toArray, row.getLong(1))) val rows = new MLlibAssociationRules() .setMinConfidence(minConfidence) - .run(freqItemSetRdd) - .map(r => Row(r.antecedent, r.consequent, r.confidence)) + .run(freqItemSetRdd, itemSupport) + .map(r => Row(r.antecedent, r.consequent, r.confidence, r.lift.orNull)) val dt = dataset.schema(itemsCol).dataType val schema = StructType(Seq( StructField("antecedent", dt, nullable = false), StructField("consequent", dt, nullable = false), - StructField("confidence", DoubleType, nullable = false))) + StructField("confidence", DoubleType, nullable = false), + StructField("lift", DoubleType))) val rules = dataset.sparkSession.createDataFrame(rows, schema) rules } diff --git a/mllib/src/main/scala/org/apache/spark/ml/image/HadoopUtils.scala b/mllib/src/main/scala/org/apache/spark/ml/image/HadoopUtils.scala index 8c975a2fba8ca..1fae1dc04ad7b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/image/HadoopUtils.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/image/HadoopUtils.scala @@ -38,13 +38,17 @@ private object RecursiveFlag { */ def withRecursiveFlag[T](value: Boolean, spark: SparkSession)(f: => T): T = { val flagName = FileInputFormat.INPUT_DIR_RECURSIVE + // scalastyle:off hadoopconfiguration val hadoopConf = spark.sparkContext.hadoopConfiguration + // scalastyle:on hadoopconfiguration val old = Option(hadoopConf.get(flagName)) hadoopConf.set(flagName, value.toString) try f finally { - old match { - case Some(v) => hadoopConf.set(flagName, v) - case None => hadoopConf.unset(flagName) + // avoid false positive of DLS_DEAD_LOCAL_STORE_IN_RETURN by SpotBugs + if (old.isDefined) { + hadoopConf.set(flagName, old.get) + } else { + hadoopConf.unset(flagName) } } } @@ -96,7 +100,9 @@ private object SamplePathFilter { val sampleImages = sampleRatio < 1 if (sampleImages) { val flagName = FileInputFormat.PATHFILTER_CLASS + // scalastyle:off hadoopconfiguration val hadoopConf = spark.sparkContext.hadoopConfiguration + // scalastyle:on hadoopconfiguration val old = Option(hadoopConf.getClass(flagName, null)) hadoopConf.setDouble(SamplePathFilter.ratioParam, sampleRatio) hadoopConf.setLong(SamplePathFilter.seedParam, seed) diff --git a/mllib/src/main/scala/org/apache/spark/ml/image/ImageSchema.scala b/mllib/src/main/scala/org/apache/spark/ml/image/ImageSchema.scala index dcc40b6668c7a..0b13eefdf3f5f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/image/ImageSchema.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/image/ImageSchema.scala @@ -198,6 +198,8 @@ object ImageSchema { * @return DataFrame with a single column "image" of images; * see ImageSchema for the details */ + @deprecated("use `spark.read.format(\"image\").load(path)` and this `readImages` will be " + + "removed in 3.0.0.", "2.4.0") def readImages(path: String): DataFrame = readImages(path, null, false, -1, false, 1.0, 0) /** @@ -218,6 +220,8 @@ object ImageSchema { * @return DataFrame with a single column "image" of images; * see ImageSchema for the details */ + @deprecated("use `spark.read.format(\"image\").load(path)` and this `readImages` will be " + + "removed in 3.0.0.", "2.4.0") def readImages( path: String, sparkSession: SparkSession, diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/AFTSurvivalRegressionWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/AFTSurvivalRegressionWrapper.scala index 80d03ab03c87d..48485e02edda8 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/AFTSurvivalRegressionWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/AFTSurvivalRegressionWrapper.scala @@ -59,13 +59,13 @@ private[r] class AFTSurvivalRegressionWrapper private ( private[r] object AFTSurvivalRegressionWrapper extends MLReadable[AFTSurvivalRegressionWrapper] { + private val FORMULA_REGEXP = """Surv\(([^,]+), ([^,]+)\) ~ (.+)""".r + private def formulaRewrite(formula: String): (String, String) = { var rewritedFormula: String = null var censorCol: String = null - - val regex = """Surv\(([^,]+), ([^,]+)\) ~ (.+)""".r try { - val regex(label, censor, features) = formula + val FORMULA_REGEXP(label, censor, features) = formula // TODO: Support dot operator. if (features.contains(".")) { throw new UnsupportedOperationException( diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala index a23f9552b9e5f..ffe592789b3cc 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala @@ -39,6 +39,7 @@ import org.apache.spark.ml.linalg.BLAS import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util._ +import org.apache.spark.ml.util.Instrumentation.instrumented import org.apache.spark.mllib.linalg.CholeskyDecomposition import org.apache.spark.mllib.optimization.NNLS import org.apache.spark.rdd.RDD @@ -654,7 +655,7 @@ class ALS(@Since("1.4.0") override val uid: String) extends Estimator[ALSModel] } @Since("2.0.0") - override def fit(dataset: Dataset[_]): ALSModel = { + override def fit(dataset: Dataset[_]): ALSModel = instrumented { instr => transformSchema(dataset.schema) import dataset.sparkSession.implicits._ @@ -666,8 +667,9 @@ class ALS(@Since("1.4.0") override val uid: String) extends Estimator[ALSModel] Rating(row.getInt(0), row.getInt(1), row.getFloat(2)) } - val instr = Instrumentation.create(this, ratings) - instr.logParams(rank, numUserBlocks, numItemBlocks, implicitPrefs, alpha, userCol, + instr.logPipelineStage(this) + instr.logDataset(dataset) + instr.logParams(this, rank, numUserBlocks, numItemBlocks, implicitPrefs, alpha, userCol, itemCol, ratingCol, predictionCol, maxIter, regParam, nonnegative, checkpointInterval, seed, intermediateStorageLevel, finalStorageLevel) @@ -681,7 +683,6 @@ class ALS(@Since("1.4.0") override val uid: String) extends Estimator[ALSModel] val userDF = userFactors.toDF("id", "features") val itemDF = itemFactors.toDF("id", "features") val model = new ALSModel(uid, $(rank), userDF, itemDF).setParent(this) - instr.logSuccess(model) copyValues(model) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala index e27a96e1f5dfc..8d6e36697d2cc 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala @@ -32,6 +32,7 @@ import org.apache.spark.ml.linalg.{BLAS, Vector, Vectors, VectorUDT} import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util._ +import org.apache.spark.ml.util.Instrumentation.instrumented import org.apache.spark.mllib.linalg.VectorImplicits._ import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer import org.apache.spark.mllib.util.MLUtils @@ -210,7 +211,7 @@ class AFTSurvivalRegression @Since("1.6.0") (@Since("1.6.0") override val uid: S } @Since("2.0.0") - override def fit(dataset: Dataset[_]): AFTSurvivalRegressionModel = { + override def fit(dataset: Dataset[_]): AFTSurvivalRegressionModel = instrumented { instr => transformSchema(dataset.schema, logging = true) val instances = extractAFTPoints(dataset) val handlePersistence = dataset.storageLevel == StorageLevel.NONE @@ -229,11 +230,13 @@ class AFTSurvivalRegression @Since("1.6.0") (@Since("1.6.0") override val uid: S val featuresStd = featuresSummarizer.variance.toArray.map(math.sqrt) val numFeatures = featuresStd.size - val instr = Instrumentation.create(this, dataset) - instr.logParams(labelCol, featuresCol, censorCol, predictionCol, quantilesCol, + instr.logPipelineStage(this) + instr.logDataset(dataset) + instr.logParams(this, labelCol, featuresCol, censorCol, predictionCol, quantilesCol, fitIntercept, maxIter, tol, aggregationDepth) instr.logNamedValue("quantileProbabilities.size", $(quantileProbabilities).length) instr.logNumFeatures(numFeatures) + instr.logNumExamples(featuresSummarizer.count) if (!$(fitIntercept) && (0 until numFeatures).exists { i => featuresStd(i) == 0.0 && featuresSummarizer.mean(i) != 0.0 }) { @@ -284,10 +287,7 @@ class AFTSurvivalRegression @Since("1.6.0") (@Since("1.6.0") override val uid: S val coefficients = Vectors.dense(rawCoefficients) val intercept = parameters(1) val scale = math.exp(parameters(0)) - val model = copyValues(new AFTSurvivalRegressionModel(uid, coefficients, - intercept, scale).setParent(this)) - instr.logSuccess(model) - model + copyValues(new AFTSurvivalRegressionModel(uid, coefficients, intercept, scale).setParent(this)) } @Since("1.6.0") diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala index 8bcf0793a64c1..018290f81842f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala @@ -30,6 +30,7 @@ import org.apache.spark.ml.tree._ import org.apache.spark.ml.tree.DecisionTreeModelReadWrite._ import org.apache.spark.ml.tree.impl.RandomForest import org.apache.spark.ml.util._ +import org.apache.spark.ml.util.Instrumentation.instrumented import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy} import org.apache.spark.mllib.tree.model.{DecisionTreeModel => OldDecisionTreeModel} import org.apache.spark.rdd.RDD @@ -99,37 +100,36 @@ class DecisionTreeRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: S @Since("2.0.0") def setVarianceCol(value: String): this.type = set(varianceCol, value) - override protected def train(dataset: Dataset[_]): DecisionTreeRegressionModel = { + override protected def train( + dataset: Dataset[_]): DecisionTreeRegressionModel = instrumented { instr => val categoricalFeatures: Map[Int, Int] = MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol))) val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset) val strategy = getOldStrategy(categoricalFeatures) - val instr = Instrumentation.create(this, oldDataset) - instr.logParams(params: _*) + instr.logPipelineStage(this) + instr.logDataset(oldDataset) + instr.logParams(this, params: _*) val trees = RandomForest.run(oldDataset, strategy, numTrees = 1, featureSubsetStrategy = "all", seed = $(seed), instr = Some(instr), parentUID = Some(uid)) - val m = trees.head.asInstanceOf[DecisionTreeRegressionModel] - instr.logSuccess(m) - m + trees.head.asInstanceOf[DecisionTreeRegressionModel] } /** (private[ml]) Train a decision tree on an RDD */ private[ml] def train( data: RDD[LabeledPoint], oldStrategy: OldStrategy, - featureSubsetStrategy: String): DecisionTreeRegressionModel = { - val instr = Instrumentation.create(this, data) - instr.logParams(params: _*) + featureSubsetStrategy: String): DecisionTreeRegressionModel = instrumented { instr => + instr.logPipelineStage(this) + instr.logDataset(data) + instr.logParams(this, params: _*) val trees = RandomForest.run(data, oldStrategy, numTrees = 1, featureSubsetStrategy, seed = $(seed), instr = Some(instr), parentUID = Some(uid)) - val m = trees.head.asInstanceOf[DecisionTreeRegressionModel] - instr.logSuccess(m) - m + trees.head.asInstanceOf[DecisionTreeRegressionModel] } /** (private[ml]) Create a Strategy instance to use with the old API. */ diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala index eb8b3c001436a..3305881b0ccc6 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala @@ -31,6 +31,7 @@ import org.apache.spark.ml.tree._ import org.apache.spark.ml.tree.impl.GradientBoostedTrees import org.apache.spark.ml.util._ import org.apache.spark.ml.util.DefaultParamsReader.Metadata +import org.apache.spark.ml.util.Instrumentation.instrumented import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} import org.apache.spark.mllib.tree.model.{GradientBoostedTreesModel => OldGBTModel} import org.apache.spark.rdd.RDD @@ -151,7 +152,7 @@ class GBTRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: String) set(validationIndicatorCol, value) } - override protected def train(dataset: Dataset[_]): GBTRegressionModel = { + override protected def train(dataset: Dataset[_]): GBTRegressionModel = instrumented { instr => val categoricalFeatures: Map[Int, Int] = MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol))) @@ -168,8 +169,9 @@ class GBTRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: String) val numFeatures = trainDataset.first().features.size val boostingStrategy = super.getOldBoostingStrategy(categoricalFeatures, OldAlgo.Regression) - val instr = Instrumentation.create(this, dataset) - instr.logParams(labelCol, featuresCol, predictionCol, impurity, lossType, + instr.logPipelineStage(this) + instr.logDataset(dataset) + instr.logParams(this, labelCol, featuresCol, predictionCol, impurity, lossType, maxDepth, maxBins, maxIter, maxMemoryInMB, minInfoGain, minInstancesPerNode, seed, stepSize, subsamplingRate, cacheNodeIds, checkpointInterval, featureSubsetStrategy) instr.logNumFeatures(numFeatures) @@ -181,9 +183,7 @@ class GBTRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: String) GradientBoostedTrees.run(trainDataset, boostingStrategy, $(seed), $(featureSubsetStrategy)) } - val m = new GBTRegressionModel(uid, baseLearners, learnerWeights, numFeatures) - instr.logSuccess(m) - m + new GBTRegressionModel(uid, baseLearners, learnerWeights, numFeatures) } @Since("1.4.0") diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala index 143c8a3548b1f..abb60ea205751 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala @@ -34,6 +34,7 @@ import org.apache.spark.ml.optim._ import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util._ +import org.apache.spark.ml.util.Instrumentation.instrumented import org.apache.spark.rdd.RDD import org.apache.spark.sql.{Column, DataFrame, Dataset, Row} import org.apache.spark.sql.functions._ @@ -373,13 +374,15 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val @Since("2.0.0") def setLinkPredictionCol(value: String): this.type = set(linkPredictionCol, value) - override protected def train(dataset: Dataset[_]): GeneralizedLinearRegressionModel = { + override protected def train( + dataset: Dataset[_]): GeneralizedLinearRegressionModel = instrumented { instr => val familyAndLink = FamilyAndLink(this) val numFeatures = dataset.select(col($(featuresCol))).first().getAs[Vector](0).size - val instr = Instrumentation.create(this, dataset) - instr.logParams(labelCol, featuresCol, weightCol, offsetCol, predictionCol, linkPredictionCol, - family, solver, fitIntercept, link, maxIter, regParam, tol) + instr.logPipelineStage(this) + instr.logDataset(dataset) + instr.logParams(this, labelCol, featuresCol, weightCol, offsetCol, predictionCol, + linkPredictionCol, family, solver, fitIntercept, link, maxIter, regParam, tol) instr.logNumFeatures(numFeatures) if (numFeatures > WeightedLeastSquares.MAX_NUM_FEATURES) { @@ -431,7 +434,6 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val model.setSummary(Some(trainingSummary)) } - instr.logSuccess(model) model } @@ -513,14 +515,13 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine * The reweight function used to update working labels and weights * at each iteration of [[IterativelyReweightedLeastSquares]]. */ - val reweightFunc: (OffsetInstance, WeightedLeastSquaresModel) => (Double, Double) = { - (instance: OffsetInstance, model: WeightedLeastSquaresModel) => { - val eta = model.predict(instance.features) + instance.offset - val mu = fitted(eta) - val newLabel = eta - instance.offset + (instance.label - mu) * link.deriv(mu) - val newWeight = instance.weight / (math.pow(this.link.deriv(mu), 2.0) * family.variance(mu)) - (newLabel, newWeight) - } + def reweightFunc( + instance: OffsetInstance, model: WeightedLeastSquaresModel): (Double, Double) = { + val eta = model.predict(instance.features) + instance.offset + val mu = fitted(eta) + val newLabel = eta - instance.offset + (instance.label - mu) * link.deriv(mu) + val newWeight = instance.weight / (math.pow(this.link.deriv(mu), 2.0) * family.variance(mu)) + (newLabel, newWeight) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala index b046897ab2b7e..8b9233dcdc4d1 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala @@ -27,6 +27,7 @@ import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.regression.IsotonicRegressionModel.IsotonicRegressionModelWriter import org.apache.spark.ml.util._ +import org.apache.spark.ml.util.Instrumentation.instrumented import org.apache.spark.mllib.regression.{IsotonicRegression => MLlibIsotonicRegression} import org.apache.spark.mllib.regression.{IsotonicRegressionModel => MLlibIsotonicRegressionModel} import org.apache.spark.rdd.RDD @@ -161,15 +162,16 @@ class IsotonicRegression @Since("1.5.0") (@Since("1.5.0") override val uid: Stri override def copy(extra: ParamMap): IsotonicRegression = defaultCopy(extra) @Since("2.0.0") - override def fit(dataset: Dataset[_]): IsotonicRegressionModel = { + override def fit(dataset: Dataset[_]): IsotonicRegressionModel = instrumented { instr => transformSchema(dataset.schema, logging = true) // Extract columns from data. If dataset is persisted, do not persist oldDataset. val instances = extractWeightedLabeledPoints(dataset) val handlePersistence = dataset.storageLevel == StorageLevel.NONE if (handlePersistence) instances.persist(StorageLevel.MEMORY_AND_DISK) - val instr = Instrumentation.create(this, dataset) - instr.logParams(labelCol, featuresCol, weightCol, predictionCol, featureIndex, isotonic) + instr.logPipelineStage(this) + instr.logDataset(dataset) + instr.logParams(this, labelCol, featuresCol, weightCol, predictionCol, featureIndex, isotonic) instr.logNumFeatures(1) val isotonicRegression = new MLlibIsotonicRegression().setIsotonic($(isotonic)) @@ -177,9 +179,7 @@ class IsotonicRegression @Since("1.5.0") (@Since("1.5.0") override val uid: Stri if (handlePersistence) instances.unpersist() - val model = copyValues(new IsotonicRegressionModel(uid, oldModel).setParent(this)) - instr.logSuccess(model) - model + copyValues(new IsotonicRegressionModel(uid, oldModel).setParent(this)) } @Since("1.5.0") diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index c45ade94a4e33..ce6c12cc368dd 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -37,6 +37,7 @@ import org.apache.spark.ml.optim.loss.{L2Regularization, RDDLossFunction} import org.apache.spark.ml.param.{DoubleParam, Param, ParamMap, ParamValidators} import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util._ +import org.apache.spark.ml.util.Instrumentation.instrumented import org.apache.spark.mllib.evaluation.RegressionMetrics import org.apache.spark.mllib.linalg.VectorImplicits._ import org.apache.spark.mllib.regression.{LinearRegressionModel => OldLinearRegressionModel} @@ -315,7 +316,7 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String def setEpsilon(value: Double): this.type = set(epsilon, value) setDefault(epsilon -> 1.35) - override protected def train(dataset: Dataset[_]): LinearRegressionModel = { + override protected def train(dataset: Dataset[_]): LinearRegressionModel = instrumented { instr => // Extract the number of features before deciding optimization solver. val numFeatures = dataset.select(col($(featuresCol))).first().getAs[Vector](0).size val w = if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else col($(weightCol)) @@ -326,9 +327,11 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String Instance(label, weight, features) } - val instr = Instrumentation.create(this, dataset) - instr.logParams(labelCol, featuresCol, weightCol, predictionCol, solver, tol, elasticNetParam, - fitIntercept, maxIter, regParam, standardization, aggregationDepth, loss, epsilon) + instr.logPipelineStage(this) + instr.logDataset(dataset) + instr.logParams(this, labelCol, featuresCol, weightCol, predictionCol, solver, tol, + elasticNetParam, fitIntercept, maxIter, regParam, standardization, aggregationDepth, loss, + epsilon) instr.logNumFeatures(numFeatures) if ($(loss) == SquaredError && (($(solver) == Auto && @@ -353,9 +356,7 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String model.diagInvAtWA.toArray, model.objectiveHistory) - lrModel.setSummary(Some(trainingSummary)) - instr.logSuccess(lrModel) - return lrModel + return lrModel.setSummary(Some(trainingSummary)) } val handlePersistence = dataset.storageLevel == StorageLevel.NONE @@ -415,9 +416,7 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String Array(0D), Array(0D)) - model.setSummary(Some(trainingSummary)) - instr.logSuccess(model) - return model + return model.setSummary(Some(trainingSummary)) } else { require($(regParam) == 0.0, "The standard deviation of the label is zero. " + "Model cannot be regularized.") @@ -596,8 +595,6 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String objectiveHistory) model.setSummary(Some(trainingSummary)) - instr.logSuccess(model) - model } @Since("1.4.0") diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala index 4509f85aafd12..35875724b3cfa 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala @@ -29,6 +29,7 @@ import org.apache.spark.ml.tree._ import org.apache.spark.ml.tree.impl.RandomForest import org.apache.spark.ml.util._ import org.apache.spark.ml.util.DefaultParamsReader.Metadata +import org.apache.spark.ml.util.Instrumentation.instrumented import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} import org.apache.spark.mllib.tree.model.{RandomForestModel => OldRandomForestModel} import org.apache.spark.rdd.RDD @@ -114,15 +115,17 @@ class RandomForestRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: S override def setFeatureSubsetStrategy(value: String): this.type = set(featureSubsetStrategy, value) - override protected def train(dataset: Dataset[_]): RandomForestRegressionModel = { + override protected def train( + dataset: Dataset[_]): RandomForestRegressionModel = instrumented { instr => val categoricalFeatures: Map[Int, Int] = MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol))) val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset) val strategy = super.getOldStrategy(categoricalFeatures, numClasses = 0, OldAlgo.Regression, getOldImpurity) - val instr = Instrumentation.create(this, oldDataset) - instr.logParams(labelCol, featuresCol, predictionCol, impurity, numTrees, + instr.logPipelineStage(this) + instr.logDataset(dataset) + instr.logParams(this, labelCol, featuresCol, predictionCol, impurity, numTrees, featureSubsetStrategy, maxDepth, maxBins, maxMemoryInMB, minInfoGain, minInstancesPerNode, seed, subsamplingRate, cacheNodeIds, checkpointInterval) @@ -131,9 +134,8 @@ class RandomForestRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: S .map(_.asInstanceOf[DecisionTreeRegressionModel]) val numFeatures = oldDataset.first().features.size - val m = new RandomForestRegressionModel(uid, trees, numFeatures) - instr.logSuccess(m) - m + instr.logNamedValue(Instrumentation.loggerTags.numFeatures, numFeatures) + new RandomForestRegressionModel(uid, trees, numFeatures) } @Since("1.4.0") diff --git a/mllib/src/main/scala/org/apache/spark/ml/source/image/ImageDataSource.scala b/mllib/src/main/scala/org/apache/spark/ml/source/image/ImageDataSource.scala new file mode 100644 index 0000000000000..a111c95248cf5 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/source/image/ImageDataSource.scala @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.source.image + +/** + * `image` package implements Spark SQL data source API for loading image data as `DataFrame`. + * The loaded `DataFrame` has one `StructType` column: `image`. + * The schema of the `image` column is: + * - origin: String (represents the file path of the image) + * - height: Int (height of the image) + * - width: Int (width of the image) + * - nChannels: Int (number of the image channels) + * - mode: Int (OpenCV-compatible type) + * - data: BinaryType (Image bytes in OpenCV-compatible order: row-wise BGR in most cases) + * + * To use image data source, you need to set "image" as the format in `DataFrameReader` and + * optionally specify the data source options, for example: + * {{{ + * // Scala + * val df = spark.read.format("image") + * .option("dropInvalid", true) + * .load("data/mllib/images/partitioned") + * + * // Java + * Dataset df = spark.read().format("image") + * .option("dropInvalid", true) + * .load("data/mllib/images/partitioned"); + * }}} + * + * Image data source supports the following options: + * - "dropInvalid": Whether to drop the files that are not valid images from the result. + * + * @note This IMAGE data source does not support saving images to files. + * + * @note This class is public for documentation purpose. Please don't use this class directly. + * Rather, use the data source API as illustrated above. + */ +class ImageDataSource private() {} diff --git a/mllib/src/main/scala/org/apache/spark/ml/source/image/ImageFileFormat.scala b/mllib/src/main/scala/org/apache/spark/ml/source/image/ImageFileFormat.scala new file mode 100644 index 0000000000000..c3321447e3c96 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/source/image/ImageFileFormat.scala @@ -0,0 +1,100 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.source.image + +import com.google.common.io.{ByteStreams, Closeables} +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.{FileStatus, Path} +import org.apache.hadoop.mapreduce.Job + +import org.apache.spark.ml.image.ImageSchema +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.encoders.RowEncoder +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, UnsafeRow} +import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap +import org.apache.spark.sql.execution.datasources.{DataSource, FileFormat, OutputWriterFactory, PartitionedFile} +import org.apache.spark.sql.sources.{DataSourceRegister, Filter} +import org.apache.spark.sql.types.StructType +import org.apache.spark.util.SerializableConfiguration + +private[image] class ImageFileFormat extends FileFormat with DataSourceRegister { + + override def inferSchema( + sparkSession: SparkSession, + options: Map[String, String], + files: Seq[FileStatus]): Option[StructType] = Some(ImageSchema.imageSchema) + + override def prepareWrite( + sparkSession: SparkSession, + job: Job, + options: Map[String, String], + dataSchema: StructType): OutputWriterFactory = { + throw new UnsupportedOperationException("Write is not supported for image data source") + } + + override def shortName(): String = "image" + + override protected def buildReader( + sparkSession: SparkSession, + dataSchema: StructType, + partitionSchema: StructType, + requiredSchema: StructType, + filters: Seq[Filter], + options: Map[String, String], + hadoopConf: Configuration): (PartitionedFile) => Iterator[InternalRow] = { + assert( + requiredSchema.length <= 1, + "Image data source only produces a single data column named \"image\".") + + val broadcastedHadoopConf = + sparkSession.sparkContext.broadcast(new SerializableConfiguration(hadoopConf)) + + val imageSourceOptions = new ImageOptions(options) + + (file: PartitionedFile) => { + val emptyUnsafeRow = new UnsafeRow(0) + if (!imageSourceOptions.dropInvalid && requiredSchema.isEmpty) { + Iterator(emptyUnsafeRow) + } else { + val origin = file.filePath + val path = new Path(origin) + val fs = path.getFileSystem(broadcastedHadoopConf.value.value) + val stream = fs.open(path) + val bytes = try { + ByteStreams.toByteArray(stream) + } finally { + Closeables.close(stream, true) + } + val resultOpt = ImageSchema.decode(origin, bytes) + val filteredResult = if (imageSourceOptions.dropInvalid) { + resultOpt.toIterator + } else { + Iterator(resultOpt.getOrElse(ImageSchema.invalidImageRow(origin))) + } + + if (requiredSchema.isEmpty) { + filteredResult.map(_ => emptyUnsafeRow) + } else { + val converter = RowEncoder(requiredSchema) + filteredResult.map(row => converter.toRow(row)) + } + } + } + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/source/image/ImageOptions.scala b/mllib/src/main/scala/org/apache/spark/ml/source/image/ImageOptions.scala new file mode 100644 index 0000000000000..7ff196907717e --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/source/image/ImageOptions.scala @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.source.image + +import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap + +private[image] class ImageOptions( + @transient private val parameters: CaseInsensitiveMap[String]) extends Serializable { + + def this(parameters: Map[String, String]) = this(CaseInsensitiveMap(parameters)) + + /** + * Whether to drop invalid images. If true, invalid images will be removed, otherwise + * invalid images will be returned with empty data and all other field filled with `-1`. + */ + val dropInvalid = parameters.getOrElse("dropInvalid", "false").toBoolean +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala index 4e84ff044f55e..39dcd911a0814 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala @@ -154,7 +154,7 @@ private[libsvm] class LibSVMFileFormat (file: PartitionedFile) => { val linesReader = new HadoopFileLinesReader(file, broadcastedHadoopConf.value.value) - Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => linesReader.close())) + Option(TaskContext.get()).foreach(_.addTaskCompletionListener[Unit](_ => linesReader.close())) val points = linesReader .map(_.toString.trim) diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala index 905870178e549..4cdd17266b771 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala @@ -77,7 +77,7 @@ import org.apache.spark.util.random.{SamplingUtils, XORShiftRandom} * the heaviest part of the computation. In general, this implementation is bound by either * the cost of statistics computation on workers or by communicating the sufficient statistics. */ -private[spark] object RandomForest extends Logging { +private[spark] object RandomForest extends Logging with Serializable { /** * Train a random forest. @@ -91,7 +91,7 @@ private[spark] object RandomForest extends Logging { numTrees: Int, featureSubsetStrategy: String, seed: Long, - instr: Option[Instrumentation[_]], + instr: Option[Instrumentation], prune: Boolean = true, // exposed for testing only, real trees are always pruned parentUID: Option[String] = None): Array[DecisionTreeModel] = { @@ -407,7 +407,7 @@ private[spark] object RandomForest extends Logging { metadata.isMulticlassWithCategoricalFeatures) logDebug("using nodeIdCache = " + nodeIdCache.nonEmpty.toString) - /** + /* * Performs a sequential aggregation over a partition for a particular tree and node. * * For each feature, the aggregate sufficient statistics are updated for the relevant @@ -438,7 +438,7 @@ private[spark] object RandomForest extends Logging { } } - /** + /* * Performs a sequential aggregation over a partition. * * Each data point contributes to one node. For each feature, diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala index f327f37bad204..e60a14f976a5c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala @@ -33,6 +33,7 @@ import org.apache.spark.ml.evaluation.Evaluator import org.apache.spark.ml.param.{IntParam, ParamMap, ParamValidators} import org.apache.spark.ml.param.shared.{HasCollectSubModels, HasParallelism} import org.apache.spark.ml.util._ +import org.apache.spark.ml.util.Instrumentation.instrumented import org.apache.spark.mllib.util.MLUtils import org.apache.spark.sql.{DataFrame, Dataset} import org.apache.spark.sql.types.StructType @@ -118,7 +119,7 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String) def setCollectSubModels(value: Boolean): this.type = set(collectSubModels, value) @Since("2.0.0") - override def fit(dataset: Dataset[_]): CrossValidatorModel = { + override def fit(dataset: Dataset[_]): CrossValidatorModel = instrumented { instr => val schema = dataset.schema transformSchema(schema, logging = true) val sparkSession = dataset.sparkSession @@ -129,8 +130,9 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String) // Create execution context based on $(parallelism) val executionContext = getExecutionContext - val instr = Instrumentation.create(this, dataset) - instr.logParams(numFolds, seed, parallelism) + instr.logPipelineStage(this) + instr.logDataset(dataset) + instr.logParams(this, numFolds, seed, parallelism) logTuningParams(instr) val collectSubModelsParam = $(collectSubModels) @@ -176,7 +178,6 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String) instr.logInfo(s"Best set of parameters:\n${epm(bestIndex)}") instr.logInfo(s"Best cross-validation metric: $bestMetric.") val bestModel = est.fit(dataset, epm(bestIndex)).asInstanceOf[Model[_]] - instr.logSuccess(bestModel) copyValues(new CrossValidatorModel(uid, bestModel, metrics) .setSubModels(subModels).setParent(this)) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala index 14d6a69c36747..8b251197afbef 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala @@ -34,6 +34,7 @@ import org.apache.spark.ml.evaluation.Evaluator import org.apache.spark.ml.param.{DoubleParam, ParamMap, ParamValidators} import org.apache.spark.ml.param.shared.{HasCollectSubModels, HasParallelism} import org.apache.spark.ml.util._ +import org.apache.spark.ml.util.Instrumentation.instrumented import org.apache.spark.sql.{DataFrame, Dataset} import org.apache.spark.sql.types.StructType import org.apache.spark.util.ThreadUtils @@ -117,7 +118,7 @@ class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: St def setCollectSubModels(value: Boolean): this.type = set(collectSubModels, value) @Since("2.0.0") - override def fit(dataset: Dataset[_]): TrainValidationSplitModel = { + override def fit(dataset: Dataset[_]): TrainValidationSplitModel = instrumented { instr => val schema = dataset.schema transformSchema(schema, logging = true) val est = $(estimator) @@ -127,8 +128,9 @@ class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: St // Create execution context based on $(parallelism) val executionContext = getExecutionContext - val instr = Instrumentation.create(this, dataset) - instr.logParams(trainRatio, seed, parallelism) + instr.logPipelineStage(this) + instr.logDataset(dataset) + instr.logParams(this, trainRatio, seed, parallelism) logTuningParams(instr) val Array(trainingDataset, validationDataset) = @@ -172,7 +174,6 @@ class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: St instr.logInfo(s"Best set of parameters:\n${epm(bestIndex)}") instr.logInfo(s"Best train validation split metric: $bestMetric.") val bestModel = est.fit(dataset, epm(bestIndex)).asInstanceOf[Model[_]] - instr.logSuccess(bestModel) copyValues(new TrainValidationSplitModel(uid, bestModel, metrics) .setSubModels(subModels).setParent(this)) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala index 363304ef10147..135828815504a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala @@ -80,7 +80,7 @@ private[ml] trait ValidatorParams extends HasSeed with Params { /** * Instrumentation logging for tuning params including the inner estimator and evaluator info. */ - protected def logTuningParams(instrumentation: Instrumentation[_]): Unit = { + protected def logTuningParams(instrumentation: Instrumentation): Unit = { instrumentation.logNamedValue("estimator", $(estimator).getClass.getCanonicalName) instrumentation.logNamedValue("evaluator", $(evaluator).getClass.getCanonicalName) instrumentation.logNamedValue("estimatorParamMapsLength", $(estimatorParamMaps).length) diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/Instrumentation.scala b/mllib/src/main/scala/org/apache/spark/ml/util/Instrumentation.scala index 11f46eb9e4359..49654918bd8f8 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/Instrumentation.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/Instrumentation.scala @@ -19,15 +19,16 @@ package org.apache.spark.ml.util import java.util.UUID -import scala.reflect.ClassTag +import scala.util.{Failure, Success, Try} +import scala.util.control.NonFatal import org.json4s._ import org.json4s.JsonDSL._ import org.json4s.jackson.JsonMethods._ import org.apache.spark.internal.Logging -import org.apache.spark.ml.{Estimator, Model} -import org.apache.spark.ml.param.Param +import org.apache.spark.ml.PipelineStage +import org.apache.spark.ml.param.{Param, Params} import org.apache.spark.rdd.RDD import org.apache.spark.sql.Dataset import org.apache.spark.util.Utils @@ -35,29 +36,34 @@ import org.apache.spark.util.Utils /** * A small wrapper that defines a training session for an estimator, and some methods to log * useful information during this session. - * - * A new instance is expected to be created within fit(). - * - * @param estimator the estimator that is being fit - * @param dataset the training dataset - * @tparam E the type of the estimator */ -private[spark] class Instrumentation[E <: Estimator[_]] private ( - val estimator: E, - val dataset: RDD[_]) extends Logging { +private[spark] class Instrumentation private () extends Logging { private val id = UUID.randomUUID() - private val prefix = { + private val shortId = id.toString.take(8) + private[util] val prefix = s"[$shortId] " + + /** + * Log some info about the pipeline stage being fit. + */ + def logPipelineStage(stage: PipelineStage): Unit = { // estimator.getClass.getSimpleName can cause Malformed class name error, // call safer `Utils.getSimpleName` instead - val className = Utils.getSimpleName(estimator.getClass) - s"$className-${estimator.uid}-${dataset.hashCode()}-$id: " + val className = Utils.getSimpleName(stage.getClass) + logInfo(s"Stage class: $className") + logInfo(s"Stage uid: ${stage.uid}") } - init() + /** + * Log some data about the dataset being fit. + */ + def logDataset(dataset: Dataset[_]): Unit = logDataset(dataset.rdd) - private def init(): Unit = { - log(s"training: numPartitions=${dataset.partitions.length}" + + /** + * Log some data about the dataset being fit. + */ + def logDataset(dataset: RDD[_]): Unit = { + logInfo(s"training: numPartitions=${dataset.partitions.length}" + s" storageLevel=${dataset.getStorageLevel}") } @@ -89,23 +95,18 @@ private[spark] class Instrumentation[E <: Estimator[_]] private ( super.logInfo(prefix + msg) } - /** - * Alias for logInfo, see above. - */ - def log(msg: String): Unit = logInfo(msg) - /** * Logs the value of the given parameters for the estimator being used in this session. */ - def logParams(params: Param[_]*): Unit = { + def logParams(hasParams: Params, params: Param[_]*): Unit = { val pairs: Seq[(String, JValue)] = for { p <- params - value <- estimator.get(p) + value <- hasParams.get(p) } yield { val cast = p.asInstanceOf[Param[Any]] p.name -> parse(cast.jsonEncode(value)) } - log(compact(render(map2jvalue(pairs.toMap)))) + logInfo(compact(render(map2jvalue(pairs.toMap)))) } def logNumFeatures(num: Long): Unit = { @@ -124,35 +125,43 @@ private[spark] class Instrumentation[E <: Estimator[_]] private ( * Logs the value with customized name field. */ def logNamedValue(name: String, value: String): Unit = { - log(compact(render(name -> value))) + logInfo(compact(render(name -> value))) } def logNamedValue(name: String, value: Long): Unit = { - log(compact(render(name -> value))) + logInfo(compact(render(name -> value))) } def logNamedValue(name: String, value: Double): Unit = { - log(compact(render(name -> value))) + logInfo(compact(render(name -> value))) } def logNamedValue(name: String, value: Array[String]): Unit = { - log(compact(render(name -> compact(render(value.toSeq))))) + logInfo(compact(render(name -> compact(render(value.toSeq))))) } def logNamedValue(name: String, value: Array[Long]): Unit = { - log(compact(render(name -> compact(render(value.toSeq))))) + logInfo(compact(render(name -> compact(render(value.toSeq))))) } def logNamedValue(name: String, value: Array[Double]): Unit = { - log(compact(render(name -> compact(render(value.toSeq))))) + logInfo(compact(render(name -> compact(render(value.toSeq))))) } /** * Logs the successful completion of the training session. */ - def logSuccess(model: Model[_]): Unit = { - log(s"training finished") + def logSuccess(): Unit = { + logInfo("training finished") + } + + /** + * Logs an exception raised during a training session. + */ + def logFailure(e: Throwable): Unit = { + val msg = e.getStackTrace.mkString("\n") + super.logError(msg) } } @@ -169,22 +178,17 @@ private[spark] object Instrumentation { val varianceOfLabels = "varianceOfLabels" } - /** - * Creates an instrumentation object for a training session. - */ - def create[E <: Estimator[_]]( - estimator: E, dataset: Dataset[_]): Instrumentation[E] = { - create[E](estimator, dataset.rdd) - } - - /** - * Creates an instrumentation object for a training session. - */ - def create[E <: Estimator[_]]( - estimator: E, dataset: RDD[_]): Instrumentation[E] = { - new Instrumentation[E](estimator, dataset) + def instrumented[T](body: (Instrumentation => T)): T = { + val instr = new Instrumentation() + Try(body(instr)) match { + case Failure(NonFatal(e)) => + instr.logFailure(e) + throw e + case Success(result) => + instr.logSuccess() + result + } } - } /** @@ -193,7 +197,7 @@ private[spark] object Instrumentation { * will log via it, otherwise will log via common logger. */ private[spark] class OptionalInstrumentation private( - val instrumentation: Option[Instrumentation[_ <: Estimator[_]]], + val instrumentation: Option[Instrumentation], val className: String) extends Logging { protected override def logName: String = className @@ -225,9 +229,8 @@ private[spark] object OptionalInstrumentation { /** * Creates an `OptionalInstrumentation` object from an existing `Instrumentation` object. */ - def create[E <: Estimator[_]](instr: Instrumentation[E]): OptionalInstrumentation = { - new OptionalInstrumentation(Some(instr), - instr.estimator.getClass.getName.stripSuffix("$")) + def create(instr: Instrumentation): OptionalInstrumentation = { + new OptionalInstrumentation(Some(instr), instr.prefix) } /** diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala b/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala index d9a3f85ef9a24..c3894ebdd1785 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala @@ -41,7 +41,8 @@ private[spark] object SchemaUtils { val actualDataType = schema(colName).dataType val message = if (msg != null && msg.trim.length > 0) " " + msg else "" require(actualDataType.equals(dataType), - s"Column $colName must be of type $dataType but was actually $actualDataType.$message") + s"Column $colName must be of type ${dataType.catalogString} but was actually " + + s"${actualDataType.catalogString}.$message") } /** @@ -58,7 +59,8 @@ private[spark] object SchemaUtils { val message = if (msg != null && msg.trim.length > 0) " " + msg else "" require(dataTypes.exists(actualDataType.equals), s"Column $colName must be of type equal to one of the following types: " + - s"${dataTypes.mkString("[", ", ", "]")} but was actually of type $actualDataType.$message") + s"${dataTypes.map(_.catalogString).mkString("[", ", ", "]")} but was actually of type " + + s"${actualDataType.catalogString}.$message") } /** @@ -71,8 +73,9 @@ private[spark] object SchemaUtils { msg: String = ""): Unit = { val actualDataType = schema(colName).dataType val message = if (msg != null && msg.trim.length > 0) " " + msg else "" - require(actualDataType.isInstanceOf[NumericType], s"Column $colName must be of type " + - s"NumericType but was actually of type $actualDataType.$message") + require(actualDataType.isInstanceOf[NumericType], + s"Column $colName must be of type ${NumericType.simpleString} but was actually of type " + + s"${actualDataType.catalogString}.$message") } /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeans.scala index 98af487306dcc..80ab8eb9bc8b0 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeans.scala @@ -25,6 +25,7 @@ import scala.collection.mutable import org.apache.spark.annotation.Since import org.apache.spark.api.java.JavaRDD import org.apache.spark.internal.Logging +import org.apache.spark.ml.util.Instrumentation import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.util.MLUtils import org.apache.spark.rdd.RDD @@ -151,13 +152,10 @@ class BisectingKMeans private ( this } - /** - * Runs the bisecting k-means algorithm. - * @param input RDD of vectors - * @return model for the bisecting kmeans - */ - @Since("1.6.0") - def run(input: RDD[Vector]): BisectingKMeansModel = { + + private[spark] def run( + input: RDD[Vector], + instr: Option[Instrumentation]): BisectingKMeansModel = { if (input.getStorageLevel == StorageLevel.NONE) { logWarning(s"The input RDD ${input.id} is not directly cached, which may hurt performance if" + " its parent RDDs are also not cached.") @@ -171,6 +169,7 @@ class BisectingKMeans private ( val vectors = input.zip(norms).map { case (x, norm) => new VectorWithNorm(x, norm) } var assignments = vectors.map(v => (ROOT_INDEX, v)) var activeClusters = summarize(d, assignments, dMeasure) + instr.foreach(_.logNumExamples(activeClusters.values.map(_.size).sum)) val rootSummary = activeClusters(ROOT_INDEX) val n = rootSummary.size logInfo(s"Number of points: $n.") @@ -246,6 +245,16 @@ class BisectingKMeans private ( new BisectingKMeansModel(root, this.distanceMeasure) } + /** + * Runs the bisecting k-means algorithm. + * @param input RDD of vectors + * @return model for the bisecting kmeans + */ + @Since("1.6.0") + def run(input: RDD[Vector]): BisectingKMeansModel = { + run(input, None) + } + /** * Java-friendly version of `run()`. */ diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala index b5b1be3490497..d967c672c581f 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala @@ -235,7 +235,7 @@ class KMeans private ( private[spark] def run( data: RDD[Vector], - instr: Option[Instrumentation[NewKMeans]]): KMeansModel = { + instr: Option[Instrumentation]): KMeansModel = { if (data.getStorageLevel == StorageLevel.NONE) { logWarning("The input data is not directly cached, which may hurt performance if its" @@ -264,7 +264,7 @@ class KMeans private ( */ private def runAlgorithm( data: RDD[VectorWithNorm], - instr: Option[Instrumentation[NewKMeans]]): KMeansModel = { + instr: Option[Instrumentation]): KMeansModel = { val sc = data.sparkContext @@ -299,7 +299,7 @@ class KMeans private ( val bcCenters = sc.broadcast(centers) // Find the new centers - val newCenters = data.mapPartitions { points => + val collected = data.mapPartitions { points => val thisCenters = bcCenters.value val dims = thisCenters.head.vector.size @@ -317,7 +317,13 @@ class KMeans private ( }.reduceByKey { case ((sum1, count1), (sum2, count2)) => axpy(1.0, sum2, sum1) (sum1, count1 + count2) - }.collectAsMap().mapValues { case (sum, count) => + }.collectAsMap() + + if (iteration == 0) { + instr.foreach(_.logNumExamples(collected.values.map(_._2).sum)) + } + + val newCenters = collected.mapValues { case (sum, count) => distanceMeasureInstance.centroid(sum, count) } @@ -348,7 +354,7 @@ class KMeans private ( logInfo(s"The cost is $cost.") - new KMeansModel(centers.map(_.vector), distanceMeasure) + new KMeansModel(centers.map(_.vector), distanceMeasure, cost, iteration) } /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala index a78c21e838e44..d5c8188144ce2 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala @@ -36,8 +36,10 @@ import org.apache.spark.sql.{Row, SparkSession} * A clustering model for K-means. Each point belongs to the cluster with the closest center. */ @Since("0.8.0") -class KMeansModel @Since("2.4.0") (@Since("1.0.0") val clusterCenters: Array[Vector], - @Since("2.4.0") val distanceMeasure: String) +class KMeansModel (@Since("1.0.0") val clusterCenters: Array[Vector], + @Since("2.4.0") val distanceMeasure: String, + @Since("2.4.0") val trainingCost: Double, + private[spark] val numIter: Int) extends Saveable with Serializable with PMMLExportable { private val distanceMeasureInstance: DistanceMeasure = @@ -46,9 +48,13 @@ class KMeansModel @Since("2.4.0") (@Since("1.0.0") val clusterCenters: Array[Vec private val clusterCentersWithNorm = if (clusterCenters == null) null else clusterCenters.map(new VectorWithNorm(_)) + @Since("2.4.0") + private[spark] def this(clusterCenters: Array[Vector], distanceMeasure: String) = + this(clusterCenters: Array[Vector], distanceMeasure, 0.0, -1) + @Since("1.1.0") def this(clusterCenters: Array[Vector]) = - this(clusterCenters: Array[Vector], DistanceMeasure.EUCLIDEAN) + this(clusterCenters: Array[Vector], DistanceMeasure.EUCLIDEAN, 0.0, -1) /** * A Java-friendly constructor that takes an Iterable of Vectors. @@ -182,7 +188,8 @@ object KMeansModel extends Loader[KMeansModel] { val spark = SparkSession.builder().sparkContext(sc).getOrCreate() val metadata = compact(render( ("class" -> thisClassName) ~ ("version" -> thisFormatVersion) - ~ ("k" -> model.k) ~ ("distanceMeasure" -> model.distanceMeasure))) + ~ ("k" -> model.k) ~ ("distanceMeasure" -> model.distanceMeasure) + ~ ("trainingCost" -> model.trainingCost))) sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path)) val dataRDD = sc.parallelize(model.clusterCentersWithNorm.zipWithIndex).map { case (p, id) => Cluster(id, p.vector) @@ -202,7 +209,8 @@ object KMeansModel extends Loader[KMeansModel] { val localCentroids = centroids.rdd.map(Cluster.apply).collect() assert(k == localCentroids.length) val distanceMeasure = (metadata \ "distanceMeasure").extract[String] - new KMeansModel(localCentroids.sortBy(_.id).map(_.point), distanceMeasure) + val trainingCost = (metadata \ "trainingCost").extract[Double] + new KMeansModel(localCentroids.sortBy(_.id).map(_.point), distanceMeasure, trainingCost, -1) } } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala index 7a5e520d5818e..ed8543da4d4ce 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala @@ -227,7 +227,7 @@ class StreamingKMeans @Since("1.2.0") ( require(centers.size == k, s"Number of initial centers must be ${k} but got ${centers.size}") require(weights.forall(_ >= 0), - s"Weight for each inital center must be nonnegative but got [${weights.mkString(" ")}]") + s"Weight for each initial center must be nonnegative but got [${weights.mkString(" ")}]") model = new StreamingKMeansModel(centers, weights) this } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala index f923be871f438..aa78e91b679ac 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala @@ -28,6 +28,7 @@ import org.apache.spark.annotation.Since import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.stat.Statistics +import org.apache.spark.mllib.stat.test.ChiSqTestResult import org.apache.spark.mllib.util.{Loader, Saveable} import org.apache.spark.rdd.RDD import org.apache.spark.sql.{Row, SparkSession} @@ -272,13 +273,16 @@ class ChiSqSelector @Since("2.1.0") () extends Serializable { // https://en.wikipedia.org/wiki/False_discovery_rate#Benjamini.E2.80.93Hochberg_procedure val tempRes = chiSqTestResult .sortBy { case (res, _) => res.pValue } - val maxIndex = tempRes + val selected = tempRes .zipWithIndex .filter { case ((res, _), index) => res.pValue <= fdr * (index + 1) / chiSqTestResult.length } - .map { case (_, index) => index } - .max - tempRes.take(maxIndex + 1) + if (selected.isEmpty) { + Array.empty[(ChiSqTestResult, Int)] + } else { + val maxIndex = selected.map(_._2).max + tempRes.take(maxIndex + 1) + } case ChiSqSelector.FWE => chiSqTestResult .filter { case (res, _) => res.pValue < fwe / chiSqTestResult.length } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/HashingTF.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/HashingTF.scala index 7b73b286fb91c..8935c8496cdbb 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/HashingTF.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/HashingTF.scala @@ -160,7 +160,7 @@ object HashingTF { case d: Double => hashLong(java.lang.Double.doubleToLongBits(d), seed) case s: String => val utf8 = UTF8String.fromString(s) - hashUnsafeBytesBlock(utf8.getMemoryBlock(), seed) + hashUnsafeBytes(utf8.getBaseObject, utf8.getBaseOffset, utf8.numBytes(), seed) case _ => throw new SparkException("HashingTF with murmur3 algorithm does not " + s"support type ${term.getClass.getCanonicalName} of input data.") } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/AssociationRules.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/AssociationRules.scala index acb83ac31affd..43d256bbc46c3 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/fpm/AssociationRules.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/AssociationRules.scala @@ -56,11 +56,24 @@ class AssociationRules private[fpm] ( /** * Computes the association rules with confidence above `minConfidence`. * @param freqItemsets frequent itemset model obtained from [[FPGrowth]] - * @return a `Set[Rule[Item]]` containing the association rules. + * @return a `RDD[Rule[Item]]` containing the association rules. * */ @Since("1.5.0") def run[Item: ClassTag](freqItemsets: RDD[FreqItemset[Item]]): RDD[Rule[Item]] = { + run(freqItemsets, Map.empty[Item, Double]) + } + + /** + * Computes the association rules with confidence above `minConfidence`. + * @param freqItemsets frequent itemset model obtained from [[FPGrowth]] + * @param itemSupport map containing an item and its support + * @return a `RDD[Rule[Item]]` containing the association rules. The rules will be able to + * compute also the lift metric. + */ + @Since("2.4.0") + def run[Item: ClassTag](freqItemsets: RDD[FreqItemset[Item]], + itemSupport: scala.collection.Map[Item, Double]): RDD[Rule[Item]] = { // For candidate rule X => Y, generate (X, (Y, freq(X union Y))) val candidates = freqItemsets.flatMap { itemset => val items = itemset.items @@ -76,8 +89,13 @@ class AssociationRules private[fpm] ( // Join to get (X, ((Y, freq(X union Y)), freq(X))), generate rules, and filter by confidence candidates.join(freqItemsets.map(x => (x.items.toSeq, x.freq))) .map { case (antecendent, ((consequent, freqUnion), freqAntecedent)) => - new Rule(antecendent.toArray, consequent.toArray, freqUnion, freqAntecedent) - }.filter(_.confidence >= minConfidence) + new Rule(antecendent.toArray, + consequent.toArray, + freqUnion, + freqAntecedent, + // the consequent contains always only one element + itemSupport.get(consequent.head)) + }.filter(_.confidence >= minConfidence) } /** @@ -107,14 +125,21 @@ object AssociationRules { @Since("1.5.0") val antecedent: Array[Item], @Since("1.5.0") val consequent: Array[Item], freqUnion: Double, - freqAntecedent: Double) extends Serializable { + freqAntecedent: Double, + freqConsequent: Option[Double]) extends Serializable { /** * Returns the confidence of the rule. * */ @Since("1.5.0") - def confidence: Double = freqUnion.toDouble / freqAntecedent + def confidence: Double = freqUnion / freqAntecedent + + /** + * Returns the lift of the rule. + */ + @Since("2.4.0") + def lift: Option[Double] = freqConsequent.map(fCons => confidence / fCons) require(antecedent.toSet.intersect(consequent.toSet).isEmpty, { val sharedItems = antecedent.toSet.intersect(consequent.toSet) @@ -142,7 +167,7 @@ object AssociationRules { override def toString: String = { s"${antecedent.mkString("{", ",", "}")} => " + - s"${consequent.mkString("{", ",", "}")}: ${confidence}" + s"${consequent.mkString("{", ",", "}")}: (confidence: $confidence; lift: $lift)" } } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala index 4f2b7e6f0764e..3a1bc35186dc3 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala @@ -48,9 +48,14 @@ import org.apache.spark.storage.StorageLevel * @tparam Item item type */ @Since("1.3.0") -class FPGrowthModel[Item: ClassTag] @Since("1.3.0") ( - @Since("1.3.0") val freqItemsets: RDD[FreqItemset[Item]]) +class FPGrowthModel[Item: ClassTag] @Since("2.4.0") ( + @Since("1.3.0") val freqItemsets: RDD[FreqItemset[Item]], + @Since("2.4.0") val itemSupport: Map[Item, Double]) extends Saveable with Serializable { + + @Since("1.3.0") + def this(freqItemsets: RDD[FreqItemset[Item]]) = this(freqItemsets, Map.empty) + /** * Generates association rules for the `Item`s in [[freqItemsets]]. * @param confidence minimal confidence of the rules produced @@ -58,7 +63,7 @@ class FPGrowthModel[Item: ClassTag] @Since("1.3.0") ( @Since("1.5.0") def generateAssociationRules(confidence: Double): RDD[AssociationRules.Rule[Item]] = { val associationRules = new AssociationRules(confidence) - associationRules.run(freqItemsets) + associationRules.run(freqItemsets, itemSupport) } /** @@ -213,9 +218,12 @@ class FPGrowth private[spark] ( val minCount = math.ceil(minSupport * count).toLong val numParts = if (numPartitions > 0) numPartitions else data.partitions.length val partitioner = new HashPartitioner(numParts) - val freqItems = genFreqItems(data, minCount, partitioner) - val freqItemsets = genFreqItemsets(data, minCount, freqItems, partitioner) - new FPGrowthModel(freqItemsets) + val freqItemsCount = genFreqItems(data, minCount, partitioner) + val freqItemsets = genFreqItemsets(data, minCount, freqItemsCount.map(_._1), partitioner) + val itemSupport = freqItemsCount.map { + case (item, cnt) => item -> cnt.toDouble / count + }.toMap + new FPGrowthModel(freqItemsets, itemSupport) } /** @@ -231,12 +239,12 @@ class FPGrowth private[spark] ( * Generates frequent items by filtering the input data using minimal support level. * @param minCount minimum count for frequent itemsets * @param partitioner partitioner used to distribute items - * @return array of frequent pattern ordered by their frequencies + * @return array of frequent patterns and their frequencies ordered by their frequencies */ private def genFreqItems[Item: ClassTag]( data: RDD[Array[Item]], minCount: Long, - partitioner: Partitioner): Array[Item] = { + partitioner: Partitioner): Array[(Item, Long)] = { data.flatMap { t => val uniq = t.toSet if (t.length != uniq.size) { @@ -248,7 +256,6 @@ class FPGrowth private[spark] ( .filter(_._2 >= minCount) .collect() .sortBy(-_._2) - .map(_._1) } /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala index ac709ad72f0c0..7b49d4d0812f9 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala @@ -78,8 +78,13 @@ class MatrixFactorizationModel @Since("0.8.0") ( /** Predict the rating of one user for one product. */ @Since("0.8.0") def predict(user: Int, product: Int): Double = { - val userVector = userFeatures.lookup(user).head - val productVector = productFeatures.lookup(product).head + val userFeatureSeq = userFeatures.lookup(user) + require(userFeatureSeq.nonEmpty, s"userId: $user not found in the model") + val productFeatureSeq = productFeatures.lookup(product) + require(productFeatureSeq.nonEmpty, s"productId: $product not found in the model") + + val userVector = userFeatureSeq.head + val productVector = productFeatureSeq.head blas.ddot(rank, userVector, 1, productVector, 1) } @@ -164,9 +169,12 @@ class MatrixFactorizationModel @Since("0.8.0") ( * recommended the product is. */ @Since("1.1.0") - def recommendProducts(user: Int, num: Int): Array[Rating] = - MatrixFactorizationModel.recommend(userFeatures.lookup(user).head, productFeatures, num) + def recommendProducts(user: Int, num: Int): Array[Rating] = { + val userFeatureSeq = userFeatures.lookup(user) + require(userFeatureSeq.nonEmpty, s"userId: $user not found in the model") + MatrixFactorizationModel.recommend(userFeatureSeq.head, productFeatures, num) .map(t => Rating(user, t._1, t._2)) + } /** * Recommends users to a product. That is, this returns users who are most likely to be @@ -181,9 +189,12 @@ class MatrixFactorizationModel @Since("0.8.0") ( * recommended the user is. */ @Since("1.1.0") - def recommendUsers(product: Int, num: Int): Array[Rating] = - MatrixFactorizationModel.recommend(productFeatures.lookup(product).head, userFeatures, num) + def recommendUsers(product: Int, num: Int): Array[Rating] = { + val productFeatureSeq = productFeatures.lookup(product) + require(productFeatureSeq.nonEmpty, s"productId: $product not found in the model") + MatrixFactorizationModel.recommend(productFeatureSeq.head, userFeatures, num) .map(t => Rating(t._1, product, t._2)) + } protected override val formatVersion: String = "1.0" diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala index 81842afbddbbb..1b7780e171e77 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala @@ -133,6 +133,7 @@ class BisectingKMeansSuite extends MLTest with DefaultReadWriteTest { assert(clusterSizes.length === k) assert(clusterSizes.sum === numRows) assert(clusterSizes.forall(_ >= 0)) + assert(summary.numIter == 20) model.setSummary(None) assert(!model.hasSummary) diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala index 0b91f502f615b..13bed9dbe3e89 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala @@ -145,6 +145,7 @@ class GaussianMixtureSuite extends MLTest with DefaultReadWriteTest { assert(clusterSizes.length === k) assert(clusterSizes.sum === numRows) assert(clusterSizes.forall(_ >= 0)) + assert(summary.numIter == 2) model.setSummary(None) assert(!model.hasSummary) diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala index 2569e7a432ca4..ccbceab53bb66 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala @@ -131,10 +131,13 @@ class KMeansSuite extends MLTest with DefaultReadWriteTest with PMMLReadWriteTes assert(summary.predictions.columns.contains(c)) } assert(summary.cluster.columns === Array(predictionColName)) + assert(summary.trainingCost < 0.1) + assert(model.computeCost(dataset) == summary.trainingCost) val clusterSizes = summary.clusterSizes assert(clusterSizes.length === k) assert(clusterSizes.sum === numRows) assert(clusterSizes.forall(_ >= 0)) + assert(summary.numIter == 1) model.setSummary(None) assert(!model.hasSummary) @@ -231,7 +234,7 @@ class KMeansSuite extends MLTest with DefaultReadWriteTest with PMMLReadWriteTes val oldKmeansModel = new MLlibKMeansModel(clusterCenters) val kmeansModel = new KMeansModel("", oldKmeansModel) def checkModel(pmml: PMML): Unit = { - // Check the header descripiton is what we expect + // Check the header description is what we expect assert(pmml.getHeader.getDescription === "k-means clustering") // check that the number of fields match the single vector size assert(pmml.getDataDictionary.getNumberOfFields === clusterCenters(0).size) diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala index db92132d18b7b..bbd5408c9fce3 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala @@ -285,7 +285,7 @@ class LDASuite extends MLTest with DefaultReadWriteTest { // There should be 1 checkpoint remaining. assert(model.getCheckpointFiles.length === 1) val checkpointFile = new Path(model.getCheckpointFiles.head) - val fs = checkpointFile.getFileSystem(spark.sparkContext.hadoopConfiguration) + val fs = checkpointFile.getFileSystem(spark.sessionState.newHadoopConf()) assert(fs.exists(checkpointFile)) model.deleteCheckpointFiles() assert(model.getCheckpointFiles.isEmpty) diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/PowerIterationClusteringSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/PowerIterationClusteringSuite.scala index b7072728d48f0..55b460f1a4524 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/PowerIterationClusteringSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/PowerIterationClusteringSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.ml.clustering +import scala.collection.mutable + import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.ml.util.DefaultReadWriteTest import org.apache.spark.mllib.util.MLlibTestSparkContext @@ -76,12 +78,15 @@ class PowerIterationClusteringSuite extends SparkFunSuite .setMaxIter(40) .setWeightCol("weight") .assignClusters(data) - val localAssignments = assignments - .select('id, 'cluster) - .as[(Long, Int)].collect().toSet - val expectedResult = (0 until n1).map(x => (x, 1)).toSet ++ - (n1 until n).map(x => (x, 0)).toSet - assert(localAssignments === expectedResult) + .select("id", "cluster") + .as[(Long, Int)] + .collect() + + val predictions = Array.fill(2)(mutable.Set.empty[Long]) + assignments.foreach { + case (id, cluster) => predictions(cluster) += id + } + assert(predictions.toSet === Set((0 until n1).toSet, (n1 until n).toSet)) val assignments2 = new PowerIterationClustering() .setK(2) @@ -89,10 +94,15 @@ class PowerIterationClusteringSuite extends SparkFunSuite .setInitMode("degree") .setWeightCol("weight") .assignClusters(data) - val localAssignments2 = assignments2 - .select('id, 'cluster) - .as[(Long, Int)].collect().toSet - assert(localAssignments2 === expectedResult) + .select("id", "cluster") + .as[(Long, Int)] + .collect() + + val predictions2 = Array.fill(2)(mutable.Set.empty[Long]) + assignments2.foreach { + case (id, cluster) => predictions2(cluster) += id + } + assert(predictions2.toSet === Set((0 until n1).toSet, (n1 until n).toSet)) } test("supported input types") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluatorSuite.scala index ede284712b1c0..2b0909acf69c3 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluatorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluatorSuite.scala @@ -67,8 +67,8 @@ class BinaryClassificationEvaluatorSuite evaluator.evaluate(stringDF) } assert(thrown.getMessage.replace("\n", "") contains "Column rawPrediction must be of type " + - "equal to one of the following types: [DoubleType, ") - assert(thrown.getMessage.replace("\n", "") contains "but was actually of type StringType.") + "equal to one of the following types: [double, ") + assert(thrown.getMessage.replace("\n", "") contains "but was actually of type string.") } test("should support all NumericType labels and not support other types") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/evaluation/ClusteringEvaluatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/evaluation/ClusteringEvaluatorSuite.scala index 2c175ff68e0b8..e2d77560293fa 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/evaluation/ClusteringEvaluatorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/evaluation/ClusteringEvaluatorSuite.scala @@ -21,7 +21,7 @@ import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.ml.attribute.AttributeGroup import org.apache.spark.ml.linalg.Vector import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.util.DefaultReadWriteTest +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.ml.util.TestingUtils._ import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.Dataset @@ -33,10 +33,17 @@ class ClusteringEvaluatorSuite import testImplicits._ @transient var irisDataset: Dataset[_] = _ + @transient var newIrisDataset: Dataset[_] = _ + @transient var newIrisDatasetD: Dataset[_] = _ + @transient var newIrisDatasetF: Dataset[_] = _ override def beforeAll(): Unit = { super.beforeAll() irisDataset = spark.read.format("libsvm").load("../data/mllib/iris_libsvm.txt") + val datasets = MLTestingUtils.generateArrayFeatureDataset(irisDataset) + newIrisDataset = datasets._1 + newIrisDatasetD = datasets._2 + newIrisDatasetF = datasets._3 } test("params") { @@ -66,6 +73,9 @@ class ClusteringEvaluatorSuite .setPredictionCol("label") assert(evaluator.evaluate(irisDataset) ~== 0.6564679231 relTol 1e-5) + assert(evaluator.evaluate(newIrisDataset) ~== 0.6564679231 relTol 1e-5) + assert(evaluator.evaluate(newIrisDatasetD) ~== 0.6564679231 relTol 1e-5) + assert(evaluator.evaluate(newIrisDatasetF) ~== 0.6564679231 relTol 1e-5) } /* @@ -85,6 +95,9 @@ class ClusteringEvaluatorSuite .setDistanceMeasure("cosine") assert(evaluator.evaluate(irisDataset) ~== 0.7222369298 relTol 1e-5) + assert(evaluator.evaluate(newIrisDataset) ~== 0.7222369298 relTol 1e-5) + assert(evaluator.evaluate(newIrisDatasetD) ~== 0.7222369298 relTol 1e-5) + assert(evaluator.evaluate(newIrisDatasetF) ~== 0.7222369298 relTol 1e-5) } test("number of clusters must be greater than one") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala index c843df9f33e3e..80499e79e3bd6 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala @@ -163,6 +163,17 @@ class ChiSqSelectorSuite extends MLTest with DefaultReadWriteTest { } } + test("SPARK-25289: ChiSqSelector should not fail when selecting no features with FDR") { + val labeledPoints = (0 to 1).map { n => + val v = Vectors.dense((1 to 3).map(_ => n * 1.0).toArray) + (n.toDouble, v) + } + val inputDF = spark.createDataFrame(labeledPoints).toDF("label", "features") + val selector = new ChiSqSelector().setSelectorType("fdr").setFdr(0.05) + val model = selector.fit(inputDF) + assert(model.selectedFeatures.isEmpty) + } + private def testSelector(selector: ChiSqSelector, data: Dataset[_]): ChiSqSelectorModel = { val selectorModel = selector.fit(data) testTransformer[(Double, Vector, Vector)](data.toDF(), selectorModel, diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala index 61217669d9277..bca580d411373 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala @@ -289,4 +289,20 @@ class CountVectorizerSuite extends MLTest with DefaultReadWriteTest { val newInstance = testDefaultReadWrite(instance) assert(newInstance.vocabulary === instance.vocabulary) } + + test("SPARK-22974: CountVectorModel should attach proper attribute to output column") { + val df = spark.createDataFrame(Seq( + (0, 1.0, Array("a", "b", "c")), + (1, 2.0, Array("a", "b", "b", "c", "a", "d")) + )).toDF("id", "features1", "words") + + val cvm = new CountVectorizerModel(Array("a", "b", "c")) + .setInputCol("words") + .setOutputCol("features2") + + val df1 = cvm.transform(df) + val interaction = new Interaction().setInputCols(Array("features1", "features2")) + .setOutputCol("features") + interaction.transform(df1) + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala index a250331efeb1d..0de6528c4cf22 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala @@ -105,7 +105,7 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest { testTransformerByInterceptingException[(Int, Boolean)]( original, model, - "Label column already exists and is not of type NumericType.", + "Label column already exists and is not of type numeric.", "x") } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala index 91fb24a268b8c..a4d388fd321db 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala @@ -99,9 +99,9 @@ class VectorAssemblerSuite assembler.transform(df) } assert(thrown.getMessage contains - "Data type StringType of column a is not supported.\n" + - "Data type StringType of column b is not supported.\n" + - "Data type StringType of column c is not supported.") + "Data type string of column a is not supported.\n" + + "Data type string of column b is not supported.\n" + + "Data type string of column c is not supported.") } test("ML attributes") { @@ -256,4 +256,9 @@ class VectorAssemblerSuite assert(runWithMetadata("keep", additional_filter = "id1 > 2").count() == 4) } + test("SPARK-25371: VectorAssembler with empty inputCols") { + val vectorAssembler = new VectorAssembler().setInputCols(Array()).setOutputCol("a") + val output = vectorAssembler.transform(dfWithNullsAndNaNs) + assert(output.select("a").limit(1).collect().head == Row(Vectors.sparse(0, Seq.empty))) + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/fpm/FPGrowthSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/fpm/FPGrowthSuite.scala index 87f8b9034dde8..b75526a48371a 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/fpm/FPGrowthSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/fpm/FPGrowthSuite.scala @@ -39,9 +39,9 @@ class FPGrowthSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul val model = new FPGrowth().setMinSupport(0.5).fit(data) val generatedRules = model.setMinConfidence(0.5).associationRules val expectedRules = spark.createDataFrame(Seq( - (Array("2"), Array("1"), 1.0), - (Array("1"), Array("2"), 0.75) - )).toDF("antecedent", "consequent", "confidence") + (Array("2"), Array("1"), 1.0, 1.0), + (Array("1"), Array("2"), 0.75, 1.0) + )).toDF("antecedent", "consequent", "confidence", "lift") .withColumn("antecedent", col("antecedent").cast(ArrayType(dt))) .withColumn("consequent", col("consequent").cast(ArrayType(dt))) assert(expectedRules.sort("antecedent").rdd.collect().sameElements( diff --git a/mllib/src/test/scala/org/apache/spark/ml/image/ImageSchemaSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/image/ImageSchemaSuite.scala index 527b3f8955968..e16ec906c90b1 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/image/ImageSchemaSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/image/ImageSchemaSuite.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.types._ class ImageSchemaSuite extends SparkFunSuite with MLlibTestSparkContext { // Single column of images named "image" - private lazy val imagePath = "../data/mllib/images" + private lazy val imagePath = "../data/mllib/images/origin" test("Smoke test: create basic ImageSchema dataframe") { val origin = "path" diff --git a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala index e3dfe2faf5698..9a59c41740daf 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala @@ -594,11 +594,12 @@ class ALSSuite extends MLTest with DefaultReadWriteTest with Logging { (check: (ALSModel, ALSModel) => Unit) (check2: (ALSModel, ALSModel, DataFrame, Encoder[_]) => Unit): Unit = { val dfs = genRatingsDFWithNumericCols(spark, column) - val df = dfs.find { - case (numericTypeWithEncoder, _) => numericTypeWithEncoder.numericType == baseType - } match { - case Some((_, df)) => df + val maybeDf = dfs.find { case (numericTypeWithEncoder, _) => + numericTypeWithEncoder.numericType == baseType } + assert(maybeDf.isDefined) + val df = maybeDf.get._2 + val expected = estimator.fit(df) val actuals = dfs.filter(_ != baseType).map(t => (t, estimator.fit(t._2))) actuals.foreach { case (_, actual) => check(expected, actual) } @@ -612,7 +613,7 @@ class ALSSuite extends MLTest with DefaultReadWriteTest with Logging { estimator.fit(strDF) } assert(thrown.getMessage.contains( - s"$column must be of type NumericType but was actually of type StringType")) + s"$column must be of type numeric but was actually of type string")) } private class NumericTypeWithEncoder[A](val numericType: NumericType) diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala index 4e4ff71c9de90..6cc73e040e82c 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala @@ -385,7 +385,7 @@ class AFTSurvivalRegressionSuite extends MLTest with DefaultReadWriteTest { aft.fit(dfWithStringCensors) } assert(thrown.getMessage.contains( - "Column censor must be of type NumericType but was actually of type StringType")) + "Column censor must be of type numeric but was actually of type string")) } test("numerical stability of standardization") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala index 997c50157dcda..600a43242751f 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.ml.regression import scala.util.Random -import org.apache.spark.SparkFunSuite +import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.ml.classification.LogisticRegressionSuite._ import org.apache.spark.ml.feature.{Instance, OffsetInstance} import org.apache.spark.ml.feature.{LabeledPoint, RFormula} @@ -29,6 +29,7 @@ import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils} import org.apache.spark.ml.util.TestingUtils._ import org.apache.spark.mllib.random._ import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.serializer.KryoSerializer import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.FloatType @@ -1687,6 +1688,14 @@ class GeneralizedLinearRegressionSuite extends MLTest with DefaultReadWriteTest assert(evalSummary.deviance === summary.deviance) assert(evalSummary.aic === summary.aic) } + + test("SPARK-23131 Kryo raises StackOverflow during serializing GLR model") { + val conf = new SparkConf(false) + val ser = new KryoSerializer(conf).newInstance() + val trainer = new GeneralizedLinearRegression() + val model = trainer.fit(Seq(Instance(1.0, 1.0, Vectors.dense(1.0, 7.0))).toDF) + ser.serialize[GeneralizedLinearRegressionModel](model) + } } object GeneralizedLinearRegressionSuite { diff --git a/mllib/src/test/scala/org/apache/spark/ml/source/image/ImageFileFormatSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/source/image/ImageFileFormatSuite.scala new file mode 100644 index 0000000000000..1a6a8d67d8d66 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/source/image/ImageFileFormatSuite.scala @@ -0,0 +1,119 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.source.image + +import java.nio.file.Paths + +import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.image.ImageSchema._ +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.sql.Row +import org.apache.spark.sql.functions.{col, substring_index} + +class ImageFileFormatSuite extends SparkFunSuite with MLlibTestSparkContext { + + // Single column of images named "image" + private lazy val imagePath = "../data/mllib/images/partitioned" + + test("image datasource count test") { + val df1 = spark.read.format("image").load(imagePath) + assert(df1.count === 9) + + val df2 = spark.read.format("image").option("dropInvalid", true).load(imagePath) + assert(df2.count === 8) + } + + test("image datasource test: read jpg image") { + val df = spark.read.format("image").load(imagePath + "/cls=kittens/date=2018-02/DP153539.jpg") + assert(df.count() === 1) + } + + test("image datasource test: read png image") { + val df = spark.read.format("image").load(imagePath + "/cls=multichannel/date=2018-01/BGRA.png") + assert(df.count() === 1) + } + + test("image datasource test: read non image") { + val filePath = imagePath + "/cls=kittens/date=2018-01/not-image.txt" + val df = spark.read.format("image").option("dropInvalid", true) + .load(filePath) + assert(df.count() === 0) + + val df2 = spark.read.format("image").option("dropInvalid", false) + .load(filePath) + assert(df2.count() === 1) + val result = df2.head() + assert(result === invalidImageRow( + Paths.get(filePath).toAbsolutePath().normalize().toUri().toString)) + } + + test("image datasource partition test") { + val result = spark.read.format("image") + .option("dropInvalid", true).load(imagePath) + .select(substring_index(col("image.origin"), "/", -1).as("origin"), col("cls"), col("date")) + .collect() + + assert(Set(result: _*) === Set( + Row("29.5.a_b_EGDP022204.jpg", "kittens", "2018-01"), + Row("54893.jpg", "kittens", "2018-02"), + Row("DP153539.jpg", "kittens", "2018-02"), + Row("DP802813.jpg", "kittens", "2018-02"), + Row("BGRA.png", "multichannel", "2018-01"), + Row("BGRA_alpha_60.png", "multichannel", "2018-01"), + Row("chr30.4.184.jpg", "multichannel", "2018-02"), + Row("grayscale.jpg", "multichannel", "2018-02") + )) + } + + // Images with the different number of channels + test("readImages pixel values test") { + val images = spark.read.format("image").option("dropInvalid", true) + .load(imagePath + "/cls=multichannel/").collect() + + val firstBytes20Set = images.map { rrow => + val row = rrow.getAs[Row]("image") + val filename = Paths.get(getOrigin(row)).getFileName().toString() + val mode = getMode(row) + val bytes20 = getData(row).slice(0, 20).toList + filename -> Tuple2(mode, bytes20) // Cannot remove `Tuple2`, otherwise `->` operator + // will match 2 arguments + }.toSet + + assert(firstBytes20Set === expectedFirstBytes20Set) + } + + // number of channels and first 20 bytes of OpenCV representation + // - default representation for 3-channel RGB images is BGR row-wise: + // (B00, G00, R00, B10, G10, R10, ...) + // - default representation for 4-channel RGB images is BGRA row-wise: + // (B00, G00, R00, A00, B10, G10, R10, A10, ...) + private val expectedFirstBytes20Set = Set( + "grayscale.jpg" -> + ((0, List[Byte](-2, -33, -61, -60, -59, -59, -64, -59, -66, -67, -73, -73, -62, + -57, -60, -63, -53, -49, -55, -69))), + "chr30.4.184.jpg" -> ((16, + List[Byte](-9, -3, -1, -43, -32, -28, -75, -60, -57, -78, -59, -56, -74, -59, -57, + -71, -58, -56, -73, -64))), + "BGRA.png" -> ((24, + List[Byte](-128, -128, -8, -1, -128, -128, -8, -1, -128, + -128, -8, -1, 127, 127, -9, -1, 127, 127, -9, -1))), + "BGRA_alpha_60.png" -> ((24, + List[Byte](-128, -128, -8, 60, -128, -128, -8, 60, -128, + -128, -8, 60, 127, 127, -9, 60, 127, 127, -9, 60))) + ) +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/MLTest.scala b/mllib/src/test/scala/org/apache/spark/ml/util/MLTest.scala index 76d41f9b23715..acac171346a85 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/util/MLTest.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/util/MLTest.scala @@ -21,12 +21,13 @@ import java.io.File import org.scalatest.Suite -import org.apache.spark.SparkContext +import org.apache.spark.{DebugFilesystem, SparkConf, SparkContext} import org.apache.spark.ml.{PredictionModel, Transformer} import org.apache.spark.ml.linalg.Vector import org.apache.spark.sql.{DataFrame, Dataset, Encoder, Row} import org.apache.spark.sql.execution.streaming.MemoryStream import org.apache.spark.sql.functions.col +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.StreamTest import org.apache.spark.sql.test.TestSparkSession import org.apache.spark.util.Utils @@ -36,6 +37,13 @@ trait MLTest extends StreamTest with TempDirectory { self: Suite => @transient var sc: SparkContext = _ @transient var checkpointDir: String = _ + protected override def sparkConf = { + new SparkConf() + .set("spark.hadoop.fs.file.impl", classOf[DebugFilesystem].getName) + .set("spark.unsafe.exceptionOnMemoryLeak", "true") + .set(SQLConf.CODEGEN_FALLBACK.key, "false") + } + protected override def createSparkSession: TestSparkSession = { new TestSparkSession(new SparkContext("local[2]", "MLlibUnitTest", sparkConf)) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala b/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala index 5e72b4d864c1d..91a8b14625a86 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala @@ -74,7 +74,7 @@ object MLTestingUtils extends SparkFunSuite { estimator.fit(dfWithStringLabels) } assert(thrown.getMessage.contains( - "Column label must be of type NumericType but was actually of type StringType")) + "Column label must be of type numeric but was actually of type string")) estimator match { case weighted: Estimator[M] with HasWeightCol => @@ -86,7 +86,7 @@ object MLTestingUtils extends SparkFunSuite { weighted.fit(dfWithStringWeights) } assert(thrown.getMessage.contains( - "Column weight must be of type NumericType but was actually of type StringType")) + "Column weight must be of type numeric but was actually of type string")) case _ => } } @@ -104,7 +104,7 @@ object MLTestingUtils extends SparkFunSuite { evaluator.evaluate(dfWithStringLabels) } assert(thrown.getMessage.contains( - "Column label must be of type NumericType but was actually of type StringType")) + "Column label must be of type numeric but was actually of type string")) } def genClassifDFWithNumericLabelCol( diff --git a/mllib/src/test/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModelSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModelSuite.scala index 2c8ed057a516a..5ed9d077afe78 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModelSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModelSuite.scala @@ -72,6 +72,27 @@ class MatrixFactorizationModelSuite extends SparkFunSuite with MLlibTestSparkCon } } + test("invalid user and product") { + val model = new MatrixFactorizationModel(rank, userFeatures, prodFeatures) + + intercept[IllegalArgumentException] { + // invalid user + model.predict(5, 2) + } + intercept[IllegalArgumentException] { + // invalid product + model.predict(0, 5) + } + intercept[IllegalArgumentException] { + // invalid user + model.recommendProducts(5, 2) + } + intercept[IllegalArgumentException] { + // invalid product + model.recommendUsers(5, 2) + } + } + test("batch predict API recommendProductsForUsers") { val model = new MatrixFactorizationModel(rank, userFeatures, prodFeatures) val topK = 10 diff --git a/pom.xml b/pom.xml index 90e64ff71d229..05e3b05613efd 100644 --- a/pom.xml +++ b/pom.xml @@ -104,6 +104,7 @@ external/kafka-0-10 external/kafka-0-10-assembly external/kafka-0-10-sql + external/avro @@ -113,7 +114,7 @@ 1.8 ${java.version} ${java.version} - 3.3.9 + 3.5.4 spark 1.7.16 1.2.17 @@ -130,26 +131,25 @@ 1.2.1 10.12.1.1 1.10.0 - 1.4.4 + 1.5.2 nohive 1.6.0 - 9.3.20.v20170531 + 9.3.24.v20180605 3.1.0 - 0.8.4 + 0.9.3 2.4.0 2.0.8 3.1.5 - 1.7.7 + 1.8.2 hadoop2 - 0.9.4 - 1.7.3 + 1.8.10 - 1.11.76 + 1.11.271 - 0.10.2 + 0.12.8 - 4.5.4 - 4.4.8 + 4.5.6 + 4.4.10 3.1 3.4.1 @@ -170,7 +170,7 @@ 3.5 3.2.10 - 3.0.8 + 3.0.9 2.22.2 2.9.3 3.5.2 @@ -189,10 +189,11 @@ If you are changing Arrow version specification, please check ./python/pyspark/sql/utils.py, ./python/run-tests.py and ./python/setup.py too. --> - 0.8.0 + 0.10.0 ${java.home} + org.spark_project @@ -313,13 +314,13 @@ chill-java ${chill.version}
      - org.apache.xbean - xbean-asm5-shaded - 4.4 + xbean-asm6-shaded + 4.8 + - net.java.dev.jets3t - jets3t - ${jets3t.version} + javax.activation + activation + 1.1.1 ${hadoop.deps.scope} - - - commons-logging - commons-logging - - - - - org.bouncycastle - bcprov-jdk15on - - 1.58 org.apache.hadoop @@ -1742,6 +1737,10 @@ org.apache.hadoop hadoop-common + + org.apache.hadoop + hadoop-hdfs + org.apache.hive hive-storage-api @@ -1771,6 +1770,10 @@ org.apache.hive hive-storage-api + + com.esotericsoftware + kryo-shaded + @@ -2122,7 +2125,7 @@ org.apache.maven.plugins maven-surefire-plugin - 2.20.1 + 2.22.0 @@ -2161,6 +2164,7 @@ false ${test.exclude.tags} + ${test.include.tags} @@ -2208,6 +2212,7 @@ __not_used__ ${test.exclude.tags} + ${test.include.tags} diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index eeb097ef153ad..55dc2b81cfe2f 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -36,6 +36,29 @@ object MimaExcludes { // Exclude rules for 2.4.x lazy val v24excludes = v23excludes ++ Seq( + // [SPARK-23429][CORE] Add executor memory metrics to heartbeat and expose in executors REST API + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.scheduler.SparkListenerExecutorMetricsUpdate.apply"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.scheduler.SparkListenerExecutorMetricsUpdate.copy"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.scheduler.SparkListenerExecutorMetricsUpdate.this"), + ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.scheduler.SparkListenerExecutorMetricsUpdate$"), + + // [SPARK-25248] add package private methods to TaskContext + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.TaskContext.markTaskFailed"), + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.TaskContext.markInterrupted"), + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.TaskContext.fetchFailed"), + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.TaskContext.markTaskCompleted"), + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.TaskContext.getLocalProperties"), + + // [SPARK-10697][ML] Add lift to Association rules + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.fpm.FPGrowthModel.this"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.mllib.fpm.AssociationRules#Rule.this"), + + // [SPARK-24296][CORE] Replicate large blocks as a stream. + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.network.netty.NettyBlockRpcServer.this"), + // [SPARK-23528] Add numIter to ClusteringSummary + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.clustering.ClusteringSummary.this"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.clustering.KMeansSummary.this"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.clustering.BisectingKMeansSummary.this"), // [SPARK-6237][NETWORK] Network-layer changes to allow stream upload ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.network.netty.NettyBlockRpcServer.receive"), @@ -92,7 +115,10 @@ object MimaExcludes { ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasValidationIndicatorCol.validationIndicatorCol"), ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasValidationIndicatorCol.getValidationIndicatorCol"), ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasValidationIndicatorCol.org$apache$spark$ml$param$shared$HasValidationIndicatorCol$_setter_$validationIndicatorCol_="), - ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasValidationIndicatorCol.validationIndicatorCol") + ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasValidationIndicatorCol.validationIndicatorCol"), + + // [SPARK-23042] Use OneHotEncoderModel to encode labels in MultilayerPerceptronClassifier + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.ml.classification.LabelConverter") ) // Exclude rules for 2.3.x diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index b606f9355e03b..a5ed9088eaa4d 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -40,8 +40,8 @@ object BuildCommons { private val buildLocation = file(".").getAbsoluteFile.getParentFile - val sqlProjects@Seq(catalyst, sql, hive, hiveThriftServer, sqlKafka010) = Seq( - "catalyst", "sql", "hive", "hive-thriftserver", "sql-kafka-0-10" + val sqlProjects@Seq(catalyst, sql, hive, hiveThriftServer, sqlKafka010, avro) = Seq( + "catalyst", "sql", "hive", "hive-thriftserver", "sql-kafka-0-10", "avro" ).map(ProjectRef(buildLocation, _)) val streamingProjects@Seq(streaming, streamingKafka010) = @@ -94,6 +94,12 @@ object SparkBuild extends PomBuild { case Some(v) => v.split("(\\s+|,)").filterNot(_.isEmpty).map(_.trim.replaceAll("-P", "")).toSeq } + + Option(System.getProperty("scala.version")) + .filter(_.startsWith("2.12")) + .foreach { versionString => + System.setProperty("scala-2.12", "true") + } if (System.getProperty("scala-2.12") == "") { // To activate scala-2.10 profile, replace empty property value to non-empty value // in the same way as Maven which handles -Dname as -Dname=true before executes build process. @@ -212,7 +218,7 @@ object SparkBuild extends PomBuild { .map(file), incOptions := incOptions.value.withNameHashing(true), publishMavenStyle := true, - unidocGenjavadocVersion := "0.10", + unidocGenjavadocVersion := "0.11", // Override SBT's default resolvers: resolvers := Seq( @@ -326,7 +332,7 @@ object SparkBuild extends PomBuild { val mimaProjects = allProjects.filterNot { x => Seq( spark, hive, hiveThriftServer, catalyst, repl, networkCommon, networkShuffle, networkYarn, - unsafe, tags, sqlKafka010, kvstore + unsafe, tags, sqlKafka010, kvstore, avro ).contains(x) } @@ -464,7 +470,8 @@ object DockerIntegrationTests { */ object DependencyOverrides { lazy val settings = Seq( - dependencyOverrides += "com.google.guava" % "guava" % "14.0.1") + dependencyOverrides += "com.google.guava" % "guava" % "14.0.1", + dependencyOverrides += "jline" % "jline" % "2.14.6") } /** @@ -687,9 +694,11 @@ object Unidoc { publish := {}, unidocProjectFilter in(ScalaUnidoc, unidoc) := - inAnyProject -- inProjects(OldDeps.project, repl, examples, tools, streamingFlumeSink, kubernetes, yarn, tags, streamingKafka010, sqlKafka010), + inAnyProject -- inProjects(OldDeps.project, repl, examples, tools, streamingFlumeSink, kubernetes, + yarn, tags, streamingKafka010, sqlKafka010, avro), unidocProjectFilter in(JavaUnidoc, unidoc) := - inAnyProject -- inProjects(OldDeps.project, repl, examples, tools, streamingFlumeSink, kubernetes, yarn, tags, streamingKafka010, sqlKafka010), + inAnyProject -- inProjects(OldDeps.project, repl, examples, tools, streamingFlumeSink, kubernetes, + yarn, tags, streamingKafka010, sqlKafka010, avro), unidocAllClasspaths in (ScalaUnidoc, unidoc) := { ignoreClasspaths((unidocAllClasspaths in (ScalaUnidoc, unidoc)).value) diff --git a/python/docs/Makefile b/python/docs/Makefile index b8e079483c90c..1ed1f33af2326 100644 --- a/python/docs/Makefile +++ b/python/docs/Makefile @@ -1,19 +1,44 @@ # Makefile for Sphinx documentation # +ifndef SPHINXBUILD +ifndef SPHINXPYTHON +SPHINXBUILD = sphinx-build +endif +endif + +ifdef SPHINXBUILD +# User-friendly check for sphinx-build. +ifeq ($(shell which $(SPHINXBUILD) >/dev/null 2>&1; echo $$?), 1) +$(error The '$(SPHINXBUILD)' command was not found. Make sure you have Sphinx installed, then set the SPHINXBUILD environment variable to point to the full path of the '$(SPHINXBUILD)' executable. Alternatively you can add the directory with the executable to your PATH. If you don't have Sphinx installed, grab it from http://sphinx-doc.org/) +endif +else +# Note that there is an issue with Python version and Sphinx in PySpark documentation generation. +# Please remove this check below when this issue is fixed. See SPARK-24530 for more details. +PYTHON_VERSION_CHECK = $(shell $(SPHINXPYTHON) -c 'import sys; print(sys.version_info < (3, 0, 0))') +ifeq ($(PYTHON_VERSION_CHECK), True) +$(error Note that Python 3 is required to generate PySpark documentation correctly for now. Current Python executable was less than Python 3. See SPARK-24530. To force Sphinx to use a specific Python executable, please set SPHINXPYTHON to point to the Python 3 executable.) +endif +# Check if Sphinx is installed. +ifeq ($(shell $(SPHINXPYTHON) -c 'import sphinx' >/dev/null 2>&1; echo $$?), 1) +$(error Python executable '$(SPHINXPYTHON)' did not have Sphinx installed. Make sure you have Sphinx installed, then set the SPHINXPYTHON environment variable to point to the Python executable having Sphinx installed. If you don't have Sphinx installed, grab it from http://sphinx-doc.org/) +endif +# Use 'SPHINXPYTHON -msphinx' instead of 'sphinx-build'. See https://github.com/sphinx-doc/sphinx/pull/3523 for more details. +SPHINXBUILD = $(SPHINXPYTHON) -msphinx +endif + # You can set these variables from the command line. SPHINXOPTS ?= -SPHINXBUILD ?= sphinx-build PAPER ?= BUILDDIR ?= _build +# You can set SPHINXBUILD to specify Sphinx build executable or SPHINXPYTHON to specify the Python executable used in Sphinx. +# They follow: +# 1. if SPHINXPYTHON is set, use Python. If SPHINXBUILD is set, use sphinx-build. +# 2. If both are set, SPHINXBUILD has a higher priority over SPHINXPYTHON +# 3. By default, SPHINXBUILD is used as 'sphinx-build'. export PYTHONPATH=$(realpath ..):$(realpath ../lib/py4j-0.10.7-src.zip) -# User-friendly check for sphinx-build -ifeq ($(shell which $(SPHINXBUILD) >/dev/null 2>&1; echo $$?), 1) -$(error The '$(SPHINXBUILD)' command was not found. Make sure you have Sphinx installed, then set the SPHINXBUILD environment variable to point to the full path of the '$(SPHINXBUILD)' executable. Alternatively you can add the directory with the executable to your PATH. If you don't have Sphinx installed, grab it from http://sphinx-doc.org/) -endif - # Internal variables. PAPEROPT_a4 = -D latex_paper_size=a4 PAPEROPT_letter = -D latex_paper_size=letter diff --git a/python/pyspark/__init__.py b/python/pyspark/__init__.py index 58218918693ca..ee153af18c88c 100644 --- a/python/pyspark/__init__.py +++ b/python/pyspark/__init__.py @@ -36,7 +36,12 @@ Finer-grained cache persistence levels. - :class:`TaskContext`: Information about the current running task, available on the workers and experimental. - + - :class:`RDDBarrier`: + Wraps an RDD under a barrier stage for barrier execution. + - :class:`BarrierTaskContext`: + A :class:`TaskContext` that provides extra info and tooling for barrier execution. + - :class:`BarrierTaskInfo`: + Information about a barrier task. """ from functools import wraps @@ -44,14 +49,14 @@ from pyspark.conf import SparkConf from pyspark.context import SparkContext -from pyspark.rdd import RDD +from pyspark.rdd import RDD, RDDBarrier from pyspark.files import SparkFiles from pyspark.storagelevel import StorageLevel from pyspark.accumulators import Accumulator, AccumulatorParam from pyspark.broadcast import Broadcast from pyspark.serializers import MarshalSerializer, PickleSerializer from pyspark.status import * -from pyspark.taskcontext import TaskContext +from pyspark.taskcontext import TaskContext, BarrierTaskContext, BarrierTaskInfo from pyspark.profiler import Profiler, BasicProfiler from pyspark.version import __version__ from pyspark._globals import _NoValue @@ -113,4 +118,5 @@ def wrapper(self, *args, **kwargs): "SparkConf", "SparkContext", "SparkFiles", "RDD", "StorageLevel", "Broadcast", "Accumulator", "AccumulatorParam", "MarshalSerializer", "PickleSerializer", "StatusTracker", "SparkJobInfo", "SparkStageInfo", "Profiler", "BasicProfiler", "TaskContext", + "RDDBarrier", "BarrierTaskContext", "BarrierTaskInfo", ] diff --git a/python/pyspark/accumulators.py b/python/pyspark/accumulators.py index f730d290273fe..30ad04297c682 100644 --- a/python/pyspark/accumulators.py +++ b/python/pyspark/accumulators.py @@ -227,20 +227,49 @@ class _UpdateRequestHandler(SocketServer.StreamRequestHandler): def handle(self): from pyspark.accumulators import _accumulatorRegistry - while not self.server.server_shutdown: - # Poll every 1 second for new data -- don't block in case of shutdown. - r, _, _ = select.select([self.rfile], [], [], 1) - if self.rfile in r: - num_updates = read_int(self.rfile) - for _ in range(num_updates): - (aid, update) = pickleSer._read_with_length(self.rfile) - _accumulatorRegistry[aid] += update - # Write a byte in acknowledgement - self.wfile.write(struct.pack("!b", 1)) + auth_token = self.server.auth_token + + def poll(func): + while not self.server.server_shutdown: + # Poll every 1 second for new data -- don't block in case of shutdown. + r, _, _ = select.select([self.rfile], [], [], 1) + if self.rfile in r: + if func(): + break + + def accum_updates(): + num_updates = read_int(self.rfile) + for _ in range(num_updates): + (aid, update) = pickleSer._read_with_length(self.rfile) + _accumulatorRegistry[aid] += update + # Write a byte in acknowledgement + self.wfile.write(struct.pack("!b", 1)) + return False + + def authenticate_and_accum_updates(): + received_token = self.rfile.read(len(auth_token)) + if isinstance(received_token, bytes): + received_token = received_token.decode("utf-8") + if (received_token == auth_token): + accum_updates() + # we've authenticated, we can break out of the first loop now + return True + else: + raise Exception( + "The value of the provided token to the AccumulatorServer is not correct.") + + # first we keep polling till we've received the authentication token + poll(authenticate_and_accum_updates) + # now we've authenticated, don't need to check for the token anymore + poll(accum_updates) class AccumulatorServer(SocketServer.TCPServer): + def __init__(self, server_address, RequestHandlerClass, auth_token): + SocketServer.TCPServer.__init__(self, server_address, RequestHandlerClass) + self.auth_token = auth_token + """ A simple TCP server that intercepts shutdown() in order to interrupt our continuous polling on the handler. @@ -253,9 +282,9 @@ def shutdown(self): self.server_close() -def _start_update_server(): +def _start_update_server(auth_token): """Start a TCP server to receive accumulator updates in a daemon thread, and returns it""" - server = AccumulatorServer(("localhost", 0), _UpdateRequestHandler) + server = AccumulatorServer(("localhost", 0), _UpdateRequestHandler, auth_token) thread = threading.Thread(target=server.serve_forever) thread.daemon = True thread.start() diff --git a/python/pyspark/context.py b/python/pyspark/context.py index ede3b6af0a8cf..4cabae4b2f50b 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -126,7 +126,7 @@ def _do_init(self, master, appName, sparkHome, pyFiles, environment, batchSize, self.environment = environment or {} # java gateway must have been launched at this point. if conf is not None and conf._jconf is not None: - # conf has been initialized in JVM properly, so use conf directly. This represent the + # conf has been initialized in JVM properly, so use conf directly. This represents the # scenario that JVM has been launched before SparkConf is created (e.g. SparkContext is # created and then stopped, and we create a new SparkConf and new SparkContext again) self._conf = conf @@ -183,9 +183,10 @@ def _do_init(self, master, appName, sparkHome, pyFiles, environment, batchSize, # Create a single Accumulator in Java that we'll send all our updates through; # they will be passed back to us through a TCP server - self._accumulatorServer = accumulators._start_update_server() + auth_token = self._gateway.gateway_parameters.auth_token + self._accumulatorServer = accumulators._start_update_server(auth_token) (host, port) = self._accumulatorServer.server_address - self._javaAccumulator = self._jvm.PythonAccumulatorV2(host, port) + self._javaAccumulator = self._jvm.PythonAccumulatorV2(host, port, auth_token) self._jsc.sc().register(self._javaAccumulator) self.pythonExec = os.environ.get("PYSPARK_PYTHON", 'python') @@ -493,10 +494,14 @@ def f(split, iterator): c = list(c) # Make it a list so we can compute its length batchSize = max(1, min(len(c) // numSlices, self._batchSize or 1024)) serializer = BatchedSerializer(self._unbatched_serializer, batchSize) - jrdd = self._serialize_to_jvm(c, numSlices, serializer) + + def reader_func(temp_filename): + return self._jvm.PythonRDD.readRDDFromFile(self._jsc, temp_filename, numSlices) + + jrdd = self._serialize_to_jvm(c, serializer, reader_func) return RDD(jrdd, self, serializer) - def _serialize_to_jvm(self, data, parallelism, serializer): + def _serialize_to_jvm(self, data, serializer, reader_func): """ Calling the Java parallelize() method with an ArrayList is too slow, because it sends O(n) Py4J commands. As an alternative, serialized @@ -506,8 +511,7 @@ def _serialize_to_jvm(self, data, parallelism, serializer): try: serializer.dump_stream(data, tempFile) tempFile.close() - readRDDFromFile = self._jvm.PythonRDD.readRDDFromFile - return readRDDFromFile(self._jsc, tempFile.name, parallelism) + return reader_func(tempFile.name) finally: # readRDDFromFile eagerily reads the file so we can delete right after. os.unlink(tempFile.name) @@ -847,6 +851,8 @@ def addFile(self, path, recursive=False): A directory can be given if the recursive option is set to True. Currently directories are only supported for Hadoop-supported filesystems. + .. note:: A path can be added only once. Subsequent additions of the same path are ignored. + >>> from pyspark import SparkFiles >>> path = os.path.join(tempdir, "test.txt") >>> with open(path, "w") as testFile: @@ -867,6 +873,8 @@ def addPyFile(self, path): SparkContext in the future. The C{path} passed can be either a local file, a file in HDFS (or other Hadoop-supported filesystems), or an HTTP, HTTPS or FTP URI. + + .. note:: A path can be added only once. Subsequent additions of the same path are ignored. """ self.addFile(path) (dirname, filename) = os.path.split(path) # dirname may be directory or HDFS/S3 prefix @@ -929,10 +937,10 @@ def setJobGroup(self, groupId, description, interruptOnCancel=False): >>> def stop_job(): ... sleep(5) ... sc.cancelJobGroup("job_to_cancel") - >>> supress = lock.acquire() - >>> supress = threading.Thread(target=start_job, args=(10,)).start() - >>> supress = threading.Thread(target=stop_job).start() - >>> supress = lock.acquire() + >>> suppress = lock.acquire() + >>> suppress = threading.Thread(target=start_job, args=(10,)).start() + >>> suppress = threading.Thread(target=stop_job).start() + >>> suppress = lock.acquire() >>> print(result) Cancelled diff --git a/python/pyspark/find_spark_home.py b/python/pyspark/find_spark_home.py index 9cf0e8c8d2fe9..9c4ed46598632 100755 --- a/python/pyspark/find_spark_home.py +++ b/python/pyspark/find_spark_home.py @@ -27,7 +27,7 @@ def _find_spark_home(): """Find the SPARK_HOME.""" - # If the enviroment has SPARK_HOME set trust it. + # If the environment has SPARK_HOME set trust it. if "SPARK_HOME" in os.environ: return os.environ["SPARK_HOME"] diff --git a/python/pyspark/heapq3.py b/python/pyspark/heapq3.py index 6af084adcf373..37a2914ebac05 100644 --- a/python/pyspark/heapq3.py +++ b/python/pyspark/heapq3.py @@ -710,7 +710,7 @@ def merge(iterables, key=None, reverse=False): # value seen being in the 100 most extreme values is 100/101. # * If the value is a new extreme value, the cost of inserting it into the # heap is 1 + log(k, 2). -# * The probabilty times the cost gives: +# * The probability times the cost gives: # (k/i) * (1 + log(k, 2)) # * Summing across the remaining n-k elements gives: # sum((k/i) * (1 + log(k, 2)) for i in range(k+1, n+1)) diff --git a/python/pyspark/java_gateway.py b/python/pyspark/java_gateway.py index fa2d5e8db716a..c8c5f801f89bb 100644 --- a/python/pyspark/java_gateway.py +++ b/python/pyspark/java_gateway.py @@ -34,6 +34,7 @@ from py4j.java_gateway import java_import, JavaGateway, JavaObject, GatewayParameters from pyspark.find_spark_home import _find_spark_home from pyspark.serializers import read_int, write_with_length, UTF8Deserializer +from pyspark.util import _exception_message def launch_gateway(conf=None): @@ -134,7 +135,7 @@ def killChild(): return gateway -def do_server_auth(conn, auth_secret): +def _do_server_auth(conn, auth_secret): """ Performs the authentication protocol defined by the SocketAuthHelper class on the given file-like object 'conn'. @@ -147,6 +148,36 @@ def do_server_auth(conn, auth_secret): raise Exception("Unexpected reply from iterator server.") +def local_connect_and_auth(port, auth_secret): + """ + Connect to local host, authenticate with it, and return a (sockfile,sock) for that connection. + Handles IPV4 & IPV6, does some error handling. + :param port + :param auth_secret + :return: a tuple with (sockfile, sock) + """ + sock = None + errors = [] + # Support for both IPv4 and IPv6. + # On most of IPv6-ready systems, IPv6 will take precedence. + for res in socket.getaddrinfo("127.0.0.1", port, socket.AF_UNSPEC, socket.SOCK_STREAM): + af, socktype, proto, _, sa = res + try: + sock = socket.socket(af, socktype, proto) + sock.settimeout(15) + sock.connect(sa) + sockfile = sock.makefile("rwb", 65536) + _do_server_auth(sockfile, auth_secret) + return (sockfile, sock) + except socket.error as e: + emsg = _exception_message(e) + errors.append("tried to connect to %s, but an error occured: %s" % (sa, emsg)) + sock.close() + sock = None + else: + raise Exception("could not open socket: %s" % errors) + + def ensure_callback_server_started(gw): """ Start callback server if not already started. The callback server is needed if the Java diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index d5963f4f7042c..ce028512357f2 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -773,8 +773,8 @@ def roc(self): which is a Dataframe having two fields (FPR, TPR) with (0.0, 0.0) prepended and (1.0, 1.0) appended to it. - .. seealso:: `Wikipedia reference \ - `_ + .. seealso:: `Wikipedia reference + `_ .. note:: This ignores instance weights (setting all to 1.0) from `LogisticRegression.weightCol`. This will change in later Spark diff --git a/python/pyspark/ml/clustering.py b/python/pyspark/ml/clustering.py index 6d77baf7349e4..5ef4e765ea4e1 100644 --- a/python/pyspark/ml/clustering.py +++ b/python/pyspark/ml/clustering.py @@ -16,6 +16,7 @@ # import sys +import warnings from pyspark import since, keyword_only from pyspark.ml.util import * @@ -87,6 +88,14 @@ def clusterSizes(self): """ return self._call_java("clusterSizes") + @property + @since("2.4.0") + def numIter(self): + """ + Number of iterations. + """ + return self._call_java("numIter") + class GaussianMixtureModel(JavaModel, JavaMLWritable, JavaMLReadable): """ @@ -303,7 +312,15 @@ class KMeansSummary(ClusteringSummary): .. versionadded:: 2.1.0 """ - pass + + @property + @since("2.4.0") + def trainingCost(self): + """ + K-means cost (sum of squared distances to the nearest centroid for all points in the + training dataset). This is equivalent to sklearn's inertia. + """ + return self._call_java("trainingCost") class KMeansModel(JavaModel, JavaMLWritable, JavaMLReadable): @@ -323,7 +340,13 @@ def computeCost(self, dataset): """ Return the K-means cost (sum of squared distances of points to their nearest center) for this model on the given data. + + ..note:: Deprecated in 2.4.0. It will be removed in 3.0.0. Use ClusteringEvaluator instead. + You can also get the cost on the training dataset in the summary. """ + warnings.warn("Deprecated in 2.4.0. It will be removed in 3.0.0. Use ClusteringEvaluator " + "instead. You can also get the cost on the training dataset in the summary.", + DeprecationWarning) return self._call_java("computeCost", dataset) @property @@ -379,6 +402,8 @@ class KMeans(JavaEstimator, HasDistanceMeasure, HasFeaturesCol, HasPredictionCol 2 >>> summary.clusterSizes [2, 2] + >>> summary.trainingCost + 2.000... >>> kmeans_path = temp_path + "/kmeans" >>> kmeans.save(kmeans_path) >>> kmeans2 = KMeans.load(kmeans_path) @@ -1010,7 +1035,7 @@ def getK(self): def setOptimizer(self, value): """ Sets the value of :py:attr:`optimizer`. - Currenlty only support 'em' and 'online'. + Currently only support 'em' and 'online'. >>> algo = LDA().setOptimizer("em") >>> algo.getOptimizer() @@ -1177,21 +1202,21 @@ class PowerIterationClustering(HasMaxIter, HasWeightCol, JavaParams, JavaMLReada .. note:: Experimental Power Iteration Clustering (PIC), a scalable graph clustering algorithm developed by - Lin and Cohen. From the abstract: + `Lin and Cohen `_. From the abstract: PIC finds a very low-dimensional embedding of a dataset using truncated power iteration on a normalized pair-wise similarity matrix of the data. This class is not yet an Estimator/Transformer, use :py:func:`assignClusters` method to run the PowerIterationClustering algorithm. - .. seealso:: `Wikipedia on Spectral clustering \ - `_ + .. seealso:: `Wikipedia on Spectral clustering + `_ - >>> data = [(1, 0, 0.5), \ - (2, 0, 0.5), (2, 1, 0.7), \ - (3, 0, 0.5), (3, 1, 0.7), (3, 2, 0.9), \ - (4, 0, 0.5), (4, 1, 0.7), (4, 2, 0.9), (4, 3, 1.1), \ - (5, 0, 0.5), (5, 1, 0.7), (5, 2, 0.9), (5, 3, 1.1), (5, 4, 1.3)] + >>> data = [(1, 0, 0.5), + ... (2, 0, 0.5), (2, 1, 0.7), + ... (3, 0, 0.5), (3, 1, 0.7), (3, 2, 0.9), + ... (4, 0, 0.5), (4, 1, 0.7), (4, 2, 0.9), (4, 3, 1.1), + ... (5, 0, 0.5), (5, 1, 0.7), (5, 2, 0.9), (5, 3, 1.1), (5, 4, 1.3)] >>> df = spark.createDataFrame(data).toDF("src", "dst", "weight") >>> pic = PowerIterationClustering(k=2, maxIter=40, weightCol="weight") >>> assignments = pic.assignClusters(df) @@ -1345,8 +1370,14 @@ def assignClusters(self, dataset): if __name__ == "__main__": import doctest + import numpy import pyspark.ml.clustering from pyspark.sql import SparkSession + try: + # Numpy 1.14+ changed it's string format. + numpy.set_printoptions(legacy='1.13') + except TypeError: + pass globs = pyspark.ml.clustering.__dict__.copy() # The small batch size here ensures that we see multiple batches, # even in these small test examples: diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index 14800d4d9327a..eccb7acae5b98 100755 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -207,8 +207,8 @@ class BucketedRandomProjectionLSH(JavaEstimator, LSHParams, HasInputCol, HasOutp distance space. The output will be vectors of configurable dimension. Hash values in the same dimension are calculated by the same hash function. - .. seealso:: `Stable Distributions \ - `_ + .. seealso:: `Stable Distributions + `_ .. seealso:: `Hashing for Similarity Search: A Survey `_ >>> from pyspark.ml.linalg import Vectors @@ -303,7 +303,7 @@ def _create_model(self, java_model): class BucketedRandomProjectionLSHModel(LSHModel, JavaMLReadable, JavaMLWritable): - """ + r""" .. note:: Experimental Model fitted by :py:class:`BucketedRandomProjectionLSH`, where multiple random vectors are @@ -653,8 +653,8 @@ class DCT(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWrit The return vector is scaled such that the transform matrix is unitary (aka scaled DCT-II). - .. seealso:: `More information on Wikipedia \ - `_. + .. seealso:: `More information on Wikipedia + `_. >>> from pyspark.ml.linalg import Vectors >>> df1 = spark.createDataFrame([(Vectors.dense([5.0, 8.0, 6.0]),)], ["vec"]) @@ -1294,14 +1294,14 @@ class MinHashLSH(JavaEstimator, LSHParams, HasInputCol, HasOutputCol, HasSeed, >>> mh = MinHashLSH(inputCol="features", outputCol="hashes", seed=12345) >>> model = mh.fit(df) >>> model.transform(df).head() - Row(id=0, features=SparseVector(6, {0: 1.0, 1: 1.0, 2: 1.0}), hashes=[DenseVector([-1638925... + Row(id=0, features=SparseVector(6, {0: 1.0, 1: 1.0, 2: 1.0}), hashes=[DenseVector([6179668... >>> data2 = [(3, Vectors.sparse(6, [1, 3, 5], [1.0, 1.0, 1.0]),), ... (4, Vectors.sparse(6, [2, 3, 5], [1.0, 1.0, 1.0]),), ... (5, Vectors.sparse(6, [1, 2, 4], [1.0, 1.0, 1.0]),)] >>> df2 = spark.createDataFrame(data2, ["id", "features"]) >>> key = Vectors.sparse(6, [1, 2], [1.0, 1.0]) >>> model.approxNearestNeighbors(df2, key, 1).collect() - [Row(id=5, features=SparseVector(6, {1: 1.0, 2: 1.0, 4: 1.0}), hashes=[DenseVector([-163892... + [Row(id=5, features=SparseVector(6, {1: 1.0, 2: 1.0, 4: 1.0}), hashes=[DenseVector([6179668... >>> model.approxSimilarityJoin(df, df2, 0.6, distCol="JaccardDistance").select( ... col("datasetA.id").alias("idA"), ... col("datasetB.id").alias("idB"), @@ -1309,8 +1309,8 @@ class MinHashLSH(JavaEstimator, LSHParams, HasInputCol, HasOutputCol, HasSeed, +---+---+---------------+ |idA|idB|JaccardDistance| +---+---+---------------+ - | 1| 4| 0.5| | 0| 5| 0.5| + | 1| 4| 0.5| +---+---+---------------+ ... >>> mhPath = temp_path + "/mh" @@ -1353,7 +1353,7 @@ def _create_model(self, java_model): class MinHashLSHModel(LSHModel, JavaMLReadable, JavaMLWritable): - """ + r""" .. note:: Experimental Model produced by :py:class:`MinHashLSH`, where where multiple hash functions are stored. Each @@ -1362,8 +1362,8 @@ class MinHashLSHModel(LSHModel, JavaMLReadable, JavaMLWritable): :math:`h_i(x) = ((x \cdot a_i + b_i) \mod prime)` This hash family is approximately min-wise independent according to the reference. - .. seealso:: Tom Bohman, Colin Cooper, and Alan Frieze. "Min-wise independent linear \ - permutations." Electronic Journal of Combinatorics 7 (2000): R26. + .. seealso:: Tom Bohman, Colin Cooper, and Alan Frieze. "Min-wise independent linear + permutations." Electronic Journal of Combinatorics 7 (2000): R26. .. versionadded:: 2.2.0 """ @@ -3843,12 +3843,12 @@ def setParams(self, inputCol=None, size=None, handleInvalid="error"): @since("2.3.0") def getSize(self): """ Gets size param, the size of vectors in `inputCol`.""" - self.getOrDefault(self.size) + return self.getOrDefault(self.size) @since("2.3.0") def setSize(self, value): """ Sets size param, the size of vectors in `inputCol`.""" - self._set(size=value) + return self._set(size=value) if __name__ == "__main__": diff --git a/python/pyspark/ml/fpm.py b/python/pyspark/ml/fpm.py index fd19fd96c4df6..886ad8409ca66 100644 --- a/python/pyspark/ml/fpm.py +++ b/python/pyspark/ml/fpm.py @@ -21,7 +21,7 @@ from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaParams, _jvm from pyspark.ml.param.shared import * -__all__ = ["FPGrowth", "FPGrowthModel"] +__all__ = ["FPGrowth", "FPGrowthModel", "PrefixSpan"] class HasMinSupport(Params): @@ -145,10 +145,11 @@ def freqItemsets(self): @since("2.2.0") def associationRules(self): """ - DataFrame with three columns: + DataFrame with four columns: * `antecedent` - Array of the same type as the input column. * `consequent` - Array of the same type as the input column. * `confidence` - Confidence for the rule (`DoubleType`). + * `lift` - Lift for the rule (`DoubleType`). """ return self._call_java("associationRules") @@ -157,7 +158,7 @@ class FPGrowth(JavaEstimator, HasItemsCol, HasPredictionCol, HasMinSupport, HasNumPartitions, HasMinConfidence, JavaMLWritable, JavaMLReadable): - """ + r""" .. note:: Experimental A parallel FP-growth algorithm to mine frequent itemsets. The algorithm is described in @@ -313,14 +314,15 @@ def setParams(self, minSupport=0.1, maxPatternLength=10, maxLocalProjDBSize=3200 def findFrequentSequentialPatterns(self, dataset): """ .. note:: Experimental + Finds the complete set of frequent sequential patterns in the input sequences of itemsets. :param dataset: A dataframe containing a sequence column which is `ArrayType(ArrayType(T))` type, T is the item type for the input dataset. :return: A `DataFrame` that contains columns of sequence and corresponding frequency. The schema of it will be: - - `sequence: ArrayType(ArrayType(T))` (T is the item type) - - `freq: Long` + - `sequence: ArrayType(ArrayType(T))` (T is the item type) + - `freq: Long` >>> from pyspark.ml.fpm import PrefixSpan >>> from pyspark.sql import Row diff --git a/python/pyspark/ml/image.py b/python/pyspark/ml/image.py index 5f0c57ee3cc67..edb90a3578546 100644 --- a/python/pyspark/ml/image.py +++ b/python/pyspark/ml/image.py @@ -25,8 +25,10 @@ """ import sys +import warnings import numpy as np + from pyspark import SparkContext from pyspark.sql.types import Row, _create_row, _parse_datatype_json_string from pyspark.sql import DataFrame, SparkSession @@ -207,6 +209,9 @@ def readImages(self, path, recursive=False, numPartitions=-1, .. note:: If sample ratio is less than 1, sampling uses a PathFilter that is efficient but potentially non-deterministic. + .. note:: Deprecated in 2.4.0. Use `spark.read.format("image").load(path)` instead and + this `readImages` will be removed in 3.0.0. + :param str path: Path to the image directory. :param bool recursive: Recursive search flag. :param int numPartitions: Number of DataFrame partitions. @@ -216,13 +221,14 @@ def readImages(self, path, recursive=False, numPartitions=-1, :return: a :class:`DataFrame` with a single column of "images", see ImageSchema for details. - >>> df = ImageSchema.readImages('data/mllib/images/kittens', recursive=True) + >>> df = ImageSchema.readImages('data/mllib/images/origin/kittens', recursive=True) >>> df.count() 5 .. versionadded:: 2.3.0 """ - + warnings.warn("`ImageSchema.readImage` is deprecated. " + + "Use `spark.read.format(\"image\").load(path)` instead.", DeprecationWarning) spark = SparkSession.builder.getOrCreate() image_schema = spark._jvm.org.apache.spark.ml.image.ImageSchema jsession = spark._jsparkSession diff --git a/python/pyspark/ml/linalg/__init__.py b/python/pyspark/ml/linalg/__init__.py index 6a611a2b5b59d..2548fd0f50b33 100644 --- a/python/pyspark/ml/linalg/__init__.py +++ b/python/pyspark/ml/linalg/__init__.py @@ -1156,6 +1156,11 @@ def sparse(numRows, numCols, colPtrs, rowIndices, values): def _test(): import doctest + try: + # Numpy 1.14+ changed it's string format. + np.set_printoptions(legacy='1.13') + except TypeError: + pass (failure_count, test_count) = doctest.testmod(optionflags=doctest.ELLIPSIS) if failure_count: sys.exit(-1) diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py index 83f0edb397271..98f4361351847 100644 --- a/python/pyspark/ml/regression.py +++ b/python/pyspark/ml/regression.py @@ -188,8 +188,8 @@ def intercept(self): @property @since("2.3.0") def scale(self): - """ - The value by which \|y - X'w\| is scaled down when loss is "huber", otherwise 1.0. + r""" + The value by which :math:`\|y - X'w\|` is scaled down when loss is "huber", otherwise 1.0. """ return self._call_java("scale") @@ -279,12 +279,12 @@ def featuresCol(self): @property @since("2.0.0") def explainedVariance(self): - """ + r""" Returns the explained variance regression score. - explainedVariance = 1 - variance(y - \hat{y}) / variance(y) + explainedVariance = :math:`1 - \frac{variance(y - \hat{y})}{variance(y)}` - .. seealso:: `Wikipedia explain variation \ - `_ + .. seealso:: `Wikipedia explain variation + `_ .. note:: This ignores instance weights (setting all to 1.0) from `LinearRegression.weightCol`. This will change in later Spark @@ -339,8 +339,8 @@ def r2(self): """ Returns R^2, the coefficient of determination. - .. seealso:: `Wikipedia coefficient of determination \ - `_ + .. seealso:: `Wikipedia coefficient of determination + `_ .. note:: This ignores instance weights (setting all to 1.0) from `LinearRegression.weightCol`. This will change in later Spark @@ -354,8 +354,8 @@ def r2adj(self): """ Returns Adjusted R^2, the adjusted coefficient of determination. - .. seealso:: `Wikipedia coefficient of determination, Adjusted R^2 \ - `_ + .. seealso:: `Wikipedia coefficient of determination, Adjusted R^2 + `_ .. note:: This ignores instance weights (setting all to 1.0) from `LinearRegression.weightCol`. This will change in later Spark versions. @@ -608,8 +608,13 @@ class TreeEnsembleParams(DecisionTreeParams): featureSubsetStrategy = \ Param(Params._dummy(), "featureSubsetStrategy", "The number of features to consider for splits at each tree node. Supported " + - "options: " + ", ".join(supportedFeatureSubsetStrategies) + ", (0.0-1.0], [1-n].", - typeConverter=TypeConverters.toString) + "options: 'auto' (choose automatically for task: If numTrees == 1, set to " + + "'all'. If numTrees > 1 (forest), set to 'sqrt' for classification and to " + + "'onethird' for regression), 'all' (use all features), 'onethird' (use " + + "1/3 of the features), 'sqrt' (use sqrt(number of features)), 'log2' (use " + + "log2(number of features)), 'n' (when n is in the range (0, 1.0], use " + + "n * number of features. When n is in the range (1, number of features), use" + + " n features). default = 'auto'", typeConverter=TypeConverters.toString) def __init__(self): super(TreeEnsembleParams, self).__init__() @@ -1370,7 +1375,7 @@ def intercept(self): @since("1.6.0") def scale(self): """ - Model scale paramter. + Model scale parameter. """ return self._call_java("scale") diff --git a/python/pyspark/ml/stat.py b/python/pyspark/ml/stat.py index a06ab31a7a56a..370154fc6d62a 100644 --- a/python/pyspark/ml/stat.py +++ b/python/pyspark/ml/stat.py @@ -388,8 +388,14 @@ def summary(self, featuresCol, weightCol=None): if __name__ == "__main__": import doctest + import numpy import pyspark.ml.stat from pyspark.sql import SparkSession + try: + # Numpy 1.14+ changed it's string format. + numpy.set_printoptions(legacy='1.13') + except TypeError: + pass globs = pyspark.ml.stat.__dict__.copy() # The small batch size here ensures that we see multiple batches, diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index bc782138292bf..821e037af0271 100755 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -844,6 +844,23 @@ def test_string_indexer_from_labels(self): .select(model_default.getOrDefault(model_default.outputCol)).collect() self.assertEqual(len(transformed_list), 5) + def test_vector_size_hint(self): + df = self.spark.createDataFrame( + [(0, Vectors.dense([0.0, 10.0, 0.5])), + (1, Vectors.dense([1.0, 11.0, 0.5, 0.6])), + (2, Vectors.dense([2.0, 12.0]))], + ["id", "vector"]) + + sizeHint = VectorSizeHint( + inputCol="vector", + handleInvalid="skip") + sizeHint.setSize(3) + self.assertEqual(sizeHint.getSize(), 3) + + output = sizeHint.transform(df).head().vector + expected = DenseVector([0.0, 10.0, 0.5]) + self.assertEqual(output, expected) + class HasInducedError(Params): @@ -950,6 +967,13 @@ def test_fit_maximize_metric(self): "Best model should have zero induced error") self.assertEqual(1.0, bestModelMetric, "Best model has R-squared of 1") + def test_param_grid_type_coercion(self): + lr = LogisticRegression(maxIter=10) + paramGrid = ParamGridBuilder().addGrid(lr.regParam, [0.5, 1]).build() + for param in paramGrid: + for v in param.values(): + assert(type(v) == float) + def test_save_load_trained_model(self): # This tests saving and loading the trained model only. # Save/load for CrossValidator will be added later: SPARK-13786 @@ -1888,6 +1912,7 @@ def test_gaussian_mixture_summary(self): self.assertTrue(isinstance(s.cluster, DataFrame)) self.assertEqual(len(s.clusterSizes), 2) self.assertEqual(s.k, 2) + self.assertEqual(s.numIter, 3) def test_bisecting_kmeans_summary(self): data = [(Vectors.dense(1.0),), (Vectors.dense(5.0),), (Vectors.dense(10.0),), @@ -1903,6 +1928,7 @@ def test_bisecting_kmeans_summary(self): self.assertTrue(isinstance(s.cluster, DataFrame)) self.assertEqual(len(s.clusterSizes), 2) self.assertEqual(s.k, 2) + self.assertEqual(s.numIter, 20) def test_kmeans_summary(self): data = [(Vectors.dense([0.0, 0.0]),), (Vectors.dense([1.0, 1.0]),), @@ -1918,6 +1944,7 @@ def test_kmeans_summary(self): self.assertTrue(isinstance(s.cluster, DataFrame)) self.assertEqual(len(s.clusterSizes), 2) self.assertEqual(s.k, 2) + self.assertEqual(s.numIter, 1) class KMeansTests(SparkSessionTestCase): @@ -2131,8 +2158,8 @@ def test_association_rules(self): fpm = fp.fit(self.data) expected_association_rules = self.spark.createDataFrame( - [([3], [1], 1.0), ([2], [1], 1.0)], - ["antecedent", "consequent", "confidence"] + [([3], [1], 1.0, 1.0), ([2], [1], 1.0, 1.0)], + ["antecedent", "consequent", "confidence", "lift"] ) actual_association_rules = fpm.associationRules @@ -2159,7 +2186,7 @@ def tearDown(self): class ImageReaderTest(SparkSessionTestCase): def test_read_images(self): - data_path = 'data/mllib/images/kittens' + data_path = 'data/mllib/images/origin/kittens' df = ImageSchema.readImages(data_path, recursive=True, dropImageFailures=True) self.assertEqual(df.count(), 4) first_row = df.take(1)[0][0] @@ -2226,7 +2253,7 @@ def tearDownClass(cls): def test_read_images_multiple_times(self): # This test case is to check if `ImageSchema.readImages` tries to # initiate Hive client multiple times. See SPARK-22651. - data_path = 'data/mllib/images/kittens' + data_path = 'data/mllib/images/origin/kittens' ImageSchema.readImages(data_path, recursive=True, dropImageFailures=True) ImageSchema.readImages(data_path, recursive=True, dropImageFailures=True) diff --git a/python/pyspark/ml/tuning.py b/python/pyspark/ml/tuning.py index 0c8029f293cfe..1f4abf5157335 100644 --- a/python/pyspark/ml/tuning.py +++ b/python/pyspark/ml/tuning.py @@ -115,7 +115,11 @@ def build(self): """ keys = self._param_grid.keys() grid_values = self._param_grid.values() - return [dict(zip(keys, prod)) for prod in itertools.product(*grid_values)] + + def to_key_value_pairs(keys, values): + return [(key, key.typeConverter(value)) for key, value in zip(keys, values)] + + return [dict(to_key_value_pairs(keys, prod)) for prod in itertools.product(*grid_values)] class ValidatorParams(HasSeed): diff --git a/python/pyspark/ml/util.py b/python/pyspark/ml/util.py index 080cd299f4fde..e846834761e49 100644 --- a/python/pyspark/ml/util.py +++ b/python/pyspark/ml/util.py @@ -63,7 +63,7 @@ def _randomUID(cls): Generate a unique unicode id for the object. The default implementation concatenates the class name, "_", and 12 random hex chars. """ - return unicode(cls.__name__ + "_" + uuid.uuid4().hex[12:]) + return unicode(cls.__name__ + "_" + uuid.uuid4().hex[-12:]) @inherit_doc diff --git a/python/pyspark/mllib/clustering.py b/python/pyspark/mllib/clustering.py index 0cbabab13a896..b1a8af6bcc094 100644 --- a/python/pyspark/mllib/clustering.py +++ b/python/pyspark/mllib/clustering.py @@ -647,7 +647,7 @@ class PowerIterationClustering(object): @classmethod @since('1.5.0') def train(cls, rdd, k, maxIterations=100, initMode="random"): - """ + r""" :param rdd: An RDD of (i, j, s\ :sub:`ij`\) tuples representing the affinity matrix, which is the matrix A in the PIC paper. The @@ -1042,7 +1042,13 @@ def train(cls, rdd, k=10, maxIterations=20, docConcentration=-1.0, def _test(): import doctest + import numpy import pyspark.mllib.clustering + try: + # Numpy 1.14+ changed it's string format. + numpy.set_printoptions(legacy='1.13') + except TypeError: + pass globs = pyspark.mllib.clustering.__dict__.copy() globs['sc'] = SparkContext('local[4]', 'PythonTest', batchSize=2) (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) diff --git a/python/pyspark/mllib/evaluation.py b/python/pyspark/mllib/evaluation.py index 36cb03369b8c0..0bb0ca37c1ab6 100644 --- a/python/pyspark/mllib/evaluation.py +++ b/python/pyspark/mllib/evaluation.py @@ -117,9 +117,9 @@ def __init__(self, predictionAndObservations): @property @since('1.4.0') def explainedVariance(self): - """ + r""" Returns the explained variance regression score. - explainedVariance = 1 - variance(y - \hat{y}) / variance(y) + explainedVariance = :math:`1 - \frac{variance(y - \hat{y})}{variance(y)}` """ return self.call("explainedVariance") @@ -532,8 +532,14 @@ def accuracy(self): def _test(): import doctest + import numpy from pyspark.sql import SparkSession import pyspark.mllib.evaluation + try: + # Numpy 1.14+ changed it's string format. + numpy.set_printoptions(legacy='1.13') + except TypeError: + pass globs = pyspark.mllib.evaluation.__dict__.copy() spark = SparkSession.builder\ .master("local[4]")\ diff --git a/python/pyspark/mllib/feature.py b/python/pyspark/mllib/feature.py index 40ecd2e0ff4be..6d7d4d61db043 100644 --- a/python/pyspark/mllib/feature.py +++ b/python/pyspark/mllib/feature.py @@ -59,7 +59,7 @@ def transform(self, vector): class Normalizer(VectorTransformer): - """ + r""" Normalizes samples individually to unit L\ :sup:`p`\ norm For any 1 <= `p` < float('inf'), normalizes samples using diff --git a/python/pyspark/mllib/linalg/__init__.py b/python/pyspark/mllib/linalg/__init__.py index 60d96d8d5ceb8..4afd6666400b0 100644 --- a/python/pyspark/mllib/linalg/__init__.py +++ b/python/pyspark/mllib/linalg/__init__.py @@ -1368,6 +1368,12 @@ def R(self): def _test(): import doctest + import numpy + try: + # Numpy 1.14+ changed it's string format. + numpy.set_printoptions(legacy='1.13') + except TypeError: + pass (failure_count, test_count) = doctest.testmod(optionflags=doctest.ELLIPSIS) if failure_count: sys.exit(-1) diff --git a/python/pyspark/mllib/linalg/distributed.py b/python/pyspark/mllib/linalg/distributed.py index bba88542167ad..7e8b15056cabe 100644 --- a/python/pyspark/mllib/linalg/distributed.py +++ b/python/pyspark/mllib/linalg/distributed.py @@ -1364,9 +1364,15 @@ def toCoordinateMatrix(self): def _test(): import doctest + import numpy from pyspark.sql import SparkSession from pyspark.mllib.linalg import Matrices import pyspark.mllib.linalg.distributed + try: + # Numpy 1.14+ changed it's string format. + numpy.set_printoptions(legacy='1.13') + except TypeError: + pass globs = pyspark.mllib.linalg.distributed.__dict__.copy() spark = SparkSession.builder\ .master("local[2]")\ diff --git a/python/pyspark/mllib/stat/_statistics.py b/python/pyspark/mllib/stat/_statistics.py index 3c75b132ecad2..6e89bfd691d16 100644 --- a/python/pyspark/mllib/stat/_statistics.py +++ b/python/pyspark/mllib/stat/_statistics.py @@ -259,7 +259,7 @@ def kolmogorovSmirnovTest(data, distName="norm", *params): The KS statistic gives us the maximum distance between the ECDF and the CDF. Intuitively if this statistic is large, the - probabilty that the null hypothesis is true becomes small. + probability that the null hypothesis is true becomes small. For specific details of the implementation, please have a look at the Scala documentation. @@ -303,7 +303,13 @@ def kolmogorovSmirnovTest(data, distName="norm", *params): def _test(): import doctest + import numpy from pyspark.sql import SparkSession + try: + # Numpy 1.14+ changed it's string format. + numpy.set_printoptions(legacy='1.13') + except TypeError: + pass globs = globals().copy() spark = SparkSession.builder\ .master("local[4]")\ diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 7e7e5822a6b20..ccf39e1ffbe96 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -39,7 +39,7 @@ else: from itertools import imap as map, ifilter as filter -from pyspark.java_gateway import do_server_auth +from pyspark.java_gateway import local_connect_and_auth from pyspark.serializers import NoOpSerializer, CartesianDeserializer, \ BatchedSerializer, CloudPickleSerializer, PairDeserializer, \ PickleSerializer, pack_long, AutoBatchedSerializer, write_with_length, \ @@ -53,7 +53,7 @@ from pyspark.shuffle import Aggregator, ExternalMerger, \ get_used_memory, ExternalSorter, ExternalGroupBy from pyspark.traceback_utils import SCCallSiteSync -from pyspark.util import fail_on_stopiteration +from pyspark.util import fail_on_stopiteration, _exception_message __all__ = ["RDD"] @@ -141,30 +141,10 @@ def _parse_memory(s): def _load_from_socket(sock_info, serializer): - port, auth_secret = sock_info - sock = None - # Support for both IPv4 and IPv6. - # On most of IPv6-ready systems, IPv6 will take precedence. - for res in socket.getaddrinfo("localhost", port, socket.AF_UNSPEC, socket.SOCK_STREAM): - af, socktype, proto, canonname, sa = res - sock = socket.socket(af, socktype, proto) - try: - sock.settimeout(15) - sock.connect(sa) - except socket.error: - sock.close() - sock = None - continue - break - if not sock: - raise Exception("could not open socket") + (sockfile, sock) = local_connect_and_auth(*sock_info) # The RDD materialization time is unpredicable, if we set a timeout for socket reading # operation, it will very possibly fail. See SPARK-18281. sock.settimeout(None) - - sockfile = sock.makefile("rwb", 65536) - do_server_auth(sockfile, auth_secret) - # The socket will be automatically closed when garbage-collected. return serializer.load_stream(sockfile) @@ -1360,7 +1340,7 @@ def take(self, num): if len(items) == 0: numPartsToTry = partsScanned * 4 else: - # the first paramter of max is >=1 whenever partsScanned >= 2 + # the first parameter of max is >=1 whenever partsScanned >= 2 numPartsToTry = int(1.5 * num * partsScanned / len(items)) - partsScanned numPartsToTry = min(max(numPartsToTry, 1), partsScanned * 4) @@ -1370,7 +1350,10 @@ def takeUpToNumLeft(iterator): iterator = iter(iterator) taken = 0 while taken < left: - yield next(iterator) + try: + yield next(iterator) + except StopIteration: + return taken += 1 p = range(partsScanned, min(partsScanned + numPartsToTry, totalParts)) @@ -2403,6 +2386,33 @@ def toLocalIterator(self): sock_info = self.ctx._jvm.PythonRDD.toLocalIteratorAndServe(self._jrdd.rdd()) return _load_from_socket(sock_info, self._jrdd_deserializer) + def barrier(self): + """ + .. note:: Experimental + + Marks the current stage as a barrier stage, where Spark must launch all tasks together. + In case of a task failure, instead of only restarting the failed task, Spark will abort the + entire stage and relaunch all tasks for this stage. + The barrier execution mode feature is experimental and it only handles limited scenarios. + Please read the linked SPIP and design docs to understand the limitations and future plans. + + :return: an :class:`RDDBarrier` instance that provides actions within a barrier stage. + + .. seealso:: :class:`BarrierTaskContext` + .. seealso:: `SPIP: Barrier Execution Mode + `_ + .. seealso:: `Design Doc `_ + + .. versionadded:: 2.4.0 + """ + return RDDBarrier(self) + + def _is_barrier(self): + """ + Whether this RDD is in a barrier stage. + """ + return self._jrdd.rdd().isBarrier() + def _prepare_for_python_RDD(sc, command): # the serialized command will be compressed by broadcast @@ -2426,6 +2436,36 @@ def _wrap_function(sc, func, deserializer, serializer, profiler=None): sc.pythonVer, broadcast_vars, sc._javaAccumulator) +class RDDBarrier(object): + + """ + .. note:: Experimental + + Wraps an RDD in a barrier stage, which forces Spark to launch tasks of this stage together. + :class:`RDDBarrier` instances are created by :func:`RDD.barrier`. + + .. versionadded:: 2.4.0 + """ + + def __init__(self, rdd): + self.rdd = rdd + + def mapPartitions(self, f, preservesPartitioning=False): + """ + .. note:: Experimental + + Returns a new RDD by applying a function to each partition of the wrapped RDD, + where tasks are launched together in a barrier stage. + The interface is the same as :func:`RDD.mapPartitions`. + Please see the API doc there. + + .. versionadded:: 2.4.0 + """ + def func(s, iterator): + return f(iterator) + return PipelinedRDD(self.rdd, func, preservesPartitioning, isFromBarrier=True) + + class PipelinedRDD(RDD): """ @@ -2445,7 +2485,7 @@ class PipelinedRDD(RDD): 20 """ - def __init__(self, prev, func, preservesPartitioning=False): + def __init__(self, prev, func, preservesPartitioning=False, isFromBarrier=False): if not isinstance(prev, PipelinedRDD) or not prev._is_pipelinable(): # This transformation is the first in its stage: self.func = func @@ -2471,6 +2511,7 @@ def pipeline_func(split, iterator): self._jrdd_deserializer = self.ctx.serializer self._bypass_serializer = False self.partitioner = prev.partitioner if self.preservesPartitioning else None + self.is_barrier = prev._is_barrier() or isFromBarrier def getNumPartitions(self): return self._prev_jrdd.partitions().size() @@ -2490,7 +2531,7 @@ def _jrdd(self): wrapped_func = _wrap_function(self.ctx, self.func, self._prev_jrdd_deserializer, self._jrdd_deserializer, profiler) python_rdd = self.ctx._jvm.PythonRDD(self._prev_jrdd.rdd(), wrapped_func, - self.preservesPartitioning) + self.preservesPartitioning, self.is_barrier) self._jrdd_val = python_rdd.asJavaRDD() if profiler: @@ -2506,6 +2547,9 @@ def id(self): def _is_pipelinable(self): return not (self.is_cached or self.is_checkpointed) + def _is_barrier(self): + return self.is_barrier + def _test(): import doctest diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index 4c16b5fc26f3d..48006778e86f2 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -185,27 +185,31 @@ def loads(self, obj): raise NotImplementedError -class ArrowSerializer(FramedSerializer): +class ArrowStreamSerializer(Serializer): """ - Serializes bytes as Arrow data with the Arrow file format. + Serializes Arrow record batches as a stream. """ - def dumps(self, batch): + def dump_stream(self, iterator, stream): import pyarrow as pa - import io - sink = io.BytesIO() - writer = pa.RecordBatchFileWriter(sink, batch.schema) - writer.write_batch(batch) - writer.close() - return sink.getvalue() + writer = None + try: + for batch in iterator: + if writer is None: + writer = pa.RecordBatchStreamWriter(stream, batch.schema) + writer.write_batch(batch) + finally: + if writer is not None: + writer.close() - def loads(self, obj): + def load_stream(self, stream): import pyarrow as pa - reader = pa.RecordBatchFileReader(pa.BufferReader(obj)) - return reader.read_all() + reader = pa.open_stream(stream) + for batch in reader: + yield batch def __repr__(self): - return "ArrowSerializer" + return "ArrowStreamSerializer" def _create_batch(series, timezone): @@ -216,9 +220,10 @@ def _create_batch(series, timezone): :param timezone: A timezone to respect when handling timestamp values :return: Arrow RecordBatch """ - - from pyspark.sql.types import _check_series_convert_timestamps_internal + import decimal + from distutils.version import LooseVersion import pyarrow as pa + from pyspark.sql.types import _check_series_convert_timestamps_internal # Make input conform to [(series1, type1), (series2, type2), ...] if not isinstance(series, (list, tuple)) or \ (len(series) == 2 and isinstance(series[1], pa.DataType)): @@ -228,14 +233,21 @@ def _create_batch(series, timezone): def create_array(s, t): mask = s.isnull() # Ensure timestamp series are in expected form for Spark internal representation + # TODO: maybe don't need None check anymore as of Arrow 0.9.1 if t is not None and pa.types.is_timestamp(t): s = _check_series_convert_timestamps_internal(s.fillna(0), timezone) # TODO: need cast after Arrow conversion, ns values cause error with pandas 0.19.2 return pa.Array.from_pandas(s, mask=mask).cast(t, safe=False) elif t is not None and pa.types.is_string(t) and sys.version < '3': # TODO: need decode before converting to Arrow in Python 2 + # TODO: don't need as of Arrow 0.9.1 return pa.Array.from_pandas(s.apply( lambda v: v.decode("utf-8") if isinstance(v, str) else v), mask=mask, type=t) + elif t is not None and pa.types.is_decimal(t) and \ + LooseVersion("0.9.0") <= LooseVersion(pa.__version__) < LooseVersion("0.10.0"): + # TODO: see ARROW-2432. Remove when the minimum PyArrow version becomes 0.10.0. + return pa.Array.from_pandas(s.apply( + lambda v: decimal.Decimal('NaN') if v is None else v), mask=mask, type=t) return pa.Array.from_pandas(s, mask=mask, type=t) arrs = [create_array(s, t) for s, t in series] @@ -707,6 +719,13 @@ def write_int(value, stream): stream.write(struct.pack("!i", value)) +def read_bool(stream): + length = stream.read(1) + if not length: + raise EOFError + return struct.unpack("!?", length)[0] + + def write_with_length(obj, stream): write_int(len(obj), stream) stream.write(obj) diff --git a/python/pyspark/shell.py b/python/pyspark/shell.py index 472c3cd4452f0..65e3bdbc05ce8 100644 --- a/python/pyspark/shell.py +++ b/python/pyspark/shell.py @@ -54,7 +54,7 @@ sqlContext = spark._wrapped sqlCtx = sqlContext -print("""Welcome to +print(r"""Welcome to ____ __ / __/__ ___ _____/ /__ _\ \/ _ \/ _ `/ __/ '_/ diff --git a/python/pyspark/sql/catalog.py b/python/pyspark/sql/catalog.py index b0d8357f4feec..974251f63b37a 100644 --- a/python/pyspark/sql/catalog.py +++ b/python/pyspark/sql/catalog.py @@ -177,8 +177,7 @@ def createTable(self, tableName, path=None, source=None, schema=None, **options) if path is not None: options["path"] = path if source is None: - source = self._sparkSession.conf.get( - "spark.sql.sources.default", "org.apache.spark.sql.parquet") + source = self._sparkSession._wrapped._conf.defaultDataSourceName() if schema is None: df = self._jcatalog.createTable(tableName, source, options) else: diff --git a/python/pyspark/sql/conf.py b/python/pyspark/sql/conf.py index db49040e17b63..71ea1631718f1 100644 --- a/python/pyspark/sql/conf.py +++ b/python/pyspark/sql/conf.py @@ -20,6 +20,9 @@ from pyspark import since, _NoValue from pyspark.rdd import ignore_unicode_prefix +if sys.version_info[0] >= 3: + basestring = str + class RuntimeConfig(object): """User-facing configuration API, accessible through `SparkSession.conf`. @@ -59,10 +62,18 @@ def unset(self, key): def _checkType(self, obj, identifier): """Assert that an object is of type str.""" - if not isinstance(obj, str) and not isinstance(obj, unicode): + if not isinstance(obj, basestring): raise TypeError("expected %s '%s' to be a string (was '%s')" % (identifier, obj, type(obj).__name__)) + @ignore_unicode_prefix + @since(2.4) + def isModifiable(self, key): + """Indicates whether the configuration property with the given key + is modifiable in the current session. + """ + return self._jconf.isModifiable(key) + def _test(): import os diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py index e9ec7ba866761..9c094dd9a9033 100644 --- a/python/pyspark/sql/context.py +++ b/python/pyspark/sql/context.py @@ -93,6 +93,11 @@ def _ssql_ctx(self): """ return self._jsqlContext + @property + def _conf(self): + """Accessor for the JVM SQL-specific configurations""" + return self.sparkSession._jsparkSession.sessionState().conf() + @classmethod @since(1.6) def getOrCreate(cls, sc): diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index cb3fe448b6fc7..1affc9b4fcf6c 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -29,7 +29,7 @@ from pyspark import copy_func, since, _NoValue from pyspark.rdd import RDD, _load_from_socket, ignore_unicode_prefix -from pyspark.serializers import ArrowSerializer, BatchedSerializer, PickleSerializer, \ +from pyspark.serializers import ArrowStreamSerializer, BatchedSerializer, PickleSerializer, \ UTF8Deserializer from pyspark.storagelevel import StorageLevel from pyspark.traceback_utils import SCCallSiteSync @@ -293,6 +293,31 @@ def explain(self, extended=False): else: print(self._jdf.queryExecution().simpleString()) + @since(2.4) + def exceptAll(self, other): + """Return a new :class:`DataFrame` containing rows in this :class:`DataFrame` but + not in another :class:`DataFrame` while preserving duplicates. + + This is equivalent to `EXCEPT ALL` in SQL. + + >>> df1 = spark.createDataFrame( + ... [("a", 1), ("a", 1), ("a", 1), ("a", 2), ("b", 3), ("c", 4)], ["C1", "C2"]) + >>> df2 = spark.createDataFrame([("a", 1), ("b", 3)], ["C1", "C2"]) + + >>> df1.exceptAll(df2).show() + +---+---+ + | C1| C2| + +---+---+ + | a| 1| + | a| 1| + | a| 2| + | c| 4| + +---+---+ + + Also as standard in SQL, this function resolves columns by position (not by name). + """ + return DataFrame(self._jdf.exceptAll(other._jdf), self.sql_ctx) + @since(1.3) def isLocal(self): """Returns ``True`` if the :func:`collect` and :func:`take` methods can be run locally @@ -354,32 +379,12 @@ def show(self, n=20, truncate=True, vertical=False): else: print(self._jdf.showString(n, int(truncate), vertical)) - @property - def _eager_eval(self): - """Returns true if the eager evaluation enabled. - """ - return self.sql_ctx.getConf( - "spark.sql.repl.eagerEval.enabled", "false").lower() == "true" - - @property - def _max_num_rows(self): - """Returns the max row number for eager evaluation. - """ - return int(self.sql_ctx.getConf( - "spark.sql.repl.eagerEval.maxNumRows", "20")) - - @property - def _truncate(self): - """Returns the truncate length for eager evaluation. - """ - return int(self.sql_ctx.getConf( - "spark.sql.repl.eagerEval.truncate", "20")) - def __repr__(self): - if not self._support_repr_html and self._eager_eval: + if not self._support_repr_html and self.sql_ctx._conf.isReplEagerEvalEnabled(): vertical = False return self._jdf.showString( - self._max_num_rows, self._truncate, vertical) + self.sql_ctx._conf.replEagerEvalMaxNumRows(), + self.sql_ctx._conf.replEagerEvalTruncate(), vertical) else: return "DataFrame[%s]" % (", ".join("%s: %s" % c for c in self.dtypes)) @@ -391,10 +396,10 @@ def _repr_html_(self): import cgi if not self._support_repr_html: self._support_repr_html = True - if self._eager_eval: - max_num_rows = max(self._max_num_rows, 0) + if self.sql_ctx._conf.isReplEagerEvalEnabled(): + max_num_rows = max(self.sql_ctx._conf.replEagerEvalMaxNumRows(), 0) sock_info = self._jdf.getRowsToPython( - max_num_rows, self._truncate) + max_num_rows, self.sql_ctx._conf.replEagerEvalTruncate()) rows = list(_load_from_socket(sock_info, BatchedSerializer(PickleSerializer()))) head = rows[0] row_data = rows[1:] @@ -1495,6 +1500,28 @@ def intersect(self, other): """ return DataFrame(self._jdf.intersect(other._jdf), self.sql_ctx) + @since(2.4) + def intersectAll(self, other): + """ Return a new :class:`DataFrame` containing rows in both this dataframe and other + dataframe while preserving duplicates. + + This is equivalent to `INTERSECT ALL` in SQL. + >>> df1 = spark.createDataFrame([("a", 1), ("a", 1), ("b", 3), ("c", 4)], ["C1", "C2"]) + >>> df2 = spark.createDataFrame([("a", 1), ("a", 1), ("b", 3)], ["C1", "C2"]) + + >>> df1.intersectAll(df2).sort("C1", "C2").show() + +---+---+ + | C1| C2| + +---+---+ + | a| 1| + | a| 1| + | b| 3| + +---+---+ + + Also as standard in SQL, this function resolves columns by position (not by name). + """ + return DataFrame(self._jdf.intersectAll(other._jdf), self.sql_ctx) + @since(1.3) def subtract(self, other): """ Return a new :class:`DataFrame` containing rows in this frame @@ -2049,13 +2076,12 @@ def toPandas(self): import pandas as pd - if self.sql_ctx.getConf("spark.sql.execution.pandas.respectSessionTimeZone").lower() \ - == "true": - timezone = self.sql_ctx.getConf("spark.sql.session.timeZone") + if self.sql_ctx._conf.pandasRespectSessionTimeZone(): + timezone = self.sql_ctx._conf.sessionLocalTimeZone() else: timezone = None - if self.sql_ctx.getConf("spark.sql.execution.arrow.enabled", "false").lower() == "true": + if self.sql_ctx._conf.arrowEnabled(): use_arrow = True try: from pyspark.sql.types import to_arrow_schema @@ -2065,8 +2091,7 @@ def toPandas(self): to_arrow_schema(self.schema) except Exception as e: - if self.sql_ctx.getConf("spark.sql.execution.arrow.fallback.enabled", "true") \ - .lower() == "true": + if self.sql_ctx._conf.arrowFallbackEnabled(): msg = ( "toPandas attempted Arrow optimization because " "'spark.sql.execution.arrow.enabled' is set to true; however, " @@ -2093,10 +2118,9 @@ def toPandas(self): from pyspark.sql.types import _check_dataframe_convert_date, \ _check_dataframe_localize_timestamps import pyarrow - - tables = self._collectAsArrow() - if tables: - table = pyarrow.concat_tables(tables) + batches = self._collectAsArrow() + if len(batches) > 0: + table = pyarrow.Table.from_batches(batches) pdf = table.to_pandas() pdf = _check_dataframe_convert_date(pdf, self.schema) return _check_dataframe_localize_timestamps(pdf, timezone) @@ -2145,14 +2169,14 @@ def toPandas(self): def _collectAsArrow(self): """ - Returns all records as list of deserialized ArrowPayloads, pyarrow must be installed - and available. + Returns all records as a list of ArrowRecordBatches, pyarrow must be installed + and available on driver and worker Python environments. .. note:: Experimental. """ with SCCallSiteSync(self._sc) as css: sock_info = self._jdf.collectAsArrowToPython() - return list(_load_from_socket(sock_info, ArrowSerializer())) + return list(_load_from_socket(sock_info, ArrowStreamSerializer())) ########################################################################################## # Pandas compatibility diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 9652d3e79b875..6da5237d18de4 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -283,7 +283,8 @@ def approxCountDistinct(col, rsd=None): @since(2.1) def approx_count_distinct(col, rsd=None): - """Aggregate function: returns a new :class:`Column` for approximate distinct count of column `col`. + """Aggregate function: returns a new :class:`Column` for approximate distinct count of + column `col`. :param rsd: maximum estimation error allowed (default = 0.05). For rsd < 0.01, it is more efficient to use :func:`countDistinct` @@ -346,7 +347,8 @@ def coalesce(*cols): @since(1.6) def corr(col1, col2): - """Returns a new :class:`Column` for the Pearson Correlation Coefficient for ``col1`` and ``col2``. + """Returns a new :class:`Column` for the Pearson Correlation Coefficient for ``col1`` + and ``col2``. >>> a = range(20) >>> b = [2 * x for x in range(20)] @@ -1285,11 +1287,21 @@ def from_utc_timestamp(timestamp, tz): that time as a timestamp in the given time zone. For example, 'GMT+1' would yield '2017-07-14 03:40:00.0'. - >>> df = spark.createDataFrame([('1997-02-28 10:30:00',)], ['t']) - >>> df.select(from_utc_timestamp(df.t, "PST").alias('local_time')).collect() + :param timestamp: the column that contains timestamps + :param tz: a string that has the ID of timezone, e.g. "GMT", "America/Los_Angeles", etc + + .. versionchanged:: 2.4 + `tz` can take a :class:`Column` containing timezone ID strings. + + >>> df = spark.createDataFrame([('1997-02-28 10:30:00', 'JST')], ['ts', 'tz']) + >>> df.select(from_utc_timestamp(df.ts, "PST").alias('local_time')).collect() [Row(local_time=datetime.datetime(1997, 2, 28, 2, 30))] + >>> df.select(from_utc_timestamp(df.ts, df.tz).alias('local_time')).collect() + [Row(local_time=datetime.datetime(1997, 2, 28, 19, 30))] """ sc = SparkContext._active_spark_context + if isinstance(tz, Column): + tz = _to_java_column(tz) return Column(sc._jvm.functions.from_utc_timestamp(_to_java_column(timestamp), tz)) @@ -1300,11 +1312,21 @@ def to_utc_timestamp(timestamp, tz): zone, and renders that time as a timestamp in UTC. For example, 'GMT+1' would yield '2017-07-14 01:40:00.0'. - >>> df = spark.createDataFrame([('1997-02-28 10:30:00',)], ['ts']) + :param timestamp: the column that contains timestamps + :param tz: a string that has the ID of timezone, e.g. "GMT", "America/Los_Angeles", etc + + .. versionchanged:: 2.4 + `tz` can take a :class:`Column` containing timezone ID strings. + + >>> df = spark.createDataFrame([('1997-02-28 10:30:00', 'JST')], ['ts', 'tz']) >>> df.select(to_utc_timestamp(df.ts, "PST").alias('utc_time')).collect() [Row(utc_time=datetime.datetime(1997, 2, 28, 18, 30))] + >>> df.select(to_utc_timestamp(df.ts, df.tz).alias('utc_time')).collect() + [Row(utc_time=datetime.datetime(1997, 2, 28, 1, 30))] """ sc = SparkContext._active_spark_context + if isinstance(tz, Column): + tz = _to_java_column(tz) return Column(sc._jvm.functions.to_utc_timestamp(_to_java_column(timestamp), tz)) @@ -1668,14 +1690,14 @@ def split(str, pattern): @ignore_unicode_prefix @since(1.5) def regexp_extract(str, pattern, idx): - """Extract a specific group matched by a Java regex, from the specified string column. + r"""Extract a specific group matched by a Java regex, from the specified string column. If the regex did not match, or the specified group did not match, an empty string is returned. >>> df = spark.createDataFrame([('100-200',)], ['str']) - >>> df.select(regexp_extract('str', '(\d+)-(\d+)', 1).alias('d')).collect() + >>> df.select(regexp_extract('str', r'(\d+)-(\d+)', 1).alias('d')).collect() [Row(d=u'100')] >>> df = spark.createDataFrame([('foo',)], ['str']) - >>> df.select(regexp_extract('str', '(\d+)', 1).alias('d')).collect() + >>> df.select(regexp_extract('str', r'(\d+)', 1).alias('d')).collect() [Row(d=u'')] >>> df = spark.createDataFrame([('aaaac',)], ['str']) >>> df.select(regexp_extract('str', '(a+)(b)?(c)', 2).alias('d')).collect() @@ -1689,10 +1711,10 @@ def regexp_extract(str, pattern, idx): @ignore_unicode_prefix @since(1.5) def regexp_replace(str, pattern, replacement): - """Replace all substrings of the specified string value that match regexp with rep. + r"""Replace all substrings of the specified string value that match regexp with rep. >>> df = spark.createDataFrame([('100-200',)], ['str']) - >>> df.select(regexp_replace('str', '(\\d+)', '--').alias('d')).collect() + >>> df.select(regexp_replace('str', r'(\d+)', '--').alias('d')).collect() [Row(d=u'-----')] """ sc = SparkContext._active_spark_context @@ -2013,6 +2035,63 @@ def array_distinct(col): return Column(sc._jvm.functions.array_distinct(_to_java_column(col))) +@ignore_unicode_prefix +@since(2.4) +def array_intersect(col1, col2): + """ + Collection function: returns an array of the elements in the intersection of col1 and col2, + without duplicates. + + :param col1: name of column containing array + :param col2: name of column containing array + + >>> from pyspark.sql import Row + >>> df = spark.createDataFrame([Row(c1=["b", "a", "c"], c2=["c", "d", "a", "f"])]) + >>> df.select(array_intersect(df.c1, df.c2)).collect() + [Row(array_intersect(c1, c2)=[u'a', u'c'])] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.array_intersect(_to_java_column(col1), _to_java_column(col2))) + + +@ignore_unicode_prefix +@since(2.4) +def array_union(col1, col2): + """ + Collection function: returns an array of the elements in the union of col1 and col2, + without duplicates. + + :param col1: name of column containing array + :param col2: name of column containing array + + >>> from pyspark.sql import Row + >>> df = spark.createDataFrame([Row(c1=["b", "a", "c"], c2=["c", "d", "a", "f"])]) + >>> df.select(array_union(df.c1, df.c2)).collect() + [Row(array_union(c1, c2)=[u'b', u'a', u'c', u'd', u'f'])] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.array_union(_to_java_column(col1), _to_java_column(col2))) + + +@ignore_unicode_prefix +@since(2.4) +def array_except(col1, col2): + """ + Collection function: returns an array of the elements in col1 but not in col2, + without duplicates. + + :param col1: name of column containing array + :param col2: name of column containing array + + >>> from pyspark.sql import Row + >>> df = spark.createDataFrame([Row(c1=["b", "a", "c"], c2=["c", "d", "a", "f"])]) + >>> df.select(array_except(df.c1, df.c2)).collect() + [Row(array_except(c1, c2)=[u'b'])] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.array_except(_to_java_column(col1), _to_java_column(col2))) + + @since(1.4) def explode(col): """Returns a new row for each element in the given array or map. @@ -2164,7 +2243,7 @@ def json_tuple(col, *fields): def from_json(col, schema, options={}): """ Parses a column containing a JSON string into a :class:`MapType` with :class:`StringType` - as keys type, :class:`StructType` or :class:`ArrayType` of :class:`StructType`\\s with + as keys type, :class:`StructType` or :class:`ArrayType` with the specified schema. Returns `null`, in the case of an unparseable string. :param col: string column in json format @@ -2189,11 +2268,21 @@ def from_json(col, schema, options={}): >>> df = spark.createDataFrame(data, ("key", "value")) >>> df.select(from_json(df.value, schema).alias("json")).collect() [Row(json=[Row(a=1)])] + >>> schema = schema_of_json(lit('''{"a": 0}''')) + >>> df.select(from_json(df.value, schema).alias("json")).collect() + [Row(json=Row(a=1))] + >>> data = [(1, '''[1, 2, 3]''')] + >>> schema = ArrayType(IntegerType()) + >>> df = spark.createDataFrame(data, ("key", "value")) + >>> df.select(from_json(df.value, schema).alias("json")).collect() + [Row(json=[1, 2, 3])] """ sc = SparkContext._active_spark_context if isinstance(schema, DataType): schema = schema.json() + elif isinstance(schema, Column): + schema = _to_java_column(schema) jc = sc._jvm.functions.from_json(_to_java_column(col), schema, options) return Column(jc) @@ -2202,13 +2291,11 @@ def from_json(col, schema, options={}): @since(2.1) def to_json(col, options={}): """ - Converts a column containing a :class:`StructType`, :class:`ArrayType` of - :class:`StructType`\\s, a :class:`MapType` or :class:`ArrayType` of :class:`MapType`\\s + Converts a column containing a :class:`StructType`, :class:`ArrayType` or a :class:`MapType` into a JSON string. Throws an exception, in the case of an unsupported type. - :param col: name of column containing the struct, array of the structs, the map or - array of the maps. - :param options: options to control converting. accepts the same options as the json datasource + :param col: name of column containing a struct, an array or a map. + :param options: options to control converting. accepts the same options as the JSON datasource >>> from pyspark.sql import Row >>> from pyspark.sql.types import * @@ -2228,6 +2315,10 @@ def to_json(col, options={}): >>> df = spark.createDataFrame(data, ("key", "value")) >>> df.select(to_json(df.value).alias("json")).collect() [Row(json=u'[{"name":"Alice"},{"name":"Bob"}]')] + >>> data = [(1, ["Alice", "Bob"])] + >>> df = spark.createDataFrame(data, ("key", "value")) + >>> df.select(to_json(df.value).alias("json")).collect() + [Row(json=u'["Alice","Bob"]')] """ sc = SparkContext._active_spark_context @@ -2235,6 +2326,28 @@ def to_json(col, options={}): return Column(jc) +@ignore_unicode_prefix +@since(2.4) +def schema_of_json(col): + """ + Parses a column containing a JSON string and infers its schema in DDL format. + + :param col: string column in json format + + >>> from pyspark.sql.types import * + >>> data = [(1, '{"a": 1}')] + >>> df = spark.createDataFrame(data, ("key", "value")) + >>> df.select(schema_of_json(df.value).alias("json")).collect() + [Row(json=u'struct')] + >>> df.select(schema_of_json(lit('{"a": 0}')).alias("json")).collect() + [Row(json=u'struct')] + """ + + sc = SparkContext._active_spark_context + jc = sc._jvm.functions.schema_of_json(_to_java_column(col)) + return Column(jc) + + @since(1.5) def size(col): """ @@ -2316,6 +2429,23 @@ def array_sort(col): return Column(sc._jvm.functions.array_sort(_to_java_column(col))) +@since(2.4) +def shuffle(col): + """ + Collection function: Generates a random permutation of the given array. + + .. note:: The function is non-deterministic. + + :param col: name of column or expression + + >>> df = spark.createDataFrame([([1, 20, 3, 5],), ([1, 20, None, 3],)], ['data']) + >>> df.select(shuffle(df.data).alias('s')).collect() # doctest: +SKIP + [Row(s=[3, 1, 5, 20]), Row(s=[20, None, 3, 1])] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.shuffle(_to_java_column(col))) + + @since(1.5) @ignore_unicode_prefix def reverse(col): @@ -2463,6 +2593,50 @@ def arrays_zip(*cols): return Column(sc._jvm.functions.arrays_zip(_to_seq(sc, cols, _to_java_column))) +@since(2.4) +def map_concat(*cols): + """Returns the union of all the given maps. + + :param cols: list of column names (string) or list of :class:`Column` expressions + + >>> from pyspark.sql.functions import map_concat + >>> df = spark.sql("SELECT map(1, 'a', 2, 'b') as map1, map(3, 'c', 1, 'd') as map2") + >>> df.select(map_concat("map1", "map2").alias("map3")).show(truncate=False) + +--------------------------------+ + |map3 | + +--------------------------------+ + |[1 -> a, 2 -> b, 3 -> c, 1 -> d]| + +--------------------------------+ + """ + sc = SparkContext._active_spark_context + if len(cols) == 1 and isinstance(cols[0], (list, set)): + cols = cols[0] + jc = sc._jvm.functions.map_concat(_to_seq(sc, cols, _to_java_column)) + return Column(jc) + + +@since(2.4) +def sequence(start, stop, step=None): + """ + Generate a sequence of integers from `start` to `stop`, incrementing by `step`. + If `step` is not set, incrementing by 1 if `start` is less than or equal to `stop`, + otherwise -1. + + >>> df1 = spark.createDataFrame([(-2, 2)], ('C1', 'C2')) + >>> df1.select(sequence('C1', 'C2').alias('r')).collect() + [Row(r=[-2, -1, 0, 1, 2])] + >>> df2 = spark.createDataFrame([(4, -4, -2)], ('C1', 'C2', 'C3')) + >>> df2.select(sequence('C1', 'C2', 'C3').alias('r')).collect() + [Row(r=[4, 2, 0, -2, -4])] + """ + sc = SparkContext._active_spark_context + if step is None: + return Column(sc._jvm.functions.sequence(_to_java_column(start), _to_java_column(stop))) + else: + return Column(sc._jvm.functions.sequence( + _to_java_column(start), _to_java_column(stop), _to_java_column(step))) + + # ---------------------------- User Defined Function ---------------------------------- class PandasUDFType(object): @@ -2548,9 +2722,10 @@ def pandas_udf(f=None, returnType=None, functionType=None): 1. SCALAR A scalar UDF defines a transformation: One or more `pandas.Series` -> A `pandas.Series`. - The returnType should be a primitive data type, e.g., :class:`DoubleType`. The length of the returned `pandas.Series` must be of the same as the input `pandas.Series`. + :class:`MapType`, :class:`StructType` are currently not supported as output types. + Scalar UDFs are used with :meth:`pyspark.sql.DataFrame.withColumn` and :meth:`pyspark.sql.DataFrame.select`. @@ -2611,14 +2786,14 @@ def pandas_udf(f=None, returnType=None, functionType=None): +---+-------------------+ Alternatively, the user can define a function that takes two arguments. - In this case, the grouping key will be passed as the first argument and the data will - be passed as the second argument. The grouping key will be passed as a tuple of numpy + In this case, the grouping key(s) will be passed as the first argument and the data will + be passed as the second argument. The grouping key(s) will be passed as a tuple of numpy data types, e.g., `numpy.int32` and `numpy.float64`. The data will still be passed in as a `pandas.DataFrame` containing all columns from the original Spark DataFrame. - This is useful when the user does not want to hardcode grouping key in the function. + This is useful when the user does not want to hardcode grouping key(s) in the function. - >>> from pyspark.sql.functions import pandas_udf, PandasUDFType >>> import pandas as pd # doctest: +SKIP + >>> from pyspark.sql.functions import pandas_udf, PandasUDFType >>> df = spark.createDataFrame( ... [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)], ... ("id", "v")) # doctest: +SKIP @@ -2634,6 +2809,22 @@ def pandas_udf(f=None, returnType=None, functionType=None): | 1|1.5| | 2|6.0| +---+---+ + >>> @pandas_udf( + ... "id long, `ceil(v / 2)` long, v double", + ... PandasUDFType.GROUPED_MAP) # doctest: +SKIP + >>> def sum_udf(key, pdf): + ... # key is a tuple of two numpy.int64s, which is the values + ... # of 'id' and 'ceil(df.v / 2)' for the current group + ... return pd.DataFrame([key + (pdf.v.sum(),)]) + >>> df.groupby(df.id, ceil(df.v / 2)).apply(sum_udf).show() # doctest: +SKIP + +---+-----------+----+ + | id|ceil(v / 2)| v| + +---+-----------+----+ + | 2| 5|10.0| + | 1| 1| 3.0| + | 2| 3| 5.0| + | 2| 2| 3.0| + +---+-----------+----+ .. note:: If returning a new `pandas.DataFrame` constructed with a dictionary, it is recommended to explicitly index the columns by name to ensure the positions are correct, @@ -2683,8 +2874,9 @@ def pandas_udf(f=None, returnType=None, functionType=None): >>> @pandas_udf("double", PandasUDFType.GROUPED_AGG) # doctest: +SKIP ... def mean_udf(v): ... return v.mean() - >>> w = Window.partitionBy('id') \\ - ... .rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing) + >>> w = Window \\ + ... .partitionBy('id') \\ + ... .rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing) >>> df.withColumn('mean_v', mean_udf(df['v']).over(w)).show() # doctest: +SKIP +---+----+------+ | id| v|mean_v| @@ -2760,6 +2952,7 @@ def pandas_udf(f=None, returnType=None, functionType=None): blacklist = ['map', 'since', 'ignore_unicode_prefix'] __all__ = [k for k, v in globals().items() if not k.startswith('_') and k[0].islower() and callable(v) and k not in blacklist] +__all__ += ["PandasUDFType"] __all__.sort() diff --git a/python/pyspark/sql/group.py b/python/pyspark/sql/group.py index 0906c9c6b329a..cc1da8e7c1f72 100644 --- a/python/pyspark/sql/group.py +++ b/python/pyspark/sql/group.py @@ -211,6 +211,8 @@ def pivot(self, pivot_col, values=None): >>> df4.groupBy("year").pivot("course").sum("earnings").collect() [Row(year=2012, Java=20000, dotNET=15000), Row(year=2013, Java=30000, dotNET=48000)] + >>> df5.groupBy("sales.year").pivot("sales.course").sum("sales.earnings").collect() + [Row(year=2012, Java=20000, dotNET=15000), Row(year=2013, Java=30000, dotNET=48000)] """ if values is None: jgd = self._jgd.pivot(pivot_col) @@ -296,6 +298,12 @@ def _test(): Row(course="dotNET", year=2012, earnings=5000), Row(course="dotNET", year=2013, earnings=48000), Row(course="Java", year=2013, earnings=30000)]).toDF() + globs['df5'] = sc.parallelize([ + Row(training="expert", sales=Row(course="dotNET", year=2012, earnings=10000)), + Row(training="junior", sales=Row(course="Java", year=2012, earnings=20000)), + Row(training="expert", sales=Row(course="dotNET", year=2012, earnings=5000)), + Row(training="junior", sales=Row(course="dotNET", year=2013, earnings=48000)), + Row(training="expert", sales=Row(course="Java", year=2013, earnings=30000))]).toDF() (failure_count, test_count) = doctest.testmod( pyspark.sql.group, globs=globs, diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index 3efe2adb6e2a4..690b13072244b 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -267,7 +267,7 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, mode=mode, columnNameOfCorruptRecord=columnNameOfCorruptRecord, dateFormat=dateFormat, timestampFormat=timestampFormat, multiLine=multiLine, allowUnquotedControlChars=allowUnquotedControlChars, lineSep=lineSep, - samplingRatio=samplingRatio, encoding=encoding) + samplingRatio=samplingRatio, dropFieldIfAllNull=dropFieldIfAllNull, encoding=encoding) if isinstance(path, basestring): path = [path] if type(path) == list: @@ -349,8 +349,8 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non negativeInf=None, dateFormat=None, timestampFormat=None, maxColumns=None, maxCharsPerColumn=None, maxMalformedLogPerPartition=None, mode=None, columnNameOfCorruptRecord=None, multiLine=None, charToEscapeQuoteEscaping=None, - samplingRatio=None, enforceSchema=None): - """Loads a CSV file and returns the result as a :class:`DataFrame`. + samplingRatio=None, enforceSchema=None, emptyValue=None): + r"""Loads a CSV file and returns the result as a :class:`DataFrame`. This function will go through the input once to determine the input schema if ``inferSchema`` is enabled. To avoid going through the entire data once, disable @@ -444,6 +444,8 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non different, ``\0`` otherwise. :param samplingRatio: defines fraction of rows used for schema inferring. If None is set, it uses the default value, ``1.0``. + :param emptyValue: sets the string representation of an empty value. If None is set, it uses + the default value, empty string. >>> df = spark.read.csv('python/test_support/sql/ages.csv') >>> df.dtypes @@ -463,7 +465,7 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non maxMalformedLogPerPartition=maxMalformedLogPerPartition, mode=mode, columnNameOfCorruptRecord=columnNameOfCorruptRecord, multiLine=multiLine, charToEscapeQuoteEscaping=charToEscapeQuoteEscaping, samplingRatio=samplingRatio, - enforceSchema=enforceSchema) + enforceSchema=enforceSchema, emptyValue=emptyValue) if isinstance(path, basestring): path = [path] if type(path) == list: @@ -517,8 +519,8 @@ def jdbc(self, url, table, column=None, lowerBound=None, upperBound=None, numPar If both ``column`` and ``predicates`` are specified, ``column`` will be used. - .. note:: Don't create too many partitions in parallel on a large cluster; \ - otherwise Spark might crash your external database systems. + .. note:: Don't create too many partitions in parallel on a large cluster; + otherwise Spark might crash your external database systems. :param url: a JDBC URL of the form ``jdbc:subprotocol:subname`` :param table: the name of the table @@ -825,10 +827,10 @@ def parquet(self, path, mode=None, partitionBy=None, compression=None): exists. :param partitionBy: names of partitioning columns :param compression: compression codec to use when saving to file. This can be one of the - known case-insensitive shorten names (none, snappy, gzip, and lzo). - This will override ``spark.sql.parquet.compression.codec``. If None - is set, it uses the value specified in - ``spark.sql.parquet.compression.codec``. + known case-insensitive shorten names (none, uncompressed, snappy, gzip, + lzo, brotli, lz4, and zstd). This will override + ``spark.sql.parquet.compression.codec``. If None is set, it uses the + value specified in ``spark.sql.parquet.compression.codec``. >>> df.write.parquet(os.path.join(tempfile.mkdtemp(), 'data')) """ @@ -859,8 +861,8 @@ def text(self, path, compression=None, lineSep=None): def csv(self, path, mode=None, compression=None, sep=None, quote=None, escape=None, header=None, nullValue=None, escapeQuotes=None, quoteAll=None, dateFormat=None, timestampFormat=None, ignoreLeadingWhiteSpace=None, ignoreTrailingWhiteSpace=None, - charToEscapeQuoteEscaping=None): - """Saves the content of the :class:`DataFrame` in CSV format at the specified path. + charToEscapeQuoteEscaping=None, encoding=None, emptyValue=None): + r"""Saves the content of the :class:`DataFrame` in CSV format at the specified path. :param path: the path in any Hadoop supported file system :param mode: specifies the behavior of the save operation when data already exists. @@ -909,6 +911,10 @@ def csv(self, path, mode=None, compression=None, sep=None, quote=None, escape=No the quote character. If None is set, the default value is escape character when escape and quote characters are different, ``\0`` otherwise.. + :param encoding: sets the encoding (charset) of saved csv files. If None is set, + the default UTF-8 charset will be used. + :param emptyValue: sets the string representation of an empty value. If None is set, it uses + the default value, ``""``. >>> df.write.csv(os.path.join(tempfile.mkdtemp(), 'data')) """ @@ -918,7 +924,8 @@ def csv(self, path, mode=None, compression=None, sep=None, quote=None, escape=No dateFormat=dateFormat, timestampFormat=timestampFormat, ignoreLeadingWhiteSpace=ignoreLeadingWhiteSpace, ignoreTrailingWhiteSpace=ignoreTrailingWhiteSpace, - charToEscapeQuoteEscaping=charToEscapeQuoteEscaping) + charToEscapeQuoteEscaping=charToEscapeQuoteEscaping, + encoding=encoding, emptyValue=emptyValue) self._jwrite.csv(path) @since(1.5) @@ -955,8 +962,8 @@ def orc(self, path, mode=None, partitionBy=None, compression=None): def jdbc(self, url, table, mode=None, properties=None): """Saves the content of the :class:`DataFrame` to an external database table via JDBC. - .. note:: Don't create too many partitions in parallel on a large cluster; \ - otherwise Spark might crash your external database systems. + .. note:: Don't create too many partitions in parallel on a large cluster; + otherwise Spark might crash your external database systems. :param url: a JDBC URL of the form ``jdbc:subprotocol:subname`` :param table: Name of the table in the external database. diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py index f1ad6b1212ed9..87d8d6a59a6e9 100644 --- a/python/pyspark/sql/session.py +++ b/python/pyspark/sql/session.py @@ -501,7 +501,7 @@ def _create_from_pandas_with_arrow(self, pdf, schema, timezone): to Arrow data, then sending to the JVM to parallelize. If a schema is passed in, the data types will be used to coerce the data in Pandas to Arrow conversion. """ - from pyspark.serializers import ArrowSerializer, _create_batch + from pyspark.serializers import ArrowStreamSerializer, _create_batch from pyspark.sql.types import from_arrow_schema, to_arrow_type, TimestampType from pyspark.sql.utils import require_minimum_pandas_version, \ require_minimum_pyarrow_version @@ -539,10 +539,12 @@ def _create_from_pandas_with_arrow(self, pdf, schema, timezone): struct.names[i] = name schema = struct - # Create the Spark DataFrame directly from the Arrow data and schema - jrdd = self._sc._serialize_to_jvm(batches, len(batches), ArrowSerializer()) - jdf = self._jvm.PythonSQLUtils.arrowPayloadToDataFrame( - jrdd, schema.json(), self._wrapped._jsqlContext) + def reader_func(temp_filename): + return self._jvm.PythonSQLUtils.arrowReadStreamFromFile( + self._wrapped._jsqlContext, temp_filename, schema.json()) + + # Create Spark DataFrame from Arrow stream file, using one batch per partition + jdf = self._sc._serialize_to_jvm(batches, ArrowStreamSerializer(), reader_func) df = DataFrame(jdf, self._wrapped) df._schema = schema return df @@ -678,9 +680,8 @@ def createDataFrame(self, data, schema=None, samplingRatio=None, verifySchema=Tr from pyspark.sql.utils import require_minimum_pandas_version require_minimum_pandas_version() - if self.conf.get("spark.sql.execution.pandas.respectSessionTimeZone").lower() \ - == "true": - timezone = self.conf.get("spark.sql.session.timeZone") + if self._wrapped._conf.pandasRespectSessionTimeZone(): + timezone = self._wrapped._conf.sessionLocalTimeZone() else: timezone = None @@ -690,15 +691,13 @@ def createDataFrame(self, data, schema=None, samplingRatio=None, verifySchema=Tr (x.encode('utf-8') if not isinstance(x, str) else x) for x in data.columns] - if self.conf.get("spark.sql.execution.arrow.enabled", "false").lower() == "true" \ - and len(data) > 0: + if self._wrapped._conf.arrowEnabled() and len(data) > 0: try: return self._create_from_pandas_with_arrow(data, schema, timezone) except Exception as e: from pyspark.util import _exception_message - if self.conf.get("spark.sql.execution.arrow.fallback.enabled", "true") \ - .lower() == "true": + if self._wrapped._conf.arrowFallbackEnabled(): msg = ( "createDataFrame attempted Arrow optimization because " "'spark.sql.execution.arrow.enabled' is set to true; however, " diff --git a/python/pyspark/sql/streaming.py b/python/pyspark/sql/streaming.py index 8c1fd4af674d7..b18453b2a4f96 100644 --- a/python/pyspark/sql/streaming.py +++ b/python/pyspark/sql/streaming.py @@ -19,10 +19,7 @@ import json if sys.version >= '3': - intlike = int - basestring = unicode = str -else: - intlike = (int, long) + basestring = str from py4j.java_gateway import java_import @@ -567,8 +564,8 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non negativeInf=None, dateFormat=None, timestampFormat=None, maxColumns=None, maxCharsPerColumn=None, maxMalformedLogPerPartition=None, mode=None, columnNameOfCorruptRecord=None, multiLine=None, charToEscapeQuoteEscaping=None, - enforceSchema=None): - """Loads a CSV file stream and returns the result as a :class:`DataFrame`. + enforceSchema=None, emptyValue=None): + r"""Loads a CSV file stream and returns the result as a :class:`DataFrame`. This function will go through the input once to determine the input schema if ``inferSchema`` is enabled. To avoid going through the entire data once, disable @@ -661,6 +658,8 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non the quote character. If None is set, the default value is escape character when escape and quote characters are different, ``\0`` otherwise.. + :param emptyValue: sets the string representation of an empty value. If None is set, it uses + the default value, empty string. >>> csv_sdf = spark.readStream.csv(tempfile.mkdtemp(), schema = sdf_schema) >>> csv_sdf.isStreaming @@ -677,7 +676,8 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non maxCharsPerColumn=maxCharsPerColumn, maxMalformedLogPerPartition=maxMalformedLogPerPartition, mode=mode, columnNameOfCorruptRecord=columnNameOfCorruptRecord, multiLine=multiLine, - charToEscapeQuoteEscaping=charToEscapeQuoteEscaping, enforceSchema=enforceSchema) + charToEscapeQuoteEscaping=charToEscapeQuoteEscaping, enforceSchema=enforceSchema, + emptyValue=emptyValue) if isinstance(path, basestring): return self._df(self._jreader.csv(path)) else: diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 8d738069adb3d..8e5bc6729dfa4 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -68,8 +68,16 @@ # If Arrow version requirement is not satisfied, skip related tests. _pyarrow_requirement_message = _exception_message(e) +_test_not_compiled_message = None +try: + from pyspark.sql.utils import require_test_compiled + require_test_compiled() +except Exception as e: + _test_not_compiled_message = _exception_message(e) + _have_pandas = _pandas_requirement_message is None _have_pyarrow = _pyarrow_requirement_message is None +_test_compiled = _test_not_compiled_message is None from pyspark import SparkContext from pyspark.sql import SparkSession, SQLContext, HiveContext, Column, Row @@ -269,6 +277,10 @@ def test_struct_field_type_name(self): struct_field = StructField("a", IntegerType()) self.assertRaises(TypeError, struct_field.typeName) + def test_invalid_create_row(self): + row_class = Row("c1", "c2") + self.assertRaises(ValueError, lambda: row_class(1, 2, 3)) + class SQLTests(ReusedSQLTestCase): @@ -765,7 +777,7 @@ def filename(path): row2 = df2.select(sameText(df2['file'])).first() self.assertTrue(row2[0].find("people.json") != -1) - def test_udf_defers_judf_initalization(self): + def test_udf_defers_judf_initialization(self): # This is separate of UDFInitializationTests # to avoid context initialization # when udf is called @@ -3351,6 +3363,63 @@ def test_checking_csv_header(self): finally: shutil.rmtree(path) + def test_ignore_column_of_all_nulls(self): + path = tempfile.mkdtemp() + shutil.rmtree(path) + try: + df = self.spark.createDataFrame([["""{"a":null, "b":1, "c":3.0}"""], + ["""{"a":null, "b":null, "c":"string"}"""], + ["""{"a":null, "b":null, "c":null}"""]]) + df.write.text(path) + schema = StructType([ + StructField('b', LongType(), nullable=True), + StructField('c', StringType(), nullable=True)]) + readback = self.spark.read.json(path, dropFieldIfAllNull=True) + self.assertEquals(readback.schema, schema) + finally: + shutil.rmtree(path) + + # SPARK-24721 + @unittest.skipIf(not _test_compiled, _test_not_compiled_message) + def test_datasource_with_udf(self): + from pyspark.sql.functions import udf, lit, col + + path = tempfile.mkdtemp() + shutil.rmtree(path) + + try: + self.spark.range(1).write.mode("overwrite").format('csv').save(path) + filesource_df = self.spark.read.option('inferSchema', True).csv(path).toDF('i') + datasource_df = self.spark.read \ + .format("org.apache.spark.sql.sources.SimpleScanSource") \ + .option('from', 0).option('to', 1).load().toDF('i') + datasource_v2_df = self.spark.read \ + .format("org.apache.spark.sql.sources.v2.SimpleDataSourceV2") \ + .load().toDF('i', 'j') + + c1 = udf(lambda x: x + 1, 'int')(lit(1)) + c2 = udf(lambda x: x + 1, 'int')(col('i')) + + f1 = udf(lambda x: False, 'boolean')(lit(1)) + f2 = udf(lambda x: False, 'boolean')(col('i')) + + for df in [filesource_df, datasource_df, datasource_v2_df]: + result = df.withColumn('c', c1) + expected = df.withColumn('c', lit(2)) + self.assertEquals(expected.collect(), result.collect()) + + for df in [filesource_df, datasource_df, datasource_v2_df]: + result = df.withColumn('c', c2) + expected = df.withColumn('c', col('i') + 1) + self.assertEquals(expected.collect(), result.collect()) + + for df in [filesource_df, datasource_df, datasource_v2_df]: + for f in [f1, f2]: + result = df.filter(f) + self.assertEquals(0, result.count()) + finally: + shutil.rmtree(path) + def test_repr_behaviors(self): import re pattern = re.compile(r'^ *\|', re.MULTILINE) @@ -3595,9 +3664,9 @@ def tearDown(self): SparkSession._instantiatedSession.stop() if SparkContext._active_spark_context is not None: - SparkContext._active_spark_contex.stop() + SparkContext._active_spark_context.stop() - def test_udf_init_shouldnt_initalize_context(self): + def test_udf_init_shouldnt_initialize_context(self): from pyspark.sql.functions import UserDefinedFunction UserDefinedFunction(lambda x: x, StringType()) @@ -4034,6 +4103,8 @@ class ArrowTests(ReusedSQLTestCase): def setUpClass(cls): from datetime import date, datetime from decimal import Decimal + from distutils.version import LooseVersion + import pyarrow as pa ReusedSQLTestCase.setUpClass() # Synchronize default timezone between Python and Java @@ -4062,6 +4133,13 @@ def setUpClass(cls): (u"c", 3, 30, 0.8, 6.0, Decimal("6.0"), date(2100, 3, 3), datetime(2100, 3, 3, 3, 3, 3))] + # TODO: remove version check once minimum pyarrow version is 0.10.0 + if LooseVersion("0.10.0") <= LooseVersion(pa.__version__): + cls.schema.add(StructField("9_binary_t", BinaryType(), True)) + cls.data[0] = cls.data[0] + (bytearray(b"a"),) + cls.data[1] = cls.data[1] + (bytearray(b"bb"),) + cls.data[2] = cls.data[2] + (bytearray(b"ccc"),) + @classmethod def tearDownClass(cls): del os.environ["TZ"] @@ -4099,12 +4177,23 @@ def test_toPandas_fallback_enabled(self): self.assertPandasEqual(pdf, pd.DataFrame({u'map': [{u'a': 1}]})) def test_toPandas_fallback_disabled(self): + from distutils.version import LooseVersion + import pyarrow as pa + schema = StructType([StructField("map", MapType(StringType(), IntegerType()), True)]) df = self.spark.createDataFrame([(None,)], schema=schema) with QuietTest(self.sc): with self.assertRaisesRegexp(Exception, 'Unsupported type'): df.toPandas() + # TODO: remove BinaryType check once minimum pyarrow version is 0.10.0 + if LooseVersion(pa.__version__) < LooseVersion("0.10.0"): + schema = StructType([StructField("binary", BinaryType(), True)]) + df = self.spark.createDataFrame([(None,)], schema=schema) + with QuietTest(self.sc): + with self.assertRaisesRegexp(Exception, 'Unsupported type.*BinaryType'): + df.toPandas() + def test_null_conversion(self): df_null = self.spark.createDataFrame([tuple([None for _ in range(len(self.data[0]))])] + self.data) @@ -4216,19 +4305,22 @@ def test_createDataFrame_with_schema(self): def test_createDataFrame_with_incorrect_schema(self): pdf = self.create_pandas_data_frame() - wrong_schema = StructType(list(reversed(self.schema))) + fields = list(self.schema) + fields[0], fields[7] = fields[7], fields[0] # swap str with timestamp + wrong_schema = StructType(fields) with QuietTest(self.sc): with self.assertRaisesRegexp(Exception, ".*No cast.*string.*timestamp.*"): self.spark.createDataFrame(pdf, schema=wrong_schema) def test_createDataFrame_with_names(self): pdf = self.create_pandas_data_frame() + new_names = list(map(str, range(len(self.schema.fieldNames())))) # Test that schema as a list of column names gets applied - df = self.spark.createDataFrame(pdf, schema=list('abcdefgh')) - self.assertEquals(df.schema.fieldNames(), list('abcdefgh')) + df = self.spark.createDataFrame(pdf, schema=list(new_names)) + self.assertEquals(df.schema.fieldNames(), new_names) # Test that schema as tuple of column names gets applied - df = self.spark.createDataFrame(pdf, schema=tuple('abcdefgh')) - self.assertEquals(df.schema.fieldNames(), list('abcdefgh')) + df = self.spark.createDataFrame(pdf, schema=tuple(new_names)) + self.assertEquals(df.schema.fieldNames(), new_names) def test_createDataFrame_column_name_encoding(self): import pandas as pd @@ -4315,13 +4407,22 @@ def test_createDataFrame_fallback_enabled(self): self.assertEqual(df.collect(), [Row(a={u'a': 1})]) def test_createDataFrame_fallback_disabled(self): + from distutils.version import LooseVersion import pandas as pd + import pyarrow as pa with QuietTest(self.sc): with self.assertRaisesRegexp(TypeError, 'Unsupported type'): self.spark.createDataFrame( pd.DataFrame([[{u'a': 1}]]), "a: map") + # TODO: remove BinaryType check once minimum pyarrow version is 0.10.0 + if LooseVersion(pa.__version__) < LooseVersion("0.10.0"): + with QuietTest(self.sc): + with self.assertRaisesRegexp(TypeError, 'Unsupported type.*BinaryType'): + self.spark.createDataFrame( + pd.DataFrame([[{'a': b'aaa'}]]), "a: binary") + # Regression test for SPARK-23314 def test_timestamp_dst(self): import pandas as pd @@ -4342,6 +4443,7 @@ def test_timestamp_dst(self): not _have_pandas or not _have_pyarrow, _pandas_requirement_message or _pyarrow_requirement_message) class PandasUDFTests(ReusedSQLTestCase): + def test_pandas_udf_basic(self): from pyspark.rdd import PythonEvalType from pyspark.sql.functions import pandas_udf, PandasUDFType @@ -4557,6 +4659,24 @@ def random_udf(v): random_udf = random_udf.asNondeterministic() return random_udf + def test_pandas_udf_tokenize(self): + from pyspark.sql.functions import pandas_udf + tokenize = pandas_udf(lambda s: s.apply(lambda str: str.split(' ')), + ArrayType(StringType())) + self.assertEqual(tokenize.returnType, ArrayType(StringType())) + df = self.spark.createDataFrame([("hi boo",), ("bye boo",)], ["vals"]) + result = df.select(tokenize("vals").alias("hi")) + self.assertEqual([Row(hi=[u'hi', u'boo']), Row(hi=[u'bye', u'boo'])], result.collect()) + + def test_pandas_udf_nested_arrays(self): + from pyspark.sql.functions import pandas_udf + tokenize = pandas_udf(lambda s: s.apply(lambda str: [str.split(' ')]), + ArrayType(ArrayType(StringType()))) + self.assertEqual(tokenize.returnType, ArrayType(ArrayType(StringType()))) + df = self.spark.createDataFrame([("hi boo",), ("bye boo",)], ["vals"]) + result = df.select(tokenize("vals").alias("hi")) + self.assertEqual([Row(hi=[[u'hi', u'boo']]), Row(hi=[[u'bye', u'boo']])], result.collect()) + def test_vectorized_udf_basic(self): from pyspark.sql.functions import pandas_udf, col, array df = self.spark.range(10).select( @@ -4713,6 +4833,24 @@ def test_vectorized_udf_datatype_string(self): bool_f(col('bool'))) self.assertEquals(df.collect(), res.collect()) + def test_vectorized_udf_null_binary(self): + from distutils.version import LooseVersion + import pyarrow as pa + from pyspark.sql.functions import pandas_udf, col + if LooseVersion(pa.__version__) < LooseVersion("0.10.0"): + with QuietTest(self.sc): + with self.assertRaisesRegexp( + NotImplementedError, + 'Invalid returnType.*scalar Pandas UDF.*BinaryType'): + pandas_udf(lambda x: x, BinaryType()) + else: + data = [(bytearray(b"a"),), (None,), (bytearray(b"bb"),), (bytearray(b"ccc"),)] + schema = StructType().add("binary", BinaryType()) + df = self.spark.createDataFrame(data, schema) + str_f = pandas_udf(lambda x: x, BinaryType()) + res = df.select(str_f(col('binary'))) + self.assertEquals(df.collect(), res.collect()) + def test_vectorized_udf_array_type(self): from pyspark.sql.functions import pandas_udf, col data = [([1, 2],), ([3, 4],)] @@ -4763,17 +4901,6 @@ def test_vectorized_udf_invalid_length(self): 'Result vector from pandas_udf was not the required length'): df.select(raise_exception(col('id'))).collect() - def test_vectorized_udf_mix_udf(self): - from pyspark.sql.functions import pandas_udf, udf, col - df = self.spark.range(10) - row_by_row_udf = udf(lambda x: x, LongType()) - pd_udf = pandas_udf(lambda x: x, LongType()) - with QuietTest(self.sc): - with self.assertRaisesRegexp( - Exception, - 'Can not mix vectorized and non-vectorized UDFs'): - df.select(row_by_row_udf(col('id')), pd_udf(col('id'))).collect() - def test_vectorized_udf_chained(self): from pyspark.sql.functions import pandas_udf, col df = self.spark.range(10) @@ -4830,12 +4957,6 @@ def test_vectorized_udf_unsupported_types(self): 'Invalid returnType.*scalar Pandas UDF.*MapType'): pandas_udf(lambda x: x, MapType(StringType(), IntegerType())) - with QuietTest(self.sc): - with self.assertRaisesRegexp( - NotImplementedError, - 'Invalid returnType.*scalar Pandas UDF.*BinaryType'): - pandas_udf(lambda x: x, BinaryType()) - def test_vectorized_udf_dates(self): from pyspark.sql.functions import pandas_udf, col from datetime import date @@ -5060,6 +5181,211 @@ def test_type_annotation(self): df = self.spark.range(1).select(pandas_udf(f=_locals['noop'], returnType='bigint')('id')) self.assertEqual(df.first()[0], 0) + def test_mixed_udf(self): + import pandas as pd + from pyspark.sql.functions import col, udf, pandas_udf + + df = self.spark.range(0, 1).toDF('v') + + # Test mixture of multiple UDFs and Pandas UDFs. + + @udf('int') + def f1(x): + assert type(x) == int + return x + 1 + + @pandas_udf('int') + def f2(x): + assert type(x) == pd.Series + return x + 10 + + @udf('int') + def f3(x): + assert type(x) == int + return x + 100 + + @pandas_udf('int') + def f4(x): + assert type(x) == pd.Series + return x + 1000 + + # Test single expression with chained UDFs + df_chained_1 = df.withColumn('f2_f1', f2(f1(df['v']))) + df_chained_2 = df.withColumn('f3_f2_f1', f3(f2(f1(df['v'])))) + df_chained_3 = df.withColumn('f4_f3_f2_f1', f4(f3(f2(f1(df['v']))))) + df_chained_4 = df.withColumn('f4_f2_f1', f4(f2(f1(df['v'])))) + df_chained_5 = df.withColumn('f4_f3_f1', f4(f3(f1(df['v'])))) + + expected_chained_1 = df.withColumn('f2_f1', df['v'] + 11) + expected_chained_2 = df.withColumn('f3_f2_f1', df['v'] + 111) + expected_chained_3 = df.withColumn('f4_f3_f2_f1', df['v'] + 1111) + expected_chained_4 = df.withColumn('f4_f2_f1', df['v'] + 1011) + expected_chained_5 = df.withColumn('f4_f3_f1', df['v'] + 1101) + + self.assertEquals(expected_chained_1.collect(), df_chained_1.collect()) + self.assertEquals(expected_chained_2.collect(), df_chained_2.collect()) + self.assertEquals(expected_chained_3.collect(), df_chained_3.collect()) + self.assertEquals(expected_chained_4.collect(), df_chained_4.collect()) + self.assertEquals(expected_chained_5.collect(), df_chained_5.collect()) + + # Test multiple mixed UDF expressions in a single projection + df_multi_1 = df \ + .withColumn('f1', f1(col('v'))) \ + .withColumn('f2', f2(col('v'))) \ + .withColumn('f3', f3(col('v'))) \ + .withColumn('f4', f4(col('v'))) \ + .withColumn('f2_f1', f2(col('f1'))) \ + .withColumn('f3_f1', f3(col('f1'))) \ + .withColumn('f4_f1', f4(col('f1'))) \ + .withColumn('f3_f2', f3(col('f2'))) \ + .withColumn('f4_f2', f4(col('f2'))) \ + .withColumn('f4_f3', f4(col('f3'))) \ + .withColumn('f3_f2_f1', f3(col('f2_f1'))) \ + .withColumn('f4_f2_f1', f4(col('f2_f1'))) \ + .withColumn('f4_f3_f1', f4(col('f3_f1'))) \ + .withColumn('f4_f3_f2', f4(col('f3_f2'))) \ + .withColumn('f4_f3_f2_f1', f4(col('f3_f2_f1'))) + + # Test mixed udfs in a single expression + df_multi_2 = df \ + .withColumn('f1', f1(col('v'))) \ + .withColumn('f2', f2(col('v'))) \ + .withColumn('f3', f3(col('v'))) \ + .withColumn('f4', f4(col('v'))) \ + .withColumn('f2_f1', f2(f1(col('v')))) \ + .withColumn('f3_f1', f3(f1(col('v')))) \ + .withColumn('f4_f1', f4(f1(col('v')))) \ + .withColumn('f3_f2', f3(f2(col('v')))) \ + .withColumn('f4_f2', f4(f2(col('v')))) \ + .withColumn('f4_f3', f4(f3(col('v')))) \ + .withColumn('f3_f2_f1', f3(f2(f1(col('v'))))) \ + .withColumn('f4_f2_f1', f4(f2(f1(col('v'))))) \ + .withColumn('f4_f3_f1', f4(f3(f1(col('v'))))) \ + .withColumn('f4_f3_f2', f4(f3(f2(col('v'))))) \ + .withColumn('f4_f3_f2_f1', f4(f3(f2(f1(col('v')))))) + + expected = df \ + .withColumn('f1', df['v'] + 1) \ + .withColumn('f2', df['v'] + 10) \ + .withColumn('f3', df['v'] + 100) \ + .withColumn('f4', df['v'] + 1000) \ + .withColumn('f2_f1', df['v'] + 11) \ + .withColumn('f3_f1', df['v'] + 101) \ + .withColumn('f4_f1', df['v'] + 1001) \ + .withColumn('f3_f2', df['v'] + 110) \ + .withColumn('f4_f2', df['v'] + 1010) \ + .withColumn('f4_f3', df['v'] + 1100) \ + .withColumn('f3_f2_f1', df['v'] + 111) \ + .withColumn('f4_f2_f1', df['v'] + 1011) \ + .withColumn('f4_f3_f1', df['v'] + 1101) \ + .withColumn('f4_f3_f2', df['v'] + 1110) \ + .withColumn('f4_f3_f2_f1', df['v'] + 1111) + + self.assertEquals(expected.collect(), df_multi_1.collect()) + self.assertEquals(expected.collect(), df_multi_2.collect()) + + def test_mixed_udf_and_sql(self): + import pandas as pd + from pyspark.sql import Column + from pyspark.sql.functions import udf, pandas_udf + + df = self.spark.range(0, 1).toDF('v') + + # Test mixture of UDFs, Pandas UDFs and SQL expression. + + @udf('int') + def f1(x): + assert type(x) == int + return x + 1 + + def f2(x): + assert type(x) == Column + return x + 10 + + @pandas_udf('int') + def f3(x): + assert type(x) == pd.Series + return x + 100 + + df1 = df.withColumn('f1', f1(df['v'])) \ + .withColumn('f2', f2(df['v'])) \ + .withColumn('f3', f3(df['v'])) \ + .withColumn('f1_f2', f1(f2(df['v']))) \ + .withColumn('f1_f3', f1(f3(df['v']))) \ + .withColumn('f2_f1', f2(f1(df['v']))) \ + .withColumn('f2_f3', f2(f3(df['v']))) \ + .withColumn('f3_f1', f3(f1(df['v']))) \ + .withColumn('f3_f2', f3(f2(df['v']))) \ + .withColumn('f1_f2_f3', f1(f2(f3(df['v'])))) \ + .withColumn('f1_f3_f2', f1(f3(f2(df['v'])))) \ + .withColumn('f2_f1_f3', f2(f1(f3(df['v'])))) \ + .withColumn('f2_f3_f1', f2(f3(f1(df['v'])))) \ + .withColumn('f3_f1_f2', f3(f1(f2(df['v'])))) \ + .withColumn('f3_f2_f1', f3(f2(f1(df['v'])))) + + expected = df.withColumn('f1', df['v'] + 1) \ + .withColumn('f2', df['v'] + 10) \ + .withColumn('f3', df['v'] + 100) \ + .withColumn('f1_f2', df['v'] + 11) \ + .withColumn('f1_f3', df['v'] + 101) \ + .withColumn('f2_f1', df['v'] + 11) \ + .withColumn('f2_f3', df['v'] + 110) \ + .withColumn('f3_f1', df['v'] + 101) \ + .withColumn('f3_f2', df['v'] + 110) \ + .withColumn('f1_f2_f3', df['v'] + 111) \ + .withColumn('f1_f3_f2', df['v'] + 111) \ + .withColumn('f2_f1_f3', df['v'] + 111) \ + .withColumn('f2_f3_f1', df['v'] + 111) \ + .withColumn('f3_f1_f2', df['v'] + 111) \ + .withColumn('f3_f2_f1', df['v'] + 111) + + self.assertEquals(expected.collect(), df1.collect()) + + # SPARK-24721 + @unittest.skipIf(not _test_compiled, _test_not_compiled_message) + def test_datasource_with_udf(self): + # Same as SQLTests.test_datasource_with_udf, but with Pandas UDF + # This needs to a separate test because Arrow dependency is optional + import pandas as pd + import numpy as np + from pyspark.sql.functions import pandas_udf, lit, col + + path = tempfile.mkdtemp() + shutil.rmtree(path) + + try: + self.spark.range(1).write.mode("overwrite").format('csv').save(path) + filesource_df = self.spark.read.option('inferSchema', True).csv(path).toDF('i') + datasource_df = self.spark.read \ + .format("org.apache.spark.sql.sources.SimpleScanSource") \ + .option('from', 0).option('to', 1).load().toDF('i') + datasource_v2_df = self.spark.read \ + .format("org.apache.spark.sql.sources.v2.SimpleDataSourceV2") \ + .load().toDF('i', 'j') + + c1 = pandas_udf(lambda x: x + 1, 'int')(lit(1)) + c2 = pandas_udf(lambda x: x + 1, 'int')(col('i')) + + f1 = pandas_udf(lambda x: pd.Series(np.repeat(False, len(x))), 'boolean')(lit(1)) + f2 = pandas_udf(lambda x: pd.Series(np.repeat(False, len(x))), 'boolean')(col('i')) + + for df in [filesource_df, datasource_df, datasource_v2_df]: + result = df.withColumn('c', c1) + expected = df.withColumn('c', lit(2)) + self.assertEquals(expected.collect(), result.collect()) + + for df in [filesource_df, datasource_df, datasource_v2_df]: + result = df.withColumn('c', c2) + expected = df.withColumn('c', col('i') + 1) + self.assertEquals(expected.collect(), result.collect()) + + for df in [filesource_df, datasource_df, datasource_v2_df]: + for f in [f1, f2]: + result = df.filter(f) + self.assertEquals(0, result.count()) + finally: + shutil.rmtree(path) + @unittest.skipIf( not _have_pandas or not _have_pyarrow, @@ -5471,6 +5797,37 @@ def foo(_): self.assertEqual(r.a, 'hi') self.assertEqual(r.b, 1) + def test_self_join_with_pandas(self): + import pyspark.sql.functions as F + + @F.pandas_udf('key long, col string', F.PandasUDFType.GROUPED_MAP) + def dummy_pandas_udf(df): + return df[['key', 'col']] + + df = self.spark.createDataFrame([Row(key=1, col='A'), Row(key=1, col='B'), + Row(key=2, col='C')]) + df_with_pandas = df.groupBy('key').apply(dummy_pandas_udf) + + # this was throwing an AnalysisException before SPARK-24208 + res = df_with_pandas.alias('temp0').join(df_with_pandas.alias('temp1'), + F.col('temp0.key') == F.col('temp1.key')) + self.assertEquals(res.count(), 5) + + def test_mixed_scalar_udfs_followed_by_grouby_apply(self): + import pandas as pd + from pyspark.sql.functions import udf, pandas_udf, PandasUDFType + + df = self.spark.range(0, 10).toDF('v1') + df = df.withColumn('v2', udf(lambda x: x + 1, 'int')(df['v1'])) \ + .withColumn('v3', pandas_udf(lambda x: x + 2, 'int')(df['v1'])) + + result = df.groupby() \ + .apply(pandas_udf(lambda x: pd.DataFrame([x.sum().sum()]), + 'sum int', + PandasUDFType.GROUPED_MAP)) + + self.assertEquals(result.collect()[0]['sum'], 165) + @unittest.skipIf( not _have_pandas or not _have_pyarrow, diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index 3cd7a2ef115af..1d24c40e5858e 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -206,7 +206,7 @@ class DecimalType(FractionalType): and scale (the number of digits on the right of dot). For example, (5, 2) can support the value from [-999.99 to 999.99]. - The precision can be up to 38, the scale must less or equal to precision. + The precision can be up to 38, the scale must be less or equal to precision. When create a DecimalType, the default precision and scale is (10, 0). When infer schema from decimal.Decimal objects, it will be DecimalType(38, 18). @@ -752,7 +752,7 @@ def __eq__(self, other): for v in [ArrayType, MapType, StructType]) -_FIXED_DECIMAL = re.compile("decimal\\(\\s*(\\d+)\\s*,\\s*(\\d+)\\s*\\)") +_FIXED_DECIMAL = re.compile(r"decimal\(\s*(\d+)\s*,\s*(\d+)\s*\)") def _parse_datatype_string(s): @@ -1500,6 +1500,9 @@ def __contains__(self, item): # let object acts like class def __call__(self, *args): """create new Row object""" + if len(args) > len(self): + raise ValueError("Can not create Row with fields %s, expected %d values " + "but got %s" % (self, len(self), args)) return _create_row(self, args) def __getitem__(self, item): @@ -1578,6 +1581,7 @@ def convert(self, obj, gateway_client): def to_arrow_type(dt): """ Convert Spark data type to pyarrow type """ + from distutils.version import LooseVersion import pyarrow as pa if type(dt) == BooleanType: arrow_type = pa.bool_() @@ -1597,6 +1601,12 @@ def to_arrow_type(dt): arrow_type = pa.decimal128(dt.precision, dt.scale) elif type(dt) == StringType: arrow_type = pa.string() + elif type(dt) == BinaryType: + # TODO: remove version check once minimum pyarrow version is 0.10.0 + if LooseVersion(pa.__version__) < LooseVersion("0.10.0"): + raise TypeError("Unsupported type in conversion to Arrow: " + str(dt) + + "\nPlease install pyarrow >= 0.10.0 for BinaryType support.") + arrow_type = pa.binary() elif type(dt) == DateType: arrow_type = pa.date32() elif type(dt) == TimestampType: @@ -1623,6 +1633,8 @@ def to_arrow_schema(schema): def from_arrow_type(at): """ Convert pyarrow type to Spark data type. """ + from distutils.version import LooseVersion + import pyarrow as pa import pyarrow.types as types if types.is_boolean(at): spark_type = BooleanType() @@ -1642,6 +1654,12 @@ def from_arrow_type(at): spark_type = DecimalType(precision=at.precision, scale=at.scale) elif types.is_string(at): spark_type = StringType() + elif types.is_binary(at): + # TODO: remove version check once minimum pyarrow version is 0.10.0 + if LooseVersion(pa.__version__) < LooseVersion("0.10.0"): + raise TypeError("Unsupported type in conversion from Arrow: " + str(at) + + "\nPlease install pyarrow >= 0.10.0 for BinaryType support.") + spark_type = BinaryType() elif types.is_date32(at): spark_type = DateType() elif types.is_timestamp(at): diff --git a/python/pyspark/sql/utils.py b/python/pyspark/sql/utils.py index bb9ce02c4b60f..bdb3a1467f1d8 100644 --- a/python/pyspark/sql/utils.py +++ b/python/pyspark/sql/utils.py @@ -152,6 +152,25 @@ def require_minimum_pyarrow_version(): "your version was %s." % (minimum_pyarrow_version, pyarrow.__version__)) +def require_test_compiled(): + """ Raise Exception if test classes are not compiled + """ + import os + import glob + try: + spark_home = os.environ['SPARK_HOME'] + except KeyError: + raise RuntimeError('SPARK_HOME is not defined in environment') + + test_class_path = os.path.join( + spark_home, 'sql', 'core', 'target', '*', 'test-classes') + paths = glob.glob(test_class_path) + + if len(paths) == 0: + raise RuntimeError( + "%s doesn't exist. Spark sql test classes are not compiled." % test_class_path) + + class ForeachBatchFunction(object): """ This is the Python implementation of Java interface 'ForeachBatchFunction'. This wraps diff --git a/python/pyspark/storagelevel.py b/python/pyspark/storagelevel.py index ef012d27cb22f..7f29646c07432 100644 --- a/python/pyspark/storagelevel.py +++ b/python/pyspark/storagelevel.py @@ -58,8 +58,8 @@ def __str__(self): StorageLevel.OFF_HEAP = StorageLevel(True, True, True, False, 1) """ -.. note:: The following four storage level constants are deprecated in 2.0, since the records \ -will always be serialized in Python. +.. note:: The following four storage level constants are deprecated in 2.0, since the records + will always be serialized in Python. """ StorageLevel.MEMORY_ONLY_SER = StorageLevel.MEMORY_ONLY """.. note:: Deprecated in 2.0, use ``StorageLevel.MEMORY_ONLY`` instead.""" diff --git a/python/pyspark/streaming/context.py b/python/pyspark/streaming/context.py index a4515828d180c..3fa57ca85b37b 100644 --- a/python/pyspark/streaming/context.py +++ b/python/pyspark/streaming/context.py @@ -222,7 +222,7 @@ def remember(self, duration): Set each DStreams in this context to remember RDDs it generated in the last given duration. DStreams remember RDDs only for a limited duration of time and releases them for garbage collection. - This method allows the developer to specify how to long to remember + This method allows the developer to specify how long to remember the RDDs (if the developer wishes to query old data outside the DStream computation). @@ -287,7 +287,7 @@ def _check_serializers(self, rdds): def queueStream(self, rdds, oneAtATime=True, default=None): """ - Create an input stream from an queue of RDDs or list. In each batch, + Create an input stream from a queue of RDDs or list. In each batch, it will process either one or all of the RDDs returned by the queue. .. note:: Changes to the queue after the stream is created will not be recognized. diff --git a/python/pyspark/streaming/dstream.py b/python/pyspark/streaming/dstream.py index 59977dcb435a8..ce42a857d0c06 100644 --- a/python/pyspark/streaming/dstream.py +++ b/python/pyspark/streaming/dstream.py @@ -23,6 +23,8 @@ if sys.version < "3": from itertools import imap as map, ifilter as filter +else: + long = int from py4j.protocol import Py4JJavaError diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py index 373784f826677..5cef621a28e6e 100644 --- a/python/pyspark/streaming/tests.py +++ b/python/pyspark/streaming/tests.py @@ -179,7 +179,7 @@ def func(dstream): self._test_func(input, func, expected) def test_flatMap(self): - """Basic operation test for DStream.faltMap.""" + """Basic operation test for DStream.flatMap.""" input = [range(1, 5), range(5, 9), range(9, 13)] def func(dstream): @@ -206,6 +206,38 @@ def func(dstream): expected = [[len(x)] for x in input] self._test_func(input, func, expected) + def test_slice(self): + """Basic operation test for DStream.slice.""" + import datetime as dt + self.ssc = StreamingContext(self.sc, 1.0) + self.ssc.remember(4.0) + input = [[1], [2], [3], [4]] + stream = self.ssc.queueStream([self.sc.parallelize(d, 1) for d in input]) + + time_vals = [] + + def get_times(t, rdd): + if rdd and len(time_vals) < len(input): + time_vals.append(t) + + stream.foreachRDD(get_times) + + self.ssc.start() + self.wait_for(time_vals, 4) + begin_time = time_vals[0] + + def get_sliced(begin_delta, end_delta): + begin = begin_time + dt.timedelta(seconds=begin_delta) + end = begin_time + dt.timedelta(seconds=end_delta) + rdds = stream.slice(begin, end) + result_list = [rdd.collect() for rdd in rdds] + return [r for result in result_list for r in result] + + self.assertEqual(set([1]), set(get_sliced(0, 0))) + self.assertEqual(set([2, 3]), set(get_sliced(1, 2))) + self.assertEqual(set([2, 3, 4]), set(get_sliced(1, 4))) + self.assertEqual(set([1, 2, 3, 4]), set(get_sliced(0, 4))) + def test_reduce(self): """Basic operation test for DStream.reduce.""" input = [range(1, 5), range(5, 9), range(9, 13)] @@ -822,7 +854,7 @@ def setupFunc(): self.ssc = StreamingContext.getActiveOrCreate(None, setupFunc) self.assertTrue(self.setupCalled) - # Verify that getActiveOrCreate() retuns active context and does not call the setupFunc + # Verify that getActiveOrCreate() returns active context and does not call the setupFunc self.ssc.start() self.setupCalled = False self.assertEqual(StreamingContext.getActiveOrCreate(None, setupFunc), self.ssc) diff --git a/python/pyspark/taskcontext.py b/python/pyspark/taskcontext.py index 63ae1f30e17ca..b61643eb0a16e 100644 --- a/python/pyspark/taskcontext.py +++ b/python/pyspark/taskcontext.py @@ -16,6 +16,10 @@ # from __future__ import print_function +import socket + +from pyspark.java_gateway import local_connect_and_auth +from pyspark.serializers import write_int, UTF8Deserializer class TaskContext(object): @@ -95,3 +99,127 @@ def getLocalProperty(self, key): Get a local property set upstream in the driver, or None if it is missing. """ return self._localProperties.get(key, None) + + +BARRIER_FUNCTION = 1 + + +def _load_from_socket(port, auth_secret): + """ + Load data from a given socket, this is a blocking method thus only return when the socket + connection has been closed. + """ + (sockfile, sock) = local_connect_and_auth(port, auth_secret) + # The barrier() call may block forever, so no timeout + sock.settimeout(None) + # Make a barrier() function call. + write_int(BARRIER_FUNCTION, sockfile) + sockfile.flush() + + # Collect result. + res = UTF8Deserializer().loads(sockfile) + + # Release resources. + sockfile.close() + sock.close() + + return res + + +class BarrierTaskContext(TaskContext): + + """ + .. note:: Experimental + + A :class:`TaskContext` with extra contextual info and tooling for tasks in a barrier stage. + Use :func:`BarrierTaskContext.get` to obtain the barrier context for a running barrier task. + + .. versionadded:: 2.4.0 + """ + + _port = None + _secret = None + + def __init__(self): + """Construct a BarrierTaskContext, use get instead""" + pass + + @classmethod + def _getOrCreate(cls): + """Internal function to get or create global BarrierTaskContext.""" + if cls._taskContext is None: + cls._taskContext = BarrierTaskContext() + return cls._taskContext + + @classmethod + def get(cls): + """ + .. note:: Experimental + + Return the currently active :class:`BarrierTaskContext`. + This can be called inside of user functions to access contextual information about + running tasks. + + .. note:: Must be called on the worker, not the driver. Returns None if not initialized. + """ + return cls._taskContext + + @classmethod + def _initialize(cls, port, secret): + """ + Initialize BarrierTaskContext, other methods within BarrierTaskContext can only be called + after BarrierTaskContext is initialized. + """ + cls._port = port + cls._secret = secret + + def barrier(self): + """ + .. note:: Experimental + + Sets a global barrier and waits until all tasks in this stage hit this barrier. + Similar to `MPI_Barrier` function in MPI, this function blocks until all tasks + in the same stage have reached this routine. + + .. warning:: In a barrier stage, each task much have the same number of `barrier()` + calls, in all possible code branches. + Otherwise, you may get the job hanging or a SparkException after timeout. + + .. versionadded:: 2.4.0 + """ + if self._port is None or self._secret is None: + raise Exception("Not supported to call barrier() before initialize " + + "BarrierTaskContext.") + else: + _load_from_socket(self._port, self._secret) + + def getTaskInfos(self): + """ + .. note:: Experimental + + Returns :class:`BarrierTaskInfo` for all tasks in this barrier stage, + ordered by partition ID. + + .. versionadded:: 2.4.0 + """ + if self._port is None or self._secret is None: + raise Exception("Not supported to call getTaskInfos() before initialize " + + "BarrierTaskContext.") + else: + addresses = self._localProperties.get("addresses", "") + return [BarrierTaskInfo(h.strip()) for h in addresses.split(",")] + + +class BarrierTaskInfo(object): + """ + .. note:: Experimental + + Carries all task infos of a barrier task. + + :var address: The IPv4 address (host:port) of the executor that the barrier task is running on + + .. versionadded:: 2.4.0 + """ + + def __init__(self, address): + self.address = address diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index a4c5fb1db8b37..8ac1df52fc597 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -70,7 +70,7 @@ from pyspark.shuffle import Aggregator, ExternalMerger, ExternalSorter from pyspark import shuffle from pyspark.profiler import BasicProfiler -from pyspark.taskcontext import TaskContext +from pyspark.taskcontext import BarrierTaskContext, TaskContext _have_scipy = False _have_numpy = False @@ -588,6 +588,40 @@ def test_get_local_property(self): finally: self.sc.setLocalProperty(key, None) + def test_barrier(self): + """ + Verify that BarrierTaskContext.barrier() performs global sync among all barrier tasks + within a stage. + """ + rdd = self.sc.parallelize(range(10), 4) + + def f(iterator): + yield sum(iterator) + + def context_barrier(x): + tc = BarrierTaskContext.get() + time.sleep(random.randint(1, 10)) + tc.barrier() + return time.time() + + times = rdd.barrier().mapPartitions(f).map(context_barrier).collect() + self.assertTrue(max(times) - min(times) < 1) + + def test_barrier_infos(self): + """ + Verify that BarrierTaskContext.getTaskInfos() returns a list of all task infos in the + barrier stage. + """ + rdd = self.sc.parallelize(range(10), 4) + + def f(iterator): + yield sum(iterator) + + taskInfos = rdd.barrier().mapPartitions(f).map(lambda x: BarrierTaskContext.get() + .getTaskInfos()).collect() + self.assertTrue(len(taskInfos) == 4) + self.assertTrue(len(taskInfos[0]) == 4) + class RDDTests(ReusedPySparkTestCase): diff --git a/python/pyspark/util.py b/python/pyspark/util.py index f015542c8799d..f906f49595438 100644 --- a/python/pyspark/util.py +++ b/python/pyspark/util.py @@ -80,7 +80,7 @@ def majorMinorVersion(sparkVersion): (2, 3) """ - m = re.search('^(\d+)\.(\d+)(\..*)?$', sparkVersion) + m = re.search(r'^(\d+)\.(\d+)(\..*)?$', sparkVersion) if m is not None: return (int(m.group(1)), int(m.group(2))) else: diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index eaaae2b14e107..e934da4d2eb6e 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -22,16 +22,17 @@ import os import sys import time +import resource import socket import traceback from pyspark.accumulators import _accumulatorRegistry from pyspark.broadcast import Broadcast, _broadcastRegistry -from pyspark.java_gateway import do_server_auth -from pyspark.taskcontext import TaskContext +from pyspark.java_gateway import local_connect_and_auth +from pyspark.taskcontext import BarrierTaskContext, TaskContext from pyspark.files import SparkFiles from pyspark.rdd import PythonEvalType -from pyspark.serializers import write_with_length, write_int, read_long, \ +from pyspark.serializers import write_with_length, write_int, read_long, read_bool, \ write_long, read_int, SpecialLengths, UTF8Deserializer, PickleSerializer, \ BatchedSerializer, ArrowStreamPandasSerializer from pyspark.sql.types import to_arrow_type @@ -259,8 +260,40 @@ def main(infile, outfile): "PYSPARK_DRIVER_PYTHON are correctly set.") % ("%d.%d" % sys.version_info[:2], version)) + # read inputs only for a barrier task + isBarrier = read_bool(infile) + boundPort = read_int(infile) + secret = UTF8Deserializer().loads(infile) + + # set up memory limits + memory_limit_mb = int(os.environ.get('PYSPARK_EXECUTOR_MEMORY_MB', "-1")) + total_memory = resource.RLIMIT_AS + try: + if memory_limit_mb > 0: + (soft_limit, hard_limit) = resource.getrlimit(total_memory) + msg = "Current mem limits: {0} of max {1}\n".format(soft_limit, hard_limit) + print(msg, file=sys.stderr) + + # convert to bytes + new_limit = memory_limit_mb * 1024 * 1024 + + if soft_limit == resource.RLIM_INFINITY or new_limit < soft_limit: + msg = "Setting mem limits to {0} of max {1}\n".format(new_limit, new_limit) + print(msg, file=sys.stderr) + resource.setrlimit(total_memory, (new_limit, new_limit)) + + except (resource.error, OSError, ValueError) as e: + # not all systems support resource limits, so warn instead of failing + print("WARN: Failed to set memory limit: {0}\n".format(e), file=sys.stderr) + # initialize global state - taskContext = TaskContext._getOrCreate() + taskContext = None + if isBarrier: + taskContext = BarrierTaskContext._getOrCreate() + BarrierTaskContext._initialize(boundPort, secret) + else: + taskContext = TaskContext._getOrCreate() + # read inputs for TaskContext info taskContext._stageId = read_int(infile) taskContext._partitionId = read_int(infile) taskContext._attemptNumber = read_int(infile) @@ -354,8 +387,5 @@ def process(): # Read information about how to connect back to the JVM from the environment. java_port = int(os.environ["PYTHON_WORKER_FACTORY_PORT"]) auth_secret = os.environ["PYTHON_WORKER_FACTORY_SECRET"] - sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - sock.connect(("127.0.0.1", java_port)) - sock_file = sock.makefile("rwb", 65536) - do_server_auth(sock_file, auth_secret) + (sock_file, _) = local_connect_and_auth(java_port, auth_secret) main(sock_file, sock_file) diff --git a/python/run-tests.py b/python/run-tests.py index 4c90926cfa350..ccbdfac3f3850 100755 --- a/python/run-tests.py +++ b/python/run-tests.py @@ -138,7 +138,7 @@ def run_individual_python_test(target_dir, test_name, pyspark_python): # 2 (or --verbose option is enabled). decoded_lines = map(lambda line: line.decode(), iter(per_test_output)) skipped_tests = list(filter( - lambda line: re.search('test_.* \(pyspark\..*\) ... skipped ', line), + lambda line: re.search(r'test_.* \(pyspark\..*\) ... skipped ', line), decoded_lines)) skipped_counts = len(skipped_tests) if skipped_counts > 0: diff --git a/python/setup.py b/python/setup.py index d309e0564530a..c447f2d40343d 100644 --- a/python/setup.py +++ b/python/setup.py @@ -34,7 +34,7 @@ print("Failed to load PySpark version file for packaging. You must be in Spark's python dir.", file=sys.stderr) sys.exit(-1) -VERSION = __version__ +VERSION = __version__ # noqa # A temporary path so we can access above the Python project root and fetch scripts and jars we need TEMP_PATH = "deps" SPARK_HOME = os.path.abspath("../") @@ -219,6 +219,7 @@ def _supports_symlinks(): 'Programming Language :: Python :: 3.4', 'Programming Language :: Python :: 3.5', 'Programming Language :: Python :: 3.6', + 'Programming Language :: Python :: 3.7', 'Programming Language :: Python :: Implementation :: CPython', 'Programming Language :: Python :: Implementation :: PyPy'] ) diff --git a/repl/pom.xml b/repl/pom.xml index 6f4a863c48bc7..e8464a688336b 100644 --- a/repl/pom.xml +++ b/repl/pom.xml @@ -32,8 +32,8 @@ repl - scala-2.11/src/main/scala - scala-2.11/src/test/scala + src/main/scala-${scala.binary.version} + src/test/scala-${scala.binary.version} @@ -102,7 +102,7 @@ org.apache.xbean - xbean-asm5-shaded + xbean-asm6-shaded @@ -166,15 +166,5 @@ - - - - scala-2.12 - - scala-2.12/src/main/scala - scala-2.12/src/test/scala - - - diff --git a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkILoop.scala b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkILoop.scala deleted file mode 100644 index a44051b351e19..0000000000000 --- a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkILoop.scala +++ /dev/null @@ -1,144 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.repl - -import java.io.BufferedReader - -// scalastyle:off println -import scala.Predef.{println => _, _} -// scalastyle:on println -import scala.tools.nsc.Settings -import scala.tools.nsc.interpreter.{ILoop, JPrintWriter} -import scala.tools.nsc.util.stringFromStream -import scala.util.Properties.{javaVersion, javaVmName, versionString} - -/** - * A Spark-specific interactive shell. - */ -class SparkILoop(in0: Option[BufferedReader], out: JPrintWriter) - extends ILoop(in0, out) { - def this(in0: BufferedReader, out: JPrintWriter) = this(Some(in0), out) - def this() = this(None, new JPrintWriter(Console.out, true)) - - override def createInterpreter(): Unit = { - intp = new SparkILoopInterpreter(settings, out, initializeSpark) - } - - val initializationCommands: Seq[String] = Seq( - """ - @transient val spark = if (org.apache.spark.repl.Main.sparkSession != null) { - org.apache.spark.repl.Main.sparkSession - } else { - org.apache.spark.repl.Main.createSparkSession() - } - @transient val sc = { - val _sc = spark.sparkContext - if (_sc.getConf.getBoolean("spark.ui.reverseProxy", false)) { - val proxyUrl = _sc.getConf.get("spark.ui.reverseProxyUrl", null) - if (proxyUrl != null) { - println( - s"Spark Context Web UI is available at ${proxyUrl}/proxy/${_sc.applicationId}") - } else { - println(s"Spark Context Web UI is available at Spark Master Public URL") - } - } else { - _sc.uiWebUrl.foreach { - webUrl => println(s"Spark context Web UI available at ${webUrl}") - } - } - println("Spark context available as 'sc' " + - s"(master = ${_sc.master}, app id = ${_sc.applicationId}).") - println("Spark session available as 'spark'.") - _sc - } - """, - "import org.apache.spark.SparkContext._", - "import spark.implicits._", - "import spark.sql", - "import org.apache.spark.sql.functions._" - ) - - def initializeSpark(): Unit = { - if (!intp.reporter.hasErrors) { - // `savingReplayStack` removes the commands from session history. - savingReplayStack { - initializationCommands.foreach(intp quietRun _) - } - } else { - throw new RuntimeException(s"Scala $versionString interpreter encountered " + - "errors during initialization") - } - } - - /** Print a welcome message */ - override def printWelcome() { - import org.apache.spark.SPARK_VERSION - echo("""Welcome to - ____ __ - / __/__ ___ _____/ /__ - _\ \/ _ \/ _ `/ __/ '_/ - /___/ .__/\_,_/_/ /_/\_\ version %s - /_/ - """.format(SPARK_VERSION)) - val welcomeMsg = "Using Scala %s (%s, Java %s)".format( - versionString, javaVmName, javaVersion) - echo(welcomeMsg) - echo("Type in expressions to have them evaluated.") - echo("Type :help for more information.") - } - - /** Available commands */ - override def commands: List[LoopCommand] = standardCommands - - override def resetCommand(line: String): Unit = { - super.resetCommand(line) - initializeSpark() - echo("Note that after :reset, state of SparkSession and SparkContext is unchanged.") - } - - override def replay(): Unit = { - initializeSpark() - super.replay() - } - -} - -object SparkILoop { - - /** - * Creates an interpreter loop with default settings and feeds - * the given code to it as input. - */ - def run(code: String, sets: Settings = new Settings): String = { - import java.io.{ BufferedReader, StringReader, OutputStreamWriter } - - stringFromStream { ostream => - Console.withOut(ostream) { - val input = new BufferedReader(new StringReader(code)) - val output = new JPrintWriter(new OutputStreamWriter(ostream), true) - val repl = new SparkILoop(input, output) - - if (sets.classpath.isDefault) { - sets.classpath.value = sys.props("java.class.path") - } - repl process sets - } - } - } - def run(lines: List[String]): String = run(lines.map(_ + "\n").mkString) -} diff --git a/repl/scala-2.12/src/main/scala/org/apache/spark/repl/SparkILoop.scala b/repl/scala-2.12/src/main/scala/org/apache/spark/repl/SparkILoop.scala deleted file mode 100644 index ffb2e5f5db7e2..0000000000000 --- a/repl/scala-2.12/src/main/scala/org/apache/spark/repl/SparkILoop.scala +++ /dev/null @@ -1,143 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.repl - -import java.io.BufferedReader - -import scala.tools.nsc.Settings -import scala.tools.nsc.interpreter.{ILoop, JPrintWriter} -import scala.tools.nsc.util.stringFromStream -import scala.util.Properties.{javaVersion, javaVmName, versionString} - -/** - * A Spark-specific interactive shell. - */ -class SparkILoop(in0: Option[BufferedReader], out: JPrintWriter) - extends ILoop(in0, out) { - def this(in0: BufferedReader, out: JPrintWriter) = this(Some(in0), out) - def this() = this(None, new JPrintWriter(Console.out, true)) - - val initializationCommands: Seq[String] = Seq( - """ - @transient val spark = if (org.apache.spark.repl.Main.sparkSession != null) { - org.apache.spark.repl.Main.sparkSession - } else { - org.apache.spark.repl.Main.createSparkSession() - } - @transient val sc = { - val _sc = spark.sparkContext - if (_sc.getConf.getBoolean("spark.ui.reverseProxy", false)) { - val proxyUrl = _sc.getConf.get("spark.ui.reverseProxyUrl", null) - if (proxyUrl != null) { - println( - s"Spark Context Web UI is available at ${proxyUrl}/proxy/${_sc.applicationId}") - } else { - println(s"Spark Context Web UI is available at Spark Master Public URL") - } - } else { - _sc.uiWebUrl.foreach { - webUrl => println(s"Spark context Web UI available at ${webUrl}") - } - } - println("Spark context available as 'sc' " + - s"(master = ${_sc.master}, app id = ${_sc.applicationId}).") - println("Spark session available as 'spark'.") - _sc - } - """, - "import org.apache.spark.SparkContext._", - "import spark.implicits._", - "import spark.sql", - "import org.apache.spark.sql.functions._" - ) - - def initializeSpark() { - intp.beQuietDuring { - savingReplayStack { // remove the commands from session history. - initializationCommands.foreach(command) - } - } - } - - /** Print a welcome message */ - override def printWelcome() { - import org.apache.spark.SPARK_VERSION - echo("""Welcome to - ____ __ - / __/__ ___ _____/ /__ - _\ \/ _ \/ _ `/ __/ '_/ - /___/ .__/\_,_/_/ /_/\_\ version %s - /_/ - """.format(SPARK_VERSION)) - val welcomeMsg = "Using Scala %s (%s, Java %s)".format( - versionString, javaVmName, javaVersion) - echo(welcomeMsg) - echo("Type in expressions to have them evaluated.") - echo("Type :help for more information.") - } - - /** Available commands */ - override def commands: List[LoopCommand] = standardCommands - - /** - * We override `createInterpreter` because we need to initialize Spark *before* the REPL - * sees any files, so that the Spark context is visible in those files. This is a bit of a - * hack, but there isn't another hook available to us at this point. - */ - override def createInterpreter(): Unit = { - super.createInterpreter() - initializeSpark() - } - - override def resetCommand(line: String): Unit = { - super.resetCommand(line) - initializeSpark() - echo("Note that after :reset, state of SparkSession and SparkContext is unchanged.") - } - - override def replay(): Unit = { - initializeSpark() - super.replay() - } - -} - -object SparkILoop { - - /** - * Creates an interpreter loop with default settings and feeds - * the given code to it as input. - */ - def run(code: String, sets: Settings = new Settings): String = { - import java.io.{ BufferedReader, StringReader, OutputStreamWriter } - - stringFromStream { ostream => - Console.withOut(ostream) { - val input = new BufferedReader(new StringReader(code)) - val output = new JPrintWriter(new OutputStreamWriter(ostream), true) - val repl = new SparkILoop(input, output) - - if (sets.classpath.isDefault) { - sets.classpath.value = sys.props("java.class.path") - } - repl process sets - } - } - } - def run(lines: List[String]): String = run(lines.map(_ + "\n").mkString) -} diff --git a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkExprTyper.scala b/repl/src/main/scala-2.11/org/apache/spark/repl/SparkExprTyper.scala similarity index 100% rename from repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkExprTyper.scala rename to repl/src/main/scala-2.11/org/apache/spark/repl/SparkExprTyper.scala diff --git a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkILoopInterpreter.scala b/repl/src/main/scala-2.11/org/apache/spark/repl/SparkILoopInterpreter.scala similarity index 93% rename from repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkILoopInterpreter.scala rename to repl/src/main/scala-2.11/org/apache/spark/repl/SparkILoopInterpreter.scala index 4e63816402a10..e736607a9a6b9 100644 --- a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkILoopInterpreter.scala +++ b/repl/src/main/scala-2.11/org/apache/spark/repl/SparkILoopInterpreter.scala @@ -21,22 +21,8 @@ import scala.collection.mutable import scala.tools.nsc.Settings import scala.tools.nsc.interpreter._ -class SparkILoopInterpreter(settings: Settings, out: JPrintWriter, initializeSpark: () => Unit) - extends IMain(settings, out) { self => - - /** - * We override `initializeSynchronous` to initialize Spark *after* `intp` is properly initialized - * and *before* the REPL sees any files in the private `loadInitFiles` functions, so that - * the Spark context is visible in those files. - * - * This is a bit of a hack, but there isn't another hook available to us at this point. - * - * See the discussion in Scala community https://github.com/scala/bug/issues/10913 for detail. - */ - override def initializeSynchronous(): Unit = { - super.initializeSynchronous() - initializeSpark() - } +class SparkILoopInterpreter(settings: Settings, out: JPrintWriter) extends IMain(settings, out) { + self => override lazy val memberHandlers = new { val intp: self.type = self diff --git a/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala b/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala index 4dc399827ffed..88eb0ad1da3d7 100644 --- a/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala +++ b/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala @@ -22,8 +22,8 @@ import java.net.{URI, URL, URLEncoder} import java.nio.channels.Channels import org.apache.hadoop.fs.{FileSystem, Path} -import org.apache.xbean.asm5._ -import org.apache.xbean.asm5.Opcodes._ +import org.apache.xbean.asm6._ +import org.apache.xbean.asm6.Opcodes._ import org.apache.spark.{SparkConf, SparkEnv} import org.apache.spark.deploy.SparkHadoopUtil @@ -187,7 +187,7 @@ class ExecutorClassLoader( } class ConstructorCleaner(className: String, cv: ClassVisitor) -extends ClassVisitor(ASM5, cv) { +extends ClassVisitor(ASM6, cv) { override def visitMethod(access: Int, name: String, desc: String, sig: String, exceptions: Array[String]): MethodVisitor = { val mv = cv.visitMethod(access, name, desc, sig, exceptions) diff --git a/repl/src/main/scala/org/apache/spark/repl/SparkILoop.scala b/repl/src/main/scala/org/apache/spark/repl/SparkILoop.scala new file mode 100644 index 0000000000000..aa9aa2793b8b3 --- /dev/null +++ b/repl/src/main/scala/org/apache/spark/repl/SparkILoop.scala @@ -0,0 +1,319 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.repl + +import java.io.BufferedReader + +// scalastyle:off println +import scala.Predef.{println => _, _} +// scalastyle:on println +import scala.concurrent.Future +import scala.reflect.classTag +import scala.reflect.io.File +import scala.tools.nsc.{GenericRunnerSettings, Properties} +import scala.tools.nsc.Settings +import scala.tools.nsc.interpreter.{isReplDebug, isReplPower, replProps} +import scala.tools.nsc.interpreter.{AbstractOrMissingHandler, ILoop, IMain, JPrintWriter} +import scala.tools.nsc.interpreter.{NamedParam, SimpleReader, SplashLoop, SplashReader} +import scala.tools.nsc.interpreter.StdReplTags.tagOfIMain +import scala.tools.nsc.util.stringFromStream +import scala.util.Properties.{javaVersion, javaVmName, versionNumberString, versionString} + +/** + * A Spark-specific interactive shell. + */ +class SparkILoop(in0: Option[BufferedReader], out: JPrintWriter) + extends ILoop(in0, out) { + def this(in0: BufferedReader, out: JPrintWriter) = this(Some(in0), out) + def this() = this(None, new JPrintWriter(Console.out, true)) + + /** + * TODO: Remove the following `override` when the support of Scala 2.11 is ended + * Scala 2.11 has a bug of finding imported types in class constructors, extends clause + * which is fixed in Scala 2.12 but never be back-ported into Scala 2.11.x. + * As a result, we copied the fixes into `SparkILoopInterpreter`. See SPARK-22393 for detail. + */ + override def createInterpreter(): Unit = { + if (isScala2_11) { + if (addedClasspath != "") { + settings.classpath append addedClasspath + } + // scalastyle:off classforname + // Have to use the default classloader to match the one used in + // `classOf[Settings]` and `classOf[JPrintWriter]`. + intp = Class.forName("org.apache.spark.repl.SparkILoopInterpreter") + .getDeclaredConstructor(Seq(classOf[Settings], classOf[JPrintWriter]): _*) + .newInstance(Seq(settings, out): _*) + .asInstanceOf[IMain] + // scalastyle:on classforname + } else { + super.createInterpreter() + } + } + + private val isScala2_11 = versionNumberString.startsWith("2.11") + + val initializationCommands: Seq[String] = Seq( + """ + @transient val spark = if (org.apache.spark.repl.Main.sparkSession != null) { + org.apache.spark.repl.Main.sparkSession + } else { + org.apache.spark.repl.Main.createSparkSession() + } + @transient val sc = { + val _sc = spark.sparkContext + if (_sc.getConf.getBoolean("spark.ui.reverseProxy", false)) { + val proxyUrl = _sc.getConf.get("spark.ui.reverseProxyUrl", null) + if (proxyUrl != null) { + println( + s"Spark Context Web UI is available at ${proxyUrl}/proxy/${_sc.applicationId}") + } else { + println(s"Spark Context Web UI is available at Spark Master Public URL") + } + } else { + _sc.uiWebUrl.foreach { + webUrl => println(s"Spark context Web UI available at ${webUrl}") + } + } + println("Spark context available as 'sc' " + + s"(master = ${_sc.master}, app id = ${_sc.applicationId}).") + println("Spark session available as 'spark'.") + _sc + } + """, + "import org.apache.spark.SparkContext._", + "import spark.implicits._", + "import spark.sql", + "import org.apache.spark.sql.functions._" + ) + + def initializeSpark(): Unit = { + if (!intp.reporter.hasErrors) { + // `savingReplayStack` removes the commands from session history. + savingReplayStack { + initializationCommands.foreach(intp quietRun _) + } + } else { + throw new RuntimeException(s"Scala $versionString interpreter encountered " + + "errors during initialization") + } + } + + /** Print a welcome message */ + override def printWelcome() { + import org.apache.spark.SPARK_VERSION + echo("""Welcome to + ____ __ + / __/__ ___ _____/ /__ + _\ \/ _ \/ _ `/ __/ '_/ + /___/ .__/\_,_/_/ /_/\_\ version %s + /_/ + """.format(SPARK_VERSION)) + val welcomeMsg = "Using Scala %s (%s, Java %s)".format( + versionString, javaVmName, javaVersion) + echo(welcomeMsg) + echo("Type in expressions to have them evaluated.") + echo("Type :help for more information.") + } + + /** Available commands */ + override def commands: List[LoopCommand] = standardCommands + + override def resetCommand(line: String): Unit = { + super.resetCommand(line) + initializeSpark() + echo("Note that after :reset, state of SparkSession and SparkContext is unchanged.") + } + + override def replay(): Unit = { + initializeSpark() + super.replay() + } + + /** + * TODO: Remove `runClosure` when the support of Scala 2.11 is ended + */ + private def runClosure(body: => Boolean): Boolean = { + if (isScala2_11) { + // In Scala 2.11, there is a bug that interpret could set the current thread's + // context classloader, but fails to reset it to its previous state when returning + // from that method. This is fixed in SI-8521 https://github.com/scala/scala/pull/5657 + // which is never back-ported into Scala 2.11.x. The following is a workaround fix. + val original = Thread.currentThread().getContextClassLoader + try { + body + } finally { + Thread.currentThread().setContextClassLoader(original) + } + } else { + body + } + } + + /** + * The following code is mostly a copy of `process` implementation in `ILoop.scala` in Scala + * + * In newer version of Scala, `printWelcome` is the first thing to be called. As a result, + * SparkUI URL information would be always shown after the welcome message. + * + * However, this is inconsistent compared with the existing version of Spark which will always + * show SparkUI URL first. + * + * The only way we can make it consistent will be duplicating the Scala code. + * + * We should remove this duplication once Scala provides a way to load our custom initialization + * code, and also customize the ordering of printing welcome message. + */ + override def process(settings: Settings): Boolean = runClosure { + + def newReader = in0.fold(chooseReader(settings))(r => SimpleReader(r, out, interactive = true)) + + /** Reader to use before interpreter is online. */ + def preLoop = { + val sr = SplashReader(newReader) { r => + in = r + in.postInit() + } + in = sr + SplashLoop(sr, prompt) + } + + /* Actions to cram in parallel while collecting first user input at prompt. + * Run with output muted both from ILoop and from the intp reporter. + */ + def loopPostInit(): Unit = mumly { + // Bind intp somewhere out of the regular namespace where + // we can get at it in generated code. + intp.quietBind(NamedParam[IMain]("$intp", intp)(tagOfIMain, classTag[IMain])) + + // Auto-run code via some setting. + ( replProps.replAutorunCode.option + flatMap (f => File(f).safeSlurp()) + foreach (intp quietRun _) + ) + // power mode setup + if (isReplPower) enablePowerMode(true) + initializeSpark() + loadInitFiles() + // SI-7418 Now, and only now, can we enable TAB completion. + in.postInit() + } + def loadInitFiles(): Unit = settings match { + case settings: GenericRunnerSettings => + for (f <- settings.loadfiles.value) { + loadCommand(f) + addReplay(s":load $f") + } + for (f <- settings.pastefiles.value) { + pasteCommand(f) + addReplay(s":paste $f") + } + case _ => + } + // wait until after startup to enable noisy settings + def withSuppressedSettings[A](body: => A): A = { + val ss = this.settings + import ss._ + val noisy = List(Xprint, Ytyperdebug) + val noisesome = noisy.exists(!_.isDefault) + val current = (Xprint.value, Ytyperdebug.value) + if (isReplDebug || !noisesome) body + else { + this.settings.Xprint.value = List.empty + this.settings.Ytyperdebug.value = false + try body + finally { + Xprint.value = current._1 + Ytyperdebug.value = current._2 + intp.global.printTypings = current._2 + } + } + } + def startup(): String = withSuppressedSettings { + // let them start typing + val splash = preLoop + + // while we go fire up the REPL + try { + // don't allow ancient sbt to hijack the reader + savingReader { + createInterpreter() + } + intp.initializeSynchronous() + + val field = classOf[ILoop].getDeclaredFields.filter(_.getName.contains("globalFuture")).head + field.setAccessible(true) + field.set(this, Future successful true) + + if (intp.reporter.hasErrors) { + echo("Interpreter encountered errors during initialization!") + null + } else { + loopPostInit() + printWelcome() + splash.start() + + val line = splash.line // what they typed in while they were waiting + if (line == null) { // they ^D + try out print Properties.shellInterruptedString + finally closeInterpreter() + } + line + } + } finally splash.stop() + } + + this.settings = settings + startup() match { + case null => false + case line => + try loop(line) match { + case LineResults.EOF => out print Properties.shellInterruptedString + case _ => + } + catch AbstractOrMissingHandler() + finally closeInterpreter() + true + } + } +} + +object SparkILoop { + + /** + * Creates an interpreter loop with default settings and feeds + * the given code to it as input. + */ + def run(code: String, sets: Settings = new Settings): String = { + import java.io.{ BufferedReader, StringReader, OutputStreamWriter } + + stringFromStream { ostream => + Console.withOut(ostream) { + val input = new BufferedReader(new StringReader(code)) + val output = new JPrintWriter(new OutputStreamWriter(ostream), true) + val repl = new SparkILoop(input, output) + + if (sets.classpath.isDefault) { + sets.classpath.value = sys.props("java.class.path") + } + repl process sets + } + } + } + def run(lines: List[String]): String = run(lines.map(_ + "\n").mkString) +} diff --git a/repl/src/test/scala/org/apache/spark/repl/ReplSuite.scala b/repl/src/test/scala/org/apache/spark/repl/ReplSuite.scala index cdd5cdd841740..4f3df729177fb 100644 --- a/repl/src/test/scala/org/apache/spark/repl/ReplSuite.scala +++ b/repl/src/test/scala/org/apache/spark/repl/ReplSuite.scala @@ -21,6 +21,7 @@ import java.io._ import java.net.URLClassLoader import scala.collection.mutable.ArrayBuffer +import scala.tools.nsc.interpreter.SimpleReader import org.apache.log4j.{Level, LogManager} @@ -84,6 +85,7 @@ class ReplSuite extends SparkFunSuite { settings = new scala.tools.nsc.Settings settings.usejavacp.value = true org.apache.spark.repl.Main.interp = this + in = SimpleReader() } val out = new StringWriter() diff --git a/resource-managers/kubernetes/core/pom.xml b/resource-managers/kubernetes/core/pom.xml index a6dd47a6b7d95..920f0f6ebf2c8 100644 --- a/resource-managers/kubernetes/core/pom.xml +++ b/resource-managers/kubernetes/core/pom.xml @@ -47,6 +47,12 @@ test + + org.apache.spark + spark-tags_${scala.binary.version} + test-jar + + io.fabric8 kubernetes-client diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala index bf33179ae3dab..71e4d321a0e3a 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala @@ -65,6 +65,7 @@ private[spark] object Config extends Logging { "spark.kubernetes.authenticate.driver" val KUBERNETES_AUTH_DRIVER_MOUNTED_CONF_PREFIX = "spark.kubernetes.authenticate.driver.mounted" + val KUBERNETES_AUTH_CLIENT_MODE_PREFIX = "spark.kubernetes.authenticate" val OAUTH_TOKEN_CONF_SUFFIX = "oauthToken" val OAUTH_TOKEN_FILE_CONF_SUFFIX = "oauthTokenFile" val CLIENT_KEY_FILE_CONF_SUFFIX = "clientKeyFile" @@ -90,7 +91,7 @@ private[spark] object Config extends Logging { ConfigBuilder("spark.kubernetes.submitInDriver") .internal() .booleanConf - .createOptional + .createWithDefault(false) val KUBERNETES_EXECUTOR_LIMIT_CORES = ConfigBuilder("spark.kubernetes.executor.limit.cores") @@ -138,6 +139,19 @@ private[spark] object Config extends Logging { .stringConf .createOptional + val KUBERNETES_R_MAIN_APP_RESOURCE = + ConfigBuilder("spark.kubernetes.r.mainAppResource") + .doc("The main app resource for SparkR jobs") + .internal() + .stringConf + .createOptional + + val KUBERNETES_R_APP_ARGS = + ConfigBuilder("spark.kubernetes.r.appArgs") + .doc("The app arguments for SparkR Jobs") + .internal() + .stringConf + .createOptional val KUBERNETES_ALLOCATION_BATCH_SIZE = ConfigBuilder("spark.kubernetes.allocation.batch.size") @@ -204,13 +218,29 @@ private[spark] object Config extends Logging { .createWithDefault(0.1) val PYSPARK_MAJOR_PYTHON_VERSION = - ConfigBuilder("spark.kubernetes.pyspark.pythonversion") + ConfigBuilder("spark.kubernetes.pyspark.pythonVersion") .doc("This sets the major Python version. Either 2 or 3. (Python2 or Python3)") .stringConf .checkValue(pv => List("2", "3").contains(pv), "Ensure that major Python version is either Python2 or Python3") .createWithDefault("2") + val APP_RESOURCE_TYPE = + ConfigBuilder("spark.kubernetes.resource.type") + .doc("This sets the resource type internally") + .internal() + .stringConf + .createOptional + + val KUBERNETES_LOCAL_DIRS_TMPFS = + ConfigBuilder("spark.kubernetes.local.dirs.tmpfs") + .doc("If set to true then emptyDir volumes created to back SPARK_LOCAL_DIRS will have " + + "their medium set to Memory so that they will be created as tmpfs (i.e. RAM) backed " + + "volumes. This may improve performance but scratch space usage will count towards " + + "your pods memory limit so you may wish to request more memory.") + .booleanConf + .createWithDefault(false) + val KUBERNETES_AUTH_SUBMISSION_CONF_PREFIX = "spark.kubernetes.authenticate.submission" @@ -220,11 +250,23 @@ private[spark] object Config extends Logging { val KUBERNETES_DRIVER_ANNOTATION_PREFIX = "spark.kubernetes.driver.annotation." val KUBERNETES_DRIVER_SECRETS_PREFIX = "spark.kubernetes.driver.secrets." val KUBERNETES_DRIVER_SECRET_KEY_REF_PREFIX = "spark.kubernetes.driver.secretKeyRef." + val KUBERNETES_DRIVER_VOLUMES_PREFIX = "spark.kubernetes.driver.volumes." val KUBERNETES_EXECUTOR_LABEL_PREFIX = "spark.kubernetes.executor.label." val KUBERNETES_EXECUTOR_ANNOTATION_PREFIX = "spark.kubernetes.executor.annotation." val KUBERNETES_EXECUTOR_SECRETS_PREFIX = "spark.kubernetes.executor.secrets." val KUBERNETES_EXECUTOR_SECRET_KEY_REF_PREFIX = "spark.kubernetes.executor.secretKeyRef." + val KUBERNETES_EXECUTOR_VOLUMES_PREFIX = "spark.kubernetes.executor.volumes." + + val KUBERNETES_VOLUMES_HOSTPATH_TYPE = "hostPath" + val KUBERNETES_VOLUMES_PVC_TYPE = "persistentVolumeClaim" + val KUBERNETES_VOLUMES_EMPTYDIR_TYPE = "emptyDir" + val KUBERNETES_VOLUMES_MOUNT_PATH_KEY = "mount.path" + val KUBERNETES_VOLUMES_MOUNT_READONLY_KEY = "mount.readOnly" + val KUBERNETES_VOLUMES_OPTIONS_PATH_KEY = "options.path" + val KUBERNETES_VOLUMES_OPTIONS_CLAIM_NAME_KEY = "options.claimName" + val KUBERNETES_VOLUMES_OPTIONS_MEDIUM_KEY = "options.medium" + val KUBERNETES_VOLUMES_OPTIONS_SIZE_LIMIT_KEY = "options.sizeLimit" val KUBERNETES_DRIVER_ENV_PREFIX = "spark.kubernetes.driverEnv." } diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Constants.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Constants.scala index 69bd03d1eda6f..8202d874a4626 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Constants.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Constants.scala @@ -25,9 +25,6 @@ private[spark] object Constants { val SPARK_POD_DRIVER_ROLE = "driver" val SPARK_POD_EXECUTOR_ROLE = "executor" - // Annotations - val SPARK_APP_NAME_ANNOTATION = "spark-app-name" - // Credentials secrets val DRIVER_CREDENTIALS_SECRETS_BASE_DIR = "/mnt/secrets/spark-kubernetes-credentials" @@ -50,17 +47,15 @@ private[spark] object Constants { val DEFAULT_BLOCKMANAGER_PORT = 7079 val DRIVER_PORT_NAME = "driver-rpc-port" val BLOCK_MANAGER_PORT_NAME = "blockmanager" - val EXECUTOR_PORT_NAME = "executor" + val UI_PORT_NAME = "spark-ui" // Environment Variables - val ENV_EXECUTOR_PORT = "SPARK_EXECUTOR_PORT" val ENV_DRIVER_URL = "SPARK_DRIVER_URL" val ENV_EXECUTOR_CORES = "SPARK_EXECUTOR_CORES" val ENV_EXECUTOR_MEMORY = "SPARK_EXECUTOR_MEMORY" val ENV_APPLICATION_ID = "SPARK_APPLICATION_ID" val ENV_EXECUTOR_ID = "SPARK_EXECUTOR_ID" val ENV_EXECUTOR_POD_IP = "SPARK_EXECUTOR_POD_IP" - val ENV_MOUNTED_CLASSPATH = "SPARK_MOUNTED_CLASSPATH" val ENV_JAVA_OPT_PREFIX = "SPARK_JAVA_OPT_" val ENV_CLASSPATH = "SPARK_CLASSPATH" val ENV_DRIVER_BIND_ADDRESS = "SPARK_DRIVER_BIND_ADDRESS" @@ -76,6 +71,8 @@ private[spark] object Constants { val ENV_PYSPARK_FILES = "PYSPARK_FILES" val ENV_PYSPARK_ARGS = "PYSPARK_APP_ARGS" val ENV_PYSPARK_MAJOR_PYTHON_VERSION = "PYSPARK_MAJOR_PYTHON_VERSION" + val ENV_R_PRIMARY = "R_PRIMARY" + val ENV_R_ARGS = "R_APP_ARGS" // Miscellaneous val KUBERNETES_MASTER_INTERNAL_URL = "https://kubernetes.default.svc" diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesConf.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesConf.scala index b0ccaa36b01ed..cae6e7d5ad518 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesConf.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesConf.scala @@ -24,6 +24,7 @@ import org.apache.spark.SparkConf import org.apache.spark.deploy.k8s.Config._ import org.apache.spark.deploy.k8s.Constants._ import org.apache.spark.deploy.k8s.submit._ +import org.apache.spark.deploy.k8s.submit.KubernetesClientApplication._ import org.apache.spark.internal.config.ConfigEntry @@ -43,7 +44,7 @@ private[spark] case class KubernetesDriverSpecificConf( */ private[spark] case class KubernetesExecutorSpecificConf( executorId: String, - driverPod: Pod) + driverPod: Option[Pod]) extends KubernetesRoleSpecificConf /** @@ -59,6 +60,7 @@ private[spark] case class KubernetesConf[T <: KubernetesRoleSpecificConf]( roleSecretNamesToMountPaths: Map[String, String], roleSecretEnvNamesToKeyRefs: Map[String, String], roleEnvs: Map[String, String], + roleVolumes: Iterable[KubernetesVolumeSpec[_ <: KubernetesVolumeSpecificConf]], sparkFiles: Seq[String]) { def namespace(): String = sparkConf.get(KUBERNETES_NAMESPACE) @@ -77,6 +79,9 @@ private[spark] case class KubernetesConf[T <: KubernetesRoleSpecificConf]( def pySparkPythonVersion(): String = sparkConf .get(PYSPARK_MAJOR_PYTHON_VERSION) + def sparkRMainResource(): Option[String] = sparkConf + .get(KUBERNETES_R_MAIN_APP_RESOURCE) + def imagePullPolicy(): String = sparkConf.get(CONTAINER_IMAGE_PULL_POLICY) def imagePullSecrets(): Seq[LocalObjectReference] = { @@ -124,7 +129,7 @@ private[spark] object KubernetesConf { sparkConfWithMainAppJar.setJars(previousJars ++ Seq(res)) } // The function of this outer match is to account for multiple nonJVM - // bindings that will all have increased MEMORY_OVERHEAD_FACTOR to 0.4 + // bindings that will all have increased default MEMORY_OVERHEAD_FACTOR to 0.4 case nonJVM: NonJVMResource => nonJVM match { case PythonMainAppResource(res) => @@ -132,6 +137,9 @@ private[spark] object KubernetesConf { maybePyFiles.foreach{maybePyFiles => additionalFiles.appendAll(maybePyFiles.split(","))} sparkConfWithMainAppJar.set(KUBERNETES_PYSPARK_MAIN_APP_RESOURCE, res) + case RMainAppResource(res) => + additionalFiles += res + sparkConfWithMainAppJar.set(KUBERNETES_R_MAIN_APP_RESOURCE, res) } sparkConfWithMainAppJar.setIfMissing(MEMORY_OVERHEAD_FACTOR, 0.4) } @@ -155,6 +163,12 @@ private[spark] object KubernetesConf { sparkConf, KUBERNETES_DRIVER_SECRET_KEY_REF_PREFIX) val driverEnvs = KubernetesUtils.parsePrefixedKeyValuePairs( sparkConf, KUBERNETES_DRIVER_ENV_PREFIX) + val driverVolumes = KubernetesVolumeUtils.parseVolumesWithPrefix( + sparkConf, KUBERNETES_DRIVER_VOLUMES_PREFIX).map(_.get) + // Also parse executor volumes in order to verify configuration + // before the driver pod is created + KubernetesVolumeUtils.parseVolumesWithPrefix( + sparkConf, KUBERNETES_EXECUTOR_VOLUMES_PREFIX).map(_.get) val sparkFiles = sparkConf .getOption("spark.files") @@ -171,6 +185,7 @@ private[spark] object KubernetesConf { driverSecretNamesToMountPaths, driverSecretEnvNamesToKeyRefs, driverEnvs, + driverVolumes, sparkFiles) } @@ -178,7 +193,7 @@ private[spark] object KubernetesConf { sparkConf: SparkConf, executorId: String, appId: String, - driverPod: Pod): KubernetesConf[KubernetesExecutorSpecificConf] = { + driverPod: Option[Pod]): KubernetesConf[KubernetesExecutorSpecificConf] = { val executorCustomLabels = KubernetesUtils.parsePrefixedKeyValuePairs( sparkConf, KUBERNETES_EXECUTOR_LABEL_PREFIX) require( @@ -203,17 +218,30 @@ private[spark] object KubernetesConf { val executorEnvSecrets = KubernetesUtils.parsePrefixedKeyValuePairs( sparkConf, KUBERNETES_EXECUTOR_SECRET_KEY_REF_PREFIX) val executorEnv = sparkConf.getExecutorEnv.toMap + val executorVolumes = KubernetesVolumeUtils.parseVolumesWithPrefix( + sparkConf, KUBERNETES_EXECUTOR_VOLUMES_PREFIX).map(_.get) + + // If no prefix is defined then we are in pure client mode + // (not the one used by cluster mode inside the container) + val appResourceNamePrefix = { + if (sparkConf.getOption(KUBERNETES_EXECUTOR_POD_NAME_PREFIX.key).isEmpty) { + getResourceNamePrefix(getAppName(sparkConf)) + } else { + sparkConf.get(KUBERNETES_EXECUTOR_POD_NAME_PREFIX) + } + } KubernetesConf( sparkConf.clone(), KubernetesExecutorSpecificConf(executorId, driverPod), - sparkConf.get(KUBERNETES_EXECUTOR_POD_NAME_PREFIX), + appResourceNamePrefix, appId, executorLabels, executorAnnotations, executorMountSecrets, executorEnvSecrets, executorEnv, + executorVolumes, Seq.empty[String]) } } diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesUtils.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesUtils.scala index 593fb531a004d..f5fae7cc8c470 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesUtils.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesUtils.scala @@ -16,9 +16,11 @@ */ package org.apache.spark.deploy.k8s -import io.fabric8.kubernetes.api.model.LocalObjectReference +import scala.collection.JavaConverters._ -import org.apache.spark.SparkConf +import io.fabric8.kubernetes.api.model.{ContainerStateRunning, ContainerStateTerminated, ContainerStateWaiting, ContainerStatus, Pod, Time} + +import org.apache.spark.{SparkConf, SparkException} import org.apache.spark.util.Utils private[spark] object KubernetesUtils { @@ -60,4 +62,83 @@ private[spark] object KubernetesUtils { case _ => uri } } + + def parseMasterUrl(url: String): String = url.substring("k8s://".length) + + def formatPairsBundle(pairs: Seq[(String, String)], indent: Int = 1) : String = { + // Use more loggable format if value is null or empty + val indentStr = "\t" * indent + pairs.map { + case (k, v) => s"\n$indentStr $k: ${Option(v).filter(_.nonEmpty).getOrElse("N/A")}" + }.mkString("") + } + + /** + * Given a pod, output a human readable representation of its state + * + * @param pod Pod + * @return Human readable pod state + */ + def formatPodState(pod: Pod): String = { + val details = Seq[(String, String)]( + // pod metadata + ("pod name", pod.getMetadata.getName), + ("namespace", pod.getMetadata.getNamespace), + ("labels", pod.getMetadata.getLabels.asScala.mkString(", ")), + ("pod uid", pod.getMetadata.getUid), + ("creation time", formatTime(pod.getMetadata.getCreationTimestamp)), + + // spec details + ("service account name", pod.getSpec.getServiceAccountName), + ("volumes", pod.getSpec.getVolumes.asScala.map(_.getName).mkString(", ")), + ("node name", pod.getSpec.getNodeName), + + // status + ("start time", formatTime(pod.getStatus.getStartTime)), + ("phase", pod.getStatus.getPhase), + ("container status", containersDescription(pod, 2)) + ) + + formatPairsBundle(details) + } + + def containersDescription(p: Pod, indent: Int = 1): String = { + p.getStatus.getContainerStatuses.asScala.map { status => + Seq( + ("container name", status.getName), + ("container image", status.getImage)) ++ + containerStatusDescription(status) + }.map(p => formatPairsBundle(p, indent)).mkString("\n\n") + } + + def containerStatusDescription(containerStatus: ContainerStatus) + : Seq[(String, String)] = { + val state = containerStatus.getState + Option(state.getRunning) + .orElse(Option(state.getTerminated)) + .orElse(Option(state.getWaiting)) + .map { + case running: ContainerStateRunning => + Seq( + ("container state", "running"), + ("container started at", formatTime(running.getStartedAt))) + case waiting: ContainerStateWaiting => + Seq( + ("container state", "waiting"), + ("pending reason", waiting.getReason)) + case terminated: ContainerStateTerminated => + Seq( + ("container state", "terminated"), + ("container started at", formatTime(terminated.getStartedAt)), + ("container finished at", formatTime(terminated.getFinishedAt)), + ("exit code", terminated.getExitCode.toString), + ("termination reason", terminated.getReason)) + case unknown => + throw new SparkException(s"Unexpected container status type ${unknown.getClass}.") + }.getOrElse(Seq(("container state", "N/A"))) + } + + def formatTime(time: Time): String = { + if (time != null) time.getTime else "N/A" + } } diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesVolumeSpec.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesVolumeSpec.scala new file mode 100644 index 0000000000000..b1762d1efe2ea --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesVolumeSpec.scala @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s + +private[spark] sealed trait KubernetesVolumeSpecificConf + +private[spark] case class KubernetesHostPathVolumeConf( + hostPath: String) + extends KubernetesVolumeSpecificConf + +private[spark] case class KubernetesPVCVolumeConf( + claimName: String) + extends KubernetesVolumeSpecificConf + +private[spark] case class KubernetesEmptyDirVolumeConf( + medium: Option[String], + sizeLimit: Option[String]) + extends KubernetesVolumeSpecificConf + +private[spark] case class KubernetesVolumeSpec[T <: KubernetesVolumeSpecificConf]( + volumeName: String, + mountPath: String, + mountReadOnly: Boolean, + volumeConf: T) diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesVolumeUtils.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesVolumeUtils.scala new file mode 100644 index 0000000000000..713df5fffc3a2 --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesVolumeUtils.scala @@ -0,0 +1,110 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s + +import java.util.NoSuchElementException + +import scala.util.{Failure, Success, Try} + +import org.apache.spark.SparkConf +import org.apache.spark.deploy.k8s.Config._ + +private[spark] object KubernetesVolumeUtils { + /** + * Extract Spark volume configuration properties with a given name prefix. + * + * @param sparkConf Spark configuration + * @param prefix the given property name prefix + * @return a Map storing with volume name as key and spec as value + */ + def parseVolumesWithPrefix( + sparkConf: SparkConf, + prefix: String): Iterable[Try[KubernetesVolumeSpec[_ <: KubernetesVolumeSpecificConf]]] = { + val properties = sparkConf.getAllWithPrefix(prefix).toMap + + getVolumeTypesAndNames(properties).map { case (volumeType, volumeName) => + val pathKey = s"$volumeType.$volumeName.$KUBERNETES_VOLUMES_MOUNT_PATH_KEY" + val readOnlyKey = s"$volumeType.$volumeName.$KUBERNETES_VOLUMES_MOUNT_READONLY_KEY" + + for { + path <- properties.getTry(pathKey) + volumeConf <- parseVolumeSpecificConf(properties, volumeType, volumeName) + } yield KubernetesVolumeSpec( + volumeName = volumeName, + mountPath = path, + mountReadOnly = properties.get(readOnlyKey).exists(_.toBoolean), + volumeConf = volumeConf + ) + } + } + + /** + * Get unique pairs of volumeType and volumeName, + * assuming options are formatted in this way: + * `volumeType`.`volumeName`.`property` = `value` + * @param properties flat mapping of property names to values + * @return Set[(volumeType, volumeName)] + */ + private def getVolumeTypesAndNames( + properties: Map[String, String] + ): Set[(String, String)] = { + properties.keys.flatMap { k => + k.split('.').toList match { + case tpe :: name :: _ => Some((tpe, name)) + case _ => None + } + }.toSet + } + + private def parseVolumeSpecificConf( + options: Map[String, String], + volumeType: String, + volumeName: String): Try[KubernetesVolumeSpecificConf] = { + volumeType match { + case KUBERNETES_VOLUMES_HOSTPATH_TYPE => + val pathKey = s"$volumeType.$volumeName.$KUBERNETES_VOLUMES_OPTIONS_PATH_KEY" + for { + path <- options.getTry(pathKey) + } yield KubernetesHostPathVolumeConf(path) + + case KUBERNETES_VOLUMES_PVC_TYPE => + val claimNameKey = s"$volumeType.$volumeName.$KUBERNETES_VOLUMES_OPTIONS_CLAIM_NAME_KEY" + for { + claimName <- options.getTry(claimNameKey) + } yield KubernetesPVCVolumeConf(claimName) + + case KUBERNETES_VOLUMES_EMPTYDIR_TYPE => + val mediumKey = s"$volumeType.$volumeName.$KUBERNETES_VOLUMES_OPTIONS_MEDIUM_KEY" + val sizeLimitKey = s"$volumeType.$volumeName.$KUBERNETES_VOLUMES_OPTIONS_SIZE_LIMIT_KEY" + Success(KubernetesEmptyDirVolumeConf(options.get(mediumKey), options.get(sizeLimitKey))) + + case _ => + Failure(new RuntimeException(s"Kubernetes Volume type `$volumeType` is not supported")) + } + } + + /** + * Convenience wrapper to accumulate key lookup errors + */ + implicit private class MapOps[A, B](m: Map[A, B]) { + def getTry(key: A): Try[B] = { + m + .get(key) + .fold[Try[B]](Failure(new NoSuchElementException(key.toString)))(Success(_)) + } + } +} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicDriverFeatureStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicDriverFeatureStep.scala index 143dc8a12304e..575bc54ffe2bb 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicDriverFeatureStep.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicDriverFeatureStep.scala @@ -19,14 +19,15 @@ package org.apache.spark.deploy.k8s.features import scala.collection.JavaConverters._ import scala.collection.mutable -import io.fabric8.kubernetes.api.model.{ContainerBuilder, EnvVarBuilder, EnvVarSourceBuilder, HasMetadata, PodBuilder, QuantityBuilder} +import io.fabric8.kubernetes.api.model._ import org.apache.spark.SparkException -import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesDriverSpecificConf, KubernetesUtils, SparkPod} +import org.apache.spark.deploy.k8s._ import org.apache.spark.deploy.k8s.Config._ import org.apache.spark.deploy.k8s.Constants._ import org.apache.spark.deploy.k8s.submit._ import org.apache.spark.internal.config._ +import org.apache.spark.ui.SparkUI private[spark] class BasicDriverFeatureStep( conf: KubernetesConf[KubernetesDriverSpecificConf]) @@ -72,10 +73,31 @@ private[spark] class BasicDriverFeatureStep( ("cpu", new QuantityBuilder(false).withAmount(limitCores).build()) } + val driverPort = conf.sparkConf.getInt("spark.driver.port", DEFAULT_DRIVER_PORT) + val driverBlockManagerPort = conf.sparkConf.getInt( + DRIVER_BLOCK_MANAGER_PORT.key, + DEFAULT_BLOCKMANAGER_PORT + ) + val driverUIPort = SparkUI.getUIPort(conf.sparkConf) val driverContainer = new ContainerBuilder(pod.container) .withName(DRIVER_CONTAINER_NAME) .withImage(driverContainerImage) .withImagePullPolicy(conf.imagePullPolicy()) + .addNewPort() + .withName(DRIVER_PORT_NAME) + .withContainerPort(driverPort) + .withProtocol("TCP") + .endPort() + .addNewPort() + .withName(BLOCK_MANAGER_PORT_NAME) + .withContainerPort(driverBlockManagerPort) + .withProtocol("TCP") + .endPort() + .addNewPort() + .withName(UI_PORT_NAME) + .withContainerPort(driverUIPort) + .withProtocol("TCP") + .endPort() .addAllToEnv(driverCustomEnvs.asJava) .addNewEnv() .withName(ENV_DRIVER_BIND_ADDRESS) @@ -103,6 +125,7 @@ private[spark] class BasicDriverFeatureStep( .addToImagePullSecrets(conf.imagePullSecrets(): _*) .endSpec() .build() + SparkPod(driverPod, driverContainer) } diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStep.scala index 91c54a9776982..d89995ba5e4f4 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStep.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStep.scala @@ -18,13 +18,13 @@ package org.apache.spark.deploy.k8s.features import scala.collection.JavaConverters._ -import io.fabric8.kubernetes.api.model.{ContainerBuilder, ContainerPortBuilder, EnvVar, EnvVarBuilder, EnvVarSourceBuilder, HasMetadata, PodBuilder, QuantityBuilder} +import io.fabric8.kubernetes.api.model._ import org.apache.spark.SparkException -import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesExecutorSpecificConf, SparkPod} +import org.apache.spark.deploy.k8s._ import org.apache.spark.deploy.k8s.Config._ import org.apache.spark.deploy.k8s.Constants._ -import org.apache.spark.internal.config.{EXECUTOR_CLASS_PATH, EXECUTOR_JAVA_OPTIONS, EXECUTOR_MEMORY, EXECUTOR_MEMORY_OVERHEAD} +import org.apache.spark.internal.config.{EXECUTOR_CLASS_PATH, EXECUTOR_JAVA_OPTIONS, EXECUTOR_MEMORY, EXECUTOR_MEMORY_OVERHEAD, PYSPARK_EXECUTOR_MEMORY} import org.apache.spark.rpc.RpcEndpointAddress import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend import org.apache.spark.util.Utils @@ -58,6 +58,16 @@ private[spark] class BasicExecutorFeatureStep( (kubernetesConf.get(MEMORY_OVERHEAD_FACTOR) * executorMemoryMiB).toInt, MEMORY_OVERHEAD_MIN_MIB)) private val executorMemoryWithOverhead = executorMemoryMiB + memoryOverheadMiB + private val executorMemoryTotal = kubernetesConf.sparkConf + .getOption(APP_RESOURCE_TYPE.key).map{ res => + val additionalPySparkMemory = res match { + case "python" => + kubernetesConf.sparkConf + .get(PYSPARK_EXECUTOR_MEMORY).map(_.toInt).getOrElse(0) + case _ => 0 + } + executorMemoryWithOverhead + additionalPySparkMemory + }.getOrElse(executorMemoryWithOverhead) private val executorCores = kubernetesConf.sparkConf.getInt("spark.executor.cores", 1) private val executorCoresRequest = @@ -76,7 +86,7 @@ private[spark] class BasicExecutorFeatureStep( // executorId val hostname = name.substring(Math.max(0, name.length - 63)) val executorMemoryQuantity = new QuantityBuilder(false) - .withAmount(s"${executorMemoryWithOverhead}Mi") + .withAmount(s"${executorMemoryTotal}Mi") .build() val executorCpuQuantity = new QuantityBuilder(false) .withAmount(executorCoresRequest) @@ -152,19 +162,20 @@ private[spark] class BasicExecutorFeatureStep( .build() }.getOrElse(executorContainer) val driverPod = kubernetesConf.roleSpecificConf.driverPod + val ownerReference = driverPod.map(pod => + new OwnerReferenceBuilder() + .withController(true) + .withApiVersion(pod.getApiVersion) + .withKind(pod.getKind) + .withName(pod.getMetadata.getName) + .withUid(pod.getMetadata.getUid) + .build()) val executorPod = new PodBuilder(pod.pod) .editOrNewMetadata() .withName(name) .withLabels(kubernetesConf.roleLabels.asJava) .withAnnotations(kubernetesConf.roleAnnotations.asJava) - .withOwnerReferences() - .addNewOwnerReference() - .withController(true) - .withApiVersion(driverPod.getApiVersion) - .withKind(driverPod.getKind) - .withName(driverPod.getMetadata.getName) - .withUid(driverPod.getMetadata.getUid) - .endOwnerReference() + .addToOwnerReferences(ownerReference.toSeq: _*) .endMetadata() .editOrNewSpec() .withHostname(hostname) @@ -173,6 +184,7 @@ private[spark] class BasicExecutorFeatureStep( .addToImagePullSecrets(kubernetesConf.imagePullSecrets(): _*) .endSpec() .build() + SparkPod(executorPod, containerWithLimitCores) } diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/LocalDirsFeatureStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/LocalDirsFeatureStep.scala index 70b307303d149..be386e119d465 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/LocalDirsFeatureStep.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/LocalDirsFeatureStep.scala @@ -22,6 +22,7 @@ import java.util.UUID import io.fabric8.kubernetes.api.model.{ContainerBuilder, HasMetadata, PodBuilder, VolumeBuilder, VolumeMountBuilder} import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesDriverSpecificConf, KubernetesRoleSpecificConf, SparkPod} +import org.apache.spark.deploy.k8s.Config._ private[spark] class LocalDirsFeatureStep( conf: KubernetesConf[_ <: KubernetesRoleSpecificConf], @@ -37,6 +38,7 @@ private[spark] class LocalDirsFeatureStep( .orElse(conf.getOption("spark.local.dir")) .getOrElse(defaultLocalDir) .split(",") + private val useLocalDirTmpFs = conf.get(KUBERNETES_LOCAL_DIRS_TMPFS) override def configurePod(pod: SparkPod): SparkPod = { val localDirVolumes = resolvedLocalDirs @@ -45,6 +47,7 @@ private[spark] class LocalDirsFeatureStep( new VolumeBuilder() .withName(s"spark-local-dir-${index + 1}") .withNewEmptyDir() + .withMedium(if (useLocalDirTmpFs) "Memory" else null) .endEmptyDir() .build() } diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/MountVolumesFeatureStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/MountVolumesFeatureStep.scala new file mode 100644 index 0000000000000..bb0e2b3128efd --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/MountVolumesFeatureStep.scala @@ -0,0 +1,79 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.features + +import io.fabric8.kubernetes.api.model._ + +import org.apache.spark.deploy.k8s._ + +private[spark] class MountVolumesFeatureStep( + kubernetesConf: KubernetesConf[_ <: KubernetesRoleSpecificConf]) + extends KubernetesFeatureConfigStep { + + override def configurePod(pod: SparkPod): SparkPod = { + val (volumeMounts, volumes) = constructVolumes(kubernetesConf.roleVolumes).unzip + + val podWithVolumes = new PodBuilder(pod.pod) + .editSpec() + .addToVolumes(volumes.toSeq: _*) + .endSpec() + .build() + + val containerWithVolumeMounts = new ContainerBuilder(pod.container) + .addToVolumeMounts(volumeMounts.toSeq: _*) + .build() + + SparkPod(podWithVolumes, containerWithVolumeMounts) + } + + override def getAdditionalPodSystemProperties(): Map[String, String] = Map.empty + + override def getAdditionalKubernetesResources(): Seq[HasMetadata] = Seq.empty + + private def constructVolumes( + volumeSpecs: Iterable[KubernetesVolumeSpec[_ <: KubernetesVolumeSpecificConf]] + ): Iterable[(VolumeMount, Volume)] = { + volumeSpecs.map { spec => + val volumeMount = new VolumeMountBuilder() + .withMountPath(spec.mountPath) + .withReadOnly(spec.mountReadOnly) + .withName(spec.volumeName) + .build() + + val volumeBuilder = spec.volumeConf match { + case KubernetesHostPathVolumeConf(hostPath) => + new VolumeBuilder() + .withHostPath(new HostPathVolumeSource(hostPath)) + + case KubernetesPVCVolumeConf(claimName) => + new VolumeBuilder() + .withPersistentVolumeClaim( + new PersistentVolumeClaimVolumeSource(claimName, spec.mountReadOnly)) + + case KubernetesEmptyDirVolumeConf(medium, sizeLimit) => + new VolumeBuilder() + .withEmptyDir( + new EmptyDirVolumeSource(medium.getOrElse(""), + new Quantity(sizeLimit.orNull))) + } + + val volume = volumeBuilder.withName(spec.volumeName).build() + + (volumeMount, volume) + } + } +} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/bindings/JavaDriverFeatureStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/bindings/JavaDriverFeatureStep.scala index f52ec9fdc677e..6f063b253cd73 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/bindings/JavaDriverFeatureStep.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/bindings/JavaDriverFeatureStep.scala @@ -19,6 +19,7 @@ package org.apache.spark.deploy.k8s.features.bindings import io.fabric8.kubernetes.api.model.{ContainerBuilder, HasMetadata} import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesDriverSpecificConf, SparkPod} +import org.apache.spark.deploy.k8s.Config.APP_RESOURCE_TYPE import org.apache.spark.deploy.k8s.Constants.SPARK_CONF_PATH import org.apache.spark.deploy.k8s.features.KubernetesFeatureConfigStep import org.apache.spark.launcher.SparkLauncher @@ -38,7 +39,8 @@ private[spark] class JavaDriverFeatureStep( .build() SparkPod(pod.pod, withDriverArgs) } - override def getAdditionalPodSystemProperties(): Map[String, String] = Map.empty + override def getAdditionalPodSystemProperties(): Map[String, String] = + Map(APP_RESOURCE_TYPE.key -> "java") override def getAdditionalKubernetesResources(): Seq[HasMetadata] = Seq.empty } diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/bindings/PythonDriverFeatureStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/bindings/PythonDriverFeatureStep.scala index c20bcac1f8987..cf0c03b22bd7e 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/bindings/PythonDriverFeatureStep.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/bindings/PythonDriverFeatureStep.scala @@ -21,6 +21,7 @@ import scala.collection.JavaConverters._ import io.fabric8.kubernetes.api.model.{ContainerBuilder, EnvVarBuilder, HasMetadata} import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesDriverSpecificConf, KubernetesUtils, SparkPod} +import org.apache.spark.deploy.k8s.Config.APP_RESOURCE_TYPE import org.apache.spark.deploy.k8s.Constants._ import org.apache.spark.deploy.k8s.features.KubernetesFeatureConfigStep @@ -30,11 +31,12 @@ private[spark] class PythonDriverFeatureStep( override def configurePod(pod: SparkPod): SparkPod = { val roleConf = kubernetesConf.roleSpecificConf require(roleConf.mainAppResource.isDefined, "PySpark Main Resource must be defined") + // Delineation is done by " " because that is input into PythonRunner val maybePythonArgs = Option(roleConf.appArgs).filter(_.nonEmpty).map( pyArgs => new EnvVarBuilder() .withName(ENV_PYSPARK_ARGS) - .withValue(pyArgs.mkString(",")) + .withValue(pyArgs.mkString(" ")) .build()) val maybePythonFiles = kubernetesConf.pyFiles().map( // Dilineation by ":" is to append the PySpark Files to the PYTHONPATH @@ -67,7 +69,8 @@ private[spark] class PythonDriverFeatureStep( SparkPod(pod.pod, withPythonPrimaryContainer) } - override def getAdditionalPodSystemProperties(): Map[String, String] = Map.empty + override def getAdditionalPodSystemProperties(): Map[String, String] = + Map(APP_RESOURCE_TYPE.key -> "python") override def getAdditionalKubernetesResources(): Seq[HasMetadata] = Seq.empty } diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/bindings/RDriverFeatureStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/bindings/RDriverFeatureStep.scala new file mode 100644 index 0000000000000..1a7ef52fefe70 --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/bindings/RDriverFeatureStep.scala @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.features.bindings + +import scala.collection.JavaConverters._ + +import io.fabric8.kubernetes.api.model.{ContainerBuilder, EnvVarBuilder, HasMetadata} + +import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesDriverSpecificConf, KubernetesUtils, SparkPod} +import org.apache.spark.deploy.k8s.Config.APP_RESOURCE_TYPE +import org.apache.spark.deploy.k8s.Constants._ +import org.apache.spark.deploy.k8s.features.KubernetesFeatureConfigStep + +private[spark] class RDriverFeatureStep( + kubernetesConf: KubernetesConf[KubernetesDriverSpecificConf]) + extends KubernetesFeatureConfigStep { + override def configurePod(pod: SparkPod): SparkPod = { + val roleConf = kubernetesConf.roleSpecificConf + require(roleConf.mainAppResource.isDefined, "R Main Resource must be defined") + // Delineation is done by " " because that is input into RRunner + val maybeRArgs = Option(roleConf.appArgs).filter(_.nonEmpty).map( + rArgs => + new EnvVarBuilder() + .withName(ENV_R_ARGS) + .withValue(rArgs.mkString(" ")) + .build()) + val envSeq = + Seq(new EnvVarBuilder() + .withName(ENV_R_PRIMARY) + .withValue(KubernetesUtils.resolveFileUri(kubernetesConf.sparkRMainResource().get)) + .build()) + val rEnvs = envSeq ++ + maybeRArgs.toSeq + + val withRPrimaryContainer = new ContainerBuilder(pod.container) + .addAllToEnv(rEnvs.asJava) + .addToArgs("driver-r") + .addToArgs("--properties-file", SPARK_CONF_PATH) + .addToArgs("--class", roleConf.mainClass) + .build() + + SparkPod(pod.pod, withRPrimaryContainer) + } + override def getAdditionalPodSystemProperties(): Map[String, String] = + Map(APP_RESOURCE_TYPE.key -> "r") + + override def getAdditionalKubernetesResources(): Seq[HasMetadata] = Seq.empty +} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesClientApplication.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesClientApplication.scala index eaff47205dbbc..edeaa380194ac 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesClientApplication.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesClientApplication.scala @@ -27,7 +27,7 @@ import scala.util.control.NonFatal import org.apache.spark.SparkConf import org.apache.spark.deploy.SparkApplication -import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesDriverSpecificConf, SparkKubernetesClientFactory} +import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesDriverSpecificConf, KubernetesUtils, SparkKubernetesClientFactory} import org.apache.spark.deploy.k8s.Config._ import org.apache.spark.deploy.k8s.Constants._ import org.apache.spark.internal.Logging @@ -60,6 +60,8 @@ private[spark] object ClientArguments { mainAppResource = Some(JavaMainAppResource(primaryJavaResource)) case Array("--primary-py-file", primaryPythonResource: String) => mainAppResource = Some(PythonMainAppResource(primaryPythonResource)) + case Array("--primary-r-file", primaryRFile: String) => + mainAppResource = Some(RMainAppResource(primaryRFile)) case Array("--other-py-files", pyFiles: String) => maybePyFiles = Some(pyFiles) case Array("--main-class", clazz: String) => @@ -209,11 +211,8 @@ private[spark] class KubernetesClientApplication extends SparkApplication { // considerably restrictive, e.g. must be no longer than 63 characters in length. So we generate // a unique app ID (captured by spark.app.id) in the format below. val kubernetesAppId = s"spark-${UUID.randomUUID().toString.replaceAll("-", "")}" - val launchTime = System.currentTimeMillis() val waitForAppCompletion = sparkConf.get(WAIT_FOR_APP_COMPLETION) - val kubernetesResourceNamePrefix = { - s"$appName-$launchTime".toLowerCase.replaceAll("\\.", "-") - } + val kubernetesResourceNamePrefix = KubernetesClientApplication.getResourceNamePrefix(appName) sparkConf.set(KUBERNETES_PYSPARK_PY_FILES, clientArguments.maybePyFiles.getOrElse("")) val kubernetesConf = KubernetesConf.createDriverConf( sparkConf, @@ -228,7 +227,7 @@ private[spark] class KubernetesClientApplication extends SparkApplication { val namespace = kubernetesConf.namespace() // The master URL has been checked for validity already in SparkSubmit. // We just need to get rid of the "k8s://" prefix here. - val master = sparkConf.get("spark.master").substring("k8s://".length) + val master = KubernetesUtils.parseMasterUrl(sparkConf.get("spark.master")) val loggingInterval = if (waitForAppCompletion) Some(sparkConf.get(REPORT_INTERVAL)) else None val watcher = new LoggingPodStatusWatcherImpl(kubernetesAppId, loggingInterval) @@ -252,3 +251,19 @@ private[spark] class KubernetesClientApplication extends SparkApplication { } } } + +private[spark] object KubernetesClientApplication { + + def getAppName(conf: SparkConf): String = conf.getOption("spark.app.name").getOrElse("spark") + + def getResourceNamePrefix(appName: String): String = { + val launchTime = System.currentTimeMillis() + s"$appName-$launchTime" + .trim + .toLowerCase + .replaceAll("\\s+", "-") + .replaceAll("\\.", "-") + .replaceAll("[^a-z0-9\\-]", "") + .replaceAll("-+", "-") + } +} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilder.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilder.scala index 5762d8245f778..8f3f18ffadc3b 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilder.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilder.scala @@ -17,8 +17,8 @@ package org.apache.spark.deploy.k8s.submit import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesDriverSpec, KubernetesDriverSpecificConf, KubernetesRoleSpecificConf} -import org.apache.spark.deploy.k8s.features.{BasicDriverFeatureStep, DriverKubernetesCredentialsFeatureStep, DriverServiceFeatureStep, EnvSecretsFeatureStep, KubernetesFeatureConfigStep, LocalDirsFeatureStep, MountSecretsFeatureStep} -import org.apache.spark.deploy.k8s.features.bindings.{JavaDriverFeatureStep, PythonDriverFeatureStep} +import org.apache.spark.deploy.k8s.features.{BasicDriverFeatureStep, DriverKubernetesCredentialsFeatureStep, DriverServiceFeatureStep, EnvSecretsFeatureStep, LocalDirsFeatureStep, MountSecretsFeatureStep, MountVolumesFeatureStep} +import org.apache.spark.deploy.k8s.features.bindings.{JavaDriverFeatureStep, PythonDriverFeatureStep, RDriverFeatureStep} private[spark] class KubernetesDriverBuilder( provideBasicStep: (KubernetesConf[KubernetesDriverSpecificConf]) => BasicDriverFeatureStep = @@ -33,18 +33,25 @@ private[spark] class KubernetesDriverBuilder( new MountSecretsFeatureStep(_), provideEnvSecretsStep: (KubernetesConf[_ <: KubernetesRoleSpecificConf] => EnvSecretsFeatureStep) = - new EnvSecretsFeatureStep(_), - provideLocalDirsStep: (KubernetesConf[_ <: KubernetesRoleSpecificConf] - => LocalDirsFeatureStep) = + new EnvSecretsFeatureStep(_), + provideLocalDirsStep: (KubernetesConf[_ <: KubernetesRoleSpecificConf]) + => LocalDirsFeatureStep = new LocalDirsFeatureStep(_), - provideJavaStep: ( - KubernetesConf[KubernetesDriverSpecificConf] - => JavaDriverFeatureStep) = - new JavaDriverFeatureStep(_), + provideVolumesStep: (KubernetesConf[_ <: KubernetesRoleSpecificConf] + => MountVolumesFeatureStep) = + new MountVolumesFeatureStep(_), providePythonStep: ( KubernetesConf[KubernetesDriverSpecificConf] => PythonDriverFeatureStep) = - new PythonDriverFeatureStep(_)) { + new PythonDriverFeatureStep(_), + provideRStep: ( + KubernetesConf[KubernetesDriverSpecificConf] + => RDriverFeatureStep) = + new RDriverFeatureStep(_), + provideJavaStep: ( + KubernetesConf[KubernetesDriverSpecificConf] + => JavaDriverFeatureStep) = + new JavaDriverFeatureStep(_)) { def buildFromFeatures( kubernetesConf: KubernetesConf[KubernetesDriverSpecificConf]): KubernetesDriverSpec = { @@ -54,22 +61,27 @@ private[spark] class KubernetesDriverBuilder( provideServiceStep(kubernetesConf), provideLocalDirsStep(kubernetesConf)) - val maybeRoleSecretNamesStep = if (kubernetesConf.roleSecretNamesToMountPaths.nonEmpty) { - Some(provideSecretsStep(kubernetesConf)) } else None - - val maybeProvideSecretsStep = if (kubernetesConf.roleSecretEnvNamesToKeyRefs.nonEmpty) { - Some(provideEnvSecretsStep(kubernetesConf)) } else None + val secretFeature = if (kubernetesConf.roleSecretNamesToMountPaths.nonEmpty) { + Seq(provideSecretsStep(kubernetesConf)) + } else Nil + val envSecretFeature = if (kubernetesConf.roleSecretEnvNamesToKeyRefs.nonEmpty) { + Seq(provideEnvSecretsStep(kubernetesConf)) + } else Nil + val volumesFeature = if (kubernetesConf.roleVolumes.nonEmpty) { + Seq(provideVolumesStep(kubernetesConf)) + } else Nil val bindingsStep = kubernetesConf.roleSpecificConf.mainAppResource.map { case JavaMainAppResource(_) => provideJavaStep(kubernetesConf) case PythonMainAppResource(_) => - providePythonStep(kubernetesConf)}.getOrElse(provideJavaStep(kubernetesConf)) + providePythonStep(kubernetesConf) + case RMainAppResource(_) => + provideRStep(kubernetesConf)} + .getOrElse(provideJavaStep(kubernetesConf)) - val allFeatures: Seq[KubernetesFeatureConfigStep] = - (baseFeatures :+ bindingsStep) ++ - maybeRoleSecretNamesStep.toSeq ++ - maybeProvideSecretsStep.toSeq + val allFeatures = (baseFeatures :+ bindingsStep) ++ + secretFeature ++ envSecretFeature ++ volumesFeature var spec = KubernetesDriverSpec.initialSpec(kubernetesConf.sparkConf.getAll.toMap) for (feature <- allFeatures) { diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/LoggingPodStatusWatcher.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/LoggingPodStatusWatcher.scala index 173ac541626a7..1889fe5eb3e9b 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/LoggingPodStatusWatcher.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/LoggingPodStatusWatcher.scala @@ -25,6 +25,7 @@ import io.fabric8.kubernetes.client.{KubernetesClientException, Watcher} import io.fabric8.kubernetes.client.Watcher.Action import org.apache.spark.SparkException +import org.apache.spark.deploy.k8s.KubernetesUtils._ import org.apache.spark.internal.Logging import org.apache.spark.util.ThreadUtils @@ -99,82 +100,10 @@ private[k8s] class LoggingPodStatusWatcherImpl( scheduler.shutdown() } - private def formatPodState(pod: Pod): String = { - val details = Seq[(String, String)]( - // pod metadata - ("pod name", pod.getMetadata.getName), - ("namespace", pod.getMetadata.getNamespace), - ("labels", pod.getMetadata.getLabels.asScala.mkString(", ")), - ("pod uid", pod.getMetadata.getUid), - ("creation time", formatTime(pod.getMetadata.getCreationTimestamp)), - - // spec details - ("service account name", pod.getSpec.getServiceAccountName), - ("volumes", pod.getSpec.getVolumes.asScala.map(_.getName).mkString(", ")), - ("node name", pod.getSpec.getNodeName), - - // status - ("start time", formatTime(pod.getStatus.getStartTime)), - ("container images", - pod.getStatus.getContainerStatuses - .asScala - .map(_.getImage) - .mkString(", ")), - ("phase", pod.getStatus.getPhase), - ("status", pod.getStatus.getContainerStatuses.toString) - ) - - formatPairsBundle(details) - } - - private def formatPairsBundle(pairs: Seq[(String, String)]) = { - // Use more loggable format if value is null or empty - pairs.map { - case (k, v) => s"\n\t $k: ${Option(v).filter(_.nonEmpty).getOrElse("N/A")}" - }.mkString("") - } - override def awaitCompletion(): Unit = { podCompletedFuture.await() logInfo(pod.map { p => s"Container final statuses:\n\n${containersDescription(p)}" }.getOrElse("No containers were found in the driver pod.")) } - - private def containersDescription(p: Pod): String = { - p.getStatus.getContainerStatuses.asScala.map { status => - Seq( - ("Container name", status.getName), - ("Container image", status.getImage)) ++ - containerStatusDescription(status) - }.map(formatPairsBundle).mkString("\n\n") - } - - private def containerStatusDescription( - containerStatus: ContainerStatus): Seq[(String, String)] = { - val state = containerStatus.getState - Option(state.getRunning) - .orElse(Option(state.getTerminated)) - .orElse(Option(state.getWaiting)) - .map { - case running: ContainerStateRunning => - Seq( - ("Container state", "Running"), - ("Container started at", formatTime(running.getStartedAt))) - case waiting: ContainerStateWaiting => - Seq( - ("Container state", "Waiting"), - ("Pending reason", waiting.getReason)) - case terminated: ContainerStateTerminated => - Seq( - ("Container state", "Terminated"), - ("Exit code", terminated.getExitCode.toString)) - case unknown => - throw new SparkException(s"Unexpected container status type ${unknown.getClass}.") - }.getOrElse(Seq(("Container state", "N/A"))) - } - - private def formatTime(time: Time): String = { - if (time != null) time.getTime else "N/A" - } } diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/MainAppResource.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/MainAppResource.scala index cbe081ae35683..dd5a4549743df 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/MainAppResource.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/MainAppResource.scala @@ -24,3 +24,6 @@ private[spark] case class JavaMainAppResource(primaryResource: String) extends M private[spark] case class PythonMainAppResource(primaryResource: String) extends MainAppResource with NonJVMResource + +private[spark] case class RMainAppResource(primaryResource: String) + extends MainAppResource with NonJVMResource diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsAllocator.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsAllocator.scala index 5a143ad3600fd..77bb9c3fcc9f4 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsAllocator.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsAllocator.scala @@ -46,13 +46,18 @@ private[spark] class ExecutorPodsAllocator( private val podCreationTimeout = math.max(podAllocationDelay * 5, 60000) + private val namespace = conf.get(KUBERNETES_NAMESPACE) + private val kubernetesDriverPodName = conf .get(KUBERNETES_DRIVER_POD_NAME) - .getOrElse(throw new SparkException("Must specify the driver pod name")) - private val driverPod = kubernetesClient.pods() - .withName(kubernetesDriverPodName) - .get() + private val driverPod = kubernetesDriverPodName + .map(name => Option(kubernetesClient.pods() + .withName(name) + .get()) + .getOrElse(throw new SparkException( + s"No pod was found named $kubernetesDriverPodName in the cluster in the " + + s"namespace $namespace (this was supposed to be the driver pod.)."))) // Executor IDs that have been requested from Kubernetes but have not been detected in any // snapshot yet. Mapped to the timestamp when they were created. diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsLifecycleManager.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsLifecycleManager.scala index b28d93990313e..e2800cff7b720 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsLifecycleManager.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsLifecycleManager.scala @@ -24,6 +24,7 @@ import scala.collection.mutable import org.apache.spark.SparkConf import org.apache.spark.deploy.k8s.Config._ +import org.apache.spark.deploy.k8s.KubernetesUtils._ import org.apache.spark.internal.Logging import org.apache.spark.scheduler.ExecutorExited import org.apache.spark.util.Utils @@ -151,13 +152,15 @@ private[spark] class ExecutorPodsLifecycleManager( private def exitReasonMessage(podState: FinalPodState, execId: Long, exitCode: Int) = { val pod = podState.pod + val reason = Option(pod.getStatus.getReason) + val message = Option(pod.getStatus.getMessage) s""" |The executor with id $execId exited with exit code $exitCode. - |The API gave the following brief reason: ${pod.getStatus.getReason} - |The API gave the following message: ${pod.getStatus.getMessage} + |The API gave the following brief reason: ${reason.getOrElse("N/A")} + |The API gave the following message: ${message.getOrElse("N/A")} |The API gave the following container statuses: | - |${pod.getStatus.getContainerStatuses.asScala.map(_.toString).mkString("\n===\n")} + |${containersDescription(pod)} """.stripMargin } diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterManager.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterManager.scala index c6e931a38405f..9999c62c878df 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterManager.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterManager.scala @@ -22,7 +22,7 @@ import java.util.concurrent.TimeUnit import com.google.common.cache.CacheBuilder import io.fabric8.kubernetes.client.Config -import org.apache.spark.{SparkContext, SparkException} +import org.apache.spark.SparkContext import org.apache.spark.deploy.k8s.{KubernetesUtils, SparkKubernetesClientFactory} import org.apache.spark.deploy.k8s.Config._ import org.apache.spark.deploy.k8s.Constants._ @@ -35,12 +35,6 @@ private[spark] class KubernetesClusterManager extends ExternalClusterManager wit override def canCreate(masterURL: String): Boolean = masterURL.startsWith("k8s") override def createTaskScheduler(sc: SparkContext, masterURL: String): TaskScheduler = { - if (masterURL.startsWith("k8s") && - sc.deployMode == "client" && - !sc.conf.get(KUBERNETES_DRIVER_SUBMIT_CHECK).getOrElse(false)) { - throw new SparkException("Client mode is currently not supported for Kubernetes.") - } - new TaskSchedulerImpl(sc) } @@ -48,15 +42,32 @@ private[spark] class KubernetesClusterManager extends ExternalClusterManager wit sc: SparkContext, masterURL: String, scheduler: TaskScheduler): SchedulerBackend = { - val executorSecretNamesToMountPaths = KubernetesUtils.parsePrefixedKeyValuePairs( - sc.conf, KUBERNETES_EXECUTOR_SECRETS_PREFIX) + val wasSparkSubmittedInClusterMode = sc.conf.get(KUBERNETES_DRIVER_SUBMIT_CHECK) + val (authConfPrefix, + apiServerUri, + defaultServiceAccountToken, + defaultServiceAccountCaCrt) = if (wasSparkSubmittedInClusterMode) { + require(sc.conf.get(KUBERNETES_DRIVER_POD_NAME).isDefined, + "If the application is deployed using spark-submit in cluster mode, the driver pod name " + + "must be provided.") + (KUBERNETES_AUTH_DRIVER_MOUNTED_CONF_PREFIX, + KUBERNETES_MASTER_INTERNAL_URL, + Some(new File(Config.KUBERNETES_SERVICE_ACCOUNT_TOKEN_PATH)), + Some(new File(Config.KUBERNETES_SERVICE_ACCOUNT_CA_CRT_PATH))) + } else { + (KUBERNETES_AUTH_CLIENT_MODE_PREFIX, + KubernetesUtils.parseMasterUrl(masterURL), + None, + None) + } + val kubernetesClient = SparkKubernetesClientFactory.createKubernetesClient( - KUBERNETES_MASTER_INTERNAL_URL, + apiServerUri, Some(sc.conf.get(KUBERNETES_NAMESPACE)), - KUBERNETES_AUTH_DRIVER_MOUNTED_CONF_PREFIX, + authConfPrefix, sc.conf, - Some(new File(Config.KUBERNETES_SERVICE_ACCOUNT_TOKEN_PATH)), - Some(new File(Config.KUBERNETES_SERVICE_ACCOUNT_CA_CRT_PATH))) + defaultServiceAccountToken, + defaultServiceAccountCaCrt) val requestExecutorsService = ThreadUtils.newDaemonCachedThreadPool( "kubernetes-executor-requests") diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilder.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilder.scala index 769a0a5a63047..364b6fb367722 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilder.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilder.scala @@ -17,37 +17,41 @@ package org.apache.spark.scheduler.cluster.k8s import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesExecutorSpecificConf, KubernetesRoleSpecificConf, SparkPod} -import org.apache.spark.deploy.k8s.features.{BasicExecutorFeatureStep, EnvSecretsFeatureStep, KubernetesFeatureConfigStep, LocalDirsFeatureStep, MountSecretsFeatureStep} +import org.apache.spark.deploy.k8s.features._ +import org.apache.spark.deploy.k8s.features.{BasicExecutorFeatureStep, EnvSecretsFeatureStep, LocalDirsFeatureStep, MountSecretsFeatureStep} private[spark] class KubernetesExecutorBuilder( - provideBasicStep: (KubernetesConf[KubernetesExecutorSpecificConf]) => BasicExecutorFeatureStep = + provideBasicStep: (KubernetesConf [KubernetesExecutorSpecificConf]) + => BasicExecutorFeatureStep = new BasicExecutorFeatureStep(_), - provideSecretsStep: - (KubernetesConf[_ <: KubernetesRoleSpecificConf]) => MountSecretsFeatureStep = + provideSecretsStep: (KubernetesConf[_ <: KubernetesRoleSpecificConf]) + => MountSecretsFeatureStep = new MountSecretsFeatureStep(_), provideEnvSecretsStep: (KubernetesConf[_ <: KubernetesRoleSpecificConf] => EnvSecretsFeatureStep) = new EnvSecretsFeatureStep(_), provideLocalDirsStep: (KubernetesConf[_ <: KubernetesRoleSpecificConf]) => LocalDirsFeatureStep = - new LocalDirsFeatureStep(_)) { + new LocalDirsFeatureStep(_), + provideVolumesStep: (KubernetesConf[_ <: KubernetesRoleSpecificConf] + => MountVolumesFeatureStep) = + new MountVolumesFeatureStep(_)) { def buildFromFeatures( kubernetesConf: KubernetesConf[KubernetesExecutorSpecificConf]): SparkPod = { - val baseFeatures = Seq( - provideBasicStep(kubernetesConf), - provideLocalDirsStep(kubernetesConf)) - val maybeRoleSecretNamesStep = if (kubernetesConf.roleSecretNamesToMountPaths.nonEmpty) { - Some(provideSecretsStep(kubernetesConf)) } else None + val baseFeatures = Seq(provideBasicStep(kubernetesConf), provideLocalDirsStep(kubernetesConf)) + val secretFeature = if (kubernetesConf.roleSecretNamesToMountPaths.nonEmpty) { + Seq(provideSecretsStep(kubernetesConf)) + } else Nil + val secretEnvFeature = if (kubernetesConf.roleSecretEnvNamesToKeyRefs.nonEmpty) { + Seq(provideEnvSecretsStep(kubernetesConf)) + } else Nil + val volumesFeature = if (kubernetesConf.roleVolumes.nonEmpty) { + Seq(provideVolumesStep(kubernetesConf)) + } else Nil - val maybeProvideSecretsStep = if (kubernetesConf.roleSecretEnvNamesToKeyRefs.nonEmpty) { - Some(provideEnvSecretsStep(kubernetesConf)) } else None - - val allFeatures: Seq[KubernetesFeatureConfigStep] = - baseFeatures ++ - maybeRoleSecretNamesStep.toSeq ++ - maybeProvideSecretsStep.toSeq + val allFeatures = baseFeatures ++ secretFeature ++ secretEnvFeature ++ volumesFeature var executorPod = SparkPod.initialPod() for (feature <- allFeatures) { diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/KubernetesConfSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/KubernetesConfSuite.scala index 661f942435921..e3c19cdb81567 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/KubernetesConfSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/KubernetesConfSuite.scala @@ -122,6 +122,28 @@ class KubernetesConfSuite extends SparkFunSuite { === Array("local:///opt/spark/example4.py", mainResourceFile) ++ inputPyFiles) } + test("Creating driver conf with a r primary file") { + val mainResourceFile = "local:///opt/spark/main.R" + val sparkConf = new SparkConf(false) + .setJars(Seq("local:///opt/spark/jar1.jar")) + .set("spark.files", "local:///opt/spark/example2.R") + val mainAppResource = Some(RMainAppResource(mainResourceFile)) + val kubernetesConfWithMainResource = KubernetesConf.createDriverConf( + sparkConf, + APP_NAME, + RESOURCE_NAME_PREFIX, + APP_ID, + mainAppResource, + MAIN_CLASS, + APP_ARGS, + maybePyFiles = None) + assert(kubernetesConfWithMainResource.sparkConf.get("spark.jars").split(",") + === Array("local:///opt/spark/jar1.jar")) + assert(kubernetesConfWithMainResource.sparkConf.get(MEMORY_OVERHEAD_FACTOR) === 0.4) + assert(kubernetesConfWithMainResource.sparkFiles + === Array("local:///opt/spark/example2.R", mainResourceFile)) + } + test("Testing explicit setting of memory overhead on non-JVM tasks") { val sparkConf = new SparkConf(false) .set(MEMORY_OVERHEAD_FACTOR, 0.3) @@ -184,9 +206,9 @@ class KubernetesConfSuite extends SparkFunSuite { new SparkConf(false), EXECUTOR_ID, APP_ID, - DRIVER_POD) + Some(DRIVER_POD)) assert(conf.roleSpecificConf.executorId === EXECUTOR_ID) - assert(conf.roleSpecificConf.driverPod === DRIVER_POD) + assert(conf.roleSpecificConf.driverPod.get === DRIVER_POD) } test("Image pull secrets.") { @@ -195,7 +217,7 @@ class KubernetesConfSuite extends SparkFunSuite { .set(IMAGE_PULL_SECRETS, "my-secret-1,my-secret-2 "), EXECUTOR_ID, APP_ID, - DRIVER_POD) + Some(DRIVER_POD)) assert(conf.imagePullSecrets() === Seq( new LocalObjectReferenceBuilder().withName("my-secret-1").build(), @@ -221,7 +243,7 @@ class KubernetesConfSuite extends SparkFunSuite { sparkConf, EXECUTOR_ID, APP_ID, - DRIVER_POD) + Some(DRIVER_POD)) assert(conf.roleLabels === Map( SPARK_EXECUTOR_ID_LABEL -> EXECUTOR_ID, SPARK_APP_ID_LABEL -> APP_ID, diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/KubernetesVolumeUtilsSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/KubernetesVolumeUtilsSuite.scala new file mode 100644 index 0000000000000..d795d159773a8 --- /dev/null +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/KubernetesVolumeUtilsSuite.scala @@ -0,0 +1,106 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s + +import org.apache.spark.{SparkConf, SparkFunSuite} + +class KubernetesVolumeUtilsSuite extends SparkFunSuite { + test("Parses hostPath volumes correctly") { + val sparkConf = new SparkConf(false) + sparkConf.set("test.hostPath.volumeName.mount.path", "/path") + sparkConf.set("test.hostPath.volumeName.mount.readOnly", "true") + sparkConf.set("test.hostPath.volumeName.options.path", "/hostPath") + + val volumeSpec = KubernetesVolumeUtils.parseVolumesWithPrefix(sparkConf, "test.").head.get + assert(volumeSpec.volumeName === "volumeName") + assert(volumeSpec.mountPath === "/path") + assert(volumeSpec.mountReadOnly === true) + assert(volumeSpec.volumeConf.asInstanceOf[KubernetesHostPathVolumeConf] === + KubernetesHostPathVolumeConf("/hostPath")) + } + + test("Parses persistentVolumeClaim volumes correctly") { + val sparkConf = new SparkConf(false) + sparkConf.set("test.persistentVolumeClaim.volumeName.mount.path", "/path") + sparkConf.set("test.persistentVolumeClaim.volumeName.mount.readOnly", "true") + sparkConf.set("test.persistentVolumeClaim.volumeName.options.claimName", "claimeName") + + val volumeSpec = KubernetesVolumeUtils.parseVolumesWithPrefix(sparkConf, "test.").head.get + assert(volumeSpec.volumeName === "volumeName") + assert(volumeSpec.mountPath === "/path") + assert(volumeSpec.mountReadOnly === true) + assert(volumeSpec.volumeConf.asInstanceOf[KubernetesPVCVolumeConf] === + KubernetesPVCVolumeConf("claimeName")) + } + + test("Parses emptyDir volumes correctly") { + val sparkConf = new SparkConf(false) + sparkConf.set("test.emptyDir.volumeName.mount.path", "/path") + sparkConf.set("test.emptyDir.volumeName.mount.readOnly", "true") + sparkConf.set("test.emptyDir.volumeName.options.medium", "medium") + sparkConf.set("test.emptyDir.volumeName.options.sizeLimit", "5G") + + val volumeSpec = KubernetesVolumeUtils.parseVolumesWithPrefix(sparkConf, "test.").head.get + assert(volumeSpec.volumeName === "volumeName") + assert(volumeSpec.mountPath === "/path") + assert(volumeSpec.mountReadOnly === true) + assert(volumeSpec.volumeConf.asInstanceOf[KubernetesEmptyDirVolumeConf] === + KubernetesEmptyDirVolumeConf(Some("medium"), Some("5G"))) + } + + test("Parses emptyDir volume options can be optional") { + val sparkConf = new SparkConf(false) + sparkConf.set("test.emptyDir.volumeName.mount.path", "/path") + sparkConf.set("test.emptyDir.volumeName.mount.readOnly", "true") + + val volumeSpec = KubernetesVolumeUtils.parseVolumesWithPrefix(sparkConf, "test.").head.get + assert(volumeSpec.volumeName === "volumeName") + assert(volumeSpec.mountPath === "/path") + assert(volumeSpec.mountReadOnly === true) + assert(volumeSpec.volumeConf.asInstanceOf[KubernetesEmptyDirVolumeConf] === + KubernetesEmptyDirVolumeConf(None, None)) + } + + test("Defaults optional readOnly to false") { + val sparkConf = new SparkConf(false) + sparkConf.set("test.hostPath.volumeName.mount.path", "/path") + sparkConf.set("test.hostPath.volumeName.options.path", "/hostPath") + + val volumeSpec = KubernetesVolumeUtils.parseVolumesWithPrefix(sparkConf, "test.").head.get + assert(volumeSpec.mountReadOnly === false) + } + + test("Gracefully fails on missing mount key") { + val sparkConf = new SparkConf(false) + sparkConf.set("test.emptyDir.volumeName.mnt.path", "/path") + + val volumeSpec = KubernetesVolumeUtils.parseVolumesWithPrefix(sparkConf, "test.").head + assert(volumeSpec.isFailure === true) + assert(volumeSpec.failed.get.getMessage === "emptyDir.volumeName.mount.path") + } + + test("Gracefully fails on missing option key") { + val sparkConf = new SparkConf(false) + sparkConf.set("test.hostPath.volumeName.mount.path", "/path") + sparkConf.set("test.hostPath.volumeName.mount.readOnly", "true") + sparkConf.set("test.hostPath.volumeName.options.pth", "/hostPath") + + val volumeSpec = KubernetesVolumeUtils.parseVolumesWithPrefix(sparkConf, "test.").head + assert(volumeSpec.isFailure === true) + assert(volumeSpec.failed.get.getMessage === "hostPath.volumeName.options.path") + } +} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicDriverFeatureStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicDriverFeatureStepSuite.scala index 04b909db9d9f3..0968cce971c31 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicDriverFeatureStepSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicDriverFeatureStepSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.deploy.k8s.features import scala.collection.JavaConverters._ -import io.fabric8.kubernetes.api.model.LocalObjectReferenceBuilder +import io.fabric8.kubernetes.api.model.{ContainerPort, ContainerPortBuilder, LocalObjectReferenceBuilder} import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesDriverSpecificConf, SparkPod} @@ -26,6 +26,7 @@ import org.apache.spark.deploy.k8s.Config._ import org.apache.spark.deploy.k8s.Constants._ import org.apache.spark.deploy.k8s.submit.JavaMainAppResource import org.apache.spark.deploy.k8s.submit.PythonMainAppResource +import org.apache.spark.ui.SparkUI class BasicDriverFeatureStepSuite extends SparkFunSuite { @@ -50,6 +51,11 @@ class BasicDriverFeatureStepSuite extends SparkFunSuite { TEST_IMAGE_PULL_SECRETS.map { secret => new LocalObjectReferenceBuilder().withName(secret).build() } + private val emptyDriverSpecificConf = KubernetesDriverSpecificConf( + None, + APP_NAME, + MAIN_CLASS, + APP_ARGS) test("Check the pod respects all configurations from the user.") { val sparkConf = new SparkConf() @@ -62,11 +68,7 @@ class BasicDriverFeatureStepSuite extends SparkFunSuite { .set(IMAGE_PULL_SECRETS, TEST_IMAGE_PULL_SECRETS.mkString(",")) val kubernetesConf = KubernetesConf( sparkConf, - KubernetesDriverSpecificConf( - Some(JavaMainAppResource("")), - APP_NAME, - MAIN_CLASS, - APP_ARGS), + emptyDriverSpecificConf, RESOURCE_NAME_PREFIX, APP_ID, DRIVER_LABELS, @@ -74,6 +76,7 @@ class BasicDriverFeatureStepSuite extends SparkFunSuite { Map.empty, Map.empty, DRIVER_ENVS, + Nil, Seq.empty[String]) val featureStep = new BasicDriverFeatureStep(kubernetesConf) @@ -84,6 +87,14 @@ class BasicDriverFeatureStepSuite extends SparkFunSuite { assert(configuredPod.container.getImage === "spark-driver:latest") assert(configuredPod.container.getImagePullPolicy === CONTAINER_IMAGE_PULL_POLICY) + val expectedPortNames = Set( + containerPort(DRIVER_PORT_NAME, DEFAULT_DRIVER_PORT), + containerPort(BLOCK_MANAGER_PORT_NAME, DEFAULT_BLOCKMANAGER_PORT), + containerPort(UI_PORT_NAME, SparkUI.DEFAULT_PORT) + ) + val foundPortNames = configuredPod.container.getPorts.asScala.toSet + assert(expectedPortNames === foundPortNames) + assert(configuredPod.container.getEnv.size === 3) val envs = configuredPod.container .getEnv @@ -143,6 +154,7 @@ class BasicDriverFeatureStepSuite extends SparkFunSuite { Map.empty, Map.empty, DRIVER_ENVS, + Nil, Seq.empty[String]) val pythonKubernetesConf = KubernetesConf( pythonSparkConf, @@ -158,6 +170,7 @@ class BasicDriverFeatureStepSuite extends SparkFunSuite { Map.empty, Map.empty, DRIVER_ENVS, + Nil, Seq.empty[String]) val javaFeatureStep = new BasicDriverFeatureStep(javaKubernetesConf) val pythonFeatureStep = new BasicDriverFeatureStep(pythonKubernetesConf) @@ -176,11 +189,7 @@ class BasicDriverFeatureStepSuite extends SparkFunSuite { .set(CONTAINER_IMAGE, "spark-driver:latest") val kubernetesConf = KubernetesConf( sparkConf, - KubernetesDriverSpecificConf( - Some(JavaMainAppResource("")), - APP_NAME, - MAIN_CLASS, - APP_ARGS), + emptyDriverSpecificConf, RESOURCE_NAME_PREFIX, APP_ID, DRIVER_LABELS, @@ -188,7 +197,9 @@ class BasicDriverFeatureStepSuite extends SparkFunSuite { Map.empty, Map.empty, DRIVER_ENVS, + Nil, allFiles) + val step = new BasicDriverFeatureStep(kubernetesConf) val additionalProperties = step.getAdditionalPodSystemProperties() val expectedSparkConf = Map( @@ -200,4 +211,11 @@ class BasicDriverFeatureStepSuite extends SparkFunSuite { "spark.files" -> "https://localhost:9000/file1.txt,/opt/spark/file2.txt") assert(additionalProperties === expectedSparkConf) } + + def containerPort(name: String, portNumber: Int): ContainerPort = + new ContainerPortBuilder() + .withName(name) + .withContainerPort(portNumber) + .withProtocol("TCP") + .build() } diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStepSuite.scala index f06030aa55c0c..63b237b9dfe46 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStepSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStepSuite.scala @@ -75,13 +75,14 @@ class BasicExecutorFeatureStepSuite .set("spark.driver.host", DRIVER_HOSTNAME) .set("spark.driver.port", DRIVER_PORT.toString) .set(IMAGE_PULL_SECRETS, TEST_IMAGE_PULL_SECRETS.mkString(",")) + .set("spark.kubernetes.resource.type", "java") } test("basic executor pod has reasonable defaults") { val step = new BasicExecutorFeatureStep( KubernetesConf( baseConf, - KubernetesExecutorSpecificConf("1", DRIVER_POD), + KubernetesExecutorSpecificConf("1", Some(DRIVER_POD)), RESOURCE_NAME_PREFIX, APP_ID, LABELS, @@ -89,6 +90,7 @@ class BasicExecutorFeatureStepSuite Map.empty, Map.empty, Map.empty, + Nil, Seq.empty[String])) val executor = step.configurePod(SparkPod.initialPod()) @@ -120,7 +122,7 @@ class BasicExecutorFeatureStepSuite val step = new BasicExecutorFeatureStep( KubernetesConf( conf, - KubernetesExecutorSpecificConf("1", DRIVER_POD), + KubernetesExecutorSpecificConf("1", Some(DRIVER_POD)), longPodNamePrefix, APP_ID, LABELS, @@ -128,6 +130,7 @@ class BasicExecutorFeatureStepSuite Map.empty, Map.empty, Map.empty, + Nil, Seq.empty[String])) assert(step.configurePod(SparkPod.initialPod()).pod.getSpec.getHostname.length === 63) } @@ -140,7 +143,7 @@ class BasicExecutorFeatureStepSuite val step = new BasicExecutorFeatureStep( KubernetesConf( conf, - KubernetesExecutorSpecificConf("1", DRIVER_POD), + KubernetesExecutorSpecificConf("1", Some(DRIVER_POD)), RESOURCE_NAME_PREFIX, APP_ID, LABELS, @@ -148,6 +151,7 @@ class BasicExecutorFeatureStepSuite Map.empty, Map.empty, Map("qux" -> "quux"), + Nil, Seq.empty[String])) val executor = step.configurePod(SparkPod.initialPod()) @@ -158,6 +162,29 @@ class BasicExecutorFeatureStepSuite checkOwnerReferences(executor.pod, DRIVER_POD_UID) } + test("test executor pyspark memory") { + val conf = baseConf.clone() + conf.set("spark.kubernetes.resource.type", "python") + conf.set(org.apache.spark.internal.config.PYSPARK_EXECUTOR_MEMORY, 42L) + + val step = new BasicExecutorFeatureStep( + KubernetesConf( + conf, + KubernetesExecutorSpecificConf("1", Some(DRIVER_POD)), + RESOURCE_NAME_PREFIX, + APP_ID, + LABELS, + ANNOTATIONS, + Map.empty, + Map.empty, + Map.empty, + Nil, + Seq.empty[String])) + val executor = step.configurePod(SparkPod.initialPod()) + // This is checking that basic executor + executorMemory = 1408 + 42 = 1450 + assert(executor.container.getResources.getRequests.get("memory").getAmount === "1450Mi") + } + // There is always exactly one controller reference, and it points to the driver pod. private def checkOwnerReferences(executor: Pod, driverPodUid: String): Unit = { assert(executor.getMetadata.getOwnerReferences.size() === 1) diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/DriverKubernetesCredentialsFeatureStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/DriverKubernetesCredentialsFeatureStepSuite.scala index 7cea83591f3e8..7e916b3854404 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/DriverKubernetesCredentialsFeatureStepSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/DriverKubernetesCredentialsFeatureStepSuite.scala @@ -61,6 +61,7 @@ class DriverKubernetesCredentialsFeatureStepSuite extends SparkFunSuite with Bef Map.empty, Map.empty, Map.empty, + Nil, Seq.empty[String]) val kubernetesCredentialsStep = new DriverKubernetesCredentialsFeatureStep(kubernetesConf) assert(kubernetesCredentialsStep.configurePod(BASE_DRIVER_POD) === BASE_DRIVER_POD) @@ -92,6 +93,7 @@ class DriverKubernetesCredentialsFeatureStepSuite extends SparkFunSuite with Bef Map.empty, Map.empty, Map.empty, + Nil, Seq.empty[String]) val kubernetesCredentialsStep = new DriverKubernetesCredentialsFeatureStep(kubernetesConf) @@ -130,6 +132,7 @@ class DriverKubernetesCredentialsFeatureStepSuite extends SparkFunSuite with Bef Map.empty, Map.empty, Map.empty, + Nil, Seq.empty[String]) val kubernetesCredentialsStep = new DriverKubernetesCredentialsFeatureStep(kubernetesConf) val resolvedProperties = kubernetesCredentialsStep.getAdditionalPodSystemProperties() diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/DriverServiceFeatureStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/DriverServiceFeatureStepSuite.scala index 77d38bf19cd10..8b91e93eecd8c 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/DriverServiceFeatureStepSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/DriverServiceFeatureStepSuite.scala @@ -67,6 +67,7 @@ class DriverServiceFeatureStepSuite extends SparkFunSuite with BeforeAndAfter { Map.empty, Map.empty, Map.empty, + Nil, Seq.empty[String])) assert(configurationStep.configurePod(SparkPod.initialPod()) === SparkPod.initialPod()) assert(configurationStep.getAdditionalKubernetesResources().size === 1) @@ -98,6 +99,7 @@ class DriverServiceFeatureStepSuite extends SparkFunSuite with BeforeAndAfter { Map.empty, Map.empty, Map.empty, + Nil, Seq.empty[String])) val expectedServiceName = SHORT_RESOURCE_NAME_PREFIX + DriverServiceFeatureStep.DRIVER_SVC_POSTFIX @@ -119,6 +121,7 @@ class DriverServiceFeatureStepSuite extends SparkFunSuite with BeforeAndAfter { Map.empty, Map.empty, Map.empty, + Nil, Seq.empty[String])) val resolvedService = configurationStep .getAdditionalKubernetesResources() @@ -149,6 +152,7 @@ class DriverServiceFeatureStepSuite extends SparkFunSuite with BeforeAndAfter { Map.empty, Map.empty, Map.empty, + Nil, Seq.empty[String]), clock) val driverService = configurationStep @@ -176,6 +180,7 @@ class DriverServiceFeatureStepSuite extends SparkFunSuite with BeforeAndAfter { Map.empty, Map.empty, Map.empty, + Nil, Seq.empty[String]), clock) fail("The driver bind address should not be allowed.") @@ -201,6 +206,7 @@ class DriverServiceFeatureStepSuite extends SparkFunSuite with BeforeAndAfter { Map.empty, Map.empty, Map.empty, + Nil, Seq.empty[String]), clock) fail("The driver host address should not be allowed.") diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/EnvSecretsFeatureStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/EnvSecretsFeatureStepSuite.scala index af6b35eae484a..85c6cb282d2b0 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/EnvSecretsFeatureStepSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/EnvSecretsFeatureStepSuite.scala @@ -37,7 +37,7 @@ class EnvSecretsFeatureStepSuite extends SparkFunSuite{ val sparkConf = new SparkConf(false) val kubernetesConf = KubernetesConf( sparkConf, - KubernetesExecutorSpecificConf("1", new PodBuilder().build()), + KubernetesExecutorSpecificConf("1", Some(new PodBuilder().build())), "resource-name-prefix", "app-id", Map.empty, @@ -45,6 +45,7 @@ class EnvSecretsFeatureStepSuite extends SparkFunSuite{ Map.empty, envVarsToKeys, Map.empty, + Nil, Seq.empty[String]) val step = new EnvSecretsFeatureStep(kubernetesConf) diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/LocalDirsFeatureStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/LocalDirsFeatureStepSuite.scala index bd6ce4b42fc8e..acdd07bc594b2 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/LocalDirsFeatureStepSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/LocalDirsFeatureStepSuite.scala @@ -18,10 +18,12 @@ package org.apache.spark.deploy.k8s.features import io.fabric8.kubernetes.api.model.{EnvVarBuilder, VolumeBuilder, VolumeMountBuilder} import org.mockito.Mockito +import org.scalatest._ import org.scalatest.BeforeAndAfter import org.apache.spark.{SparkConf, SparkFunSuite} -import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesDriverSpecificConf, KubernetesExecutorSpecificConf, KubernetesRoleSpecificConf, SparkPod} +import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesDriverSpecificConf, KubernetesRoleSpecificConf, SparkPod} +import org.apache.spark.deploy.k8s.Config._ class LocalDirsFeatureStepSuite extends SparkFunSuite with BeforeAndAfter { private val defaultLocalDir = "/var/data/default-local-dir" @@ -45,6 +47,7 @@ class LocalDirsFeatureStepSuite extends SparkFunSuite with BeforeAndAfter { Map.empty, Map.empty, Map.empty, + Nil, Seq.empty[String]) } @@ -110,4 +113,32 @@ class LocalDirsFeatureStepSuite extends SparkFunSuite with BeforeAndAfter { .withValue("/var/data/my-local-dir-1,/var/data/my-local-dir-2") .build()) } + + test("Use tmpfs to back default local dir") { + Mockito.doReturn(null).when(sparkConf).get("spark.local.dir") + Mockito.doReturn(null).when(sparkConf).getenv("SPARK_LOCAL_DIRS") + Mockito.doReturn(true).when(sparkConf).get(KUBERNETES_LOCAL_DIRS_TMPFS) + val stepUnderTest = new LocalDirsFeatureStep(kubernetesConf, defaultLocalDir) + val configuredPod = stepUnderTest.configurePod(SparkPod.initialPod()) + assert(configuredPod.pod.getSpec.getVolumes.size === 1) + assert(configuredPod.pod.getSpec.getVolumes.get(0) === + new VolumeBuilder() + .withName(s"spark-local-dir-1") + .withNewEmptyDir() + .withMedium("Memory") + .endEmptyDir() + .build()) + assert(configuredPod.container.getVolumeMounts.size === 1) + assert(configuredPod.container.getVolumeMounts.get(0) === + new VolumeMountBuilder() + .withName(s"spark-local-dir-1") + .withMountPath(defaultLocalDir) + .build()) + assert(configuredPod.container.getEnv.size === 1) + assert(configuredPod.container.getEnv.get(0) === + new EnvVarBuilder() + .withName("SPARK_LOCAL_DIRS") + .withValue(defaultLocalDir) + .build()) + } } diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/MountSecretsFeatureStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/MountSecretsFeatureStepSuite.scala index eff75b8a15daa..dad610c443acc 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/MountSecretsFeatureStepSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/MountSecretsFeatureStepSuite.scala @@ -35,7 +35,7 @@ class MountSecretsFeatureStepSuite extends SparkFunSuite { val sparkConf = new SparkConf(false) val kubernetesConf = KubernetesConf( sparkConf, - KubernetesExecutorSpecificConf("1", new PodBuilder().build()), + KubernetesExecutorSpecificConf("1", Some(new PodBuilder().build())), "resource-name-prefix", "app-id", Map.empty, @@ -43,6 +43,7 @@ class MountSecretsFeatureStepSuite extends SparkFunSuite { secretNamesToMountPaths, Map.empty, Map.empty, + Nil, Seq.empty[String]) val step = new MountSecretsFeatureStep(kubernetesConf) diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/MountVolumesFeatureStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/MountVolumesFeatureStepSuite.scala new file mode 100644 index 0000000000000..d309aa94ec115 --- /dev/null +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/MountVolumesFeatureStepSuite.scala @@ -0,0 +1,144 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.features + +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.deploy.k8s._ + +class MountVolumesFeatureStepSuite extends SparkFunSuite { + private val sparkConf = new SparkConf(false) + private val emptyKubernetesConf = KubernetesConf( + sparkConf = sparkConf, + roleSpecificConf = KubernetesDriverSpecificConf( + None, + "app-name", + "main", + Seq.empty), + appResourceNamePrefix = "resource", + appId = "app-id", + roleLabels = Map.empty, + roleAnnotations = Map.empty, + roleSecretNamesToMountPaths = Map.empty, + roleSecretEnvNamesToKeyRefs = Map.empty, + roleEnvs = Map.empty, + roleVolumes = Nil, + sparkFiles = Nil) + + test("Mounts hostPath volumes") { + val volumeConf = KubernetesVolumeSpec( + "testVolume", + "/tmp", + false, + KubernetesHostPathVolumeConf("/hostPath/tmp") + ) + val kubernetesConf = emptyKubernetesConf.copy(roleVolumes = volumeConf :: Nil) + val step = new MountVolumesFeatureStep(kubernetesConf) + val configuredPod = step.configurePod(SparkPod.initialPod()) + + assert(configuredPod.pod.getSpec.getVolumes.size() === 1) + assert(configuredPod.pod.getSpec.getVolumes.get(0).getHostPath.getPath === "/hostPath/tmp") + assert(configuredPod.container.getVolumeMounts.size() === 1) + assert(configuredPod.container.getVolumeMounts.get(0).getMountPath === "/tmp") + assert(configuredPod.container.getVolumeMounts.get(0).getName === "testVolume") + assert(configuredPod.container.getVolumeMounts.get(0).getReadOnly === false) + } + + test("Mounts pesistentVolumeClaims") { + val volumeConf = KubernetesVolumeSpec( + "testVolume", + "/tmp", + true, + KubernetesPVCVolumeConf("pvcClaim") + ) + val kubernetesConf = emptyKubernetesConf.copy(roleVolumes = volumeConf :: Nil) + val step = new MountVolumesFeatureStep(kubernetesConf) + val configuredPod = step.configurePod(SparkPod.initialPod()) + + assert(configuredPod.pod.getSpec.getVolumes.size() === 1) + val pvcClaim = configuredPod.pod.getSpec.getVolumes.get(0).getPersistentVolumeClaim + assert(pvcClaim.getClaimName === "pvcClaim") + assert(configuredPod.container.getVolumeMounts.size() === 1) + assert(configuredPod.container.getVolumeMounts.get(0).getMountPath === "/tmp") + assert(configuredPod.container.getVolumeMounts.get(0).getName === "testVolume") + assert(configuredPod.container.getVolumeMounts.get(0).getReadOnly === true) + + } + + test("Mounts emptyDir") { + val volumeConf = KubernetesVolumeSpec( + "testVolume", + "/tmp", + false, + KubernetesEmptyDirVolumeConf(Some("Memory"), Some("6G")) + ) + val kubernetesConf = emptyKubernetesConf.copy(roleVolumes = volumeConf :: Nil) + val step = new MountVolumesFeatureStep(kubernetesConf) + val configuredPod = step.configurePod(SparkPod.initialPod()) + + assert(configuredPod.pod.getSpec.getVolumes.size() === 1) + val emptyDir = configuredPod.pod.getSpec.getVolumes.get(0).getEmptyDir + assert(emptyDir.getMedium === "Memory") + assert(emptyDir.getSizeLimit.getAmount === "6G") + assert(configuredPod.container.getVolumeMounts.size() === 1) + assert(configuredPod.container.getVolumeMounts.get(0).getMountPath === "/tmp") + assert(configuredPod.container.getVolumeMounts.get(0).getName === "testVolume") + assert(configuredPod.container.getVolumeMounts.get(0).getReadOnly === false) + } + + test("Mounts emptyDir with no options") { + val volumeConf = KubernetesVolumeSpec( + "testVolume", + "/tmp", + false, + KubernetesEmptyDirVolumeConf(None, None) + ) + val kubernetesConf = emptyKubernetesConf.copy(roleVolumes = volumeConf :: Nil) + val step = new MountVolumesFeatureStep(kubernetesConf) + val configuredPod = step.configurePod(SparkPod.initialPod()) + + assert(configuredPod.pod.getSpec.getVolumes.size() === 1) + val emptyDir = configuredPod.pod.getSpec.getVolumes.get(0).getEmptyDir + assert(emptyDir.getMedium === "") + assert(emptyDir.getSizeLimit.getAmount === null) + assert(configuredPod.container.getVolumeMounts.size() === 1) + assert(configuredPod.container.getVolumeMounts.get(0).getMountPath === "/tmp") + assert(configuredPod.container.getVolumeMounts.get(0).getName === "testVolume") + assert(configuredPod.container.getVolumeMounts.get(0).getReadOnly === false) + } + + test("Mounts multiple volumes") { + val hpVolumeConf = KubernetesVolumeSpec( + "hpVolume", + "/tmp", + false, + KubernetesHostPathVolumeConf("/hostPath/tmp") + ) + val pvcVolumeConf = KubernetesVolumeSpec( + "checkpointVolume", + "/checkpoints", + true, + KubernetesPVCVolumeConf("pvcClaim") + ) + val volumesConf = hpVolumeConf :: pvcVolumeConf :: Nil + val kubernetesConf = emptyKubernetesConf.copy(roleVolumes = volumesConf) + val step = new MountVolumesFeatureStep(kubernetesConf) + val configuredPod = step.configurePod(SparkPod.initialPod()) + + assert(configuredPod.pod.getSpec.getVolumes.size() === 2) + assert(configuredPod.container.getVolumeMounts.size() === 2) + } +} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/bindings/JavaDriverFeatureStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/bindings/JavaDriverFeatureStepSuite.scala index 0f2bf2fa1d9b5..bf552aeb8b901 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/bindings/JavaDriverFeatureStepSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/bindings/JavaDriverFeatureStepSuite.scala @@ -42,6 +42,7 @@ class JavaDriverFeatureStepSuite extends SparkFunSuite { roleSecretNamesToMountPaths = Map.empty, roleSecretEnvNamesToKeyRefs = Map.empty, roleEnvs = Map.empty, + roleVolumes = Nil, sparkFiles = Seq.empty[String]) val step = new JavaDriverFeatureStep(kubernetesConf) @@ -55,6 +56,5 @@ class JavaDriverFeatureStepSuite extends SparkFunSuite { "--properties-file", SPARK_CONF_PATH, "--class", "test-class", "spark-internal", "5 7")) - } } diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/bindings/PythonDriverFeatureStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/bindings/PythonDriverFeatureStepSuite.scala index a1f9a5d9e264e..c14af1d3b0f01 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/bindings/PythonDriverFeatureStepSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/bindings/PythonDriverFeatureStepSuite.scala @@ -44,7 +44,7 @@ class PythonDriverFeatureStepSuite extends SparkFunSuite { Some(PythonMainAppResource("local:///main.py")), "test-app", "python-runner", - Seq("5 7")), + Seq("5", "7", "9")), appResourceNamePrefix = "", appId = "", roleLabels = Map.empty, @@ -52,6 +52,7 @@ class PythonDriverFeatureStepSuite extends SparkFunSuite { roleSecretNamesToMountPaths = Map.empty, roleSecretEnvNamesToKeyRefs = Map.empty, roleEnvs = Map.empty, + roleVolumes = Nil, sparkFiles = Seq.empty[String]) val step = new PythonDriverFeatureStep(kubernetesConf) @@ -65,7 +66,7 @@ class PythonDriverFeatureStepSuite extends SparkFunSuite { .toMap assert(envs(ENV_PYSPARK_PRIMARY) === expectedMainResource) assert(envs(ENV_PYSPARK_FILES) === expectedPySparkFiles) - assert(envs(ENV_PYSPARK_ARGS) === "5 7") + assert(envs(ENV_PYSPARK_ARGS) === "5 7 9") assert(envs(ENV_PYSPARK_MAJOR_PYTHON_VERSION) === "2") } test("Python Step testing empty pyfiles") { @@ -88,6 +89,7 @@ class PythonDriverFeatureStepSuite extends SparkFunSuite { roleSecretNamesToMountPaths = Map.empty, roleSecretEnvNamesToKeyRefs = Map.empty, roleEnvs = Map.empty, + roleVolumes = Nil, sparkFiles = Seq.empty[String]) val step = new PythonDriverFeatureStep(kubernetesConf) val driverContainerwithPySpark = step.configurePod(baseDriverPod).container diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/bindings/RDriverFeatureStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/bindings/RDriverFeatureStepSuite.scala new file mode 100644 index 0000000000000..ace0faa8629c3 --- /dev/null +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/bindings/RDriverFeatureStepSuite.scala @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.features.bindings + +import scala.collection.JavaConverters._ + +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesDriverSpecificConf, SparkPod} +import org.apache.spark.deploy.k8s.Config._ +import org.apache.spark.deploy.k8s.Constants._ +import org.apache.spark.deploy.k8s.submit.RMainAppResource + +class RDriverFeatureStepSuite extends SparkFunSuite { + + test("R Step modifies container correctly") { + val expectedMainResource = "/main.R" + val mainResource = "local:///main.R" + val baseDriverPod = SparkPod.initialPod() + val sparkConf = new SparkConf(false) + .set(KUBERNETES_R_MAIN_APP_RESOURCE, mainResource) + val kubernetesConf = KubernetesConf( + sparkConf, + KubernetesDriverSpecificConf( + Some(RMainAppResource(mainResource)), + "test-app", + "r-runner", + Seq("5", "7", "9")), + appResourceNamePrefix = "", + appId = "", + roleLabels = Map.empty, + roleAnnotations = Map.empty, + roleSecretNamesToMountPaths = Map.empty, + roleSecretEnvNamesToKeyRefs = Map.empty, + roleEnvs = Map.empty, + roleVolumes = Seq.empty, + sparkFiles = Seq.empty[String]) + + val step = new RDriverFeatureStep(kubernetesConf) + val driverContainerwithR = step.configurePod(baseDriverPod).container + assert(driverContainerwithR.getEnv.size === 2) + val envs = driverContainerwithR + .getEnv + .asScala + .map(env => (env.getName, env.getValue)) + .toMap + assert(envs(ENV_R_PRIMARY) === expectedMainResource) + assert(envs(ENV_R_ARGS) === "5 7 9") + } +} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/ClientSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/ClientSuite.scala index d045d9ae89c07..4d8e79189ff32 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/ClientSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/ClientSuite.scala @@ -141,6 +141,7 @@ class ClientSuite extends SparkFunSuite with BeforeAndAfter { Map.empty, Map.empty, Map.empty, + Nil, Seq.empty[String]) when(driverBuilder.buildFromFeatures(kubernetesConf)).thenReturn(BUILT_KUBERNETES_SPEC) when(kubernetesClient.pods()).thenReturn(podOperations) diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilderSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilderSuite.scala index 4e8c300543430..4117c5487a41e 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilderSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilderSuite.scala @@ -17,9 +17,10 @@ package org.apache.spark.deploy.k8s.submit import org.apache.spark.{SparkConf, SparkFunSuite} -import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesDriverSpec, KubernetesDriverSpecificConf} +import org.apache.spark.deploy.k8s._ +import org.apache.spark.deploy.k8s.features._ import org.apache.spark.deploy.k8s.features.{BasicDriverFeatureStep, DriverKubernetesCredentialsFeatureStep, DriverServiceFeatureStep, EnvSecretsFeatureStep, KubernetesFeaturesTestUtils, LocalDirsFeatureStep, MountSecretsFeatureStep} -import org.apache.spark.deploy.k8s.features.bindings.{JavaDriverFeatureStep, PythonDriverFeatureStep} +import org.apache.spark.deploy.k8s.features.bindings.{JavaDriverFeatureStep, PythonDriverFeatureStep, RDriverFeatureStep} class KubernetesDriverBuilderSuite extends SparkFunSuite { @@ -30,7 +31,9 @@ class KubernetesDriverBuilderSuite extends SparkFunSuite { private val SECRETS_STEP_TYPE = "mount-secrets" private val JAVA_STEP_TYPE = "java-bindings" private val PYSPARK_STEP_TYPE = "pyspark-bindings" + private val R_STEP_TYPE = "r-bindings" private val ENV_SECRETS_STEP_TYPE = "env-secrets" + private val MOUNT_VOLUMES_STEP_TYPE = "mount-volumes" private val basicFeatureStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType( BASIC_STEP_TYPE, classOf[BasicDriverFeatureStep]) @@ -53,9 +56,15 @@ class KubernetesDriverBuilderSuite extends SparkFunSuite { private val pythonStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType( PYSPARK_STEP_TYPE, classOf[PythonDriverFeatureStep]) + private val rStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType( + R_STEP_TYPE, classOf[RDriverFeatureStep]) + private val envSecretsStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType( ENV_SECRETS_STEP_TYPE, classOf[EnvSecretsFeatureStep]) + private val mountVolumesStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType( + MOUNT_VOLUMES_STEP_TYPE, classOf[MountVolumesFeatureStep]) + private val builderUnderTest: KubernetesDriverBuilder = new KubernetesDriverBuilder( _ => basicFeatureStep, @@ -64,8 +73,10 @@ class KubernetesDriverBuilderSuite extends SparkFunSuite { _ => secretsStep, _ => envSecretsStep, _ => localDirsStep, - _ => javaStep, - _ => pythonStep) + _ => mountVolumesStep, + _ => pythonStep, + _ => rStep, + _ => javaStep) test("Apply fundamental steps all the time.") { val conf = KubernetesConf( @@ -82,6 +93,7 @@ class KubernetesDriverBuilderSuite extends SparkFunSuite { Map.empty, Map.empty, Map.empty, + Nil, Seq.empty[String]) validateStepTypesApplied( builderUnderTest.buildFromFeatures(conf), @@ -107,6 +119,7 @@ class KubernetesDriverBuilderSuite extends SparkFunSuite { Map("secret" -> "secretMountPath"), Map("EnvName" -> "SecretName:secretKey"), Map.empty, + Nil, Seq.empty[String]) validateStepTypesApplied( builderUnderTest.buildFromFeatures(conf), @@ -134,6 +147,7 @@ class KubernetesDriverBuilderSuite extends SparkFunSuite { Map.empty, Map.empty, Map.empty, + Nil, Seq.empty[String]) validateStepTypesApplied( builderUnderTest.buildFromFeatures(conf), @@ -159,6 +173,7 @@ class KubernetesDriverBuilderSuite extends SparkFunSuite { Map.empty, Map.empty, Map.empty, + Nil, Seq.empty[String]) validateStepTypesApplied( builderUnderTest.buildFromFeatures(conf), @@ -169,6 +184,64 @@ class KubernetesDriverBuilderSuite extends SparkFunSuite { PYSPARK_STEP_TYPE) } + test("Apply volumes step if mounts are present.") { + val volumeSpec = KubernetesVolumeSpec( + "volume", + "/tmp", + false, + KubernetesHostPathVolumeConf("/path")) + val conf = KubernetesConf( + new SparkConf(false), + KubernetesDriverSpecificConf( + None, + "test-app", + "main", + Seq.empty), + "prefix", + "appId", + Map.empty, + Map.empty, + Map.empty, + Map.empty, + Map.empty, + volumeSpec :: Nil, + Seq.empty[String]) + validateStepTypesApplied( + builderUnderTest.buildFromFeatures(conf), + BASIC_STEP_TYPE, + CREDENTIALS_STEP_TYPE, + SERVICE_STEP_TYPE, + LOCAL_DIRS_STEP_TYPE, + MOUNT_VOLUMES_STEP_TYPE, + JAVA_STEP_TYPE) + } + + test("Apply R step if main resource is R.") { + val conf = KubernetesConf( + new SparkConf(false), + KubernetesDriverSpecificConf( + Some(RMainAppResource("example.R")), + "test-app", + "main", + Seq.empty), + "prefix", + "appId", + Map.empty, + Map.empty, + Map.empty, + Map.empty, + Map.empty, + Nil, + Seq.empty[String]) + validateStepTypesApplied( + builderUnderTest.buildFromFeatures(conf), + BASIC_STEP_TYPE, + CREDENTIALS_STEP_TYPE, + SERVICE_STEP_TYPE, + LOCAL_DIRS_STEP_TYPE, + R_STEP_TYPE) + } + private def validateStepTypesApplied(resolvedSpec: KubernetesDriverSpec, stepTypes: String*) : Unit = { assert(resolvedSpec.systemProperties.size === stepTypes.size) diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsAllocatorSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsAllocatorSuite.scala index 0c19f5946b75f..0e617b0021019 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsAllocatorSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsAllocatorSuite.scala @@ -166,14 +166,24 @@ class ExecutorPodsAllocatorSuite extends SparkFunSuite with BeforeAndAfter { conf, executorSpecificConf.executorId, TEST_SPARK_APP_ID, - driverPod) - k8sConf.sparkConf.getAll.toMap == conf.getAll.toMap && + Some(driverPod)) + + // Set prefixes to a common string since KUBERNETES_EXECUTOR_POD_NAME_PREFIX + // has not be set for the tests and thus KubernetesConf will use a random + // string for the prefix, based on the app name, and this comparison here will fail. + val k8sConfCopy = k8sConf + .copy(appResourceNamePrefix = "") + .copy(sparkConf = conf) + val expectedK8sConfCopy = expectedK8sConf + .copy(appResourceNamePrefix = "") + .copy(sparkConf = conf) + + k8sConf.sparkConf.getAll.toMap == conf.getAll.toMap && // Since KubernetesConf.createExecutorConf clones the SparkConf object, force // deep equality comparison for the SparkConf object and use object equality // comparison on all other fields. - k8sConf.copy(sparkConf = conf) == expectedK8sConf.copy(sparkConf = conf) + k8sConfCopy == expectedK8sConfCopy } } }) - } diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsLifecycleManagerSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsLifecycleManagerSuite.scala index 562ace9f49d4d..d8409383b4a1c 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsLifecycleManagerSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsLifecycleManagerSuite.scala @@ -31,6 +31,7 @@ import scala.collection.mutable import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.deploy.k8s.Fabric8Aliases._ +import org.apache.spark.deploy.k8s.KubernetesUtils._ import org.apache.spark.scheduler.ExecutorExited import org.apache.spark.scheduler.cluster.k8s.ExecutorLifecycleTestUtils._ @@ -104,13 +105,15 @@ class ExecutorPodsLifecycleManagerSuite extends SparkFunSuite with BeforeAndAfte } private def exitReasonMessage(failedExecutorId: Int, failedPod: Pod): String = { + val reason = Option(failedPod.getStatus.getReason) + val message = Option(failedPod.getStatus.getMessage) s""" |The executor with id $failedExecutorId exited with exit code 1. - |The API gave the following brief reason: ${failedPod.getStatus.getReason} - |The API gave the following message: ${failedPod.getStatus.getMessage} + |The API gave the following brief reason: ${reason.getOrElse("N/A")} + |The API gave the following message: ${message.getOrElse("N/A")} |The API gave the following container statuses: | - |${failedPod.getStatus.getContainerStatuses.asScala.map(_.toString).mkString("\n===\n")} + |${containersDescription(failedPod)} """.stripMargin } diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilderSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilderSuite.scala index a6bc8bce32926..44fe4a24e1102 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilderSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilderSuite.scala @@ -19,14 +19,15 @@ package org.apache.spark.scheduler.cluster.k8s import io.fabric8.kubernetes.api.model.PodBuilder import org.apache.spark.{SparkConf, SparkFunSuite} -import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesExecutorSpecificConf, SparkPod} -import org.apache.spark.deploy.k8s.features.{BasicExecutorFeatureStep, EnvSecretsFeatureStep, KubernetesFeaturesTestUtils, LocalDirsFeatureStep, MountSecretsFeatureStep} +import org.apache.spark.deploy.k8s._ +import org.apache.spark.deploy.k8s.features._ class KubernetesExecutorBuilderSuite extends SparkFunSuite { private val BASIC_STEP_TYPE = "basic" private val SECRETS_STEP_TYPE = "mount-secrets" private val ENV_SECRETS_STEP_TYPE = "env-secrets" private val LOCAL_DIRS_STEP_TYPE = "local-dirs" + private val MOUNT_VOLUMES_STEP_TYPE = "mount-volumes" private val basicFeatureStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType( BASIC_STEP_TYPE, classOf[BasicExecutorFeatureStep]) @@ -36,18 +37,21 @@ class KubernetesExecutorBuilderSuite extends SparkFunSuite { ENV_SECRETS_STEP_TYPE, classOf[EnvSecretsFeatureStep]) private val localDirsStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType( LOCAL_DIRS_STEP_TYPE, classOf[LocalDirsFeatureStep]) + private val mountVolumesStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType( + MOUNT_VOLUMES_STEP_TYPE, classOf[MountVolumesFeatureStep]) private val builderUnderTest = new KubernetesExecutorBuilder( _ => basicFeatureStep, _ => mountSecretsStep, _ => envSecretsStep, - _ => localDirsStep) + _ => localDirsStep, + _ => mountVolumesStep) test("Basic steps are consistently applied.") { val conf = KubernetesConf( new SparkConf(false), KubernetesExecutorSpecificConf( - "executor-id", new PodBuilder().build()), + "executor-id", Some(new PodBuilder().build())), "prefix", "appId", Map.empty, @@ -55,6 +59,7 @@ class KubernetesExecutorBuilderSuite extends SparkFunSuite { Map.empty, Map.empty, Map.empty, + Nil, Seq.empty[String]) validateStepTypesApplied( builderUnderTest.buildFromFeatures(conf), BASIC_STEP_TYPE, LOCAL_DIRS_STEP_TYPE) @@ -64,7 +69,7 @@ class KubernetesExecutorBuilderSuite extends SparkFunSuite { val conf = KubernetesConf( new SparkConf(false), KubernetesExecutorSpecificConf( - "executor-id", new PodBuilder().build()), + "executor-id", Some(new PodBuilder().build())), "prefix", "appId", Map.empty, @@ -72,6 +77,7 @@ class KubernetesExecutorBuilderSuite extends SparkFunSuite { Map("secret" -> "secretMountPath"), Map("secret-name" -> "secret-key"), Map.empty, + Nil, Seq.empty[String]) validateStepTypesApplied( builderUnderTest.buildFromFeatures(conf), @@ -81,6 +87,32 @@ class KubernetesExecutorBuilderSuite extends SparkFunSuite { ENV_SECRETS_STEP_TYPE) } + test("Apply volumes step if mounts are present.") { + val volumeSpec = KubernetesVolumeSpec( + "volume", + "/tmp", + false, + KubernetesHostPathVolumeConf("/checkpoint")) + val conf = KubernetesConf( + new SparkConf(false), + KubernetesExecutorSpecificConf( + "executor-id", Some(new PodBuilder().build())), + "prefix", + "appId", + Map.empty, + Map.empty, + Map.empty, + Map.empty, + Map.empty, + volumeSpec :: Nil, + Seq.empty[String]) + validateStepTypesApplied( + builderUnderTest.buildFromFeatures(conf), + BASIC_STEP_TYPE, + LOCAL_DIRS_STEP_TYPE, + MOUNT_VOLUMES_STEP_TYPE) + } + private def validateStepTypesApplied(resolvedPod: SparkPod, stepTypes: String*): Unit = { assert(resolvedPod.pod.getMetadata.getLabels.size === stepTypes.size) stepTypes.foreach { stepType => diff --git a/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/Dockerfile b/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/Dockerfile index 9badf8556afc3..7ae57bf6e42d0 100644 --- a/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/Dockerfile +++ b/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/Dockerfile @@ -29,12 +29,13 @@ ARG img_path=kubernetes/dockerfiles RUN set -ex && \ apk upgrade --no-cache && \ - apk add --no-cache bash tini libc6-compat && \ + apk add --no-cache bash tini libc6-compat linux-pam && \ mkdir -p /opt/spark && \ - mkdir -p /opt/spark/work-dir \ + mkdir -p /opt/spark/work-dir && \ touch /opt/spark/RELEASE && \ rm /bin/sh && \ ln -sv /bin/bash /bin/sh && \ + echo "auth required pam_wheel.so use_uid" >> /etc/pam.d/su && \ chgrp root /etc/passwd && chmod ug+rw /etc/passwd COPY ${spark_jars} /opt/spark/jars @@ -42,6 +43,7 @@ COPY bin /opt/spark/bin COPY sbin /opt/spark/sbin COPY ${img_path}/spark/entrypoint.sh /opt/ COPY examples /opt/spark/examples +COPY kubernetes/tests /opt/spark/tests COPY data /opt/spark/data ENV SPARK_HOME /opt/spark diff --git a/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/bindings/R/Dockerfile b/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/bindings/R/Dockerfile new file mode 100644 index 0000000000000..9f67422efeb3c --- /dev/null +++ b/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/bindings/R/Dockerfile @@ -0,0 +1,29 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +ARG base_img +FROM $base_img +WORKDIR / +RUN mkdir ${SPARK_HOME}/R + +RUN apk add --no-cache R R-dev + +COPY R ${SPARK_HOME}/R +ENV R_HOME /usr/lib/R + +WORKDIR /opt/spark/work-dir +ENTRYPOINT [ "/opt/entrypoint.sh" ] diff --git a/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/bindings/python/Dockerfile b/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/bindings/python/Dockerfile index 72bb9620b45de..69b6efa6149a0 100644 --- a/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/bindings/python/Dockerfile +++ b/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/bindings/python/Dockerfile @@ -19,7 +19,6 @@ ARG base_img FROM $base_img WORKDIR / RUN mkdir ${SPARK_HOME}/python -COPY python/lib ${SPARK_HOME}/python/lib # TODO: Investigate running both pip and pip3 via virtualenvs RUN apk add --no-cache python && \ apk add --no-cache python3 && \ @@ -33,6 +32,7 @@ RUN apk add --no-cache python && \ # Removed the .cache to save space rm -r /root/.cache +COPY python/lib ${SPARK_HOME}/python/lib ENV PYTHONPATH ${SPARK_HOME}/python/lib/pyspark.zip:${SPARK_HOME}/python/lib/py4j-*.zip WORKDIR /opt/spark/work-dir diff --git a/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/entrypoint.sh b/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/entrypoint.sh index 2f4e115e84ecd..216e8fe31becb 100755 --- a/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/entrypoint.sh +++ b/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/entrypoint.sh @@ -38,7 +38,7 @@ fi SPARK_K8S_CMD="$1" case "$SPARK_K8S_CMD" in - driver | driver-py | executor) + driver | driver-py | driver-r | executor) shift 1 ;; "") @@ -51,12 +51,10 @@ esac SPARK_CLASSPATH="$SPARK_CLASSPATH:${SPARK_HOME}/jars/*" env | grep SPARK_JAVA_OPT_ | sort -t_ -k4 -n | sed 's/[^=]*=\(.*\)/\1/g' > /tmp/java_opts.txt -readarray -t SPARK_JAVA_OPTS < /tmp/java_opts.txt -if [ -n "$SPARK_MOUNTED_CLASSPATH" ]; then - SPARK_CLASSPATH="$SPARK_CLASSPATH:$SPARK_MOUNTED_CLASSPATH" -fi -if [ -n "$SPARK_MOUNTED_FILES_DIR" ]; then - cp -R "$SPARK_MOUNTED_FILES_DIR/." . +readarray -t SPARK_EXECUTOR_JAVA_OPTS < /tmp/java_opts.txt + +if [ -n "$SPARK_EXTRA_CLASSPATH" ]; then + SPARK_CLASSPATH="$SPARK_CLASSPATH:$SPARK_EXTRA_CLASSPATH" fi if [ -n "$PYSPARK_FILES" ]; then @@ -68,6 +66,10 @@ if [ -n "$PYSPARK_APP_ARGS" ]; then PYSPARK_ARGS="$PYSPARK_APP_ARGS" fi +R_ARGS="" +if [ -n "$R_APP_ARGS" ]; then + R_ARGS="$R_APP_ARGS" +fi if [ "$PYSPARK_MAJOR_PYTHON_VERSION" == "2" ]; then pyv="$(python -V 2>&1)" @@ -98,10 +100,18 @@ case "$SPARK_K8S_CMD" in "$@" $PYSPARK_PRIMARY $PYSPARK_ARGS ) ;; + driver-r) + CMD=( + "$SPARK_HOME/bin/spark-submit" + --conf "spark.driver.bindAddress=$SPARK_DRIVER_BIND_ADDRESS" + --deploy-mode client + "$@" $R_PRIMARY $R_ARGS + ) + ;; executor) CMD=( ${JAVA_HOME}/bin/java - "${SPARK_JAVA_OPTS[@]}" + "${SPARK_EXECUTOR_JAVA_OPTS[@]}" -Xms$SPARK_EXECUTOR_MEMORY -Xmx$SPARK_EXECUTOR_MEMORY -cp "$SPARK_CLASSPATH" diff --git a/resource-managers/kubernetes/integration-tests/dev/dev-run-integration-tests.sh b/resource-managers/kubernetes/integration-tests/dev/dev-run-integration-tests.sh index ea893fa39eede..b28b8b82ca016 100755 --- a/resource-managers/kubernetes/integration-tests/dev/dev-run-integration-tests.sh +++ b/resource-managers/kubernetes/integration-tests/dev/dev-run-integration-tests.sh @@ -16,9 +16,8 @@ # See the License for the specific language governing permissions and # limitations under the License. # -TEST_ROOT_DIR=$(git rev-parse --show-toplevel)/resource-managers/kubernetes/integration-tests - -cd "${TEST_ROOT_DIR}" +set -xo errexit +TEST_ROOT_DIR=$(git rev-parse --show-toplevel) DEPLOY_MODE="minikube" IMAGE_REPO="docker.io/kubespark" @@ -27,6 +26,8 @@ IMAGE_TAG="N/A" SPARK_MASTER= NAMESPACE= SERVICE_ACCOUNT= +INCLUDE_TAGS="k8s" +EXCLUDE_TAGS= # Parse arguments while (( "$#" )); do @@ -59,6 +60,14 @@ while (( "$#" )); do SERVICE_ACCOUNT="$2" shift ;; + --include-tags) + INCLUDE_TAGS="k8s,$2" + shift + ;; + --exclude-tags) + EXCLUDE_TAGS="$2" + shift + ;; *) break ;; @@ -66,13 +75,12 @@ while (( "$#" )); do shift done -cd $TEST_ROOT_DIR - properties=( -Dspark.kubernetes.test.sparkTgz=$SPARK_TGZ \ -Dspark.kubernetes.test.imageTag=$IMAGE_TAG \ -Dspark.kubernetes.test.imageRepo=$IMAGE_REPO \ - -Dspark.kubernetes.test.deployMode=$DEPLOY_MODE + -Dspark.kubernetes.test.deployMode=$DEPLOY_MODE \ + -Dtest.include.tags=$INCLUDE_TAGS ) if [ -n $NAMESPACE ]; @@ -90,4 +98,9 @@ then properties=( ${properties[@]} -Dspark.kubernetes.test.master=$SPARK_MASTER ) fi -../../../build/mvn integration-test ${properties[@]} +if [ -n $EXCLUDE_TAGS ]; +then + properties=( ${properties[@]} -Dtest.exclude.tags=$EXCLUDE_TAGS ) +fi + +$TEST_ROOT_DIR/build/mvn integration-test -f $TEST_ROOT_DIR/pom.xml -pl resource-managers/kubernetes/integration-tests -am -Pkubernetes -Phadoop-2.7 ${properties[@]} diff --git a/resource-managers/kubernetes/integration-tests/pom.xml b/resource-managers/kubernetes/integration-tests/pom.xml index 520bda89e034d..614705c1ed668 100644 --- a/resource-managers/kubernetes/integration-tests/pom.xml +++ b/resource-managers/kubernetes/integration-tests/pom.xml @@ -25,7 +25,6 @@ spark-kubernetes-integration-tests_2.11 - spark-kubernetes-integration-tests 1.3.0 1.4.0 @@ -40,6 +39,7 @@ minikube docker.io/kubespark + jar Spark Project Kubernetes Integration Tests @@ -62,6 +62,11 @@ kubernetes-client ${kubernetes-client.version} + + org.apache.spark + spark-tags_${scala.binary.version} + test-jar + @@ -102,6 +107,15 @@ + + + org.apache.maven.plugins + maven-surefire-plugin + + true + + + @@ -126,6 +140,7 @@ ${spark.kubernetes.test.serviceAccountName} ${test.exclude.tags} + ${test.include.tags} diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/BasicTestsSuite.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/BasicTestsSuite.scala new file mode 100644 index 0000000000000..4e749c40563dc --- /dev/null +++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/BasicTestsSuite.scala @@ -0,0 +1,106 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.integrationtest + +import io.fabric8.kubernetes.api.model.Pod + +import org.apache.spark.launcher.SparkLauncher + +private[spark] trait BasicTestsSuite { k8sSuite: KubernetesSuite => + + import BasicTestsSuite._ + import KubernetesSuite.k8sTestTag + + test("Run SparkPi with no resources", k8sTestTag) { + runSparkPiAndVerifyCompletion() + } + + test("Run SparkPi with a very long application name.", k8sTestTag) { + sparkAppConf.set("spark.app.name", "long" * 40) + runSparkPiAndVerifyCompletion() + } + + test("Use SparkLauncher.NO_RESOURCE", k8sTestTag) { + sparkAppConf.setJars(Seq(containerLocalSparkDistroExamplesJar)) + runSparkPiAndVerifyCompletion( + appResource = SparkLauncher.NO_RESOURCE) + } + + test("Run SparkPi with a master URL without a scheme.", k8sTestTag) { + val url = kubernetesTestComponents.kubernetesClient.getMasterUrl + val k8sMasterUrl = if (url.getPort < 0) { + s"k8s://${url.getHost}" + } else { + s"k8s://${url.getHost}:${url.getPort}" + } + sparkAppConf.set("spark.master", k8sMasterUrl) + runSparkPiAndVerifyCompletion() + } + + test("Run SparkPi with an argument.", k8sTestTag) { + runSparkPiAndVerifyCompletion(appArgs = Array("5")) + } + + test("Run SparkPi with custom labels, annotations, and environment variables.", k8sTestTag) { + sparkAppConf + .set("spark.kubernetes.driver.label.label1", "label1-value") + .set("spark.kubernetes.driver.label.label2", "label2-value") + .set("spark.kubernetes.driver.annotation.annotation1", "annotation1-value") + .set("spark.kubernetes.driver.annotation.annotation2", "annotation2-value") + .set("spark.kubernetes.driverEnv.ENV1", "VALUE1") + .set("spark.kubernetes.driverEnv.ENV2", "VALUE2") + .set("spark.kubernetes.executor.label.label1", "label1-value") + .set("spark.kubernetes.executor.label.label2", "label2-value") + .set("spark.kubernetes.executor.annotation.annotation1", "annotation1-value") + .set("spark.kubernetes.executor.annotation.annotation2", "annotation2-value") + .set("spark.executorEnv.ENV1", "VALUE1") + .set("spark.executorEnv.ENV2", "VALUE2") + + runSparkPiAndVerifyCompletion( + driverPodChecker = (driverPod: Pod) => { + doBasicDriverPodCheck(driverPod) + checkCustomSettings(driverPod) + }, + executorPodChecker = (executorPod: Pod) => { + doBasicExecutorPodCheck(executorPod) + checkCustomSettings(executorPod) + }) + } + + test("Run extraJVMOptions check on driver", k8sTestTag) { + sparkAppConf + .set("spark.driver.extraJavaOptions", "-Dspark.test.foo=spark.test.bar") + runSparkJVMCheckAndVerifyCompletion( + expectedJVMValue = Seq("(spark.test.foo,spark.test.bar)")) + } + + test("Run SparkRemoteFileTest using a remote data file", k8sTestTag) { + sparkAppConf + .set("spark.files", REMOTE_PAGE_RANK_DATA_FILE) + runSparkRemoteCheckAndVerifyCompletion(appArgs = Array(REMOTE_PAGE_RANK_FILE_NAME)) + } +} + +private[spark] object BasicTestsSuite { + val SPARK_PAGE_RANK_MAIN_CLASS: String = "org.apache.spark.examples.SparkPageRank" + val CONTAINER_LOCAL_FILE_DOWNLOAD_PATH = "/var/spark-data/spark-files" + val CONTAINER_LOCAL_DOWNLOADED_PAGE_RANK_DATA_FILE = + s"$CONTAINER_LOCAL_FILE_DOWNLOAD_PATH/pagerank_data.txt" + val REMOTE_PAGE_RANK_DATA_FILE = + "https://storage.googleapis.com/spark-k8s-integration-tests/files/pagerank_data.txt" + val REMOTE_PAGE_RANK_FILE_NAME = "pagerank_data.txt" +} diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/ClientModeTestsSuite.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/ClientModeTestsSuite.scala new file mode 100644 index 0000000000000..c8bd584516ea5 --- /dev/null +++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/ClientModeTestsSuite.scala @@ -0,0 +1,111 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.integrationtest + +import org.scalatest.concurrent.Eventually +import scala.collection.JavaConverters._ + +import org.apache.spark.deploy.k8s.integrationtest.KubernetesSuite.{k8sTestTag, INTERVAL, TIMEOUT} + +private[spark] trait ClientModeTestsSuite { k8sSuite: KubernetesSuite => + + test("Run in client mode.", k8sTestTag) { + val labels = Map("spark-app-selector" -> driverPodName) + val driverPort = 7077 + val blockManagerPort = 10000 + val driverService = testBackend + .getKubernetesClient + .services() + .inNamespace(kubernetesTestComponents.namespace) + .createNew() + .withNewMetadata() + .withName(s"$driverPodName-svc") + .endMetadata() + .withNewSpec() + .withClusterIP("None") + .withSelector(labels.asJava) + .addNewPort() + .withName("driver-port") + .withPort(driverPort) + .withNewTargetPort(driverPort) + .endPort() + .addNewPort() + .withName("block-manager") + .withPort(blockManagerPort) + .withNewTargetPort(blockManagerPort) + .endPort() + .endSpec() + .done() + try { + val driverPod = testBackend + .getKubernetesClient + .pods() + .inNamespace(kubernetesTestComponents.namespace) + .createNew() + .withNewMetadata() + .withName(driverPodName) + .withLabels(labels.asJava) + .endMetadata() + .withNewSpec() + .withServiceAccountName(kubernetesTestComponents.serviceAccountName) + .addNewContainer() + .withName("spark-example") + .withImage(image) + .withImagePullPolicy("IfNotPresent") + .withCommand("/opt/spark/bin/run-example") + .addToArgs("--master", s"k8s://https://kubernetes.default.svc") + .addToArgs("--deploy-mode", "client") + .addToArgs("--conf", s"spark.kubernetes.container.image=$image") + .addToArgs( + "--conf", + s"spark.kubernetes.namespace=${kubernetesTestComponents.namespace}") + .addToArgs("--conf", "spark.kubernetes.authenticate.oauthTokenFile=" + + "/var/run/secrets/kubernetes.io/serviceaccount/token") + .addToArgs("--conf", "spark.kubernetes.authenticate.caCertFile=" + + "/var/run/secrets/kubernetes.io/serviceaccount/ca.crt") + .addToArgs("--conf", s"spark.kubernetes.driver.pod.name=$driverPodName") + .addToArgs("--conf", "spark.executor.memory=500m") + .addToArgs("--conf", "spark.executor.cores=1") + .addToArgs("--conf", "spark.executor.instances=1") + .addToArgs("--conf", + s"spark.driver.host=" + + s"${driverService.getMetadata.getName}.${kubernetesTestComponents.namespace}.svc") + .addToArgs("--conf", s"spark.driver.port=$driverPort") + .addToArgs("--conf", s"spark.driver.blockManager.port=$blockManagerPort") + .addToArgs("SparkPi") + .addToArgs("10") + .endContainer() + .endSpec() + .done() + Eventually.eventually(TIMEOUT, INTERVAL) { + assert(kubernetesTestComponents.kubernetesClient + .pods() + .withName(driverPodName) + .getLog + .contains("Pi is roughly 3"), "The application did not complete.") + } + } finally { + // Have to delete the service manually since it doesn't have an owner reference + kubernetesTestComponents + .kubernetesClient + .services() + .inNamespace(kubernetesTestComponents.namespace) + .delete(driverService) + } + } + +} diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala index 65c513cf241a4..18541baf05813 100644 --- a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala +++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala @@ -21,33 +21,48 @@ import java.nio.file.{Path, Paths} import java.util.UUID import java.util.regex.Pattern -import scala.collection.JavaConverters._ - import com.google.common.io.PatternFilenameFilter -import io.fabric8.kubernetes.api.model.{Container, Pod} -import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll} +import io.fabric8.kubernetes.api.model.Pod +import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll, Tag} import org.scalatest.concurrent.{Eventually, PatienceConfiguration} import org.scalatest.time.{Minutes, Seconds, Span} +import scala.collection.JavaConverters._ import org.apache.spark.SparkFunSuite +import org.apache.spark.deploy.k8s.integrationtest.TestConfig._ import org.apache.spark.deploy.k8s.integrationtest.backend.{IntegrationTestBackend, IntegrationTestBackendFactory} -import org.apache.spark.deploy.k8s.integrationtest.config._ private[spark] class KubernetesSuite extends SparkFunSuite - with BeforeAndAfterAll with BeforeAndAfter { + with BeforeAndAfterAll with BeforeAndAfter with BasicTestsSuite with SecretsTestsSuite + with PythonTestsSuite with ClientModeTestsSuite { import KubernetesSuite._ - private var testBackend: IntegrationTestBackend = _ private var sparkHomeDir: Path = _ - private var kubernetesTestComponents: KubernetesTestComponents = _ - private var sparkAppConf: SparkAppConf = _ - private var image: String = _ - private var containerLocalSparkDistroExamplesJar: String = _ - private var appLocator: String = _ - private var driverPodName: String = _ + private var pyImage: String = _ + private var rImage: String = _ + + protected var image: String = _ + protected var testBackend: IntegrationTestBackend = _ + protected var driverPodName: String = _ + protected var kubernetesTestComponents: KubernetesTestComponents = _ + protected var sparkAppConf: SparkAppConf = _ + protected var containerLocalSparkDistroExamplesJar: String = _ + protected var appLocator: String = _ + + // Default memory limit is 1024M + 384M (minimum overhead constant) + private val baseMemory = s"${1024 + 384}Mi" + protected val memOverheadConstant = 0.8 + private val standardNonJVMMemory = s"${(1024 + 0.4*1024).toInt}Mi" + protected val additionalMemory = 200 + // 209715200 is 200Mi + protected val additionalMemoryInBytes = 209715200 + private val extraDriverTotalMemory = s"${(1024 + memOverheadConstant*1024).toInt}Mi" + private val extraExecTotalMemory = + s"${(1024 + memOverheadConstant*1024 + additionalMemory).toInt}Mi" override def beforeAll(): Unit = { + super.beforeAll() // The scalatest-maven-plugin gives system properties that are referenced but not set null // values. We need to remove the null-value properties before initializing the test backend. val nullValueProperties = System.getProperties.asScala @@ -65,6 +80,8 @@ private[spark] class KubernetesSuite extends SparkFunSuite val imageTag = getTestImageTag val imageRepo = getTestImageRepo image = s"$imageRepo/spark:$imageTag" + pyImage = s"$imageRepo/spark-py:$imageTag" + rImage = s"$imageRepo/spark-r:$imageTag" val sparkDistroExamplesJarFile: File = sparkHomeDir.resolve(Paths.get("examples", "jars")) .toFile @@ -77,7 +94,11 @@ private[spark] class KubernetesSuite extends SparkFunSuite } override def afterAll(): Unit = { - testBackend.cleanUp() + try { + testBackend.cleanUp() + } finally { + super.afterAll() + } } before { @@ -100,72 +121,13 @@ private[spark] class KubernetesSuite extends SparkFunSuite deleteDriverPod() } - test("Run SparkPi with no resources") { - runSparkPiAndVerifyCompletion() - } - - test("Run SparkPi with a very long application name.") { - sparkAppConf.set("spark.app.name", "long" * 40) - runSparkPiAndVerifyCompletion() - } - - test("Run SparkPi with a master URL without a scheme.") { - val url = kubernetesTestComponents.kubernetesClient.getMasterUrl - val k8sMasterUrl = if (url.getPort < 0) { - s"k8s://${url.getHost}" - } else { - s"k8s://${url.getHost}:${url.getPort}" - } - sparkAppConf.set("spark.master", k8sMasterUrl) - runSparkPiAndVerifyCompletion() - } - - test("Run SparkPi with an argument.") { - runSparkPiAndVerifyCompletion(appArgs = Array("5")) - } - - test("Run SparkPi with custom labels, annotations, and environment variables.") { - sparkAppConf - .set("spark.kubernetes.driver.label.label1", "label1-value") - .set("spark.kubernetes.driver.label.label2", "label2-value") - .set("spark.kubernetes.driver.annotation.annotation1", "annotation1-value") - .set("spark.kubernetes.driver.annotation.annotation2", "annotation2-value") - .set("spark.kubernetes.driverEnv.ENV1", "VALUE1") - .set("spark.kubernetes.driverEnv.ENV2", "VALUE2") - .set("spark.kubernetes.executor.label.label1", "label1-value") - .set("spark.kubernetes.executor.label.label2", "label2-value") - .set("spark.kubernetes.executor.annotation.annotation1", "annotation1-value") - .set("spark.kubernetes.executor.annotation.annotation2", "annotation2-value") - .set("spark.executorEnv.ENV1", "VALUE1") - .set("spark.executorEnv.ENV2", "VALUE2") - - runSparkPiAndVerifyCompletion( - driverPodChecker = (driverPod: Pod) => { - doBasicDriverPodCheck(driverPod) - checkCustomSettings(driverPod) - }, - executorPodChecker = (executorPod: Pod) => { - doBasicExecutorPodCheck(executorPod) - checkCustomSettings(executorPod) - }) - } - - // TODO(ssuchter): Enable the below after debugging - // test("Run PageRank using remote data file") { - // sparkAppConf - // .set("spark.kubernetes.mountDependencies.filesDownloadDir", - // CONTAINER_LOCAL_FILE_DOWNLOAD_PATH) - // .set("spark.files", REMOTE_PAGE_RANK_DATA_FILE) - // runSparkPageRankAndVerifyCompletion( - // appArgs = Array(CONTAINER_LOCAL_DOWNLOADED_PAGE_RANK_DATA_FILE)) - // } - - private def runSparkPiAndVerifyCompletion( + protected def runSparkPiAndVerifyCompletion( appResource: String = containerLocalSparkDistroExamplesJar, driverPodChecker: Pod => Unit = doBasicDriverPodCheck, executorPodChecker: Pod => Unit = doBasicExecutorPodCheck, appArgs: Array[String] = Array.empty[String], - appLocator: String = appLocator): Unit = { + appLocator: String = appLocator, + isJVM: Boolean = true ): Unit = { runSparkApplicationAndVerifyCompletion( appResource, SPARK_PI_MAIN_CLASS, @@ -173,10 +135,11 @@ private[spark] class KubernetesSuite extends SparkFunSuite appArgs, driverPodChecker, executorPodChecker, - appLocator) + appLocator, + isJVM) } - private def runSparkPageRankAndVerifyCompletion( + protected def runSparkRemoteCheckAndVerifyCompletion( appResource: String = containerLocalSparkDistroExamplesJar, driverPodChecker: Pod => Unit = doBasicDriverPodCheck, executorPodChecker: Pod => Unit = doBasicExecutorPodCheck, @@ -184,27 +147,73 @@ private[spark] class KubernetesSuite extends SparkFunSuite appLocator: String = appLocator): Unit = { runSparkApplicationAndVerifyCompletion( appResource, - SPARK_PAGE_RANK_MAIN_CLASS, - Seq("1 has rank", "2 has rank", "3 has rank", "4 has rank"), + SPARK_REMOTE_MAIN_CLASS, + Seq(s"Mounting of ${appArgs.head} was true"), appArgs, driverPodChecker, executorPodChecker, - appLocator) + appLocator, + true) } - private def runSparkApplicationAndVerifyCompletion( + protected def runSparkJVMCheckAndVerifyCompletion( + appResource: String = containerLocalSparkDistroExamplesJar, + mainClass: String = SPARK_DRIVER_MAIN_CLASS, + driverPodChecker: Pod => Unit = doBasicDriverPodCheck, + appArgs: Array[String] = Array("5"), + expectedJVMValue: Seq[String]): Unit = { + val appArguments = SparkAppArguments( + mainAppResource = appResource, + mainClass = mainClass, + appArgs = appArgs) + SparkAppLauncher.launch( + appArguments, + sparkAppConf, + TIMEOUT.value.toSeconds.toInt, + sparkHomeDir, + true) + + val driverPod = kubernetesTestComponents.kubernetesClient + .pods() + .withLabel("spark-app-locator", appLocator) + .withLabel("spark-role", "driver") + .list() + .getItems + .get(0) + doBasicDriverPodCheck(driverPod) + + Eventually.eventually(TIMEOUT, INTERVAL) { + expectedJVMValue.foreach { e => + assert(kubernetesTestComponents.kubernetesClient + .pods() + .withName(driverPod.getMetadata.getName) + .getLog + .contains(e), "The application did not complete.") + } + } + } + + protected def runSparkApplicationAndVerifyCompletion( appResource: String, mainClass: String, expectedLogOnCompletion: Seq[String], appArgs: Array[String], driverPodChecker: Pod => Unit, executorPodChecker: Pod => Unit, - appLocator: String): Unit = { + appLocator: String, + isJVM: Boolean, + pyFiles: Option[String] = None): Unit = { val appArguments = SparkAppArguments( mainAppResource = appResource, mainClass = mainClass, appArgs = appArgs) - SparkAppLauncher.launch(appArguments, sparkAppConf, TIMEOUT.value.toSeconds.toInt, sparkHomeDir) + SparkAppLauncher.launch( + appArguments, + sparkAppConf, + TIMEOUT.value.toSeconds.toInt, + sparkHomeDir, + isJVM, + pyFiles) val driverPod = kubernetesTestComponents.kubernetesClient .pods() @@ -236,18 +245,64 @@ private[spark] class KubernetesSuite extends SparkFunSuite } } - private def doBasicDriverPodCheck(driverPod: Pod): Unit = { + protected def doBasicDriverPodCheck(driverPod: Pod): Unit = { assert(driverPod.getMetadata.getName === driverPodName) assert(driverPod.getSpec.getContainers.get(0).getImage === image) assert(driverPod.getSpec.getContainers.get(0).getName === "spark-kubernetes-driver") + assert(driverPod.getSpec.getContainers.get(0).getResources.getRequests.get("memory").getAmount + === baseMemory) } - private def doBasicExecutorPodCheck(executorPod: Pod): Unit = { + + protected def doBasicDriverPyPodCheck(driverPod: Pod): Unit = { + assert(driverPod.getMetadata.getName === driverPodName) + assert(driverPod.getSpec.getContainers.get(0).getImage === pyImage) + assert(driverPod.getSpec.getContainers.get(0).getName === "spark-kubernetes-driver") + assert(driverPod.getSpec.getContainers.get(0).getResources.getRequests.get("memory").getAmount + === standardNonJVMMemory) + } + + protected def doBasicDriverRPodCheck(driverPod: Pod): Unit = { + assert(driverPod.getMetadata.getName === driverPodName) + assert(driverPod.getSpec.getContainers.get(0).getImage === rImage) + assert(driverPod.getSpec.getContainers.get(0).getName === "spark-kubernetes-driver") + assert(driverPod.getSpec.getContainers.get(0).getResources.getRequests.get("memory").getAmount + === standardNonJVMMemory) + } + + + protected def doBasicExecutorPodCheck(executorPod: Pod): Unit = { assert(executorPod.getSpec.getContainers.get(0).getImage === image) assert(executorPod.getSpec.getContainers.get(0).getName === "executor") + assert(executorPod.getSpec.getContainers.get(0).getResources.getRequests.get("memory").getAmount + === baseMemory) } - private def checkCustomSettings(pod: Pod): Unit = { + protected def doBasicExecutorPyPodCheck(executorPod: Pod): Unit = { + assert(executorPod.getSpec.getContainers.get(0).getImage === pyImage) + assert(executorPod.getSpec.getContainers.get(0).getName === "executor") + assert(executorPod.getSpec.getContainers.get(0).getResources.getRequests.get("memory").getAmount + === standardNonJVMMemory) + } + + protected def doBasicExecutorRPodCheck(executorPod: Pod): Unit = { + assert(executorPod.getSpec.getContainers.get(0).getImage === rImage) + assert(executorPod.getSpec.getContainers.get(0).getName === "executor") + assert(executorPod.getSpec.getContainers.get(0).getResources.getRequests.get("memory").getAmount + === standardNonJVMMemory) + } + + protected def doDriverMemoryCheck(driverPod: Pod): Unit = { + assert(driverPod.getSpec.getContainers.get(0).getResources.getRequests.get("memory").getAmount + === extraDriverTotalMemory) + } + + protected def doExecutorMemoryCheck(executorPod: Pod): Unit = { + assert(executorPod.getSpec.getContainers.get(0).getResources.getRequests.get("memory").getAmount + === extraExecTotalMemory) + } + + protected def checkCustomSettings(pod: Pod): Unit = { assert(pod.getMetadata.getLabels.get("label1") === "label1-value") assert(pod.getMetadata.getLabels.get("label2") === "label2-value") assert(pod.getMetadata.getAnnotations.get("annotation1") === "annotation1-value") @@ -277,18 +332,10 @@ private[spark] class KubernetesSuite extends SparkFunSuite } private[spark] object KubernetesSuite { - + val k8sTestTag = Tag("k8s") + val SPARK_PI_MAIN_CLASS: String = "org.apache.spark.examples.SparkPi" + val SPARK_REMOTE_MAIN_CLASS: String = "org.apache.spark.examples.SparkRemoteFileTest" + val SPARK_DRIVER_MAIN_CLASS: String = "org.apache.spark.examples.DriverSubmissionTest" val TIMEOUT = PatienceConfiguration.Timeout(Span(2, Minutes)) val INTERVAL = PatienceConfiguration.Interval(Span(2, Seconds)) - val SPARK_PI_MAIN_CLASS: String = "org.apache.spark.examples.SparkPi" - val SPARK_PAGE_RANK_MAIN_CLASS: String = "org.apache.spark.examples.SparkPageRank" - - // val CONTAINER_LOCAL_FILE_DOWNLOAD_PATH = "/var/spark-data/spark-files" - - // val REMOTE_PAGE_RANK_DATA_FILE = - // "https://storage.googleapis.com/spark-k8s-integration-tests/files/pagerank_data.txt" - // val CONTAINER_LOCAL_DOWNLOADED_PAGE_RANK_DATA_FILE = - // s"$CONTAINER_LOCAL_FILE_DOWNLOAD_PATH/pagerank_data.txt" - - // case object ShuffleNotReadyException extends Exception } diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesTestComponents.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesTestComponents.scala index 48727142dd052..b602fdf39731f 100644 --- a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesTestComponents.scala +++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesTestComponents.scala @@ -32,7 +32,7 @@ private[spark] class KubernetesTestComponents(defaultClient: DefaultKubernetesCl val namespaceOption = Option(System.getProperty("spark.kubernetes.test.namespace")) val hasUserSpecifiedNamespace = namespaceOption.isDefined val namespace = namespaceOption.getOrElse(UUID.randomUUID().toString.replaceAll("-", "")) - private val serviceAccountName = + val serviceAccountName = Option(System.getProperty("spark.kubernetes.test.serviceAccountName")) .getOrElse("default") val kubernetesClient = defaultClient.inNamespace(namespace) @@ -97,24 +97,33 @@ private[spark] case class SparkAppArguments( appArgs: Array[String]) private[spark] object SparkAppLauncher extends Logging { - def launch( appArguments: SparkAppArguments, appConf: SparkAppConf, timeoutSecs: Int, - sparkHomeDir: Path): Unit = { + sparkHomeDir: Path, + isJVM: Boolean, + pyFiles: Option[String] = None): Unit = { val sparkSubmitExecutable = sparkHomeDir.resolve(Paths.get("bin", "spark-submit")) logInfo(s"Launching a spark app with arguments $appArguments and conf $appConf") - val appArgsArray = - if (appArguments.appArgs.length > 0) Array(appArguments.appArgs.mkString(" ")) - else Array[String]() - val commandLine = (Array(sparkSubmitExecutable.toFile.getAbsolutePath, + val preCommandLine = if (isJVM) { + mutable.ArrayBuffer(sparkSubmitExecutable.toFile.getAbsolutePath, "--deploy-mode", "cluster", "--class", appArguments.mainClass, - "--master", appConf.get("spark.master") - ) ++ appConf.toStringArray :+ - appArguments.mainAppResource) ++ - appArgsArray - ProcessUtils.executeProcess(commandLine, timeoutSecs) + "--master", appConf.get("spark.master")) + } else { + mutable.ArrayBuffer(sparkSubmitExecutable.toFile.getAbsolutePath, + "--deploy-mode", "cluster", + "--master", appConf.get("spark.master")) + } + val commandLine = + pyFiles.map(s => preCommandLine ++ Array("--py-files", s)).getOrElse(preCommandLine) ++ + appConf.toStringArray :+ appArguments.mainAppResource + + if (appArguments.appArgs.nonEmpty) { + commandLine += appArguments.appArgs.mkString(" ") + } + logInfo(s"Launching a spark app with command line: ${commandLine.mkString(" ")}") + ProcessUtils.executeProcess(commandLine.toArray, timeoutSecs) } } diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/PythonTestsSuite.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/PythonTestsSuite.scala new file mode 100644 index 0000000000000..06b73107ec236 --- /dev/null +++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/PythonTestsSuite.scala @@ -0,0 +1,105 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.integrationtest + +import org.apache.spark.deploy.k8s.integrationtest.TestConfig.{getTestImageRepo, getTestImageTag} + +private[spark] trait PythonTestsSuite { k8sSuite: KubernetesSuite => + + import PythonTestsSuite._ + import KubernetesSuite.k8sTestTag + + private val pySparkDockerImage = + s"${getTestImageRepo}/spark-py:${getTestImageTag}" + test("Run PySpark on simple pi.py example", k8sTestTag) { + sparkAppConf + .set("spark.kubernetes.container.image", pySparkDockerImage) + runSparkApplicationAndVerifyCompletion( + appResource = PYSPARK_PI, + mainClass = "", + expectedLogOnCompletion = Seq("Pi is roughly 3"), + appArgs = Array("5"), + driverPodChecker = doBasicDriverPyPodCheck, + executorPodChecker = doBasicExecutorPyPodCheck, + appLocator = appLocator, + isJVM = false) + } + + test("Run PySpark with Python2 to test a pyfiles example", k8sTestTag) { + sparkAppConf + .set("spark.kubernetes.container.image", pySparkDockerImage) + .set("spark.kubernetes.pyspark.pythonVersion", "2") + runSparkApplicationAndVerifyCompletion( + appResource = PYSPARK_FILES, + mainClass = "", + expectedLogOnCompletion = Seq( + "Python runtime version check is: True", + "Python environment version check is: True"), + appArgs = Array("python"), + driverPodChecker = doBasicDriverPyPodCheck, + executorPodChecker = doBasicExecutorPyPodCheck, + appLocator = appLocator, + isJVM = false, + pyFiles = Some(PYSPARK_CONTAINER_TESTS)) + } + + test("Run PySpark with Python3 to test a pyfiles example", k8sTestTag) { + sparkAppConf + .set("spark.kubernetes.container.image", pySparkDockerImage) + .set("spark.kubernetes.pyspark.pythonVersion", "3") + runSparkApplicationAndVerifyCompletion( + appResource = PYSPARK_FILES, + mainClass = "", + expectedLogOnCompletion = Seq( + "Python runtime version check is: True", + "Python environment version check is: True"), + appArgs = Array("python3"), + driverPodChecker = doBasicDriverPyPodCheck, + executorPodChecker = doBasicExecutorPyPodCheck, + appLocator = appLocator, + isJVM = false, + pyFiles = Some(PYSPARK_CONTAINER_TESTS)) + } + + test("Run PySpark with memory customization", k8sTestTag) { + sparkAppConf + .set("spark.kubernetes.container.image", pySparkDockerImage) + .set("spark.kubernetes.pyspark.pythonVersion", "3") + .set("spark.kubernetes.memoryOverheadFactor", s"$memOverheadConstant") + .set("spark.executor.pyspark.memory", s"${additionalMemory}m") + runSparkApplicationAndVerifyCompletion( + appResource = PYSPARK_MEMORY_CHECK, + mainClass = "", + expectedLogOnCompletion = Seq( + "PySpark Worker Memory Check is: True"), + appArgs = Array(s"$additionalMemoryInBytes"), + driverPodChecker = doDriverMemoryCheck, + executorPodChecker = doExecutorMemoryCheck, + appLocator = appLocator, + isJVM = false, + pyFiles = Some(PYSPARK_CONTAINER_TESTS)) + } +} + +private[spark] object PythonTestsSuite { + val CONTAINER_LOCAL_PYSPARK: String = "local:///opt/spark/examples/src/main/python/" + val PYSPARK_PI: String = CONTAINER_LOCAL_PYSPARK + "pi.py" + val TEST_LOCAL_PYSPARK: String = "local:///opt/spark/tests/" + val PYSPARK_FILES: String = TEST_LOCAL_PYSPARK + "pyfiles.py" + val PYSPARK_CONTAINER_TESTS: String = TEST_LOCAL_PYSPARK + "py_container_checks.py" + val PYSPARK_MEMORY_CHECK: String = TEST_LOCAL_PYSPARK + "worker_memory_check.py" +} diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/RTestsSuite.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/RTestsSuite.scala new file mode 100644 index 0000000000000..885a23cfb4864 --- /dev/null +++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/RTestsSuite.scala @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.integrationtest + +import org.apache.spark.deploy.k8s.integrationtest.TestConfig.{getTestImageRepo, getTestImageTag} + +private[spark] trait RTestsSuite { k8sSuite: KubernetesSuite => + + import RTestsSuite._ + import KubernetesSuite.k8sTestTag + + test("Run SparkR on simple dataframe.R example", k8sTestTag) { + sparkAppConf + .set("spark.kubernetes.container.image", s"${getTestImageRepo}/spark-r:${getTestImageTag}") + runSparkApplicationAndVerifyCompletion( + appResource = SPARK_R_DATAFRAME_TEST, + mainClass = "", + expectedLogOnCompletion = Seq("name: string (nullable = true)", "1 Justin"), + appArgs = Array.empty[String], + driverPodChecker = doBasicDriverRPodCheck, + executorPodChecker = doBasicExecutorRPodCheck, + appLocator = appLocator, + isJVM = false) + } +} + +private[spark] object RTestsSuite { + val CONTAINER_LOCAL_SPARKR: String = "local:///opt/spark/examples/src/main/r/" + val SPARK_R_DATAFRAME_TEST: String = CONTAINER_LOCAL_SPARKR + "dataframe.R" +} diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/SecretsTestsSuite.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/SecretsTestsSuite.scala new file mode 100644 index 0000000000000..9b039bb98dd9a --- /dev/null +++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/SecretsTestsSuite.scala @@ -0,0 +1,123 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.integrationtest + +import scala.collection.JavaConverters._ + +import io.fabric8.kubernetes.api.model.{Pod, Secret, SecretBuilder} +import org.apache.commons.codec.binary.Base64 +import org.apache.commons.io.output.ByteArrayOutputStream +import org.scalatest.concurrent.Eventually + +import org.apache.spark.deploy.k8s.integrationtest.KubernetesSuite._ + +private[spark] trait SecretsTestsSuite { k8sSuite: KubernetesSuite => + + import SecretsTestsSuite._ + + private def createTestSecret(): Unit = { + val sb = new SecretBuilder() + sb.withNewMetadata() + .withName(ENV_SECRET_NAME) + .endMetadata() + val secUsername = Base64.encodeBase64String(ENV_SECRET_VALUE_1.getBytes()) + val secPassword = Base64.encodeBase64String(ENV_SECRET_VALUE_2.getBytes()) + val envSecretData = Map(ENV_SECRET_KEY_1 -> secUsername, ENV_SECRET_KEY_2 -> secPassword) + sb.addToData(envSecretData.asJava) + val envSecret = sb.build() + val sec = kubernetesTestComponents + .kubernetesClient + .secrets() + .createOrReplace(envSecret) + } + + private def deleteTestSecret(): Unit = { + kubernetesTestComponents + .kubernetesClient + .secrets() + .withName(ENV_SECRET_NAME) + .delete() + } + + // TODO: [SPARK-25291] This test is flaky with regards to memory of executors + test("Run SparkPi with env and mount secrets.", k8sTestTag) { + createTestSecret() + sparkAppConf + .set(s"spark.kubernetes.driver.secrets.$ENV_SECRET_NAME", SECRET_MOUNT_PATH) + .set(s"spark.kubernetes.driver.secretKeyRef.USERNAME", s"$ENV_SECRET_NAME:username") + .set(s"spark.kubernetes.driver.secretKeyRef.PASSWORD", s"$ENV_SECRET_NAME:password") + .set(s"spark.kubernetes.executor.secrets.$ENV_SECRET_NAME", SECRET_MOUNT_PATH) + .set(s"spark.kubernetes.executor.secretKeyRef.USERNAME", s"$ENV_SECRET_NAME:username") + .set(s"spark.kubernetes.executor.secretKeyRef.PASSWORD", s"$ENV_SECRET_NAME:password") + try { + runSparkPiAndVerifyCompletion( + driverPodChecker = (driverPod: Pod) => { + doBasicDriverPodCheck(driverPod) + checkSecrets(driverPod) + }, + executorPodChecker = (executorPod: Pod) => { + doBasicExecutorPodCheck(executorPod) + checkSecrets(executorPod) + }, + appArgs = Array("1000") // give it enough time for all execs to be visible + ) + } finally { + // make sure this always run + deleteTestSecret() + } + } + + private def checkSecrets(pod: Pod): Unit = { + Eventually.eventually(TIMEOUT, INTERVAL) { + implicit val podName: String = pod.getMetadata.getName + val env = executeCommand("env") + assert(env.toString.contains(ENV_SECRET_VALUE_1)) + assert(env.toString.contains(ENV_SECRET_VALUE_2)) + val fileUsernameContents = executeCommand("cat", s"$SECRET_MOUNT_PATH/$ENV_SECRET_KEY_1") + val filePasswordContents = executeCommand("cat", s"$SECRET_MOUNT_PATH/$ENV_SECRET_KEY_2") + assert(fileUsernameContents.toString.trim.equals(ENV_SECRET_VALUE_1)) + assert(filePasswordContents.toString.trim.equals(ENV_SECRET_VALUE_2)) + } + } + + private def executeCommand(cmd: String*)(implicit podName: String): String = { + val out = new ByteArrayOutputStream() + val watch = kubernetesTestComponents + .kubernetesClient + .pods() + .withName(podName) + .readingInput(System.in) + .writingOutput(out) + .writingError(System.err) + .withTTY() + .exec(cmd.toArray: _*) + // wait to get some result back + Thread.sleep(1000) + watch.close() + out.flush() + out.toString() + } +} + +private[spark] object SecretsTestsSuite { + val ENV_SECRET_NAME = "mysecret" + val SECRET_MOUNT_PATH = "/etc/secret" + val ENV_SECRET_KEY_1 = "username" + val ENV_SECRET_KEY_2 = "password" + val ENV_SECRET_VALUE_1 = "secretusername" + val ENV_SECRET_VALUE_2 = "secretpassword" +} diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/config.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/TestConfig.scala similarity index 98% rename from resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/config.scala rename to resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/TestConfig.scala index a81ef455c6766..5a49e0779160c 100644 --- a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/config.scala +++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/TestConfig.scala @@ -21,7 +21,7 @@ import java.io.File import com.google.common.base.Charsets import com.google.common.io.Files -package object config { +object TestConfig { def getTestImageTag: String = { val imageTagFileProp = System.getProperty("spark.kubernetes.test.imageTagFile") require(imageTagFileProp != null, "Image tag file must be provided in system properties.") diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/constants.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/TestConstants.scala similarity index 97% rename from resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/constants.scala rename to resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/TestConstants.scala index 0807a68cd823c..8595d0eab1126 100644 --- a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/constants.scala +++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/TestConstants.scala @@ -16,7 +16,7 @@ */ package org.apache.spark.deploy.k8s.integrationtest -package object constants { +object TestConstants { val MINIKUBE_TEST_BACKEND = "minikube" val GCE_TEST_BACKEND = "gce" } diff --git a/examples/src/main/python/py_container_checks.py b/resource-managers/kubernetes/integration-tests/tests/py_container_checks.py similarity index 100% rename from examples/src/main/python/py_container_checks.py rename to resource-managers/kubernetes/integration-tests/tests/py_container_checks.py diff --git a/examples/src/main/python/pyfiles.py b/resource-managers/kubernetes/integration-tests/tests/pyfiles.py similarity index 100% rename from examples/src/main/python/pyfiles.py rename to resource-managers/kubernetes/integration-tests/tests/pyfiles.py diff --git a/resource-managers/kubernetes/integration-tests/tests/worker_memory_check.py b/resource-managers/kubernetes/integration-tests/tests/worker_memory_check.py new file mode 100644 index 0000000000000..d312a29f388e4 --- /dev/null +++ b/resource-managers/kubernetes/integration-tests/tests/worker_memory_check.py @@ -0,0 +1,47 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +import resource +import sys + +from pyspark.sql import SparkSession + + +if __name__ == "__main__": + """ + Usage: worker_memory_check [Memory_in_Mi] + """ + spark = SparkSession \ + .builder \ + .appName("PyMemoryTest") \ + .getOrCreate() + sc = spark.sparkContext + if len(sys.argv) < 2: + print("Usage: worker_memory_check [Memory_in_Mi]", file=sys.stderr) + sys.exit(-1) + + def f(x): + rLimit = resource.getrlimit(resource.RLIMIT_AS) + print("RLimit is " + str(rLimit)) + return rLimit + resourceValue = sc.parallelize([1]).map(f).collect()[0][0] + print("Resource Value is " + str(resourceValue)) + truthCheck = (resourceValue == int(sys.argv[1])) + print("PySpark Worker Memory Check is: " + str(truthCheck)) + spark.stop() diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcher.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcher.scala index ccf33e8d4283c..64698b55c6bb6 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcher.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcher.scala @@ -51,6 +51,14 @@ private[mesos] class MesosClusterDispatcher( conf: SparkConf) extends Logging { + { + // This doesn't support authentication because the RestSubmissionServer doesn't support it. + val authKey = SecurityManager.SPARK_AUTH_SECRET_CONF + require(conf.getOption(authKey).isEmpty, + s"The MesosClusterDispatcher does not support authentication via ${authKey}. It is not " + + s"currently possible to run jobs in cluster mode with authentication on.") + } + private val publicAddress = Option(conf.getenv("SPARK_PUBLIC_DNS")).getOrElse(args.host) private val recoveryMode = conf.get(RECOVERY_MODE).toUpperCase() logInfo("Recovery mode in Mesos dispatcher set to: " + recoveryMode) diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala index 7d80eedcc43ce..cb1bcba651be6 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala @@ -202,7 +202,7 @@ private[spark] class MesosClusterScheduler( } else if (removeFromPendingRetryDrivers(submissionId)) { k.success = true k.message = "Removed driver while it's being retried" - } else if (finishedDrivers.exists(_.driverDescription.submissionId.equals(submissionId))) { + } else if (finishedDrivers.exists(_.driverDescription.submissionId == submissionId)) { k.success = false k.message = "Driver already terminated" } else { @@ -222,21 +222,21 @@ private[spark] class MesosClusterScheduler( } s.submissionId = submissionId stateLock.synchronized { - if (queuedDrivers.exists(_.submissionId.equals(submissionId))) { + if (queuedDrivers.exists(_.submissionId == submissionId)) { s.success = true s.driverState = "QUEUED" } else if (launchedDrivers.contains(submissionId)) { s.success = true s.driverState = "RUNNING" launchedDrivers(submissionId).mesosTaskStatus.foreach(state => s.message = state.toString) - } else if (finishedDrivers.exists(_.driverDescription.submissionId.equals(submissionId))) { + } else if (finishedDrivers.exists(_.driverDescription.submissionId == submissionId)) { s.success = true s.driverState = "FINISHED" finishedDrivers .find(d => d.driverDescription.submissionId.equals(submissionId)).get.mesosTaskStatus .foreach(state => s.message = state.toString) - } else if (pendingRetryDrivers.exists(_.submissionId.equals(submissionId))) { - val status = pendingRetryDrivers.find(_.submissionId.equals(submissionId)) + } else if (pendingRetryDrivers.exists(_.submissionId == submissionId)) { + val status = pendingRetryDrivers.find(_.submissionId == submissionId) .get.retryState.get.lastFailureStatus s.success = true s.driverState = "RETRYING" @@ -254,13 +254,13 @@ private[spark] class MesosClusterScheduler( */ def getDriverState(submissionId: String): Option[MesosDriverState] = { stateLock.synchronized { - queuedDrivers.find(_.submissionId.equals(submissionId)) + queuedDrivers.find(_.submissionId == submissionId) .map(d => new MesosDriverState("QUEUED", d)) .orElse(launchedDrivers.get(submissionId) .map(d => new MesosDriverState("RUNNING", d.driverDescription, Some(d)))) - .orElse(finishedDrivers.find(_.driverDescription.submissionId.equals(submissionId)) + .orElse(finishedDrivers.find(_.driverDescription.submissionId == submissionId) .map(d => new MesosDriverState("FINISHED", d.driverDescription, Some(d)))) - .orElse(pendingRetryDrivers.find(_.submissionId.equals(submissionId)) + .orElse(pendingRetryDrivers.find(_.submissionId == submissionId) .map(d => new MesosDriverState("RETRYING", d))) } } @@ -814,7 +814,7 @@ private[spark] class MesosClusterScheduler( status: Int): Unit = {} private def removeFromQueuedDrivers(subId: String): Boolean = { - val index = queuedDrivers.indexWhere(_.submissionId.equals(subId)) + val index = queuedDrivers.indexWhere(_.submissionId == subId) if (index != -1) { queuedDrivers.remove(index) queuedDriversState.expunge(subId) @@ -834,7 +834,7 @@ private[spark] class MesosClusterScheduler( } private def removeFromPendingRetryDrivers(subId: String): Boolean = { - val index = pendingRetryDrivers.indexWhere(_.submissionId.equals(subId)) + val index = pendingRetryDrivers.indexWhere(_.submissionId == subId) if (index != -1) { pendingRetryDrivers.remove(index) pendingRetryDriversState.expunge(subId) diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala index 1ce2f816dffb2..178de30f0f381 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala @@ -102,7 +102,7 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( // If shuffle service is enabled, the Spark driver will register with the shuffle service. // This is for cleaning up shuffle files reliably. - private val shuffleServiceEnabled = conf.getBoolean("spark.shuffle.service.enabled", false) + private val shuffleServiceEnabled = conf.get(config.SHUFFLE_SERVICE_ENABLED) // Cores we have acquired with each Mesos task ID private val coresByTaskId = new mutable.HashMap[String, Int] @@ -624,7 +624,7 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( "External shuffle client was not instantiated even though shuffle service is enabled.") // TODO: Remove this and allow the MesosExternalShuffleService to detect // framework termination when new Mesos Framework HTTP API is available. - val externalShufflePort = conf.getInt("spark.shuffle.service.port", 7337) + val externalShufflePort = conf.get(config.SHUFFLE_SERVICE_PORT) logDebug(s"Connecting to shuffle service on slave $slaveId, " + s"host ${slave.hostname}, port $externalShufflePort for app ${conf.getAppId}") diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackend.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackend.scala index 71a70ff048ccc..0bb6fe0fa4bdf 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackend.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackend.scala @@ -453,4 +453,8 @@ private[spark] class MesosFineGrainedSchedulerBackend( super.applicationId } + override def maxNumConcurrentTasks(): Int = { + // TODO SPARK-25074 support this method for MesosFineGrainedSchedulerBackend + 0 + } } diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendUtil.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendUtil.scala index bfb73611f0530..b4364a5e2eb3a 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendUtil.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendUtil.scala @@ -117,7 +117,7 @@ private[mesos] object MesosSchedulerBackendUtil extends Logging { case Array(key, value) => Some(param.setKey(key).setValue(value)) case spec => - logWarning(s"Unable to parse arbitary parameters: $params. " + logWarning(s"Unable to parse arbitrary parameters: $params. " + "Expected form: \"key=value(, ...)\"") None } diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala index ecbcc960fc5a0..8ef1e18f83de3 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala @@ -355,7 +355,7 @@ trait MesosSchedulerUtils extends Logging { * https://github.com/apache/mesos/blob/master/src/common/values.cpp * https://github.com/apache/mesos/blob/master/src/common/attributes.cpp * - * @param constraintsVal constaints string consisting of ';' separated key-value pairs (separated + * @param constraintsVal constains string consisting of ';' separated key-value pairs (separated * by ':') * @return Map of constraints to match resources offers. */ diff --git a/resource-managers/mesos/src/test/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcherArgumentsSuite.scala b/resource-managers/mesos/src/test/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcherArgumentsSuite.scala index 33e7d69d53d38..057c51db455ef 100644 --- a/resource-managers/mesos/src/test/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcherArgumentsSuite.scala +++ b/resource-managers/mesos/src/test/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcherArgumentsSuite.scala @@ -23,7 +23,7 @@ import org.apache.spark.deploy.TestPrematureExit class MesosClusterDispatcherArgumentsSuite extends SparkFunSuite with TestPrematureExit { - test("test if spark config args are passed sucessfully") { + test("test if spark config args are passed successfully") { val args = Array[String]("--master", "mesos://localhost:5050", "--conf", "key1=value1", "--conf", "spark.mesos.key2=value2", "--verbose") val conf = new SparkConf() diff --git a/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterSchedulerSuite.scala b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterSchedulerSuite.scala index e534b9d7e3ed9..082d4bcfdf83a 100644 --- a/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterSchedulerSuite.scala +++ b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterSchedulerSuite.scala @@ -21,7 +21,7 @@ import java.util.{Collection, Collections, Date} import scala.collection.JavaConverters._ -import org.apache.mesos.Protos.{Environment, Secret, TaskState => MesosTaskState, _} +import org.apache.mesos.Protos.{TaskState => MesosTaskState, _} import org.apache.mesos.Protos.Value.{Scalar, Type} import org.apache.mesos.SchedulerDriver import org.mockito.{ArgumentCaptor, Matchers} @@ -146,14 +146,14 @@ class MesosClusterSchedulerSuite extends SparkFunSuite with LocalSparkContext wi assert(scheduler.getResource(resources, "cpus") == 1.5) assert(scheduler.getResource(resources, "mem") == 1200) val resourcesSeq: Seq[Resource] = resources.asScala - val cpus = resourcesSeq.filter(_.getName.equals("cpus")).toList + val cpus = resourcesSeq.filter(_.getName == "cpus").toList assert(cpus.size == 2) - assert(cpus.exists(_.getRole().equals("role2"))) - assert(cpus.exists(_.getRole().equals("*"))) - val mem = resourcesSeq.filter(_.getName.equals("mem")).toList + assert(cpus.exists(_.getRole() == "role2")) + assert(cpus.exists(_.getRole() == "*")) + val mem = resourcesSeq.filter(_.getName == "mem").toList assert(mem.size == 2) - assert(mem.exists(_.getRole().equals("role2"))) - assert(mem.exists(_.getRole().equals("*"))) + assert(mem.exists(_.getRole() == "role2")) + assert(mem.exists(_.getRole() == "*")) verify(driver, times(1)).launchTasks( Matchers.eq(Collections.singleton(offer.getId)), diff --git a/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackendSuite.scala b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackendSuite.scala index b790c7cd27794..da33d85d8fb2e 100644 --- a/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackendSuite.scala +++ b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackendSuite.scala @@ -262,7 +262,7 @@ class MesosCoarseGrainedSchedulerBackendSuite extends SparkFunSuite } test("mesos doesn't register twice with the same shuffle service") { - setBackend(Map("spark.shuffle.service.enabled" -> "true")) + setBackend(Map(SHUFFLE_SERVICE_ENABLED.key -> "true")) val (mem, cpu) = (backend.executorMemory(sc), 4) val offer1 = createOffer("o1", "s1", mem, cpu) diff --git a/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackendSuite.scala b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackendSuite.scala index 2d2f90c63a309..1ead4b1ed7c7e 100644 --- a/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackendSuite.scala +++ b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackendSuite.scala @@ -106,7 +106,7 @@ class MesosFineGrainedSchedulerBackendSuite // uri is null. val (executorInfo, _) = mesosSchedulerBackend.createExecutorInfo(resources, "test-id") val executorResources = executorInfo.getResourcesList - val cpus = executorResources.asScala.find(_.getName.equals("cpus")).get.getScalar.getValue + val cpus = executorResources.asScala.find(_.getName == "cpus").get.getScalar.getValue assert(cpus === mesosExecutorCores) } @@ -253,6 +253,7 @@ class MesosFineGrainedSchedulerBackendSuite executorId = "s1", name = "n1", index = 0, + partitionId = 0, addedFiles = mutable.Map.empty[String, Long], addedJars = mutable.Map.empty[String, Long], properties = new Properties(), @@ -361,6 +362,7 @@ class MesosFineGrainedSchedulerBackendSuite executorId = "s1", name = "n1", index = 0, + partitionId = 0, addedFiles = mutable.Map.empty[String, Long], addedJars = mutable.Map.empty[String, Long], properties = new Properties(), diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala index ecc576910db9e..8f94e3f731007 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -19,7 +19,7 @@ package org.apache.spark.deploy.yarn import java.io.{File, IOException} import java.lang.reflect.{InvocationTargetException, Modifier} -import java.net.{Socket, URI, URL} +import java.net.{URI, URL} import java.security.PrivilegedExceptionAction import java.util.concurrent.{TimeoutException, TimeUnit} @@ -28,6 +28,7 @@ import scala.concurrent.Promise import scala.concurrent.duration.Duration import scala.util.control.NonFatal +import org.apache.commons.lang3.{StringUtils => ComStrUtils} import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.hadoop.util.StringUtils import org.apache.hadoop.yarn.api._ @@ -43,6 +44,7 @@ import org.apache.spark.deploy.yarn.config._ import org.apache.spark.deploy.yarn.security.AMCredentialRenewer import org.apache.spark.internal.Logging import org.apache.spark.internal.config._ +import org.apache.spark.metrics.MetricsSystem import org.apache.spark.rpc._ import org.apache.spark.scheduler.cluster.{CoarseGrainedSchedulerBackend, YarnSchedulerBackend} import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._ @@ -67,6 +69,8 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments) extends private val securityMgr = new SecurityManager(sparkConf) + private var metricsSystem: Option[MetricsSystem] = None + // Set system properties for each config entry. This covers two use cases: // - The default configuration stored by the SparkHadoopUtil class // - The user application creating a new SparkConf in cluster mode @@ -309,6 +313,16 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments) extends finish(FinalApplicationStatus.FAILED, ApplicationMaster.EXIT_UNCAUGHT_EXCEPTION, "Uncaught exception: " + StringUtils.stringifyException(e)) + } finally { + try { + metricsSystem.foreach { ms => + ms.report() + ms.stop() + } + } catch { + case e: Exception => + logWarning("Exception during stopping of the metric system: ", e) + } } } @@ -355,7 +369,7 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments) extends } logInfo(s"Final app status: $finalStatus, exitCode: $exitCode" + Option(msg).map(msg => s", (reason: $msg)").getOrElse("")) - finalMsg = msg + finalMsg = ComStrUtils.abbreviate(msg, sparkConf.get(AM_FINAL_MSG_LIMIT).toInt) finished = true if (!inShutdown && Thread.currentThread() != reporterThread && reporterThread != null) { logDebug("shutting down reporter thread") @@ -434,6 +448,11 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments) extends rpcEnv.setupEndpoint("YarnAM", new AMEndpoint(rpcEnv, driverRef)) allocator.allocateResources() + val ms = MetricsSystem.createMetricsSystem("applicationMaster", sparkConf, securityMgr) + val prefix = _sparkConf.get(YARN_METRICS_NAMESPACE).getOrElse(appId) + ms.registerSource(new ApplicationMasterSource(prefix, allocator)) + ms.start() + metricsSystem = Some(ms) reporterThread = launchReporterThread() } diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMasterSource.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMasterSource.scala new file mode 100644 index 0000000000000..0fec916582602 --- /dev/null +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMasterSource.scala @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.yarn + +import com.codahale.metrics.{Gauge, MetricRegistry} + +import org.apache.spark.metrics.source.Source + +private[spark] class ApplicationMasterSource(prefix: String, yarnAllocator: YarnAllocator) + extends Source { + + override val sourceName: String = prefix + ".applicationMaster" + override val metricRegistry: MetricRegistry = new MetricRegistry() + + metricRegistry.register(MetricRegistry.name("numExecutorsFailed"), new Gauge[Int] { + override def getValue: Int = yarnAllocator.getNumExecutorsFailed + }) + + metricRegistry.register(MetricRegistry.name("numExecutorsRunning"), new Gauge[Int] { + override def getValue: Int = yarnAllocator.getNumExecutorsRunning + }) + + metricRegistry.register(MetricRegistry.name("numReleasedContainers"), new Gauge[Int] { + override def getValue: Int = yarnAllocator.getNumReleasedContainers + }) + + metricRegistry.register(MetricRegistry.name("numLocalityAwareTasks"), new Gauge[Int] { + override def getValue: Int = yarnAllocator.numLocalityAwareTasks + }) + + metricRegistry.register(MetricRegistry.name("numContainersPendingAllocate"), new Gauge[Int] { + override def getValue: Int = yarnAllocator.numContainersPendingAllocate + }) + +} diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index 793d012218490..4a85898ef880b 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -91,6 +91,13 @@ private[spark] class Client( private val executorMemoryOverhead = sparkConf.get(EXECUTOR_MEMORY_OVERHEAD).getOrElse( math.max((MEMORY_OVERHEAD_FACTOR * executorMemory).toLong, MEMORY_OVERHEAD_MIN)).toInt + private val isPython = sparkConf.get(IS_PYTHON_APP) + private val pysparkWorkerMemory: Int = if (isPython) { + sparkConf.get(PYSPARK_EXECUTOR_MEMORY).map(_.toInt).getOrElse(0) + } else { + 0 + } + private val distCacheMgr = new ClientDistributedCacheManager() private val principal = sparkConf.get(PRINCIPAL).orNull @@ -333,18 +340,19 @@ private[spark] class Client( val maxMem = newAppResponse.getMaximumResourceCapability().getMemory() logInfo("Verifying our application has not requested more than the maximum " + s"memory capability of the cluster ($maxMem MB per container)") - val executorMem = executorMemory + executorMemoryOverhead + val executorMem = executorMemory + executorMemoryOverhead + pysparkWorkerMemory if (executorMem > maxMem) { - throw new IllegalArgumentException(s"Required executor memory ($executorMemory" + - s"+$executorMemoryOverhead MB) is above the max threshold ($maxMem MB) of this cluster! " + - "Please check the values of 'yarn.scheduler.maximum-allocation-mb' and/or " + - "'yarn.nodemanager.resource.memory-mb'.") + throw new IllegalArgumentException(s"Required executor memory ($executorMemory), overhead " + + s"($executorMemoryOverhead MB), and PySpark memory ($pysparkWorkerMemory MB) is above " + + s"the max threshold ($maxMem MB) of this cluster! Please check the values of " + + s"'yarn.scheduler.maximum-allocation-mb' and/or 'yarn.nodemanager.resource.memory-mb'.") } val amMem = amMemory + amMemoryOverhead if (amMem > maxMem) { throw new IllegalArgumentException(s"Required AM memory ($amMemory" + s"+$amMemoryOverhead MB) is above the max threshold ($maxMem MB) of this cluster! " + - "Please increase the value of 'yarn.scheduler.maximum-allocation-mb'.") + "Please check the values of 'yarn.scheduler.maximum-allocation-mb' and/or " + + "'yarn.nodemanager.resource.memory-mb'.") } logInfo("Will allocate AM container, with %d MB memory including %d MB overhead".format( amMem, @@ -437,7 +445,7 @@ private[spark] class Client( } } - /** + /* * Distribute a file to the cluster. * * If the file's path is a "local:" URI, it's actually not distributed. Other files are copied @@ -811,10 +819,12 @@ private[spark] class Client( // Finally, update the Spark config to propagate PYTHONPATH to the AM and executors. if (pythonPath.nonEmpty) { - val pythonPathStr = (sys.env.get("PYTHONPATH") ++ pythonPath) + val pythonPathList = (sys.env.get("PYTHONPATH") ++ pythonPath) + env("PYTHONPATH") = (env.get("PYTHONPATH") ++ pythonPathList) .mkString(ApplicationConstants.CLASS_PATH_SEPARATOR) - env("PYTHONPATH") = pythonPathStr - sparkConf.setExecutorEnv("PYTHONPATH", pythonPathStr) + val pythonPathExecutorEnv = (sparkConf.getExecutorEnv.toMap.get("PYTHONPATH") ++ + pythonPathList).mkString(ApplicationConstants.CLASS_PATH_SEPARATOR) + sparkConf.setExecutorEnv("PYTHONPATH", pythonPathExecutorEnv) } if (isClusterMode) { diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala index fae054e0eea00..8a7551de7c088 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala @@ -133,10 +133,17 @@ private[yarn] class YarnAllocator( // Additional memory overhead. protected val memoryOverhead: Int = sparkConf.get(EXECUTOR_MEMORY_OVERHEAD).getOrElse( math.max((MEMORY_OVERHEAD_FACTOR * executorMemory).toInt, MEMORY_OVERHEAD_MIN)).toInt + protected val pysparkWorkerMemory: Int = if (sparkConf.get(IS_PYTHON_APP)) { + sparkConf.get(PYSPARK_EXECUTOR_MEMORY).map(_.toInt).getOrElse(0) + } else { + 0 + } // Number of cores per executor. protected val executorCores = sparkConf.get(EXECUTOR_CORES) // Resource capability requested for each executors - private[yarn] val resource = Resource.newInstance(executorMemory + memoryOverhead, executorCores) + private[yarn] val resource = Resource.newInstance( + executorMemory + memoryOverhead + pysparkWorkerMemory, + executorCores) private val launcherPool = ThreadUtils.newDaemonCachedThreadPool( "ContainerLauncher", sparkConf.get(CONTAINER_LAUNCH_MAX_THREADS)) @@ -150,7 +157,7 @@ private[yarn] class YarnAllocator( private var hostToLocalTaskCounts: Map[String, Int] = Map.empty // Number of tasks that have locality preferences in active stages - private var numLocalityAwareTasks: Int = 0 + private[yarn] var numLocalityAwareTasks: Int = 0 // A container placement strategy based on pending tasks' locality preference private[yarn] val containerPlacementStrategy = @@ -158,6 +165,8 @@ private[yarn] class YarnAllocator( def getNumExecutorsRunning: Int = runningExecutors.size() + def getNumReleasedContainers: Int = releasedContainers.size() + def getNumExecutorsFailed: Int = failureTracker.numFailedExecutors def isAllNodeBlacklisted: Boolean = allocatorBlacklistTracker.isAllNodeBlacklisted @@ -167,6 +176,10 @@ private[yarn] class YarnAllocator( */ def getPendingAllocate: Seq[ContainerRequest] = getPendingAtLocation(ANY_HOST) + def numContainersPendingAllocate: Int = synchronized { + getPendingAllocate.size + } + /** * A sequence of pending container requests at the given location that have not yet been * fulfilled. diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocatorBlacklistTracker.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocatorBlacklistTracker.scala index 1b48a0ee7ad32..ceac7cda5f8be 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocatorBlacklistTracker.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocatorBlacklistTracker.scala @@ -28,7 +28,7 @@ import org.apache.spark.deploy.yarn.config._ import org.apache.spark.internal.Logging import org.apache.spark.internal.config._ import org.apache.spark.scheduler.BlacklistTracker -import org.apache.spark.util.{Clock, SystemClock, Utils} +import org.apache.spark.util.{Clock, SystemClock} /** * YarnAllocatorBlacklistTracker is responsible for tracking the blacklisted nodes diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala index b59dcf158d87c..05a7b1e1310c4 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala @@ -19,6 +19,7 @@ package org.apache.spark.deploy.yarn import scala.collection.JavaConverters._ +import org.apache.hadoop.conf.Configuration import org.apache.hadoop.yarn.api.records._ import org.apache.hadoop.yarn.client.api.AMRMClient import org.apache.hadoop.yarn.client.api.AMRMClient.ContainerRequest @@ -112,7 +113,16 @@ private[spark] class YarnRMClient extends Logging { val proxies = WebAppUtils.getProxyHostsAndPortsForAmFilter(conf) val hosts = proxies.asScala.map(_.split(":").head) val uriBases = proxies.asScala.map { proxy => prefix + proxy + proxyBase } - Map("PROXY_HOSTS" -> hosts.mkString(","), "PROXY_URI_BASES" -> uriBases.mkString(",")) + val params = + Map("PROXY_HOSTS" -> hosts.mkString(","), "PROXY_URI_BASES" -> uriBases.mkString(",")) + + // Handles RM HA urls + val rmIds = conf.getStringCollection(YarnConfiguration.RM_HA_IDS).asScala + if (rmIds != null && rmIds.nonEmpty) { + params + ("RM_HA_URLS" -> rmIds.map(getUrlByRmId(conf, _)).mkString(",")) + } else { + params + } } /** Returns the maximum number of attempts to register the AM. */ @@ -126,4 +136,21 @@ private[spark] class YarnRMClient extends Logging { } } + private def getUrlByRmId(conf: Configuration, rmId: String): String = { + val addressPropertyPrefix = if (YarnConfiguration.useHttps(conf)) { + YarnConfiguration.RM_WEBAPP_HTTPS_ADDRESS + } else { + YarnConfiguration.RM_WEBAPP_ADDRESS + } + + val addressWithRmId = if (rmId == null || rmId.isEmpty) { + addressPropertyPrefix + } else if (rmId.startsWith(".")) { + throw new IllegalStateException(s"rmId $rmId should not already have '.' prepended.") + } else { + s"$addressPropertyPrefix.$rmId" + } + + conf.get(addressWithRmId) + } } diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala index 7250e58b6c49a..3a3272216294f 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala @@ -27,11 +27,8 @@ import org.apache.hadoop.yarn.api.ApplicationConstants import org.apache.hadoop.yarn.api.records.{ApplicationAccessType, ContainerId, Priority} import org.apache.hadoop.yarn.util.ConverterUtils -import org.apache.spark.{SecurityManager, SparkConf, SparkException} -import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.{SecurityManager, SparkConf} import org.apache.spark.deploy.yarn.config._ -import org.apache.spark.deploy.yarn.security.YARNHadoopDelegationTokenManager -import org.apache.spark.internal.config._ import org.apache.spark.launcher.YarnCommandBuilderUtils import org.apache.spark.util.Utils @@ -193,8 +190,7 @@ object YarnSparkHadoopUtil { sparkConf: SparkConf, hadoopConf: Configuration): Set[FileSystem] = { val filesystemsToAccess = sparkConf.get(FILESYSTEMS_TO_ACCESS) - .map(new Path(_).getFileSystem(hadoopConf)) - .toSet + val requestAllDelegationTokens = filesystemsToAccess.isEmpty val stagingFS = sparkConf.get(STAGING_DIR) .map(new Path(_).getFileSystem(hadoopConf)) @@ -203,8 +199,8 @@ object YarnSparkHadoopUtil { // Add the list of available namenodes for all namespaces in HDFS federation. // If ViewFS is enabled, this is skipped as ViewFS already handles delegation tokens for its // namespaces. - val hadoopFilesystems = if (stagingFS.getScheme == "viewfs") { - Set.empty + val hadoopFilesystems = if (!requestAllDelegationTokens || stagingFS.getScheme == "viewfs") { + filesystemsToAccess.map(new Path(_).getFileSystem(hadoopConf)).toSet } else { val nameservices = hadoopConf.getTrimmedStrings("dfs.nameservices") // Retrieving the filesystem for the nameservices where HA is not enabled @@ -222,7 +218,7 @@ object YarnSparkHadoopUtil { (filesystemsWithoutHA ++ filesystemsWithHA).toSet } - filesystemsToAccess ++ hadoopFilesystems + stagingFS + hadoopFilesystems + stagingFS } } diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/config.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/config.scala index 129084a86597a..ab8273bd6321d 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/config.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/config.scala @@ -152,6 +152,11 @@ package object config { .timeConf(TimeUnit.MILLISECONDS) .createWithDefaultString("100s") + private[spark] val YARN_METRICS_NAMESPACE = ConfigBuilder("spark.yarn.metrics.namespace") + .doc("The root namespace for AM metrics reporting.") + .stringConf + .createOptional + private[spark] val AM_NODE_LABEL_EXPRESSION = ConfigBuilder("spark.yarn.am.nodeLabelExpression") .doc("Node label expression for the AM.") .stringConf @@ -187,6 +192,12 @@ package object config { .toSequence .createWithDefault(Nil) + private[spark] val AM_FINAL_MSG_LIMIT = ConfigBuilder("spark.yarn.am.finalMessageLimit") + .doc("The limit size of final diagnostic message for our ApplicationMaster to unregister from" + + " the ResourceManager.") + .bytesConf(ByteUnit.BYTE) + .createWithDefaultString("1m") + /* Client-mode AM configuration. */ private[spark] val AM_CORES = ConfigBuilder("spark.yarn.am.cores") diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala index f1a8df00f9c5b..9397a1e3de9ac 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala @@ -111,7 +111,7 @@ private[spark] class YarnClientSchedulerBackend( override def run() { try { val YarnAppReport(_, state, diags) = - client.monitorApplication(appId.get, logApplicationReport = true) + client.monitorApplication(appId.get, logApplicationReport = false) logError(s"YARN application has exited unexpectedly with state $state! " + "Check the YARN application logs for more details.") diags.foreach { err => diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/BaseYarnClusterSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/BaseYarnClusterSuite.scala index b0abcc9149d08..3a7913122dd83 100644 --- a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/BaseYarnClusterSuite.scala +++ b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/BaseYarnClusterSuite.scala @@ -133,7 +133,8 @@ abstract class BaseYarnClusterSuite extraClassPath: Seq[String] = Nil, extraJars: Seq[String] = Nil, extraConf: Map[String, String] = Map(), - extraEnv: Map[String, String] = Map()): SparkAppHandle.State = { + extraEnv: Map[String, String] = Map(), + outFile: Option[File] = None): SparkAppHandle.State = { val deployMode = if (clientMode) "client" else "cluster" val propsFile = createConfFile(extraClassPath = extraClassPath, extraConf = extraConf) val env = Map("YARN_CONF_DIR" -> hadoopConfDir.getAbsolutePath()) ++ extraEnv @@ -161,6 +162,11 @@ abstract class BaseYarnClusterSuite } extraJars.foreach(launcher.addJar) + if (outFile.isDefined) { + launcher.redirectOutput(outFile.get) + launcher.redirectError() + } + val handle = launcher.startApplication() try { eventually(timeout(2 minutes), interval(1 second)) { @@ -179,17 +185,22 @@ abstract class BaseYarnClusterSuite * the tests enforce that something is written to a file after everything is ok to indicate * that the job succeeded. */ - protected def checkResult(finalState: SparkAppHandle.State, result: File): Unit = { - checkResult(finalState, result, "success") - } - protected def checkResult( finalState: SparkAppHandle.State, result: File, - expected: String): Unit = { - finalState should be (SparkAppHandle.State.FINISHED) + expected: String = "success", + outFile: Option[File] = None): Unit = { + // the context message is passed to assert as Any instead of a function. to lazily load the + // output from the file, this passes an anonymous object that loads it in toString when building + // an error message + val output = new Object() { + override def toString: String = outFile + .map(Files.toString(_, StandardCharsets.UTF_8)) + .getOrElse("(stdout/stderr was not captured)") + } + assert(finalState === SparkAppHandle.State.FINISHED, output) val resultString = Files.toString(result, StandardCharsets.UTF_8) - resultString should be (expected) + assert(resultString === expected, output) } protected def mainClassName(klass: Class[_]): String = { diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala index 7fa597167f3f0..26013a109c42b 100644 --- a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala +++ b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala @@ -191,7 +191,7 @@ class ClientSuite extends SparkFunSuite with Matchers { appContext.getQueue should be ("staging-queue") appContext.getAMContainerSpec should be (containerLaunchContext) appContext.getApplicationType should be ("SPARK") - appContext.getClass.getMethods.filter(_.getName.equals("getApplicationTags")).foreach{ method => + appContext.getClass.getMethods.filter(_.getName == "getApplicationTags").foreach { method => val tags = method.invoke(appContext).asInstanceOf[java.util.Set[String]] tags should contain allOf ("tag1", "dup", "tag2", "multi word") tags.asScala.count(_.nonEmpty) should be (4) diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala index 3b78b88de778d..58d11e96942e1 100644 --- a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala +++ b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala @@ -108,7 +108,7 @@ class YarnClusterSuite extends BaseYarnClusterSuite { "spark.executor.cores" -> "1", "spark.executor.memory" -> "512m", "spark.executor.instances" -> "2", - // Sending some senstive information, which we'll make sure gets redacted + // Sending some sensitive information, which we'll make sure gets redacted "spark.executorEnv.HADOOP_CREDSTORE_PASSWORD" -> YarnClusterDriver.SECRET_PASSWORD, "spark.yarn.appMasterEnv.HADOOP_CREDSTORE_PASSWORD" -> YarnClusterDriver.SECRET_PASSWORD )) @@ -282,13 +282,15 @@ class YarnClusterSuite extends BaseYarnClusterSuite { val mod2Archive = TestUtils.createJarWithFiles(Map("mod2.py" -> TEST_PYMODULE), moduleDir) val pyFiles = Seq(pyModule.getAbsolutePath(), mod2Archive.getPath()).mkString(",") val result = File.createTempFile("result", null, tempDir) + val outFile = Some(File.createTempFile("stdout", null, tempDir)) val finalState = runSpark(clientMode, primaryPyFile.getAbsolutePath(), sparkArgs = Seq("--py-files" -> pyFiles), appArgs = Seq(result.getAbsolutePath()), extraEnv = extraEnvVars, - extraConf = extraConf) - checkResult(finalState, result) + extraConf = extraConf, + outFile = outFile) + checkResult(finalState, result, outFile = outFile) } private def testUseClassPathFirst(clientMode: Boolean): Unit = { diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnShuffleIntegrationSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnShuffleIntegrationSuite.scala index 01db796096f26..37bccaf0439b4 100644 --- a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnShuffleIntegrationSuite.scala +++ b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnShuffleIntegrationSuite.scala @@ -44,7 +44,7 @@ class YarnShuffleIntegrationSuite extends BaseYarnClusterSuite { yarnConfig.set(YarnConfiguration.NM_AUX_SERVICES, "spark_shuffle") yarnConfig.set(YarnConfiguration.NM_AUX_SERVICE_FMT.format("spark_shuffle"), classOf[YarnShuffleService].getCanonicalName) - yarnConfig.set("spark.shuffle.service.port", "0") + yarnConfig.set(SHUFFLE_SERVICE_PORT.key, "0") yarnConfig } @@ -54,8 +54,8 @@ class YarnShuffleIntegrationSuite extends BaseYarnClusterSuite { logInfo("Shuffle service port = " + shuffleServicePort) Map( - "spark.shuffle.service.enabled" -> "true", - "spark.shuffle.service.port" -> shuffleServicePort.toString, + SHUFFLE_SERVICE_ENABLED.key -> "true", + SHUFFLE_SERVICE_PORT.key -> shuffleServicePort.toString, MAX_EXECUTOR_FAILURES.key -> "1" ) } diff --git a/scalastyle-config.xml b/scalastyle-config.xml index e65e3aafe5b5b..da5c3f29c32dc 100644 --- a/scalastyle-config.xml +++ b/scalastyle-config.xml @@ -150,6 +150,19 @@ This file is divided into 3 sections: // scalastyle:on println]]> + + spark(.sqlContext)?.sparkContext.hadoopConfiguration + + + @VisibleForTesting ' expression #lambda + | '(' IDENTIFIER (',' IDENTIFIER)+ ')' '->' expression #lambda | value=primaryExpression '[' index=valueExpression ']' #subscript | identifier #columnReference | base=primaryExpression '.' fieldName=identifier #dereference diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/MaskExpressionsUtils.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/MaskExpressionsUtils.java deleted file mode 100644 index 05879902a4ed9..0000000000000 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/MaskExpressionsUtils.java +++ /dev/null @@ -1,80 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.expressions; - -/** - * Contains all the Utils methods used in the masking expressions. - */ -public class MaskExpressionsUtils { - static final int UNMASKED_VAL = -1; - - /** - * Returns the masking character for {@param c} or {@param c} is it should not be masked. - * @param c the character to transform - * @param maskedUpperChar the character to use instead of a uppercase letter - * @param maskedLowerChar the character to use instead of a lowercase letter - * @param maskedDigitChar the character to use instead of a digit - * @param maskedOtherChar the character to use instead of a any other character - * @return masking character for {@param c} - */ - public static int transformChar( - final int c, - int maskedUpperChar, - int maskedLowerChar, - int maskedDigitChar, - int maskedOtherChar) { - switch(Character.getType(c)) { - case Character.UPPERCASE_LETTER: - if(maskedUpperChar != UNMASKED_VAL) { - return maskedUpperChar; - } - break; - - case Character.LOWERCASE_LETTER: - if(maskedLowerChar != UNMASKED_VAL) { - return maskedLowerChar; - } - break; - - case Character.DECIMAL_DIGIT_NUMBER: - if(maskedDigitChar != UNMASKED_VAL) { - return maskedDigitChar; - } - break; - - default: - if(maskedOtherChar != UNMASKED_VAL) { - return maskedOtherChar; - } - break; - } - - return c; - } - - /** - * Returns the replacement char to use according to the {@param rep} specified by the user and - * the {@param def} default. - */ - public static int getReplacementChar(String rep, int def) { - if (rep != null && rep.length() > 0) { - return rep.codePointAt(0); - } - return def; - } -} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java index 4dd2b7365652a..9002abdcfd474 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java @@ -27,7 +27,6 @@ import org.apache.spark.unsafe.array.ByteArrayMethods; import org.apache.spark.unsafe.bitset.BitSetMethods; import org.apache.spark.unsafe.hash.Murmur3_x86_32; -import org.apache.spark.unsafe.memory.MemoryBlock; import org.apache.spark.unsafe.types.CalendarInterval; import org.apache.spark.unsafe.types.UTF8String; @@ -241,8 +240,7 @@ public UTF8String getUTF8String(int ordinal) { final long offsetAndSize = getLong(ordinal); final int offset = (int) (offsetAndSize >> 32); final int size = (int) offsetAndSize; - MemoryBlock mb = MemoryBlock.allocateFromObject(baseObject, baseOffset + offset, size); - return new UTF8String(mb); + return UTF8String.fromAddress(baseObject, baseOffset + offset, size); } @Override @@ -450,7 +448,7 @@ public double[] toDoubleArray() { return values; } - private static UnsafeArrayData fromPrimitiveArray( + public static UnsafeArrayData fromPrimitiveArray( Object arr, int offset, int length, int elementSize) { final long headerInBytes = calculateHeaderPortionInBytes(length); final long valueRegionInBytes = (long)elementSize * length; @@ -463,14 +461,41 @@ private static UnsafeArrayData fromPrimitiveArray( final long[] data = new long[(int)totalSizeInLongs]; Platform.putLong(data, Platform.LONG_ARRAY_OFFSET, length); - Platform.copyMemory(arr, offset, data, - Platform.LONG_ARRAY_OFFSET + headerInBytes, valueRegionInBytes); + if (arr != null) { + Platform.copyMemory(arr, offset, data, + Platform.LONG_ARRAY_OFFSET + headerInBytes, valueRegionInBytes); + } + + UnsafeArrayData result = new UnsafeArrayData(); + result.pointTo(data, Platform.LONG_ARRAY_OFFSET, (int)totalSizeInLongs * 8); + return result; + } + + public static UnsafeArrayData createFreshArray(int length, int elementSize) { + final long headerInBytes = calculateHeaderPortionInBytes(length); + final long valueRegionInBytes = (long)elementSize * length; + final long totalSizeInLongs = (headerInBytes + valueRegionInBytes + 7) / 8; + if (totalSizeInLongs > Integer.MAX_VALUE / 8) { + throw new UnsupportedOperationException("Cannot convert this array to unsafe format as " + + "it's too big."); + } + + final long[] data = new long[(int)totalSizeInLongs]; + + Platform.putLong(data, Platform.LONG_ARRAY_OFFSET, length); UnsafeArrayData result = new UnsafeArrayData(); result.pointTo(data, Platform.LONG_ARRAY_OFFSET, (int)totalSizeInLongs * 8); return result; } + public static boolean shouldUseGenericArrayData(int elementSize, long length) { + final long headerInBytes = calculateHeaderPortionInBytes(length); + final long valueRegionInBytes = elementSize * length; + final long totalSizeInLongs = (headerInBytes + valueRegionInBytes + 7) / 8; + return totalSizeInLongs > Integer.MAX_VALUE / 8; + } + public static UnsafeArrayData fromPrimitiveArray(boolean[] arr) { return fromPrimitiveArray(arr, Platform.BOOLEAN_ARRAY_OFFSET, arr.length, 1); } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java index 469b0e60cc9a2..a76e6ef8c91c1 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java @@ -37,7 +37,6 @@ import org.apache.spark.unsafe.array.ByteArrayMethods; import org.apache.spark.unsafe.bitset.BitSetMethods; import org.apache.spark.unsafe.hash.Murmur3_x86_32; -import org.apache.spark.unsafe.memory.MemoryBlock; import org.apache.spark.unsafe.types.CalendarInterval; import org.apache.spark.unsafe.types.UTF8String; @@ -417,8 +416,7 @@ public UTF8String getUTF8String(int ordinal) { final long offsetAndSize = getLong(ordinal); final int offset = (int) (offsetAndSize >> 32); final int size = (int) offsetAndSize; - MemoryBlock mb = MemoryBlock.allocateFromObject(baseObject, baseOffset + offset, size); - return new UTF8String(mb); + return UTF8String.fromAddress(baseObject, baseOffset + offset, size); } @Override diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/XXH64.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/XXH64.java index 8e9c0a2e9dc81..eb5051b284073 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/XXH64.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/XXH64.java @@ -16,7 +16,7 @@ */ package org.apache.spark.sql.catalyst.expressions; -import org.apache.spark.unsafe.memory.MemoryBlock; +import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.types.UTF8String; // scalastyle: off @@ -72,13 +72,13 @@ public static long hashLong(long input, long seed) { return fmix(hash); } - public long hashUnsafeWordsBlock(MemoryBlock mb) { - return hashUnsafeWordsBlock(mb, seed); + public long hashUnsafeWords(Object base, long offset, int length) { + return hashUnsafeWords(base, offset, length, seed); } - public static long hashUnsafeWordsBlock(MemoryBlock mb, long seed) { - assert (mb.size() % 8 == 0) : "lengthInBytes must be a multiple of 8 (word-aligned)"; - long hash = hashBytesByWordsBlock(mb, seed); + public static long hashUnsafeWords(Object base, long offset, int length, long seed) { + assert (length % 8 == 0) : "lengthInBytes must be a multiple of 8 (word-aligned)"; + long hash = hashBytesByWords(base, offset, length, seed); return fmix(hash); } @@ -86,22 +86,20 @@ public long hashUnsafeBytes(Object base, long offset, int length) { return hashUnsafeBytes(base, offset, length, seed); } - public static long hashUnsafeBytesBlock(MemoryBlock mb, long seed) { - long offset = 0; - long length = mb.size(); + public static long hashUnsafeBytes(Object base, long offset, int length, long seed) { assert (length >= 0) : "lengthInBytes cannot be negative"; - long hash = hashBytesByWordsBlock(mb, seed); + long hash = hashBytesByWords(base, offset, length, seed); long end = offset + length; offset += length & -8; if (offset + 4L <= end) { - hash ^= (mb.getInt(offset) & 0xFFFFFFFFL) * PRIME64_1; + hash ^= (Platform.getInt(base, offset) & 0xFFFFFFFFL) * PRIME64_1; hash = Long.rotateLeft(hash, 23) * PRIME64_2 + PRIME64_3; offset += 4L; } while (offset < end) { - hash ^= (mb.getByte(offset) & 0xFFL) * PRIME64_5; + hash ^= (Platform.getByte(base, offset) & 0xFFL) * PRIME64_5; hash = Long.rotateLeft(hash, 11) * PRIME64_1; offset++; } @@ -109,11 +107,7 @@ public static long hashUnsafeBytesBlock(MemoryBlock mb, long seed) { } public static long hashUTF8String(UTF8String str, long seed) { - return hashUnsafeBytesBlock(str.getMemoryBlock(), seed); - } - - public static long hashUnsafeBytes(Object base, long offset, int length, long seed) { - return hashUnsafeBytesBlock(MemoryBlock.allocateFromObject(base, offset, length), seed); + return hashUnsafeBytes(str.getBaseObject(), str.getBaseOffset(), str.numBytes(), seed); } private static long fmix(long hash) { @@ -125,31 +119,30 @@ private static long fmix(long hash) { return hash; } - private static long hashBytesByWordsBlock(MemoryBlock mb, long seed) { - long offset = 0; - long length = mb.size(); + private static long hashBytesByWords(Object base, long offset, int length, long seed) { + long end = offset + length; long hash; if (length >= 32) { - long limit = length - 32; + long limit = end - 32; long v1 = seed + PRIME64_1 + PRIME64_2; long v2 = seed + PRIME64_2; long v3 = seed; long v4 = seed - PRIME64_1; do { - v1 += mb.getLong(offset) * PRIME64_2; + v1 += Platform.getLong(base, offset) * PRIME64_2; v1 = Long.rotateLeft(v1, 31); v1 *= PRIME64_1; - v2 += mb.getLong(offset + 8) * PRIME64_2; + v2 += Platform.getLong(base, offset + 8) * PRIME64_2; v2 = Long.rotateLeft(v2, 31); v2 *= PRIME64_1; - v3 += mb.getLong(offset + 16) * PRIME64_2; + v3 += Platform.getLong(base, offset + 16) * PRIME64_2; v3 = Long.rotateLeft(v3, 31); v3 *= PRIME64_1; - v4 += mb.getLong(offset + 24) * PRIME64_2; + v4 += Platform.getLong(base, offset + 24) * PRIME64_2; v4 = Long.rotateLeft(v4, 31); v4 *= PRIME64_1; @@ -190,9 +183,9 @@ private static long hashBytesByWordsBlock(MemoryBlock mb, long seed) { hash += length; - long limit = length - 8; + long limit = end - 8; while (offset <= limit) { - long k1 = mb.getLong(offset); + long k1 = Platform.getLong(base, offset); hash ^= Long.rotateLeft(k1 * PRIME64_2, 31) * PRIME64_1; hash = Long.rotateLeft(hash, 27) * PRIME64_1 + PRIME64_4; offset += 8L; diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolder.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolder.java index 537ef244b7e81..6a52a5b0e0664 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolder.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolder.java @@ -35,6 +35,7 @@ final class BufferHolder { private static final int ARRAY_MAX = ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH; + // buffer is guarantee to be word-aligned since UnsafeRow assumes each field is word-aligned. private byte[] buffer; private int cursor = Platform.BYTE_ARRAY_OFFSET; private final UnsafeRow row; @@ -52,7 +53,8 @@ final class BufferHolder { "too many fields (number of fields: " + row.numFields() + ")"); } this.fixedSize = bitsetWidthInBytes + 8 * row.numFields(); - this.buffer = new byte[fixedSize + initialSize]; + int roundedSize = ByteArrayMethods.roundNumberOfBytesToNearestWord(fixedSize + initialSize); + this.buffer = new byte[roundedSize]; this.row = row; this.row.pointTo(buffer, buffer.length); } @@ -61,8 +63,12 @@ final class BufferHolder { * Grows the buffer by at least neededSize and points the row to the buffer. */ void grow(int neededSize) { + if (neededSize < 0) { + throw new IllegalArgumentException( + "Cannot grow BufferHolder by size " + neededSize + " because the size is negative"); + } if (neededSize > ARRAY_MAX - totalSize()) { - throw new UnsupportedOperationException( + throw new IllegalArgumentException( "Cannot grow BufferHolder by size " + neededSize + " because the size after growing " + "exceeds size limitation " + ARRAY_MAX); } @@ -70,7 +76,8 @@ void grow(int neededSize) { if (buffer.length < length) { // This will not happen frequently, because the buffer is re-used. int newLength = length < ARRAY_MAX / 2 ? length * 2 : ARRAY_MAX; - final byte[] tmp = new byte[newLength]; + int roundedSize = ByteArrayMethods.roundNumberOfBytesToNearestWord(newLength); + final byte[] tmp = new byte[roundedSize]; Platform.copyMemory( buffer, Platform.BYTE_ARRAY_OFFSET, diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UTF8StringBuilder.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UTF8StringBuilder.java index f8000d78cd1b6..f0f66bae245fd 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UTF8StringBuilder.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UTF8StringBuilder.java @@ -19,8 +19,6 @@ import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.array.ByteArrayMethods; -import org.apache.spark.unsafe.memory.ByteArrayMemoryBlock; -import org.apache.spark.unsafe.memory.MemoryBlock; import org.apache.spark.unsafe.types.UTF8String; /** @@ -31,34 +29,43 @@ public class UTF8StringBuilder { private static final int ARRAY_MAX = ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH; - private ByteArrayMemoryBlock buffer; - private int length = 0; + private byte[] buffer; + private int cursor = Platform.BYTE_ARRAY_OFFSET; public UTF8StringBuilder() { // Since initial buffer size is 16 in `StringBuilder`, we set the same size here - this.buffer = new ByteArrayMemoryBlock(16); + this.buffer = new byte[16]; } // Grows the buffer by at least `neededSize` private void grow(int neededSize) { - if (neededSize > ARRAY_MAX - length) { + if (neededSize > ARRAY_MAX - totalSize()) { throw new UnsupportedOperationException( "Cannot grow internal buffer by size " + neededSize + " because the size after growing " + "exceeds size limitation " + ARRAY_MAX); } - final int requestedSize = length + neededSize; - if (buffer.size() < requestedSize) { - int newLength = requestedSize < ARRAY_MAX / 2 ? requestedSize * 2 : ARRAY_MAX; - final ByteArrayMemoryBlock tmp = new ByteArrayMemoryBlock(newLength); - MemoryBlock.copyMemory(buffer, tmp, length); + final int length = totalSize() + neededSize; + if (buffer.length < length) { + int newLength = length < ARRAY_MAX / 2 ? length * 2 : ARRAY_MAX; + final byte[] tmp = new byte[newLength]; + Platform.copyMemory( + buffer, + Platform.BYTE_ARRAY_OFFSET, + tmp, + Platform.BYTE_ARRAY_OFFSET, + totalSize()); buffer = tmp; } } + private int totalSize() { + return cursor - Platform.BYTE_ARRAY_OFFSET; + } + public void append(UTF8String value) { grow(value.numBytes()); - value.writeToMemory(buffer.getByteArray(), length + Platform.BYTE_ARRAY_OFFSET); - length += value.numBytes(); + value.writeToMemory(buffer, cursor); + cursor += value.numBytes(); } public void append(String value) { @@ -66,6 +73,6 @@ public void append(String value) { } public UTF8String build() { - return UTF8String.fromBytes(buffer.getByteArray(), 0, length); + return UTF8String.fromBytes(buffer, 0, totalSize()); } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/RecordBinaryComparator.java b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/RecordBinaryComparator.java index bb77b5bf6de2a..40c2cc806e87a 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/RecordBinaryComparator.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/RecordBinaryComparator.java @@ -22,12 +22,10 @@ public final class RecordBinaryComparator extends RecordComparator { - // TODO(jiangxb) Add test suite for this. @Override public int compare( Object leftObj, long leftOff, int leftLen, Object rightObj, long rightOff, int rightLen) { int i = 0; - int res = 0; // If the arrays have different length, the longer one is larger. if (leftLen != rightLen) { @@ -40,27 +38,33 @@ public int compare( // check if stars align and we can get both offsets to be aligned if ((leftOff % 8) == (rightOff % 8)) { while ((leftOff + i) % 8 != 0 && i < leftLen) { - res = (Platform.getByte(leftObj, leftOff + i) & 0xff) - - (Platform.getByte(rightObj, rightOff + i) & 0xff); - if (res != 0) return res; + final int v1 = Platform.getByte(leftObj, leftOff + i) & 0xff; + final int v2 = Platform.getByte(rightObj, rightOff + i) & 0xff; + if (v1 != v2) { + return v1 > v2 ? 1 : -1; + } i += 1; } } // for architectures that support unaligned accesses, chew it up 8 bytes at a time if (Platform.unaligned() || (((leftOff + i) % 8 == 0) && ((rightOff + i) % 8 == 0))) { while (i <= leftLen - 8) { - res = (int) ((Platform.getLong(leftObj, leftOff + i) - - Platform.getLong(rightObj, rightOff + i)) % Integer.MAX_VALUE); - if (res != 0) return res; + final long v1 = Platform.getLong(leftObj, leftOff + i); + final long v2 = Platform.getLong(rightObj, rightOff + i); + if (v1 != v2) { + return v1 > v2 ? 1 : -1; + } i += 8; } } // this will finish off the unaligned comparisons, or do the entire aligned comparison // whichever is needed. while (i < leftLen) { - res = (Platform.getByte(leftObj, leftOff + i) & 0xff) - - (Platform.getByte(rightObj, rightOff + i) & 0xff); - if (res != 0) return res; + final int v1 = Platform.getByte(leftObj, leftOff + i) & 0xff; + final int v2 = Platform.getByte(rightObj, rightOff + i) & 0xff; + if (v1 != v2) { + return v1 > v2 ? 1 : -1; + } i += 1; } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala index 93df73ab1eaf6..6f5fbdd79e668 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala @@ -431,6 +431,12 @@ object CatalystTypeConverters { map, (key: Any) => convertToCatalyst(key), (value: Any) => convertToCatalyst(value)) + case (keys: Array[_], values: Array[_]) => + // case for mapdata with duplicate keys + new ArrayBasedMapData( + new GenericArrayData(keys.map(convertToCatalyst)), + new GenericArrayData(values.map(convertToCatalyst)) + ) case other => other } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index f9acc208b715e..0238d57de2446 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -709,6 +709,8 @@ object ScalaReflection extends ScalaReflection { def attributesFor[T: TypeTag]: Seq[Attribute] = schemaFor[T] match { case Schema(s: StructType, _) => s.toAttributes + case others => + throw new UnsupportedOperationException(s"Attributes for type $others is not supported") } /** Returns a catalyst DataType and its nullability for the given Scala Type using reflection. */ @@ -798,7 +800,12 @@ object ScalaReflection extends ScalaReflection { * Whether the fields of the given type is defined entirely by its constructor parameters. */ def definedByConstructorParams(tpe: Type): Boolean = cleanUpReflectionObjects { - tpe.dealias <:< localTypeOf[Product] || tpe.dealias <:< localTypeOf[DefinedByConstructorParams] + tpe.dealias match { + // `Option` is a `Product`, but we don't wanna treat `Option[Int]` as a struct type. + case t if t <:< localTypeOf[Option[_]] => definedByConstructorParams(t.typeArgs.head) + case _ => tpe.dealias <:< localTypeOf[Product] || + tpe.dealias <:< localTypeOf[DefinedByConstructorParams] + } } private val javaKeywords = Set("abstract", "assert", "boolean", "break", "byte", "case", "catch", @@ -925,15 +932,6 @@ trait ScalaReflection { tpe.dealias.erasure.typeSymbol.asClass.fullName } - /** - * Returns classes of input parameters of scala function object. - */ - def getParameterTypes(func: AnyRef): Seq[Class[_]] = { - val methods = func.getClass.getMethods.filter(m => m.getName == "apply" && !m.isBridge) - assert(methods.length == 1) - methods.head.getParameterTypes - } - /** * Returns the parameter names and types for the primary constructor of this type. * diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index e187133d03b17..580133dd971b1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -17,6 +17,9 @@ package org.apache.spark.sql.catalyst.analysis +import java.util.Locale + +import scala.collection.mutable import scala.collection.mutable.ArrayBuffer import scala.util.Random @@ -27,7 +30,7 @@ import org.apache.spark.sql.catalyst.encoders.OuterScopes import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.SubExprUtils._ import org.apache.spark.sql.catalyst.expressions.aggregate._ -import org.apache.spark.sql.catalyst.expressions.objects.{LambdaVariable, MapObjects, NewInstance, UnresolvedMapObjects} +import org.apache.spark.sql.catalyst.expressions.objects._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ @@ -99,11 +102,11 @@ class Analyzer( this(catalog, conf, conf.optimizerMaxIterations) } - def executeAndCheck(plan: LogicalPlan): LogicalPlan = { + def executeAndCheck(plan: LogicalPlan): LogicalPlan = AnalysisHelper.markInAnalyzer { val analyzed = execute(plan) try { checkAnalysis(analyzed) - EliminateBarriers(analyzed) + analyzed } catch { case e: AnalysisException => val ae = new AnalysisException(e.message, e.line, e.startPosition, Option(analyzed)) @@ -142,6 +145,7 @@ class Analyzer( lazy val batches: Seq[Batch] = Seq( Batch("Hints", fixedPoint, new ResolveHints.ResolveBroadcastHints(conf), + ResolveHints.ResolveCoalesceHints, ResolveHints.RemoveAllHints), Batch("Simple Sanity Check", Once, LookupFunctions), @@ -172,13 +176,16 @@ class Analyzer( ResolveWindowOrder :: ResolveWindowFrame :: ResolveNaturalAndUsingJoin :: + ResolveOutputRelation :: ExtractWindowExpressions :: GlobalAggregates :: ResolveAggregateFunctions :: TimeWindowing :: ResolveInlineTables(conf) :: + ResolveHigherOrderFunctions(catalog) :: + ResolveLambdaVariables(conf) :: ResolveTimeZone(conf) :: - ResolvedUuidExpressions :: + ResolveRandomSeed :: TypeCoercion.typeCoercionRules(conf) ++ extendedResolutionRules : _*), Batch("Post-Hoc Resolution", Once, postHocResolutionRules: _*), @@ -200,7 +207,7 @@ class Analyzer( * Analyze cte definitions and substitute child plan with analyzed cte definitions. */ object CTESubstitution extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp { case With(child, relations) => substituteCTE(child, relations.foldLeft(Seq.empty[(String, LogicalPlan)]) { case (resolved, (name, relation)) => @@ -210,8 +217,8 @@ class Analyzer( } def substituteCTE(plan: LogicalPlan, cteRelations: Seq[(String, LogicalPlan)]): LogicalPlan = { - plan transformDown { - case u : UnresolvedRelation => + plan resolveOperatorsDown { + case u: UnresolvedRelation => cteRelations.find(x => resolver(x._1, u.tableIdentifier.table)) .map(_._2).getOrElse(u) case other => @@ -228,19 +235,16 @@ class Analyzer( * Substitute child plan with WindowSpecDefinitions. */ object WindowsSubstitution extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp { // Lookup WindowSpecDefinitions. This rule works with unresolved children. - case WithWindowDefinition(windowDefinitions, child) => - child.transform { - case p => p.transformExpressions { - case UnresolvedWindowExpression(c, WindowSpecReference(windowName)) => - val errorMessage = - s"Window specification $windowName is not defined in the WINDOW clause." - val windowSpecDefinition = - windowDefinitions.getOrElse(windowName, failAnalysis(errorMessage)) - WindowExpression(c, windowSpecDefinition) - } - } + case WithWindowDefinition(windowDefinitions, child) => child.resolveExpressions { + case UnresolvedWindowExpression(c, WindowSpecReference(windowName)) => + val errorMessage = + s"Window specification $windowName is not defined in the WINDOW clause." + val windowSpecDefinition = + windowDefinitions.getOrElse(windowName, failAnalysis(errorMessage)) + WindowExpression(c, windowSpecDefinition) + } } } @@ -268,7 +272,7 @@ class Analyzer( private def hasUnresolvedAlias(exprs: Seq[NamedExpression]) = exprs.exists(_.find(_.isInstanceOf[UnresolvedAlias]).isDefined) - def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp { case Aggregate(groups, aggs, child) if child.resolved && hasUnresolvedAlias(aggs) => Aggregate(groups, assignAliases(aggs), child) @@ -439,17 +443,35 @@ class Analyzer( child: LogicalPlan): LogicalPlan = { val gid = AttributeReference(VirtualColumn.groupingIdName, IntegerType, false)() + // In case of ANSI-SQL compliant syntax for GROUPING SETS, groupByExprs is optional and + // can be null. In such case, we derive the groupByExprs from the user supplied values for + // grouping sets. + val finalGroupByExpressions = if (groupByExprs == Nil) { + selectedGroupByExprs.flatten.foldLeft(Seq.empty[Expression]) { (result, currentExpr) => + // Only unique expressions are included in the group by expressions and is determined + // based on their semantic equality. Example. grouping sets ((a * b), (b * a)) results + // in grouping expression (a * b) + if (result.find(_.semanticEquals(currentExpr)).isDefined) { + result + } else { + result :+ currentExpr + } + } + } else { + groupByExprs + } + // Expand works by setting grouping expressions to null as determined by the // `selectedGroupByExprs`. To prevent these null values from being used in an aggregate // instead of the original value we need to create new aliases for all group by expressions // that will only be used for the intended purpose. - val groupByAliases = constructGroupByAlias(groupByExprs) + val groupByAliases = constructGroupByAlias(finalGroupByExpressions) val expand = constructExpand(selectedGroupByExprs, child, groupByAliases, gid) val groupingAttrs = expand.output.drop(child.output.length) val aggregations = constructAggregateExprs( - groupByExprs, aggregationExprs, groupByAliases, groupingAttrs, gid) + finalGroupByExpressions, aggregationExprs, groupByAliases, groupingAttrs, gid) Aggregate(groupingAttrs, aggregations, expand) } @@ -470,7 +492,7 @@ class Analyzer( } // This require transformUp to replace grouping()/grouping_id() in resolved Filter/Sort - def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { + def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperatorsUp { case a if !a.childrenResolved => a // be sure all of the children are resolved. // Ensure group by expressions and aggregate expressions have been resolved. @@ -503,25 +525,46 @@ class Analyzer( } object ResolvePivot extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan transform { + def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case p: Pivot if !p.childrenResolved || !p.aggregates.forall(_.resolved) || (p.groupByExprsOpt.isDefined && !p.groupByExprsOpt.get.forall(_.resolved)) - || !p.pivotColumn.resolved => p + || !p.pivotColumn.resolved || !p.pivotValues.forall(_.resolved) => p case Pivot(groupByExprsOpt, pivotColumn, pivotValues, aggregates, child) => + if (!RowOrdering.isOrderable(pivotColumn.dataType)) { + throw new AnalysisException( + s"Invalid pivot column '${pivotColumn}'. Pivot columns must be comparable.") + } // Check all aggregate expressions. - aggregates.foreach { e => - if (!isAggregateExpression(e)) { - throw new AnalysisException( - s"Aggregate expression required for pivot, found '$e'") + aggregates.foreach(checkValidAggregateExpression) + // Check all pivot values are literal and match pivot column data type. + val evalPivotValues = pivotValues.map { value => + val foldable = value match { + case Alias(v, _) => v.foldable + case _ => value.foldable + } + if (!foldable) { + throw new AnalysisException( + s"Literal expressions required for pivot values, found '$value'") + } + if (!Cast.canCast(value.dataType, pivotColumn.dataType)) { + throw new AnalysisException(s"Invalid pivot value '$value': " + + s"value data type ${value.dataType.simpleString} does not match " + + s"pivot column data type ${pivotColumn.dataType.catalogString}") } + Cast(value, pivotColumn.dataType, Some(conf.sessionLocalTimeZone)).eval(EmptyRow) } // Group-by expressions coming from SQL are implicit and need to be deduced. val groupByExprs = groupByExprsOpt.getOrElse( (child.outputSet -- aggregates.flatMap(_.references) -- pivotColumn.references).toSeq) val singleAgg = aggregates.size == 1 - def outputName(value: Literal, aggregate: Expression): String = { - val utf8Value = Cast(value, StringType, Some(conf.sessionLocalTimeZone)).eval(EmptyRow) - val stringValue: String = Option(utf8Value).map(_.toString).getOrElse("null") + def outputName(value: Expression, aggregate: Expression): String = { + val stringValue = value match { + case n: NamedExpression => n.name + case _ => + val utf8Value = + Cast(value, StringType, Some(conf.sessionLocalTimeZone)).eval(EmptyRow) + Option(utf8Value).map(_.toString).getOrElse("null") + } if (singleAgg) { stringValue } else { @@ -542,9 +585,8 @@ class Analyzer( } val bigGroup = groupByExprs :+ namedPivotCol val firstAgg = Aggregate(bigGroup, bigGroup ++ namedAggExps, child) - val castPivotValues = pivotValues.map(Cast(_, pivotColumn.dataType).eval(EmptyRow)) val pivotAggs = namedAggExps.map { a => - Alias(PivotFirst(namedPivotCol.toAttribute, a.toAttribute, castPivotValues) + Alias(PivotFirst(namedPivotCol.toAttribute, a.toAttribute, evalPivotValues) .toAggregateExpression() , "__pivot_" + a.sql)() } @@ -559,8 +601,12 @@ class Analyzer( Project(groupByExprsAttr ++ pivotOutputs, secondAgg) } else { val pivotAggregates: Seq[NamedExpression] = pivotValues.flatMap { value => - def ifExpr(expr: Expression) = { - If(EqualNullSafe(pivotColumn, value), expr, Literal(null)) + def ifExpr(e: Expression) = { + If( + EqualNullSafe( + pivotColumn, + Cast(value, pivotColumn.dataType, Some(conf.sessionLocalTimeZone))), + e, Literal(null)) } aggregates.map { aggregate => val filteredAggregate = aggregate.transformDown { @@ -586,12 +632,17 @@ class Analyzer( } } - private def isAggregateExpression(expr: Expression): Boolean = { - expr match { - case Alias(e, _) => isAggregateExpression(e) - case AggregateExpression(_, _, _, _) => true - case _ => false - } + // Support any aggregate expression that can appear in an Aggregate plan except Pandas UDF. + // TODO: Support Pandas UDF. + private def checkValidAggregateExpression(expr: Expression): Unit = expr match { + case _: AggregateExpression => // OK and leave the argument check to CheckAnalysis. + case expr: PythonUDF if PythonUDF.isGroupedAggPandasUDF(expr) => + failAnalysis("Pandas UDF aggregate expressions are currently not supported in pivot.") + case e: Attribute => + failAnalysis( + s"Aggregate expression required for pivot, but '${e.sql}' " + + s"did not appear in any aggregate function.") + case e => e.children.foreach(checkValidAggregateExpression) } } @@ -652,7 +703,7 @@ class Analyzer( case _ => plan } - def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp { case i @ InsertIntoTable(u: UnresolvedRelation, parts, child, _, _) if child.resolved => EliminateSubqueryAliases(lookupTableFromCatalog(u)) match { case v: View => @@ -713,12 +764,6 @@ class Analyzer( s"between $left and $right") right.collect { - // For `AnalysisBarrier`, recursively de-duplicate its child. - case oldVersion: AnalysisBarrier - if oldVersion.outputSet.intersect(conflictingAttributes).nonEmpty => - val newVersion = dedupRight(left, oldVersion.child) - (oldVersion, AnalysisBarrier(newVersion)) - // Handle base relations that might appear more than once. case oldVersion: MultiInstanceRelation if oldVersion.outputSet.intersect(conflictingAttributes).nonEmpty => @@ -738,6 +783,10 @@ class Analyzer( if findAliases(aggregateExpressions).intersect(conflictingAttributes).nonEmpty => (oldVersion, oldVersion.copy(aggregateExpressions = newAliases(aggregateExpressions))) + case oldVersion @ FlatMapGroupsInPandas(_, _, output, _) + if oldVersion.outputSet.intersect(conflictingAttributes).nonEmpty => + (oldVersion, oldVersion.copy(output = output.map(_.newInstance()))) + case oldVersion: Generate if oldVersion.producedAttributes.intersect(conflictingAttributes).nonEmpty => val newOutput = oldVersion.generatorOutput.map(_.newInstance()) @@ -819,7 +868,7 @@ class Analyzer( private def dedupOuterReferencesInSubquery( plan: LogicalPlan, attrMap: AttributeMap[Attribute]): LogicalPlan = { - plan transformDown { case currentFragment => + plan resolveOperatorsDown { case currentFragment => currentFragment transformExpressions { case OuterReference(a: Attribute) => OuterReference(dedupAttr(a, attrMap)) @@ -830,6 +879,7 @@ class Analyzer( } private def resolve(e: Expression, q: LogicalPlan): Expression = e match { + case f: LambdaFunction if !f.bound => f case u @ UnresolvedAttribute(nameParts) => // Leave unchanged if resolution fails. Hopefully will be resolved next round. val result = @@ -845,7 +895,7 @@ class Analyzer( case _ => e.mapChildren(resolve(_, q)) } - def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp { case p: LogicalPlan if !p.childrenResolved => p // If the projection list contains Stars, expand it. @@ -873,11 +923,10 @@ class Analyzer( // To resolve duplicate expression IDs for Join and Intersect case j @ Join(left, right, _, _) if !j.duplicateResolved => j.copy(right = dedupRight(left, right)) - case i @ Intersect(left, right) if !i.duplicateResolved => - i.copy(right = dedupRight(left, right)) - case i @ Except(left, right) if !i.duplicateResolved => + case i @ Intersect(left, right, _) if !i.duplicateResolved => i.copy(right = dedupRight(left, right)) - + case e @ Except(left, right, _) if !e.duplicateResolved => + e.copy(right = dedupRight(left, right)) // When resolve `SortOrder`s in Sort based on child, don't report errors as // we still have chance to resolve it based on its descendants case s @ Sort(ordering, global, child) if child.resolved && !s.resolved => @@ -1040,7 +1089,7 @@ class Analyzer( * have no effect on the results. */ object ResolveOrdinalInOrderByAndGroupBy extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp { case p if !p.childrenResolved => p // Replace the index with the related attribute for ORDER BY, // which is a 1-base position of the projection list. @@ -1096,7 +1145,7 @@ class Analyzer( }} } - override def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { + override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp { case agg @ Aggregate(groups, aggs, child) if conf.groupByAliases && child.resolved && aggs.forall(_.resolved) && groups.exists(!_.resolved) => @@ -1120,12 +1169,12 @@ class Analyzer( * The HAVING clause could also used a grouping columns that is not presented in the SELECT. */ object ResolveMissingReferences extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp { // Skip sort with aggregate. This will be handled in ResolveAggregateFunctions - case sa @ Sort(_, _, AnalysisBarrier(child: Aggregate)) => sa case sa @ Sort(_, _, child: Aggregate) => sa - case s @ Sort(order, _, child) if !s.resolved && child.resolved => + case s @ Sort(order, _, child) + if (!s.resolved || s.missingInput.nonEmpty) && child.resolved => val (newOrder, newChild) = resolveExprsAndAddMissingAttrs(order, child) val ordering = newOrder.map(_.asInstanceOf[SortOrder]) if (child.output == newChild.output) { @@ -1136,7 +1185,7 @@ class Analyzer( Project(child.output, newSort) } - case f @ Filter(cond, child) if !f.resolved && child.resolved => + case f @ Filter(cond, child) if (!f.resolved || f.missingInput.nonEmpty) && child.resolved => val (newCond, newChild) = resolveExprsAndAddMissingAttrs(Seq(cond), child) if (child.output == newChild.output) { f.copy(condition = newCond.head) @@ -1147,29 +1196,34 @@ class Analyzer( } } + /** + * This method tries to resolve expressions and find missing attributes recursively. Specially, + * when the expressions used in `Sort` or `Filter` contain unresolved attributes or resolved + * attributes which are missed from child output. This method tries to find the missing + * attributes out and add into the projection. + */ private def resolveExprsAndAddMissingAttrs( exprs: Seq[Expression], plan: LogicalPlan): (Seq[Expression], LogicalPlan) = { - if (exprs.forall(_.resolved)) { - // All given expressions are resolved, no need to continue anymore. + // Missing attributes can be unresolved attributes or resolved attributes which are not in + // the output attributes of the plan. + if (exprs.forall(e => e.resolved && e.references.subsetOf(plan.outputSet))) { (exprs, plan) } else { plan match { - // For `AnalysisBarrier`, recursively resolve expressions and add missing attributes via - // its child. - case barrier: AnalysisBarrier => - val (newExprs, newChild) = resolveExprsAndAddMissingAttrs(exprs, barrier.child) - (newExprs, AnalysisBarrier(newChild)) - case p: Project => + // Resolving expressions against current plan. val maybeResolvedExprs = exprs.map(resolveExpression(_, p)) + // Recursively resolving expressions on the child of current plan. val (newExprs, newChild) = resolveExprsAndAddMissingAttrs(maybeResolvedExprs, p.child) - val missingAttrs = AttributeSet(newExprs) -- AttributeSet(maybeResolvedExprs) + // If some attributes used by expressions are resolvable only on the rewritten child + // plan, we need to add them into original projection. + val missingAttrs = (AttributeSet(newExprs) -- p.outputSet).intersect(newChild.outputSet) (newExprs, Project(p.projectList ++ missingAttrs, newChild)) case a @ Aggregate(groupExprs, aggExprs, child) => val maybeResolvedExprs = exprs.map(resolveExpression(_, a)) val (newExprs, newChild) = resolveExprsAndAddMissingAttrs(maybeResolvedExprs, child) - val missingAttrs = AttributeSet(newExprs) -- AttributeSet(maybeResolvedExprs) + val missingAttrs = (AttributeSet(newExprs) -- a.outputSet).intersect(newChild.outputSet) if (missingAttrs.forall(attr => groupExprs.exists(_.semanticEquals(attr)))) { // All the missing attributes are grouping expressions, valid case. (newExprs, a.copy(aggregateExpressions = aggExprs ++ missingAttrs, child = newChild)) @@ -1204,16 +1258,46 @@ class Analyzer( * only performs simple existence check according to the function identifier to quickly identify * undefined functions without triggering relation resolution, which may incur potentially * expensive partition/schema discovery process in some cases. - * + * In order to avoid duplicate external functions lookup, the external function identifier will + * store in the local hash set externalFunctionNameSet. * @see [[ResolveFunctions]] * @see https://issues.apache.org/jira/browse/SPARK-19737 */ object LookupFunctions extends Rule[LogicalPlan] { - override def apply(plan: LogicalPlan): LogicalPlan = plan.transformAllExpressions { - case f: UnresolvedFunction if !catalog.functionExists(f.name) => - withPosition(f) { - throw new NoSuchFunctionException(f.name.database.getOrElse("default"), f.name.funcName) - } + override def apply(plan: LogicalPlan): LogicalPlan = { + val externalFunctionNameSet = new mutable.HashSet[FunctionIdentifier]() + plan.resolveExpressions { + case f: UnresolvedFunction + if externalFunctionNameSet.contains(normalizeFuncName(f.name)) => f + case f: UnresolvedFunction if catalog.isRegisteredFunction(f.name) => f + case f: UnresolvedFunction if catalog.isPersistentFunction(f.name) => + externalFunctionNameSet.add(normalizeFuncName(f.name)) + f + case f: UnresolvedFunction => + withPosition(f) { + throw new NoSuchFunctionException(f.name.database.getOrElse(catalog.getCurrentDatabase), + f.name.funcName) + } + } + } + + def normalizeFuncName(name: FunctionIdentifier): FunctionIdentifier = { + val funcName = if (conf.caseSensitiveAnalysis) { + name.funcName + } else { + name.funcName.toLowerCase(Locale.ROOT) + } + + val databaseName = name.database match { + case Some(a) => formatDatabaseName(a) + case None => catalog.getCurrentDatabase + } + + FunctionIdentifier(funcName, Some(databaseName)) + } + + protected def formatDatabaseName(name: String): String = { + if (conf.caseSensitiveAnalysis) name else name.toLowerCase(Locale.ROOT) } } @@ -1221,7 +1305,7 @@ class Analyzer( * Replaces [[UnresolvedFunction]]s with concrete [[Expression]]s. */ object ResolveFunctions extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp { case q: LogicalPlan => q transformExpressions { case u if !u.childrenResolved => u // Skip until children are resolved. @@ -1276,7 +1360,7 @@ class Analyzer( * resolved outer references are wrapped in an [[OuterReference]] */ private def resolveOuterReferences(plan: LogicalPlan, outer: LogicalPlan): LogicalPlan = { - plan transformDown { + plan resolveOperatorsDown { case q: LogicalPlan if q.childrenResolved && !q.resolved => q transformExpressions { case u @ UnresolvedAttribute(nameParts) => @@ -1347,18 +1431,33 @@ class Analyzer( resolveSubQuery(s, plans)(ScalarSubquery(_, _, exprId)) case e @ Exists(sub, _, exprId) if !sub.resolved => resolveSubQuery(e, plans)(Exists(_, _, exprId)) - case In(value, Seq(l @ ListQuery(sub, _, exprId, _))) if value.resolved && !l.resolved => + case InSubquery(values, l @ ListQuery(_, _, exprId, _)) + if values.forall(_.resolved) && !l.resolved => val expr = resolveSubQuery(l, plans)((plan, exprs) => { ListQuery(plan, exprs, exprId, plan.output) }) - In(value, Seq(expr)) + val subqueryOutput = expr.plan.output + val resolvedIn = InSubquery(values, expr.asInstanceOf[ListQuery]) + if (values.length != subqueryOutput.length) { + throw new AnalysisException( + s"""Cannot analyze ${resolvedIn.sql}. + |The number of columns in the left hand side of an IN subquery does not match the + |number of columns in the output of subquery. + |#columns in left hand side: ${values.length} + |#columns in right hand side: ${subqueryOutput.length} + |Left side columns: + |[${values.map(_.sql).mkString(", ")}] + |Right side columns: + |[${subqueryOutput.map(_.sql).mkString(", ")}]""".stripMargin) + } + resolvedIn } } /** * Resolve and rewrite all subqueries in an operator tree.. */ - def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp { // In case of HAVING (a filter after an aggregate) we use both the aggregate and // its child for resolution. case f @ Filter(_, a: Aggregate) if f.childrenResolved => @@ -1374,7 +1473,7 @@ class Analyzer( */ object ResolveSubqueryColumnAliases extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp { case u @ UnresolvedSubqueryColumnAliases(columnNames, child) if child.resolved => // Resolves output attributes if a query has alias names in its subquery: // e.g., SELECT * FROM (SELECT 1 AS a, 1 AS b) t(col1, col2) @@ -1397,7 +1496,7 @@ class Analyzer( * Turns projections that contain aggregate expressions into aggregations. */ object GlobalAggregates extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { case Project(projectList, child) if containsAggregates(projectList) => Aggregate(Nil, projectList, child) } @@ -1423,9 +1522,7 @@ class Analyzer( * underlying aggregate operator and then projected away after the original operator. */ object ResolveAggregateFunctions extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { - case Filter(cond, AnalysisBarrier(agg: Aggregate)) => - apply(Filter(cond, agg)).mapChildren(AnalysisBarrier) + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp { case f @ Filter(cond, agg @ Aggregate(grouping, originalAggExprs, child)) if agg.resolved => // Try resolving the condition of the filter as though it is in the aggregate clause @@ -1483,13 +1580,15 @@ class Analyzer( case ae: AnalysisException => f } - case Sort(sortOrder, global, AnalysisBarrier(aggregate: Aggregate)) => - apply(Sort(sortOrder, global, aggregate)).mapChildren(AnalysisBarrier) case sort @ Sort(sortOrder, global, aggregate: Aggregate) if aggregate.resolved => // Try resolving the ordering as though it is in the aggregate clause. try { - val unresolvedSortOrders = sortOrder.filter(s => !s.resolved || containsAggregate(s)) + // If a sort order is unresolved, containing references not in aggregate, or containing + // `AggregateExpression`, we need to push down it to the underlying aggregate operator. + val unresolvedSortOrders = sortOrder.filter { s => + !s.resolved || !s.references.subsetOf(aggregate.outputSet) || containsAggregate(s) + } val aliasedOrdering = unresolvedSortOrders.map(o => Alias(o.child, "aggOrder")()) val aggregatedOrdering = aggregate.copy(aggregateExpressions = aliasedOrdering) @@ -1568,11 +1667,13 @@ class Analyzer( expr.find(_.isInstanceOf[Generator]).isDefined } - private def hasNestedGenerator(expr: NamedExpression): Boolean = expr match { - case UnresolvedAlias(_: Generator, _) => false - case Alias(_: Generator, _) => false - case MultiAlias(_: Generator, _) => false - case other => hasGenerator(other) + private def hasNestedGenerator(expr: NamedExpression): Boolean = { + CleanupAliases.trimNonTopLevelAliases(expr) match { + case UnresolvedAlias(_: Generator, _) => false + case Alias(_: Generator, _) => false + case MultiAlias(_: Generator, _) => false + case other => hasGenerator(other) + } } private def trimAlias(expr: NamedExpression): Expression = expr match { @@ -1598,7 +1699,7 @@ class Analyzer( } } - def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp { case Project(projectList, _) if projectList.exists(hasNestedGenerator) => val nestedGenerator = projectList.find(hasNestedGenerator).get throw new AnalysisException("Generators are not supported when it's nested in " + @@ -1613,24 +1714,26 @@ class Analyzer( // Holds the resolved generator, if one exists in the project list. var resolvedGenerator: Generate = null - val newProjectList = projectList.flatMap { - case AliasedGenerator(generator, names, outer) if generator.childrenResolved => - // It's a sanity check, this should not happen as the previous case will throw - // exception earlier. - assert(resolvedGenerator == null, "More than one generator found in SELECT.") - - resolvedGenerator = - Generate( - generator, - unrequiredChildIndex = Nil, - outer = outer, - qualifier = None, - generatorOutput = ResolveGenerate.makeGeneratorOutput(generator, names), - child) - - resolvedGenerator.generatorOutput - case other => other :: Nil - } + val newProjectList = projectList + .map(CleanupAliases.trimNonTopLevelAliases(_).asInstanceOf[NamedExpression]) + .flatMap { + case AliasedGenerator(generator, names, outer) if generator.childrenResolved => + // It's a sanity check, this should not happen as the previous case will throw + // exception earlier. + assert(resolvedGenerator == null, "More than one generator found in SELECT.") + + resolvedGenerator = + Generate( + generator, + unrequiredChildIndex = Nil, + outer = outer, + qualifier = None, + generatorOutput = ResolveGenerate.makeGeneratorOutput(generator, names), + child) + + resolvedGenerator.generatorOutput + case other => other :: Nil + } if (resolvedGenerator != null) { Project(newProjectList, resolvedGenerator) @@ -1656,7 +1759,7 @@ class Analyzer( * that wrap the [[Generator]]. */ object ResolveGenerate extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp { case g: Generate if !g.child.resolved || !g.generator.resolved => g case g: Generate if !g.resolved => g.copy(generatorOutput = makeGeneratorOutput(g.generator, g.generatorOutput.map(_.name))) @@ -1697,7 +1800,7 @@ class Analyzer( */ object FixNullability extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { + def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperatorsUp { case p if !p.resolved => p // Skip unresolved nodes. case p: LogicalPlan if p.resolved => val childrenOutput = p.children.flatMap(c => c.output).groupBy(_.exprId).flatMap { @@ -1921,7 +2024,7 @@ class Analyzer( // We have to use transformDown at here to make sure the rule of // "Aggregate with Having clause" will be triggered. - def apply(plan: LogicalPlan): LogicalPlan = plan transformDown { + def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperatorsDown { case Filter(condition, _) if hasWindowFunction(condition) => failAnalysis("It is not allowed to use window functions inside WHERE and HAVING clauses") @@ -1981,7 +2084,7 @@ class Analyzer( * put them into an inner Project and finally project them away at the outer Project. */ object PullOutNondeterministic extends Rule[LogicalPlan] { - override def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { + override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp { case p if !p.resolved => p // Skip unresolved nodes. case p: Project => p case f: Filter => f @@ -2020,15 +2123,16 @@ class Analyzer( } /** - * Set the seed for random number generation in Uuid expressions. + * Set the seed for random number generation. */ - object ResolvedUuidExpressions extends Rule[LogicalPlan] { + object ResolveRandomSeed extends Rule[LogicalPlan] { private lazy val random = new Random() - override def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { + override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp { case p if p.resolved => p case p => p transformExpressionsUp { case Uuid(None) => Uuid(Some(random.nextLong())) + case Shuffle(child, None) => Shuffle(child, Some(random.nextLong())) } } } @@ -2040,23 +2144,39 @@ class Analyzer( * and we should return null if the input is null. */ object HandleNullInputsForUDF extends Rule[LogicalPlan] { - override def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { + override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp { case p if !p.resolved => p // Skip unresolved nodes. case p => p transformExpressionsUp { - case udf @ ScalaUDF(func, _, inputs, _, _, _, _) => - val parameterTypes = ScalaReflection.getParameterTypes(func) - assert(parameterTypes.length == inputs.length) + case udf@ScalaUDF(func, _, inputs, _, _, _, _, nullableTypes) => + if (nullableTypes.isEmpty) { + // If no nullability info is available, do nothing. No fields will be specially + // checked for null in the plan. If nullability info is incorrect, the results + // of the UDF could be wrong. + udf + } else { + // Otherwise, add special handling of null for fields that can't accept null. + // The result of operations like this, when passed null, is generally to return null. + assert(nullableTypes.length == inputs.length) - val inputsNullCheck = parameterTypes.zip(inputs) // TODO: skip null handling for not-nullable primitive inputs after we can completely // trust the `nullable` information. - // .filter { case (cls, expr) => cls.isPrimitive && expr.nullable } - .filter { case (cls, _) => cls.isPrimitive } - .map { case (_, expr) => IsNull(expr) } - .reduceLeftOption[Expression]((e1, e2) => Or(e1, e2)) - inputsNullCheck.map(If(_, Literal.create(null, udf.dataType), udf)).getOrElse(udf) + val inputsNullCheck = nullableTypes.zip(inputs) + .filter { case (nullable, _) => !nullable } + .map { case (_, expr) => IsNull(expr) } + .reduceLeftOption[Expression]((e1, e2) => Or(e1, e2)) + // Once we add an `If` check above the udf, it is safe to mark those checked inputs + // as not nullable (i.e., wrap them with `KnownNotNull`), because the null-returning + // branch of `If` will be called if any of these checked inputs is null. Thus we can + // prevent this rule from being applied repeatedly. + val newInputs = nullableTypes.zip(inputs).map { case (nullable, expr) => + if (nullable) expr else KnownNotNull(expr) + } + inputsNullCheck + .map(If(_, Literal.create(null, udf.dataType), udf.copy(children = newInputs))) + .getOrElse(udf) + } } } } @@ -2065,25 +2185,21 @@ class Analyzer( * Check and add proper window frames for all window functions. */ object ResolveWindowFrame extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan transform { - case logical: LogicalPlan => logical transformExpressions { - case WindowExpression(wf: WindowFunction, - WindowSpecDefinition(_, _, f: SpecifiedWindowFrame)) + def apply(plan: LogicalPlan): LogicalPlan = plan resolveExpressions { + case WindowExpression(wf: WindowFunction, WindowSpecDefinition(_, _, f: SpecifiedWindowFrame)) if wf.frame != UnspecifiedFrame && wf.frame != f => - failAnalysis(s"Window Frame $f must match the required frame ${wf.frame}") - case WindowExpression(wf: WindowFunction, - s @ WindowSpecDefinition(_, o, UnspecifiedFrame)) + failAnalysis(s"Window Frame $f must match the required frame ${wf.frame}") + case WindowExpression(wf: WindowFunction, s @ WindowSpecDefinition(_, _, UnspecifiedFrame)) if wf.frame != UnspecifiedFrame => - WindowExpression(wf, s.copy(frameSpecification = wf.frame)) - case we @ WindowExpression(e, s @ WindowSpecDefinition(_, o, UnspecifiedFrame)) + WindowExpression(wf, s.copy(frameSpecification = wf.frame)) + case we @ WindowExpression(e, s @ WindowSpecDefinition(_, o, UnspecifiedFrame)) if e.resolved => - val frame = if (o.nonEmpty) { - SpecifiedWindowFrame(RangeFrame, UnboundedPreceding, CurrentRow) - } else { - SpecifiedWindowFrame(RowFrame, UnboundedPreceding, UnboundedFollowing) - } - we.copy(windowSpec = s.copy(frameSpecification = frame)) - } + val frame = if (o.nonEmpty) { + SpecifiedWindowFrame(RangeFrame, UnboundedPreceding, CurrentRow) + } else { + SpecifiedWindowFrame(RowFrame, UnboundedPreceding, UnboundedFollowing) + } + we.copy(windowSpec = s.copy(frameSpecification = frame)) } } @@ -2091,16 +2207,14 @@ class Analyzer( * Check and add order to [[AggregateWindowFunction]]s. */ object ResolveWindowOrder extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan transform { - case logical: LogicalPlan => logical transformExpressions { - case WindowExpression(wf: WindowFunction, spec) if spec.orderSpec.isEmpty => - failAnalysis(s"Window function $wf requires window to be ordered, please add ORDER BY " + - s"clause. For example SELECT $wf(value_expr) OVER (PARTITION BY window_partition " + - s"ORDER BY window_ordering) from table") - case WindowExpression(rank: RankLike, spec) if spec.resolved => - val order = spec.orderSpec.map(_.child) - WindowExpression(rank.withOrder(order), spec) - } + def apply(plan: LogicalPlan): LogicalPlan = plan resolveExpressions { + case WindowExpression(wf: WindowFunction, spec) if spec.orderSpec.isEmpty => + failAnalysis(s"Window function $wf requires window to be ordered, please add ORDER BY " + + s"clause. For example SELECT $wf(value_expr) OVER (PARTITION BY window_partition " + + s"ORDER BY window_ordering) from table") + case WindowExpression(rank: RankLike, spec) if spec.resolved => + val order = spec.orderSpec.map(_.child) + WindowExpression(rank.withOrder(order), spec) } } @@ -2109,8 +2223,8 @@ class Analyzer( * Then apply a Project on a normal Join to eliminate natural or using join. */ object ResolveNaturalAndUsingJoin extends Rule[LogicalPlan] { - override def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { - case j @ Join(left, right, UsingJoin(joinType, usingCols), condition) + override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp { + case j @ Join(left, right, UsingJoin(joinType, usingCols), _) if left.resolved && right.resolved && j.duplicateResolved => commonNaturalJoinProcessing(left, right, joinType, usingCols, None) case j @ Join(left, right, NaturalJoin(joinType), condition) if j.resolvedExceptNatural => @@ -2120,6 +2234,102 @@ class Analyzer( } } + /** + * Resolves columns of an output table from the data in a logical plan. This rule will: + * + * - Reorder columns when the write is by name + * - Insert safe casts when data types do not match + * - Insert aliases when column names do not match + * - Detect plans that are not compatible with the output table and throw AnalysisException + */ + object ResolveOutputRelation extends Rule[LogicalPlan] { + override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { + case append @ AppendData(table, query, isByName) + if table.resolved && query.resolved && !append.resolved => + val projection = resolveOutputColumns(table.name, table.output, query, isByName) + + if (projection != query) { + append.copy(query = projection) + } else { + append + } + } + + def resolveOutputColumns( + tableName: String, + expected: Seq[Attribute], + query: LogicalPlan, + byName: Boolean): LogicalPlan = { + + if (expected.size < query.output.size) { + throw new AnalysisException( + s"""Cannot write to '$tableName', too many data columns: + |Table columns: ${expected.map(c => s"'${c.name}'").mkString(", ")} + |Data columns: ${query.output.map(c => s"'${c.name}'").mkString(", ")}""".stripMargin) + } + + val errors = new mutable.ArrayBuffer[String]() + val resolved: Seq[NamedExpression] = if (byName) { + expected.flatMap { tableAttr => + query.resolveQuoted(tableAttr.name, resolver) match { + case Some(queryExpr) => + checkField(tableAttr, queryExpr, err => errors += err) + case None => + errors += s"Cannot find data for output column '${tableAttr.name}'" + None + } + } + + } else { + if (expected.size > query.output.size) { + throw new AnalysisException( + s"""Cannot write to '$tableName', not enough data columns: + |Table columns: ${expected.map(c => s"'${c.name}'").mkString(", ")} + |Data columns: ${query.output.map(c => s"'${c.name}'").mkString(", ")}""" + .stripMargin) + } + + query.output.zip(expected).flatMap { + case (queryExpr, tableAttr) => + checkField(tableAttr, queryExpr, err => errors += err) + } + } + + if (errors.nonEmpty) { + throw new AnalysisException( + s"Cannot write incompatible data to table '$tableName':\n- ${errors.mkString("\n- ")}") + } + + Project(resolved, query) + } + + private def checkField( + tableAttr: Attribute, + queryExpr: NamedExpression, + addError: String => Unit): Option[NamedExpression] = { + + // run the type check first to ensure type errors are present + val canWrite = DataType.canWrite( + queryExpr.dataType, tableAttr.dataType, resolver, tableAttr.name, addError) + + if (queryExpr.nullable && !tableAttr.nullable) { + addError(s"Cannot write nullable values to non-null column '${tableAttr.name}'") + None + + } else if (!canWrite) { + None + + } else { + // always add an UpCast. it will be removed in the optimizer if it is unnecessary. + Some(Alias( + UpCast(queryExpr, tableAttr.dataType, Seq()), tableAttr.name + )( + explicitMetadata = Option(tableAttr.metadata) + )) + } + } + } + private def commonNaturalJoinProcessing( left: LogicalPlan, right: LogicalPlan, @@ -2174,7 +2384,7 @@ class Analyzer( * to the given input attributes. */ object ResolveDeserializer extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp { case p if !p.childrenResolved => p case p if p.resolved => p @@ -2199,7 +2409,7 @@ class Analyzer( } expr case other => - throw new AnalysisException("need an array field but got " + other.simpleString) + throw new AnalysisException("need an array field but got " + other.catalogString) } } validateNestedTupleFields(result) @@ -2208,8 +2418,8 @@ class Analyzer( } private def fail(schema: StructType, maxOrdinal: Int): Unit = { - throw new AnalysisException(s"Try to map ${schema.simpleString} to Tuple${maxOrdinal + 1}, " + - "but failed as the number of fields does not line up.") + throw new AnalysisException(s"Try to map ${schema.catalogString} to Tuple${maxOrdinal + 1}" + + ", but failed as the number of fields does not line up.") } /** @@ -2260,7 +2470,7 @@ class Analyzer( * constructed is an inner class. */ object ResolveNewInstance extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp { case p if !p.childrenResolved => p case p if p.resolved => p @@ -2288,13 +2498,13 @@ class Analyzer( case e => e.sql } throw new AnalysisException(s"Cannot up cast $fromStr from " + - s"${from.dataType.simpleString} to ${to.simpleString} as it may truncate\n" + + s"${from.dataType.catalogString} to ${to.catalogString} as it may truncate\n" + "The type path of the target object is:\n" + walkedTypePath.mkString("", "\n", "\n") + "You can either add an explicit cast to the input data or choose a higher precision " + "type of the field in the target object") } - def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp { case p if !p.childrenResolved => p case p if p.resolved => p @@ -2316,8 +2526,12 @@ class Analyzer( * scoping information for attributes and can be removed once analysis is complete. */ object EliminateSubqueryAliases extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { - case SubqueryAlias(_, child) => child + // This is also called in the beginning of the optimization phase, and as a result + // is using transformUp rather than resolveOperators. + def apply(plan: LogicalPlan): LogicalPlan = AnalysisHelper.allowInvokingTransformsInAnalyzer { + plan transformUp { + case SubqueryAlias(_, child) => child + } } } @@ -2325,7 +2539,7 @@ object EliminateSubqueryAliases extends Rule[LogicalPlan] { * Removes [[Union]] operators from the plan if it just has one child. */ object EliminateUnions extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan transform { + def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case Union(children) if children.size == 1 => children.head } } @@ -2341,6 +2555,7 @@ object CleanupAliases extends Rule[LogicalPlan] { private def trimAliases(e: Expression): Expression = { e.transformDown { case Alias(child, _) => child + case MultiAlias(child, _) => child } } @@ -2350,10 +2565,12 @@ object CleanupAliases extends Rule[LogicalPlan] { exprId = a.exprId, qualifier = a.qualifier, explicitMetadata = Some(a.metadata)) + case a: MultiAlias => + a.copy(child = trimAliases(a.child)) case other => trimAliases(other) } - override def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { + override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp { case Project(projectList, child) => val cleanedProjectList = projectList.map(trimNonTopLevelAliases(_).asInstanceOf[NamedExpression]) @@ -2363,7 +2580,7 @@ object CleanupAliases extends Rule[LogicalPlan] { val cleanedAggs = aggs.map(trimNonTopLevelAliases(_).asInstanceOf[NamedExpression]) Aggregate(grouping.map(trimAliases), cleanedAggs, child) - case w @ Window(windowExprs, partitionSpec, orderSpec, child) => + case Window(windowExprs, partitionSpec, orderSpec, child) => val cleanedWindowExprs = windowExprs.map(e => trimNonTopLevelAliases(e).asInstanceOf[NamedExpression]) Window(cleanedWindowExprs, partitionSpec.map(trimAliases), @@ -2382,19 +2599,12 @@ object CleanupAliases extends Rule[LogicalPlan] { } } -/** Remove the barrier nodes of analysis */ -object EliminateBarriers extends Rule[LogicalPlan] { - override def apply(plan: LogicalPlan): LogicalPlan = plan transformDown { - case AnalysisBarrier(child) => child - } -} - /** * Ignore event time watermark in batch query, which is only supported in Structured Streaming. * TODO: add this rule into analyzer rule list. */ object EliminateEventTimeWatermark extends Rule[LogicalPlan] { - override def apply(plan: LogicalPlan): LogicalPlan = plan transform { + override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case EventTimeWatermark(_, _, child) if !child.isStreaming => child } } @@ -2439,7 +2649,7 @@ object TimeWindowing extends Rule[LogicalPlan] { * @return the logical plan that will generate the time windows using the Expand operator, with * the Filter operator for correctness and Project for usability. */ - def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp { case p: LogicalPlan if p.children.size == 1 => val child = p.children.head val windowExpressions = @@ -2527,7 +2737,7 @@ object TimeWindowing extends Rule[LogicalPlan] { * Resolve a [[CreateNamedStruct]] if it contains [[NamePlaceholder]]s. */ object ResolveCreateNamedStruct extends Rule[LogicalPlan] { - override def apply(plan: LogicalPlan): LogicalPlan = plan.transformAllExpressions { + override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveExpressions { case e: CreateNamedStruct if !e.resolved => val children = e.children.grouped(2).flatMap { case Seq(NamePlaceholder, e: NamedExpression) if e.resolved => @@ -2579,7 +2789,7 @@ object UpdateOuterReferences extends Rule[LogicalPlan] { private def updateOuterReferenceInSubquery( plan: LogicalPlan, refExprs: Seq[Expression]): LogicalPlan = { - plan transformAllExpressions { case e => + plan resolveExpressions { case e => val outerAlias = refExprs.find(stripAlias(_).semanticEquals(stripOuterReference(e))) outerAlias match { @@ -2590,7 +2800,7 @@ object UpdateOuterReferences extends Rule[LogicalPlan] { } def apply(plan: LogicalPlan): LogicalPlan = { - plan transform { + plan resolveOperators { case f @ Filter(_, a: Aggregate) if f.resolved => f transformExpressions { case s: SubqueryExpression if s.children.nonEmpty => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index af256b98b34f3..6a91d556b2f3e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -67,11 +67,15 @@ trait CheckAnalysis extends PredicateHelper { limitExpr.sql) case e if e.dataType != IntegerType => failAnalysis( s"The limit expression must be integer type, but got " + - e.dataType.simpleString) - case e if e.eval().asInstanceOf[Int] < 0 => failAnalysis( - "The limit expression must be equal to or greater than 0, but got " + - e.eval().asInstanceOf[Int]) - case e => // OK + e.dataType.catalogString) + case e => + e.eval() match { + case null => failAnalysis( + s"The evaluated limit expression must not be null, but got ${limitExpr.sql}") + case v: Int if v < 0 => failAnalysis( + s"The limit expression must be equal to or greater than 0, but got $v") + case _ => // OK + } } } @@ -79,10 +83,27 @@ trait CheckAnalysis extends PredicateHelper { // We transform up and order the rules so as to catch the first possible failure instead // of the result of cascading resolution failures. plan.foreachUp { + + case p if p.analyzed => // Skip already analyzed sub-plans + case u: UnresolvedRelation => u.failAnalysis(s"Table or view not found: ${u.tableIdentifier}") case operator: LogicalPlan => + // Check argument data types of higher-order functions downwards first. + // If the arguments of the higher-order functions are resolved but the type check fails, + // the argument functions will not get resolved, but we should report the argument type + // check failure instead of claiming the argument functions are unresolved. + operator transformExpressionsDown { + case hof: HigherOrderFunction + if hof.argumentsResolved && hof.checkArgumentDataTypes().isFailure => + hof.checkArgumentDataTypes() match { + case TypeCheckResult.TypeCheckFailure(message) => + hof.failAnalysis( + s"cannot resolve '${hof.sql}' due to argument data type mismatch: $message") + } + } + operator transformExpressionsUp { case a: Attribute if !a.resolved => val from = operator.inputSet.map(_.qualifiedName).mkString(", ") @@ -96,8 +117,8 @@ trait CheckAnalysis extends PredicateHelper { } case c: Cast if !c.resolved => - failAnalysis( - s"invalid cast from ${c.child.dataType.simpleString} to ${c.dataType.simpleString}") + failAnalysis(s"invalid cast from ${c.child.dataType.catalogString} to " + + c.dataType.catalogString) case g: Grouping => failAnalysis("grouping() can only be used with GroupingSets/Cube/Rollup") @@ -144,12 +165,12 @@ trait CheckAnalysis extends PredicateHelper { case _ => failAnalysis( s"Event time must be defined on a window or a timestamp, but " + - s"${etw.eventTime.name} is of type ${etw.eventTime.dataType.simpleString}") + s"${etw.eventTime.name} is of type ${etw.eventTime.dataType.catalogString}") } case f: Filter if f.condition.dataType != BooleanType => failAnalysis( s"filter expression '${f.condition.sql}' " + - s"of type ${f.condition.dataType.simpleString} is not a boolean.") + s"of type ${f.condition.dataType.catalogString} is not a boolean.") case Filter(condition, _) if hasNullAwarePredicateWithinNot(condition) => failAnalysis("Null-aware predicate sub-queries cannot be used in nested " + @@ -158,7 +179,7 @@ trait CheckAnalysis extends PredicateHelper { case j @ Join(_, _, _, Some(condition)) if condition.dataType != BooleanType => failAnalysis( s"join condition '${condition.sql}' " + - s"of type ${condition.dataType.simpleString} is not a boolean.") + s"of type ${condition.dataType.catalogString} is not a boolean.") case Aggregate(groupingExprs, aggregateExprs, child) => def isAggregateExpression(expr: Expression) = { @@ -219,7 +240,7 @@ trait CheckAnalysis extends PredicateHelper { if (!RowOrdering.isOrderable(expr.dataType)) { failAnalysis( s"expression ${expr.sql} cannot be used as a grouping expression " + - s"because its data type ${expr.dataType.simpleString} is not an orderable " + + s"because its data type ${expr.dataType.catalogString} is not an orderable " + s"data type.") } @@ -239,7 +260,7 @@ trait CheckAnalysis extends PredicateHelper { orders.foreach { order => if (!RowOrdering.isOrderable(order.dataType)) { failAnalysis( - s"sorting is not supported for columns of type ${order.dataType.simpleString}") + s"sorting is not supported for columns of type ${order.dataType.catalogString}") } } @@ -342,7 +363,7 @@ trait CheckAnalysis extends PredicateHelper { val mapCol = mapColumnInSetOperation(o).get failAnalysis("Cannot have map type columns in DataFrame which calls " + s"set operations(intersect, except, etc.), but the type of column ${mapCol.name} " + - "is " + mapCol.dataType.simpleString) + "is " + mapCol.dataType.catalogString) case o if o.expressions.exists(!_.deterministic) && !o.isInstanceOf[Project] && !o.isInstanceOf[Filter] && @@ -364,10 +385,11 @@ trait CheckAnalysis extends PredicateHelper { } extendedCheckRules.foreach(_(plan)) plan.foreachUp { - case AnalysisBarrier(child) if !child.resolved => checkAnalysis(child) case o if !o.resolved => failAnalysis(s"unresolved operator ${o.simpleString}") case _ => } + + plan.setAnalyzed() } /** @@ -531,9 +553,8 @@ trait CheckAnalysis extends PredicateHelper { var foundNonEqualCorrelatedPred: Boolean = false - // Simplify the predicates before validating any unsupported correlation patterns - // in the plan. - BooleanSimplification(sub).foreachUp { + // Simplify the predicates before validating any unsupported correlation patterns in the plan. + AnalysisHelper.allowInvokingTransformsInAnalyzer { BooleanSimplification(sub).foreachUp { // Whitelist operators allowed in a correlated subquery // There are 4 categories: // 1. Operators that are allowed anywhere in a correlated subquery, and, @@ -635,6 +656,6 @@ trait CheckAnalysis extends PredicateHelper { // are not allowed to have any correlated expressions. case p => failOnOuterReferenceInSubTree(p) - } + }} } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala index ab63131b07573..e511f8064e28a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala @@ -82,14 +82,14 @@ object DecimalPrecision extends TypeCoercionRule { PromotePrecision(Cast(e, dataType)) } - override protected def coerceTypes(plan: LogicalPlan): LogicalPlan = plan transformUp { + override protected def coerceTypes(plan: LogicalPlan): LogicalPlan = plan resolveOperators { // fix decimal precision for expressions case q => q.transformExpressionsUp( decimalAndDecimal.orElse(integralAndDecimalLiteral).orElse(nondecimalAndDecimal)) } /** Decimal precision promotion for +, -, *, /, %, pmod, and binary comparison. */ - private val decimalAndDecimal: PartialFunction[Expression, Expression] = { + private[catalyst] val decimalAndDecimal: PartialFunction[Expression, Expression] = { // Skip nodes whose children have not been resolved yet case e if !e.childrenResolved => e @@ -286,7 +286,7 @@ object DecimalPrecision extends TypeCoercionRule { // Consider the following example: multiplying a column which is DECIMAL(38, 18) by 2. // If we use the default precision and scale for the integer type, 2 is considered a // DECIMAL(10, 0). According to the rules, the result would be DECIMAL(38 + 10 + 1, 18), - // which is out of range and therefore it will becomes DECIMAL(38, 7), leading to + // which is out of range and therefore it will become DECIMAL(38, 7), leading to // potentially loosing 11 digits of the fractional part. Using only the precision needed // by the Literal, instead, the result would be DECIMAL(38 + 1 + 1, 18), which would // become DECIMAL(38, 16), safely having a much lower precision loss. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index a574d8a84d4fb..77860e1584f42 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -411,9 +411,12 @@ object FunctionRegistry { expression[CreateArray]("array"), expression[ArrayContains]("array_contains"), expression[ArraysOverlap]("arrays_overlap"), + expression[ArrayIntersect]("array_intersect"), expression[ArrayJoin]("array_join"), expression[ArrayPosition]("array_position"), expression[ArraySort]("array_sort"), + expression[ArrayExcept]("array_except"), + expression[ArrayUnion]("array_union"), expression[CreateMap]("map"), expression[CreateNamedStruct]("named_struct"), expression[ElementAt]("element_at"), @@ -422,11 +425,13 @@ object FunctionRegistry { expression[MapValues]("map_values"), expression[MapEntries]("map_entries"), expression[MapFromEntries]("map_from_entries"), + expression[MapConcat]("map_concat"), expression[Size]("size"), expression[Slice]("slice"), expression[Size]("cardinality"), expression[ArraysZip]("arrays_zip"), expression[SortArray]("sort_array"), + expression[Shuffle]("shuffle"), expression[ArrayMin]("array_min"), expression[ArrayMax]("array_max"), expression[Reverse]("reverse"), @@ -436,15 +441,17 @@ object FunctionRegistry { expression[ArrayRepeat]("array_repeat"), expression[ArrayRemove]("array_remove"), expression[ArrayDistinct]("array_distinct"), - CreateStruct.registryEntry, + expression[ArrayTransform]("transform"), + expression[MapFilter]("map_filter"), + expression[ArrayFilter]("filter"), + expression[ArrayExists]("exists"), + expression[ArrayAggregate]("aggregate"), + expression[TransformValues]("transform_values"), + expression[TransformKeys]("transform_keys"), + expression[MapZipWith]("map_zip_with"), + expression[ZipWith]("zip_with"), - // mask functions - expression[Mask]("mask"), - expression[MaskFirstN]("mask_first_n"), - expression[MaskLastN]("mask_last_n"), - expression[MaskShowFirstN]("mask_show_first_n"), - expression[MaskShowLastN]("mask_show_last_n"), - expression[MaskHash]("mask_hash"), + CreateStruct.registryEntry, // misc functions expression[AssertTrue]("assert_true"), @@ -505,6 +512,7 @@ object FunctionRegistry { // json expression[StructsToJson]("to_json"), expression[JsonToStructs]("from_json"), + expression[SchemaOfJson]("schema_of_json"), // cast expression[Cast]("cast"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/NamedRelation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/NamedRelation.scala new file mode 100644 index 0000000000000..ad201f947b671 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/NamedRelation.scala @@ -0,0 +1,24 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis + +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan + +trait NamedRelation extends LogicalPlan { + def name: String +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala index f068bce3e9b69..dbd4ed845e329 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.analysis import java.util.Locale import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.expressions.IntegerLiteral import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.trees.CurrentOrigin @@ -85,7 +86,7 @@ object ResolveHints { } } - def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { + def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperatorsUp { case h: UnresolvedHint if BROADCAST_HINT_NAMES.contains(h.name.toUpperCase(Locale.ROOT)) => if (h.parameters.isEmpty) { // If there is no table alias specified, turn the entire subtree into a BroadcastHint. @@ -102,12 +103,38 @@ object ResolveHints { } } + /** + * COALESCE Hint accepts name "COALESCE" and "REPARTITION". + * Its parameter includes a partition number. + */ + object ResolveCoalesceHints extends Rule[LogicalPlan] { + private val COALESCE_HINT_NAMES = Set("COALESCE", "REPARTITION") + + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { + case h: UnresolvedHint if COALESCE_HINT_NAMES.contains(h.name.toUpperCase(Locale.ROOT)) => + val hintName = h.name.toUpperCase(Locale.ROOT) + val shuffle = hintName match { + case "REPARTITION" => true + case "COALESCE" => false + } + val numPartitions = h.parameters match { + case Seq(IntegerLiteral(numPartitions)) => + numPartitions + case Seq(numPartitions: Int) => + numPartitions + case _ => + throw new AnalysisException(s"$hintName Hint expects a partition number as parameter") + } + Repartition(numPartitions, shuffle, h.child) + } + } + /** * Removes all the hints, used to remove invalid hints provided by the user. * This must be executed after all the other hint rules are executed. */ object RemoveAllHints extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { + def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperatorsUp { case h: UnresolvedHint => h.child } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTables.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTables.scala index 71ed75454cd4d..4edfe507a7580 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTables.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTables.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.types.{StructField, StructType} * An analyzer rule that replaces [[UnresolvedInlineTable]] with [[LocalRelation]]. */ case class ResolveInlineTables(conf: SQLConf) extends Rule[LogicalPlan] with CastSupport { - override def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { + override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case table: UnresolvedInlineTable if table.expressionsResolved => validateInputDimension(table) validateInputEvaluable(table) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableValuedFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableValuedFunctions.scala index a214e59302cd9..983e4b0e901cf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableValuedFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableValuedFunctions.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.analysis import java.util.Locale +import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.expressions.{Alias, Expression} import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project, Range} import org.apache.spark.sql.catalyst.rules._ @@ -68,9 +69,11 @@ object ResolveTableValuedFunctions extends Rule[LogicalPlan] { : (ArgumentList, Seq[Any] => LogicalPlan) = { (ArgumentList(args: _*), pf orElse { - case args => - throw new IllegalArgumentException( - "Invalid arguments for resolved function: " + args.mkString(", ")) + case arguments => + // This is caught again by the apply function and rethrow with richer information about + // position, etc, for a better error message. + throw new AnalysisException( + "Invalid arguments for resolved function: " + arguments.mkString(", ")) }) } @@ -103,24 +106,37 @@ object ResolveTableValuedFunctions extends Rule[LogicalPlan] { }) ) - override def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { + override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case u: UnresolvedTableValuedFunction if u.functionArgs.forall(_.resolved) => + // The whole resolution is somewhat difficult to understand here due to too much abstractions. + // We should probably rewrite the following at some point. Reynold was just here to improve + // error messages and didn't have time to do a proper rewrite. val resolvedFunc = builtinFunctions.get(u.functionName.toLowerCase(Locale.ROOT)) match { case Some(tvf) => + + def failAnalysis(): Nothing = { + val argTypes = u.functionArgs.map(_.dataType.typeName).mkString(", ") + u.failAnalysis( + s"""error: table-valued function ${u.functionName} with alternatives: + |${tvf.keys.map(_.toString).toSeq.sorted.map(x => s" ($x)").mkString("\n")} + |cannot be applied to: ($argTypes)""".stripMargin) + } + val resolved = tvf.flatMap { case (argList, resolver) => argList.implicitCast(u.functionArgs) match { case Some(casted) => - Some(resolver(casted.map(_.eval()))) + try { + Some(resolver(casted.map(_.eval()))) + } catch { + case e: AnalysisException => + failAnalysis() + } case _ => None } } resolved.headOption.getOrElse { - val argTypes = u.functionArgs.map(_.dataType.typeName).mkString(", ") - u.failAnalysis( - s"""error: table-valued function ${u.functionName} with alternatives: - |${tvf.keys.map(_.toString).toSeq.sorted.map(x => s" ($x)").mkString("\n")} - |cannot be applied to: (${argTypes})""".stripMargin) + failAnalysis() } case _ => u.failAnalysis(s"could not resolve `${u.functionName}` to a table-valued function") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/SubstituteUnresolvedOrdinals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/SubstituteUnresolvedOrdinals.scala index f9fd0df9e4010..860d20f897690 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/SubstituteUnresolvedOrdinals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/SubstituteUnresolvedOrdinals.scala @@ -33,7 +33,7 @@ class SubstituteUnresolvedOrdinals(conf: SQLConf) extends Rule[LogicalPlan] { case _ => false } - def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { + def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case s: Sort if conf.orderByOrdinal && s.order.exists(o => isIntLiteral(o.child)) => val newOrders = s.order.map { case order @ SortOrder(ordinal @ Literal(index: Int, IntegerType), _, _, _) => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index 3ebab430ffbcd..288b6358fbff1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -54,6 +54,7 @@ object TypeCoercion { BooleanEquality :: FunctionArgumentConversion :: ConcatCoercion(conf) :: + MapZipWithCoercion :: EltCoercion(conf) :: CaseWhenCoercion :: IfCoercion :: @@ -102,25 +103,7 @@ object TypeCoercion { case (_: TimestampType, _: DateType) | (_: DateType, _: TimestampType) => Some(TimestampType) - case (t1 @ StructType(fields1), t2 @ StructType(fields2)) if t1.sameType(t2) => - Some(StructType(fields1.zip(fields2).map { case (f1, f2) => - // Since `t1.sameType(t2)` is true, two StructTypes have the same DataType - // except `name` (in case of `spark.sql.caseSensitive=false`) and `nullable`. - // - Different names: use f1.name - // - Different nullabilities: `nullable` is true iff one of them is nullable. - val dataType = findTightestCommonType(f1.dataType, f2.dataType).get - StructField(f1.name, dataType, nullable = f1.nullable || f2.nullable) - })) - - case (a1 @ ArrayType(et1, hasNull1), a2 @ ArrayType(et2, hasNull2)) if a1.sameType(a2) => - findTightestCommonType(et1, et2).map(ArrayType(_, hasNull1 || hasNull2)) - - case (m1 @ MapType(kt1, vt1, hasNull1), m2 @ MapType(kt2, vt2, hasNull2)) if m1.sameType(m2) => - val keyType = findTightestCommonType(kt1, kt2) - val valueType = findTightestCommonType(vt1, vt2) - Some(MapType(keyType.get, valueType.get, hasNull1 || hasNull2)) - - case _ => None + case (t1, t2) => findTypeForComplex(t1, t2, findTightestCommonType) } /** Promotes all the way to StringType. */ @@ -166,6 +149,60 @@ object TypeCoercion { case (l, r) => None } + private def findTypeForComplex( + t1: DataType, + t2: DataType, + findTypeFunc: (DataType, DataType) => Option[DataType]): Option[DataType] = (t1, t2) match { + case (ArrayType(et1, containsNull1), ArrayType(et2, containsNull2)) => + findTypeFunc(et1, et2).map { et => + ArrayType(et, containsNull1 || containsNull2 || + Cast.forceNullable(et1, et) || Cast.forceNullable(et2, et)) + } + case (MapType(kt1, vt1, valueContainsNull1), MapType(kt2, vt2, valueContainsNull2)) => + findTypeFunc(kt1, kt2) + .filter { kt => !Cast.forceNullable(kt1, kt) && !Cast.forceNullable(kt2, kt) } + .flatMap { kt => + findTypeFunc(vt1, vt2).map { vt => + MapType(kt, vt, valueContainsNull1 || valueContainsNull2 || + Cast.forceNullable(vt1, vt) || Cast.forceNullable(vt2, vt)) + } + } + case (StructType(fields1), StructType(fields2)) if fields1.length == fields2.length => + val resolver = SQLConf.get.resolver + fields1.zip(fields2).foldLeft(Option(new StructType())) { + case (Some(struct), (field1, field2)) if resolver(field1.name, field2.name) => + findTypeFunc(field1.dataType, field2.dataType).map { dt => + struct.add(field1.name, dt, field1.nullable || field2.nullable || + Cast.forceNullable(field1.dataType, dt) || Cast.forceNullable(field2.dataType, dt)) + } + case _ => None + } + case _ => None + } + + /** + * The method finds a common type for data types that differ only in nullable, containsNull + * and valueContainsNull flags. If the input types are too different, None is returned. + */ + def findCommonTypeDifferentOnlyInNullFlags(t1: DataType, t2: DataType): Option[DataType] = { + if (t1 == t2) { + Some(t1) + } else { + findTypeForComplex(t1, t2, findCommonTypeDifferentOnlyInNullFlags) + } + } + + def findCommonTypeDifferentOnlyInNullFlags(types: Seq[DataType]): Option[DataType] = { + if (types.isEmpty) { + None + } else { + types.tail.foldLeft[Option[DataType]](Some(types.head)) { + case (Some(t1), t2) => findCommonTypeDifferentOnlyInNullFlags(t1, t2) + case _ => None + } + } + } + /** * Case 2 type widening (see the classdoc comment above for TypeCoercion). * @@ -176,11 +213,7 @@ object TypeCoercion { findTightestCommonType(t1, t2) .orElse(findWiderTypeForDecimal(t1, t2)) .orElse(stringPromotion(t1, t2)) - .orElse((t1, t2) match { - case (ArrayType(et1, containsNull1), ArrayType(et2, containsNull2)) => - findWiderTypeForTwo(et1, et2).map(ArrayType(_, containsNull1 || containsNull2)) - case _ => None - }) + .orElse(findTypeForComplex(t1, t2, findWiderTypeForTwo)) } /** @@ -216,12 +249,7 @@ object TypeCoercion { t2: DataType): Option[DataType] = { findTightestCommonType(t1, t2) .orElse(findWiderTypeForDecimal(t1, t2)) - .orElse((t1, t2) match { - case (ArrayType(et1, containsNull1), ArrayType(et2, containsNull2)) => - findWiderTypeWithoutStringPromotionForTwo(et1, et2) - .map(ArrayType(_, containsNull1 || containsNull2)) - case _ => None - }) + .orElse(findTypeForComplex(t1, t2, findWiderTypeWithoutStringPromotionForTwo)) } def findWiderTypeWithoutStringPromotion(types: Seq[DataType]): Option[DataType] = { @@ -250,8 +278,25 @@ object TypeCoercion { } } - private def haveSameType(exprs: Seq[Expression]): Boolean = - exprs.map(_.dataType).distinct.length == 1 + /** + * Check whether the given types are equal ignoring nullable, containsNull and valueContainsNull. + */ + def haveSameType(types: Seq[DataType]): Boolean = { + if (types.size <= 1) { + true + } else { + val head = types.head + types.tail.forall(_.sameType(head)) + } + } + + private def castIfNotSameType(expr: Expression, dt: DataType): Expression = { + if (!expr.dataType.sameType(dt)) { + Cast(expr, dt) + } else { + expr + } + } /** * Widens numeric types and converts strings to numbers when appropriate. @@ -281,12 +326,18 @@ object TypeCoercion { */ object WidenSetOperationTypes extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { - case s @ SetOperation(left, right) if s.childrenResolved && - left.output.length == right.output.length && !s.resolved => + def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperatorsUp { + case s @ Except(left, right, isAll) if s.childrenResolved && + left.output.length == right.output.length && !s.resolved => val newChildren: Seq[LogicalPlan] = buildNewChildrenWithWiderTypes(left :: right :: Nil) assert(newChildren.length == 2) - s.makeCopy(Array(newChildren.head, newChildren.last)) + Except(newChildren.head, newChildren.last, isAll) + + case s @ Intersect(left, right, isAll) if s.childrenResolved && + left.output.length == right.output.length && !s.resolved => + val newChildren: Seq[LogicalPlan] = buildNewChildrenWithWiderTypes(left :: right :: Nil) + assert(newChildren.length == 2) + Intersect(newChildren.head, newChildren.last, isAll) case s: Union if s.childrenResolved && s.children.forall(_.output.length == s.children.head.output.length) && !s.resolved => @@ -354,7 +405,7 @@ object TypeCoercion { } override protected def coerceTypes( - plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { + plan: LogicalPlan): LogicalPlan = plan resolveExpressions { // Skip nodes who's children have not been resolved yet. case e if !e.childrenResolved => e @@ -406,27 +457,16 @@ object TypeCoercion { * Analysis Exception will be raised at the type checking phase. */ case class InConversion(conf: SQLConf) extends TypeCoercionRule { - private def flattenExpr(expr: Expression): Seq[Expression] = { - expr match { - // Multi columns in IN clause is represented as a CreateNamedStruct. - // flatten the named struct to get the list of expressions. - case cns: CreateNamedStruct => cns.valExprs - case expr => Seq(expr) - } - } - override protected def coerceTypes( - plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { + plan: LogicalPlan): LogicalPlan = plan resolveExpressions { // Skip nodes who's children have not been resolved yet. case e if !e.childrenResolved => e // Handle type casting required between value expression and subquery output // in IN subquery. - case i @ In(a, Seq(ListQuery(sub, children, exprId, _))) - if !i.resolved && flattenExpr(a).length == sub.output.length => - // LHS is the value expression of IN subquery. - val lhs = flattenExpr(a) - + case i @ InSubquery(lhs, ListQuery(sub, children, exprId, _)) + if !i.resolved && lhs.length == sub.output.length => + // LHS is the value expressions of IN subquery. // RHS is the subquery output. val rhs = sub.output @@ -442,20 +482,13 @@ object TypeCoercion { case (e, dt) if e.dataType != dt => Alias(Cast(e, dt), e.name)() case (e, _) => e } - val castedLhs = lhs.zip(commonTypes).map { + val newLhs = lhs.zip(commonTypes).map { case (e, dt) if e.dataType != dt => Cast(e, dt) case (e, _) => e } - // Before constructing the In expression, wrap the multi values in LHS - // in a CreatedNamedStruct. - val newLhs = castedLhs match { - case Seq(lhs) => lhs - case _ => CreateStruct(castedLhs) - } - val newSub = Project(castedRhs, sub) - In(newLhs, Seq(ListQuery(newSub, children, exprId, newSub.output))) + InSubquery(newLhs, ListQuery(newSub, children, exprId, newSub.output)) } else { i } @@ -475,7 +508,7 @@ object TypeCoercion { private val trueValues = Seq(1.toByte, 1.toShort, 1, 1L, Decimal.ONE) private val falseValues = Seq(0.toByte, 0.toShort, 0, 0L, Decimal.ZERO) - def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { + def apply(plan: LogicalPlan): LogicalPlan = plan resolveExpressions { // Skip nodes who's children have not been resolved yet. case e if !e.childrenResolved => e @@ -516,23 +549,24 @@ object TypeCoercion { * This ensure that the types for various functions are as expected. */ object FunctionArgumentConversion extends TypeCoercionRule { + override protected def coerceTypes( - plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { + plan: LogicalPlan): LogicalPlan = plan resolveExpressions { // Skip nodes who's children have not been resolved yet. case e if !e.childrenResolved => e - case a @ CreateArray(children) if !haveSameType(children) => + case a @ CreateArray(children) if !haveSameType(children.map(_.dataType)) => val types = children.map(_.dataType) findWiderCommonType(types) match { - case Some(finalDataType) => CreateArray(children.map(Cast(_, finalDataType))) + case Some(finalDataType) => CreateArray(children.map(castIfNotSameType(_, finalDataType))) case None => a } case c @ Concat(children) if children.forall(c => ArrayType.acceptsType(c.dataType)) && - !haveSameType(children) => + !haveSameType(c.inputTypesForMerging) => val types = children.map(_.dataType) findWiderCommonType(types) match { - case Some(finalDataType) => Concat(children.map(Cast(_, finalDataType))) + case Some(finalDataType) => Concat(children.map(castIfNotSameType(_, finalDataType))) case None => c } @@ -544,33 +578,34 @@ object TypeCoercion { case None => aj } - case s @ Sequence(_, _, _, timeZoneId) if !haveSameType(s.coercibleChildren) => + case s @ Sequence(_, _, _, timeZoneId) + if !haveSameType(s.coercibleChildren.map(_.dataType)) => val types = s.coercibleChildren.map(_.dataType) findWiderCommonType(types) match { case Some(widerDataType) => s.castChildrenTo(widerDataType) case None => s } + case m @ MapConcat(children) if children.forall(c => MapType.acceptsType(c.dataType)) && + !haveSameType(m.inputTypesForMerging) => + val types = children.map(_.dataType) + findWiderCommonType(types) match { + case Some(finalDataType) => MapConcat(children.map(castIfNotSameType(_, finalDataType))) + case None => m + } + case m @ CreateMap(children) if m.keys.length == m.values.length && - (!haveSameType(m.keys) || !haveSameType(m.values)) => - val newKeys = if (haveSameType(m.keys)) { - m.keys - } else { - val types = m.keys.map(_.dataType) - findWiderCommonType(types) match { - case Some(finalDataType) => m.keys.map(Cast(_, finalDataType)) - case None => m.keys - } + (!haveSameType(m.keys.map(_.dataType)) || !haveSameType(m.values.map(_.dataType))) => + val keyTypes = m.keys.map(_.dataType) + val newKeys = findWiderCommonType(keyTypes) match { + case Some(finalDataType) => m.keys.map(castIfNotSameType(_, finalDataType)) + case None => m.keys } - val newValues = if (haveSameType(m.values)) { - m.values - } else { - val types = m.values.map(_.dataType) - findWiderCommonType(types) match { - case Some(finalDataType) => m.values.map(Cast(_, finalDataType)) - case None => m.values - } + val valueTypes = m.values.map(_.dataType) + val newValues = findWiderCommonType(valueTypes) match { + case Some(finalDataType) => m.values.map(castIfNotSameType(_, finalDataType)) + case None => m.values } CreateMap(newKeys.zip(newValues).flatMap { case (k, v) => Seq(k, v) }) @@ -593,27 +628,27 @@ object TypeCoercion { // Coalesce should return the first non-null value, which could be any column // from the list. So we need to make sure the return type is deterministic and // compatible with every child column. - case c @ Coalesce(es) if !haveSameType(es) => + case c @ Coalesce(es) if !haveSameType(c.inputTypesForMerging) => val types = es.map(_.dataType) findWiderCommonType(types) match { - case Some(finalDataType) => Coalesce(es.map(Cast(_, finalDataType))) + case Some(finalDataType) => Coalesce(es.map(castIfNotSameType(_, finalDataType))) case None => c } // When finding wider type for `Greatest` and `Least`, we should handle decimal types even if // we need to truncate, but we should not promote one side to string if the other side is // string.g - case g @ Greatest(children) if !haveSameType(children) => + case g @ Greatest(children) if !haveSameType(g.inputTypesForMerging) => val types = children.map(_.dataType) findWiderTypeWithoutStringPromotion(types) match { - case Some(finalDataType) => Greatest(children.map(Cast(_, finalDataType))) + case Some(finalDataType) => Greatest(children.map(castIfNotSameType(_, finalDataType))) case None => g } - case l @ Least(children) if !haveSameType(children) => + case l @ Least(children) if !haveSameType(l.inputTypesForMerging) => val types = children.map(_.dataType) findWiderTypeWithoutStringPromotion(types) match { - case Some(finalDataType) => Least(children.map(Cast(_, finalDataType))) + case Some(finalDataType) => Least(children.map(castIfNotSameType(_, finalDataType))) case None => l } @@ -631,7 +666,7 @@ object TypeCoercion { */ object Division extends TypeCoercionRule { override protected def coerceTypes( - plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { + plan: LogicalPlan): LogicalPlan = plan resolveExpressions { // Skip nodes who has not been resolved yet, // as this is an extra rule which should be applied at last. case e if !e.childrenResolved => e @@ -654,28 +689,15 @@ object TypeCoercion { */ object CaseWhenCoercion extends TypeCoercionRule { override protected def coerceTypes( - plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { - case c: CaseWhen if c.childrenResolved && !c.valueTypesEqual => - val maybeCommonType = findWiderCommonType(c.valueTypes) + plan: LogicalPlan): LogicalPlan = plan resolveExpressions { + case c: CaseWhen if c.childrenResolved && !haveSameType(c.inputTypesForMerging) => + val maybeCommonType = findWiderCommonType(c.inputTypesForMerging) maybeCommonType.map { commonType => - var changed = false val newBranches = c.branches.map { case (condition, value) => - if (value.dataType.sameType(commonType)) { - (condition, value) - } else { - changed = true - (condition, Cast(value, commonType)) - } + (condition, castIfNotSameType(value, commonType)) } - val newElseValue = c.elseValue.map { value => - if (value.dataType.sameType(commonType)) { - value - } else { - changed = true - Cast(value, commonType) - } - } - if (changed) CaseWhen(newBranches, newElseValue) else c + val newElseValue = c.elseValue.map(castIfNotSameType(_, commonType)) + CaseWhen(newBranches, newElseValue) }.getOrElse(c) } } @@ -685,13 +707,13 @@ object TypeCoercion { */ object IfCoercion extends TypeCoercionRule { override protected def coerceTypes( - plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { + plan: LogicalPlan): LogicalPlan = plan resolveExpressions { case e if !e.childrenResolved => e // Find tightest common type for If, if the true value and false value have different types. - case i @ If(pred, left, right) if left.dataType != right.dataType => + case i @ If(pred, left, right) if !haveSameType(i.inputTypesForMerging) => findWiderTypeForTwo(left.dataType, right.dataType).map { widestType => - val newLeft = if (left.dataType == widestType) left else Cast(left, widestType) - val newRight = if (right.dataType == widestType) right else Cast(right, widestType) + val newLeft = castIfNotSameType(left, widestType) + val newRight = castIfNotSameType(right, widestType) If(pred, newLeft, newRight) }.getOrElse(i) // If there is no applicable conversion, leave expression unchanged. case If(Literal(null, NullType), left, right) => @@ -705,7 +727,7 @@ object TypeCoercion { * Coerces NullTypes in the Stack expression to the column types of the corresponding positions. */ object StackCoercion extends TypeCoercionRule { - override def coerceTypes(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { + override def coerceTypes(plan: LogicalPlan): LogicalPlan = plan resolveExpressions { case s @ Stack(children) if s.childrenResolved && s.hasFoldableNumRows => Stack(children.zipWithIndex.map { // The first child is the number of rows for stack. @@ -725,20 +747,46 @@ object TypeCoercion { */ case class ConcatCoercion(conf: SQLConf) extends TypeCoercionRule { - override protected def coerceTypes(plan: LogicalPlan): LogicalPlan = plan transform { case p => - p transformExpressionsUp { - // Skip nodes if unresolved or empty children - case c @ Concat(children) if !c.childrenResolved || children.isEmpty => c - case c @ Concat(children) if conf.concatBinaryAsString || + override protected def coerceTypes(plan: LogicalPlan): LogicalPlan = { + plan resolveOperators { case p => + p transformExpressionsUp { + // Skip nodes if unresolved or empty children + case c @ Concat(children) if !c.childrenResolved || children.isEmpty => c + case c @ Concat(children) if conf.concatBinaryAsString || !children.map(_.dataType).forall(_ == BinaryType) => - val newChildren = c.children.map { e => - ImplicitTypeCasts.implicitCast(e, StringType).getOrElse(e) - } - c.copy(children = newChildren) + val newChildren = c.children.map { e => + ImplicitTypeCasts.implicitCast(e, StringType).getOrElse(e) + } + c.copy(children = newChildren) + } } } } + /** + * Coerces key types of two different [[MapType]] arguments of the [[MapZipWith]] expression + * to a common type. + */ + object MapZipWithCoercion extends TypeCoercionRule { + override protected def coerceTypes(plan: LogicalPlan): LogicalPlan = plan resolveExpressions { + // Lambda function isn't resolved when the rule is executed. + case m @ MapZipWith(left, right, function) if m.arguments.forall(a => a.resolved && + MapType.acceptsType(a.dataType)) && !m.leftKeyType.sameType(m.rightKeyType) => + findWiderTypeForTwo(m.leftKeyType, m.rightKeyType) match { + case Some(finalKeyType) if !Cast.forceNullable(m.leftKeyType, finalKeyType) && + !Cast.forceNullable(m.rightKeyType, finalKeyType) => + val newLeft = castIfNotSameType( + left, + MapType(finalKeyType, m.leftValueType, m.leftValueContainsNull)) + val newRight = castIfNotSameType( + right, + MapType(finalKeyType, m.rightValueType, m.rightValueContainsNull)) + MapZipWith(newLeft, newRight, function) + case _ => m + } + } + } + /** * Coerces the types of [[Elt]] children to expected ones. * @@ -747,22 +795,24 @@ object TypeCoercion { */ case class EltCoercion(conf: SQLConf) extends TypeCoercionRule { - override protected def coerceTypes(plan: LogicalPlan): LogicalPlan = plan transform { case p => - p transformExpressionsUp { - // Skip nodes if unresolved or not enough children - case c @ Elt(children) if !c.childrenResolved || children.size < 2 => c - case c @ Elt(children) => - val index = children.head - val newIndex = ImplicitTypeCasts.implicitCast(index, IntegerType).getOrElse(index) - val newInputs = if (conf.eltOutputAsString || + override protected def coerceTypes(plan: LogicalPlan): LogicalPlan = { + plan resolveOperators { case p => + p transformExpressionsUp { + // Skip nodes if unresolved or not enough children + case c @ Elt(children) if !c.childrenResolved || children.size < 2 => c + case c @ Elt(children) => + val index = children.head + val newIndex = ImplicitTypeCasts.implicitCast(index, IntegerType).getOrElse(index) + val newInputs = if (conf.eltOutputAsString || !children.tail.map(_.dataType).forall(_ == BinaryType)) { - children.tail.map { e => - ImplicitTypeCasts.implicitCast(e, StringType).getOrElse(e) + children.tail.map { e => + ImplicitTypeCasts.implicitCast(e, StringType).getOrElse(e) + } + } else { + children.tail } - } else { - children.tail - } - c.copy(children = newIndex +: newInputs) + c.copy(children = newIndex +: newInputs) + } } } } @@ -775,7 +825,7 @@ object TypeCoercion { private val acceptedTypes = Seq(DateType, TimestampType, StringType) - def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { + def apply(plan: LogicalPlan): LogicalPlan = plan resolveExpressions { // Skip nodes who's children have not been resolved yet. case e if !e.childrenResolved => e @@ -796,7 +846,7 @@ object TypeCoercion { private def rejectTzInString = conf.getConf(SQLConf.REJECT_TIMEZONE_IN_STRING) override protected def coerceTypes( - plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { + plan: LogicalPlan): LogicalPlan = plan resolveExpressions { // Skip nodes who's children have not been resolved yet. case e if !e.childrenResolved => e @@ -935,7 +985,7 @@ object TypeCoercion { */ object WindowFrameCoercion extends TypeCoercionRule { override protected def coerceTypes( - plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { + plan: LogicalPlan): LogicalPlan = plan resolveExpressions { case s @ WindowSpecDefinition(_, Seq(order), SpecifiedWindowFrame(RangeFrame, lower, upper)) if order.resolved => s.copy(frameSpecification = SpecifiedWindowFrame( @@ -973,7 +1023,7 @@ trait TypeCoercionRule extends Rule[LogicalPlan] with Logging { protected def coerceTypes(plan: LogicalPlan): LogicalPlan - private def propagateTypes(plan: LogicalPlan): LogicalPlan = plan transformUp { + private def propagateTypes(plan: LogicalPlan): LogicalPlan = plan resolveOperatorsUp { // No propagation required for leaf nodes. case q: LogicalPlan if q.children.isEmpty => q diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala index 5ced1ca200daa..cff4cee09427f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala @@ -306,17 +306,19 @@ object UnsupportedOperationChecker { case u: Union if u.children.map(_.isStreaming).distinct.size == 2 => throwError("Union between streaming and batch DataFrames/Datasets is not supported") - case Except(left, right) if right.isStreaming => + case Except(left, right, _) if right.isStreaming => throwError("Except on a streaming DataFrame/Dataset on the right is not supported") - case Intersect(left, right) if left.isStreaming && right.isStreaming => + case Intersect(left, right, _) if left.isStreaming && right.isStreaming => throwError("Intersect between two streaming DataFrames/Datasets is not supported") case GroupingSets(_, _, child, _) if child.isStreaming => throwError("GroupingSets is not supported on streaming DataFrames/Datasets") - case GlobalLimit(_, _) | LocalLimit(_, _) if subPlan.children.forall(_.isStreaming) => - throwError("Limits are not supported on streaming DataFrames/Datasets") + case GlobalLimit(_, _) | LocalLimit(_, _) + if subPlan.children.forall(_.isStreaming) && outputMode == InternalOutputModes.Update => + throwError("Limits are not supported on streaming DataFrames/Datasets in Update " + + "output mode") case Sort(_, _, _) if !containsCompleteData(subPlan) => throwError("Sorting is not supported on streaming DataFrames/Datasets, unless it is on " + diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/higherOrderFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/higherOrderFunctions.scala new file mode 100644 index 0000000000000..dd08190e1e8a3 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/higherOrderFunctions.scala @@ -0,0 +1,163 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis + +import org.apache.spark.sql.catalyst.catalog.SessionCatalog +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types.DataType + +/** + * Resolve a higher order functions from the catalog. This is different from regular function + * resolution because lambda functions can only be resolved after the function has been resolved; + * so we need to resolve higher order function when all children are either resolved or a lambda + * function. + */ +case class ResolveHigherOrderFunctions(catalog: SessionCatalog) extends Rule[LogicalPlan] { + + override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveExpressions { + case u @ UnresolvedFunction(fn, children, false) + if hasLambdaAndResolvedArguments(children) => + withPosition(u) { + catalog.lookupFunction(fn, children) match { + case func: HigherOrderFunction => func + case other => other.failAnalysis( + "A lambda function should only be used in a higher order function. However, " + + s"its class is ${other.getClass.getCanonicalName}, which is not a " + + s"higher order function.") + } + } + } + + /** + * Check if the arguments of a function are either resolved or a lambda function. + */ + private def hasLambdaAndResolvedArguments(expressions: Seq[Expression]): Boolean = { + val (lambdas, others) = expressions.partition(_.isInstanceOf[LambdaFunction]) + lambdas.nonEmpty && others.forall(_.resolved) + } +} + +/** + * Resolve the lambda variables exposed by a higher order functions. + * + * This rule works in two steps: + * [1]. Bind the anonymous variables exposed by the higher order function to the lambda function's + * arguments; this creates named and typed lambda variables. The argument names are checked + * for duplicates and the number of arguments are checked during this step. + * [2]. Resolve the used lambda variables used in the lambda function's function expression tree. + * Note that we allow the use of variables from outside the current lambda, this can either + * be a lambda function defined in an outer scope, or a attribute in produced by the plan's + * child. If names are duplicate, the name defined in the most inner scope is used. + */ +case class ResolveLambdaVariables(conf: SQLConf) extends Rule[LogicalPlan] { + + type LambdaVariableMap = Map[String, NamedExpression] + + private val canonicalizer = { + if (!conf.caseSensitiveAnalysis) { + s: String => s.toLowerCase + } else { + s: String => s + } + } + + override def apply(plan: LogicalPlan): LogicalPlan = { + plan.resolveOperators { + case q: LogicalPlan => + q.mapExpressions(resolve(_, Map.empty)) + } + } + + /** + * Create a bound lambda function by binding the arguments of a lambda function to the given + * partial arguments (dataType and nullability only). If the expression happens to be an already + * bound lambda function then we assume it has been bound to the correct arguments and do + * nothing. This function will produce a lambda function with hidden arguments when it is passed + * an arbitrary expression. + */ + private def createLambda( + e: Expression, + argInfo: Seq[(DataType, Boolean)]): LambdaFunction = e match { + case f: LambdaFunction if f.bound => f + + case LambdaFunction(function, names, _) => + if (names.size != argInfo.size) { + e.failAnalysis( + s"The number of lambda function arguments '${names.size}' does not " + + "match the number of arguments expected by the higher order function " + + s"'${argInfo.size}'.") + } + + if (names.map(a => canonicalizer(a.name)).distinct.size < names.size) { + e.failAnalysis( + "Lambda function arguments should not have names that are semantically the same.") + } + + val arguments = argInfo.zip(names).map { + case ((dataType, nullable), ne) => + NamedLambdaVariable(ne.name, dataType, nullable) + } + LambdaFunction(function, arguments) + + case _ => + // This expression does not consume any of the lambda's arguments (it is independent). We do + // create a lambda function with default parameters because this is expected by the higher + // order function. Note that we hide the lambda variables produced by this function in order + // to prevent accidental naming collisions. + val arguments = argInfo.zipWithIndex.map { + case ((dataType, nullable), i) => + NamedLambdaVariable(s"col$i", dataType, nullable) + } + LambdaFunction(e, arguments, hidden = true) + } + + /** + * Resolve lambda variables in the expression subtree, using the passed lambda variable registry. + */ + private def resolve(e: Expression, parentLambdaMap: LambdaVariableMap): Expression = e match { + case _ if e.resolved => e + + case h: HigherOrderFunction if h.argumentsResolved && h.checkArgumentDataTypes().isSuccess => + h.bind(createLambda).mapChildren(resolve(_, parentLambdaMap)) + + case l: LambdaFunction if !l.bound => + // Do not resolve an unbound lambda function. If we see such a lambda function this means + // that either the higher order function has yet to be resolved, or that we are seeing + // dangling lambda function. + l + + case l: LambdaFunction if !l.hidden => + val lambdaMap = l.arguments.map(v => canonicalizer(v.name) -> v).toMap + l.mapChildren(resolve(_, parentLambdaMap ++ lambdaMap)) + + case u @ UnresolvedAttribute(name +: nestedFields) => + parentLambdaMap.get(canonicalizer(name)) match { + case Some(lambda) => + nestedFields.foldLeft(lambda: Expression) { (expr, fieldName) => + ExtractValue(expr, Literal(fieldName), conf.resolver) + } + case None => u + } + + case _ => + e.mapChildren(resolve(_, parentLambdaMap)) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/timeZoneAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/timeZoneAnalysis.scala index af1f9165b0044..a27aa845bf0ae 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/timeZoneAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/timeZoneAnalysis.scala @@ -38,7 +38,7 @@ case class ResolveTimeZone(conf: SQLConf) extends Rule[LogicalPlan] { } override def apply(plan: LogicalPlan): LogicalPlan = - plan.transformAllExpressions(transformTimeZoneExprs) + plan.resolveExpressions(transformTimeZoneExprs) def resolveTimeZones(e: Expression): Expression = e.transform(transformTimeZoneExprs) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala index 71e23175168e2..c1ec736c32ed4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala @@ -104,12 +104,12 @@ case class UnresolvedAttribute(nameParts: Seq[String]) extends Attribute with Un override def exprId: ExprId = throw new UnresolvedException(this, "exprId") override def dataType: DataType = throw new UnresolvedException(this, "dataType") override def nullable: Boolean = throw new UnresolvedException(this, "nullable") - override def qualifier: Option[String] = throw new UnresolvedException(this, "qualifier") + override def qualifier: Seq[String] = throw new UnresolvedException(this, "qualifier") override lazy val resolved = false override def newInstance(): UnresolvedAttribute = this override def withNullability(newNullability: Boolean): UnresolvedAttribute = this - override def withQualifier(newQualifier: Option[String]): UnresolvedAttribute = this + override def withQualifier(newQualifier: Seq[String]): UnresolvedAttribute = this override def withName(newName: String): UnresolvedAttribute = UnresolvedAttribute.quoted(newName) override def withMetadata(newMetadata: Metadata): Attribute = this @@ -240,7 +240,7 @@ abstract class Star extends LeafExpression with NamedExpression { override def exprId: ExprId = throw new UnresolvedException(this, "exprId") override def dataType: DataType = throw new UnresolvedException(this, "dataType") override def nullable: Boolean = throw new UnresolvedException(this, "nullable") - override def qualifier: Option[String] = throw new UnresolvedException(this, "qualifier") + override def qualifier: Seq[String] = throw new UnresolvedException(this, "qualifier") override def toAttribute: Attribute = throw new UnresolvedException(this, "toAttribute") override def newInstance(): NamedExpression = throw new UnresolvedException(this, "newInstance") override lazy val resolved = false @@ -262,17 +262,46 @@ abstract class Star extends LeafExpression with NamedExpression { */ case class UnresolvedStar(target: Option[Seq[String]]) extends Star with Unevaluable { - override def expand(input: LogicalPlan, resolver: Resolver): Seq[NamedExpression] = { + /** + * Returns true if the nameParts match the qualifier of the attribute + * + * There are two checks: i) Check if the nameParts match the qualifier fully. + * E.g. SELECT db.t1.* FROM db1.t1 In this case, the nameParts is Seq("db1", "t1") and + * qualifier of the attribute is Seq("db1","t1") + * ii) If (i) is not true, then check if nameParts is only a single element and it + * matches the table portion of the qualifier + * + * E.g. SELECT t1.* FROM db1.t1 In this case nameParts is Seq("t1") and + * qualifier is Seq("db1","t1") + * SELECT a.* FROM db1.t1 AS a + * In this case nameParts is Seq("a") and qualifier for + * attribute is Seq("a") + */ + private def matchedQualifier( + attribute: Attribute, + nameParts: Seq[String], + resolver: Resolver): Boolean = { + val qualifierList = attribute.qualifier + + val matched = nameParts.corresponds(qualifierList)(resolver) || { + // check if it matches the table portion of the qualifier + if (nameParts.length == 1 && qualifierList.nonEmpty) { + resolver(nameParts.head, qualifierList.last) + } else { + false + } + } + matched + } + + override def expand( + input: LogicalPlan, + resolver: Resolver): Seq[NamedExpression] = { // If there is no table specified, use all input attributes. if (target.isEmpty) return input.output - val expandedAttributes = - if (target.get.size == 1) { - // If there is a table, pick out attributes that are part of this table. - input.output.filter(_.qualifier.exists(resolver(_, target.get.head))) - } else { - List() - } + val expandedAttributes = input.output.filter(matchedQualifier(_, target.get, resolver)) + if (expandedAttributes.nonEmpty) return expandedAttributes // Try to resolve it as a struct expansion. If there is a conflict and both are possible, @@ -316,8 +345,8 @@ case class UnresolvedRegex(regexPattern: String, table: Option[String], caseSens // If there is no table specified, use all input attributes that match expr case None => input.output.filter(_.name.matches(pattern)) // If there is a table, pick out attributes that are part of this table that match expr - case Some(t) => input.output.filter(_.qualifier.exists(resolver(_, t))) - .filter(_.name.matches(pattern)) + case Some(t) => input.output.filter(a => a.qualifier.nonEmpty && + resolver(a.qualifier.last, t)).filter(_.name.matches(pattern)) } } @@ -345,7 +374,7 @@ case class MultiAlias(child: Expression, names: Seq[String]) override def nullable: Boolean = throw new UnresolvedException(this, "nullable") - override def qualifier: Option[String] = throw new UnresolvedException(this, "qualifier") + override def qualifier: Seq[String] = throw new UnresolvedException(this, "qualifier") override def toAttribute: Attribute = throw new UnresolvedException(this, "toAttribute") @@ -403,7 +432,7 @@ case class UnresolvedAlias( extends UnaryExpression with NamedExpression with Unevaluable { override def toAttribute: Attribute = throw new UnresolvedException(this, "toAttribute") - override def qualifier: Option[String] = throw new UnresolvedException(this, "qualifier") + override def qualifier: Seq[String] = throw new UnresolvedException(this, "qualifier") override def exprId: ExprId = throw new UnresolvedException(this, "exprId") override def nullable: Boolean = throw new UnresolvedException(this, "nullable") override def dataType: DataType = throw new UnresolvedException(this, "dataType") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/view.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/view.scala index 20216087b0158..af74693000c44 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/view.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/view.scala @@ -48,7 +48,7 @@ import org.apache.spark.sql.internal.SQLConf * completely resolved during the batch of Resolution. */ case class AliasViewChild(conf: SQLConf) extends Rule[LogicalPlan] with CastSupport { - override def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { + override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperatorsUp { case v @ View(desc, output, child) if child.resolved && output != child.output => val resolver = conf.resolver val queryColumnNames = desc.viewQueryColumnNames @@ -76,7 +76,8 @@ case class AliasViewChild(conf: SQLConf) extends Rule[LogicalPlan] with CastSupp // Will throw an AnalysisException if the cast can't perform or might truncate. if (Cast.mayTruncate(originAttr.dataType, attr.dataType)) { throw new AnalysisException(s"Cannot up cast ${originAttr.sql} from " + - s"${originAttr.dataType.simpleString} to ${attr.simpleString} as it may truncate\n") + s"${originAttr.dataType.catalogString} to ${attr.dataType.catalogString} as it " + + s"may truncate\n") } else { Alias(cast(originAttr, attr.dataType), attr.name)(exprId = attr.exprId, qualifier = attr.qualifier, explicitMetadata = Some(attr.metadata)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala index c390337c03ff5..c11b444212946 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala @@ -34,7 +34,7 @@ import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst._ import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder -import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionInfo} +import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionInfo, ImplicitCastInputTypes} import org.apache.spark.sql.catalyst.parser.{CatalystSqlParser, ParserInterface} import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, SubqueryAlias, View} import org.apache.spark.sql.catalyst.util.StringUtils @@ -101,6 +101,8 @@ class SessionCatalog( @GuardedBy("this") protected var currentDb: String = formatDatabaseName(DEFAULT_DATABASE) + private val validNameFormat = "([\\w_]+)".r + /** * Checks if the given name conforms the Hive standard ("[a-zA-Z_0-9]+"), * i.e. if this name only contains characters, numbers, and _. @@ -109,7 +111,6 @@ class SessionCatalog( * org.apache.hadoop.hive.metastore.MetaStoreUtils.validateName. */ private def validateName(name: String): Unit = { - val validNameFormat = "([\\w_]+)".r if (!validNameFormat.pattern.matcher(name).matches()) { throw new AnalysisException(s"`$name` is not a valid name for tables/databases. " + "Valid names only contain alphabet characters, numbers and _.") @@ -619,6 +620,7 @@ class SessionCatalog( requireTableExists(TableIdentifier(oldTableName, Some(db))) requireTableNotExists(TableIdentifier(newTableName, Some(db))) validateName(newTableName) + validateNewLocationOfRename(oldName, newName) externalCatalog.renameTable(db, oldTableName, newTableName) } else { if (newName.database.isDefined) { @@ -683,6 +685,7 @@ class SessionCatalog( * * If the relation is a view, we generate a [[View]] operator from the view description, and * wrap the logical plan in a [[SubqueryAlias]] which will track the name of the view. + * [[SubqueryAlias]] will also keep track of the name and database(optional) of the table/view * * @param name The name of the table/view that we look up. */ @@ -692,12 +695,13 @@ class SessionCatalog( val table = formatTableName(name.table) if (db == globalTempViewManager.database) { globalTempViewManager.get(table).map { viewDef => - SubqueryAlias(table, viewDef) + SubqueryAlias(table, db, viewDef) }.getOrElse(throw new NoSuchTableException(db, table)) } else if (name.database.isDefined || !tempViews.contains(table)) { val metadata = externalCatalog.getTable(db, table) if (metadata.tableType == CatalogTableType.VIEW) { val viewText = metadata.viewText.getOrElse(sys.error("Invalid view without text.")) + logDebug(s"'$viewText' will be used for the view($table).") // The relation is a view, so we wrap the relation by: // 1. Add a [[View]] operator over the relation to keep track of the view desc; // 2. Wrap the logical plan in a [[SubqueryAlias]] which tracks the name of the view. @@ -705,9 +709,9 @@ class SessionCatalog( desc = metadata, output = metadata.schema.toAttributes, child = parser.parsePlan(viewText)) - SubqueryAlias(table, child) + SubqueryAlias(table, db, child) } else { - SubqueryAlias(table, UnresolvedCatalogRelation(metadata)) + SubqueryAlias(table, db, UnresolvedCatalogRelation(metadata)) } } else { SubqueryAlias(table, tempViews(table)) @@ -1058,7 +1062,7 @@ class SessionCatalog( } /** - * overwirte a metastore function in the database specified in `funcDefinition`.. + * overwrite a metastore function in the database specified in `funcDefinition`.. * If no database is specified, assume the function is in the current database. */ def alterFunction(funcDefinition: CatalogFunction): Unit = { @@ -1123,13 +1127,22 @@ class SessionCatalog( name: String, clazz: Class[_], input: Seq[Expression]): Expression = { + // Unfortunately we need to use reflection here because UserDefinedAggregateFunction + // and ScalaUDAF are defined in sql/core module. val clsForUDAF = Utils.classForName("org.apache.spark.sql.expressions.UserDefinedAggregateFunction") if (clsForUDAF.isAssignableFrom(clazz)) { val cls = Utils.classForName("org.apache.spark.sql.execution.aggregate.ScalaUDAF") - cls.getConstructor(classOf[Seq[Expression]], clsForUDAF, classOf[Int], classOf[Int]) + val e = cls.getConstructor(classOf[Seq[Expression]], clsForUDAF, classOf[Int], classOf[Int]) .newInstance(input, clazz.newInstance().asInstanceOf[Object], Int.box(1), Int.box(1)) - .asInstanceOf[Expression] + .asInstanceOf[ImplicitCastInputTypes] + + // Check input argument size + if (e.inputTypes.size != input.size) { + throw new AnalysisException(s"Invalid number of arguments for function $name. " + + s"Expected: ${e.inputTypes.size}; Found: ${input.size}") + } + e } else { throw new AnalysisException(s"No handler for UDAF '${clazz.getCanonicalName}'. " + s"Use sparkSession.udf.register(...) instead.") @@ -1192,6 +1205,22 @@ class SessionCatalog( !hiveFunctions.contains(name.funcName.toLowerCase(Locale.ROOT)) } + /** + * Return whether this function has been registered in the function registry of the current + * session. If not existed, return false. + */ + def isRegisteredFunction(name: FunctionIdentifier): Boolean = { + functionRegistry.functionExists(name) + } + + /** + * Returns whether it is a persistent function. If not existed, returns false. + */ + def isPersistentFunction(name: FunctionIdentifier): Boolean = { + val db = formatDatabaseName(name.database.getOrElse(getCurrentDatabase)) + databaseExists(db) && externalCatalog.functionExists(db, name.funcName) + } + protected def failFunctionLookup(name: FunctionIdentifier): Nothing = { throw new NoSuchFunctionException( db = name.database.getOrElse(getCurrentDatabase), func = name.funcName) @@ -1366,4 +1395,23 @@ class SessionCatalog( // copy over temporary views tempViews.foreach(kv => target.tempViews.put(kv._1, kv._2)) } + + /** + * Validate the new locatoin before renaming a managed table, which should be non-existent. + */ + private def validateNewLocationOfRename( + oldName: TableIdentifier, + newName: TableIdentifier): Unit = { + val oldTable = getTableMetadata(oldName) + if (oldTable.tableType == CatalogTableType.MANAGED) { + val databaseLocation = + externalCatalog.getDatabase(oldName.database.getOrElse(currentDb)).locationUri + val newTableLocation = new Path(new Path(databaseLocation), formatTableName(newName.table)) + val fs = newTableLocation.getFileSystem(hadoopConf) + if (fs.exists(newTableLocation)) { + throw new AnalysisException(s"Can not rename the managed table('$oldName')" + + s". The associated location('$newTableLocation') already exists.") + } + } + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala index c6105c5526049..30ded13410f7c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala @@ -32,6 +32,7 @@ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils} import org.apache.spark.sql.catalyst.util.quoteIdentifier +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -114,7 +115,10 @@ case class CatalogTablePartition( map.put("Partition Parameters", s"{${parameters.map(p => p._1 + "=" + p._2).mkString(", ")}}") } map.put("Created Time", new Date(createTime).toString) - map.put("Last Access", new Date(lastAccessTime).toString) + val lastAccess = { + if (-1 == lastAccessTime) "UNKNOWN" else new Date(lastAccessTime).toString + } + map.put("Last Access", lastAccess) stats.foreach(s => map.put("Partition Statistics", s.simpleString)) map } @@ -170,9 +174,12 @@ case class BucketSpec( numBuckets: Int, bucketColumnNames: Seq[String], sortColumnNames: Seq[String]) { - if (numBuckets <= 0 || numBuckets >= 100000) { + def conf: SQLConf = SQLConf.get + + if (numBuckets <= 0 || numBuckets > conf.bucketingMaxBuckets) { throw new AnalysisException( - s"Number of buckets should be greater than 0 but less than 100000. Got `$numBuckets`") + s"Number of buckets should be greater than 0 but less than or equal to " + + s"bucketing.maxBuckets (`${conf.bucketingMaxBuckets}`). Got `$numBuckets`") } override def toString: String = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index 89e8c998f740d..d3ccd18d0245e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -88,7 +88,13 @@ package object dsl { def <=> (other: Expression): Predicate = EqualNullSafe(expr, other) def =!= (other: Expression): Predicate = Not(EqualTo(expr, other)) - def in(list: Expression*): Expression = In(expr, list) + def in(list: Expression*): Expression = list match { + case Seq(l: ListQuery) => expr match { + case c: CreateNamedStruct => InSubquery(c.valExprs, l) + case other => InSubquery(Seq(other), l) + } + case _ => In(expr, list) + } def like(other: Expression): Expression = Like(expr, other) def rlike(other: Expression): Expression = RLike(expr, other) @@ -166,6 +172,9 @@ package object dsl { def maxDistinct(e: Expression): Expression = Max(e).toAggregateExpression(isDistinct = true) def upper(e: Expression): Expression = Upper(e) def lower(e: Expression): Expression = Lower(e) + def coalesce(args: Expression*): Expression = Coalesce(args) + def greatest(args: Expression*): Expression = Greatest(args) + def least(args: Expression*): Expression = Least(args) def sqrt(e: Expression): Expression = Sqrt(e) def abs(e: Expression): Expression = Abs(e) def star(names: String*): Expression = names match { @@ -355,9 +364,11 @@ package object dsl { def subquery(alias: Symbol): LogicalPlan = SubqueryAlias(alias.name, logicalPlan) - def except(otherPlan: LogicalPlan): LogicalPlan = Except(logicalPlan, otherPlan) + def except(otherPlan: LogicalPlan, isAll: Boolean): LogicalPlan = + Except(logicalPlan, otherPlan, isAll) - def intersect(otherPlan: LogicalPlan): LogicalPlan = Intersect(logicalPlan, otherPlan) + def intersect(otherPlan: LogicalPlan, isAll: Boolean): LogicalPlan = + Intersect(logicalPlan, otherPlan, isAll) def union(otherPlan: LogicalPlan): LogicalPlan = Union(logicalPlan, otherPlan) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala index df3ab05e02c76..77582e10f9ff2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.errors.attachTree -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, FalseLiteral} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, FalseLiteral, JavaCode} import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.types._ @@ -53,7 +53,7 @@ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean) ev.copy(code = oev.code) } else { assert(ctx.INPUT_ROW != null, "INPUT_ROW and currentVars cannot both be null.") - val javaType = CodeGenerator.javaType(dataType) + val javaType = JavaCode.javaType(dataType) val value = CodeGenerator.getValue(ctx.INPUT_ROW, dataType, ordinal.toString) if (nullable) { ev.copy(code = diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Canonicalize.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Canonicalize.scala index 7541f527a52a8..fe6db8b344d3d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Canonicalize.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Canonicalize.scala @@ -87,8 +87,6 @@ object Canonicalize { case Not(LessThanOrEqual(l, r)) => GreaterThan(l, r) // order the list in the In operator - // In subqueries contain only one element of type ListQuery. So checking that the length > 1 - // we are not reordering In subqueries. case In(value, list) if list.length > 1 => In(value, list.sortBy(_.hashCode())) case _ => e diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 699ea53b5df0f..8f777997bf615 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -134,6 +134,35 @@ object Cast { toPrecedence > 0 && fromPrecedence > toPrecedence } + /** + * Returns true iff we can safely cast the `from` type to `to` type without any truncating or + * precision lose, e.g. int -> long, date -> timestamp. + */ + def canSafeCast(from: AtomicType, to: AtomicType): Boolean = (from, to) match { + case _ if from == to => true + case (from: NumericType, to: DecimalType) if to.isWiderThan(from) => true + case (from: DecimalType, to: NumericType) if from.isTighterThan(to) => true + case (from, to) if legalNumericPrecedence(from, to) => true + case (DateType, TimestampType) => true + case (_, StringType) => true + case _ => false + } + + private def legalNumericPrecedence(from: DataType, to: DataType): Boolean = { + val fromPrecedence = TypeCoercion.numericPrecedence.indexOf(from) + val toPrecedence = TypeCoercion.numericPrecedence.indexOf(to) + fromPrecedence >= 0 && fromPrecedence < toPrecedence + } + + def canNullSafeCastToDecimal(from: DataType, to: DecimalType): Boolean = from match { + case from: BooleanType if to.isWiderThan(DecimalType.BooleanDecimal) => true + case from: NumericType if to.isWiderThan(from) => true + case from: DecimalType => + // truncating or precision lose + (to.precision - to.scale) > (from.precision - from.scale) + case _ => false // overflow + } + def forceNullable(from: DataType, to: DataType): Boolean = (from, to) match { case (NullType, _) => true case (_, _) if from == to => false @@ -149,7 +178,7 @@ object Cast { case (DateType, _) => true case (_, CalendarIntervalType) => true - case (_, _: DecimalType) => true // overflow + case (_, to: DecimalType) if !canNullSafeCastToDecimal(from, to) => true case (_: FractionalType, _: IntegralType) => true // NaN, infinity case _ => false } @@ -182,7 +211,7 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String TypeCheckResult.TypeCheckSuccess } else { TypeCheckResult.TypeCheckFailure( - s"cannot cast ${child.dataType.simpleString} to ${dataType.simpleString}") + s"cannot cast ${child.dataType.catalogString} to ${dataType.catalogString}") } } @@ -625,25 +654,21 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String val eval = child.genCode(ctx) val nullSafeCast = nullSafeCastFunction(child.dataType, dataType, ctx) - ev.copy(code = - code""" - ${eval.code} - // This comment is added for manually tracking reference of ${eval.value}, ${eval.isNull} - ${castCode(ctx, eval.value, eval.isNull, ev.value, ev.isNull, dataType, nullSafeCast)} - """) + ev.copy(code = eval.code + + castCode(ctx, eval.value, eval.isNull, ev.value, ev.isNull, dataType, nullSafeCast)) } // The function arguments are: `input`, `result` and `resultIsNull`. We don't need `inputIsNull` // in parameter list, because the returned code will be put in null safe evaluation region. - private[this] type CastFunction = (String, String, String) => String + private[this] type CastFunction = (ExprValue, ExprValue, ExprValue) => Block private[this] def nullSafeCastFunction( from: DataType, to: DataType, ctx: CodegenContext): CastFunction = to match { - case _ if from == NullType => (c, evPrim, evNull) => s"$evNull = true;" - case _ if to == from => (c, evPrim, evNull) => s"$evPrim = $c;" + case _ if from == NullType => (c, evPrim, evNull) => code"$evNull = true;" + case _ if to == from => (c, evPrim, evNull) => code"$evPrim = $c;" case StringType => castToStringCode(from, ctx) case BinaryType => castToBinaryCode(from) case DateType => castToDateCode(from, ctx) @@ -664,18 +689,19 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String case struct: StructType => castStructCode(from.asInstanceOf[StructType], struct, ctx) case udt: UserDefinedType[_] if udt.userClass == from.asInstanceOf[UserDefinedType[_]].userClass => - (c, evPrim, evNull) => s"$evPrim = $c;" + (c, evPrim, evNull) => code"$evPrim = $c;" case _: UserDefinedType[_] => throw new SparkException(s"Cannot cast $from to $to.") } // Since we need to cast input expressions recursively inside ComplexTypes, such as Map's // Key and Value, Struct's field, we need to name out all the variable names involved in a cast. - private[this] def castCode(ctx: CodegenContext, input: String, inputIsNull: String, - result: String, resultIsNull: String, resultType: DataType, cast: CastFunction): String = { - s""" + private[this] def castCode(ctx: CodegenContext, input: ExprValue, inputIsNull: ExprValue, + result: ExprValue, resultIsNull: ExprValue, resultType: DataType, cast: CastFunction): Block = { + val javaType = JavaCode.javaType(resultType) + code""" boolean $resultIsNull = $inputIsNull; - ${CodeGenerator.javaType(resultType)} $result = ${CodeGenerator.defaultValue(resultType)}; + $javaType $result = ${CodeGenerator.defaultValue(resultType)}; if (!$inputIsNull) { ${cast(input, result, resultIsNull)} } @@ -684,22 +710,24 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String private def writeArrayToStringBuilder( et: DataType, - array: String, - buffer: String, - ctx: CodegenContext): String = { + array: ExprValue, + buffer: ExprValue, + ctx: CodegenContext): Block = { val elementToStringCode = castToStringCode(et, ctx) val funcName = ctx.freshName("elementToString") - val elementToStringFunc = ctx.addNewFunction(funcName, + val element = JavaCode.variable("element", et) + val elementStr = JavaCode.variable("elementStr", StringType) + val elementToStringFunc = inline"${ctx.addNewFunction(funcName, s""" - |private UTF8String $funcName(${CodeGenerator.javaType(et)} element) { - | UTF8String elementStr = null; - | ${elementToStringCode("element", "elementStr", null /* resultIsNull won't be used */)} + |private UTF8String $funcName(${CodeGenerator.javaType(et)} $element) { + | UTF8String $elementStr = null; + | ${elementToStringCode(element, elementStr, null /* resultIsNull won't be used */)} | return elementStr; |} - """.stripMargin) + """.stripMargin)}" - val loopIndex = ctx.freshName("loopIndex") - s""" + val loopIndex = ctx.freshVariable("loopIndex", IntegerType) + code""" |$buffer.append("["); |if ($array.numElements() > 0) { | if (!$array.isNullAt(0)) { @@ -720,31 +748,37 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String private def writeMapToStringBuilder( kt: DataType, vt: DataType, - map: String, - buffer: String, - ctx: CodegenContext): String = { + map: ExprValue, + buffer: ExprValue, + ctx: CodegenContext): Block = { def dataToStringFunc(func: String, dataType: DataType) = { val funcName = ctx.freshName(func) val dataToStringCode = castToStringCode(dataType, ctx) - ctx.addNewFunction(funcName, + val data = JavaCode.variable("data", dataType) + val dataStr = JavaCode.variable("dataStr", StringType) + val functionCall = ctx.addNewFunction(funcName, s""" - |private UTF8String $funcName(${CodeGenerator.javaType(dataType)} data) { - | UTF8String dataStr = null; - | ${dataToStringCode("data", "dataStr", null /* resultIsNull won't be used */)} + |private UTF8String $funcName(${CodeGenerator.javaType(dataType)} $data) { + | UTF8String $dataStr = null; + | ${dataToStringCode(data, dataStr, null /* resultIsNull won't be used */)} | return dataStr; |} """.stripMargin) + inline"$functionCall" } val keyToStringFunc = dataToStringFunc("keyToString", kt) val valueToStringFunc = dataToStringFunc("valueToString", vt) - val loopIndex = ctx.freshName("loopIndex") - val getMapFirstKey = CodeGenerator.getValue(s"$map.keyArray()", kt, "0") - val getMapFirstValue = CodeGenerator.getValue(s"$map.valueArray()", vt, "0") - val getMapKeyArray = CodeGenerator.getValue(s"$map.keyArray()", kt, loopIndex) - val getMapValueArray = CodeGenerator.getValue(s"$map.valueArray()", vt, loopIndex) - s""" + val loopIndex = ctx.freshVariable("loopIndex", IntegerType) + val mapKeyArray = JavaCode.expression(s"$map.keyArray()", classOf[ArrayData]) + val mapValueArray = JavaCode.expression(s"$map.valueArray()", classOf[ArrayData]) + val getMapFirstKey = CodeGenerator.getValue(mapKeyArray, kt, JavaCode.literal("0", IntegerType)) + val getMapFirstValue = CodeGenerator.getValue(mapValueArray, vt, + JavaCode.literal("0", IntegerType)) + val getMapKeyArray = CodeGenerator.getValue(mapKeyArray, kt, loopIndex) + val getMapValueArray = CodeGenerator.getValue(mapValueArray, vt, loopIndex) + code""" |$buffer.append("["); |if ($map.numElements() > 0) { | $buffer.append($keyToStringFunc($getMapFirstKey)); @@ -769,20 +803,21 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String private def writeStructToStringBuilder( st: Seq[DataType], - row: String, - buffer: String, - ctx: CodegenContext): String = { + row: ExprValue, + buffer: ExprValue, + ctx: CodegenContext): Block = { val structToStringCode = st.zipWithIndex.map { case (ft, i) => val fieldToStringCode = castToStringCode(ft, ctx) - val field = ctx.freshName("field") - val fieldStr = ctx.freshName("fieldStr") - s""" - |${if (i != 0) s"""$buffer.append(",");""" else ""} + val field = ctx.freshVariable("field", ft) + val fieldStr = ctx.freshVariable("fieldStr", StringType) + val javaType = JavaCode.javaType(ft) + code""" + |${if (i != 0) code"""$buffer.append(",");""" else EmptyBlock} |if (!$row.isNullAt($i)) { - | ${if (i != 0) s"""$buffer.append(" ");""" else ""} + | ${if (i != 0) code"""$buffer.append(" ");""" else EmptyBlock} | | // Append $i field into the string buffer - | ${CodeGenerator.javaType(ft)} $field = ${CodeGenerator.getValue(row, ft, s"$i")}; + | $javaType $field = ${CodeGenerator.getValue(row, ft, s"$i")}; | UTF8String $fieldStr = null; | ${fieldToStringCode(field, fieldStr, null /* resultIsNull won't be used */)} | $buffer.append($fieldStr); @@ -791,11 +826,12 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String } val writeStructCode = ctx.splitExpressions( - expressions = structToStringCode, + expressions = structToStringCode.map(_.code), funcName = "fieldToString", - arguments = ("InternalRow", row) :: (classOf[UTF8StringBuilder].getName, buffer) :: Nil) + arguments = ("InternalRow", row.code) :: + (classOf[UTF8StringBuilder].getName, buffer.code) :: Nil) - s""" + code""" |$buffer.append("["); |$writeStructCode |$buffer.append("]"); @@ -805,20 +841,20 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String private[this] def castToStringCode(from: DataType, ctx: CodegenContext): CastFunction = { from match { case BinaryType => - (c, evPrim, evNull) => s"$evPrim = UTF8String.fromBytes($c);" + (c, evPrim, evNull) => code"$evPrim = UTF8String.fromBytes($c);" case DateType => - (c, evPrim, evNull) => s"""$evPrim = UTF8String.fromString( + (c, evPrim, evNull) => code"""$evPrim = UTF8String.fromString( org.apache.spark.sql.catalyst.util.DateTimeUtils.dateToString($c));""" case TimestampType => - val tz = ctx.addReferenceObj("timeZone", timeZone) - (c, evPrim, evNull) => s"""$evPrim = UTF8String.fromString( + val tz = JavaCode.global(ctx.addReferenceObj("timeZone", timeZone), timeZone.getClass) + (c, evPrim, evNull) => code"""$evPrim = UTF8String.fromString( org.apache.spark.sql.catalyst.util.DateTimeUtils.timestampToString($c, $tz));""" case ArrayType(et, _) => (c, evPrim, evNull) => { - val buffer = ctx.freshName("buffer") - val bufferClass = classOf[UTF8StringBuilder].getName + val buffer = ctx.freshVariable("buffer", classOf[UTF8StringBuilder]) + val bufferClass = JavaCode.javaType(classOf[UTF8StringBuilder]) val writeArrayElemCode = writeArrayToStringBuilder(et, c, buffer, ctx) - s""" + code""" |$bufferClass $buffer = new $bufferClass(); |$writeArrayElemCode; |$evPrim = $buffer.build(); @@ -826,10 +862,10 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String } case MapType(kt, vt, _) => (c, evPrim, evNull) => { - val buffer = ctx.freshName("buffer") - val bufferClass = classOf[UTF8StringBuilder].getName + val buffer = ctx.freshVariable("buffer", classOf[UTF8StringBuilder]) + val bufferClass = JavaCode.javaType(classOf[UTF8StringBuilder]) val writeMapElemCode = writeMapToStringBuilder(kt, vt, c, buffer, ctx) - s""" + code""" |$bufferClass $buffer = new $bufferClass(); |$writeMapElemCode; |$evPrim = $buffer.build(); @@ -837,11 +873,11 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String } case StructType(fields) => (c, evPrim, evNull) => { - val row = ctx.freshName("row") - val buffer = ctx.freshName("buffer") - val bufferClass = classOf[UTF8StringBuilder].getName + val row = ctx.freshVariable("row", classOf[InternalRow]) + val buffer = ctx.freshVariable("buffer", classOf[UTF8StringBuilder]) + val bufferClass = JavaCode.javaType(classOf[UTF8StringBuilder]) val writeStructCode = writeStructToStringBuilder(fields.map(_.dataType), row, buffer, ctx) - s""" + code""" |InternalRow $row = $c; |$bufferClass $buffer = new $bufferClass(); |$writeStructCode @@ -850,26 +886,26 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String } case pudt: PythonUserDefinedType => castToStringCode(pudt.sqlType, ctx) case udt: UserDefinedType[_] => - val udtRef = ctx.addReferenceObj("udt", udt) + val udtRef = JavaCode.global(ctx.addReferenceObj("udt", udt), udt.sqlType) (c, evPrim, evNull) => { - s"$evPrim = UTF8String.fromString($udtRef.deserialize($c).toString());" + code"$evPrim = UTF8String.fromString($udtRef.deserialize($c).toString());" } case _ => - (c, evPrim, evNull) => s"$evPrim = UTF8String.fromString(String.valueOf($c));" + (c, evPrim, evNull) => code"$evPrim = UTF8String.fromString(String.valueOf($c));" } } private[this] def castToBinaryCode(from: DataType): CastFunction = from match { case StringType => - (c, evPrim, evNull) => s"$evPrim = $c.getBytes();" + (c, evPrim, evNull) => code"$evPrim = $c.getBytes();" } private[this] def castToDateCode( from: DataType, ctx: CodegenContext): CastFunction = from match { case StringType => - val intOpt = ctx.freshName("intOpt") - (c, evPrim, evNull) => s""" + val intOpt = ctx.freshVariable("intOpt", classOf[Option[Integer]]) + (c, evPrim, evNull) => code""" scala.Option $intOpt = org.apache.spark.sql.catalyst.util.DateTimeUtils.stringToDate($c); if ($intOpt.isDefined()) { @@ -879,75 +915,85 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String } """ case TimestampType => - val tz = ctx.addReferenceObj("timeZone", timeZone) + val tz = JavaCode.global(ctx.addReferenceObj("timeZone", timeZone), timeZone.getClass) (c, evPrim, evNull) => - s"$evPrim = org.apache.spark.sql.catalyst.util.DateTimeUtils.millisToDays($c / 1000L, $tz);" + code"""$evPrim = + org.apache.spark.sql.catalyst.util.DateTimeUtils.millisToDays($c / 1000L, $tz);""" case _ => - (c, evPrim, evNull) => s"$evNull = true;" + (c, evPrim, evNull) => code"$evNull = true;" } - private[this] def changePrecision(d: String, decimalType: DecimalType, - evPrim: String, evNull: String): String = - s""" - if ($d.changePrecision(${decimalType.precision}, ${decimalType.scale})) { - $evPrim = $d; - } else { - $evNull = true; - } - """ + private[this] def changePrecision(d: ExprValue, decimalType: DecimalType, + evPrim: ExprValue, evNull: ExprValue, canNullSafeCast: Boolean): Block = { + if (canNullSafeCast) { + code""" + |$d.changePrecision(${decimalType.precision}, ${decimalType.scale}); + |$evPrim = $d; + """.stripMargin + } else { + code""" + |if ($d.changePrecision(${decimalType.precision}, ${decimalType.scale})) { + | $evPrim = $d; + |} else { + | $evNull = true; + |} + """.stripMargin + } + } private[this] def castToDecimalCode( from: DataType, target: DecimalType, ctx: CodegenContext): CastFunction = { - val tmp = ctx.freshName("tmpDecimal") + val tmp = ctx.freshVariable("tmpDecimal", classOf[Decimal]) + val canNullSafeCast = Cast.canNullSafeCastToDecimal(from, target) from match { case StringType => (c, evPrim, evNull) => - s""" + code""" try { Decimal $tmp = Decimal.apply(new java.math.BigDecimal($c.toString())); - ${changePrecision(tmp, target, evPrim, evNull)} + ${changePrecision(tmp, target, evPrim, evNull, canNullSafeCast)} } catch (java.lang.NumberFormatException e) { $evNull = true; } """ case BooleanType => (c, evPrim, evNull) => - s""" + code""" Decimal $tmp = $c ? Decimal.apply(1) : Decimal.apply(0); - ${changePrecision(tmp, target, evPrim, evNull)} + ${changePrecision(tmp, target, evPrim, evNull, canNullSafeCast)} """ case DateType => // date can't cast to decimal in Hive - (c, evPrim, evNull) => s"$evNull = true;" + (c, evPrim, evNull) => code"$evNull = true;" case TimestampType => // Note that we lose precision here. (c, evPrim, evNull) => - s""" + code""" Decimal $tmp = Decimal.apply( scala.math.BigDecimal.valueOf(${timestampToDoubleCode(c)})); - ${changePrecision(tmp, target, evPrim, evNull)} + ${changePrecision(tmp, target, evPrim, evNull, canNullSafeCast)} """ case DecimalType() => (c, evPrim, evNull) => - s""" + code""" Decimal $tmp = $c.clone(); - ${changePrecision(tmp, target, evPrim, evNull)} + ${changePrecision(tmp, target, evPrim, evNull, canNullSafeCast)} """ case x: IntegralType => (c, evPrim, evNull) => - s""" + code""" Decimal $tmp = Decimal.apply((long) $c); - ${changePrecision(tmp, target, evPrim, evNull)} + ${changePrecision(tmp, target, evPrim, evNull, canNullSafeCast)} """ case x: FractionalType => // All other numeric types can be represented precisely as Doubles (c, evPrim, evNull) => - s""" + code""" try { Decimal $tmp = Decimal.apply(scala.math.BigDecimal.valueOf((double) $c)); - ${changePrecision(tmp, target, evPrim, evNull)} + ${changePrecision(tmp, target, evPrim, evNull, canNullSafeCast)} } catch (java.lang.NumberFormatException e) { $evNull = true; } @@ -959,10 +1005,10 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String from: DataType, ctx: CodegenContext): CastFunction = from match { case StringType => - val tz = ctx.addReferenceObj("timeZone", timeZone) - val longOpt = ctx.freshName("longOpt") + val tz = JavaCode.global(ctx.addReferenceObj("timeZone", timeZone), timeZone.getClass) + val longOpt = ctx.freshVariable("longOpt", classOf[Option[Long]]) (c, evPrim, evNull) => - s""" + code""" scala.Option $longOpt = org.apache.spark.sql.catalyst.util.DateTimeUtils.stringToTimestamp($c, $tz); if ($longOpt.isDefined()) { @@ -972,18 +1018,19 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String } """ case BooleanType => - (c, evPrim, evNull) => s"$evPrim = $c ? 1L : 0L;" + (c, evPrim, evNull) => code"$evPrim = $c ? 1L : 0L;" case _: IntegralType => - (c, evPrim, evNull) => s"$evPrim = ${longToTimeStampCode(c)};" + (c, evPrim, evNull) => code"$evPrim = ${longToTimeStampCode(c)};" case DateType => - val tz = ctx.addReferenceObj("timeZone", timeZone) + val tz = JavaCode.global(ctx.addReferenceObj("timeZone", timeZone), timeZone.getClass) (c, evPrim, evNull) => - s"$evPrim = org.apache.spark.sql.catalyst.util.DateTimeUtils.daysToMillis($c, $tz) * 1000;" + code"""$evPrim = + org.apache.spark.sql.catalyst.util.DateTimeUtils.daysToMillis($c, $tz) * 1000;""" case DecimalType() => - (c, evPrim, evNull) => s"$evPrim = ${decimalToTimestampCode(c)};" + (c, evPrim, evNull) => code"$evPrim = ${decimalToTimestampCode(c)};" case DoubleType => (c, evPrim, evNull) => - s""" + code""" if (Double.isNaN($c) || Double.isInfinite($c)) { $evNull = true; } else { @@ -992,7 +1039,7 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String """ case FloatType => (c, evPrim, evNull) => - s""" + code""" if (Float.isNaN($c) || Float.isInfinite($c)) { $evNull = true; } else { @@ -1004,7 +1051,7 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String private[this] def castToIntervalCode(from: DataType): CastFunction = from match { case StringType => (c, evPrim, evNull) => - s"""$evPrim = CalendarInterval.fromString($c.toString()); + code"""$evPrim = CalendarInterval.fromString($c.toString()); if(${evPrim} == null) { ${evNull} = true; } @@ -1012,18 +1059,21 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String } - private[this] def decimalToTimestampCode(d: String): String = - s"($d.toBigDecimal().bigDecimal().multiply(new java.math.BigDecimal(1000000L))).longValue()" - private[this] def longToTimeStampCode(l: String): String = s"$l * 1000000L" - private[this] def timestampToIntegerCode(ts: String): String = - s"java.lang.Math.floor((double) $ts / 1000000L)" - private[this] def timestampToDoubleCode(ts: String): String = s"$ts / 1000000.0" + private[this] def decimalToTimestampCode(d: ExprValue): Block = { + val block = inline"new java.math.BigDecimal(1000000L)" + code"($d.toBigDecimal().bigDecimal().multiply($block)).longValue()" + } + private[this] def longToTimeStampCode(l: ExprValue): Block = code"$l * 1000000L" + private[this] def timestampToIntegerCode(ts: ExprValue): Block = + code"java.lang.Math.floor((double) $ts / 1000000L)" + private[this] def timestampToDoubleCode(ts: ExprValue): Block = + code"$ts / 1000000.0" private[this] def castToBooleanCode(from: DataType): CastFunction = from match { case StringType => - val stringUtils = StringUtils.getClass.getName.stripSuffix("$") + val stringUtils = inline"${StringUtils.getClass.getName.stripSuffix("$")}" (c, evPrim, evNull) => - s""" + code""" if ($stringUtils.isTrueString($c)) { $evPrim = true; } else if ($stringUtils.isFalseString($c)) { @@ -1033,21 +1083,21 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String } """ case TimestampType => - (c, evPrim, evNull) => s"$evPrim = $c != 0;" + (c, evPrim, evNull) => code"$evPrim = $c != 0;" case DateType => // Hive would return null when cast from date to boolean - (c, evPrim, evNull) => s"$evNull = true;" + (c, evPrim, evNull) => code"$evNull = true;" case DecimalType() => - (c, evPrim, evNull) => s"$evPrim = !$c.isZero();" + (c, evPrim, evNull) => code"$evPrim = !$c.isZero();" case n: NumericType => - (c, evPrim, evNull) => s"$evPrim = $c != 0;" + (c, evPrim, evNull) => code"$evPrim = $c != 0;" } private[this] def castToByteCode(from: DataType, ctx: CodegenContext): CastFunction = from match { case StringType => - val wrapper = ctx.freshName("intWrapper") + val wrapper = ctx.freshVariable("intWrapper", classOf[UTF8String.IntWrapper]) (c, evPrim, evNull) => - s""" + code""" UTF8String.IntWrapper $wrapper = new UTF8String.IntWrapper(); if ($c.toByte($wrapper)) { $evPrim = (byte) $wrapper.value; @@ -1057,24 +1107,24 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String $wrapper = null; """ case BooleanType => - (c, evPrim, evNull) => s"$evPrim = $c ? (byte) 1 : (byte) 0;" + (c, evPrim, evNull) => code"$evPrim = $c ? (byte) 1 : (byte) 0;" case DateType => - (c, evPrim, evNull) => s"$evNull = true;" + (c, evPrim, evNull) => code"$evNull = true;" case TimestampType => - (c, evPrim, evNull) => s"$evPrim = (byte) ${timestampToIntegerCode(c)};" + (c, evPrim, evNull) => code"$evPrim = (byte) ${timestampToIntegerCode(c)};" case DecimalType() => - (c, evPrim, evNull) => s"$evPrim = $c.toByte();" + (c, evPrim, evNull) => code"$evPrim = $c.toByte();" case x: NumericType => - (c, evPrim, evNull) => s"$evPrim = (byte) $c;" + (c, evPrim, evNull) => code"$evPrim = (byte) $c;" } private[this] def castToShortCode( from: DataType, ctx: CodegenContext): CastFunction = from match { case StringType => - val wrapper = ctx.freshName("intWrapper") + val wrapper = ctx.freshVariable("intWrapper", classOf[UTF8String.IntWrapper]) (c, evPrim, evNull) => - s""" + code""" UTF8String.IntWrapper $wrapper = new UTF8String.IntWrapper(); if ($c.toShort($wrapper)) { $evPrim = (short) $wrapper.value; @@ -1084,22 +1134,22 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String $wrapper = null; """ case BooleanType => - (c, evPrim, evNull) => s"$evPrim = $c ? (short) 1 : (short) 0;" + (c, evPrim, evNull) => code"$evPrim = $c ? (short) 1 : (short) 0;" case DateType => - (c, evPrim, evNull) => s"$evNull = true;" + (c, evPrim, evNull) => code"$evNull = true;" case TimestampType => - (c, evPrim, evNull) => s"$evPrim = (short) ${timestampToIntegerCode(c)};" + (c, evPrim, evNull) => code"$evPrim = (short) ${timestampToIntegerCode(c)};" case DecimalType() => - (c, evPrim, evNull) => s"$evPrim = $c.toShort();" + (c, evPrim, evNull) => code"$evPrim = $c.toShort();" case x: NumericType => - (c, evPrim, evNull) => s"$evPrim = (short) $c;" + (c, evPrim, evNull) => code"$evPrim = (short) $c;" } private[this] def castToIntCode(from: DataType, ctx: CodegenContext): CastFunction = from match { case StringType => - val wrapper = ctx.freshName("intWrapper") + val wrapper = ctx.freshVariable("intWrapper", classOf[UTF8String.IntWrapper]) (c, evPrim, evNull) => - s""" + code""" UTF8String.IntWrapper $wrapper = new UTF8String.IntWrapper(); if ($c.toInt($wrapper)) { $evPrim = $wrapper.value; @@ -1109,23 +1159,23 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String $wrapper = null; """ case BooleanType => - (c, evPrim, evNull) => s"$evPrim = $c ? 1 : 0;" + (c, evPrim, evNull) => code"$evPrim = $c ? 1 : 0;" case DateType => - (c, evPrim, evNull) => s"$evNull = true;" + (c, evPrim, evNull) => code"$evNull = true;" case TimestampType => - (c, evPrim, evNull) => s"$evPrim = (int) ${timestampToIntegerCode(c)};" + (c, evPrim, evNull) => code"$evPrim = (int) ${timestampToIntegerCode(c)};" case DecimalType() => - (c, evPrim, evNull) => s"$evPrim = $c.toInt();" + (c, evPrim, evNull) => code"$evPrim = $c.toInt();" case x: NumericType => - (c, evPrim, evNull) => s"$evPrim = (int) $c;" + (c, evPrim, evNull) => code"$evPrim = (int) $c;" } private[this] def castToLongCode(from: DataType, ctx: CodegenContext): CastFunction = from match { case StringType => - val wrapper = ctx.freshName("longWrapper") + val wrapper = ctx.freshVariable("longWrapper", classOf[UTF8String.LongWrapper]) (c, evPrim, evNull) => - s""" + code""" UTF8String.LongWrapper $wrapper = new UTF8String.LongWrapper(); if ($c.toLong($wrapper)) { $evPrim = $wrapper.value; @@ -1135,21 +1185,21 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String $wrapper = null; """ case BooleanType => - (c, evPrim, evNull) => s"$evPrim = $c ? 1L : 0L;" + (c, evPrim, evNull) => code"$evPrim = $c ? 1L : 0L;" case DateType => - (c, evPrim, evNull) => s"$evNull = true;" + (c, evPrim, evNull) => code"$evNull = true;" case TimestampType => - (c, evPrim, evNull) => s"$evPrim = (long) ${timestampToIntegerCode(c)};" + (c, evPrim, evNull) => code"$evPrim = (long) ${timestampToIntegerCode(c)};" case DecimalType() => - (c, evPrim, evNull) => s"$evPrim = $c.toLong();" + (c, evPrim, evNull) => code"$evPrim = $c.toLong();" case x: NumericType => - (c, evPrim, evNull) => s"$evPrim = (long) $c;" + (c, evPrim, evNull) => code"$evPrim = (long) $c;" } private[this] def castToFloatCode(from: DataType): CastFunction = from match { case StringType => (c, evPrim, evNull) => - s""" + code""" try { $evPrim = Float.valueOf($c.toString()); } catch (java.lang.NumberFormatException e) { @@ -1157,21 +1207,21 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String } """ case BooleanType => - (c, evPrim, evNull) => s"$evPrim = $c ? 1.0f : 0.0f;" + (c, evPrim, evNull) => code"$evPrim = $c ? 1.0f : 0.0f;" case DateType => - (c, evPrim, evNull) => s"$evNull = true;" + (c, evPrim, evNull) => code"$evNull = true;" case TimestampType => - (c, evPrim, evNull) => s"$evPrim = (float) (${timestampToDoubleCode(c)});" + (c, evPrim, evNull) => code"$evPrim = (float) (${timestampToDoubleCode(c)});" case DecimalType() => - (c, evPrim, evNull) => s"$evPrim = $c.toFloat();" + (c, evPrim, evNull) => code"$evPrim = $c.toFloat();" case x: NumericType => - (c, evPrim, evNull) => s"$evPrim = (float) $c;" + (c, evPrim, evNull) => code"$evPrim = (float) $c;" } private[this] def castToDoubleCode(from: DataType): CastFunction = from match { case StringType => (c, evPrim, evNull) => - s""" + code""" try { $evPrim = Double.valueOf($c.toString()); } catch (java.lang.NumberFormatException e) { @@ -1179,31 +1229,32 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String } """ case BooleanType => - (c, evPrim, evNull) => s"$evPrim = $c ? 1.0d : 0.0d;" + (c, evPrim, evNull) => code"$evPrim = $c ? 1.0d : 0.0d;" case DateType => - (c, evPrim, evNull) => s"$evNull = true;" + (c, evPrim, evNull) => code"$evNull = true;" case TimestampType => - (c, evPrim, evNull) => s"$evPrim = ${timestampToDoubleCode(c)};" + (c, evPrim, evNull) => code"$evPrim = ${timestampToDoubleCode(c)};" case DecimalType() => - (c, evPrim, evNull) => s"$evPrim = $c.toDouble();" + (c, evPrim, evNull) => code"$evPrim = $c.toDouble();" case x: NumericType => - (c, evPrim, evNull) => s"$evPrim = (double) $c;" + (c, evPrim, evNull) => code"$evPrim = (double) $c;" } private[this] def castArrayCode( fromType: DataType, toType: DataType, ctx: CodegenContext): CastFunction = { val elementCast = nullSafeCastFunction(fromType, toType, ctx) - val arrayClass = classOf[GenericArrayData].getName - val fromElementNull = ctx.freshName("feNull") - val fromElementPrim = ctx.freshName("fePrim") - val toElementNull = ctx.freshName("teNull") - val toElementPrim = ctx.freshName("tePrim") - val size = ctx.freshName("n") - val j = ctx.freshName("j") - val values = ctx.freshName("values") + val arrayClass = JavaCode.javaType(classOf[GenericArrayData]) + val fromElementNull = ctx.freshVariable("feNull", BooleanType) + val fromElementPrim = ctx.freshVariable("fePrim", fromType) + val toElementNull = ctx.freshVariable("teNull", BooleanType) + val toElementPrim = ctx.freshVariable("tePrim", toType) + val size = ctx.freshVariable("n", IntegerType) + val j = ctx.freshVariable("j", IntegerType) + val values = ctx.freshVariable("values", classOf[Array[Object]]) + val javaType = JavaCode.javaType(fromType) (c, evPrim, evNull) => - s""" + code""" final int $size = $c.numElements(); final Object[] $values = new Object[$size]; for (int $j = 0; $j < $size; $j ++) { @@ -1211,7 +1262,7 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String $values[$j] = null; } else { boolean $fromElementNull = false; - ${CodeGenerator.javaType(fromType)} $fromElementPrim = + $javaType $fromElementPrim = ${CodeGenerator.getValue(c, fromType, j)}; ${castCode(ctx, fromElementPrim, fromElementNull, toElementPrim, toElementNull, toType, elementCast)} @@ -1230,23 +1281,23 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String val keysCast = castArrayCode(from.keyType, to.keyType, ctx) val valuesCast = castArrayCode(from.valueType, to.valueType, ctx) - val mapClass = classOf[ArrayBasedMapData].getName + val mapClass = JavaCode.javaType(classOf[ArrayBasedMapData]) - val keys = ctx.freshName("keys") - val convertedKeys = ctx.freshName("convertedKeys") - val convertedKeysNull = ctx.freshName("convertedKeysNull") + val keys = ctx.freshVariable("keys", ArrayType(from.keyType)) + val convertedKeys = ctx.freshVariable("convertedKeys", ArrayType(to.keyType)) + val convertedKeysNull = ctx.freshVariable("convertedKeysNull", BooleanType) - val values = ctx.freshName("values") - val convertedValues = ctx.freshName("convertedValues") - val convertedValuesNull = ctx.freshName("convertedValuesNull") + val values = ctx.freshVariable("values", ArrayType(from.valueType)) + val convertedValues = ctx.freshVariable("convertedValues", ArrayType(to.valueType)) + val convertedValuesNull = ctx.freshVariable("convertedValuesNull", BooleanType) (c, evPrim, evNull) => - s""" + code""" final ArrayData $keys = $c.keyArray(); final ArrayData $values = $c.valueArray(); - ${castCode(ctx, keys, "false", + ${castCode(ctx, keys, FalseLiteral, convertedKeys, convertedKeysNull, ArrayType(to.keyType), keysCast)} - ${castCode(ctx, values, "false", + ${castCode(ctx, values, FalseLiteral, convertedValues, convertedValuesNull, ArrayType(to.valueType), valuesCast)} $evPrim = new $mapClass($convertedKeys, $convertedValues); @@ -1259,17 +1310,18 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String val fieldsCasts = from.fields.zip(to.fields).map { case (fromField, toField) => nullSafeCastFunction(fromField.dataType, toField.dataType, ctx) } - val rowClass = classOf[GenericInternalRow].getName - val tmpResult = ctx.freshName("tmpResult") - val tmpInput = ctx.freshName("tmpInput") + val tmpResult = ctx.freshVariable("tmpResult", classOf[GenericInternalRow]) + val rowClass = JavaCode.javaType(classOf[GenericInternalRow]) + val tmpInput = ctx.freshVariable("tmpInput", classOf[InternalRow]) val fieldsEvalCode = fieldsCasts.zipWithIndex.map { case (cast, i) => - val fromFieldPrim = ctx.freshName("ffp") - val fromFieldNull = ctx.freshName("ffn") - val toFieldPrim = ctx.freshName("tfp") - val toFieldNull = ctx.freshName("tfn") - val fromType = CodeGenerator.javaType(from.fields(i).dataType) - s""" + val fromFieldPrim = ctx.freshVariable("ffp", from.fields(i).dataType) + val fromFieldNull = ctx.freshVariable("ffn", BooleanType) + val toFieldPrim = ctx.freshVariable("tfp", to.fields(i).dataType) + val toFieldNull = ctx.freshVariable("tfn", BooleanType) + val fromType = JavaCode.javaType(from.fields(i).dataType) + val setColumn = CodeGenerator.setColumn(tmpResult, to.fields(i).dataType, i, toFieldPrim) + code""" boolean $fromFieldNull = $tmpInput.isNullAt($i); if ($fromFieldNull) { $tmpResult.setNullAt($i); @@ -1281,18 +1333,18 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String if ($toFieldNull) { $tmpResult.setNullAt($i); } else { - ${CodeGenerator.setColumn(tmpResult, to.fields(i).dataType, i, toFieldPrim)}; + $setColumn; } } """ } val fieldsEvalCodes = ctx.splitExpressions( - expressions = fieldsEvalCode, + expressions = fieldsEvalCode.map(_.code), funcName = "castStruct", - arguments = ("InternalRow", tmpInput) :: (rowClass, tmpResult) :: Nil) + arguments = ("InternalRow", tmpInput.code) :: (rowClass.code, tmpResult.code) :: Nil) (input, result, resultIsNull) => - s""" + code""" final $rowClass $tmpResult = new $rowClass(${fieldsCasts.length}); final InternalRow $tmpInput = $input; $fieldsEvalCodes diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CodeGeneratorWithInterpretedFallback.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CodeGeneratorWithInterpretedFallback.scala index fb25e781e72e4..07fa813a98922 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CodeGeneratorWithInterpretedFallback.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CodeGeneratorWithInterpretedFallback.scala @@ -17,24 +17,12 @@ package org.apache.spark.sql.catalyst.expressions -import org.codehaus.commons.compiler.CompileException -import org.codehaus.janino.InternalCompilerException +import scala.util.control.NonFatal -import org.apache.spark.TaskContext +import org.apache.spark.internal.Logging import org.apache.spark.sql.internal.SQLConf import org.apache.spark.util.Utils -/** - * Catches compile error during code generation. - */ -object CodegenError { - def unapply(throwable: Throwable): Option[Exception] = throwable match { - case e: InternalCompilerException => Some(e) - case e: CompileException => Some(e) - case _ => None - } -} - /** * Defines values for `SQLConf` config of fallback mode. Use for test only. */ @@ -44,10 +32,10 @@ object CodegenObjectFactoryMode extends Enumeration { /** * A codegen object generator which creates objects with codegen path first. Once any compile - * error happens, it can fallbacks to interpreted implementation. In tests, we can use a SQL config + * error happens, it can fallback to interpreted implementation. In tests, we can use a SQL config * `SQLConf.CODEGEN_FACTORY_MODE` to control fallback behavior. */ -abstract class CodeGeneratorWithInterpretedFallback[IN, OUT] { +abstract class CodeGeneratorWithInterpretedFallback[IN, OUT] extends Logging { def createObject(in: IN): OUT = { // We are allowed to choose codegen-only or no-codegen modes if under tests. @@ -63,7 +51,10 @@ abstract class CodeGeneratorWithInterpretedFallback[IN, OUT] { try { createCodeGeneratedObject(in) } catch { - case CodegenError(_) => createInterpretedObject(in) + case NonFatal(_) => + // We should have already seen the error message in `CodeGenerator` + logWarning("Expr codegen error and falling back to interpreter mode") + createInterpretedObject(in) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpectsInputTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpectsInputTypes.scala index 98f25a9ad7597..981ce0b6a29fa 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpectsInputTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpectsInputTypes.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.types.AbstractDataType * This trait is typically used by operator expressions (e.g. [[Add]], [[Subtract]]) to define * expected input types without any implicit casting. * - * Most function expressions (e.g. [[Substring]] should extends [[ImplicitCastInputTypes]]) instead. + * Most function expressions (e.g. [[Substring]] should extend [[ImplicitCastInputTypes]]) instead. */ trait ExpectsInputTypes extends Expression { @@ -41,10 +41,19 @@ trait ExpectsInputTypes extends Expression { def inputTypes: Seq[AbstractDataType] override def checkInputDataTypes(): TypeCheckResult = { - val mismatches = children.zip(inputTypes).zipWithIndex.collect { - case ((child, expected), idx) if !expected.acceptsType(child.dataType) => + ExpectsInputTypes.checkInputDataTypes(children, inputTypes) + } +} + +object ExpectsInputTypes { + + def checkInputDataTypes( + inputs: Seq[Expression], + inputTypes: Seq[AbstractDataType]): TypeCheckResult = { + val mismatches = inputs.zip(inputTypes).zipWithIndex.collect { + case ((input, expected), idx) if !expected.acceptsType(input.dataType) => s"argument ${idx + 1} requires ${expected.simpleString} type, " + - s"however, '${child.sql}' is of ${child.dataType.simpleString} type." + s"however, '${input.sql}' is of ${input.dataType.catalogString} type." } if (mismatches.isEmpty) { @@ -55,7 +64,6 @@ trait ExpectsInputTypes extends Expression { } } - /** * A mixin for the analyzer to perform implicit type casting using * [[org.apache.spark.sql.catalyst.analysis.TypeCoercion.ImplicitTypeCasts]]. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 9b9fa41a47d0f..773aefc0ac1f9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.expressions import java.util.Locale import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion} import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.trees.TreeNode @@ -580,10 +580,10 @@ abstract class BinaryOperator extends BinaryExpression with ExpectsInputTypes { // First check whether left and right have the same type, then check if the type is acceptable. if (!left.dataType.sameType(right.dataType)) { TypeCheckResult.TypeCheckFailure(s"differing types in '$sql' " + - s"(${left.dataType.simpleString} and ${right.dataType.simpleString}).") + s"(${left.dataType.catalogString} and ${right.dataType.catalogString}).") } else if (!inputType.acceptsType(left.dataType)) { TypeCheckResult.TypeCheckFailure(s"'$sql' requires ${inputType.simpleString} type," + - s" not ${left.dataType.simpleString}") + s" not ${left.dataType.catalogString}") } else { TypeCheckResult.TypeCheckSuccess } @@ -695,6 +695,36 @@ abstract class TernaryExpression extends Expression { } } +/** + * A trait resolving nullable, containsNull, valueContainsNull flags of the output date type. + * This logic is usually utilized by expressions combining data from multiple child expressions + * of non-primitive types (e.g. [[CaseWhen]]). + */ +trait ComplexTypeMergingExpression extends Expression { + + /** + * A collection of data types used for resolution the output type of the expression. By default, + * data types of all child expressions. The collection must not be empty. + */ + @transient + lazy val inputTypesForMerging: Seq[DataType] = children.map(_.dataType) + + def dataTypeCheck: Unit = { + require( + inputTypesForMerging.nonEmpty, + "The collection of input data types must not be empty.") + require( + TypeCoercion.haveSameType(inputTypesForMerging), + "All input types must be the same except nullable, containsNull, valueContainsNull flags." + + s" The input types found are\n\t${inputTypesForMerging.mkString("\n\t")}") + } + + override def dataType: DataType = { + dataTypeCheck + inputTypesForMerging.reduceLeft(TypeCoercion.findCommonTypeDifferentOnlyInNullFlags(_, _).get) + } +} + /** * Common base trait for user-defined functions, including UDF/UDAF/UDTF of different languages * and Hive function wrappers. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala index 6493f09100577..226a4ddcffaa8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.expressions +import scala.util.control.NonFatal + import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateSafeProjection, GenerateUnsafeProjection} import org.apache.spark.sql.types.{DataType, StructType} @@ -180,7 +182,10 @@ object UnsafeProjection try { GenerateUnsafeProjection.generate(unsafeExprs, subexpressionEliminationEnabled) } catch { - case CodegenError(_) => InterpretedUnsafeProjection.createProjection(unsafeExprs) + case NonFatal(_) => + // We should have already seen the error message in `CodeGenerator` + logWarning("Expr codegen error and falling back to interpreter mode") + InterpretedUnsafeProjection.createProjection(unsafeExprs) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala index 3e7ca88249737..8954fe8a58e6e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala @@ -39,6 +39,7 @@ import org.apache.spark.sql.types.DataType * @param nullable True if the UDF can return null value. * @param udfDeterministic True if the UDF is deterministic. Deterministic UDF returns same result * each time it is invoked with a particular input. + * @param nullableTypes which of the inputTypes are nullable (i.e. not primitive) */ case class ScalaUDF( function: AnyRef, @@ -47,7 +48,8 @@ case class ScalaUDF( inputTypes: Seq[DataType] = Nil, udfName: Option[String] = None, nullable: Boolean = true, - udfDeterministic: Boolean = true) + udfDeterministic: Boolean = true, + nullableTypes: Seq[Boolean] = Nil) extends Expression with ImplicitCastInputTypes with NonSQLExpression with UserDefinedExpression { // The constructor for SPARK 2.1 and 2.2 @@ -58,7 +60,8 @@ case class ScalaUDF( inputTypes: Seq[DataType], udfName: Option[String]) = { this( - function, dataType, children, inputTypes, udfName, nullable = true, udfDeterministic = true) + function, dataType, children, inputTypes, udfName, nullable = true, + udfDeterministic = true, nullableTypes = Nil) } override lazy val deterministic: Boolean = udfDeterministic && children.forall(_.deterministic) @@ -1048,8 +1051,9 @@ case class ScalaUDF( lazy val udfErrorMessage = { val funcCls = function.getClass.getSimpleName - val inputTypes = children.map(_.dataType.simpleString).mkString(", ") - s"Failed to execute user defined function($funcCls: ($inputTypes) => ${dataType.simpleString})" + val inputTypes = children.map(_.dataType.catalogString).mkString(", ") + val outputType = dataType.catalogString + s"Failed to execute user defined function($funcCls: ($inputTypes) => $outputType)" } override def eval(input: InternalRow): Any = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala index 76a881146a146..536276b5cb29f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala @@ -73,7 +73,7 @@ case class SortOrder( if (RowOrdering.isOrderable(dataType)) { TypeCheckResult.TypeCheckSuccess } else { - TypeCheckResult.TypeCheckFailure(s"cannot sort data type ${dataType.simpleString}") + TypeCheckResult.TypeCheckFailure(s"cannot sort data type ${dataType.catalogString}") } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala index 84e38a8b2711e..8e48856d4607c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala @@ -80,16 +80,13 @@ case class TimeWindow( if (slideDuration <= 0) { return TypeCheckFailure(s"The slide duration ($slideDuration) must be greater than 0.") } - if (startTime < 0) { - return TypeCheckFailure(s"The start time ($startTime) must be greater than or equal to 0.") - } if (slideDuration > windowDuration) { return TypeCheckFailure(s"The slide duration ($slideDuration) must be less than or equal" + s" to the windowDuration ($windowDuration).") } - if (startTime >= slideDuration) { - return TypeCheckFailure(s"The start time ($startTime) must be less than the " + - s"slideDuration ($slideDuration).") + if (startTime.abs >= slideDuration) { + return TypeCheckFailure(s"The absolute value of start time ($startTime) must be less " + + s"than the slideDuration ($slideDuration).") } } dataTypeCheck diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxCountDistinctForIntervals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxCountDistinctForIntervals.scala index d4421ca20a9bd..f96a087972f1f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxCountDistinctForIntervals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxCountDistinctForIntervals.scala @@ -63,11 +63,11 @@ case class ApproxCountDistinctForIntervals( } // Mark as lazy so that endpointsExpression is not evaluated during tree transformation. - lazy val endpoints: Array[Double] = - (endpointsExpression.dataType, endpointsExpression.eval()) match { - case (ArrayType(elementType, _), arrayData: ArrayData) => - arrayData.toObjectArray(elementType).map(_.toString.toDouble) - } + lazy val endpoints: Array[Double] = { + val endpointsType = endpointsExpression.dataType.asInstanceOf[ArrayType] + val endpoints = endpointsExpression.eval().asInstanceOf[ArrayData] + endpoints.toObjectArray(endpointsType.elementType).map(_.toString.toDouble) + } override def checkInputDataTypes(): TypeCheckResult = { val defaultCheck = super.checkInputDataTypes() diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala index f1bbbdabb41f3..c790d87492c73 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala @@ -132,7 +132,7 @@ case class ApproximatePercentile( case TimestampType => value.asInstanceOf[Long].toDouble case n: NumericType => n.numeric.toDouble(value.asInstanceOf[n.InternalType]) case other: DataType => - throw new UnsupportedOperationException(s"Unexpected data type ${other.simpleString}") + throw new UnsupportedOperationException(s"Unexpected data type ${other.catalogString}") } buffer.add(doubleValue) } @@ -157,7 +157,7 @@ case class ApproximatePercentile( case DoubleType => doubleResult case _: DecimalType => doubleResult.map(Decimal(_)) case other: DataType => - throw new UnsupportedOperationException(s"Unexpected data type ${other.simpleString}") + throw new UnsupportedOperationException(s"Unexpected data type ${other.catalogString}") } if (result.length == 0) { null diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala index a133bc2361eb5..5ecb77be5965e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.expressions.aggregate -import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.analysis.{DecimalPrecision, TypeCheckResult} import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.util.TypeUtils @@ -46,7 +46,7 @@ abstract class AverageLike(child: Expression) extends DeclarativeAggregate { override lazy val aggBufferAttributes = sum :: count :: Nil override lazy val initialValues = Seq( - /* sum = */ Cast(Literal(0), sumDataType), + /* sum = */ Literal(0).cast(sumDataType), /* count = */ Literal(0L) ) @@ -57,21 +57,18 @@ abstract class AverageLike(child: Expression) extends DeclarativeAggregate { // If all input are nulls, count will be 0 and we will get null after the division. override lazy val evaluateExpression = child.dataType match { - case DecimalType.Fixed(p, s) => - // increase the precision and scale to prevent precision loss - val dt = DecimalType.bounded(p + 14, s + 4) - Cast(Cast(sum, dt) / Cast(count, DecimalType.bounded(DecimalType.MAX_PRECISION, 0)), - resultType) + case _: DecimalType => + DecimalPrecision.decimalAndDecimal(sum / count.cast(DecimalType.LongDecimal)).cast(resultType) case _ => - Cast(sum, resultType) / Cast(count, resultType) + sum.cast(resultType) / count.cast(resultType) } protected def updateExpressionsDef: Seq[Expression] = Seq( /* sum = */ Add( sum, - Coalesce(Cast(child, sumDataType) :: Cast(Literal(0), sumDataType) :: Nil)), - /* count = */ If(IsNull(child), count, count + 1L) + coalesce(child.cast(sumDataType), Literal(0).cast(sumDataType))), + /* count = */ If(child.isNull, count, count + 1L) ) override lazy val updateExpressions = updateExpressionsDef diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala index 6bbb083f1e18e..e2ff0efba07ca 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala @@ -75,7 +75,7 @@ abstract class CentralMomentAgg(child: Expression) val n2 = n.right val newN = n1 + n2 val delta = avg.right - avg.left - val deltaN = If(newN === Literal(0.0), Literal(0.0), delta / newN) + val deltaN = If(newN === 0.0, 0.0, delta / newN) val newAvg = avg.left + deltaN * n2 // higher order moments computed according to: @@ -102,7 +102,7 @@ abstract class CentralMomentAgg(child: Expression) } protected def updateExpressionsDef: Seq[Expression] = { - val newN = n + Literal(1.0) + val newN = n + 1.0 val delta = child - avg val deltaN = delta / newN val newAvg = avg + deltaN @@ -123,11 +123,11 @@ abstract class CentralMomentAgg(child: Expression) } trimHigherOrder(Seq( - If(IsNull(child), n, newN), - If(IsNull(child), avg, newAvg), - If(IsNull(child), m2, newM2), - If(IsNull(child), m3, newM3), - If(IsNull(child), m4, newM4) + If(child.isNull, n, newN), + If(child.isNull, avg, newAvg), + If(child.isNull, m2, newM2), + If(child.isNull, m3, newM3), + If(child.isNull, m4, newM4) )) } } @@ -142,8 +142,7 @@ case class StddevPop(child: Expression) extends CentralMomentAgg(child) { override protected def momentOrder = 2 override val evaluateExpression: Expression = { - If(n === Literal(0.0), Literal.create(null, DoubleType), - Sqrt(m2 / n)) + If(n === 0.0, Literal.create(null, DoubleType), sqrt(m2 / n)) } override def prettyName: String = "stddev_pop" @@ -159,9 +158,8 @@ case class StddevSamp(child: Expression) extends CentralMomentAgg(child) { override protected def momentOrder = 2 override val evaluateExpression: Expression = { - If(n === Literal(0.0), Literal.create(null, DoubleType), - If(n === Literal(1.0), Literal(Double.NaN), - Sqrt(m2 / (n - Literal(1.0))))) + If(n === 0.0, Literal.create(null, DoubleType), + If(n === 1.0, Double.NaN, sqrt(m2 / (n - 1.0)))) } override def prettyName: String = "stddev_samp" @@ -175,8 +173,7 @@ case class VariancePop(child: Expression) extends CentralMomentAgg(child) { override protected def momentOrder = 2 override val evaluateExpression: Expression = { - If(n === Literal(0.0), Literal.create(null, DoubleType), - m2 / n) + If(n === 0.0, Literal.create(null, DoubleType), m2 / n) } override def prettyName: String = "var_pop" @@ -190,9 +187,8 @@ case class VarianceSamp(child: Expression) extends CentralMomentAgg(child) { override protected def momentOrder = 2 override val evaluateExpression: Expression = { - If(n === Literal(0.0), Literal.create(null, DoubleType), - If(n === Literal(1.0), Literal(Double.NaN), - m2 / (n - Literal(1.0)))) + If(n === 0.0, Literal.create(null, DoubleType), + If(n === 1.0, Double.NaN, m2 / (n - 1.0))) } override def prettyName: String = "var_samp" @@ -207,9 +203,8 @@ case class Skewness(child: Expression) extends CentralMomentAgg(child) { override protected def momentOrder = 3 override val evaluateExpression: Expression = { - If(n === Literal(0.0), Literal.create(null, DoubleType), - If(m2 === Literal(0.0), Literal(Double.NaN), - Sqrt(n) * m3 / Sqrt(m2 * m2 * m2))) + If(n === 0.0, Literal.create(null, DoubleType), + If(m2 === 0.0, Double.NaN, sqrt(n) * m3 / sqrt(m2 * m2 * m2))) } } @@ -220,9 +215,8 @@ case class Kurtosis(child: Expression) extends CentralMomentAgg(child) { override protected def momentOrder = 4 override val evaluateExpression: Expression = { - If(n === Literal(0.0), Literal.create(null, DoubleType), - If(m2 === Literal(0.0), Literal(Double.NaN), - n * m4 / (m2 * m2) - Literal(3.0))) + If(n === 0.0, Literal.create(null, DoubleType), + If(m2 === 0.0, Double.NaN, n * m4 / (m2 * m2) - 3.0)) } override def prettyName: String = "kurtosis" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Corr.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Corr.scala index 3cdef72c1f2c4..e14cc716ea223 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Corr.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Corr.scala @@ -54,9 +54,9 @@ abstract class PearsonCorrelation(x: Expression, y: Expression) val n2 = n.right val newN = n1 + n2 val dx = xAvg.right - xAvg.left - val dxN = If(newN === Literal(0.0), Literal(0.0), dx / newN) + val dxN = If(newN === 0.0, 0.0, dx / newN) val dy = yAvg.right - yAvg.left - val dyN = If(newN === Literal(0.0), Literal(0.0), dy / newN) + val dyN = If(newN === 0.0, 0.0, dy / newN) val newXAvg = xAvg.left + dxN * n2 val newYAvg = yAvg.left + dyN * n2 val newCk = ck.left + ck.right + dx * dyN * n1 * n2 @@ -67,7 +67,7 @@ abstract class PearsonCorrelation(x: Expression, y: Expression) } protected def updateExpressionsDef: Seq[Expression] = { - val newN = n + Literal(1.0) + val newN = n + 1.0 val dx = x - xAvg val dxN = dx / newN val dy = y - yAvg @@ -78,7 +78,7 @@ abstract class PearsonCorrelation(x: Expression, y: Expression) val newXMk = xMk + dx * (x - newXAvg) val newYMk = yMk + dy * (y - newYAvg) - val isNull = IsNull(x) || IsNull(y) + val isNull = x.isNull || y.isNull Seq( If(isNull, n, newN), If(isNull, xAvg, newXAvg), @@ -99,9 +99,8 @@ case class Corr(x: Expression, y: Expression) extends PearsonCorrelation(x, y) { override val evaluateExpression: Expression = { - If(n === Literal(0.0), Literal.create(null, DoubleType), - If(n === Literal(1.0), Literal(Double.NaN), - ck / Sqrt(xMk * yMk))) + If(n === 0.0, Literal.create(null, DoubleType), + If(n === 1.0, Double.NaN, ck / sqrt(xMk * yMk))) } override def prettyName: String = "corr" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Covariance.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Covariance.scala index 72a7c62b328ee..ee28eb591882f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Covariance.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Covariance.scala @@ -50,9 +50,9 @@ abstract class Covariance(x: Expression, y: Expression) val n2 = n.right val newN = n1 + n2 val dx = xAvg.right - xAvg.left - val dxN = If(newN === Literal(0.0), Literal(0.0), dx / newN) + val dxN = If(newN === 0.0, 0.0, dx / newN) val dy = yAvg.right - yAvg.left - val dyN = If(newN === Literal(0.0), Literal(0.0), dy / newN) + val dyN = If(newN === 0.0, 0.0, dy / newN) val newXAvg = xAvg.left + dxN * n2 val newYAvg = yAvg.left + dyN * n2 val newCk = ck.left + ck.right + dx * dyN * n1 * n2 @@ -61,7 +61,7 @@ abstract class Covariance(x: Expression, y: Expression) } protected def updateExpressionsDef: Seq[Expression] = { - val newN = n + Literal(1.0) + val newN = n + 1.0 val dx = x - xAvg val dy = y - yAvg val dyN = dy / newN @@ -69,7 +69,7 @@ abstract class Covariance(x: Expression, y: Expression) val newYAvg = yAvg + dyN val newCk = ck + dx * (y - newYAvg) - val isNull = IsNull(x) || IsNull(y) + val isNull = x.isNull || y.isNull Seq( If(isNull, n, newN), If(isNull, xAvg, newXAvg), @@ -83,8 +83,7 @@ abstract class Covariance(x: Expression, y: Expression) usage = "_FUNC_(expr1, expr2) - Returns the population covariance of a set of number pairs.") case class CovPopulation(left: Expression, right: Expression) extends Covariance(left, right) { override val evaluateExpression: Expression = { - If(n === Literal(0.0), Literal.create(null, DoubleType), - ck / n) + If(n === 0.0, Literal.create(null, DoubleType), ck / n) } override def prettyName: String = "covar_pop" } @@ -94,9 +93,8 @@ case class CovPopulation(left: Expression, right: Expression) extends Covariance usage = "_FUNC_(expr1, expr2) - Returns the sample covariance of a set of number pairs.") case class CovSample(left: Expression, right: Expression) extends Covariance(left, right) { override val evaluateExpression: Expression = { - If(n === Literal(0.0), Literal.create(null, DoubleType), - If(n === Literal(1.0), Literal(Double.NaN), - ck / (n - Literal(1.0)))) + If(n === 0.0, Literal.create(null, DoubleType), + If(n === 1.0, Double.NaN, ck / (n - 1.0))) } override def prettyName: String = "covar_samp" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/First.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/First.scala index 4e671e1f3e6eb..f51bfd591204a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/First.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/First.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions.aggregate import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess} +import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.types._ @@ -80,8 +81,8 @@ case class First(child: Expression, ignoreNullsExpr: Expression) override lazy val updateExpressions: Seq[Expression] = { if (ignoreNulls) { Seq( - /* first = */ If(Or(valueSet, IsNull(child)), first, child), - /* valueSet = */ Or(valueSet, IsNotNull(child)) + /* first = */ If(valueSet || child.isNull, first, child), + /* valueSet = */ valueSet || child.isNotNull ) } else { Seq( @@ -97,7 +98,7 @@ case class First(child: Expression, ignoreNullsExpr: Expression) // false, we are safe to do so because first.right will be null in this case). Seq( /* first = */ If(valueSet.left, first.left, first.right), - /* valueSet = */ Or(valueSet.left, valueSet.right) + /* valueSet = */ valueSet.left || valueSet.right ) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Last.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Last.scala index 0ccabb9d98914..2650d7b5908fd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Last.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Last.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions.aggregate import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess} +import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.types._ @@ -80,8 +81,8 @@ case class Last(child: Expression, ignoreNullsExpr: Expression) override lazy val updateExpressions: Seq[Expression] = { if (ignoreNulls) { Seq( - /* last = */ If(IsNull(child), last, child), - /* valueSet = */ Or(valueSet, IsNotNull(child)) + /* last = */ If(child.isNull, last, child), + /* valueSet = */ valueSet || child.isNotNull ) } else { Seq( @@ -95,7 +96,7 @@ case class Last(child: Expression, ignoreNullsExpr: Expression) // Prefer the right hand expression if it has been set. Seq( /* last = */ If(valueSet.right, last.right, last.left), - /* valueSet = */ Or(valueSet.right, valueSet.left) + /* valueSet = */ valueSet.right || valueSet.left ) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Max.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Max.scala index 58fd1d8620e16..71099eba0fc75 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Max.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Max.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.expressions.aggregate import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ @@ -45,12 +46,12 @@ case class Max(child: Expression) extends DeclarativeAggregate { ) override lazy val updateExpressions: Seq[Expression] = Seq( - /* max = */ Greatest(Seq(max, child)) + /* max = */ greatest(max, child) ) override lazy val mergeExpressions: Seq[Expression] = { Seq( - /* max = */ Greatest(Seq(max.left, max.right)) + /* max = */ greatest(max.left, max.right) ) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Min.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Min.scala index b2724ee76827c..8c4ba93231cbe 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Min.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Min.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.expressions.aggregate import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ @@ -45,12 +46,12 @@ case class Min(child: Expression) extends DeclarativeAggregate { ) override lazy val updateExpressions: Seq[Expression] = Seq( - /* min = */ Least(Seq(min, child)) + /* min = */ least(min, child) ) override lazy val mergeExpressions: Seq[Expression] = { Seq( - /* min = */ Least(Seq(min.left, min.right)) + /* min = */ least(min.left, min.right) ) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PivotFirst.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PivotFirst.scala index 523714869242d..33bc5b5821b36 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PivotFirst.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PivotFirst.scala @@ -17,11 +17,11 @@ package org.apache.spark.sql.catalyst.expressions.aggregate -import scala.collection.immutable.HashMap +import scala.collection.immutable.{HashMap, TreeMap} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.util.GenericArrayData +import org.apache.spark.sql.catalyst.util.{GenericArrayData, TypeUtils} import org.apache.spark.sql.types._ object PivotFirst { @@ -83,7 +83,12 @@ case class PivotFirst( override val dataType: DataType = ArrayType(valueDataType) - val pivotIndex = HashMap(pivotColumnValues.zipWithIndex: _*) + val pivotIndex = if (pivotColumn.dataType.isInstanceOf[AtomicType]) { + HashMap(pivotColumnValues.zipWithIndex: _*) + } else { + TreeMap(pivotColumnValues.zipWithIndex: _*)( + TypeUtils.getInterpretedOrdering(pivotColumn.dataType)) + } val indexSize = pivotIndex.size diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala index 86e40a9713b36..761dba111c074 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.expressions.aggregate import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ @@ -61,12 +62,12 @@ case class Sum(child: Expression) extends DeclarativeAggregate with ImplicitCast if (child.nullable) { Seq( /* sum = */ - Coalesce(Seq(Add(Coalesce(Seq(sum, zero)), Cast(child, sumDataType)), sum)) + coalesce(coalesce(sum, zero) + child.cast(sumDataType), sum) ) } else { Seq( /* sum = */ - Add(Coalesce(Seq(sum, zero)), Cast(child, sumDataType)) + coalesce(sum, zero) + child.cast(sumDataType) ) } } @@ -74,7 +75,7 @@ case class Sum(child: Expression) extends DeclarativeAggregate with ImplicitCast override lazy val mergeExpressions: Seq[Expression] = { Seq( /* sum = */ - Coalesce(Seq(Add(Coalesce(Seq(sum.left, zero)), sum.right), sum.left)) + coalesce(coalesce(sum.left, zero) + sum.right, sum.left) ) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index fe91e520169b4..c827226d58420 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion} import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.util.TypeUtils @@ -514,7 +514,7 @@ case class Pmod(left: Expression, right: Expression) extends BinaryArithmetic { > SELECT _FUNC_(10, 9, 2, 4, 3); 2 """) -case class Least(children: Seq[Expression]) extends Expression { +case class Least(children: Seq[Expression]) extends ComplexTypeMergingExpression { override def nullable: Boolean = children.forall(_.nullable) override def foldable: Boolean = children.forall(_.foldable) @@ -525,17 +525,15 @@ case class Least(children: Seq[Expression]) extends Expression { if (children.length <= 1) { TypeCheckResult.TypeCheckFailure( s"input to function $prettyName requires at least two arguments") - } else if (children.map(_.dataType).distinct.count(_ != NullType) > 1) { + } else if (!TypeCoercion.haveSameType(inputTypesForMerging)) { TypeCheckResult.TypeCheckFailure( s"The expressions should all have the same type," + - s" got LEAST(${children.map(_.dataType.simpleString).mkString(", ")}).") + s" got LEAST(${children.map(_.dataType.catalogString).mkString(", ")}).") } else { TypeUtils.checkForOrderingExpr(dataType, s"function $prettyName") } } - override def dataType: DataType = children.head.dataType - override def eval(input: InternalRow): Any = { children.foldLeft[Any](null)((r, c) => { val evalc = c.eval(input) @@ -589,7 +587,7 @@ case class Least(children: Seq[Expression]) extends Expression { > SELECT _FUNC_(10, 9, 2, 4, 3); 10 """) -case class Greatest(children: Seq[Expression]) extends Expression { +case class Greatest(children: Seq[Expression]) extends ComplexTypeMergingExpression { override def nullable: Boolean = children.forall(_.nullable) override def foldable: Boolean = children.forall(_.foldable) @@ -600,17 +598,15 @@ case class Greatest(children: Seq[Expression]) extends Expression { if (children.length <= 1) { TypeCheckResult.TypeCheckFailure( s"input to function $prettyName requires at least two arguments") - } else if (children.map(_.dataType).distinct.count(_ != NullType) > 1) { + } else if (!TypeCoercion.haveSameType(inputTypesForMerging)) { TypeCheckResult.TypeCheckFailure( s"The expressions should all have the same type," + - s" got GREATEST(${children.map(_.dataType.simpleString).mkString(", ")}).") + s" got GREATEST(${children.map(_.dataType.catalogString).mkString(", ")}).") } else { TypeUtils.checkForOrderingExpr(dataType, s"function $prettyName") } } - override def dataType: DataType = children.head.dataType - override def eval(input: InternalRow): Any = { children.foldLeft[Any](null)((r, c) => { val evalc = c.eval(input) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeFormatter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeFormatter.scala index 7b398f424cead..ea1bb87d415c9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeFormatter.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeFormatter.scala @@ -27,6 +27,10 @@ import java.util.regex.Matcher */ object CodeFormatter { val commentHolder = """\/\*(.+?)\*\/""".r + val commentRegexp = + ("""([ |\t]*?\/\*[\s|\S]*?\*\/[ |\t]*?)|""" + // strip /*comment*/ + """([ |\t]*?\/\/[\s\S]*?\n)""").r // strip //comment + val extraNewLinesRegexp = """\n\s*\n""".r // strip extra newlines def format(code: CodeAndComment, maxLines: Int = -1): String = { val formatter = new CodeFormatter @@ -91,11 +95,7 @@ object CodeFormatter { } def stripExtraNewLinesAndComments(input: String): String = { - val commentReg = - ("""([ |\t]*?\/\*[\s|\S]*?\*\/[ |\t]*?)|""" + // strip /*comment*/ - """([ |\t]*?\/\/[\s\S]*?\n)""").r // strip //comment - val codeWithoutComment = commentReg.replaceAllIn(input, "") - codeWithoutComment.replaceAll("""\n\s*\n""", "\n") // strip ExtraNewLines + extraNewLinesRegexp.replaceAllIn(commentRegexp.replaceAllIn(input, ""), "\n") } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 4cc0968911cb5..d5857e060a2c4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -39,7 +39,7 @@ import org.apache.spark.metrics.source.CodegenMetrics import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ -import org.apache.spark.sql.catalyst.util.{ArrayData, MapData} +import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, MapData} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.unsafe.Platform @@ -471,6 +471,8 @@ class CodegenContext { case NewFunctionSpec(functionName, None, None) => functionName case NewFunctionSpec(functionName, Some(_), Some(innerClassInstance)) => innerClassInstance + "." + functionName + case _ => + throw new IllegalArgumentException(s"$funcName is not matched at addNewFunction") } } @@ -579,6 +581,18 @@ class CodegenContext { s"${fullName}_$id" } + /** + * Creates an `ExprValue` representing a local java variable of required data type. + */ + def freshVariable(name: String, dt: DataType): VariableValue = + JavaCode.variable(freshName(name), dt) + + /** + * Creates an `ExprValue` representing a local java variable of required Java class. + */ + def freshVariable(name: String, javaClass: Class[_]): VariableValue = + JavaCode.variable(freshName(name), javaClass) + /** * Generates code for equal expression in Java. */ @@ -596,7 +610,7 @@ class CodegenContext { case NullType => "false" case _ => throw new IllegalArgumentException( - "cannot generate equality code for un-comparable type: " + dataType.simpleString) + "cannot generate equality code for un-comparable type: " + dataType.catalogString) } /** @@ -683,7 +697,7 @@ class CodegenContext { case udt: UserDefinedType[_] => genComp(udt.sqlType, c1, c2) case _ => throw new IllegalArgumentException( - "cannot generate compare code for un-comparable type: " + dataType.simpleString) + "cannot generate compare code for un-comparable type: " + dataType.catalogString) } /** @@ -732,73 +746,6 @@ class CodegenContext { """.stripMargin } - /** - * Generates code creating a [[UnsafeArrayData]]. - * - * @param arrayName name of the array to create - * @param numElements code representing the number of elements the array should contain - * @param elementType data type of the elements in the array - * @param additionalErrorMessage string to include in the error message - */ - def createUnsafeArray( - arrayName: String, - numElements: String, - elementType: DataType, - additionalErrorMessage: String): String = { - val arraySize = freshName("size") - val arrayBytes = freshName("arrayBytes") - - s""" - |long $arraySize = UnsafeArrayData.calculateSizeOfUnderlyingByteArray( - | $numElements, - | ${elementType.defaultSize}); - |if ($arraySize > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) { - | throw new RuntimeException("Unsuccessful try create array with " + $arraySize + - | " bytes of data due to exceeding the limit " + - | "${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH} bytes for UnsafeArrayData." + - | "$additionalErrorMessage"); - |} - |byte[] $arrayBytes = new byte[(int)$arraySize]; - |UnsafeArrayData $arrayName = new UnsafeArrayData(); - |Platform.putLong($arrayBytes, ${Platform.BYTE_ARRAY_OFFSET}, $numElements); - |$arrayName.pointTo($arrayBytes, ${Platform.BYTE_ARRAY_OFFSET}, (int)$arraySize); - """.stripMargin - } - - /** - * Generates code creating a [[UnsafeArrayData]]. The generated code executes - * a provided fallback when the size of backing array would exceed the array size limit. - * @param arrayName a name of the array to create - * @param numElements a piece of code representing the number of elements the array should contain - * @param elementSize a size of an element in bytes - * @param bodyCode a function generating code that fills up the [[UnsafeArrayData]] - * and getting the backing array as a parameter - * @param fallbackCode a piece of code executed when the array size limit is exceeded - */ - def createUnsafeArrayWithFallback( - arrayName: String, - numElements: String, - elementSize: Int, - bodyCode: String => String, - fallbackCode: String): String = { - val arraySize = freshName("size") - val arrayBytes = freshName("arrayBytes") - s""" - |final long $arraySize = UnsafeArrayData.calculateSizeOfUnderlyingByteArray( - | $numElements, - | $elementSize); - |if ($arraySize > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) { - | $fallbackCode - |} else { - | final byte[] $arrayBytes = new byte[(int)$arraySize]; - | UnsafeArrayData $arrayName = new UnsafeArrayData(); - | Platform.putLong($arrayBytes, ${Platform.BYTE_ARRAY_OFFSET}, $numElements); - | $arrayName.pointTo($arrayBytes, ${Platform.BYTE_ARRAY_OFFSET}, (int)$arraySize); - | ${bodyCode(arrayBytes)} - |} - """.stripMargin - } - /** * Generates code to do null safe execution, i.e. only execute the code when the input is not * null by adding null check if necessary. @@ -1173,12 +1120,7 @@ class CodegenContext { text: => String, placeholderId: String = "", force: Boolean = false): Block = { - // By default, disable comments in generated code because computing the comments themselves can - // be extremely expensive in certain cases, such as deeply-nested expressions which operate over - // inputs with wide schemas. For more details on the performance issues that motivated this - // flat, see SPARK-15680. - if (force || - SparkEnv.get != null && SparkEnv.get.conf.getBoolean("spark.sql.codegen.comments", false)) { + if (force || SQLConf.get.codegenComments) { val name = if (placeholderId != "") { assert(!placeHolderToComments.contains(placeholderId)) placeholderId @@ -1261,7 +1203,8 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin object CodeGenerator extends Logging { - // This is the value of HugeMethodLimit in the OpenJDK JVM settings + // This is the default value of HugeMethodLimit in the OpenJDK HotSpot JVM, + // beyond which methods will be rejected from JIT compilation final val DEFAULT_JVM_HUGE_METHOD_LIMIT = 8000 // The max valid length of method parameters in JVM. @@ -1319,7 +1262,7 @@ object CodeGenerator extends Logging { evaluator.setParentClassLoader(parentClassLoader) // Cannot be under package codegen, or fail with java.lang.InstantiationException evaluator.setClassName("org.apache.spark.sql.catalyst.expressions.GeneratedClass") - evaluator.setDefaultImports(Array( + evaluator.setDefaultImports( classOf[Platform].getName, classOf[InternalRow].getName, classOf[UnsafeRow].getName, @@ -1334,7 +1277,7 @@ object CodeGenerator extends Logging { classOf[TaskContext].getName, classOf[TaskKilledException].getName, classOf[InputMetrics].getName - )) + ) evaluator.setExtendedClass(classOf[GeneratedClass]) logDebug({ @@ -1388,9 +1331,15 @@ object CodeGenerator extends Logging { try { val cf = new ClassFile(new ByteArrayInputStream(classBytes)) val stats = cf.methodInfos.asScala.flatMap { method => - method.getAttributes().filter(_.getClass.getName == codeAttr.getName).map { a => + method.getAttributes().filter(_.getClass eq codeAttr).map { a => val byteCodeSize = codeAttrField.get(a).asInstanceOf[Array[Byte]].length CodegenMetrics.METRIC_GENERATED_METHOD_BYTECODE_SIZE.update(byteCodeSize) + + if (byteCodeSize > DEFAULT_JVM_HUGE_METHOD_LIMIT) { + logInfo("Generated method too long to be JIT compiled: " + + s"${cf.getThisClassName}.${method.getName} is $byteCodeSize bytes") + } + byteCodeSize } } @@ -1415,7 +1364,7 @@ object CodeGenerator extends Logging { * weak keys/values and thus does not respond to memory pressure. */ private val cache = CacheBuilder.newBuilder() - .maximumSize(100) + .maximumSize(SQLConf.get.codegenCacheMaxEntries) .build( new CacheLoader[CodeAndComment, (GeneratedClass, Int)]() { override def load(code: CodeAndComment): (GeneratedClass, Int) = { @@ -1474,6 +1423,59 @@ object CodeGenerator extends Logging { } } + /** + * Generates code creating a [[UnsafeArrayData]] or [[GenericArrayData]] based on + * given parameters. + * + * @param arrayName name of the array to create + * @param elementType data type of the elements in source array + * @param numElements code representing the number of elements the array should contain + * @param additionalErrorMessage string to include in the error message + * + * @return code representing the allocation of [[ArrayData]] + */ + def createArrayData( + arrayName: String, + elementType: DataType, + numElements: String, + additionalErrorMessage: String): String = { + val elementSize = if (CodeGenerator.isPrimitiveType(elementType)) { + elementType.defaultSize + } else { + -1 + } + s""" + |ArrayData $arrayName = ArrayData.allocateArrayData( + | $elementSize, $numElements, "$additionalErrorMessage"); + """.stripMargin + } + + /** + * Generates assignment code for an [[ArrayData]] + * + * @param dstArray name of the array to be assigned + * @param elementType data type of the elements in destination and source arrays + * @param srcArray name of the array to be read + * @param needNullCheck value which shows whether a nullcheck is required for the returning + * assignment + * @param dstArrayIndex an index variable to access each element of destination array + * @param srcArrayIndex an index variable to access each element of source array + * + * @return code representing an assignment to each element of the [[ArrayData]], which requires + * a pair of destination and source loop index variables + */ + def createArrayAssignment( + dstArray: String, + elementType: DataType, + srcArray: String, + dstArrayIndex: String, + srcArrayIndex: String, + needNullCheck: Boolean): String = { + CodeGenerator.setArrayElement(dstArray, elementType, dstArrayIndex, + CodeGenerator.getValue(srcArray, elementType, srcArrayIndex), + if (needNullCheck) Some(s"$srcArray.isNullAt($srcArrayIndex)") else None) + } + /** * Returns the code to update a column in Row for a given DataType. */ @@ -1542,6 +1544,34 @@ object CodeGenerator extends Logging { } } + /** + * Generates code of setter for an [[ArrayData]]. + */ + def setArrayElement( + array: String, + elementType: DataType, + i: String, + value: String, + isNull: Option[String] = None): String = { + val isPrimitiveType = CodeGenerator.isPrimitiveType(elementType) + val setFunc = if (isPrimitiveType) { + s"set${CodeGenerator.primitiveTypeName(elementType)}" + } else { + "update" + } + if (isNull.isDefined && isPrimitiveType) { + s""" + |if (${isNull.get}) { + | $array.setNullAt($i); + |} else { + | $array.$setFunc($i, $value); + |} + """.stripMargin + } else { + s"$array.$setFunc($i, $value);" + } + } + /** * Returns the specialized code to set a given value in a column vector for a given `DataType` * that could potentially be nullable. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala index 8f2a5a0dce943..0ecd0de8d8203 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -32,6 +32,8 @@ import org.apache.spark.sql.types._ */ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafeProjection] { + case class Schema(dataType: DataType, nullable: Boolean) + /** Returns true iff we support this data type. */ def canSupport(dataType: DataType): Boolean = UserDefinedType.sqlType(dataType) match { case NullType => true @@ -43,19 +45,21 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro case _ => false } - // TODO: if the nullability of field is correct, we can use it to save null check. private def writeStructToBuffer( ctx: CodegenContext, input: String, index: String, - fieldTypes: Seq[DataType], + schemas: Seq[Schema], rowWriter: String): String = { // Puts `input` in a local variable to avoid to re-evaluate it if it's a statement. val tmpInput = ctx.freshName("tmpInput") - val fieldEvals = fieldTypes.zipWithIndex.map { case (dt, i) => - ExprCode( - JavaCode.isNullExpression(s"$tmpInput.isNullAt($i)"), - JavaCode.expression(CodeGenerator.getValue(tmpInput, dt, i.toString), dt)) + val fieldEvals = schemas.zipWithIndex.map { case (Schema(dt, nullable), i) => + val isNull = if (nullable) { + JavaCode.isNullExpression(s"$tmpInput.isNullAt($i)") + } else { + FalseLiteral + } + ExprCode(isNull, JavaCode.expression(CodeGenerator.getValue(tmpInput, dt, i.toString), dt)) } val rowWriterClass = classOf[UnsafeRowWriter].getName @@ -70,7 +74,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro | // Remember the current cursor so that we can calculate how many bytes are | // written later. | final int $previousCursor = $rowWriter.cursor(); - | ${writeExpressionsToBuffer(ctx, tmpInput, fieldEvals, fieldTypes, structRowWriter)} + | ${writeExpressionsToBuffer(ctx, tmpInput, fieldEvals, schemas, structRowWriter)} | $rowWriter.setOffsetAndSizeFromPreviousCursor($index, $previousCursor); |} """.stripMargin @@ -80,14 +84,14 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro ctx: CodegenContext, row: String, inputs: Seq[ExprCode], - inputTypes: Seq[DataType], + schemas: Seq[Schema], rowWriter: String, isTopLevel: Boolean = false): String = { val resetWriter = if (isTopLevel) { // For top level row writer, it always writes to the beginning of the global buffer holder, // which means its fixed-size region always in the same position, so we don't need to call // `reset` to set up its fixed-size region every time. - if (inputs.map(_.isNull).forall(_ == "false")) { + if (inputs.map(_.isNull).forall(_ == FalseLiteral)) { // If all fields are not nullable, which means the null bits never changes, then we don't // need to clear it out every time. "" @@ -98,8 +102,8 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro s"$rowWriter.resetRowWriter();" } - val writeFields = inputs.zip(inputTypes).zipWithIndex.map { - case ((input, dataType), index) => + val writeFields = inputs.zip(schemas).zipWithIndex.map { + case ((input, Schema(dataType, nullable)), index) => val dt = UserDefinedType.sqlType(dataType) val setNull = dt match { @@ -110,7 +114,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro } val writeField = writeElement(ctx, input.value, index.toString, dt, rowWriter) - if (input.isNull == FalseLiteral) { + if (!nullable) { s""" |${input.code} |${writeField.trim} @@ -143,11 +147,11 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro """.stripMargin } - // TODO: if the nullability of array element is correct, we can use it to save null check. private def writeArrayToBuffer( ctx: CodegenContext, input: String, elementType: DataType, + containsNull: Boolean, rowWriter: String): String = { // Puts `input` in a local variable to avoid to re-evaluate it if it's a statement. val tmpInput = ctx.freshName("tmpInput") @@ -170,6 +174,18 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro val element = CodeGenerator.getValue(tmpInput, et, index) + val elementAssignment = if (containsNull) { + s""" + |if ($tmpInput.isNullAt($index)) { + | $arrayWriter.setNull${elementOrOffsetSize}Bytes($index); + |} else { + | ${writeElement(ctx, element, index, et, arrayWriter)} + |} + """.stripMargin + } else { + writeElement(ctx, element, index, et, arrayWriter) + } + s""" |final ArrayData $tmpInput = $input; |if ($tmpInput instanceof UnsafeArrayData) { @@ -179,23 +195,19 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro | $arrayWriter.initialize($numElements); | | for (int $index = 0; $index < $numElements; $index++) { - | if ($tmpInput.isNullAt($index)) { - | $arrayWriter.setNull${elementOrOffsetSize}Bytes($index); - | } else { - | ${writeElement(ctx, element, index, et, arrayWriter)} - | } + | $elementAssignment | } |} """.stripMargin } - // TODO: if the nullability of value element is correct, we can use it to save null check. private def writeMapToBuffer( ctx: CodegenContext, input: String, index: String, keyType: DataType, valueType: DataType, + valueContainsNull: Boolean, rowWriter: String): String = { // Puts `input` in a local variable to avoid to re-evaluate it if it's a statement. val tmpInput = ctx.freshName("tmpInput") @@ -203,6 +215,11 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro val previousCursor = ctx.freshName("previousCursor") // Writes out unsafe map according to the format described in `UnsafeMapData`. + val keyArray = writeArrayToBuffer( + ctx, s"$tmpInput.keyArray()", keyType, false, rowWriter) + val valueArray = writeArrayToBuffer( + ctx, s"$tmpInput.valueArray()", valueType, valueContainsNull, rowWriter) + s""" |final MapData $tmpInput = $input; |if ($tmpInput instanceof UnsafeMapData) { @@ -219,7 +236,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro | // Remember the current cursor so that we can write numBytes of key array later. | final int $tmpCursor = $rowWriter.cursor(); | - | ${writeArrayToBuffer(ctx, s"$tmpInput.keyArray()", keyType, rowWriter)} + | $keyArray | | // Write the numBytes of key array into the first 8 bytes. | Platform.putLong( @@ -227,7 +244,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro | $tmpCursor - 8, | $rowWriter.cursor() - $tmpCursor); | - | ${writeArrayToBuffer(ctx, s"$tmpInput.valueArray()", valueType, rowWriter)} + | $valueArray | $rowWriter.setOffsetAndSizeFromPreviousCursor($index, $previousCursor); |} """.stripMargin @@ -240,20 +257,21 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro dt: DataType, writer: String): String = dt match { case t: StructType => - writeStructToBuffer(ctx, input, index, t.map(_.dataType), writer) + writeStructToBuffer( + ctx, input, index, t.map(e => Schema(e.dataType, e.nullable)), writer) - case ArrayType(et, _) => + case ArrayType(et, en) => val previousCursor = ctx.freshName("previousCursor") s""" |// Remember the current cursor so that we can calculate how many bytes are |// written later. |final int $previousCursor = $writer.cursor(); - |${writeArrayToBuffer(ctx, input, et, writer)} + |${writeArrayToBuffer(ctx, input, et, en, writer)} |$writer.setOffsetAndSizeFromPreviousCursor($index, $previousCursor); """.stripMargin - case MapType(kt, vt, _) => - writeMapToBuffer(ctx, input, index, kt, vt, writer) + case MapType(kt, vt, vn) => + writeMapToBuffer(ctx, input, index, kt, vt, vn, writer) case DecimalType.Fixed(precision, scale) => s"$writer.write($index, $input, $precision, $scale);" @@ -268,12 +286,11 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro expressions: Seq[Expression], useSubexprElimination: Boolean = false): ExprCode = { val exprEvals = ctx.generateExpressions(expressions, useSubexprElimination) - val exprTypes = expressions.map(_.dataType) + val exprSchemas = expressions.map(e => Schema(e.dataType, e.nullable)) - val numVarLenFields = exprTypes.count { - case dt if UnsafeRow.isFixedLength(dt) => false + val numVarLenFields = exprSchemas.count { + case Schema(dt, _) => !UnsafeRow.isFixedLength(dt) // TODO: consider large decimal and interval type - case _ => true } val rowWriterClass = classOf[UnsafeRowWriter].getName @@ -284,7 +301,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro val evalSubexpr = ctx.subexprFunctions.mkString("\n") val writeExpressions = writeExpressionsToBuffer( - ctx, ctx.INPUT_ROW, exprEvals, exprTypes, rowWriter, isTopLevel = true) + ctx, ctx.INPUT_ROW, exprEvals, exprSchemas, rowWriter, isTopLevel = true) val code = code""" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala index 250ce48d059e0..17d4a0dc4e884 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala @@ -22,6 +22,7 @@ import java.lang.{Boolean => JBool} import scala.collection.mutable.ArrayBuffer import scala.language.{existentials, implicitConversions} +import org.apache.spark.sql.catalyst.trees.TreeNode import org.apache.spark.sql.types.{BooleanType, DataType} /** @@ -113,15 +114,28 @@ object JavaCode { def isNullExpression(code: String): SimpleExprValue = { expression(code, BooleanType) } + + /** + * Create an `Inline` for Java Class name. + */ + def javaType(javaClass: Class[_]): Inline = Inline(javaClass.getName) + + /** + * Create an `Inline` for Java Type name. + */ + def javaType(dataType: DataType): Inline = Inline(CodeGenerator.javaType(dataType)) + + /** + * Create an `Inline` for boxed Java Type name. + */ + def boxedType(dataType: DataType): Inline = Inline(CodeGenerator.boxedType(dataType)) } /** * A trait representing a block of java code. */ -trait Block extends JavaCode { - - // The expressions to be evaluated inside this block. - def exprValues: Set[ExprValue] +trait Block extends TreeNode[Block] with JavaCode { + import Block._ // Returns java code string for this code block. override def toString: String = _marginChar match { @@ -131,7 +145,9 @@ trait Block extends JavaCode { def length: Int = toString.length - def nonEmpty: Boolean = toString.nonEmpty + def isEmpty: Boolean = toString.isEmpty + + def nonEmpty: Boolean = !isEmpty // The leading prefix that should be stripped from each line. // By default we strip blanks or control characters followed by '|' from the line. @@ -147,15 +163,58 @@ trait Block extends JavaCode { this } + /** + * Apply a map function to each java expression codes present in this java code, and return a new + * java code based on the mapped java expression codes. + */ + def transformExprValues(f: PartialFunction[ExprValue, ExprValue]): this.type = { + var changed = false + + @inline def transform(e: ExprValue): ExprValue = { + val newE = f lift e + if (!newE.isDefined || newE.get.equals(e)) { + e + } else { + changed = true + newE.get + } + } + + def doTransform(arg: Any): AnyRef = arg match { + case e: ExprValue => transform(e) + case Some(value) => Some(doTransform(value)) + case seq: Traversable[_] => seq.map(doTransform) + case other: AnyRef => other + } + + val newArgs = mapProductIterator(doTransform) + if (changed) makeCopy(newArgs).asInstanceOf[this.type] else this + } + // Concatenates this block with other block. - def + (other: Block): Block + def + (other: Block): Block = other match { + case EmptyBlock => this + case _ => code"$this\n$other" + } + + override def verboseString: String = toString } object Block { val CODE_BLOCK_BUFFER_LENGTH: Int = 512 - implicit def blocksToBlock(blocks: Seq[Block]): Block = Blocks(blocks) + /** + * A custom string interpolator which inlines a string into code block. + */ + implicit class InlineHelper(val sc: StringContext) extends AnyVal { + def inline(args: Any*): Inline = { + val inlineString = sc.raw(args: _*) + Inline(inlineString) + } + } + + implicit def blocksToBlock(blocks: Seq[Block]): Block = blocks.reduceLeft(_ + _) implicit class BlockHelper(val sc: StringContext) extends AnyVal { def code(args: Any*): Block = { @@ -164,9 +223,8 @@ object Block { EmptyBlock } else { args.foreach { - case _: ExprValue => + case _: ExprValue | _: Inline | _: Block => case _: Int | _: Long | _: Float | _: Double | _: String => - case _: Block => case other => throw new IllegalArgumentException( s"Can not interpolate ${other.getClass.getName} into code block.") } @@ -190,18 +248,17 @@ object Block { while (strings.hasNext) { val input = inputs.next input match { - case _: ExprValue | _: Block => + case _: ExprValue | _: CodeBlock => codeParts += buf.toString buf.clear blockInputs += input.asInstanceOf[JavaCode] + case EmptyBlock => case _ => buf.append(input) } buf.append(strings.next) } - if (buf.nonEmpty) { - codeParts += buf.toString - } + codeParts += buf.toString (codeParts.toSeq, blockInputs.toSeq) } @@ -209,15 +266,15 @@ object Block { /** * A block of java code. Including a sequence of code parts and some inputs to this block. - * The actual java code is generated by embedding the inputs into the code parts. + * The actual java code is generated by embedding the inputs into the code parts. Here we keep + * inputs of `JavaCode` instead of simply folding them as a string of code, because we need to + * track expressions (`ExprValue`) in this code block. We need to be able to manipulate the + * expressions later without changing the behavior of this code block in some applications, e.g., + * method splitting. */ case class CodeBlock(codeParts: Seq[String], blockInputs: Seq[JavaCode]) extends Block { - override lazy val exprValues: Set[ExprValue] = { - blockInputs.flatMap { - case b: Block => b.exprValues - case e: ExprValue => Set(e) - }.toSet - } + override def children: Seq[Block] = + blockInputs.filter(_.isInstanceOf[Block]).asInstanceOf[Seq[Block]] override lazy val code: String = { val strings = codeParts.iterator @@ -230,30 +287,19 @@ case class CodeBlock(codeParts: Seq[String], blockInputs: Seq[JavaCode]) extends } buf.toString } - - override def + (other: Block): Block = other match { - case c: CodeBlock => Blocks(Seq(this, c)) - case b: Blocks => Blocks(Seq(this) ++ b.blocks) - case EmptyBlock => this - } -} - -case class Blocks(blocks: Seq[Block]) extends Block { - override lazy val exprValues: Set[ExprValue] = blocks.flatMap(_.exprValues).toSet - override lazy val code: String = blocks.map(_.toString).mkString("\n") - - override def + (other: Block): Block = other match { - case c: CodeBlock => Blocks(blocks :+ c) - case b: Blocks => Blocks(blocks ++ b.blocks) - case EmptyBlock => this - } } -object EmptyBlock extends Block with Serializable { +case object EmptyBlock extends Block with Serializable { override val code: String = "" - override val exprValues: Set[ExprValue] = Set.empty + override def children: Seq[Block] = Seq.empty +} - override def + (other: Block): Block = other +/** + * A piece of java code snippet inlines all types of input arguments into a string without + * tracking any reference of `JavaCode` instances. + */ +case class Inline(codeString: String) extends JavaCode { + override val code: String = codeString } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 8b278f067749e..3ad21ec5e51f4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -64,7 +64,7 @@ trait BinaryArrayExpressionWithImplicitCast extends BinaryExpression TypeCheckResult.TypeCheckSuccess case _ => TypeCheckResult.TypeCheckFailure(s"input to function $prettyName should have " + s"been two ${ArrayType.simpleString}s with same element type, but it's " + - s"[${left.dataType.simpleString}, ${right.dataType.simpleString}]") + s"[${left.dataType.catalogString}, ${right.dataType.catalogString}]") } } } @@ -89,15 +89,9 @@ trait BinaryArrayExpressionWithImplicitCast extends BinaryExpression > SELECT _FUNC_(NULL); -1 """) -case class Size( - child: Expression, - legacySizeOfNull: Boolean) - extends UnaryExpression with ExpectsInputTypes { +case class Size(child: Expression) extends UnaryExpression with ExpectsInputTypes { - def this(child: Expression) = - this( - child, - legacySizeOfNull = SQLConf.get.getConf(SQLConf.LEGACY_SIZE_OF_NULL)) + val legacySizeOfNull = SQLConf.get.legacySizeOfNull override def dataType: DataType = IntegerType override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(ArrayType, MapType)) @@ -137,7 +131,7 @@ case class Size( examples = """ Examples: > SELECT _FUNC_(map(1, 'a', 2, 'b')); - [1,2] + [1, 2] """) case class MapKeys(child: Expression) extends UnaryExpression with ExpectsInputTypes { @@ -174,27 +168,22 @@ case class ArraysZip(children: Seq[Expression]) extends Expression with ExpectsI override def inputTypes: Seq[AbstractDataType] = Seq.fill(children.length)(ArrayType) - override def dataType: DataType = ArrayType(mountSchema) - - override def nullable: Boolean = children.exists(_.nullable) - - private lazy val arrayTypes = children.map(_.dataType.asInstanceOf[ArrayType]) - - private lazy val arrayElementTypes = arrayTypes.map(_.elementType) - - @transient private lazy val mountSchema: StructType = { + @transient override lazy val dataType: DataType = { val fields = children.zip(arrayElementTypes).zipWithIndex.map { case ((expr: NamedExpression, elementType), _) => StructField(expr.name, elementType, nullable = true) case ((_, elementType), idx) => StructField(idx.toString, elementType, nullable = true) } - StructType(fields) + ArrayType(StructType(fields), containsNull = false) } - @transient lazy val numberOfArrays: Int = children.length + override def nullable: Boolean = children.exists(_.nullable) + + @transient private lazy val arrayElementTypes = + children.map(_.dataType.asInstanceOf[ArrayType].elementType) - @transient lazy val genericArrayData = classOf[GenericArrayData].getName + private def genericArrayData = classOf[GenericArrayData].getName def emptyInputGenCode(ev: ExprCode): ExprCode = { ev.copy(code""" @@ -262,7 +251,7 @@ case class ArraysZip(children: Seq[Expression]) extends Expression with ExpectsI ("ArrayData[]", arrVals) :: Nil) val initVariables = s""" - |ArrayData[] $arrVals = new ArrayData[$numberOfArrays]; + |ArrayData[] $arrVals = new ArrayData[${children.length}]; |int $biggestCardinality = 0; |${CodeGenerator.javaType(dataType)} ${ev.value} = null; """.stripMargin @@ -274,7 +263,7 @@ case class ArraysZip(children: Seq[Expression]) extends Expression with ExpectsI |if (!${ev.isNull}) { | Object[] $args = new Object[$biggestCardinality]; | for (int $i = 0; $i < $biggestCardinality; $i ++) { - | Object[] $currentRow = new Object[$numberOfArrays]; + | Object[] $currentRow = new Object[${children.length}]; | $getValueForTypeSplitted | $args[$i] = new $genericInternalRow($currentRow); | } @@ -284,7 +273,7 @@ case class ArraysZip(children: Seq[Expression]) extends Expression with ExpectsI } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - if (numberOfArrays == 0) { + if (children.length == 0) { emptyInputGenCode(ev) } else { nonEmptyInputGenCode(ctx, ev) @@ -331,7 +320,7 @@ case class ArraysZip(children: Seq[Expression]) extends Expression with ExpectsI examples = """ Examples: > SELECT _FUNC_(map(1, 'a', 2, 'b')); - ["a","b"] + ["a", "b"] """) case class MapValues(child: Expression) extends UnaryExpression with ExpectsInputTypes { @@ -359,14 +348,14 @@ case class MapValues(child: Expression) examples = """ Examples: > SELECT _FUNC_(map(1, 'a', 2, 'b')); - [(1,"a"),(2,"b")] + [[1, "a"], [2, "b"]] """, since = "2.4.0") case class MapEntries(child: Expression) extends UnaryExpression with ExpectsInputTypes { override def inputTypes: Seq[AbstractDataType] = Seq(MapType) - lazy val childDataType: MapType = child.dataType.asInstanceOf[MapType] + @transient private lazy val childDataType: MapType = child.dataType.asInstanceOf[MapType] override def dataType: DataType = { ArrayType( @@ -383,7 +372,7 @@ case class MapEntries(child: Expression) extends UnaryExpression with ExpectsInp val values = childMap.valueArray() val length = childMap.numElements() val resultData = new Array[AnyRef](length) - var i = 0; + var i = 0 while (i < length) { val key = keys.get(i, childDataType.keyType) val value = values.get(i, childDataType.valueType) @@ -396,113 +385,331 @@ case class MapEntries(child: Expression) extends UnaryExpression with ExpectsInp override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { nullSafeCodeGen(ctx, ev, c => { + val arrayData = ctx.freshName("arrayData") val numElements = ctx.freshName("numElements") val keys = ctx.freshName("keys") val values = ctx.freshName("values") val isKeyPrimitive = CodeGenerator.isPrimitiveType(childDataType.keyType) val isValuePrimitive = CodeGenerator.isPrimitiveType(childDataType.valueType) - val code = if (isKeyPrimitive && isValuePrimitive) { - genCodeForPrimitiveElements(ctx, keys, values, ev.value, numElements) + + val wordSize = UnsafeRow.WORD_SIZE + val structSize = UnsafeRow.calculateBitSetWidthInBytes(2) + wordSize * 2 + val (isPrimitive, elementSize) = if (isKeyPrimitive && isValuePrimitive) { + (true, structSize + wordSize) + } else { + (false, -1) + } + + val allocation = + s""" + |ArrayData $arrayData = ArrayData.allocateArrayData( + | $elementSize, $numElements, " $prettyName failed."); + """.stripMargin + + val code = if (isPrimitive) { + val genCodeForPrimitive = genCodeForPrimitiveElements( + ctx, arrayData, keys, values, ev.value, numElements, structSize) + s""" + |if ($arrayData instanceof UnsafeArrayData) { + | $genCodeForPrimitive + |} else { + | ${genCodeForAnyElements(ctx, arrayData, keys, values, ev.value, numElements)} + |} + """.stripMargin } else { - genCodeForAnyElements(ctx, keys, values, ev.value, numElements) + s"${genCodeForAnyElements(ctx, arrayData, keys, values, ev.value, numElements)}" } + s""" |final int $numElements = $c.numElements(); |final ArrayData $keys = $c.keyArray(); |final ArrayData $values = $c.valueArray(); + |$allocation |$code """.stripMargin }) } - private def getKey(varName: String) = CodeGenerator.getValue(varName, childDataType.keyType, "z") + private def getKey(varName: String, index: String) = + CodeGenerator.getValue(varName, childDataType.keyType, index) - private def getValue(varName: String) = { - CodeGenerator.getValue(varName, childDataType.valueType, "z") - } + private def getValue(varName: String, index: String) = + CodeGenerator.getValue(varName, childDataType.valueType, index) private def genCodeForPrimitiveElements( ctx: CodegenContext, + arrayData: String, keys: String, values: String, - arrayData: String, - numElements: String): String = { - val unsafeRow = ctx.freshName("unsafeRow") + resultArrayData: String, + numElements: String, + structSize: Int): String = { val unsafeArrayData = ctx.freshName("unsafeArrayData") + val baseObject = ctx.freshName("baseObject") + val unsafeRow = ctx.freshName("unsafeRow") val structsOffset = ctx.freshName("structsOffset") + val offset = ctx.freshName("offset") + val z = ctx.freshName("z") val calculateHeader = "UnsafeArrayData.calculateHeaderPortionInBytes" val baseOffset = Platform.BYTE_ARRAY_OFFSET val wordSize = UnsafeRow.WORD_SIZE - val structSize = UnsafeRow.calculateBitSetWidthInBytes(2) + wordSize * 2 - val structSizeAsLong = structSize + "L" - val keyTypeName = CodeGenerator.primitiveTypeName(childDataType.keyType) - val valueTypeName = CodeGenerator.primitiveTypeName(childDataType.keyType) + val structSizeAsLong = s"${structSize}L" - val valueAssignment = s"$unsafeRow.set$valueTypeName(1, ${getValue(values)});" - val valueAssignmentChecked = if (childDataType.valueContainsNull) { - s""" - |if ($values.isNullAt(z)) { - | $unsafeRow.setNullAt(1); - |} else { - | $valueAssignment - |} - """.stripMargin - } else { - valueAssignment - } + val setKey = CodeGenerator.setColumn(unsafeRow, childDataType.keyType, 0, getKey(keys, z)) - val assignmentLoop = (byteArray: String) => - s""" - |final int $structsOffset = $calculateHeader($numElements) + $numElements * $wordSize; - |UnsafeRow $unsafeRow = new UnsafeRow(2); - |for (int z = 0; z < $numElements; z++) { - | long offset = $structsOffset + z * $structSizeAsLong; - | $unsafeArrayData.setLong(z, (offset << 32) + $structSizeAsLong); - | $unsafeRow.pointTo($byteArray, $baseOffset + offset, $structSize); - | $unsafeRow.set$keyTypeName(0, ${getKey(keys)}); - | $valueAssignmentChecked - |} - |$arrayData = $unsafeArrayData; - """.stripMargin + val valueAssignmentChecked = CodeGenerator.createArrayAssignment( + unsafeRow, childDataType.valueType, values, "1", z, childDataType.valueContainsNull) - ctx.createUnsafeArrayWithFallback( - unsafeArrayData, - numElements, - structSize + wordSize, - assignmentLoop, - genCodeForAnyElements(ctx, keys, values, arrayData, numElements)) + s""" + |UnsafeArrayData $unsafeArrayData = (UnsafeArrayData)$arrayData; + |Object $baseObject = $unsafeArrayData.getBaseObject(); + |final int $structsOffset = $calculateHeader($numElements) + $numElements * $wordSize; + |UnsafeRow $unsafeRow = new UnsafeRow(2); + |for (int $z = 0; $z < $numElements; $z++) { + | long $offset = $structsOffset + $z * $structSizeAsLong; + | $unsafeArrayData.setLong($z, ($offset << 32) + $structSizeAsLong); + | $unsafeRow.pointTo($baseObject, $baseOffset + $offset, $structSize); + | $setKey; + | $valueAssignmentChecked + |} + |$resultArrayData = $arrayData; + """.stripMargin } private def genCodeForAnyElements( ctx: CodegenContext, + arrayData: String, keys: String, values: String, - arrayData: String, + resultArrayData: String, numElements: String): String = { - val genericArrayClass = classOf[GenericArrayData].getName - val rowClass = classOf[GenericInternalRow].getName - val data = ctx.freshName("internalRowArray") - + val z = ctx.freshName("z") val isValuePrimitive = CodeGenerator.isPrimitiveType(childDataType.valueType) val getValueWithCheck = if (childDataType.valueContainsNull && isValuePrimitive) { - s"$values.isNullAt(z) ? null : (Object)${getValue(values)}" + s"$values.isNullAt($z) ? null : (Object)${getValue(values, z)}" } else { - getValue(values) + getValue(values, z) } + val rowClass = classOf[GenericInternalRow].getName + val genericArrayDataClass = classOf[GenericArrayData].getName + val genericArrayData = ctx.freshName("genericArrayData") + val rowObject = s"new $rowClass(new Object[]{${getKey(keys, z)}, $getValueWithCheck})" s""" - |final Object[] $data = new Object[$numElements]; - |for (int z = 0; z < $numElements; z++) { - | $data[z] = new $rowClass(new Object[]{${getKey(keys)}, $getValueWithCheck}); + |$genericArrayDataClass $genericArrayData = ($genericArrayDataClass)$arrayData; + |for (int $z = 0; $z < $numElements; $z++) { + | $genericArrayData.update($z, $rowObject); |} - |$arrayData = new $genericArrayClass($data); + |$resultArrayData = $arrayData; """.stripMargin } override def prettyName: String = "map_entries" } +/** + * Returns the union of all the given maps. + */ +@ExpressionDescription( + usage = "_FUNC_(map, ...) - Returns the union of all the given maps", + examples = """ + Examples: + > SELECT _FUNC_(map(1, 'a', 2, 'b'), map(2, 'c', 3, 'd')); + [1 -> "a", 2 -> "b", 2 -> "c", 3 -> "d"] + """, since = "2.4.0") +case class MapConcat(children: Seq[Expression]) extends ComplexTypeMergingExpression { + + override def checkInputDataTypes(): TypeCheckResult = { + var funcName = s"function $prettyName" + if (children.exists(!_.dataType.isInstanceOf[MapType])) { + TypeCheckResult.TypeCheckFailure( + s"input to $funcName should all be of type map, but it's " + + children.map(_.dataType.catalogString).mkString("[", ", ", "]")) + } else { + TypeUtils.checkForSameTypeInputExpr(children.map(_.dataType), funcName) + } + } + + @transient override lazy val dataType: MapType = { + if (children.isEmpty) { + MapType(StringType, StringType) + } else { + super.dataType.asInstanceOf[MapType] + } + } + + override def nullable: Boolean = children.exists(_.nullable) + + override def eval(input: InternalRow): Any = { + val maps = children.map(_.eval(input)) + if (maps.contains(null)) { + return null + } + val keyArrayDatas = maps.map(_.asInstanceOf[MapData].keyArray()) + val valueArrayDatas = maps.map(_.asInstanceOf[MapData].valueArray()) + + val numElements = keyArrayDatas.foldLeft(0L)((sum, ad) => sum + ad.numElements()) + if (numElements > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) { + throw new RuntimeException(s"Unsuccessful attempt to concat maps with $numElements " + + s"elements due to exceeding the map size limit " + + s"${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}.") + } + val finalKeyArray = new Array[AnyRef](numElements.toInt) + val finalValueArray = new Array[AnyRef](numElements.toInt) + var position = 0 + for (i <- keyArrayDatas.indices) { + val keyArray = keyArrayDatas(i).toObjectArray(dataType.keyType) + val valueArray = valueArrayDatas(i).toObjectArray(dataType.valueType) + Array.copy(keyArray, 0, finalKeyArray, position, keyArray.length) + Array.copy(valueArray, 0, finalValueArray, position, valueArray.length) + position += keyArray.length + } + + new ArrayBasedMapData(new GenericArrayData(finalKeyArray), + new GenericArrayData(finalValueArray)) + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val mapCodes = children.map(_.genCode(ctx)) + val keyType = dataType.keyType + val valueType = dataType.valueType + val argsName = ctx.freshName("args") + val hasNullName = ctx.freshName("hasNull") + val mapDataClass = classOf[MapData].getName + val arrayBasedMapDataClass = classOf[ArrayBasedMapData].getName + val arrayDataClass = classOf[ArrayData].getName + + val init = + s""" + |$mapDataClass[] $argsName = new $mapDataClass[${mapCodes.size}]; + |boolean ${ev.isNull}, $hasNullName = false; + |$mapDataClass ${ev.value} = null; + """.stripMargin + + val assignments = mapCodes.zip(children.map(_.nullable)).zipWithIndex.map { + case ((m, true), i) => + s""" + |if (!$hasNullName) { + | ${m.code} + | if (!${m.isNull}) { + | $argsName[$i] = ${m.value}; + | } else { + | $hasNullName = true; + | } + |} + """.stripMargin + case ((m, false), i) => + s""" + |if (!$hasNullName) { + | ${m.code} + | $argsName[$i] = ${m.value}; + |} + """.stripMargin + } + + val codes = ctx.splitExpressionsWithCurrentInputs( + expressions = assignments, + funcName = "getMapConcatInputs", + extraArguments = (s"$mapDataClass[]", argsName) :: ("boolean", hasNullName) :: Nil, + returnType = "boolean", + makeSplitFunction = body => + s""" + |$body + |return $hasNullName; + """.stripMargin, + foldFunctions = _.map(funcCall => s"$hasNullName = $funcCall;").mkString("\n") + ) + + val idxName = ctx.freshName("idx") + val numElementsName = ctx.freshName("numElems") + val finKeysName = ctx.freshName("finalKeys") + val finValsName = ctx.freshName("finalValues") + + val keyConcat = genCodeForArrays(ctx, keyType, false) + + val valueConcat = + if (valueType.sameType(keyType) && + !(CodeGenerator.isPrimitiveType(valueType) && dataType.valueContainsNull)) { + keyConcat + } else { + genCodeForArrays(ctx, valueType, dataType.valueContainsNull) + } + + val keyArgsName = ctx.freshName("keyArgs") + val valArgsName = ctx.freshName("valArgs") + + val mapMerge = + s""" + |${ev.isNull} = $hasNullName; + |if (!${ev.isNull}) { + | $arrayDataClass[] $keyArgsName = new $arrayDataClass[${mapCodes.size}]; + | $arrayDataClass[] $valArgsName = new $arrayDataClass[${mapCodes.size}]; + | long $numElementsName = 0; + | for (int $idxName = 0; $idxName < $argsName.length; $idxName++) { + | $keyArgsName[$idxName] = $argsName[$idxName].keyArray(); + | $valArgsName[$idxName] = $argsName[$idxName].valueArray(); + | $numElementsName += $argsName[$idxName].numElements(); + | } + | if ($numElementsName > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) { + | throw new RuntimeException("Unsuccessful attempt to concat maps with " + + | $numElementsName + " elements due to exceeding the map size limit " + + | "${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}."); + | } + | $arrayDataClass $finKeysName = $keyConcat($keyArgsName, + | (int) $numElementsName); + | $arrayDataClass $finValsName = $valueConcat($valArgsName, + | (int) $numElementsName); + | ${ev.value} = new $arrayBasedMapDataClass($finKeysName, $finValsName); + |} + """.stripMargin + + ev.copy( + code = code""" + |$init + |$codes + |$mapMerge + """.stripMargin) + } + + private def genCodeForArrays( + ctx: CodegenContext, + elementType: DataType, + checkForNull: Boolean): String = { + val counter = ctx.freshName("counter") + val arrayData = ctx.freshName("arrayData") + val argsName = ctx.freshName("args") + val numElemName = ctx.freshName("numElements") + val y = ctx.freshName("y") + val z = ctx.freshName("z") + + val allocation = CodeGenerator.createArrayData( + arrayData, elementType, numElemName, s" $prettyName failed.") + val assignment = CodeGenerator.createArrayAssignment( + arrayData, elementType, s"$argsName[$y]", counter, z, checkForNull) + + val concat = ctx.freshName("concat") + val concatDef = + s""" + |private ArrayData $concat(ArrayData[] $argsName, int $numElemName) { + | $allocation + | int $counter = 0; + | for (int $y = 0; $y < ${children.length}; $y++) { + | for (int $z = 0; $z < $argsName[$y].numElements(); $z++) { + | $assignment + | $counter++; + | } + | } + | return $arrayData; + |} + """.stripMargin + + ctx.addNewFunction(concat, concatDef) + } + + override def prettyName: String = "map_concat" +} + /** * Returns a map created from the given array of entries. */ @@ -511,7 +718,7 @@ case class MapEntries(child: Expression) extends UnaryExpression with ExpectsInp examples = """ Examples: > SELECT _FUNC_(array(struct(1, 'a'), struct(2, 'b'))); - {1:"a",2:"b"} + [1 -> "a", 2 -> "b"] """, since = "2.4.0") case class MapFromEntries(child: Expression) extends UnaryExpression { @@ -526,16 +733,16 @@ case class MapFromEntries(child: Expression) extends UnaryExpression { case _ => None } - private def nullEntries: Boolean = dataTypeDetails.get._3 + @transient private lazy val nullEntries: Boolean = dataTypeDetails.get._3 override def nullable: Boolean = child.nullable || nullEntries - override def dataType: MapType = dataTypeDetails.get._1 + @transient override lazy val dataType: MapType = dataTypeDetails.get._1 override def checkInputDataTypes(): TypeCheckResult = dataTypeDetails match { case Some(_) => TypeCheckResult.TypeCheckSuccess case None => TypeCheckResult.TypeCheckFailure(s"'${child.sql}' is of " + - s"${child.dataType.simpleString} type. $prettyName accepts only arrays of pair structs.") + s"${child.dataType.catalogString} type. $prettyName accepts only arrays of pair structs.") } override protected def nullSafeEval(input: Any): Any = { @@ -632,25 +839,12 @@ case class MapFromEntries(child: Expression) extends UnaryExpression { val valueSize = dataType.valueType.defaultSize val kByteSize = s"UnsafeArrayData.calculateSizeOfUnderlyingByteArray($numEntries, $keySize)" val vByteSize = s"UnsafeArrayData.calculateSizeOfUnderlyingByteArray($numEntries, $valueSize)" - val keyTypeName = CodeGenerator.primitiveTypeName(dataType.keyType) - val valueTypeName = CodeGenerator.primitiveTypeName(dataType.valueType) - val keyAssignment = (key: String, idx: String) => s"$keyArrayData.set$keyTypeName($idx, $key);" - val valueAssignment = (entry: String, idx: String) => { - val value = CodeGenerator.getValue(entry, dataType.valueType, "1") - val valueNullUnsafeAssignment = s"$valueArrayData.set$valueTypeName($idx, $value);" - if (dataType.valueContainsNull) { - s""" - |if ($entry.isNullAt(1)) { - | $valueArrayData.setNullAt($idx); - |} else { - | $valueNullUnsafeAssignment - |} - """.stripMargin - } else { - valueNullUnsafeAssignment - } - } + val keyAssignment = (key: String, idx: String) => + CodeGenerator.setArrayElement(keyArrayData, dataType.keyType, idx, key) + val valueAssignment = (entry: String, idx: String) => + CodeGenerator.createArrayAssignment( + valueArrayData, dataType.valueType, entry, idx, "1", dataType.valueContainsNull) val assignmentLoop = genCodeForAssignmentLoop( ctx, childVariable, @@ -728,8 +922,7 @@ trait ArraySortLike extends ExpectsInputTypes { protected def nullOrder: NullOrder - @transient - private lazy val lt: Comparator[Any] = { + @transient private lazy val lt: Comparator[Any] = { val ordering = arrayExpression.dataType match { case _ @ ArrayType(n: AtomicType, _) => n.ordering.asInstanceOf[Ordering[Any]] case _ @ ArrayType(a: ArrayType, _) => a.interpretedOrdering.asInstanceOf[Ordering[Any]] @@ -751,8 +944,7 @@ trait ArraySortLike extends ExpectsInputTypes { } } - @transient - private lazy val gt: Comparator[Any] = { + @transient private lazy val gt: Comparator[Any] = { val ordering = arrayExpression.dataType match { case _ @ ArrayType(n: AtomicType, _) => n.ordering.asInstanceOf[Ordering[Any]] case _ @ ArrayType(a: ArrayType, _) => a.interpretedOrdering.asInstanceOf[Ordering[Any]] @@ -768,13 +960,15 @@ trait ArraySortLike extends ExpectsInputTypes { } else if (o2 == null) { nullOrder } else { - -ordering.compare(o1, o2) + ordering.compare(o2, o1) } } } } - def elementType: DataType = arrayExpression.dataType.asInstanceOf[ArrayType].elementType + @transient lazy val elementType: DataType = + arrayExpression.dataType.asInstanceOf[ArrayType].elementType + def containsNull: Boolean = arrayExpression.dataType.asInstanceOf[ArrayType].containsNull def sortEval(array: Any, ascending: Boolean): Any = { @@ -811,8 +1005,9 @@ trait ArraySortLike extends ExpectsInputTypes { } else { s"int $c = ${ctx.genComp(elementType, s"(($jt) $o1)", s"(($jt) $o2)")};" } - val nonNullPrimitiveAscendingSort = - if (CodeGenerator.isPrimitiveType(elementType) && !containsNull) { + val canPerformFastSort = + CodeGenerator.isPrimitiveType(elementType) && elementType != BooleanType && !containsNull + val nonNullPrimitiveAscendingSort = if (canPerformFastSort) { val javaType = CodeGenerator.javaType(elementType) val primitiveTypeName = CodeGenerator.primitiveTypeName(elementType) s""" @@ -876,7 +1071,7 @@ object ArraySortLike { examples = """ Examples: > SELECT _FUNC_(array('b', 'd', null, 'c', 'a'), true); - [null,"a","b","c","d"] + [null, "a", "b", "c", "d"] """) // scalastyle:on line.size.limit case class SortArray(base: Expression, ascendingOrder: Expression) @@ -902,7 +1097,7 @@ case class SortArray(base: Expression, ascendingOrder: Expression) "Sort order in second argument requires a boolean literal.") } case ArrayType(dt, _) => - val dtSimple = dt.simpleString + val dtSimple = dt.catalogString TypeCheckResult.TypeCheckFailure( s"$prettyName does not support sorting array of type $dtSimple which is not orderable") case _ => @@ -934,7 +1129,7 @@ case class SortArray(base: Expression, ascendingOrder: Expression) examples = """ Examples: > SELECT _FUNC_(array('b', 'd', null, 'c', 'a')); - ["a","b","c","d",null] + ["a", "b", "c", "d", null] """, since = "2.4.0") // scalastyle:on line.size.limit @@ -950,7 +1145,7 @@ case class ArraySort(child: Expression) extends UnaryExpression with ArraySortLi case ArrayType(dt, _) if RowOrdering.isOrderable(dt) => TypeCheckResult.TypeCheckSuccess case ArrayType(dt, _) => - val dtSimple = dt.simpleString + val dtSimple = dt.catalogString TypeCheckResult.TypeCheckFailure( s"$prettyName does not support sorting array of type $dtSimple which is not orderable") case _ => @@ -968,6 +1163,89 @@ case class ArraySort(child: Expression) extends UnaryExpression with ArraySortLi override def prettyName: String = "array_sort" } +/** + * Returns a random permutation of the given array. + */ +@ExpressionDescription( + usage = "_FUNC_(array) - Returns a random permutation of the given array.", + examples = """ + Examples: + > SELECT _FUNC_(array(1, 20, 3, 5)); + [3, 1, 5, 20] + > SELECT _FUNC_(array(1, 20, null, 3)); + [20, null, 3, 1] + """, + note = "The function is non-deterministic.", + since = "2.4.0") +case class Shuffle(child: Expression, randomSeed: Option[Long] = None) + extends UnaryExpression with ExpectsInputTypes with Stateful with ExpressionWithRandomSeed { + + def this(child: Expression) = this(child, None) + + override def withNewSeed(seed: Long): Shuffle = copy(randomSeed = Some(seed)) + + override lazy val resolved: Boolean = + childrenResolved && checkInputDataTypes().isSuccess && randomSeed.isDefined + + override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType) + + override def dataType: DataType = child.dataType + + @transient lazy val elementType: DataType = dataType.asInstanceOf[ArrayType].elementType + + @transient private[this] var random: RandomIndicesGenerator = _ + + override protected def initializeInternal(partitionIndex: Int): Unit = { + random = RandomIndicesGenerator(randomSeed.get + partitionIndex) + } + + override protected def evalInternal(input: InternalRow): Any = { + val value = child.eval(input) + if (value == null) { + null + } else { + val source = value.asInstanceOf[ArrayData] + val numElements = source.numElements() + val indices = random.getNextIndices(numElements) + new GenericArrayData(indices.map(source.get(_, elementType))) + } + } + + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + nullSafeCodeGen(ctx, ev, c => shuffleArrayCodeGen(ctx, ev, c)) + } + + private def shuffleArrayCodeGen(ctx: CodegenContext, ev: ExprCode, childName: String): String = { + val randomClass = classOf[RandomIndicesGenerator].getName + + val rand = ctx.addMutableState(randomClass, "rand", forceInline = true) + ctx.addPartitionInitializationStatement( + s"$rand = new $randomClass(${randomSeed.get}L + partitionIndex);") + + val numElements = ctx.freshName("numElements") + val arrayData = ctx.freshName("arrayData") + val indices = ctx.freshName("indices") + val i = ctx.freshName("i") + + val initialization = CodeGenerator.createArrayData( + arrayData, elementType, numElements, s" $prettyName failed.") + val assignment = CodeGenerator.createArrayAssignment(arrayData, elementType, childName, + i, s"$indices[$i]", dataType.asInstanceOf[ArrayType].containsNull) + + s""" + |int $numElements = $childName.numElements(); + |int[] $indices = $rand.getNextIndices($numElements); + |$initialization + |for (int $i = 0; $i < $numElements; $i++) { + | $assignment + |} + |${ev.value} = $arrayData; + """.stripMargin + } + + override def freshCopy(): Shuffle = Shuffle(child, randomSeed) +} + /** * Returns a reversed string or an array with reverse order of elements. */ @@ -976,7 +1254,7 @@ case class ArraySort(child: Expression) extends UnaryExpression with ArraySortLi examples = """ Examples: > SELECT _FUNC_('Spark SQL'); - LQS krapS + "LQS krapS" > SELECT _FUNC_(array(2, 1, 4, 3)); [3, 4, 1, 2] """, @@ -990,7 +1268,7 @@ case class Reverse(child: Expression) extends UnaryExpression with ImplicitCastI override def dataType: DataType = child.dataType - lazy val elementType: DataType = dataType.asInstanceOf[ArrayType].elementType + @transient private lazy val elementType: DataType = dataType.asInstanceOf[ArrayType].elementType override def nullSafeEval(input: Any): Any = input match { case a: ArrayData => new GenericArrayData(a.toObjectArray(elementType).reverse) @@ -1009,46 +1287,26 @@ case class Reverse(child: Expression) extends UnaryExpression with ImplicitCastI } private def arrayCodeGen(ctx: CodegenContext, ev: ExprCode, childName: String): String = { - val length = ctx.freshName("length") - val javaElementType = CodeGenerator.javaType(elementType) - val isPrimitiveType = CodeGenerator.isPrimitiveType(elementType) - val initialization = if (isPrimitiveType) { - s"$childName.copy()" - } else { - s"new ${classOf[GenericArrayData].getName()}(new Object[$length])" - } - - val numberOfIterations = if (isPrimitiveType) s"$length / 2" else length - - val swapAssigments = if (isPrimitiveType) { - val setFunc = "set" + CodeGenerator.primitiveTypeName(elementType) - val getCall = (index: String) => CodeGenerator.getValue(ev.value, elementType, index) - s"""|boolean isNullAtK = ${ev.value}.isNullAt(k); - |boolean isNullAtL = ${ev.value}.isNullAt(l); - |if(!isNullAtK) { - | $javaElementType el = ${getCall("k")}; - | if(!isNullAtL) { - | ${ev.value}.$setFunc(k, ${getCall("l")}); - | } else { - | ${ev.value}.setNullAt(k); - | } - | ${ev.value}.$setFunc(l, el); - |} else if (!isNullAtL) { - | ${ev.value}.$setFunc(k, ${getCall("l")}); - | ${ev.value}.setNullAt(l); - |}""".stripMargin - } else { - s"${ev.value}.update(k, ${CodeGenerator.getValue(childName, elementType, "l")});" - } + val numElements = ctx.freshName("numElements") + val arrayData = ctx.freshName("arrayData") + + val i = ctx.freshName("i") + val j = ctx.freshName("j") + + val initialization = CodeGenerator.createArrayData( + arrayData, elementType, numElements, s" $prettyName failed.") + val assignment = CodeGenerator.createArrayAssignment( + arrayData, elementType, childName, i, j, dataType.asInstanceOf[ArrayType].containsNull) s""" - |final int $length = $childName.numElements(); - |${ev.value} = $initialization; - |for(int k = 0; k < $numberOfIterations; k++) { - | int l = $length - k - 1; - | $swapAssigments + |final int $numElements = $childName.numElements(); + |$initialization + |for (int $i = 0; $i < $numElements; $i++) { + | int $j = $numElements - $i - 1; + | $assignment |} + |${ev.value} = $arrayData; """.stripMargin } @@ -1085,7 +1343,7 @@ case class ArrayContains(left: Expression, right: Expression) if (right.dataType == NullType) { TypeCheckResult.TypeCheckFailure("Null typed values cannot be used as arguments") } else if (!left.dataType.isInstanceOf[ArrayType] - || left.dataType.asInstanceOf[ArrayType].elementType != right.dataType) { + || !left.dataType.asInstanceOf[ArrayType].elementType.sameType(right.dataType)) { TypeCheckResult.TypeCheckFailure( "Arguments must be an array followed by a value of same type as the array members") } else { @@ -1117,17 +1375,29 @@ case class ArrayContains(left: Expression, right: Expression) nullSafeCodeGen(ctx, ev, (arr, value) => { val i = ctx.freshName("i") val getValue = CodeGenerator.getValue(arr, right.dataType, i) - s""" - for (int $i = 0; $i < $arr.numElements(); $i ++) { - if ($arr.isNullAt($i)) { - ${ev.isNull} = true; - } else if (${ctx.genEqual(right.dataType, value, getValue)}) { - ${ev.isNull} = false; - ${ev.value} = true; - break; - } + val loopBodyCode = if (nullable) { + s""" + |if ($arr.isNullAt($i)) { + | ${ev.isNull} = true; + |} else if (${ctx.genEqual(right.dataType, value, getValue)}) { + | ${ev.isNull} = false; + | ${ev.value} = true; + | break; + |} + """.stripMargin + } else { + s""" + |if (${ctx.genEqual(right.dataType, value, getValue)}) { + | ${ev.value} = true; + | break; + |} + """.stripMargin } - """ + s""" + |for (int $i = 0; $i < $arr.numElements(); $i ++) { + | $loopBodyCode + |} + """.stripMargin }) } @@ -1158,13 +1428,7 @@ case class ArraysOverlap(left: Expression, right: Expression) @transient private lazy val ordering: Ordering[Any] = TypeUtils.getInterpretedOrdering(elementType) - @transient private lazy val elementTypeSupportEquals = elementType match { - case BinaryType => false - case _: AtomicType => true - case _ => false - } - - @transient private lazy val doEvaluation = if (elementTypeSupportEquals) { + @transient private lazy val doEvaluation = if (TypeUtils.typeWithProperEquals(elementType)) { fastEval _ } else { bruteForceEval _ @@ -1246,7 +1510,7 @@ case class ArraysOverlap(left: Expression, right: Expression) nullSafeCodeGen(ctx, ev, (a1, a2) => { val smaller = ctx.freshName("smallerArray") val bigger = ctx.freshName("biggerArray") - val comparisonCode = if (elementTypeSupportEquals) { + val comparisonCode = if (TypeUtils.typeWithProperEquals(elementType)) { fastCodegen(ctx, ev, smaller, bigger) } else { bruteForceCodegen(ctx, ev, smaller, bigger) @@ -1282,12 +1546,13 @@ case class ArraysOverlap(left: Expression, right: Expression) val set = ctx.freshName("set") val addToSetFromSmallerCode = nullSafeElementCodegen( smaller, i, s"$set.add($getFromSmaller);", s"${ev.isNull} = true;") + val setIsNullCode = if (nullable) s"${ev.isNull} = false;" else "" val elementIsInSetCode = nullSafeElementCodegen( bigger, i, s""" |if ($set.contains($getFromBigger)) { - | ${ev.isNull} = false; + | $setIsNullCode | ${ev.value} = true; | break; |} @@ -1312,12 +1577,13 @@ case class ArraysOverlap(left: Expression, right: Expression) val j = ctx.freshName("j") val getFromSmaller = CodeGenerator.getValue(smaller, elementType, j) val getFromBigger = CodeGenerator.getValue(bigger, elementType, i) + val setIsNullCode = if (nullable) s"${ev.isNull} = false;" else "" val compareValues = nullSafeElementCodegen( smaller, j, s""" |if (${ctx.genEqual(elementType, getFromSmaller, getFromBigger)}) { - | ${ev.isNull} = false; + | $setIsNullCode | ${ev.value} = true; |} """.stripMargin, @@ -1368,9 +1634,9 @@ case class ArraysOverlap(left: Expression, right: Expression) examples = """ Examples: > SELECT _FUNC_(array(1, 2, 3, 4), 2, 2); - [2,3] + [2, 3] > SELECT _FUNC_(array(1, 2, 3, 4), -2, 2); - [3,4] + [3, 4] """, since = "2.4.0") // scalastyle:on line.size.limit case class Slice(x: Expression, start: Expression, length: Expression) @@ -1380,9 +1646,9 @@ case class Slice(x: Expression, start: Expression, length: Expression) override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType, IntegerType, IntegerType) - override def children: Seq[Expression] = Seq(x, start, length) + @transient override lazy val children: Seq[Expression] = Seq(x, start, length) // called from eval - lazy val elementType: DataType = x.dataType.asInstanceOf[ArrayType].elementType + @transient private lazy val elementType: DataType = x.dataType.asInstanceOf[ArrayType].elementType override def nullSafeEval(xVal: Any, startVal: Any, lengthVal: Any): Any = { val startInt = startVal.asInstanceOf[Int] @@ -1447,38 +1713,24 @@ case class Slice(x: Expression, start: Expression, length: Expression) resLength: String): String = { val values = ctx.freshName("values") val i = ctx.freshName("i") - val getValue = CodeGenerator.getValue(inputArray, elementType, s"$i + $startIdx") - if (!CodeGenerator.isPrimitiveType(elementType)) { - val arrayClass = classOf[GenericArrayData].getName - s""" - |Object[] $values; - |if ($startIdx < 0 || $startIdx >= $inputArray.numElements()) { - | $values = new Object[0]; - |} else { - | $values = new Object[$resLength]; - | for (int $i = 0; $i < $resLength; $i ++) { - | $values[$i] = $getValue; - | } - |} - |${ev.value} = new $arrayClass($values); - """.stripMargin - } else { - val primitiveValueTypeName = CodeGenerator.primitiveTypeName(elementType) - s""" - |if ($startIdx < 0 || $startIdx >= $inputArray.numElements()) { - | $resLength = 0; - |} - |${ctx.createUnsafeArray(values, resLength, elementType, s" $prettyName failed.")} - |for (int $i = 0; $i < $resLength; $i ++) { - | if ($inputArray.isNullAt($i + $startIdx)) { - | $values.setNullAt($i); - | } else { - | $values.set$primitiveValueTypeName($i, $getValue); - | } - |} - |${ev.value} = $values; - """.stripMargin - } + val genericArrayData = classOf[GenericArrayData].getName + + val allocation = CodeGenerator.createArrayData( + values, elementType, resLength, s" $prettyName failed.") + val assignment = CodeGenerator.createArrayAssignment(values, elementType, inputArray, + i, s"$i + $startIdx", dataType.asInstanceOf[ArrayType].containsNull) + + s""" + |if ($startIdx < 0 || $startIdx >= $inputArray.numElements()) { + | ${ev.value} = new $genericArrayData(new Object[0]); + |} else { + | $allocation + | for (int $i = 0; $i < $resLength; $i ++) { + | $assignment + | } + | ${ev.value} = $values; + |} + """.stripMargin } } @@ -1493,11 +1745,11 @@ case class Slice(x: Expression, start: Expression, length: Expression) examples = """ Examples: > SELECT _FUNC_(array('hello', 'world'), ' '); - hello world + "hello world" > SELECT _FUNC_(array('hello', null ,'world'), ' '); - hello world + "hello world" > SELECT _FUNC_(array('hello', null ,'world'), ' ', ','); - hello , world + "hello , world" """, since = "2.4.0") case class ArrayJoin( array: Expression, @@ -1668,7 +1920,7 @@ case class ArrayMin(child: Expression) extends UnaryExpression with ImplicitCast override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType) - private lazy val ordering = TypeUtils.getInterpretedOrdering(dataType) + @transient private lazy val ordering = TypeUtils.getInterpretedOrdering(dataType) override def checkInputDataTypes(): TypeCheckResult = { val typeCheckResult = super.checkInputDataTypes() @@ -1709,7 +1961,7 @@ case class ArrayMin(child: Expression) extends UnaryExpression with ImplicitCast min } - override def dataType: DataType = child.dataType match { + @transient override lazy val dataType: DataType = child.dataType match { case ArrayType(dt, _) => dt case _ => throw new IllegalStateException(s"$prettyName accepts only arrays.") } @@ -1733,7 +1985,7 @@ case class ArrayMax(child: Expression) extends UnaryExpression with ImplicitCast override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType) - private lazy val ordering = TypeUtils.getInterpretedOrdering(dataType) + @transient private lazy val ordering = TypeUtils.getInterpretedOrdering(dataType) override def checkInputDataTypes(): TypeCheckResult = { val typeCheckResult = super.checkInputDataTypes() @@ -1774,7 +2026,7 @@ case class ArrayMax(child: Expression) extends UnaryExpression with ImplicitCast max } - override def dataType: DataType = child.dataType match { + @transient override lazy val dataType: DataType = child.dataType match { case ArrayType(dt, _) => dt case _ => throw new IllegalStateException(s"$prettyName accepts only arrays.") } @@ -1876,10 +2128,13 @@ case class ArrayPosition(left: Expression, right: Expression) since = "2.4.0") case class ElementAt(left: Expression, right: Expression) extends GetMapValueUtil { - @transient private lazy val ordering: Ordering[Any] = - TypeUtils.getInterpretedOrdering(left.dataType.asInstanceOf[MapType].keyType) + @transient private lazy val mapKeyType = left.dataType.asInstanceOf[MapType].keyType + + @transient private lazy val arrayContainsNull = left.dataType.asInstanceOf[ArrayType].containsNull - override def dataType: DataType = left.dataType match { + @transient private lazy val ordering: Ordering[Any] = TypeUtils.getInterpretedOrdering(mapKeyType) + + @transient override lazy val dataType: DataType = left.dataType match { case ArrayType(elementType, _) => elementType case MapType(_, valueType, _) => valueType } @@ -1888,7 +2143,7 @@ case class ElementAt(left: Expression, right: Expression) extends GetMapValueUti Seq(TypeCollection(ArrayType, MapType), left.dataType match { case _: ArrayType => IntegerType - case _: MapType => left.dataType.asInstanceOf[MapType].keyType + case _: MapType => mapKeyType case _ => AnyDataType // no match for a wrong 'left' expression type } ) @@ -1898,8 +2153,7 @@ case class ElementAt(left: Expression, right: Expression) extends GetMapValueUti super.checkInputDataTypes() match { case f: TypeCheckResult.TypeCheckFailure => f case TypeCheckResult.TypeCheckSuccess if left.dataType.isInstanceOf[MapType] => - TypeUtils.checkForOrderingExpr( - left.dataType.asInstanceOf[MapType].keyType, s"function $prettyName") + TypeUtils.checkForOrderingExpr(mapKeyType, s"function $prettyName") case TypeCheckResult.TypeCheckSuccess => TypeCheckResult.TypeCheckSuccess } } @@ -1921,14 +2175,14 @@ case class ElementAt(left: Expression, right: Expression) extends GetMapValueUti } else { array.numElements() + index } - if (left.dataType.asInstanceOf[ArrayType].containsNull && array.isNullAt(idx)) { + if (arrayContainsNull && array.isNullAt(idx)) { null } else { array.get(idx, dataType) } } case _: MapType => - getValueEval(value, ordinal, left.dataType.asInstanceOf[MapType].keyType, ordering) + getValueEval(value, ordinal, mapKeyType, ordering) } } @@ -1937,7 +2191,7 @@ case class ElementAt(left: Expression, right: Expression) extends GetMapValueUti case _: ArrayType => nullSafeCodeGen(ctx, ev, (eval1, eval2) => { val index = ctx.freshName("elementAtIndex") - val nullCheck = if (left.dataType.asInstanceOf[ArrayType].containsNull) { + val nullCheck = if (arrayContainsNull) { s""" |if ($eval1.isNullAt($index)) { | ${ev.isNull} = true; @@ -1982,15 +2236,14 @@ case class ElementAt(left: Expression, right: Expression) extends GetMapValueUti examples = """ Examples: > SELECT _FUNC_('Spark', 'SQL'); - SparkSQL + "SparkSQL" > SELECT _FUNC_(array(1, 2, 3), array(4, 5), array(6)); - | [1,2,3,4,5,6] - """) -case class Concat(children: Seq[Expression]) extends Expression { - - private val MAX_ARRAY_LENGTH: Int = ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH + [1, 2, 3, 4, 5, 6] + """, + note = "Concat logic for arrays is available since 2.4.0.") +case class Concat(children: Seq[Expression]) extends ComplexTypeMergingExpression { - val allowedTypes = Seq(StringType, BinaryType, ArrayType) + private def allowedTypes: Seq[AbstractDataType] = Seq(StringType, BinaryType, ArrayType) override def checkInputDataTypes(): TypeCheckResult = { if (children.isEmpty) { @@ -2001,15 +2254,21 @@ case class Concat(children: Seq[Expression]) extends Expression { return TypeCheckResult.TypeCheckFailure( s"input to function $prettyName should have been ${StringType.simpleString}," + s" ${BinaryType.simpleString} or ${ArrayType.simpleString}, but it's " + - childTypes.map(_.simpleString).mkString("[", ", ", "]")) + childTypes.map(_.catalogString).mkString("[", ", ", "]")) } TypeUtils.checkForSameTypeInputExpr(childTypes, s"function $prettyName") } } - override def dataType: DataType = children.map(_.dataType).headOption.getOrElse(StringType) + @transient override lazy val dataType: DataType = { + if (children.isEmpty) { + StringType + } else { + super.dataType + } + } - lazy val javaType: String = CodeGenerator.javaType(dataType) + private def javaType: String = CodeGenerator.javaType(dataType) override def nullable: Boolean = children.exists(_.nullable) @@ -2029,9 +2288,10 @@ case class Concat(children: Seq[Expression]) extends Expression { } else { val arrayData = inputs.map(_.asInstanceOf[ArrayData]) val numberOfElements = arrayData.foldLeft(0L)((sum, ad) => sum + ad.numElements()) - if (numberOfElements > MAX_ARRAY_LENGTH) { + if (numberOfElements > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) { throw new RuntimeException(s"Unsuccessful try to concat arrays with $numberOfElements" + - s" elements due to exceeding the array size limit $MAX_ARRAY_LENGTH.") + " elements due to exceeding the array size limit " + + ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH + ".") } val finalData = new Array[AnyRef](numberOfElements.toInt) var position = 0 @@ -2047,125 +2307,113 @@ case class Concat(children: Seq[Expression]) extends Expression { override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val evals = children.map(_.genCode(ctx)) val args = ctx.freshName("args") + val hasNull = ctx.freshName("hasNull") - val inputs = evals.zipWithIndex.map { case (eval, index) => - s""" - ${eval.code} - if (!${eval.isNull}) { - $args[$index] = ${eval.value}; - } - """ + val inputs = evals.zip(children.map(_.nullable)).zipWithIndex.map { + case ((eval, true), index) => + s""" + |if (!$hasNull) { + | ${eval.code} + | if (!${eval.isNull}) { + | $args[$index] = ${eval.value}; + | } else { + | $hasNull = true; + | } + |} + """.stripMargin + case ((eval, false), index) => + s""" + |if (!$hasNull) { + | ${eval.code} + | $args[$index] = ${eval.value}; + |} + """.stripMargin } - val (concatenator, initCode) = dataType match { - case BinaryType => - (classOf[ByteArray].getName, s"byte[][] $args = new byte[${evals.length}][];") - case StringType => - ("UTF8String", s"UTF8String[] $args = new UTF8String[${evals.length}];") - case ArrayType(elementType, _) => - val arrayConcatClass = if (CodeGenerator.isPrimitiveType(elementType)) { - genCodeForPrimitiveArrays(ctx, elementType) - } else { - genCodeForNonPrimitiveArrays(ctx, elementType) - } - (arrayConcatClass, s"ArrayData[] $args = new ArrayData[${evals.length}];") - } val codes = ctx.splitExpressionsWithCurrentInputs( expressions = inputs, funcName = "valueConcat", - extraArguments = (s"$javaType[]", args) :: Nil) - ev.copy(code""" - $initCode - $codes - $javaType ${ev.value} = $concatenator.concat($args); - boolean ${ev.isNull} = ${ev.value} == null; - """) - } - - private def genCodeForNumberOfElements(ctx: CodegenContext) : (String, String) = { - val numElements = ctx.freshName("numElements") + extraArguments = (s"$javaType[]", args) :: ("boolean", hasNull) :: Nil, + returnType = "boolean", + makeSplitFunction = body => + s""" + |$body + |return $hasNull; + """.stripMargin, + foldFunctions = _.map(funcCall => s"$hasNull = $funcCall;").mkString("\n") + ) + + val (concat, initCode) = dataType match { + case BinaryType => + (s"${classOf[ByteArray].getName}.concat", s"byte[][] $args = new byte[${evals.length}][];") + case StringType => + ("UTF8String.concat", s"UTF8String[] $args = new UTF8String[${evals.length}];") + case ArrayType(elementType, containsNull) => + val concat = genCodeForArrays(ctx, elementType, containsNull) + (concat, s"ArrayData[] $args = new ArrayData[${evals.length}];") + } + + ev.copy(code = + code""" + |boolean $hasNull = false; + |$initCode + |$codes + |$javaType ${ev.value} = null; + |if (!$hasNull) { + | ${ev.value} = $concat($args); + |} + |boolean ${ev.isNull} = ${ev.value} == null; + """.stripMargin) + } + + private def genCodeForNumberOfElements(ctx: CodegenContext) : (String, String) = { + val numElements = ctx.freshName("numElements") + val z = ctx.freshName("z") val code = s""" |long $numElements = 0L; - |for (int z = 0; z < ${children.length}; z++) { - | $numElements += args[z].numElements(); - |} - |if ($numElements > $MAX_ARRAY_LENGTH) { - | throw new RuntimeException("Unsuccessful try to concat arrays with " + $numElements + - | " elements due to exceeding the array size limit $MAX_ARRAY_LENGTH."); + |for (int $z = 0; $z < ${children.length}; $z++) { + | $numElements += args[$z].numElements(); |} """.stripMargin (code, numElements) } - private def nullArgumentProtection() : String = { - if (nullable) { - s""" - |for (int z = 0; z < ${children.length}; z++) { - | if (args[z] == null) return null; - |} - """.stripMargin - } else { - "" - } - } - - private def genCodeForPrimitiveArrays(ctx: CodegenContext, elementType: DataType): String = { + private def genCodeForArrays( + ctx: CodegenContext, + elementType: DataType, + checkForNull: Boolean): String = { val counter = ctx.freshName("counter") val arrayData = ctx.freshName("arrayData") + val y = ctx.freshName("y") + val z = ctx.freshName("z") val (numElemCode, numElemName) = genCodeForNumberOfElements(ctx) - val primitiveValueTypeName = CodeGenerator.primitiveTypeName(elementType) + val initialization = CodeGenerator.createArrayData( + arrayData, elementType, numElemName, s" $prettyName failed.") + val assignment = CodeGenerator.createArrayAssignment( + arrayData, elementType, s"args[$y]", counter, z, + dataType.asInstanceOf[ArrayType].containsNull) - s""" - |new Object() { - | public ArrayData concat($javaType[] args) { - | ${nullArgumentProtection()} - | $numElemCode - | ${ctx.createUnsafeArray(arrayData, numElemName, elementType, s" $prettyName failed.")} - | int $counter = 0; - | for (int y = 0; y < ${children.length}; y++) { - | for (int z = 0; z < args[y].numElements(); z++) { - | if (args[y].isNullAt(z)) { - | $arrayData.setNullAt($counter); - | } else { - | $arrayData.set$primitiveValueTypeName( - | $counter, - | ${CodeGenerator.getValue(s"args[y]", elementType, "z")} - | ); - | } - | $counter++; - | } - | } - | return $arrayData; - | } - |}""".stripMargin.stripPrefix("\n") - } - - private def genCodeForNonPrimitiveArrays(ctx: CodegenContext, elementType: DataType): String = { - val genericArrayClass = classOf[GenericArrayData].getName - val arrayData = ctx.freshName("arrayObjects") - val counter = ctx.freshName("counter") - - val (numElemCode, numElemName) = genCodeForNumberOfElements(ctx) + val concat = ctx.freshName("concat") + val concatDef = + s""" + |private ArrayData $concat(ArrayData[] args) { + | $numElemCode + | $initialization + | int $counter = 0; + | for (int $y = 0; $y < ${children.length}; $y++) { + | for (int $z = 0; $z < args[$y].numElements(); $z++) { + | $assignment + | $counter++; + | } + | } + | return $arrayData; + |} + """.stripMargin - s""" - |new Object() { - | public ArrayData concat($javaType[] args) { - | ${nullArgumentProtection()} - | $numElemCode - | Object[] $arrayData = new Object[(int)$numElemName]; - | int $counter = 0; - | for (int y = 0; y < ${children.length}; y++) { - | for (int z = 0; z < args[y].numElements(); z++) { - | $arrayData[$counter] = ${CodeGenerator.getValue(s"args[y]", elementType, "z")}; - | $counter++; - | } - | } - | return new $genericArrayClass($arrayData); - | } - |}""".stripMargin.stripPrefix("\n") + ctx.addNewFunction(concat, concatDef) } override def toString: String = s"concat(${children.mkString(", ")})" @@ -2180,21 +2428,19 @@ case class Concat(children: Seq[Expression]) extends Expression { usage = "_FUNC_(arrayOfArrays) - Transforms an array of arrays into a single array.", examples = """ Examples: - > SELECT _FUNC_(array(array(1, 2), array(3, 4)); - [1,2,3,4] + > SELECT _FUNC_(array(array(1, 2), array(3, 4))); + [1, 2, 3, 4] """, since = "2.4.0") case class Flatten(child: Expression) extends UnaryExpression { - private val MAX_ARRAY_LENGTH = ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH - - private lazy val childDataType: ArrayType = child.dataType.asInstanceOf[ArrayType] + private def childDataType: ArrayType = child.dataType.asInstanceOf[ArrayType] override def nullable: Boolean = child.nullable || childDataType.containsNull - override def dataType: DataType = childDataType.elementType + @transient override lazy val dataType: DataType = childDataType.elementType - lazy val elementType: DataType = dataType.asInstanceOf[ArrayType].elementType + @transient private lazy val elementType: DataType = dataType.asInstanceOf[ArrayType].elementType override def checkInputDataTypes(): TypeCheckResult = child.dataType match { case ArrayType(_: ArrayType, _) => @@ -2202,7 +2448,7 @@ case class Flatten(child: Expression) extends UnaryExpression { case _ => TypeCheckResult.TypeCheckFailure( s"The argument should be an array of arrays, " + - s"but '${child.sql}' is of ${child.dataType.simpleString} type." + s"but '${child.sql}' is of ${child.dataType.catalogString} type." ) } @@ -2214,9 +2460,10 @@ case class Flatten(child: Expression) extends UnaryExpression { } else { val arrayData = elements.map(_.asInstanceOf[ArrayData]) val numberOfElements = arrayData.foldLeft(0L)((sum, e) => sum + e.numElements()) - if (numberOfElements > MAX_ARRAY_LENGTH) { + if (numberOfElements > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) { throw new RuntimeException("Unsuccessful try to flatten an array of arrays with " + - s"$numberOfElements elements due to exceeding the array size limit $MAX_ARRAY_LENGTH.") + s"$numberOfElements elements due to exceeding the array size limit " + + ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH + ".") } val flattenedData = new Array(numberOfElements.toInt) var position = 0 @@ -2231,11 +2478,7 @@ case class Flatten(child: Expression) extends UnaryExpression { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { nullSafeCodeGen(ctx, ev, c => { - val code = if (CodeGenerator.isPrimitiveType(elementType)) { - genCodeForFlattenOfPrimitiveElements(ctx, c, ev.value) - } else { - genCodeForFlattenOfNonPrimitiveElements(ctx, c, ev.value) - } + val code = genCodeForFlatten(ctx, c, ev.value) ctx.nullArrayElementsSaveExec(childDataType.containsNull, ev.isNull, c)(code) }) } @@ -2249,40 +2492,36 @@ case class Flatten(child: Expression) extends UnaryExpression { |for (int z = 0; z < $childVariableName.numElements(); z++) { | $variableName += $childVariableName.getArray(z).numElements(); |} - |if ($variableName > $MAX_ARRAY_LENGTH) { - | throw new RuntimeException("Unsuccessful try to flatten an array of arrays with " + - | $variableName + " elements due to exceeding the array size limit $MAX_ARRAY_LENGTH."); - |} """.stripMargin (code, variableName) } - private def genCodeForFlattenOfPrimitiveElements( + private def genCodeForFlatten( ctx: CodegenContext, childVariableName: String, arrayDataName: String): String = { val counter = ctx.freshName("counter") val tempArrayDataName = ctx.freshName("tempArrayData") + val k = ctx.freshName("k") + val l = ctx.freshName("l") + val arr = ctx.freshName("arr") val (numElemCode, numElemName) = genCodeForNumberOfElements(ctx, childVariableName) - val primitiveValueTypeName = CodeGenerator.primitiveTypeName(elementType) + val allocation = CodeGenerator.createArrayData( + tempArrayDataName, elementType, numElemName, s" $prettyName failed.") + val assignment = CodeGenerator.createArrayAssignment( + tempArrayDataName, elementType, arr, counter, l, + dataType.asInstanceOf[ArrayType].containsNull) s""" |$numElemCode - |${ctx.createUnsafeArray(tempArrayDataName, numElemName, elementType, s" $prettyName failed.")} + |$allocation |int $counter = 0; - |for (int k = 0; k < $childVariableName.numElements(); k++) { - | ArrayData arr = $childVariableName.getArray(k); - | for (int l = 0; l < arr.numElements(); l++) { - | if (arr.isNullAt(l)) { - | $tempArrayDataName.setNullAt($counter); - | } else { - | $tempArrayDataName.set$primitiveValueTypeName( - | $counter, - | ${CodeGenerator.getValue("arr", elementType, "l")} - | ); - | } + |for (int $k = 0; $k < $childVariableName.numElements(); $k++) { + | ArrayData $arr = $childVariableName.getArray($k); + | for (int $l = 0; $l < $arr.numElements(); $l++) { + | $assignment | $counter++; | } |} @@ -2290,30 +2529,6 @@ case class Flatten(child: Expression) extends UnaryExpression { """.stripMargin } - private def genCodeForFlattenOfNonPrimitiveElements( - ctx: CodegenContext, - childVariableName: String, - arrayDataName: String): String = { - val genericArrayClass = classOf[GenericArrayData].getName - val arrayName = ctx.freshName("arrayObject") - val counter = ctx.freshName("counter") - val (numElemCode, numElemName) = genCodeForNumberOfElements(ctx, childVariableName) - - s""" - |$numElemCode - |Object[] $arrayName = new Object[(int)$numElemName]; - |int $counter = 0; - |for (int k = 0; k < $childVariableName.numElements(); k++) { - | ArrayData arr = $childVariableName.getArray(k); - | for (int l = 0; l < arr.numElements(); l++) { - | $arrayName[$counter] = ${CodeGenerator.getValue("arr", elementType, "l")}; - | $counter++; - | } - |} - |$arrayDataName = new $genericArrayClass($arrayName); - """.stripMargin - } - override def prettyName: String = "flatten" } @@ -2375,7 +2590,7 @@ case class Sequence( override def nullable: Boolean = children.exists(_.nullable) - override lazy val dataType: ArrayType = ArrayType(start.dataType, containsNull = false) + override def dataType: ArrayType = ArrayType(start.dataType, containsNull = false) override def checkInputDataTypes(): TypeCheckResult = { val startType = start.dataType @@ -2406,7 +2621,7 @@ case class Sequence( stepOpt.map(step => if (step.dataType != CalendarIntervalType) Cast(step, widerType) else step), timeZoneId) - private lazy val impl: SequenceImpl = dataType.elementType match { + @transient private lazy val impl: SequenceImpl = dataType.elementType match { case iType: IntegralType => type T = iType.InternalType val ct = ClassTag[T](iType.tag.mirror.runtimeClass(iType.tag.tpe)) @@ -2720,14 +2935,12 @@ object Sequence { examples = """ Examples: > SELECT _FUNC_('123', 2); - ['123', '123'] + ["123", "123"] """, since = "2.4.0") case class ArrayRepeat(left: Expression, right: Expression) extends BinaryExpression with ExpectsInputTypes { - private val MAX_ARRAY_LENGTH = ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH - override def dataType: ArrayType = ArrayType(left.dataType, left.nullable) override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType, IntegerType) @@ -2739,9 +2952,9 @@ case class ArrayRepeat(left: Expression, right: Expression) if (count == null) { null } else { - if (count.asInstanceOf[Int] > MAX_ARRAY_LENGTH) { + if (count.asInstanceOf[Int] > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) { throw new RuntimeException(s"Unsuccessful try to create array with $count elements " + - s"due to exceeding the array size limit $MAX_ARRAY_LENGTH."); + s"due to exceeding the array size limit ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}."); } val element = left.eval(input) new GenericArrayData(Array.fill(count.asInstanceOf[Int])(element)) @@ -2757,11 +2970,7 @@ case class ArrayRepeat(left: Expression, right: Expression) val count = rightGen.value val et = dataType.elementType - val coreLogic = if (CodeGenerator.isPrimitiveType(et)) { - genCodeForPrimitiveElement(ctx, et, element, count, leftGen.isNull, ev.value) - } else { - genCodeForNonPrimitiveElement(ctx, element, count, leftGen.isNull, ev.value) - } + val coreLogic = genCodeForElement(ctx, et, element, count, leftGen.isNull, ev.value) val resultCode = nullElementsProtection(ev, rightGen.isNull, coreLogic) ev.copy(code = @@ -2800,16 +3009,12 @@ case class ArrayRepeat(left: Expression, right: Expression) |if ($count > 0) { | $numElements = $count; |} - |if ($numElements > $MAX_ARRAY_LENGTH) { - | throw new RuntimeException("Unsuccessful try to create array with " + $numElements + - | " elements due to exceeding the array size limit $MAX_ARRAY_LENGTH."); - |} """.stripMargin (numElements, numElementsCode) } - private def genCodeForPrimitiveElement( + private def genCodeForElement( ctx: CodegenContext, elementType: DataType, element: String, @@ -2817,48 +3022,30 @@ case class ArrayRepeat(left: Expression, right: Expression) leftIsNull: String, arrayDataName: String): String = { val tempArrayDataName = ctx.freshName("tempArrayData") - val primitiveValueTypeName = CodeGenerator.primitiveTypeName(elementType) - val errorMessage = s" $prettyName failed." + val k = ctx.freshName("k") val (numElemName, numElemCode) = genCodeForNumberOfElements(ctx, count) + val allocation = CodeGenerator.createArrayData( + tempArrayDataName, elementType, numElemName, s" $prettyName failed.") + val assignment = + CodeGenerator.setArrayElement(tempArrayDataName, elementType, k, element) + s""" |$numElemCode - |${ctx.createUnsafeArray(tempArrayDataName, numElemName, elementType, errorMessage)} + |$allocation |if (!$leftIsNull) { - | for (int k = 0; k < $tempArrayDataName.numElements(); k++) { - | $tempArrayDataName.set$primitiveValueTypeName(k, $element); + | for (int $k = 0; $k < $tempArrayDataName.numElements(); $k++) { + | $assignment | } |} else { - | for (int k = 0; k < $tempArrayDataName.numElements(); k++) { - | $tempArrayDataName.setNullAt(k); + | for (int $k = 0; $k < $tempArrayDataName.numElements(); $k++) { + | $tempArrayDataName.setNullAt($k); | } |} |$arrayDataName = $tempArrayDataName; """.stripMargin } - private def genCodeForNonPrimitiveElement( - ctx: CodegenContext, - element: String, - count: String, - leftIsNull: String, - arrayDataName: String): String = { - val genericArrayClass = classOf[GenericArrayData].getName - val arrayName = ctx.freshName("arrayObject") - val (numElemName, numElemCode) = genCodeForNumberOfElements(ctx, count) - - s""" - |$numElemCode - |Object[] $arrayName = new Object[(int)$numElemName]; - |if (!$leftIsNull) { - | for (int k = 0; k < $numElemName; k++) { - | $arrayName[k] = $element; - | } - |} - |$arrayDataName = new $genericArrayClass($arrayName); - """.stripMargin - } - } /** @@ -2869,7 +3056,7 @@ case class ArrayRepeat(left: Expression, right: Expression) examples = """ Examples: > SELECT _FUNC_(array(1, 2, 3, null, 3), 3); - [1,2,null] + [1, 2, null] """, since = "2.4.0") case class ArrayRemove(left: Expression, right: Expression) extends BinaryExpression with ImplicitCastInputTypes { @@ -2884,7 +3071,7 @@ case class ArrayRemove(left: Expression, right: Expression) Seq(ArrayType, elementType) } - lazy val elementType: DataType = left.dataType.asInstanceOf[ArrayType].elementType + private def elementType: DataType = left.dataType.asInstanceOf[ArrayType].elementType @transient private lazy val ordering: Ordering[Any] = TypeUtils.getInterpretedOrdering(right.dataType) @@ -2940,50 +3127,117 @@ case class ArrayRemove(left: Expression, right: Expression) val pos = ctx.freshName("pos") val getValue = CodeGenerator.getValue(inputArray, elementType, i) val isEqual = ctx.genEqual(elementType, value, getValue) - if (!CodeGenerator.isPrimitiveType(elementType)) { - val arrayClass = classOf[GenericArrayData].getName + + val allocation = CodeGenerator.createArrayData( + values, elementType, newArraySize, s" $prettyName failed.") + val assignment = CodeGenerator.createArrayAssignment( + values, elementType, inputArray, pos, i, false) + + s""" + |$allocation + |int $pos = 0; + |for (int $i = 0; $i < $inputArray.numElements(); $i ++) { + | if ($inputArray.isNullAt($i)) { + | $values.setNullAt($pos); + | $pos = $pos + 1; + | } + | else { + | if (!($isEqual)) { + | $assignment + | $pos = $pos + 1; + | } + | } + |} + |${ev.value} = $values; + """.stripMargin + } + + override def prettyName: String = "array_remove" +} + +/** + * Will become common base class for [[ArrayDistinct]], [[ArrayUnion]], [[ArrayIntersect]], + * and [[ArrayExcept]]. + */ +trait ArraySetLike { + protected def dt: DataType + protected def et: DataType + + @transient protected lazy val canUseSpecializedHashSet = et match { + case ByteType | ShortType | IntegerType | LongType | FloatType | DoubleType => true + case _ => false + } + + @transient protected lazy val ordering: Ordering[Any] = + TypeUtils.getInterpretedOrdering(et) + + protected def genGetValue(array: String, i: String): String = + CodeGenerator.getValue(array, et, i) + + @transient protected lazy val (hsPostFix, hsTypeName) = { + val ptName = CodeGenerator.primitiveTypeName(et) + et match { + // we cast byte/short to int when writing to the hash set. + case ByteType | ShortType | IntegerType => ("$mcI$sp", "Int") + case LongType => ("$mcJ$sp", ptName) + case FloatType => ("$mcF$sp", ptName) + case DoubleType => ("$mcD$sp", ptName) + } + } + + // we cast byte/short to int when writing to the hash set. + @transient protected lazy val hsValueCast = et match { + case ByteType | ShortType => "(int) " + case _ => "" + } + + // When hitting a null value, put a null holder in the ArrayBuilder. Finally we will + // convert ArrayBuilder to ArrayData and setNull on the slot with null holder. + @transient protected lazy val nullValueHolder = et match { + case ByteType => "(byte) 0" + case ShortType => "(short) 0" + case _ => "0" + } + + protected def withResultArrayNullCheck( + body: String, + value: String, + nullElementIndex: String): String = { + if (dt.asInstanceOf[ArrayType].containsNull) { s""" - |int $pos = 0; - |Object[] $values = new Object[$newArraySize]; - |for (int $i = 0; $i < $inputArray.numElements(); $i ++) { - | if ($inputArray.isNullAt($i)) { - | $values[$pos] = null; - | $pos = $pos + 1; - | } - | else { - | if (!($isEqual)) { - | $values[$pos] = $getValue; - | $pos = $pos + 1; - | } - | } + |$body + |if ($nullElementIndex >= 0) { + | // result has null element + | $value.setNullAt($nullElementIndex); |} - |${ev.value} = new $arrayClass($values); """.stripMargin } else { - val primitiveValueTypeName = CodeGenerator.primitiveTypeName(elementType) - s""" - |${ctx.createUnsafeArray(values, newArraySize, elementType, s" $prettyName failed.")} - |int $pos = 0; - |for (int $i = 0; $i < $inputArray.numElements(); $i ++) { - | if ($inputArray.isNullAt($i)) { - | $values.setNullAt($pos); - | $pos = $pos + 1; - | } - | else { - | if (!($isEqual)) { - | $values.set$primitiveValueTypeName($pos, $getValue); - | $pos = $pos + 1; - | } - | } - |} - |${ev.value} = $values; - """.stripMargin + body } } - override def prettyName: String = "array_remove" + def buildResultArray( + builder: String, + value : String, + size : String, + nullElementIndex : String): String = withResultArrayNullCheck( + s""" + |if ($size > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) { + | throw new RuntimeException("Cannot create array with " + $size + + | " elements of data due to exceeding the limit " + + | "${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH} elements for ArrayData."); + |} + | + |if (!UnsafeArrayData.shouldUseGenericArrayData(${et.defaultSize}, $size)) { + | $value = UnsafeArrayData.fromPrimitiveArray($builder.result()); + |} else { + | $value = new ${classOf[GenericArrayData].getName}($builder.result()); + |} + """.stripMargin, value, nullElementIndex) + } + /** * Removes duplicate values from the array. */ @@ -2992,19 +3246,19 @@ case class ArrayRemove(left: Expression, right: Expression) examples = """ Examples: > SELECT _FUNC_(array(1, 2, 3, null, 3)); - [1,2,3,null] + [1, 2, 3, null] """, since = "2.4.0") case class ArrayDistinct(child: Expression) - extends UnaryExpression with ExpectsInputTypes { + extends UnaryExpression with ArraySetLike with ExpectsInputTypes { override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType) override def dataType: DataType = child.dataType - @transient lazy val elementType: DataType = dataType.asInstanceOf[ArrayType].elementType + @transient private lazy val elementType: DataType = dataType.asInstanceOf[ArrayType].elementType - @transient private lazy val ordering: Ordering[Any] = - TypeUtils.getInterpretedOrdering(elementType) + override protected def dt: DataType = dataType + override protected def et: DataType = elementType override def checkInputDataTypes(): TypeCheckResult = { super.checkInputDataTypes() match { @@ -3014,17 +3268,15 @@ case class ArrayDistinct(child: Expression) } } - @transient private lazy val elementTypeSupportEquals = elementType match { - case BinaryType => false - case _: AtomicType => true - case _ => false - } - override def nullSafeEval(array: Any): Any = { val data = array.asInstanceOf[ArrayData].toArray[AnyRef](elementType) - if (elementTypeSupportEquals) { - new GenericArrayData(data.distinct.asInstanceOf[Array[Any]]) - } else { + doEvaluation(data) + } + + @transient private lazy val doEvaluation = if (TypeUtils.typeWithProperEquals(elementType)) { + (data: Array[AnyRef]) => new GenericArrayData(data.distinct.asInstanceOf[Array[Any]]) + } else { + (data: Array[AnyRef]) => { var foundNullElement = false var pos = 0 for (i <- 0 until data.length) { @@ -3052,212 +3304,782 @@ case class ArrayDistinct(child: Expression) } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - nullSafeCodeGen(ctx, ev, (array) => { - val i = ctx.freshName("i") - val j = ctx.freshName("j") - val sizeOfDistinctArray = ctx.freshName("sizeOfDistinctArray") - val getValue1 = CodeGenerator.getValue(array, elementType, i) - val getValue2 = CodeGenerator.getValue(array, elementType, j) - val foundNullElement = ctx.freshName("foundNullElement") - val openHashSet = classOf[OpenHashSet[_]].getName - val hs = ctx.freshName("hs") - val classTag = s"scala.reflect.ClassTag$$.MODULE$$.Object()" - if (elementTypeSupportEquals) { + val i = ctx.freshName("i") + val value = ctx.freshName("value") + val size = ctx.freshName("size") + + if (canUseSpecializedHashSet) { + val jt = CodeGenerator.javaType(elementType) + val ptName = CodeGenerator.primitiveTypeName(jt) + + nullSafeCodeGen(ctx, ev, (array) => { + val foundNullElement = ctx.freshName("foundNullElement") + val nullElementIndex = ctx.freshName("nullElementIndex") + val builder = ctx.freshName("builder") + val openHashSet = classOf[OpenHashSet[_]].getName + val classTag = s"scala.reflect.ClassTag$$.MODULE$$.$hsTypeName()" + val hashSet = ctx.freshName("hashSet") + val arrayBuilder = classOf[mutable.ArrayBuilder[_]].getName + val arrayBuilderClass = s"$arrayBuilder$$of$ptName" + + // Only need to track null element index when array's element is nullable. + val declareNullTrackVariables = if (dataType.asInstanceOf[ArrayType].containsNull) { + s""" + |boolean $foundNullElement = false; + |int $nullElementIndex = -1; + """.stripMargin + } else { + "" + } + + def withArrayNullAssignment(body: String) = + if (dataType.asInstanceOf[ArrayType].containsNull) { + s""" + |if ($array.isNullAt($i)) { + | if (!$foundNullElement) { + | $nullElementIndex = $size; + | $foundNullElement = true; + | $size++; + | $builder.$$plus$$eq($nullValueHolder); + | } + |} else { + | $body + |} + """.stripMargin + } else { + body + } + + val processArray = withArrayNullAssignment( + s""" + |$jt $value = ${genGetValue(array, i)}; + |if (!$hashSet.contains($hsValueCast$value)) { + | if (++$size > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) { + | break; + | } + | $hashSet.add$hsPostFix($hsValueCast$value); + | $builder.$$plus$$eq($value); + |} + """.stripMargin) + s""" - |int $sizeOfDistinctArray = 0; - |boolean $foundNullElement = false; - |$openHashSet $hs = new $openHashSet($classTag); - |for (int $i = 0; $i < $array.numElements(); $i ++) { - | if ($array.isNullAt($i)) { - | $foundNullElement = true; - | } else { - | $hs.add($getValue1); - | } + |$openHashSet $hashSet = new $openHashSet$hsPostFix($classTag); + |$declareNullTrackVariables + |$arrayBuilderClass $builder = new $arrayBuilderClass(); + |int $size = 0; + |for (int $i = 0; $i < $array.numElements(); $i++) { + | $processArray |} - |$sizeOfDistinctArray = $hs.size() + ($foundNullElement ? 1 : 0); - |${genCodeForResult(ctx, ev, array, sizeOfDistinctArray)} + |${buildResultArray(builder, ev.value, size, nullElementIndex)} """.stripMargin - } else { + }) + } else { + nullSafeCodeGen(ctx, ev, (array) => { + val expr = ctx.addReferenceObj("arrayDistinctExpr", this) + s"${ev.value} = (ArrayData)$expr.nullSafeEval($array);" + }) + } + } + + override def prettyName: String = "array_distinct" +} + +/** + * Will become common base class for [[ArrayUnion]], [[ArrayIntersect]], and [[ArrayExcept]]. + */ +trait ArrayBinaryLike extends BinaryArrayExpressionWithImplicitCast with ArraySetLike { + override protected def dt: DataType = dataType + override protected def et: DataType = elementType + + override def checkInputDataTypes(): TypeCheckResult = { + val typeCheckResult = super.checkInputDataTypes() + if (typeCheckResult.isSuccess) { + TypeUtils.checkForOrderingExpr(dataType.asInstanceOf[ArrayType].elementType, + s"function $prettyName") + } else { + typeCheckResult + } + } +} + +object ArrayBinaryLike { + def throwUnionLengthOverflowException(length: Int): Unit = { + throw new RuntimeException(s"Unsuccessful try to union arrays with $length " + + s"elements due to exceeding the array size limit " + + s"${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}.") + } +} + + +/** + * Returns an array of the elements in the union of x and y, without duplicates + */ +@ExpressionDescription( + usage = """ + _FUNC_(array1, array2) - Returns an array of the elements in the union of array1 and array2, + without duplicates. + """, + examples = """ + Examples: + > SELECT _FUNC_(array(1, 2, 3), array(1, 3, 5)); + [1, 2, 3, 5] + """, + since = "2.4.0") +case class ArrayUnion(left: Expression, right: Expression) extends ArrayBinaryLike + with ComplexTypeMergingExpression { + + @transient lazy val evalUnion: (ArrayData, ArrayData) => ArrayData = { + if (TypeUtils.typeWithProperEquals(elementType)) { + (array1, array2) => + val arrayBuffer = new scala.collection.mutable.ArrayBuffer[Any] + val hs = new OpenHashSet[Any] + var foundNullElement = false + Seq(array1, array2).foreach { array => + var i = 0 + while (i < array.numElements()) { + if (array.isNullAt(i)) { + if (!foundNullElement) { + arrayBuffer += null + foundNullElement = true + } + } else { + val elem = array.get(i, elementType) + if (!hs.contains(elem)) { + if (arrayBuffer.size > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) { + ArrayBinaryLike.throwUnionLengthOverflowException(arrayBuffer.size) + } + arrayBuffer += elem + hs.add(elem) + } + } + i += 1 + } + } + new GenericArrayData(arrayBuffer) + } else { + (array1, array2) => + val arrayBuffer = new scala.collection.mutable.ArrayBuffer[Any] + var alreadyIncludeNull = false + Seq(array1, array2).foreach(_.foreach(elementType, (_, elem) => { + var found = false + if (elem == null) { + if (alreadyIncludeNull) { + found = true + } else { + alreadyIncludeNull = true + } + } else { + // check elem is already stored in arrayBuffer or not? + var j = 0 + while (!found && j < arrayBuffer.size) { + val va = arrayBuffer(j) + if (va != null && ordering.equiv(va, elem)) { + found = true + } + j = j + 1 + } + } + if (!found) { + if (arrayBuffer.length > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) { + ArrayBinaryLike.throwUnionLengthOverflowException(arrayBuffer.length) + } + arrayBuffer += elem + } + })) + new GenericArrayData(arrayBuffer) + } + } + + override def nullSafeEval(input1: Any, input2: Any): Any = { + val array1 = input1.asInstanceOf[ArrayData] + val array2 = input2.asInstanceOf[ArrayData] + + evalUnion(array1, array2) + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val i = ctx.freshName("i") + val value = ctx.freshName("value") + val size = ctx.freshName("size") + if (canUseSpecializedHashSet) { + val jt = CodeGenerator.javaType(elementType) + val ptName = CodeGenerator.primitiveTypeName(jt) + + nullSafeCodeGen(ctx, ev, (array1, array2) => { + val foundNullElement = ctx.freshName("foundNullElement") + val nullElementIndex = ctx.freshName("nullElementIndex") + val builder = ctx.freshName("builder") + val array = ctx.freshName("array") + val arrays = ctx.freshName("arrays") + val arrayDataIdx = ctx.freshName("arrayDataIdx") + val openHashSet = classOf[OpenHashSet[_]].getName + val classTag = s"scala.reflect.ClassTag$$.MODULE$$.$hsTypeName()" + val hashSet = ctx.freshName("hashSet") + val arrayBuilder = classOf[mutable.ArrayBuilder[_]].getName + val arrayBuilderClass = s"$arrayBuilder$$of$ptName" + + def withArrayNullAssignment(body: String) = + if (dataType.asInstanceOf[ArrayType].containsNull) { + s""" + |if ($array.isNullAt($i)) { + | if (!$foundNullElement) { + | $nullElementIndex = $size; + | $foundNullElement = true; + | $size++; + | $builder.$$plus$$eq($nullValueHolder); + | } + |} else { + | $body + |} + """.stripMargin + } else { + body + } + + val processArray = withArrayNullAssignment( + s""" + |$jt $value = ${genGetValue(array, i)}; + |if (!$hashSet.contains($hsValueCast$value)) { + | if (++$size > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) { + | break; + | } + | $hashSet.add$hsPostFix($hsValueCast$value); + | $builder.$$plus$$eq($value); + |} + """.stripMargin) + + // Only need to track null element index when result array's element is nullable. + val declareNullTrackVariables = if (dataType.asInstanceOf[ArrayType].containsNull) { + s""" + |boolean $foundNullElement = false; + |int $nullElementIndex = -1; + """.stripMargin + } else { + "" + } + s""" - |int $sizeOfDistinctArray = 0; - |boolean $foundNullElement = false; - |for (int $i = 0; $i < $array.numElements(); $i ++) { - | if ($array.isNullAt($i)) { - | if (!($foundNullElement)) { - | $sizeOfDistinctArray = $sizeOfDistinctArray + 1; - | $foundNullElement = true; - | } - | } else { - | int $j; - | for ($j = 0; $j < $i; $j ++) { - | if (!$array.isNullAt($j) && ${ctx.genEqual(elementType, getValue1, getValue2)}) { - | break; - | } - | } - | if ($i == $j) { - | $sizeOfDistinctArray = $sizeOfDistinctArray + 1; - | } + |$openHashSet $hashSet = new $openHashSet$hsPostFix($classTag); + |$declareNullTrackVariables + |int $size = 0; + |$arrayBuilderClass $builder = new $arrayBuilderClass(); + |ArrayData[] $arrays = new ArrayData[]{$array1, $array2}; + |for (int $arrayDataIdx = 0; $arrayDataIdx < 2; $arrayDataIdx++) { + | ArrayData $array = $arrays[$arrayDataIdx]; + | for (int $i = 0; $i < $array.numElements(); $i++) { + | $processArray | } |} - | - |${genCodeForResult(ctx, ev, array, sizeOfDistinctArray)} + |${buildResultArray(builder, ev.value, size, nullElementIndex)} """.stripMargin - } - }) + }) + } else { + nullSafeCodeGen(ctx, ev, (array1, array2) => { + val expr = ctx.addReferenceObj("arrayUnionExpr", this) + s"${ev.value} = (ArrayData)$expr.nullSafeEval($array1, $array2);" + }) + } } - private def setNull( - isPrimitive: Boolean, - foundNullElement: String, - distinctArray: String, - pos: String): String = { - val setNullValue = - if (!isPrimitive) { - s"$distinctArray[$pos] = null"; + override def prettyName: String = "array_union" +} + +object ArrayUnion { + def unionOrdering( + array1: ArrayData, + array2: ArrayData, + elementType: DataType, + ordering: Ordering[Any]): ArrayData = { + val arrayBuffer = new scala.collection.mutable.ArrayBuffer[Any] + var alreadyIncludeNull = false + Seq(array1, array2).foreach(_.foreach(elementType, (_, elem) => { + var found = false + if (elem == null) { + if (alreadyIncludeNull) { + found = true + } else { + alreadyIncludeNull = true + } } else { - s"$distinctArray.setNullAt($pos)"; + // check elem is already stored in arrayBuffer or not? + var j = 0 + while (!found && j < arrayBuffer.size) { + val va = arrayBuffer(j) + if (va != null && ordering.equiv(va, elem)) { + found = true + } + j = j + 1 + } } + if (!found) { + if (arrayBuffer.length > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) { + ArrayBinaryLike.throwUnionLengthOverflowException(arrayBuffer.length) + } + arrayBuffer += elem + } + })) + new GenericArrayData(arrayBuffer) + } +} - s""" - |if (!($foundNullElement)) { - | $setNullValue; - | $pos = $pos + 1; - | $foundNullElement = true; - |} - """.stripMargin +/** + * Returns an array of the elements in the intersect of x and y, without duplicates + */ +@ExpressionDescription( + usage = """ + _FUNC_(array1, array2) - Returns an array of the elements in the intersection of array1 and + array2, without duplicates. + """, + examples = """ + Examples: + > SELECT _FUNC_(array(1, 2, 3), array(1, 3, 5)); + [1, 3] + """, + since = "2.4.0") +case class ArrayIntersect(left: Expression, right: Expression) extends ArrayBinaryLike + with ComplexTypeMergingExpression { + override def dataType: DataType = { + dataTypeCheck + ArrayType(elementType, + left.dataType.asInstanceOf[ArrayType].containsNull && + right.dataType.asInstanceOf[ArrayType].containsNull) + } + + @transient lazy val evalIntersect: (ArrayData, ArrayData) => ArrayData = { + if (TypeUtils.typeWithProperEquals(elementType)) { + (array1, array2) => + if (array1.numElements() != 0 && array2.numElements() != 0) { + val hs = new OpenHashSet[Any] + val hsResult = new OpenHashSet[Any] + var foundNullElement = false + var i = 0 + while (i < array2.numElements()) { + if (array2.isNullAt(i)) { + foundNullElement = true + } else { + val elem = array2.get(i, elementType) + hs.add(elem) + } + i += 1 + } + val arrayBuffer = new scala.collection.mutable.ArrayBuffer[Any] + i = 0 + while (i < array1.numElements()) { + if (array1.isNullAt(i)) { + if (foundNullElement) { + arrayBuffer += null + foundNullElement = false + } + } else { + val elem = array1.get(i, elementType) + if (hs.contains(elem) && !hsResult.contains(elem)) { + arrayBuffer += elem + hsResult.add(elem) + } + } + i += 1 + } + new GenericArrayData(arrayBuffer) + } else { + new GenericArrayData(Array.emptyObjectArray) + } + } else { + (array1, array2) => + if (array1.numElements() != 0 && array2.numElements() != 0) { + val arrayBuffer = new scala.collection.mutable.ArrayBuffer[Any] + var alreadySeenNull = false + var i = 0 + while (i < array1.numElements()) { + var found = false + val elem1 = array1.get(i, elementType) + if (array1.isNullAt(i)) { + if (!alreadySeenNull) { + var j = 0 + while (!found && j < array2.numElements()) { + found = array2.isNullAt(j) + j += 1 + } + // array2 is scanned only once for null element + alreadySeenNull = true + } + } else { + var j = 0 + while (!found && j < array2.numElements()) { + if (!array2.isNullAt(j)) { + val elem2 = array2.get(j, elementType) + if (ordering.equiv(elem1, elem2)) { + // check whether elem1 is already stored in arrayBuffer + var foundArrayBuffer = false + var k = 0 + while (!foundArrayBuffer && k < arrayBuffer.size) { + val va = arrayBuffer(k) + foundArrayBuffer = (va != null) && ordering.equiv(va, elem1) + k += 1 + } + found = !foundArrayBuffer + } + } + j += 1 + } + } + if (found) { + arrayBuffer += elem1 + } + i += 1 + } + new GenericArrayData(arrayBuffer) + } else { + new GenericArrayData(Array.emptyObjectArray) + } + } + } + + override def nullSafeEval(input1: Any, input2: Any): Any = { + val array1 = input1.asInstanceOf[ArrayData] + val array2 = input2.asInstanceOf[ArrayData] + + evalIntersect(array1, array2) } - private def setNotNullValue(isPrimitive: Boolean, - distinctArray: String, - pos: String, - getValue1: String, - primitiveValueTypeName: String): String = { - if (!isPrimitive) { - s"$distinctArray[$pos] = $getValue1"; + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val i = ctx.freshName("i") + val value = ctx.freshName("value") + val size = ctx.freshName("size") + if (canUseSpecializedHashSet) { + val jt = CodeGenerator.javaType(elementType) + val ptName = CodeGenerator.primitiveTypeName(jt) + + nullSafeCodeGen(ctx, ev, (array1, array2) => { + val foundNullElement = ctx.freshName("foundNullElement") + val nullElementIndex = ctx.freshName("nullElementIndex") + val builder = ctx.freshName("builder") + val openHashSet = classOf[OpenHashSet[_]].getName + val classTag = s"scala.reflect.ClassTag$$.MODULE$$.$hsTypeName()" + val hashSet = ctx.freshName("hashSet") + val hashSetResult = ctx.freshName("hashSetResult") + val arrayBuilder = classOf[mutable.ArrayBuilder[_]].getName + val arrayBuilderClass = s"$arrayBuilder$$of$ptName" + + def withArray2NullCheck(body: String): String = + if (right.dataType.asInstanceOf[ArrayType].containsNull) { + if (left.dataType.asInstanceOf[ArrayType].containsNull) { + s""" + |if ($array2.isNullAt($i)) { + | $foundNullElement = true; + |} else { + | $body + |} + """.stripMargin + } else { + // if array1's element is not nullable, we don't need to track the null element index. + s""" + |if (!$array2.isNullAt($i)) { + | $body + |} + """.stripMargin + } + } else { + body + } + + val writeArray2ToHashSet = withArray2NullCheck( + s""" + |$jt $value = ${genGetValue(array2, i)}; + |$hashSet.add$hsPostFix($hsValueCast$value); + """.stripMargin) + + def withArray1NullAssignment(body: String) = + if (left.dataType.asInstanceOf[ArrayType].containsNull) { + if (right.dataType.asInstanceOf[ArrayType].containsNull) { + s""" + |if ($array1.isNullAt($i)) { + | if ($foundNullElement) { + | $nullElementIndex = $size; + | $foundNullElement = false; + | $size++; + | $builder.$$plus$$eq($nullValueHolder); + | } + |} else { + | $body + |} + """.stripMargin + } else { + s""" + |if (!$array1.isNullAt($i)) { + | $body + |} + """.stripMargin + } + } else { + body + } + + val processArray1 = withArray1NullAssignment( + s""" + |$jt $value = ${genGetValue(array1, i)}; + |if ($hashSet.contains($hsValueCast$value) && + | !$hashSetResult.contains($hsValueCast$value)) { + | if (++$size > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) { + | break; + | } + | $hashSetResult.add$hsPostFix($hsValueCast$value); + | $builder.$$plus$$eq($value); + |} + """.stripMargin) + + // Only need to track null element index when result array's element is nullable. + val declareNullTrackVariables = if (dataType.asInstanceOf[ArrayType].containsNull) { + s""" + |boolean $foundNullElement = false; + |int $nullElementIndex = -1; + """.stripMargin + } else { + "" + } + + s""" + |$openHashSet $hashSet = new $openHashSet$hsPostFix($classTag); + |$openHashSet $hashSetResult = new $openHashSet$hsPostFix($classTag); + |$declareNullTrackVariables + |for (int $i = 0; $i < $array2.numElements(); $i++) { + | $writeArray2ToHashSet + |} + |$arrayBuilderClass $builder = new $arrayBuilderClass(); + |int $size = 0; + |for (int $i = 0; $i < $array1.numElements(); $i++) { + | $processArray1 + |} + |${buildResultArray(builder, ev.value, size, nullElementIndex)} + """.stripMargin + }) } else { - s"$distinctArray.set$primitiveValueTypeName($pos, $getValue1)"; + nullSafeCodeGen(ctx, ev, (array1, array2) => { + val expr = ctx.addReferenceObj("arrayIntersectExpr", this) + s"${ev.value} = (ArrayData)$expr.nullSafeEval($array1, $array2);" + }) } } - private def setValueForFastEval( - isPrimitive: Boolean, - hs: String, - distinctArray: String, - pos: String, - getValue1: String, - primitiveValueTypeName: String): String = { - val setValue = setNotNullValue(isPrimitive, - distinctArray, pos, getValue1, primitiveValueTypeName) - s""" - |if (!($hs.contains($getValue1))) { - | $hs.add($getValue1); - | $setValue; - | $pos = $pos + 1; - |} - """.stripMargin + override def prettyName: String = "array_intersect" +} + +/** + * Returns an array of the elements in the intersect of x and y, without duplicates + */ +@ExpressionDescription( + usage = """ + _FUNC_(array1, array2) - Returns an array of the elements in array1 but not in array2, + without duplicates. + """, + examples = """ + Examples: + > SELECT _FUNC_(array(1, 2, 3), array(1, 3, 5)); + [2] + """, + since = "2.4.0") +case class ArrayExcept(left: Expression, right: Expression) extends ArrayBinaryLike + with ComplexTypeMergingExpression { + + override def dataType: DataType = { + dataTypeCheck + left.dataType } - private def setValueForBruteForceEval( - isPrimitive: Boolean, - i: String, - j: String, - inputArray: String, - distinctArray: String, - pos: String, - getValue1: String, - isEqual: String, - primitiveValueTypeName: String): String = { - val setValue = setNotNullValue(isPrimitive, - distinctArray, pos, getValue1, primitiveValueTypeName) - s""" - |int $j; - |for ($j = 0; $j < $i; $j ++) { - | if (!$inputArray.isNullAt($j) && $isEqual) { - | break; - | } - |} - |if ($i == $j) { - | $setValue; - | $pos = $pos + 1; - |} - """.stripMargin + @transient lazy val evalExcept: (ArrayData, ArrayData) => ArrayData = { + if (TypeUtils.typeWithProperEquals(elementType)) { + (array1, array2) => + val hs = new OpenHashSet[Any] + var notFoundNullElement = true + var i = 0 + while (i < array2.numElements()) { + if (array2.isNullAt(i)) { + notFoundNullElement = false + } else { + val elem = array2.get(i, elementType) + hs.add(elem) + } + i += 1 + } + val arrayBuffer = new scala.collection.mutable.ArrayBuffer[Any] + i = 0 + while (i < array1.numElements()) { + if (array1.isNullAt(i)) { + if (notFoundNullElement) { + arrayBuffer += null + notFoundNullElement = false + } + } else { + val elem = array1.get(i, elementType) + if (!hs.contains(elem)) { + arrayBuffer += elem + hs.add(elem) + } + } + i += 1 + } + new GenericArrayData(arrayBuffer) + } else { + (array1, array2) => + val arrayBuffer = new scala.collection.mutable.ArrayBuffer[Any] + var scannedNullElements = false + var i = 0 + while (i < array1.numElements()) { + var found = false + val elem1 = array1.get(i, elementType) + if (elem1 == null) { + if (!scannedNullElements) { + var j = 0 + while (!found && j < array2.numElements()) { + found = array2.isNullAt(j) + j += 1 + } + // array2 is scanned only once for null element + scannedNullElements = true + } else { + found = true + } + } else { + var j = 0 + while (!found && j < array2.numElements()) { + val elem2 = array2.get(j, elementType) + if (elem2 != null) { + found = ordering.equiv(elem1, elem2) + } + j += 1 + } + if (!found) { + // check whether elem1 is already stored in arrayBuffer + var k = 0 + while (!found && k < arrayBuffer.size) { + val va = arrayBuffer(k) + found = (va != null) && ordering.equiv(va, elem1) + k += 1 + } + } + } + if (!found) { + arrayBuffer += elem1 + } + i += 1 + } + new GenericArrayData(arrayBuffer) + } } - def genCodeForResult( - ctx: CodegenContext, - ev: ExprCode, - inputArray: String, - size: String): String = { - val distinctArray = ctx.freshName("distinctArray") + override def nullSafeEval(input1: Any, input2: Any): Any = { + val array1 = input1.asInstanceOf[ArrayData] + val array2 = input2.asInstanceOf[ArrayData] + + evalExcept(array1, array2) + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val i = ctx.freshName("i") - val j = ctx.freshName("j") - val pos = ctx.freshName("pos") - val getValue1 = CodeGenerator.getValue(inputArray, elementType, i) - val getValue2 = CodeGenerator.getValue(inputArray, elementType, j) - val isEqual = ctx.genEqual(elementType, getValue1, getValue2) - val foundNullElement = ctx.freshName("foundNullElement") - val hs = ctx.freshName("hs") - val openHashSet = classOf[OpenHashSet[_]].getName - if (!CodeGenerator.isPrimitiveType(elementType)) { - val arrayClass = classOf[GenericArrayData].getName - val classTag = s"scala.reflect.ClassTag$$.MODULE$$.Object()" - val setNullForNonPrimitive = - setNull(false, foundNullElement, distinctArray, pos) - if (elementTypeSupportEquals) { - val setValueForFast = setValueForFastEval(false, hs, distinctArray, pos, getValue1, "") + val value = ctx.freshName("value") + val size = ctx.freshName("size") + if (canUseSpecializedHashSet) { + val jt = CodeGenerator.javaType(elementType) + val ptName = CodeGenerator.primitiveTypeName(jt) + + nullSafeCodeGen(ctx, ev, (array1, array2) => { + val notFoundNullElement = ctx.freshName("notFoundNullElement") + val nullElementIndex = ctx.freshName("nullElementIndex") + val builder = ctx.freshName("builder") + val openHashSet = classOf[OpenHashSet[_]].getName + val classTag = s"scala.reflect.ClassTag$$.MODULE$$.$hsTypeName()" + val hashSet = ctx.freshName("hashSet") + val arrayBuilder = classOf[mutable.ArrayBuilder[_]].getName + val arrayBuilderClass = s"$arrayBuilder$$of$ptName" + + def withArray2NullCheck(body: String): String = + if (right.dataType.asInstanceOf[ArrayType].containsNull) { + if (left.dataType.asInstanceOf[ArrayType].containsNull) { + s""" + |if ($array2.isNullAt($i)) { + | $notFoundNullElement = false; + |} else { + | $body + |} + """.stripMargin + } else { + // if array1's element is not nullable, we don't need to track the null element index. + s""" + |if (!$array2.isNullAt($i)) { + | $body + |} + """.stripMargin + } + } else { + body + } + + val writeArray2ToHashSet = withArray2NullCheck( + s""" + |$jt $value = ${genGetValue(array2, i)}; + |$hashSet.add$hsPostFix($hsValueCast$value); + """.stripMargin) + + def withArray1NullAssignment(body: String) = + if (left.dataType.asInstanceOf[ArrayType].containsNull) { + s""" + |if ($array1.isNullAt($i)) { + | if ($notFoundNullElement) { + | $nullElementIndex = $size; + | $notFoundNullElement = false; + | $size++; + | $builder.$$plus$$eq($nullValueHolder); + | } + |} else { + | $body + |} + """.stripMargin + } else { + body + } + + val processArray1 = withArray1NullAssignment( + s""" + |$jt $value = ${genGetValue(array1, i)}; + |if (!$hashSet.contains($hsValueCast$value)) { + | if (++$size > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) { + | break; + | } + | $hashSet.add$hsPostFix($hsValueCast$value); + | $builder.$$plus$$eq($value); + |} + """.stripMargin) + + // Only need to track null element index when array1's element is nullable. + val declareNullTrackVariables = if (left.dataType.asInstanceOf[ArrayType].containsNull) { + s""" + |boolean $notFoundNullElement = true; + |int $nullElementIndex = -1; + """.stripMargin + } else { + "" + } + s""" - |int $pos = 0; - |Object[] $distinctArray = new Object[$size]; - |boolean $foundNullElement = false; - |$openHashSet $hs = new $openHashSet($classTag); - |for (int $i = 0; $i < $inputArray.numElements(); $i ++) { - | if ($inputArray.isNullAt($i)) { - | $setNullForNonPrimitive; - | } else { - | $setValueForFast; - | } + |$openHashSet $hashSet = new $openHashSet$hsPostFix($classTag); + |$declareNullTrackVariables + |for (int $i = 0; $i < $array2.numElements(); $i++) { + | $writeArray2ToHashSet |} - |${ev.value} = new $arrayClass($distinctArray); - """.stripMargin - } else { - val setValueForBruteForce = setValueForBruteForceEval( - false, i, j, inputArray, distinctArray, pos, getValue1, isEqual, "") - s""" - |int $pos = 0; - |Object[] $distinctArray = new Object[$size]; - |boolean $foundNullElement = false; - |for (int $i = 0; $i < $inputArray.numElements(); $i ++) { - | if ($inputArray.isNullAt($i)) { - | $setNullForNonPrimitive; - | } else { - | $setValueForBruteForce; - | } + |$arrayBuilderClass $builder = new $arrayBuilderClass(); + |int $size = 0; + |for (int $i = 0; $i < $array1.numElements(); $i++) { + | $processArray1 |} - |${ev.value} = new $arrayClass($distinctArray); - """.stripMargin - } + |${buildResultArray(builder, ev.value, size, nullElementIndex)} + """.stripMargin + }) } else { - val primitiveValueTypeName = CodeGenerator.primitiveTypeName(elementType) - val setNullForPrimitive = setNull(true, foundNullElement, distinctArray, pos) - val classTag = s"scala.reflect.ClassTag$$.MODULE$$.$primitiveValueTypeName()" - val setValueForFast = - setValueForFastEval(true, hs, distinctArray, pos, getValue1, primitiveValueTypeName) - s""" - |${ctx.createUnsafeArray(distinctArray, size, elementType, s" $prettyName failed.")} - |int $pos = 0; - |boolean $foundNullElement = false; - |$openHashSet $hs = new $openHashSet($classTag); - |for (int $i = 0; $i < $inputArray.numElements(); $i ++) { - | if ($inputArray.isNullAt($i)) { - | $setNullForPrimitive; - | } else { - | $setValueForFast; - | } - |} - |${ev.value} = $distinctArray; - """.stripMargin + nullSafeCodeGen(ctx, ev, (array1, array2) => { + val expr = ctx.addReferenceObj("arrayExceptExpr", this) + s"${ev.value} = (ArrayData)$expr.nullSafeEval($array1, $array2);" + }) } } - override def prettyName: String = "array_distinct" + override def prettyName: String = "array_except" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index 0a5f8a907b50a..117fa3e9aa519 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -18,8 +18,8 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion} import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder -import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.util._ @@ -48,7 +48,8 @@ case class CreateArray(children: Seq[Expression]) extends Expression { override def dataType: ArrayType = { ArrayType( - children.headOption.map(_.dataType).getOrElse(StringType), + TypeCoercion.findCommonTypeDifferentOnlyInNullFlags(children.map(_.dataType)) + .getOrElse(StringType), containsNull = children.exists(_.nullable)) } @@ -179,14 +180,14 @@ case class CreateMap(children: Seq[Expression]) extends Expression { if (children.size % 2 != 0) { TypeCheckResult.TypeCheckFailure( s"$prettyName expects a positive even number of arguments.") - } else if (keys.map(_.dataType).distinct.length > 1) { + } else if (!TypeCoercion.haveSameType(keys.map(_.dataType))) { TypeCheckResult.TypeCheckFailure( "The given keys of function map should all be the same type, but they are " + - keys.map(_.dataType.simpleString).mkString("[", ", ", "]")) - } else if (values.map(_.dataType).distinct.length > 1) { + keys.map(_.dataType.catalogString).mkString("[", ", ", "]")) + } else if (!TypeCoercion.haveSameType(values.map(_.dataType))) { TypeCheckResult.TypeCheckFailure( "The given values of function map should all be the same type, but they are " + - values.map(_.dataType.simpleString).mkString("[", ", ", "]")) + values.map(_.dataType.catalogString).mkString("[", ", ", "]")) } else { TypeCheckResult.TypeCheckSuccess } @@ -194,8 +195,10 @@ case class CreateMap(children: Seq[Expression]) extends Expression { override def dataType: DataType = { MapType( - keyType = keys.headOption.map(_.dataType).getOrElse(StringType), - valueType = values.headOption.map(_.dataType).getOrElse(StringType), + keyType = TypeCoercion.findCommonTypeDifferentOnlyInNullFlags(keys.map(_.dataType)) + .getOrElse(StringType), + valueType = TypeCoercion.findCommonTypeDifferentOnlyInNullFlags(values.map(_.dataType)) + .getOrElse(StringType), valueContainsNull = values.exists(_.nullable)) } @@ -245,8 +248,8 @@ case class CreateMap(children: Seq[Expression]) extends Expression { in keys should not be null""", examples = """ Examples: - > SELECT _FUNC_([1.0, 3.0], ['2', '4']); - {1.0:"2",3.0:"4"} + > SELECT _FUNC_(array(1.0, 3.0), array('2', '4')); + [1.0 -> "2", 3.0 -> "4"] """, since = "2.4.0") case class MapFromArrays(left: Expression, right: Expression) extends BinaryExpression with ExpectsInputTypes { @@ -376,17 +379,14 @@ trait CreateNamedStructLike extends Expression { } override def checkInputDataTypes(): TypeCheckResult = { - if (children.length < 1) { - TypeCheckResult.TypeCheckFailure( - s"input to function $prettyName requires at least one argument") - } else if (children.size % 2 != 0) { + if (children.size % 2 != 0) { TypeCheckResult.TypeCheckFailure(s"$prettyName expects an even number of arguments.") } else { val invalidNames = nameExprs.filterNot(e => e.foldable && e.dataType == StringType) if (invalidNames.nonEmpty) { TypeCheckResult.TypeCheckFailure( - "Only foldable StringType expressions are allowed to appear at odd position, got:" + - s" ${invalidNames.mkString(",")}") + s"Only foldable ${StringType.catalogString} expressions are allowed to appear at odd" + + s" position, got: ${invalidNames.mkString(",")}") } else if (!names.contains(null)) { TypeCheckResult.TypeCheckSuccess } else { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala index 99671d5b863c4..8994eeff92c7f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala @@ -68,7 +68,7 @@ object ExtractValue { case StructType(_) => s"Field name should be String Literal, but it's $extraction" case other => - s"Can't extract value from $child: need struct type but got ${other.simpleString}" + s"Can't extract value from $child: need struct type but got ${other.catalogString}" } throw new AnalysisException(errorMsg) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala index 77ac6c088022e..bed581a61b2dc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion} import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.types._ @@ -33,7 +33,12 @@ import org.apache.spark.sql.types._ """) // scalastyle:on line.size.limit case class If(predicate: Expression, trueValue: Expression, falseValue: Expression) - extends Expression { + extends ComplexTypeMergingExpression { + + @transient + override lazy val inputTypesForMerging: Seq[DataType] = { + Seq(trueValue.dataType, falseValue.dataType) + } override def children: Seq[Expression] = predicate :: trueValue :: falseValue :: Nil override def nullable: Boolean = trueValue.nullable || falseValue.nullable @@ -42,17 +47,15 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi if (predicate.dataType != BooleanType) { TypeCheckResult.TypeCheckFailure( "type of predicate expression in If should be boolean, " + - s"not ${predicate.dataType.simpleString}") - } else if (!trueValue.dataType.sameType(falseValue.dataType)) { + s"not ${predicate.dataType.catalogString}") + } else if (!TypeCoercion.haveSameType(inputTypesForMerging)) { TypeCheckResult.TypeCheckFailure(s"differing types in '$sql' " + - s"(${trueValue.dataType.simpleString} and ${falseValue.dataType.simpleString}).") + s"(${trueValue.dataType.catalogString} and ${falseValue.dataType.catalogString}).") } else { TypeCheckResult.TypeCheckSuccess } } - override def dataType: DataType = trueValue.dataType - override def eval(input: InternalRow): Any = { if (java.lang.Boolean.TRUE.equals(predicate.eval(input))) { trueValue.eval(input) @@ -118,27 +121,24 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi case class CaseWhen( branches: Seq[(Expression, Expression)], elseValue: Option[Expression] = None) - extends Expression with Serializable { + extends ComplexTypeMergingExpression with Serializable { override def children: Seq[Expression] = branches.flatMap(b => b._1 :: b._2 :: Nil) ++ elseValue // both then and else expressions should be considered. - def valueTypes: Seq[DataType] = branches.map(_._2.dataType) ++ elseValue.map(_.dataType) - - def valueTypesEqual: Boolean = valueTypes.size <= 1 || valueTypes.sliding(2, 1).forall { - case Seq(dt1, dt2) => dt1.sameType(dt2) + @transient + override lazy val inputTypesForMerging: Seq[DataType] = { + branches.map(_._2.dataType) ++ elseValue.map(_.dataType) } - override def dataType: DataType = branches.head._2.dataType - override def nullable: Boolean = { // Result is nullable if any of the branch is nullable, or if the else value is nullable branches.exists(_._2.nullable) || elseValue.map(_.nullable).getOrElse(true) } override def checkInputDataTypes(): TypeCheckResult = { - // Make sure all branch conditions are boolean types. - if (valueTypesEqual) { + if (TypeCoercion.haveSameType(inputTypesForMerging)) { + // Make sure all branch conditions are boolean types. if (branches.forall(_._1.dataType == BooleanType)) { TypeCheckResult.TypeCheckSuccess } else { @@ -294,7 +294,7 @@ object CaseWhen { case cond :: value :: Nil => Some((cond, value)) case value :: Nil => None }.toArray.toSeq // force materialization to make the seq serializable - val elseValue = if (branches.size % 2 == 1) Some(branches.last) else None + val elseValue = if (branches.size % 2 != 0) Some(branches.last) else None CaseWhen(cases, elseValue) } } @@ -309,7 +309,7 @@ object CaseKeyWhen { case Seq(cond, value) => Some((EqualTo(key, cond), value)) case Seq(value) => None }.toArray.toSeq // force materialization to make the seq serializable - val elseValue = if (branches.size % 2 == 1) Some(branches.last) else None + val elseValue = if (branches.size % 2 != 0) Some(branches.last) else None CaseWhen(cases, elseValue) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/constraintExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/constraintExpressions.scala new file mode 100644 index 0000000000000..2917b0b8c9c53 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/constraintExpressions.scala @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, FalseLiteral} +import org.apache.spark.sql.types.DataType + +case class KnownNotNull(child: Expression) extends UnaryExpression { + override def nullable: Boolean = false + override def dataType: DataType = child.dataType + + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + child.genCode(ctx).copy(isNull = FalseLiteral) + } + + override def eval(input: InternalRow): Any = { + child.eval(input) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala index 08838d2b2c612..f95798d64db19 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala @@ -1345,7 +1345,7 @@ case class ParseToDate(left: Expression, format: Option[Expression], child: Expr } def this(left: Expression) = { - // backwards compatability + // backwards compatibility this(left, None, Cast(left, DateType)) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala index b7c52f1d7b40a..d6e67b9ac3d10 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala @@ -156,8 +156,8 @@ case class Stack(children: Seq[Expression]) extends Generator { val j = (i - 1) % numFields if (children(i).dataType != elementSchema.fields(j).dataType) { return TypeCheckResult.TypeCheckFailure( - s"Argument ${j + 1} (${elementSchema.fields(j).dataType.simpleString}) != " + - s"Argument $i (${children(i).dataType.simpleString})") + s"Argument ${j + 1} (${elementSchema.fields(j).dataType.catalogString}) != " + + s"Argument $i (${children(i).dataType.catalogString})") } } TypeCheckResult.TypeCheckSuccess @@ -223,6 +223,32 @@ case class Stack(children: Seq[Expression]) extends Generator { } } +/** + * Replicate the row N times. N is specified as the first argument to the function. + * This is an internal function solely used by optimizer to rewrite EXCEPT ALL AND + * INTERSECT ALL queries. + */ +case class ReplicateRows(children: Seq[Expression]) extends Generator with CodegenFallback { + private lazy val numColumns = children.length - 1 // remove the multiplier value from output. + + override def elementSchema: StructType = + StructType(children.tail.zipWithIndex.map { + case (e, index) => StructField(s"col$index", e.dataType) + }) + + override def eval(input: InternalRow): TraversableOnce[InternalRow] = { + val numRows = children.head.eval(input).asInstanceOf[Long] + val values = children.tail.map(_.eval(input)).toArray + Range.Long(0, numRows, 1).map { _ => + val fields = new Array[Any](numColumns) + for (col <- 0 until numColumns) { + fields.update(col, values(col)) + } + InternalRow(fields: _*) + } + } +} + /** * Wrapper around another generator to specify outer behavior. This is used to implement functions * such as explode_outer. This expression gets replaced during analysis. @@ -251,7 +277,7 @@ abstract class ExplodeBase extends UnaryExpression with CollectionGenerator with case _ => TypeCheckResult.TypeCheckFailure( "input to function explode should be array or map type, " + - s"not ${child.dataType.simpleString}") + s"not ${child.dataType.catalogString}") } // hive-compatible default alias for explode function ("col" for array, "key", "value" for map) @@ -381,7 +407,7 @@ case class Inline(child: Expression) extends UnaryExpression with CollectionGene case _ => TypeCheckResult.TypeCheckFailure( s"input to function $prettyName should be array of struct type, " + - s"not ${child.dataType.simpleString}") + s"not ${child.dataType.catalogString}") } override def elementSchema: StructType = child.dataType match { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala index cec00b66f873c..742a4f87a9c04 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala @@ -33,7 +33,6 @@ import org.apache.spark.sql.catalyst.util.{ArrayData, MapData} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.Platform import org.apache.spark.unsafe.hash.Murmur3_x86_32 -import org.apache.spark.unsafe.memory.MemoryBlock import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -362,7 +361,10 @@ abstract class HashExpression[E] extends Expression { } protected def genHashString(input: String, result: String): String = { - s"$result = $hasherClassName.hashUTF8String($input, $result);" + val baseObject = s"$input.getBaseObject()" + val baseOffset = s"$input.getBaseOffset()" + val numBytes = s"$input.numBytes()" + s"$result = $hasherClassName.hashUnsafeBytes($baseObject, $baseOffset, $numBytes, $result);" } protected def genHashForMap( @@ -404,14 +406,15 @@ abstract class HashExpression[E] extends Expression { input: String, result: String, fields: Array[StructField]): String = { + val tmpInput = ctx.freshName("input") val fieldsHash = fields.zipWithIndex.map { case (field, index) => - nullSafeElementHash(input, index.toString, field.nullable, field.dataType, result, ctx) + nullSafeElementHash(tmpInput, index.toString, field.nullable, field.dataType, result, ctx) } val hashResultType = CodeGenerator.javaType(dataType) - ctx.splitExpressions( + val code = ctx.splitExpressions( expressions = fieldsHash, funcName = "computeHashForStruct", - arguments = Seq("InternalRow" -> input, hashResultType -> result), + arguments = Seq("InternalRow" -> tmpInput, hashResultType -> result), returnType = hashResultType, makeSplitFunction = body => s""" @@ -419,6 +422,10 @@ abstract class HashExpression[E] extends Expression { |return $result; """.stripMargin, foldFunctions = _.map(funcCall => s"$result = $funcCall;").mkString("\n")) + s""" + |final InternalRow $tmpInput = $input; + |$code + """.stripMargin } @tailrec @@ -464,8 +471,6 @@ abstract class InterpretedHashFunction { protected def hashUnsafeBytes(base: AnyRef, offset: Long, length: Int, seed: Long): Long - protected def hashUnsafeBytesBlock(base: MemoryBlock, seed: Long): Long - /** * Computes hash of a given `value` of type `dataType`. The caller needs to check the validity * of input `value`. @@ -491,7 +496,8 @@ abstract class InterpretedHashFunction { case c: CalendarInterval => hashInt(c.months, hashLong(c.microseconds, seed)) case a: Array[Byte] => hashUnsafeBytes(a, Platform.BYTE_ARRAY_OFFSET, a.length, seed) - case s: UTF8String => hashUnsafeBytesBlock(s.getMemoryBlock(), seed) + case s: UTF8String => + hashUnsafeBytes(s.getBaseObject, s.getBaseOffset, s.numBytes(), seed) case array: ArrayData => val elementType = dataType match { @@ -578,15 +584,9 @@ object Murmur3HashFunction extends InterpretedHashFunction { Murmur3_x86_32.hashLong(l, seed.toInt) } - override protected def hashUnsafeBytes( - base: AnyRef, offset: Long, len: Int, seed: Long): Long = { + override protected def hashUnsafeBytes(base: AnyRef, offset: Long, len: Int, seed: Long): Long = { Murmur3_x86_32.hashUnsafeBytes(base, offset, len, seed.toInt) } - - override protected def hashUnsafeBytesBlock( - base: MemoryBlock, seed: Long): Long = { - Murmur3_x86_32.hashUnsafeBytesBlock(base, seed.toInt) - } } /** @@ -611,14 +611,9 @@ object XxHash64Function extends InterpretedHashFunction { override protected def hashLong(l: Long, seed: Long): Long = XXH64.hashLong(l, seed) - override protected def hashUnsafeBytes( - base: AnyRef, offset: Long, len: Int, seed: Long): Long = { + override protected def hashUnsafeBytes(base: AnyRef, offset: Long, len: Int, seed: Long): Long = { XXH64.hashUnsafeBytes(base, offset, len, seed) } - - override protected def hashUnsafeBytesBlock(base: MemoryBlock, seed: Long): Long = { - XXH64.hashUnsafeBytesBlock(base, seed) - } } /** @@ -725,7 +720,10 @@ case class HiveHash(children: Seq[Expression]) extends HashExpression[Int] { """ override protected def genHashString(input: String, result: String): String = { - s"$result = $hasherClassName.hashUTF8String($input);" + val baseObject = s"$input.getBaseObject()" + val baseOffset = s"$input.getBaseOffset()" + val numBytes = s"$input.numBytes()" + s"$result = $hasherClassName.hashUnsafeBytes($baseObject, $baseOffset, $numBytes);" } override protected def genHashForArray( @@ -778,10 +776,11 @@ case class HiveHash(children: Seq[Expression]) extends HashExpression[Int] { input: String, result: String, fields: Array[StructField]): String = { + val tmpInput = ctx.freshName("input") val childResult = ctx.freshName("childResult") val fieldsHash = fields.zipWithIndex.map { case (field, index) => val computeFieldHash = nullSafeElementHash( - input, index.toString, field.nullable, field.dataType, childResult, ctx) + tmpInput, index.toString, field.nullable, field.dataType, childResult, ctx) s""" |$childResult = 0; |$computeFieldHash @@ -789,10 +788,10 @@ case class HiveHash(children: Seq[Expression]) extends HashExpression[Int] { """.stripMargin } - s"${CodeGenerator.JAVA_INT} $childResult = 0;\n" + ctx.splitExpressions( + val code = ctx.splitExpressions( expressions = fieldsHash, funcName = "computeHashForStruct", - arguments = Seq("InternalRow" -> input, CodeGenerator.JAVA_INT -> result), + arguments = Seq("InternalRow" -> tmpInput, CodeGenerator.JAVA_INT -> result), returnType = CodeGenerator.JAVA_INT, makeSplitFunction = body => s""" @@ -801,6 +800,11 @@ case class HiveHash(children: Seq[Expression]) extends HashExpression[Int] { |return $result; """.stripMargin, foldFunctions = _.map(funcCall => s"$result = $funcCall;").mkString("\n")) + s""" + |final InternalRow $tmpInput = $input; + |${CodeGenerator.JAVA_INT} $childResult = 0; + |$code + """.stripMargin } } @@ -813,14 +817,10 @@ object HiveHashFunction extends InterpretedHashFunction { HiveHasher.hashLong(l) } - override protected def hashUnsafeBytes( - base: AnyRef, offset: Long, len: Int, seed: Long): Long = { + override protected def hashUnsafeBytes(base: AnyRef, offset: Long, len: Int, seed: Long): Long = { HiveHasher.hashUnsafeBytes(base, offset, len) } - override protected def hashUnsafeBytesBlock( - base: MemoryBlock, seed: Long): Long = HiveHasher.hashUnsafeBytesBlock(base) - private val HIVE_DECIMAL_MAX_PRECISION = 38 private val HIVE_DECIMAL_MAX_SCALE = 38 diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala new file mode 100644 index 0000000000000..3ef2ec03099e4 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala @@ -0,0 +1,846 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import java.util.concurrent.atomic.AtomicReference + +import scala.collection.mutable + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion, UnresolvedAttribute} +import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.util._ +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.array.ByteArrayMethods + +/** + * A named lambda variable. + */ +case class NamedLambdaVariable( + name: String, + dataType: DataType, + nullable: Boolean, + exprId: ExprId = NamedExpression.newExprId, + value: AtomicReference[Any] = new AtomicReference()) + extends LeafExpression + with NamedExpression + with CodegenFallback { + + override def qualifier: Seq[String] = Seq.empty + + override def newInstance(): NamedExpression = + copy(exprId = NamedExpression.newExprId, value = new AtomicReference()) + + override def toAttribute: Attribute = { + AttributeReference(name, dataType, nullable, Metadata.empty)(exprId, Seq.empty) + } + + override def eval(input: InternalRow): Any = value.get + + override def toString: String = s"lambda $name#${exprId.id}$typeSuffix" + + override def simpleString: String = s"lambda $name#${exprId.id}: ${dataType.simpleString}" +} + +/** + * A lambda function and its arguments. A lambda function can be hidden when a user wants to + * process an completely independent expression in a [[HigherOrderFunction]], the lambda function + * and its variables are then only used for internal bookkeeping within the higher order function. + */ +case class LambdaFunction( + function: Expression, + arguments: Seq[NamedExpression], + hidden: Boolean = false) + extends Expression with CodegenFallback { + + override def children: Seq[Expression] = function +: arguments + override def dataType: DataType = function.dataType + override def nullable: Boolean = function.nullable + + lazy val bound: Boolean = arguments.forall(_.resolved) + + override def eval(input: InternalRow): Any = function.eval(input) +} + +object LambdaFunction { + val identity: LambdaFunction = { + val id = UnresolvedAttribute.quoted("id") + LambdaFunction(id, Seq(id)) + } +} + +/** + * A higher order function takes one or more (lambda) functions and applies these to some objects. + * The function produces a number of variables which can be consumed by some lambda function. + */ +trait HigherOrderFunction extends Expression with ExpectsInputTypes { + + override def nullable: Boolean = arguments.exists(_.nullable) + + override def children: Seq[Expression] = arguments ++ functions + + /** + * Arguments of the higher ordered function. + */ + def arguments: Seq[Expression] + + def argumentTypes: Seq[AbstractDataType] + + /** + * All arguments have been resolved. This means that the types and nullabilty of (most of) the + * lambda function arguments is known, and that we can start binding the lambda functions. + */ + lazy val argumentsResolved: Boolean = arguments.forall(_.resolved) + + /** + * Checks the argument data types, returns `TypeCheckResult.success` if it's valid, + * or returns a `TypeCheckResult` with an error message if invalid. + * Note: it's not valid to call this method until `argumentsResolved == true`. + */ + def checkArgumentDataTypes(): TypeCheckResult = { + ExpectsInputTypes.checkInputDataTypes(arguments, argumentTypes) + } + + /** + * Functions applied by the higher order function. + */ + def functions: Seq[Expression] + + def functionTypes: Seq[AbstractDataType] + + override def inputTypes: Seq[AbstractDataType] = argumentTypes ++ functionTypes + + /** + * All inputs must be resolved and all functions must be resolved lambda functions. + */ + override lazy val resolved: Boolean = argumentsResolved && functions.forall { + case l: LambdaFunction => l.resolved + case _ => false + } + + /** + * Bind the lambda functions to the [[HigherOrderFunction]] using the given bind function. The + * bind function takes the potential lambda and it's (partial) arguments and converts this into + * a bound lambda function. + */ + def bind(f: (Expression, Seq[(DataType, Boolean)]) => LambdaFunction): HigherOrderFunction + + // Make sure the lambda variables refer the same instances as of arguments for case that the + // variables in instantiated separately during serialization or for some reason. + @transient lazy val functionsForEval: Seq[Expression] = functions.map { + case LambdaFunction(function, arguments, hidden) => + val argumentMap = arguments.map { arg => arg.exprId -> arg }.toMap + function.transformUp { + case variable: NamedLambdaVariable if argumentMap.contains(variable.exprId) => + argumentMap(variable.exprId) + } + } +} + +/** + * Trait for functions having as input one argument and one function. + */ +trait SimpleHigherOrderFunction extends HigherOrderFunction { + + def argument: Expression + + override def arguments: Seq[Expression] = argument :: Nil + + def argumentType: AbstractDataType + + override def argumentTypes(): Seq[AbstractDataType] = argumentType :: Nil + + def function: Expression + + override def functions: Seq[Expression] = function :: Nil + + def functionType: AbstractDataType = AnyDataType + + override def functionTypes: Seq[AbstractDataType] = functionType :: Nil + + def functionForEval: Expression = functionsForEval.head + + /** + * Called by [[eval]]. If a subclass keeps the default nullability, it can override this method + * in order to save null-check code. + */ + protected def nullSafeEval(inputRow: InternalRow, argumentValue: Any): Any = + sys.error(s"UnaryHigherOrderFunction must override either eval or nullSafeEval") + + override def eval(inputRow: InternalRow): Any = { + val value = argument.eval(inputRow) + if (value == null) { + null + } else { + nullSafeEval(inputRow, value) + } + } +} + +trait ArrayBasedSimpleHigherOrderFunction extends SimpleHigherOrderFunction { + override def argumentType: AbstractDataType = ArrayType +} + +trait MapBasedSimpleHigherOrderFunction extends SimpleHigherOrderFunction { + override def argumentType: AbstractDataType = MapType +} + +/** + * Transform elements in an array using the transform function. This is similar to + * a `map` in functional programming. + */ +@ExpressionDescription( + usage = "_FUNC_(expr, func) - Transforms elements in an array using the function.", + examples = """ + Examples: + > SELECT _FUNC_(array(1, 2, 3), x -> x + 1); + [2, 3, 4] + > SELECT _FUNC_(array(1, 2, 3), (x, i) -> x + i); + [1, 3, 5] + """, + since = "2.4.0") +case class ArrayTransform( + argument: Expression, + function: Expression) + extends ArrayBasedSimpleHigherOrderFunction with CodegenFallback { + + override def dataType: ArrayType = ArrayType(function.dataType, function.nullable) + + override def bind(f: (Expression, Seq[(DataType, Boolean)]) => LambdaFunction): ArrayTransform = { + val ArrayType(elementType, containsNull) = argument.dataType + function match { + case LambdaFunction(_, arguments, _) if arguments.size == 2 => + copy(function = f(function, (elementType, containsNull) :: (IntegerType, false) :: Nil)) + case _ => + copy(function = f(function, (elementType, containsNull) :: Nil)) + } + } + + @transient lazy val (elementVar, indexVar) = { + val LambdaFunction(_, (elementVar: NamedLambdaVariable) +: tail, _) = function + val indexVar = if (tail.nonEmpty) { + Some(tail.head.asInstanceOf[NamedLambdaVariable]) + } else { + None + } + (elementVar, indexVar) + } + + override def nullSafeEval(inputRow: InternalRow, argumentValue: Any): Any = { + val arr = argumentValue.asInstanceOf[ArrayData] + val f = functionForEval + val result = new GenericArrayData(new Array[Any](arr.numElements)) + var i = 0 + while (i < arr.numElements) { + elementVar.value.set(arr.get(i, elementVar.dataType)) + if (indexVar.isDefined) { + indexVar.get.value.set(i) + } + result.update(i, f.eval(inputRow)) + i += 1 + } + result + } + + override def prettyName: String = "transform" +} + +/** + * Filters entries in a map using the provided function. + */ +@ExpressionDescription( +usage = "_FUNC_(expr, func) - Filters entries in a map using the function.", +examples = """ + Examples: + > SELECT _FUNC_(map(1, 0, 2, 2, 3, -1), (k, v) -> k > v); + [1 -> 0, 3 -> -1] + """, +since = "2.4.0") +case class MapFilter( + argument: Expression, + function: Expression) + extends MapBasedSimpleHigherOrderFunction with CodegenFallback { + + @transient lazy val (keyVar, valueVar) = { + val args = function.asInstanceOf[LambdaFunction].arguments + (args.head.asInstanceOf[NamedLambdaVariable], args.tail.head.asInstanceOf[NamedLambdaVariable]) + } + + @transient lazy val MapType(keyType, valueType, valueContainsNull) = argument.dataType + + override def bind(f: (Expression, Seq[(DataType, Boolean)]) => LambdaFunction): MapFilter = { + copy(function = f(function, (keyType, false) :: (valueType, valueContainsNull) :: Nil)) + } + + override def nullSafeEval(inputRow: InternalRow, argumentValue: Any): Any = { + val m = argumentValue.asInstanceOf[MapData] + val f = functionForEval + val retKeys = new mutable.ListBuffer[Any] + val retValues = new mutable.ListBuffer[Any] + m.foreach(keyType, valueType, (k, v) => { + keyVar.value.set(k) + valueVar.value.set(v) + if (f.eval(inputRow).asInstanceOf[Boolean]) { + retKeys += k + retValues += v + } + }) + ArrayBasedMapData(retKeys.toArray, retValues.toArray) + } + + override def dataType: DataType = argument.dataType + + override def functionType: AbstractDataType = BooleanType + + override def prettyName: String = "map_filter" +} + +/** + * Filters the input array using the given lambda function. + */ +@ExpressionDescription( + usage = "_FUNC_(expr, func) - Filters the input array using the given predicate.", + examples = """ + Examples: + > SELECT _FUNC_(array(1, 2, 3), x -> x % 2 == 1); + [1, 3] + """, + since = "2.4.0") +case class ArrayFilter( + argument: Expression, + function: Expression) + extends ArrayBasedSimpleHigherOrderFunction with CodegenFallback { + + override def dataType: DataType = argument.dataType + + override def functionType: AbstractDataType = BooleanType + + override def bind(f: (Expression, Seq[(DataType, Boolean)]) => LambdaFunction): ArrayFilter = { + val ArrayType(elementType, containsNull) = argument.dataType + copy(function = f(function, (elementType, containsNull) :: Nil)) + } + + @transient lazy val LambdaFunction(_, Seq(elementVar: NamedLambdaVariable), _) = function + + override def nullSafeEval(inputRow: InternalRow, argumentValue: Any): Any = { + val arr = argumentValue.asInstanceOf[ArrayData] + val f = functionForEval + val buffer = new mutable.ArrayBuffer[Any](arr.numElements) + var i = 0 + while (i < arr.numElements) { + elementVar.value.set(arr.get(i, elementVar.dataType)) + if (f.eval(inputRow).asInstanceOf[Boolean]) { + buffer += elementVar.value.get + } + i += 1 + } + new GenericArrayData(buffer) + } + + override def prettyName: String = "filter" +} + +/** + * Tests whether a predicate holds for one or more elements in the array. + */ +@ExpressionDescription(usage = + "_FUNC_(expr, pred) - Tests whether a predicate holds for one or more elements in the array.", + examples = """ + Examples: + > SELECT _FUNC_(array(1, 2, 3), x -> x % 2 == 0); + true + """, + since = "2.4.0") +case class ArrayExists( + argument: Expression, + function: Expression) + extends ArrayBasedSimpleHigherOrderFunction with CodegenFallback { + + override def dataType: DataType = BooleanType + + override def functionType: AbstractDataType = BooleanType + + override def bind(f: (Expression, Seq[(DataType, Boolean)]) => LambdaFunction): ArrayExists = { + val ArrayType(elementType, containsNull) = argument.dataType + copy(function = f(function, (elementType, containsNull) :: Nil)) + } + + @transient lazy val LambdaFunction(_, Seq(elementVar: NamedLambdaVariable), _) = function + + override def nullSafeEval(inputRow: InternalRow, argumentValue: Any): Any = { + val arr = argumentValue.asInstanceOf[ArrayData] + val f = functionForEval + var exists = false + var i = 0 + while (i < arr.numElements && !exists) { + elementVar.value.set(arr.get(i, elementVar.dataType)) + if (f.eval(inputRow).asInstanceOf[Boolean]) { + exists = true + } + i += 1 + } + exists + } + + override def prettyName: String = "exists" +} + +/** + * Applies a binary operator to a start value and all elements in the array. + */ +@ExpressionDescription( + usage = + """ + _FUNC_(expr, start, merge, finish) - Applies a binary operator to an initial state and all + elements in the array, and reduces this to a single state. The final state is converted + into the final result by applying a finish function. + """, + examples = """ + Examples: + > SELECT _FUNC_(array(1, 2, 3), 0, (acc, x) -> acc + x); + 6 + > SELECT _FUNC_(array(1, 2, 3), 0, (acc, x) -> acc + x, acc -> acc * 10); + 60 + """, + since = "2.4.0") +case class ArrayAggregate( + argument: Expression, + zero: Expression, + merge: Expression, + finish: Expression) + extends HigherOrderFunction with CodegenFallback { + + def this(argument: Expression, zero: Expression, merge: Expression) = { + this(argument, zero, merge, LambdaFunction.identity) + } + + override def arguments: Seq[Expression] = argument :: zero :: Nil + + override def argumentTypes: Seq[AbstractDataType] = ArrayType :: AnyDataType :: Nil + + override def functions: Seq[Expression] = merge :: finish :: Nil + + override def functionTypes: Seq[AbstractDataType] = zero.dataType :: AnyDataType :: Nil + + override def nullable: Boolean = argument.nullable || finish.nullable + + override def dataType: DataType = finish.dataType + + override def checkInputDataTypes(): TypeCheckResult = { + checkArgumentDataTypes() match { + case TypeCheckResult.TypeCheckSuccess => + if (!DataType.equalsStructurally( + zero.dataType, merge.dataType, ignoreNullability = true)) { + TypeCheckResult.TypeCheckFailure( + s"argument 3 requires ${zero.dataType.simpleString} type, " + + s"however, '${merge.sql}' is of ${merge.dataType.catalogString} type.") + } else { + TypeCheckResult.TypeCheckSuccess + } + case failure => failure + } + } + + override def bind(f: (Expression, Seq[(DataType, Boolean)]) => LambdaFunction): ArrayAggregate = { + // Be very conservative with nullable. We cannot be sure that the accumulator does not + // evaluate to null. So we always set nullable to true here. + val ArrayType(elementType, containsNull) = argument.dataType + val acc = zero.dataType -> true + val newMerge = f(merge, acc :: (elementType, containsNull) :: Nil) + val newFinish = f(finish, acc :: Nil) + copy(merge = newMerge, finish = newFinish) + } + + @transient lazy val LambdaFunction(_, + Seq(accForMergeVar: NamedLambdaVariable, elementVar: NamedLambdaVariable), _) = merge + @transient lazy val LambdaFunction(_, Seq(accForFinishVar: NamedLambdaVariable), _) = finish + + override def eval(input: InternalRow): Any = { + val arr = argument.eval(input).asInstanceOf[ArrayData] + if (arr == null) { + null + } else { + val Seq(mergeForEval, finishForEval) = functionsForEval + accForMergeVar.value.set(zero.eval(input)) + var i = 0 + while (i < arr.numElements()) { + elementVar.value.set(arr.get(i, elementVar.dataType)) + accForMergeVar.value.set(mergeForEval.eval(input)) + i += 1 + } + accForFinishVar.value.set(accForMergeVar.value.get) + finishForEval.eval(input) + } + } + + override def prettyName: String = "aggregate" +} + +/** + * Transform Keys for every entry of the map by applying the transform_keys function. + * Returns map with transformed key entries + */ +@ExpressionDescription( + usage = "_FUNC_(expr, func) - Transforms elements in a map using the function.", + examples = """ + Examples: + > SELECT _FUNC_(map_from_arrays(array(1, 2, 3), array(1, 2, 3)), (k, v) -> k + 1); + [2 -> 1, 3 -> 2, 4 -> 3] + > SELECT _FUNC_(map_from_arrays(array(1, 2, 3), array(1, 2, 3)), (k, v) -> k + v); + [2 -> 1, 4 -> 2, 6 -> 3] + """, + since = "2.4.0") +case class TransformKeys( + argument: Expression, + function: Expression) + extends MapBasedSimpleHigherOrderFunction with CodegenFallback { + + @transient lazy val MapType(keyType, valueType, valueContainsNull) = argument.dataType + + override def dataType: DataType = MapType(function.dataType, valueType, valueContainsNull) + + override def bind(f: (Expression, Seq[(DataType, Boolean)]) => LambdaFunction): TransformKeys = { + copy(function = f(function, (keyType, false) :: (valueType, valueContainsNull) :: Nil)) + } + + @transient lazy val LambdaFunction( + _, (keyVar: NamedLambdaVariable) :: (valueVar: NamedLambdaVariable) :: Nil, _) = function + + + override def nullSafeEval(inputRow: InternalRow, argumentValue: Any): Any = { + val map = argumentValue.asInstanceOf[MapData] + val resultKeys = new GenericArrayData(new Array[Any](map.numElements)) + var i = 0 + while (i < map.numElements) { + keyVar.value.set(map.keyArray().get(i, keyVar.dataType)) + valueVar.value.set(map.valueArray().get(i, valueVar.dataType)) + val result = functionForEval.eval(inputRow) + if (result == null) { + throw new RuntimeException("Cannot use null as map key!") + } + resultKeys.update(i, result) + i += 1 + } + new ArrayBasedMapData(resultKeys, map.valueArray()) + } + + override def prettyName: String = "transform_keys" +} + +/** + * Returns a map that applies the function to each value of the map. + */ +@ExpressionDescription( + usage = "_FUNC_(expr, func) - Transforms values in the map using the function.", + examples = """ + Examples: + > SELECT _FUNC_(map_from_arrays(array(1, 2, 3), array(1, 2, 3)), (k, v) -> v + 1); + [1 -> 2, 2 -> 3, 3 -> 4] + > SELECT _FUNC_(map_from_arrays(array(1, 2, 3), array(1, 2, 3)), (k, v) -> k + v); + [1 -> 2, 2 -> 4, 3 -> 6] + """, + since = "2.4.0") +case class TransformValues( + argument: Expression, + function: Expression) + extends MapBasedSimpleHigherOrderFunction with CodegenFallback { + + @transient lazy val MapType(keyType, valueType, valueContainsNull) = argument.dataType + + override def dataType: DataType = MapType(keyType, function.dataType, function.nullable) + + override def bind(f: (Expression, Seq[(DataType, Boolean)]) => LambdaFunction) + : TransformValues = { + copy(function = f(function, (keyType, false) :: (valueType, valueContainsNull) :: Nil)) + } + + @transient lazy val LambdaFunction( + _, (keyVar: NamedLambdaVariable) :: (valueVar: NamedLambdaVariable) :: Nil, _) = function + + override def nullSafeEval(inputRow: InternalRow, argumentValue: Any): Any = { + val map = argumentValue.asInstanceOf[MapData] + val resultValues = new GenericArrayData(new Array[Any](map.numElements)) + var i = 0 + while (i < map.numElements) { + keyVar.value.set(map.keyArray().get(i, keyVar.dataType)) + valueVar.value.set(map.valueArray().get(i, valueVar.dataType)) + resultValues.update(i, functionForEval.eval(inputRow)) + i += 1 + } + new ArrayBasedMapData(map.keyArray(), resultValues) + } + + override def prettyName: String = "transform_values" +} + +/** + * Merges two given maps into a single map by applying function to the pair of values with + * the same key. + */ +@ExpressionDescription( + usage = + """ + _FUNC_(map1, map2, function) - Merges two given maps into a single map by applying + function to the pair of values with the same key. For keys only presented in one map, + NULL will be passed as the value for the missing key. If an input map contains duplicated + keys, only the first entry of the duplicated key is passed into the lambda function. + """, + examples = """ + Examples: + > SELECT _FUNC_(map(1, 'a', 2, 'b'), map(1, 'x', 2, 'y'), (k, v1, v2) -> concat(v1, v2)); + [1 -> "ax", 2 -> "by"] + """, + since = "2.4.0") +case class MapZipWith(left: Expression, right: Expression, function: Expression) + extends HigherOrderFunction with CodegenFallback { + + def functionForEval: Expression = functionsForEval.head + + @transient lazy val MapType(leftKeyType, leftValueType, leftValueContainsNull) = left.dataType + + @transient lazy val MapType(rightKeyType, rightValueType, rightValueContainsNull) = right.dataType + + @transient lazy val keyType = + TypeCoercion.findCommonTypeDifferentOnlyInNullFlags(leftKeyType, rightKeyType).get + + @transient lazy val ordering = TypeUtils.getInterpretedOrdering(keyType) + + override def arguments: Seq[Expression] = left :: right :: Nil + + override def argumentTypes: Seq[AbstractDataType] = MapType :: MapType :: Nil + + override def functions: Seq[Expression] = function :: Nil + + override def functionTypes: Seq[AbstractDataType] = AnyDataType :: Nil + + override def dataType: DataType = MapType(keyType, function.dataType, function.nullable) + + override def bind(f: (Expression, Seq[(DataType, Boolean)]) => LambdaFunction): MapZipWith = { + val arguments = Seq((keyType, false), (leftValueType, true), (rightValueType, true)) + copy(function = f(function, arguments)) + } + + override def checkArgumentDataTypes(): TypeCheckResult = { + super.checkArgumentDataTypes() match { + case TypeCheckResult.TypeCheckSuccess => + if (leftKeyType.sameType(rightKeyType)) { + TypeUtils.checkForOrderingExpr(leftKeyType, s"function $prettyName") + } else { + TypeCheckResult.TypeCheckFailure(s"The input to function $prettyName should have " + + s"been two ${MapType.simpleString}s with compatible key types, but the key types are " + + s"[${leftKeyType.catalogString}, ${rightKeyType.catalogString}].") + } + case failure => failure + } + } + + override def checkInputDataTypes(): TypeCheckResult = checkArgumentDataTypes() + + override def eval(input: InternalRow): Any = { + val value1 = left.eval(input) + if (value1 == null) { + null + } else { + val value2 = right.eval(input) + if (value2 == null) { + null + } else { + nullSafeEval(input, value1, value2) + } + } + } + + @transient lazy val LambdaFunction(_, Seq( + keyVar: NamedLambdaVariable, + value1Var: NamedLambdaVariable, + value2Var: NamedLambdaVariable), + _) = function + + /** + * The function accepts two key arrays and returns a collection of keys with indexes + * to value arrays. Indexes are represented as an array of two items. This is a small + * optimization leveraging mutability of arrays. + */ + @transient private lazy val getKeysWithValueIndexes: + (ArrayData, ArrayData) => mutable.Iterable[(Any, Array[Option[Int]])] = { + if (TypeUtils.typeWithProperEquals(keyType)) { + getKeysWithIndexesFast + } else { + getKeysWithIndexesBruteForce + } + } + + private def assertSizeOfArrayBuffer(size: Int): Unit = { + if (size > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) { + throw new RuntimeException(s"Unsuccessful try to zip maps with $size " + + s"unique keys due to exceeding the array size limit " + + s"${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}.") + } + } + + private def getKeysWithIndexesFast(keys1: ArrayData, keys2: ArrayData) = { + val hashMap = new mutable.LinkedHashMap[Any, Array[Option[Int]]] + for((z, array) <- Array((0, keys1), (1, keys2))) { + var i = 0 + while (i < array.numElements()) { + val key = array.get(i, keyType) + hashMap.get(key) match { + case Some(indexes) => + if (indexes(z).isEmpty) { + indexes(z) = Some(i) + } + case None => + val indexes = Array[Option[Int]](None, None) + indexes(z) = Some(i) + hashMap.put(key, indexes) + } + i += 1 + } + } + hashMap + } + + private def getKeysWithIndexesBruteForce(keys1: ArrayData, keys2: ArrayData) = { + val arrayBuffer = new mutable.ArrayBuffer[(Any, Array[Option[Int]])] + for((z, array) <- Array((0, keys1), (1, keys2))) { + var i = 0 + while (i < array.numElements()) { + val key = array.get(i, keyType) + var found = false + var j = 0 + while (!found && j < arrayBuffer.size) { + val (bufferKey, indexes) = arrayBuffer(j) + if (ordering.equiv(bufferKey, key)) { + found = true + if(indexes(z).isEmpty) { + indexes(z) = Some(i) + } + } + j += 1 + } + if (!found) { + assertSizeOfArrayBuffer(arrayBuffer.size) + val indexes = Array[Option[Int]](None, None) + indexes(z) = Some(i) + arrayBuffer += Tuple2(key, indexes) + } + i += 1 + } + } + arrayBuffer + } + + private def nullSafeEval(inputRow: InternalRow, value1: Any, value2: Any): Any = { + val mapData1 = value1.asInstanceOf[MapData] + val mapData2 = value2.asInstanceOf[MapData] + val keysWithIndexes = getKeysWithValueIndexes(mapData1.keyArray(), mapData2.keyArray()) + val size = keysWithIndexes.size + val keys = new GenericArrayData(new Array[Any](size)) + val values = new GenericArrayData(new Array[Any](size)) + val valueData1 = mapData1.valueArray() + val valueData2 = mapData2.valueArray() + var i = 0 + for ((key, Array(index1, index2)) <- keysWithIndexes) { + val v1 = index1.map(valueData1.get(_, leftValueType)).getOrElse(null) + val v2 = index2.map(valueData2.get(_, rightValueType)).getOrElse(null) + keyVar.value.set(key) + value1Var.value.set(v1) + value2Var.value.set(v2) + keys.update(i, key) + values.update(i, functionForEval.eval(inputRow)) + i += 1 + } + new ArrayBasedMapData(keys, values) + } + + override def prettyName: String = "map_zip_with" +} + +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(left, right, func) - Merges the two given arrays, element-wise, into a single array using function. If one array is shorter, nulls are appended at the end to match the length of the longer array, before applying function.", + examples = """ + Examples: + > SELECT _FUNC_(array(1, 2, 3), array('a', 'b', 'c'), (x, y) -> (y, x)); + [["a", 1], ["b", 2], ["c", 3]] + > SELECT _FUNC_(array(1, 2), array(3, 4), (x, y) -> x + y); + [4, 6] + > SELECT _FUNC_(array('a', 'b', 'c'), array('d', 'e', 'f'), (x, y) -> concat(x, y)); + ["ad", "be", "cf"] + """, + since = "2.4.0") +// scalastyle:on line.size.limit +case class ZipWith(left: Expression, right: Expression, function: Expression) + extends HigherOrderFunction with CodegenFallback { + + def functionForEval: Expression = functionsForEval.head + + override def arguments: Seq[Expression] = left :: right :: Nil + + override def argumentTypes: Seq[AbstractDataType] = ArrayType :: ArrayType :: Nil + + override def functions: Seq[Expression] = List(function) + + override def functionTypes: Seq[AbstractDataType] = AnyDataType :: Nil + + override def dataType: ArrayType = ArrayType(function.dataType, function.nullable) + + override def bind(f: (Expression, Seq[(DataType, Boolean)]) => LambdaFunction): ZipWith = { + val ArrayType(leftElementType, _) = left.dataType + val ArrayType(rightElementType, _) = right.dataType + copy(function = f(function, + (leftElementType, true) :: (rightElementType, true) :: Nil)) + } + + @transient lazy val LambdaFunction(_, + Seq(leftElemVar: NamedLambdaVariable, rightElemVar: NamedLambdaVariable), _) = function + + override def eval(input: InternalRow): Any = { + val leftArr = left.eval(input).asInstanceOf[ArrayData] + if (leftArr == null) { + null + } else { + val rightArr = right.eval(input).asInstanceOf[ArrayData] + if (rightArr == null) { + null + } else { + val resultLength = math.max(leftArr.numElements(), rightArr.numElements()) + val f = functionForEval + val result = new GenericArrayData(new Array[Any](resultLength)) + var i = 0 + while (i < resultLength) { + if (i < leftArr.numElements()) { + leftElemVar.value.set(leftArr.get(i, leftElemVar.dataType)) + } else { + leftElemVar.value.set(null) + } + if (i < rightArr.numElements()) { + rightElemVar.value.set(rightArr.get(i, rightElemVar.dataType)) + } else { + rightElemVar.value.set(null) + } + result.update(i, f.eval(input)) + i += 1 + } + result + } + } + } + + override def prettyName: String = "zip_with" +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala index f6d74f5b74c8e..bd9090a07471b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.expressions -import java.io.{ByteArrayInputStream, ByteArrayOutputStream, CharArrayWriter, InputStreamReader, StringWriter} +import java.io._ import scala.util.parsing.combinator.RegexParsers @@ -28,7 +28,8 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.catalyst.json._ -import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, BadRecordException, FailFastMode, GenericArrayData, MapData} +import org.apache.spark.sql.catalyst.json.JsonInferSchema.inferField +import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -494,7 +495,7 @@ case class JsonTuple(children: Seq[Expression]) } /** - * Converts an json input string to a [[StructType]] or [[ArrayType]] of [[StructType]]s + * Converts an json input string to a [[StructType]], [[ArrayType]] or [[MapType]] * with the specified schema. */ // scalastyle:off line.size.limit @@ -513,10 +514,11 @@ case class JsonToStructs( schema: DataType, options: Map[String, String], child: Expression, - timeZoneId: Option[String], - forceNullableSchema: Boolean) + timeZoneId: Option[String] = None) extends UnaryExpression with TimeZoneAwareExpression with CodegenFallback with ExpectsInputTypes { + val forceNullableSchema = SQLConf.get.getConf(SQLConf.FROM_JSON_FORCE_NULLABLE_SCHEMA) + // The JSON input data might be missing certain fields. We force the nullability // of the user-provided schema to avoid data corruptions. In particular, the parquet-mr encoder // can generate incorrect files if values are missing in columns declared as non-nullable. @@ -525,39 +527,27 @@ case class JsonToStructs( override def nullable: Boolean = true // Used in `FunctionRegistry` - def this(child: Expression, schema: Expression) = + def this(child: Expression, schema: Expression, options: Map[String, String]) = this( - schema = JsonExprUtils.validateSchemaLiteral(schema), - options = Map.empty[String, String], + schema = JsonExprUtils.evalSchemaExpr(schema), + options = options, child = child, - timeZoneId = None, - forceNullableSchema = SQLConf.get.getConf(SQLConf.FROM_JSON_FORCE_NULLABLE_SCHEMA)) + timeZoneId = None) + + def this(child: Expression, schema: Expression) = this(child, schema, Map.empty[String, String]) def this(child: Expression, schema: Expression, options: Expression) = this( - schema = JsonExprUtils.validateSchemaLiteral(schema), + schema = JsonExprUtils.evalSchemaExpr(schema), options = JsonExprUtils.convertToMapData(options), child = child, - timeZoneId = None, - forceNullableSchema = SQLConf.get.getConf(SQLConf.FROM_JSON_FORCE_NULLABLE_SCHEMA)) - - // Used in `org.apache.spark.sql.functions` - def this(schema: DataType, options: Map[String, String], child: Expression) = - this(schema, options, child, timeZoneId = None, - forceNullableSchema = SQLConf.get.getConf(SQLConf.FROM_JSON_FORCE_NULLABLE_SCHEMA)) + timeZoneId = None) override def checkInputDataTypes(): TypeCheckResult = nullableSchema match { - case _: StructType | ArrayType(_: StructType, _) | _: MapType => + case _: StructType | _: ArrayType | _: MapType => super.checkInputDataTypes() case _ => TypeCheckResult.TypeCheckFailure( - s"Input schema ${nullableSchema.simpleString} must be a struct or an array of structs.") - } - - @transient - lazy val rowSchema = nullableSchema match { - case st: StructType => st - case ArrayType(st: StructType, _) => st - case mt: MapType => mt + s"Input schema ${nullableSchema.catalogString} must be a struct, an array or a map.") } // This converts parsed rows to the desired output by the given schema. @@ -565,8 +555,8 @@ case class JsonToStructs( lazy val converter = nullableSchema match { case _: StructType => (rows: Seq[InternalRow]) => if (rows.length == 1) rows.head else null - case ArrayType(_: StructType, _) => - (rows: Seq[InternalRow]) => new GenericArrayData(rows) + case _: ArrayType => + (rows: Seq[InternalRow]) => rows.head.getArray(0) case _: MapType => (rows: Seq[InternalRow]) => rows.head.getMap(0) } @@ -574,7 +564,7 @@ case class JsonToStructs( @transient lazy val parser = new JacksonParser( - rowSchema, + nullableSchema, new JSONOptions(options + ("mode" -> FailFastMode.name), timeZoneId.get)) override def dataType: DataType = nullableSchema @@ -623,19 +613,18 @@ case class JsonToStructs( } /** - * Converts a [[StructType]], [[ArrayType]] of [[StructType]]s, [[MapType]] - * or [[ArrayType]] of [[MapType]]s to a json output string. + * Converts a [[StructType]], [[ArrayType]] or [[MapType]] to a JSON output string. */ // scalastyle:off line.size.limit @ExpressionDescription( - usage = "_FUNC_(expr[, options]) - Returns a json string with a given struct value", + usage = "_FUNC_(expr[, options]) - Returns a JSON string with a given struct value", examples = """ Examples: > SELECT _FUNC_(named_struct('a', 1, 'b', 2)); {"a":1,"b":2} > SELECT _FUNC_(named_struct('time', to_timestamp('2015-08-26', 'yyyy-MM-dd')), map('timestampFormat', 'dd/MM/yyyy')); {"time":"26/08/2015"} - > SELECT _FUNC_(array(named_struct('a', 1, 'b', 2)); + > SELECT _FUNC_(array(named_struct('a', 1, 'b', 2))); [{"a":1,"b":2}] > SELECT _FUNC_(map('a', named_struct('b', 1))); {"a":{"b":1}} @@ -670,15 +659,10 @@ case class StructsToJson( @transient lazy val gen = new JacksonGenerator( - rowSchema, writer, new JSONOptions(options, timeZoneId.get)) + inputSchema, writer, new JSONOptions(options, timeZoneId.get)) @transient - lazy val rowSchema = child.dataType match { - case st: StructType => st - case ArrayType(st: StructType, _) => st - case mt: MapType => mt - case ArrayType(mt: MapType, _) => mt - } + lazy val inputSchema = child.dataType // This converts rows to the JSON output according to the given schema. @transient @@ -690,12 +674,12 @@ case class StructsToJson( UTF8String.fromString(json) } - child.dataType match { + inputSchema match { case _: StructType => (row: Any) => gen.write(row.asInstanceOf[InternalRow]) getAndReset() - case ArrayType(_: StructType, _) => + case _: ArrayType => (arr: Any) => gen.write(arr.asInstanceOf[ArrayData]) getAndReset() @@ -703,36 +687,40 @@ case class StructsToJson( (map: Any) => gen.write(map.asInstanceOf[MapData]) getAndReset() - case ArrayType(_: MapType, _) => - (arr: Any) => - gen.write(arr.asInstanceOf[ArrayData]) - getAndReset() } } override def dataType: DataType = StringType - override def checkInputDataTypes(): TypeCheckResult = child.dataType match { - case _: StructType | ArrayType(_: StructType, _) => + override def checkInputDataTypes(): TypeCheckResult = inputSchema match { + case struct: StructType => try { - JacksonUtils.verifySchema(rowSchema.asInstanceOf[StructType]) + JacksonUtils.verifySchema(struct) TypeCheckResult.TypeCheckSuccess } catch { case e: UnsupportedOperationException => TypeCheckResult.TypeCheckFailure(e.getMessage) } - case _: MapType | ArrayType(_: MapType, _) => + case map: MapType => // TODO: let `JacksonUtils.verifySchema` verify a `MapType` try { - val st = StructType(StructField("a", rowSchema.asInstanceOf[MapType]) :: Nil) + val st = StructType(StructField("a", map) :: Nil) JacksonUtils.verifySchema(st) TypeCheckResult.TypeCheckSuccess } catch { case e: UnsupportedOperationException => TypeCheckResult.TypeCheckFailure(e.getMessage) } + case array: ArrayType => + try { + JacksonUtils.verifyType(prettyName, array) + TypeCheckResult.TypeCheckSuccess + } catch { + case e: UnsupportedOperationException => + TypeCheckResult.TypeCheckFailure(e.getMessage) + } case _ => TypeCheckResult.TypeCheckFailure( - s"Input type ${child.dataType.simpleString} must be a struct, array of structs or " + + s"Input type ${child.dataType.catalogString} must be a struct, array of structs or " + "a map or array of map.") } @@ -744,11 +732,44 @@ case class StructsToJson( override def inputTypes: Seq[AbstractDataType] = TypeCollection(ArrayType, StructType) :: Nil } +/** + * A function infers schema of JSON string. + */ +@ExpressionDescription( + usage = "_FUNC_(json[, options]) - Returns schema in the DDL format of JSON string.", + examples = """ + Examples: + > SELECT _FUNC_('[{"col":0}]'); + array> + """, + since = "2.4.0") +case class SchemaOfJson(child: Expression) + extends UnaryExpression with String2StringExpression with CodegenFallback { + + private val jsonOptions = new JSONOptions(Map.empty, "UTC") + private val jsonFactory = new JsonFactory() + jsonOptions.setJacksonOptions(jsonFactory) + + override def convert(v: UTF8String): UTF8String = { + val dt = Utils.tryWithResource(CreateJacksonParser.utf8String(jsonFactory, v)) { parser => + parser.nextToken() + inferField(parser, jsonOptions) + } + + UTF8String.fromString(dt.catalogString) + } +} + object JsonExprUtils { - def validateSchemaLiteral(exp: Expression): DataType = exp match { + def evalSchemaExpr(exp: Expression): DataType = exp match { case Literal(s, StringType) => DataType.fromDDL(s.toString) - case e => throw new AnalysisException(s"Expected a string literal instead of $e") + case e @ SchemaOfJson(_: Literal) => + val ddlSchema = e.eval().asInstanceOf[UTF8String] + DataType.fromDDL(ddlSchema.toString) + case e => throw new AnalysisException( + "Schema should be specified in DDL format as a string literal" + + s" or output of the schema_of_json function instead of ${e.sql}") } def convertToMapData(exp: Expression): Map[String, String] = exp match { @@ -760,7 +781,7 @@ object JsonExprUtils { } case m: CreateMap => throw new AnalysisException( - s"A type of keys and values in map() must be string, but got ${m.dataType}") + s"A type of keys and values in map() must be string, but got ${m.dataType.catalogString}") case _ => throw new AnalysisException("Must use a map() function for options") } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala index 0cc2a332f2c30..2bcbb92f1a469 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala @@ -128,30 +128,36 @@ object Literal { val dataType = DataType.parseDataType(json \ "dataType") json \ "value" match { case JNull => Literal.create(null, dataType) - case JString(str) => - val value = dataType match { - case BooleanType => str.toBoolean - case ByteType => str.toByte - case ShortType => str.toShort - case IntegerType => str.toInt - case LongType => str.toLong - case FloatType => str.toFloat - case DoubleType => str.toDouble - case StringType => UTF8String.fromString(str) - case DateType => java.sql.Date.valueOf(str) - case TimestampType => java.sql.Timestamp.valueOf(str) - case CalendarIntervalType => CalendarInterval.fromString(str) - case t: DecimalType => - val d = Decimal(str) - assert(d.changePrecision(t.precision, t.scale)) - d - case _ => null - } - Literal.create(value, dataType) + case JString(str) => fromString(str, dataType) case other => sys.error(s"$other is not a valid Literal json value") } } + /** + * Constructs a Literal from a String + */ + def fromString(str: String, dataType: DataType): Literal = { + val value = dataType match { + case BooleanType => str.toBoolean + case ByteType => str.toByte + case ShortType => str.toShort + case IntegerType => str.toInt + case LongType => str.toLong + case FloatType => str.toFloat + case DoubleType => str.toDouble + case StringType => UTF8String.fromString(str) + case DateType => java.sql.Date.valueOf(str) + case TimestampType => java.sql.Timestamp.valueOf(str) + case CalendarIntervalType => CalendarInterval.fromString(str) + case t: DecimalType => + val d = Decimal(str) + assert(d.changePrecision(t.precision, t.scale)) + d + case _ => null + } + Literal.create(value, dataType) + } + def create(v: Any, dataType: DataType): Literal = { Literal(CatalystTypeConverters.convertToCatalyst(v), dataType) } @@ -186,7 +192,7 @@ object Literal { case map: MapType => create(Map(), map) case struct: StructType => create(InternalRow.fromSeq(struct.fields.map(f => default(f.dataType).value)), struct) - case udt: UserDefinedType[_] => default(udt.sqlType) + case udt: UserDefinedType[_] => Literal(default(udt.sqlType).value, udt) case other => throw new RuntimeException(s"no default for type $dataType") } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/maskExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/maskExpressions.scala deleted file mode 100644 index 276a57266a6e0..0000000000000 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/maskExpressions.scala +++ /dev/null @@ -1,569 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.expressions - -import org.apache.commons.codec.digest.DigestUtils - -import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.expressions.MaskExpressionsUtils._ -import org.apache.spark.sql.catalyst.expressions.MaskLike._ -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} -import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.types.UTF8String - - -trait MaskLike { - def upper: String - def lower: String - def digit: String - - protected lazy val upperReplacement: Int = getReplacementChar(upper, defaultMaskedUppercase) - protected lazy val lowerReplacement: Int = getReplacementChar(lower, defaultMaskedLowercase) - protected lazy val digitReplacement: Int = getReplacementChar(digit, defaultMaskedDigit) - - protected val maskUtilsClassName: String = classOf[MaskExpressionsUtils].getName - - def inputStringLengthCode(inputString: String, length: String): String = { - s"${CodeGenerator.JAVA_INT} $length = $inputString.codePointCount(0, $inputString.length());" - } - - def appendMaskedToStringBuilderCode( - ctx: CodegenContext, - sb: String, - inputString: String, - offset: String, - numChars: String): String = { - val i = ctx.freshName("i") - val codePoint = ctx.freshName("codePoint") - s""" - |for (${CodeGenerator.JAVA_INT} $i = 0; $i < $numChars; $i++) { - | ${CodeGenerator.JAVA_INT} $codePoint = $inputString.codePointAt($offset); - | $sb.appendCodePoint($maskUtilsClassName.transformChar($codePoint, - | $upperReplacement, $lowerReplacement, - | $digitReplacement, $defaultMaskedOther)); - | $offset += Character.charCount($codePoint); - |} - """.stripMargin - } - - def appendUnchangedToStringBuilderCode( - ctx: CodegenContext, - sb: String, - inputString: String, - offset: String, - numChars: String): String = { - val i = ctx.freshName("i") - val codePoint = ctx.freshName("codePoint") - s""" - |for (${CodeGenerator.JAVA_INT} $i = 0; $i < $numChars; $i++) { - | ${CodeGenerator.JAVA_INT} $codePoint = $inputString.codePointAt($offset); - | $sb.appendCodePoint($codePoint); - | $offset += Character.charCount($codePoint); - |} - """.stripMargin - } - - def appendMaskedToStringBuilder( - sb: java.lang.StringBuilder, - inputString: String, - startOffset: Int, - numChars: Int): Int = { - var offset = startOffset - (1 to numChars) foreach { _ => - val codePoint = inputString.codePointAt(offset) - sb.appendCodePoint(transformChar( - codePoint, - upperReplacement, - lowerReplacement, - digitReplacement, - defaultMaskedOther)) - offset += Character.charCount(codePoint) - } - offset - } - - def appendUnchangedToStringBuilder( - sb: java.lang.StringBuilder, - inputString: String, - startOffset: Int, - numChars: Int): Int = { - var offset = startOffset - (1 to numChars) foreach { _ => - val codePoint = inputString.codePointAt(offset) - sb.appendCodePoint(codePoint) - offset += Character.charCount(codePoint) - } - offset - } -} - -trait MaskLikeWithN extends MaskLike { - def n: Int - protected lazy val charCount: Int = if (n < 0) 0 else n -} - -/** - * Utils for mask operations. - */ -object MaskLike { - val defaultCharCount = 4 - val defaultMaskedUppercase: Int = 'X' - val defaultMaskedLowercase: Int = 'x' - val defaultMaskedDigit: Int = 'n' - val defaultMaskedOther: Int = MaskExpressionsUtils.UNMASKED_VAL - - def extractCharCount(e: Expression): Int = e match { - case Literal(i, IntegerType | NullType) => - if (i == null) defaultCharCount else i.asInstanceOf[Int] - case Literal(_, dt) => throw new AnalysisException("Expected literal expression of type " + - s"${IntegerType.simpleString}, but got literal of ${dt.simpleString}") - case other => throw new AnalysisException(s"Expected literal expression, but got ${other.sql}") - } - - def extractReplacement(e: Expression): String = e match { - case Literal(s, StringType | NullType) => if (s == null) null else s.toString - case Literal(_, dt) => throw new AnalysisException("Expected literal expression of type " + - s"${StringType.simpleString}, but got literal of ${dt.simpleString}") - case other => throw new AnalysisException(s"Expected literal expression, but got ${other.sql}") - } -} - -/** - * Masks the input string. Additional parameters can be set to change the masking chars for - * uppercase letters, lowercase letters and digits. - */ -// scalastyle:off line.size.limit -@ExpressionDescription( - usage = "_FUNC_(str[, upper[, lower[, digit]]]) - Masks str. By default, upper case letters are converted to \"X\", lower case letters are converted to \"x\" and numbers are converted to \"n\". You can override the characters used in the mask by supplying additional arguments: the second argument controls the mask character for upper case letters, the third argument for lower case letters and the fourth argument for numbers.", - examples = """ - Examples: - > SELECT _FUNC_("abcd-EFGH-8765-4321", "U", "l", "#"); - llll-UUUU-####-#### - """) -// scalastyle:on line.size.limit -case class Mask(child: Expression, upper: String, lower: String, digit: String) - extends UnaryExpression with ExpectsInputTypes with MaskLike { - - def this(child: Expression) = this(child, null.asInstanceOf[String], null, null) - - def this(child: Expression, upper: Expression) = - this(child, extractReplacement(upper), null, null) - - def this(child: Expression, upper: Expression, lower: Expression) = - this(child, extractReplacement(upper), extractReplacement(lower), null) - - def this(child: Expression, upper: Expression, lower: Expression, digit: Expression) = - this(child, extractReplacement(upper), extractReplacement(lower), extractReplacement(digit)) - - override def nullSafeEval(input: Any): Any = { - val str = input.asInstanceOf[UTF8String].toString - val length = str.codePointCount(0, str.length()) - val sb = new java.lang.StringBuilder(length) - appendMaskedToStringBuilder(sb, str, 0, length) - UTF8String.fromString(sb.toString) - } - - override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - nullSafeCodeGen(ctx, ev, (input: String) => { - val sb = ctx.freshName("sb") - val length = ctx.freshName("length") - val offset = ctx.freshName("offset") - val inputString = ctx.freshName("inputString") - s""" - |String $inputString = $input.toString(); - |${inputStringLengthCode(inputString, length)} - |StringBuilder $sb = new StringBuilder($length); - |${CodeGenerator.JAVA_INT} $offset = 0; - |${appendMaskedToStringBuilderCode(ctx, sb, inputString, offset, length)} - |${ev.value} = UTF8String.fromString($sb.toString()); - """.stripMargin - }) - } - - override def dataType: DataType = StringType - - override def inputTypes: Seq[AbstractDataType] = Seq(StringType) -} - -/** - * Masks the first N chars of the input string. N defaults to 4. Additional parameters can be set - * to change the masking chars for uppercase letters, lowercase letters and digits. - */ -// scalastyle:off line.size.limit -@ExpressionDescription( - usage = "_FUNC_(str[, n[, upper[, lower[, digit]]]]) - Masks the first n values of str. By default, n is 4, upper case letters are converted to \"X\", lower case letters are converted to \"x\" and numbers are converted to \"n\". You can override the characters used in the mask by supplying additional arguments: the second argument controls the mask character for upper case letters, the third argument for lower case letters and the fourth argument for numbers.", - examples = """ - Examples: - > SELECT _FUNC_("1234-5678-8765-4321", 4); - nnnn-5678-8765-4321 - """) -// scalastyle:on line.size.limit -case class MaskFirstN( - child: Expression, - n: Int, - upper: String, - lower: String, - digit: String) - extends UnaryExpression with ExpectsInputTypes with MaskLikeWithN { - - def this(child: Expression) = - this(child, defaultCharCount, null, null, null) - - def this(child: Expression, n: Expression) = - this(child, extractCharCount(n), null, null, null) - - def this(child: Expression, n: Expression, upper: Expression) = - this(child, extractCharCount(n), extractReplacement(upper), null, null) - - def this(child: Expression, n: Expression, upper: Expression, lower: Expression) = - this(child, extractCharCount(n), extractReplacement(upper), extractReplacement(lower), null) - - def this( - child: Expression, - n: Expression, - upper: Expression, - lower: Expression, - digit: Expression) = - this(child, - extractCharCount(n), - extractReplacement(upper), - extractReplacement(lower), - extractReplacement(digit)) - - override def nullSafeEval(input: Any): Any = { - val str = input.asInstanceOf[UTF8String].toString - val length = str.codePointCount(0, str.length()) - val endOfMask = if (charCount > length) length else charCount - val sb = new java.lang.StringBuilder(length) - val offset = appendMaskedToStringBuilder(sb, str, 0, endOfMask) - appendUnchangedToStringBuilder(sb, str, offset, length - endOfMask) - UTF8String.fromString(sb.toString) - } - - override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - nullSafeCodeGen(ctx, ev, (input: String) => { - val sb = ctx.freshName("sb") - val length = ctx.freshName("length") - val offset = ctx.freshName("offset") - val inputString = ctx.freshName("inputString") - val endOfMask = ctx.freshName("endOfMask") - s""" - |String $inputString = $input.toString(); - |${inputStringLengthCode(inputString, length)} - |${CodeGenerator.JAVA_INT} $endOfMask = $charCount > $length ? $length : $charCount; - |${CodeGenerator.JAVA_INT} $offset = 0; - |StringBuilder $sb = new StringBuilder($length); - |${appendMaskedToStringBuilderCode(ctx, sb, inputString, offset, endOfMask)} - |${appendUnchangedToStringBuilderCode( - ctx, sb, inputString, offset, s"$length - $endOfMask")} - |${ev.value} = UTF8String.fromString($sb.toString()); - |""".stripMargin - }) - } - - override def dataType: DataType = StringType - - override def inputTypes: Seq[AbstractDataType] = Seq(StringType) - - override def prettyName: String = "mask_first_n" -} - -/** - * Masks the last N chars of the input string. N defaults to 4. Additional parameters can be set - * to change the masking chars for uppercase letters, lowercase letters and digits. - */ -// scalastyle:off line.size.limit -@ExpressionDescription( - usage = "_FUNC_(str[, n[, upper[, lower[, digit]]]]) - Masks the last n values of str. By default, n is 4, upper case letters are converted to \"X\", lower case letters are converted to \"x\" and numbers are converted to \"n\". You can override the characters used in the mask by supplying additional arguments: the second argument controls the mask character for upper case letters, the third argument for lower case letters and the fourth argument for numbers.", - examples = """ - Examples: - > SELECT _FUNC_("1234-5678-8765-4321", 4); - 1234-5678-8765-nnnn - """, since = "2.4.0") -// scalastyle:on line.size.limit -case class MaskLastN( - child: Expression, - n: Int, - upper: String, - lower: String, - digit: String) - extends UnaryExpression with ExpectsInputTypes with MaskLikeWithN { - - def this(child: Expression) = - this(child, defaultCharCount, null, null, null) - - def this(child: Expression, n: Expression) = - this(child, extractCharCount(n), null, null, null) - - def this(child: Expression, n: Expression, upper: Expression) = - this(child, extractCharCount(n), extractReplacement(upper), null, null) - - def this(child: Expression, n: Expression, upper: Expression, lower: Expression) = - this(child, extractCharCount(n), extractReplacement(upper), extractReplacement(lower), null) - - def this( - child: Expression, - n: Expression, - upper: Expression, - lower: Expression, - digit: Expression) = - this(child, - extractCharCount(n), - extractReplacement(upper), - extractReplacement(lower), - extractReplacement(digit)) - - override def nullSafeEval(input: Any): Any = { - val str = input.asInstanceOf[UTF8String].toString - val length = str.codePointCount(0, str.length()) - val startOfMask = if (charCount >= length) 0 else length - charCount - val sb = new java.lang.StringBuilder(length) - val offset = appendUnchangedToStringBuilder(sb, str, 0, startOfMask) - appendMaskedToStringBuilder(sb, str, offset, length - startOfMask) - UTF8String.fromString(sb.toString) - } - - override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - nullSafeCodeGen(ctx, ev, (input: String) => { - val sb = ctx.freshName("sb") - val length = ctx.freshName("length") - val offset = ctx.freshName("offset") - val inputString = ctx.freshName("inputString") - val startOfMask = ctx.freshName("startOfMask") - s""" - |String $inputString = $input.toString(); - |${inputStringLengthCode(inputString, length)} - |${CodeGenerator.JAVA_INT} $startOfMask = $charCount >= $length ? - | 0 : $length - $charCount; - |${CodeGenerator.JAVA_INT} $offset = 0; - |StringBuilder $sb = new StringBuilder($length); - |${appendUnchangedToStringBuilderCode(ctx, sb, inputString, offset, startOfMask)} - |${appendMaskedToStringBuilderCode( - ctx, sb, inputString, offset, s"$length - $startOfMask")} - |${ev.value} = UTF8String.fromString($sb.toString()); - |""".stripMargin - }) - } - - override def dataType: DataType = StringType - - override def inputTypes: Seq[AbstractDataType] = Seq(StringType) - - override def prettyName: String = "mask_last_n" -} - -/** - * Masks all but the first N chars of the input string. N defaults to 4. Additional parameters can - * be set to change the masking chars for uppercase letters, lowercase letters and digits. - */ -// scalastyle:off line.size.limit -@ExpressionDescription( - usage = "_FUNC_(str[, n[, upper[, lower[, digit]]]]) - Masks all but the first n values of str. By default, n is 4, upper case letters are converted to \"X\", lower case letters are converted to \"x\" and numbers are converted to \"n\". You can override the characters used in the mask by supplying additional arguments: the second argument controls the mask character for upper case letters, the third argument for lower case letters and the fourth argument for numbers.", - examples = """ - Examples: - > SELECT _FUNC_("1234-5678-8765-4321", 4); - 1234-nnnn-nnnn-nnnn - """, since = "2.4.0") -// scalastyle:on line.size.limit -case class MaskShowFirstN( - child: Expression, - n: Int, - upper: String, - lower: String, - digit: String) - extends UnaryExpression with ExpectsInputTypes with MaskLikeWithN { - - def this(child: Expression) = - this(child, defaultCharCount, null, null, null) - - def this(child: Expression, n: Expression) = - this(child, extractCharCount(n), null, null, null) - - def this(child: Expression, n: Expression, upper: Expression) = - this(child, extractCharCount(n), extractReplacement(upper), null, null) - - def this(child: Expression, n: Expression, upper: Expression, lower: Expression) = - this(child, extractCharCount(n), extractReplacement(upper), extractReplacement(lower), null) - - def this( - child: Expression, - n: Expression, - upper: Expression, - lower: Expression, - digit: Expression) = - this(child, - extractCharCount(n), - extractReplacement(upper), - extractReplacement(lower), - extractReplacement(digit)) - - override def nullSafeEval(input: Any): Any = { - val str = input.asInstanceOf[UTF8String].toString - val length = str.codePointCount(0, str.length()) - val startOfMask = if (charCount > length) length else charCount - val sb = new java.lang.StringBuilder(length) - val offset = appendUnchangedToStringBuilder(sb, str, 0, startOfMask) - appendMaskedToStringBuilder(sb, str, offset, length - startOfMask) - UTF8String.fromString(sb.toString) - } - - override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - nullSafeCodeGen(ctx, ev, (input: String) => { - val sb = ctx.freshName("sb") - val length = ctx.freshName("length") - val offset = ctx.freshName("offset") - val inputString = ctx.freshName("inputString") - val startOfMask = ctx.freshName("startOfMask") - s""" - |String $inputString = $input.toString(); - |${inputStringLengthCode(inputString, length)} - |${CodeGenerator.JAVA_INT} $startOfMask = $charCount > $length ? $length : $charCount; - |${CodeGenerator.JAVA_INT} $offset = 0; - |StringBuilder $sb = new StringBuilder($length); - |${appendUnchangedToStringBuilderCode(ctx, sb, inputString, offset, startOfMask)} - |${appendMaskedToStringBuilderCode( - ctx, sb, inputString, offset, s"$length - $startOfMask")} - |${ev.value} = UTF8String.fromString($sb.toString()); - |""".stripMargin - }) - } - - override def dataType: DataType = StringType - - override def inputTypes: Seq[AbstractDataType] = Seq(StringType) - - override def prettyName: String = "mask_show_first_n" -} - -/** - * Masks all but the last N chars of the input string. N defaults to 4. Additional parameters can - * be set to change the masking chars for uppercase letters, lowercase letters and digits. - */ -// scalastyle:off line.size.limit -@ExpressionDescription( - usage = "_FUNC_(str[, n[, upper[, lower[, digit]]]]) - Masks all but the last n values of str. By default, n is 4, upper case letters are converted to \"X\", lower case letters are converted to \"x\" and numbers are converted to \"n\". You can override the characters used in the mask by supplying additional arguments: the second argument controls the mask character for upper case letters, the third argument for lower case letters and the fourth argument for numbers.", - examples = """ - Examples: - > SELECT _FUNC_("1234-5678-8765-4321", 4); - nnnn-nnnn-nnnn-4321 - """, since = "2.4.0") -// scalastyle:on line.size.limit -case class MaskShowLastN( - child: Expression, - n: Int, - upper: String, - lower: String, - digit: String) - extends UnaryExpression with ExpectsInputTypes with MaskLikeWithN { - - def this(child: Expression) = - this(child, defaultCharCount, null, null, null) - - def this(child: Expression, n: Expression) = - this(child, extractCharCount(n), null, null, null) - - def this(child: Expression, n: Expression, upper: Expression) = - this(child, extractCharCount(n), extractReplacement(upper), null, null) - - def this(child: Expression, n: Expression, upper: Expression, lower: Expression) = - this(child, extractCharCount(n), extractReplacement(upper), extractReplacement(lower), null) - - def this( - child: Expression, - n: Expression, - upper: Expression, - lower: Expression, - digit: Expression) = - this(child, - extractCharCount(n), - extractReplacement(upper), - extractReplacement(lower), - extractReplacement(digit)) - - override def nullSafeEval(input: Any): Any = { - val str = input.asInstanceOf[UTF8String].toString - val length = str.codePointCount(0, str.length()) - val endOfMask = if (charCount >= length) 0 else length - charCount - val sb = new java.lang.StringBuilder(length) - val offset = appendMaskedToStringBuilder(sb, str, 0, endOfMask) - appendUnchangedToStringBuilder(sb, str, offset, length - endOfMask) - UTF8String.fromString(sb.toString) - } - - override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - nullSafeCodeGen(ctx, ev, (input: String) => { - val sb = ctx.freshName("sb") - val length = ctx.freshName("length") - val offset = ctx.freshName("offset") - val inputString = ctx.freshName("inputString") - val endOfMask = ctx.freshName("endOfMask") - s""" - |String $inputString = $input.toString(); - |${inputStringLengthCode(inputString, length)} - |${CodeGenerator.JAVA_INT} $endOfMask = $charCount >= $length ? 0 : $length - $charCount; - |${CodeGenerator.JAVA_INT} $offset = 0; - |StringBuilder $sb = new StringBuilder($length); - |${appendMaskedToStringBuilderCode(ctx, sb, inputString, offset, endOfMask)} - |${appendUnchangedToStringBuilderCode( - ctx, sb, inputString, offset, s"$length - $endOfMask")} - |${ev.value} = UTF8String.fromString($sb.toString()); - |""".stripMargin - }) - } - - override def dataType: DataType = StringType - - override def inputTypes: Seq[AbstractDataType] = Seq(StringType) - - override def prettyName: String = "mask_show_last_n" -} - -/** - * Returns a hashed value based on str. - */ -// scalastyle:off line.size.limit -@ExpressionDescription( - usage = "_FUNC_(str) - Returns a hashed value based on str. The hash is consistent and can be used to join masked values together across tables.", - examples = """ - Examples: - > SELECT _FUNC_("abcd-EFGH-8765-4321"); - 60c713f5ec6912229d2060df1c322776 - """) -// scalastyle:on line.size.limit -case class MaskHash(child: Expression) - extends UnaryExpression with ExpectsInputTypes { - - override def nullSafeEval(input: Any): Any = { - UTF8String.fromString(DigestUtils.md5Hex(input.asInstanceOf[UTF8String].toString)) - } - - override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - nullSafeCodeGen(ctx, ev, (input: String) => { - val digestUtilsClass = classOf[DigestUtils].getName.stripSuffix("$") - s""" - |${ev.value} = UTF8String.fromString($digestUtilsClass.md5Hex($input.toString())); - |""".stripMargin - }) - } - - override def dataType: DataType = StringType - - override def inputTypes: Seq[AbstractDataType] = Seq(StringType) - - override def prettyName: String = "mask_hash" -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala index 5d98dac46cf17..0cdeda9b10516 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala @@ -126,10 +126,13 @@ case class CurrentDatabase() extends LeafExpression with Unevaluable { """, note = "The function is non-deterministic.") // scalastyle:on line.size.limit -case class Uuid(randomSeed: Option[Long] = None) extends LeafExpression with Stateful { +case class Uuid(randomSeed: Option[Long] = None) extends LeafExpression with Stateful + with ExpressionWithRandomSeed { def this() = this(None) + override def withNewSeed(seed: Long): Uuid = Uuid(Some(seed)) + override lazy val resolved: Boolean = randomSeed.isDefined override def nullable: Boolean = false diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala index 8df870468c2ad..584a2946bd564 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala @@ -40,7 +40,16 @@ object NamedExpression { * * The `id` field is unique within a given JVM, while the `uuid` is used to uniquely identify JVMs. */ -case class ExprId(id: Long, jvmId: UUID) +case class ExprId(id: Long, jvmId: UUID) { + + override def equals(other: Any): Boolean = other match { + case ExprId(id, jvmId) => this.id == id && this.jvmId == jvmId + case _ => false + } + + override def hashCode(): Int = id.hashCode() + +} object ExprId { def apply(id: Long): ExprId = ExprId(id, NamedExpression.jvmId) @@ -62,19 +71,22 @@ trait NamedExpression extends Expression { * multiple qualifiers, it is possible that there are other possible way to refer to this * attribute. */ - def qualifiedName: String = (qualifier.toSeq :+ name).mkString(".") + def qualifiedName: String = (qualifier :+ name).mkString(".") /** * Optional qualifier for the expression. + * Qualifier can also contain the fully qualified information, for e.g, Sequence of string + * containing the database and the table name * * For now, since we do not allow using original table name to qualify a column name once the * table is aliased, this can only be: * * 1. Empty Seq: when an attribute doesn't have a qualifier, * e.g. top level attributes aliased in the SELECT clause, or column from a LocalRelation. - * 2. Single element: either the table name or the alias name of the table. + * 2. Seq with a Single element: either the table name or the alias name of the table. + * 3. Seq with 2 elements: database name and table name */ - def qualifier: Option[String] + def qualifier: Seq[String] def toAttribute: Attribute @@ -100,7 +112,7 @@ abstract class Attribute extends LeafExpression with NamedExpression with NullIn override def references: AttributeSet = AttributeSet(this) def withNullability(newNullability: Boolean): Attribute - def withQualifier(newQualifier: Option[String]): Attribute + def withQualifier(newQualifier: Seq[String]): Attribute def withName(newName: String): Attribute def withMetadata(newMetadata: Metadata): Attribute @@ -121,14 +133,14 @@ abstract class Attribute extends LeafExpression with NamedExpression with NullIn * @param name The name to be associated with the result of computing [[child]]. * @param exprId A globally unique id used to check if an [[AttributeReference]] refers to this * alias. Auto-assigned if left blank. - * @param qualifier An optional string that can be used to referred to this attribute in a fully - * qualified way. Consider the examples tableName.name, subQueryAlias.name. - * tableName and subQueryAlias are possible qualifiers. + * @param qualifier An optional Seq of string that can be used to refer to this attribute in a + * fully qualified way. Consider the examples tableName.name, subQueryAlias.name. + * tableName and subQueryAlias are possible qualifiers. * @param explicitMetadata Explicit metadata associated with this alias that overwrites child's. */ case class Alias(child: Expression, name: String)( val exprId: ExprId = NamedExpression.newExprId, - val qualifier: Option[String] = None, + val qualifier: Seq[String] = Seq.empty, val explicitMetadata: Option[Metadata] = None) extends UnaryExpression with NamedExpression { @@ -192,7 +204,7 @@ case class Alias(child: Expression, name: String)( } override def sql: String = { - val qualifierPrefix = qualifier.map(_ + ".").getOrElse("") + val qualifierPrefix = if (qualifier.nonEmpty) qualifier.mkString(".") + "." else "" s"${child.sql} AS $qualifierPrefix${quoteIdentifier(name)}" } } @@ -216,9 +228,11 @@ case class AttributeReference( nullable: Boolean = true, override val metadata: Metadata = Metadata.empty)( val exprId: ExprId = NamedExpression.newExprId, - val qualifier: Option[String] = None) + val qualifier: Seq[String] = Seq.empty[String]) extends Attribute with Unevaluable { + // currently can only handle qualifier of length 2 + require(qualifier.length <= 2) /** * Returns true iff the expression id is the same for both attributes. */ @@ -277,7 +291,7 @@ case class AttributeReference( /** * Returns a copy of this [[AttributeReference]] with new qualifier. */ - override def withQualifier(newQualifier: Option[String]): AttributeReference = { + override def withQualifier(newQualifier: Seq[String]): AttributeReference = { if (newQualifier == qualifier) { this } else { @@ -315,7 +329,7 @@ case class AttributeReference( override def simpleString: String = s"$name#${exprId.id}: ${dataType.simpleString}" override def sql: String = { - val qualifierPrefix = qualifier.map(_ + ".").getOrElse("") + val qualifierPrefix = if (qualifier.nonEmpty) qualifier.mkString(".") + "." else "" s"$qualifierPrefix${quoteIdentifier(name)}" } } @@ -341,12 +355,12 @@ case class PrettyAttribute( override def withNullability(newNullability: Boolean): Attribute = throw new UnsupportedOperationException override def newInstance(): Attribute = throw new UnsupportedOperationException - override def withQualifier(newQualifier: Option[String]): Attribute = + override def withQualifier(newQualifier: Seq[String]): Attribute = throw new UnsupportedOperationException override def withName(newName: String): Attribute = throw new UnsupportedOperationException override def withMetadata(newMetadata: Metadata): Attribute = throw new UnsupportedOperationException - override def qualifier: Option[String] = throw new UnsupportedOperationException + override def qualifier: Seq[String] = throw new UnsupportedOperationException override def exprId: ExprId = throw new UnsupportedOperationException override def nullable: Boolean = true } @@ -362,7 +376,7 @@ case class OuterReference(e: NamedExpression) override def prettyName: String = "outer" override def name: String = e.name - override def qualifier: Option[String] = e.qualifier + override def qualifier: Seq[String] = e.qualifier override def exprId: ExprId = e.exprId override def toAttribute: Attribute = e.toAttribute override def newInstance(): NamedExpression = OuterReference(e.newInstance()) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala index 2eeed3bbb2d91..b683d2a7e9ef3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion} import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.util.TypeUtils @@ -44,7 +44,7 @@ import org.apache.spark.sql.types._ 1 """) // scalastyle:on line.size.limit -case class Coalesce(children: Seq[Expression]) extends Expression { +case class Coalesce(children: Seq[Expression]) extends ComplexTypeMergingExpression { /** Coalesce is nullable if all of its children are nullable, or if it has no children. */ override def nullable: Boolean = children.forall(_.nullable) @@ -61,8 +61,6 @@ case class Coalesce(children: Seq[Expression]) extends Expression { } } - override def dataType: DataType = children.head.dataType - override def eval(input: InternalRow): Any = { var result: Any = null val childIterator = children.iterator diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 2bf4203d0fec3..3189e6841a525 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -1727,7 +1727,7 @@ case class ValidateExternalType(child: Expression, expected: DataType) override val dataType: DataType = RowEncoder.externalDataTypeForInput(expected) - private val errMsg = s" is not a valid external type for schema of ${expected.simpleString}" + private val errMsg = s" is not a valid external type for schema of ${expected.catalogString}" private lazy val checkType: (Any) => Boolean = expected match { case _: DecimalType => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala index 8a06daa37132d..11dcc3ebf798c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala @@ -152,10 +152,22 @@ package object expressions { unique(attrs.groupBy(_.name.toLowerCase(Locale.ROOT))) } - /** Map to use for qualified case insensitive attribute lookups. */ - @transient private val qualified: Map[(String, String), Seq[Attribute]] = { - val grouped = attrs.filter(_.qualifier.isDefined).groupBy { a => - (a.qualifier.get.toLowerCase(Locale.ROOT), a.name.toLowerCase(Locale.ROOT)) + /** Map to use for qualified case insensitive attribute lookups with 2 part key */ + @transient private lazy val qualified: Map[(String, String), Seq[Attribute]] = { + // key is 2 part: table/alias and name + val grouped = attrs.filter(_.qualifier.nonEmpty).groupBy { + a => (a.qualifier.last.toLowerCase(Locale.ROOT), a.name.toLowerCase(Locale.ROOT)) + } + unique(grouped) + } + + /** Map to use for qualified case insensitive attribute lookups with 3 part key */ + @transient private val qualified3Part: Map[(String, String, String), Seq[Attribute]] = { + // key is 3 part: database name, table name and name + val grouped = attrs.filter(_.qualifier.length == 2).groupBy { a => + (a.qualifier.head.toLowerCase(Locale.ROOT), + a.qualifier.last.toLowerCase(Locale.ROOT), + a.name.toLowerCase(Locale.ROOT)) } unique(grouped) } @@ -169,25 +181,48 @@ package object expressions { }) } - // Find matches for the given name assuming that the 1st part is a qualifier (i.e. table name, - // alias, or subquery alias) and the 2nd part is the actual name. This returns a tuple of + // Find matches for the given name assuming that the 1st two parts are qualifier + // (i.e. database name and table name) and the 3rd part is the actual column name. + // + // For example, consider an example where "db1" is the database name, "a" is the table name + // and "b" is the column name and "c" is the struct field name. + // If the name parts is db1.a.b.c, then Attribute will match + // Attribute(b, qualifier("db1,"a")) and List("c") will be the second element + var matches: (Seq[Attribute], Seq[String]) = nameParts match { + case dbPart +: tblPart +: name +: nestedFields => + val key = (dbPart.toLowerCase(Locale.ROOT), + tblPart.toLowerCase(Locale.ROOT), name.toLowerCase(Locale.ROOT)) + val attributes = collectMatches(name, qualified3Part.get(key)).filter { + a => (resolver(dbPart, a.qualifier.head) && resolver(tblPart, a.qualifier.last)) + } + (attributes, nestedFields) + case all => + (Seq.empty, Seq.empty) + } + + // If there are no matches, then find matches for the given name assuming that + // the 1st part is a qualifier (i.e. table name, alias, or subquery alias) and the + // 2nd part is the actual name. This returns a tuple of // matched attributes and a list of parts that are to be resolved. // // For example, consider an example where "a" is the table name, "b" is the column name, // and "c" is the struct field name, i.e. "a.b.c". In this case, Attribute will be "a.b", // and the second element will be List("c"). - val matches = nameParts match { - case qualifier +: name +: nestedFields => - val key = (qualifier.toLowerCase(Locale.ROOT), name.toLowerCase(Locale.ROOT)) - val attributes = collectMatches(name, qualified.get(key)).filter { a => - resolver(qualifier, a.qualifier.get) - } - (attributes, nestedFields) - case all => - (Nil, all) + if (matches._1.isEmpty) { + matches = nameParts match { + case qualifier +: name +: nestedFields => + val key = (qualifier.toLowerCase(Locale.ROOT), name.toLowerCase(Locale.ROOT)) + val attributes = collectMatches(name, qualified.get(key)).filter { a => + resolver(qualifier, a.qualifier.last) + } + (attributes, nestedFields) + case all => + (Seq.empty[Attribute], Seq.empty[String]) + } } - // If none of attributes match `table.column` pattern, we try to resolve it as a column. + // If none of attributes match database.table.column pattern or + // `table.column` pattern, we try to resolve it as a column. val (candidates, nestedFields) = matches match { case (Seq(), _) => val name = nameParts.head diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index f54103c4fbfba..149bd79278a54 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -138,6 +138,66 @@ case class Not(child: Expression) override def sql: String = s"(NOT ${child.sql})" } +/** + * Evaluates to `true` if `values` are returned in `query`'s result set. + */ +case class InSubquery(values: Seq[Expression], query: ListQuery) + extends Predicate with Unevaluable { + + @transient lazy val value: Expression = if (values.length > 1) { + CreateNamedStruct(values.zipWithIndex.flatMap { + case (v: NamedExpression, _) => Seq(Literal(v.name), v) + case (v, idx) => Seq(Literal(s"_$idx"), v) + }) + } else { + values.head + } + + + override def checkInputDataTypes(): TypeCheckResult = { + val mismatchOpt = !DataType.equalsStructurally(query.dataType, value.dataType, + ignoreNullability = true) + if (mismatchOpt) { + if (values.length != query.childOutputs.length) { + TypeCheckResult.TypeCheckFailure( + s""" + |The number of columns in the left hand side of an IN subquery does not match the + |number of columns in the output of subquery. + |#columns in left hand side: ${values.length}. + |#columns in right hand side: ${query.childOutputs.length}. + |Left side columns: + |[${values.map(_.sql).mkString(", ")}]. + |Right side columns: + |[${query.childOutputs.map(_.sql).mkString(", ")}].""".stripMargin) + } else { + val mismatchedColumns = values.zip(query.childOutputs).flatMap { + case (l, r) if l.dataType != r.dataType => + Seq(s"(${l.sql}:${l.dataType.catalogString}, ${r.sql}:${r.dataType.catalogString})") + case _ => None + } + TypeCheckResult.TypeCheckFailure( + s""" + |The data type of one or more elements in the left hand side of an IN subquery + |is not compatible with the data type of the output of the subquery + |Mismatched columns: + |[${mismatchedColumns.mkString(", ")}] + |Left side: + |[${values.map(_.dataType.catalogString).mkString(", ")}]. + |Right side: + |[${query.childOutputs.map(_.dataType.catalogString).mkString(", ")}].""".stripMargin) + } + } else { + TypeUtils.checkForOrderingExpr(value.dataType, s"function $prettyName") + } + } + + override def children: Seq[Expression] = values :+ query + override def nullable: Boolean = children.exists(_.nullable) + override def foldable: Boolean = children.forall(_.foldable) + override def toString: String = s"$value IN ($query)" + override def sql: String = s"(${value.sql} IN (${query.sql}))" +} + /** * Evaluates to `true` if `list` contains `value`. @@ -169,44 +229,8 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate { val mismatchOpt = list.find(l => !DataType.equalsStructurally(l.dataType, value.dataType, ignoreNullability = true)) if (mismatchOpt.isDefined) { - list match { - case ListQuery(_, _, _, childOutputs) :: Nil => - val valExprs = value match { - case cns: CreateNamedStruct => cns.valExprs - case expr => Seq(expr) - } - if (valExprs.length != childOutputs.length) { - TypeCheckResult.TypeCheckFailure( - s""" - |The number of columns in the left hand side of an IN subquery does not match the - |number of columns in the output of subquery. - |#columns in left hand side: ${valExprs.length}. - |#columns in right hand side: ${childOutputs.length}. - |Left side columns: - |[${valExprs.map(_.sql).mkString(", ")}]. - |Right side columns: - |[${childOutputs.map(_.sql).mkString(", ")}].""".stripMargin) - } else { - val mismatchedColumns = valExprs.zip(childOutputs).flatMap { - case (l, r) if l.dataType != r.dataType => - s"(${l.sql}:${l.dataType.catalogString}, ${r.sql}:${r.dataType.catalogString})" - case _ => None - } - TypeCheckResult.TypeCheckFailure( - s""" - |The data type of one or more elements in the left hand side of an IN subquery - |is not compatible with the data type of the output of the subquery - |Mismatched columns: - |[${mismatchedColumns.mkString(", ")}] - |Left side: - |[${valExprs.map(_.dataType.catalogString).mkString(", ")}]. - |Right side: - |[${childOutputs.map(_.dataType.catalogString).mkString(", ")}].""".stripMargin) - } - case _ => - TypeCheckResult.TypeCheckFailure(s"Arguments must be same type but were: " + - s"${value.dataType.simpleString} != ${mismatchOpt.get.dataType.simpleString}") - } + TypeCheckResult.TypeCheckFailure(s"Arguments must be same type but were: " + + s"${value.dataType.catalogString} != ${mismatchOpt.get.dataType.catalogString}") } else { TypeUtils.checkForOrderingExpr(value.dataType, s"function $prettyName") } @@ -307,9 +331,8 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate { } override def sql: String = { - val childrenSQL = children.map(_.sql) - val valueSQL = childrenSQL.head - val listSQL = childrenSQL.tail.mkString(", ") + val valueSQL = value.sql + val listSQL = list.map(_.sql).mkString(", ") s"($valueSQL IN ($listSQL))" } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala index 926c2f00d430d..b70c34141b97d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala @@ -57,6 +57,14 @@ abstract class RDG extends UnaryExpression with ExpectsInputTypes with Stateful override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(IntegerType, LongType)) } +/** + * Represents the behavior of expressions which have a random seed and can renew the seed. + * Usually the random seed needs to be renewed at each execution under streaming queries. + */ +trait ExpressionWithRandomSeed { + def withNewSeed(seed: Long): Expression +} + /** Generate a random column with i.i.d. uniformly distributed values in [0, 1). */ // scalastyle:off line.size.limit @ExpressionDescription( @@ -72,10 +80,12 @@ abstract class RDG extends UnaryExpression with ExpectsInputTypes with Stateful """, note = "The function is non-deterministic in general case.") // scalastyle:on line.size.limit -case class Rand(child: Expression) extends RDG { +case class Rand(child: Expression) extends RDG with ExpressionWithRandomSeed { def this() = this(Literal(Utils.random.nextLong(), LongType)) + override def withNewSeed(seed: Long): Rand = Rand(Literal(seed, LongType)) + override protected def evalInternal(input: InternalRow): Double = rng.nextDouble() override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { @@ -110,10 +120,12 @@ object Rand { """, note = "The function is non-deterministic in general case.") // scalastyle:on line.size.limit -case class Randn(child: Expression) extends RDG { +case class Randn(child: Expression) extends RDG with ExpressionWithRandomSeed { def this() = this(Literal(Utils.random.nextLong(), LongType)) + override def withNewSeed(seed: Long): Randn = Randn(Literal(seed, LongType)) + override protected def evalInternal(input: InternalRow): Double = rng.nextGaussian() override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala index 7b68bb771faf3..bf0c35fe61018 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala @@ -272,7 +272,7 @@ case class StringSplit(str: Expression, pattern: Expression) usage = "_FUNC_(str, regexp, rep) - Replaces all substrings of `str` that match `regexp` with `rep`.", examples = """ Examples: - > SELECT _FUNC_('100-200', '(\d+)', 'num'); + > SELECT _FUNC_('100-200', '(\\d+)', 'num'); num-num """) // scalastyle:on line.size.limit @@ -371,7 +371,7 @@ case class RegExpReplace(subject: Expression, regexp: Expression, rep: Expressio usage = "_FUNC_(str, regexp[, idx]) - Extracts a group that matches `regexp`.", examples = """ Examples: - > SELECT _FUNC_('100-200', '(\d+)-(\d+)', 1); + > SELECT _FUNC_('100-200', '(\\d+)-(\\d+)', 1); 100 """) case class RegExpExtract(subject: Expression, regexp: Expression, idx: Expression) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index bedad7da334ae..14faa62bde7d0 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -91,7 +91,7 @@ case class ConcatWs(children: Seq[Expression]) val args = ctx.freshName("args") val inputs = strings.zipWithIndex.map { case (eval, index) => - if (eval.isNull != "true") { + if (eval.isNull != TrueLiteral) { s""" ${eval.code} if (!${eval.isNull}) { @@ -123,14 +123,14 @@ case class ConcatWs(children: Seq[Expression]) child.dataType match { case StringType => ("", // we count all the StringType arguments num at once below. - if (eval.isNull == "true") { + if (eval.isNull == TrueLiteral) { "" } else { s"$array[$idxVararg ++] = ${eval.isNull} ? (UTF8String) null : ${eval.value};" }) case _: ArrayType => val size = ctx.freshName("n") - if (eval.isNull == "true") { + if (eval.isNull == TrueLiteral) { ("", "") } else { (s""" @@ -222,12 +222,13 @@ case class Elt(children: Seq[Expression]) extends Expression { val (indexType, inputTypes) = (indexExpr.dataType, inputExprs.map(_.dataType)) if (indexType != IntegerType) { return TypeCheckResult.TypeCheckFailure(s"first input to function $prettyName should " + - s"have IntegerType, but it's $indexType") + s"have ${IntegerType.catalogString}, but it's ${indexType.catalogString}") } if (inputTypes.exists(tpe => !Seq(StringType, BinaryType).contains(tpe))) { return TypeCheckResult.TypeCheckFailure( - s"input to function $prettyName should have StringType or BinaryType, but it's " + - inputTypes.map(_.simpleString).mkString("[", ", ", "]")) + s"input to function $prettyName should have ${StringType.catalogString} or " + + s"${BinaryType.catalogString}, but it's " + + inputTypes.map(_.catalogString).mkString("[", ", ", "]")) } TypeUtils.checkForSameTypeInputExpr(inputTypes, s"function $prettyName") } @@ -1553,10 +1554,9 @@ case class Left(str: Expression, len: Expression, child: Expression) extends Run * A function that returns the char length of the given string expression or * number of bytes of the given binary expression. */ +// scalastyle:off line.size.limit @ExpressionDescription( - usage = "_FUNC_(expr) - Returns the character length of string data or number of bytes of " + - "binary data. The length of string data includes the trailing spaces. The length of binary " + - "data includes binary zeros.", + usage = "_FUNC_(expr) - Returns the character length of string data or number of bytes of binary data. The length of string data includes the trailing spaces. The length of binary data includes binary zeros.", examples = """ Examples: > SELECT _FUNC_('Spark SQL '); @@ -1566,6 +1566,7 @@ case class Left(str: Expression, len: Expression, child: Expression) extends Run > SELECT CHARACTER_LENGTH('Spark SQL '); 10 """) +// scalastyle:on line.size.limit case class Length(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { override def dataType: DataType = IntegerType override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(StringType, BinaryType)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala index 6acc87a3e7367..fc1caed84e272 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala @@ -117,10 +117,10 @@ object SubExprUtils extends PredicateHelper { def hasNullAwarePredicateWithinNot(condition: Expression): Boolean = { splitConjunctivePredicates(condition).exists { case _: Exists | Not(_: Exists) => false - case In(_, Seq(_: ListQuery)) | Not(In(_, Seq(_: ListQuery))) => false + case _: InSubquery | Not(_: InSubquery) => false case e => e.find { x => x.isInstanceOf[Not] && e.find { - case In(_, Seq(_: ListQuery)) => true + case _: InSubquery => true case _ => false }.isDefined }.isDefined diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala index f957aaa96e98c..707f312499734 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala @@ -21,6 +21,7 @@ import java.util.Locale import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, UnresolvedException} import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess} +import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateFunction, DeclarativeAggregate, NoOp} import org.apache.spark.sql.types._ @@ -70,9 +71,9 @@ case class WindowSpecDefinition( case f: SpecifiedWindowFrame if f.frameType == RangeFrame && f.isValueBound && !isValidFrameType(f.valueBoundary.head.dataType) => TypeCheckFailure( - s"The data type '${orderSpec.head.dataType.simpleString}' used in the order " + + s"The data type '${orderSpec.head.dataType.catalogString}' used in the order " + "specification does not match the data type " + - s"'${f.valueBoundary.head.dataType.simpleString}' which is used in the range frame.") + s"'${f.valueBoundary.head.dataType.catalogString}' which is used in the range frame.") case _ => TypeCheckSuccess } } @@ -251,7 +252,7 @@ case class SpecifiedWindowFrame( TypeCheckFailure(s"Window frame $location bound '$e' is not a literal.") case e: Expression if !frameType.inputType.acceptsType(e.dataType) => TypeCheckFailure( - s"The data type of the $location bound '${e.dataType.simpleString}' does not match " + + s"The data type of the $location bound '${e.dataType.catalogString}' does not match " + s"the expected data type '${frameType.inputType.simpleString}'.") case _ => TypeCheckSuccess } @@ -476,7 +477,7 @@ abstract class RowNumberLike extends AggregateWindowFunction { protected val rowNumber = AttributeReference("rowNumber", IntegerType, nullable = false)() override val aggBufferAttributes: Seq[AttributeReference] = rowNumber :: Nil override val initialValues: Seq[Expression] = zero :: Nil - override val updateExpressions: Seq[Expression] = Add(rowNumber, one) :: Nil + override val updateExpressions: Seq[Expression] = rowNumber + one :: Nil } /** @@ -527,7 +528,7 @@ case class CumeDist() extends RowNumberLike with SizeBasedWindowFunction { // The frame for CUME_DIST is Range based instead of Row based, because CUME_DIST must // return the same value for equal values in the partition. override val frame = SpecifiedWindowFrame(RangeFrame, UnboundedPreceding, CurrentRow) - override val evaluateExpression = Divide(Cast(rowNumber, DoubleType), Cast(n, DoubleType)) + override val evaluateExpression = rowNumber.cast(DoubleType) / n.cast(DoubleType) override def prettyName: String = "cume_dist" } @@ -587,8 +588,7 @@ case class NTile(buckets: Expression) extends RowNumberLike with SizeBasedWindow private val bucketSize = AttributeReference("bucketSize", IntegerType, nullable = false)() private val bucketsWithPadding = AttributeReference("bucketsWithPadding", IntegerType, nullable = false)() - private def bucketOverflow(e: Expression) = - If(GreaterThanOrEqual(rowNumber, bucketThreshold), e, zero) + private def bucketOverflow(e: Expression) = If(rowNumber >= bucketThreshold, e, zero) override val aggBufferAttributes = Seq( rowNumber, @@ -602,15 +602,14 @@ case class NTile(buckets: Expression) extends RowNumberLike with SizeBasedWindow zero, zero, zero, - Cast(Divide(n, buckets), IntegerType), - Cast(Remainder(n, buckets), IntegerType) + (n / buckets).cast(IntegerType), + (n % buckets).cast(IntegerType) ) override val updateExpressions = Seq( - Add(rowNumber, one), - Add(bucket, bucketOverflow(one)), - Add(bucketThreshold, bucketOverflow( - Add(bucketSize, If(LessThan(bucket, bucketsWithPadding), one, zero)))), + rowNumber + one, + bucket + bucketOverflow(one), + bucketThreshold + bucketOverflow(bucketSize + If(bucket < bucketsWithPadding, one, zero)), NoOp, NoOp ) @@ -644,7 +643,7 @@ abstract class RankLike extends AggregateWindowFunction { protected val rowNumber = AttributeReference("rowNumber", IntegerType, nullable = false)() protected val zero = Literal(0) protected val one = Literal(1) - protected val increaseRowNumber = Add(rowNumber, one) + protected val increaseRowNumber = rowNumber + one /** * Different RankLike implementations use different source expressions to update their rank value. @@ -653,7 +652,7 @@ abstract class RankLike extends AggregateWindowFunction { protected def rankSource: Expression = rowNumber /** Increase the rank when the current rank == 0 or when the one of order attributes changes. */ - protected val increaseRank = If(And(orderEquals, Not(EqualTo(rank, zero))), rank, rankSource) + protected val increaseRank = If(orderEquals && rank =!= zero, rank, rankSource) override val aggBufferAttributes: Seq[AttributeReference] = rank +: rowNumber +: orderAttrs override val initialValues = zero +: one +: orderInit @@ -707,7 +706,7 @@ case class Rank(children: Seq[Expression]) extends RankLike { case class DenseRank(children: Seq[Expression]) extends RankLike { def this() = this(Nil) override def withOrder(order: Seq[Expression]): DenseRank = DenseRank(order) - override protected def rankSource = Add(rank, one) + override protected def rankSource = rank + one override val updateExpressions = increaseRank +: children override val aggBufferAttributes = rank +: orderAttrs override val initialValues = zero +: orderInit @@ -736,8 +735,7 @@ case class PercentRank(children: Seq[Expression]) extends RankLike with SizeBase def this() = this(Nil) override def withOrder(order: Seq[Expression]): PercentRank = PercentRank(order) override def dataType: DataType = DoubleType - override val evaluateExpression = If(GreaterThan(n, one), - Divide(Cast(Subtract(rank, one), DoubleType), Cast(Subtract(n, one), DoubleType)), - Literal(0.0d)) + override val evaluateExpression = + If(n > one, (rank - one).cast(DoubleType) / (n - one).cast(DoubleType), 0.0d) override def prettyName: String = "percent_rank" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/identifiers.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/identifiers.scala index a3cc4529b5456..deceec73dda30 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/identifiers.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/identifiers.scala @@ -47,6 +47,22 @@ sealed trait IdentifierWithDatabase { override def toString: String = quotedString } +/** + * Encapsulates an identifier that is either a alias name or an identifier that has table + * name and optionally a database name. + * The SubqueryAlias node keeps track of the qualifier using the information in this structure + * @param identifier - Is an alias name or a table name + * @param database - Is a database name and is optional + */ +case class AliasIdentifier(identifier: String, database: Option[String]) + extends IdentifierWithDatabase { + + def this(identifier: String) = this(identifier, None) +} + +object AliasIdentifier { + def apply(identifier: String): AliasIdentifier = new AliasIdentifier(identifier) +} /** * Identifies a table in a database. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonGenerator.scala index 9c413de752a8c..9b86d865622dc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonGenerator.scala @@ -18,7 +18,6 @@ package org.apache.spark.sql.catalyst.json import java.io.Writer -import java.nio.charset.StandardCharsets import com.fasterxml.jackson.core._ @@ -28,7 +27,7 @@ import org.apache.spark.sql.catalyst.util.{ArrayData, DateTimeUtils, MapData} import org.apache.spark.sql.types._ /** - * `JackGenerator` can only be initialized with a `StructType` or a `MapType`. + * `JackGenerator` can only be initialized with a `StructType`, a `MapType` or an `ArrayType`. * Once it is initialized with `StructType`, it can be used to write out a struct or an array of * struct. Once it is initialized with `MapType`, it can be used to write out a map or an array * of map. An exception will be thrown if trying to write out a struct if it is initialized with @@ -43,34 +42,32 @@ private[sql] class JacksonGenerator( // we can directly access data in `ArrayData` without the help of `SpecificMutableRow`. private type ValueWriter = (SpecializedGetters, Int) => Unit - // `JackGenerator` can only be initialized with a `StructType` or a `MapType`. - require(dataType.isInstanceOf[StructType] || dataType.isInstanceOf[MapType], - "JacksonGenerator only supports to be initialized with a StructType " + - s"or MapType but got ${dataType.simpleString}") + // `JackGenerator` can only be initialized with a `StructType`, a `MapType` or a `ArrayType`. + require(dataType.isInstanceOf[StructType] || dataType.isInstanceOf[MapType] + || dataType.isInstanceOf[ArrayType], + s"JacksonGenerator only supports to be initialized with a ${StructType.simpleString}, " + + s"${MapType.simpleString} or ${ArrayType.simpleString} but got ${dataType.catalogString}") // `ValueWriter`s for all fields of the schema private lazy val rootFieldWriters: Array[ValueWriter] = dataType match { case st: StructType => st.map(_.dataType).map(makeWriter).toArray case _ => throw new UnsupportedOperationException( - s"Initial type ${dataType.simpleString} must be a struct") + s"Initial type ${dataType.catalogString} must be a ${StructType.simpleString}") } // `ValueWriter` for array data storing rows of the schema. private lazy val arrElementWriter: ValueWriter = dataType match { - case st: StructType => - (arr: SpecializedGetters, i: Int) => { - writeObject(writeFields(arr.getStruct(i, st.length), st, rootFieldWriters)) - } - case mt: MapType => - (arr: SpecializedGetters, i: Int) => { - writeObject(writeMapData(arr.getMap(i), mt, mapElementWriter)) - } + case at: ArrayType => makeWriter(at.elementType) + case _: StructType | _: MapType => makeWriter(dataType) + case _ => throw new UnsupportedOperationException( + s"Initial type ${dataType.catalogString} must be " + + s"an ${ArrayType.simpleString}, a ${StructType.simpleString} or a ${MapType.simpleString}") } private lazy val mapElementWriter: ValueWriter = dataType match { case mt: MapType => makeWriter(mt.valueType) case _ => throw new UnsupportedOperationException( - s"Initial type ${dataType.simpleString} must be a map") + s"Initial type ${dataType.catalogString} must be a ${MapType.simpleString}") } private val gen = new JsonFactory().createGenerator(writer).setRootValueSeparator(null) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala index c3a4ca8f64bf6..984979ac5e9b4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.json import java.io.{ByteArrayOutputStream, CharConversionException} +import java.nio.charset.MalformedInputException import scala.collection.mutable.ArrayBuffer import scala.util.Try @@ -61,6 +62,7 @@ class JacksonParser( dt match { case st: StructType => makeStructRootConverter(st) case mt: MapType => makeMapRootConverter(mt) + case at: ArrayType => makeArrayRootConverter(at) } } @@ -101,6 +103,35 @@ class JacksonParser( } } + private def makeArrayRootConverter(at: ArrayType): JsonParser => Seq[InternalRow] = { + val elemConverter = makeConverter(at.elementType) + (parser: JsonParser) => parseJsonToken[Seq[InternalRow]](parser, at) { + case START_ARRAY => Seq(InternalRow(convertArray(parser, elemConverter))) + case START_OBJECT if at.elementType.isInstanceOf[StructType] => + // This handles the case when an input JSON object is a structure but + // the specified schema is an array of structures. In that case, the input JSON is + // considered as an array of only one element of struct type. + // This behavior was introduced by changes for SPARK-19595. + // + // For example, if the specified schema is ArrayType(new StructType().add("i", IntegerType)) + // and JSON input as below: + // + // [{"i": 1}, {"i": 2}] + // [{"i": 3}] + // {"i": 4} + // + // The last row is considered as an array with one element, and result of conversion: + // + // Seq(Row(1), Row(2)) + // Seq(Row(3)) + // Seq(Row(4)) + // + val st = at.elementType.asInstanceOf[StructType] + val fieldConverters = st.map(_.dataType).map(makeConverter).toArray + Seq(InternalRow(new GenericArrayData(Seq(convertObject(parser, st, fieldConverters))))) + } + } + /** * Create a converter which converts the JSON documents held by the `JsonParser` * to a value according to a desired schema. @@ -143,7 +174,8 @@ class JacksonParser( case "NaN" => Float.NaN case "Infinity" => Float.PositiveInfinity case "-Infinity" => Float.NegativeInfinity - case other => throw new RuntimeException(s"Cannot parse $other as FloatType.") + case other => throw new RuntimeException( + s"Cannot parse $other as ${FloatType.catalogString}.") } } @@ -158,7 +190,8 @@ class JacksonParser( case "NaN" => Double.NaN case "Infinity" => Double.PositiveInfinity case "-Infinity" => Double.NegativeInfinity - case other => throw new RuntimeException(s"Cannot parse $other as DoubleType.") + case other => + throw new RuntimeException(s"Cannot parse $other as ${DoubleType.catalogString}.") } } @@ -370,7 +403,7 @@ class JacksonParser( } } } catch { - case e @ (_: RuntimeException | _: JsonProcessingException) => + case e @ (_: RuntimeException | _: JsonProcessingException | _: MalformedInputException) => // JSON parser currently doesn't support partial results for corrupted records. // For such records, all fields other than the field configured by // `columnNameOfCorruptRecord` are set to `null`. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonUtils.scala index 134d16e981a15..2d89c7066d080 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonUtils.scala @@ -32,11 +32,8 @@ object JacksonUtils { } } - /** - * Verify if the schema is supported in JSON parsing. - */ - def verifySchema(schema: StructType): Unit = { - def verifyType(name: String, dataType: DataType): Unit = dataType match { + def verifyType(name: String, dataType: DataType): Unit = { + dataType match { case NullType | BooleanType | ByteType | ShortType | IntegerType | LongType | FloatType | DoubleType | StringType | TimestampType | DateType | BinaryType | _: DecimalType => @@ -52,9 +49,14 @@ object JacksonUtils { case _ => throw new UnsupportedOperationException( - s"Unable to convert column $name of type ${dataType.simpleString} to JSON.") + s"Unable to convert column $name of type ${dataType.catalogString} to JSON.") } + } + /** + * Verify if the schema is supported in JSON parsing. + */ + def verifySchema(schema: StructType): Unit = { schema.foreach(field => verifyType(field.name, field.dataType)) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JsonInferSchema.scala similarity index 93% rename from sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala rename to sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JsonInferSchema.scala index 8e1b430f4eb33..9999a005106f9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JsonInferSchema.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.execution.datasources.json +package org.apache.spark.sql.catalyst.json import java.util.Comparator @@ -25,8 +25,8 @@ import org.apache.spark.SparkException import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.analysis.TypeCoercion import org.apache.spark.sql.catalyst.json.JacksonUtils.nextUntil -import org.apache.spark.sql.catalyst.json.JSONOptions import org.apache.spark.sql.catalyst.util.{DropMalformedMode, FailFastMode, ParseMode, PermissiveMode} +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -70,10 +70,17 @@ private[sql] object JsonInferSchema { }.reduceOption(typeMerger).toIterator } - // Here we get RDD local iterator then fold, instead of calling `RDD.fold` directly, because - // `RDD.fold` will run the fold function in DAGScheduler event loop thread, which may not have - // active SparkSession and `SQLConf.get` may point to the wrong configs. - val rootType = mergedTypesFromPartitions.toLocalIterator.fold(StructType(Nil))(typeMerger) + // Here we manually submit a fold-like Spark job, so that we can set the SQLConf when running + // the fold functions in the scheduler event loop thread. + val existingConf = SQLConf.get + var rootType: DataType = StructType(Nil) + val foldPartition = (iter: Iterator[DataType]) => iter.fold(StructType(Nil))(typeMerger) + val mergeResult = (index: Int, taskResult: DataType) => { + rootType = SQLConf.withExistingConf(existingConf) { + typeMerger(rootType, taskResult) + } + } + json.sparkContext.runJob(mergedTypesFromPartitions, foldPartition, mergeResult) canonicalizeType(rootType, configOptions) match { case Some(st: StructType) => st @@ -103,7 +110,7 @@ private[sql] object JsonInferSchema { /** * Infer the type of a json document from the parser's token stream */ - private def inferField(parser: JsonParser, configOptions: JSONOptions): DataType = { + def inferField(parser: JsonParser, configOptions: JSONOptions): DataType = { import com.fasterxml.jackson.core.JsonToken._ parser.getCurrentToken match { case null | VALUE_NULL => NullType @@ -295,8 +302,10 @@ private[sql] object JsonInferSchema { // Both fields1 and fields2 should be sorted by name, since inferField performs sorting. // Therefore, we can take advantage of the fact that we're merging sorted lists and skip // building a hash map or performing additional sorting. - assert(isSorted(fields1), s"StructType's fields were not sorted: ${fields1.toSeq}") - assert(isSorted(fields2), s"StructType's fields were not sorted: ${fields2.toSeq}") + assert(isSorted(fields1), + s"${StructType.simpleString}'s fields were not sorted: ${fields1.toSeq}") + assert(isSorted(fields2), + s"${StructType.simpleString}'s fields were not sorted: ${fields2.toSeq}") val newFields = new java.util.ArrayList[StructField]() diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 2cc27d82f7d20..b432ce24e1ef7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -46,7 +46,14 @@ abstract class Optimizer(sessionCatalog: SessionCatalog) protected def fixedPoint = FixedPoint(SQLConf.get.optimizerMaxIterations) - def batches: Seq[Batch] = { + /** + * Defines the default rule batches in the Optimizer. + * + * Implementations of this class should override this method, and [[nonExcludableRules]] if + * necessary, instead of [[batches]]. The rule batches that eventually run in the Optimizer, + * i.e., returned by [[batches]], will be (defaultBatches - (excludedRules - nonExcludableRules)). + */ + def defaultBatches: Seq[Batch] = { val operatorOptimizationRuleSet = Seq( // Operator push down @@ -123,11 +130,21 @@ abstract class Optimizer(sessionCatalog: SessionCatalog) // since the other rules might make two separate Unions operators adjacent. Batch("Union", Once, CombineUnions) :: + // run this once earlier. this might simplify the plan and reduce cost of optimizer. + // for example, a query such as Filter(LocalRelation) would go through all the heavy + // optimizer rules that are triggered when there is a filter + // (e.g. InferFiltersFromConstraints). if we run this batch earlier, the query becomes just + // LocalRelation and does not trigger many rules + Batch("LocalRelation early", fixedPoint, + ConvertToLocalRelation, + PropagateEmptyRelation) :: Batch("Pullup Correlated Expressions", Once, PullupCorrelatedPredicates) :: Batch("Subquery", Once, OptimizeSubqueries) :: Batch("Replace Operators", fixedPoint, + RewriteExceptAll, + RewriteIntersectAll, ReplaceIntersectWithSemiJoin, ReplaceExceptWithFilter, ReplaceExceptWithAntiJoin, @@ -160,14 +177,51 @@ abstract class Optimizer(sessionCatalog: SessionCatalog) UpdateNullabilityInAttributeReferences) } + /** + * Defines rules that cannot be excluded from the Optimizer even if they are specified in + * SQL config "excludedRules". + * + * Implementations of this class can override this method if necessary. The rule batches + * that eventually run in the Optimizer, i.e., returned by [[batches]], will be + * (defaultBatches - (excludedRules - nonExcludableRules)). + */ + def nonExcludableRules: Seq[String] = + EliminateDistinct.ruleName :: + EliminateSubqueryAliases.ruleName :: + EliminateView.ruleName :: + ReplaceExpressions.ruleName :: + ComputeCurrentTime.ruleName :: + GetCurrentDatabase(sessionCatalog).ruleName :: + RewriteDistinctAggregates.ruleName :: + ReplaceDeduplicateWithAggregate.ruleName :: + ReplaceIntersectWithSemiJoin.ruleName :: + ReplaceExceptWithFilter.ruleName :: + ReplaceExceptWithAntiJoin.ruleName :: + RewriteExceptAll.ruleName :: + RewriteIntersectAll.ruleName :: + ReplaceDistinctWithAggregate.ruleName :: + PullupCorrelatedPredicates.ruleName :: + RewriteCorrelatedScalarSubquery.ruleName :: + RewritePredicateSubquery.ruleName :: Nil + /** * Optimize all the subqueries inside expression. */ object OptimizeSubqueries extends Rule[LogicalPlan] { + private def removeTopLevelSort(plan: LogicalPlan): LogicalPlan = { + plan match { + case Sort(_, _, child) => child + case Project(fields, child) => Project(fields, removeTopLevelSort(child)) + case other => other + } + } def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { case s: SubqueryExpression => val Subquery(newPlan) = Optimizer.this.execute(Subquery(s.plan)) - s.withNewPlan(newPlan) + // At this point we have an optimized subquery plan that we are going to attach + // to this subquery expression. Here we can safely remove any top level sort + // in the plan as tuples produced by a subquery are un-ordered. + s.withNewPlan(removeTopLevelSort(newPlan)) } } @@ -175,6 +229,48 @@ abstract class Optimizer(sessionCatalog: SessionCatalog) * Override to provide additional rules for the operator optimization batch. */ def extendedOperatorOptimizationRules: Seq[Rule[LogicalPlan]] = Nil + + /** + * Returns (defaultBatches - (excludedRules - nonExcludableRules)), the rule batches that + * eventually run in the Optimizer. + * + * Implementations of this class should override [[defaultBatches]], and [[nonExcludableRules]] + * if necessary, instead of this method. + */ + final override def batches: Seq[Batch] = { + val excludedRulesConf = + SQLConf.get.optimizerExcludedRules.toSeq.flatMap(Utils.stringToSeq) + val excludedRules = excludedRulesConf.filter { ruleName => + val nonExcludable = nonExcludableRules.contains(ruleName) + if (nonExcludable) { + logWarning(s"Optimization rule '${ruleName}' was not excluded from the optimizer " + + s"because this rule is a non-excludable rule.") + } + !nonExcludable + } + if (excludedRules.isEmpty) { + defaultBatches + } else { + defaultBatches.flatMap { batch => + val filteredRules = batch.rules.filter { rule => + val exclude = excludedRules.contains(rule.ruleName) + if (exclude) { + logInfo(s"Optimization rule '${rule.ruleName}' is excluded from the optimizer.") + } + !exclude + } + if (batch.rules == filteredRules) { + Some(batch) + } else if (filteredRules.nonEmpty) { + Some(Batch(batch.name, batch.strategy, filteredRules: _*)) + } else { + logInfo(s"Optimization batch '${batch.name}' is excluded from the optimizer " + + s"as all enclosed rules have been excluded.") + None + } + } + } + } } /** @@ -450,13 +546,16 @@ object ColumnPruning extends Rule[LogicalPlan] { case d @ DeserializeToObject(_, _, child) if (child.outputSet -- d.references).nonEmpty => d.copy(child = prunedChild(child, d.references)) - // Prunes the unused columns from child of Aggregate/Expand/Generate + // Prunes the unused columns from child of Aggregate/Expand/Generate/ScriptTransformation case a @ Aggregate(_, _, child) if (child.outputSet -- a.references).nonEmpty => a.copy(child = prunedChild(child, a.references)) case f @ FlatMapGroupsInPandas(_, _, _, child) if (child.outputSet -- f.references).nonEmpty => f.copy(child = prunedChild(child, f.references)) case e @ Expand(_, _, child) if (child.outputSet -- e.references).nonEmpty => e.copy(child = prunedChild(child, e.references)) + case s @ ScriptTransformation(_, _, _, child, _) + if (child.outputSet -- s.references).nonEmpty => + s.copy(child = prunedChild(child, s.references)) // prune unrequired references case p @ Project(_, g: Generate) if p.references != g.outputSet => @@ -635,6 +734,28 @@ object CollapseWindow extends Rule[LogicalPlan] { } } +/** + * Transpose Adjacent Window Expressions. + * - If the partition spec of the parent Window expression is compatible with the partition spec + * of the child window expression, transpose them. + */ +object TransposeWindow extends Rule[LogicalPlan] { + private def compatibleParititions(ps1 : Seq[Expression], ps2: Seq[Expression]): Boolean = { + ps1.length < ps2.length && ps2.take(ps1.length).permutations.exists(ps1.zip(_).forall { + case (l, r) => l.semanticEquals(r) + }) + } + + def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { + case w1 @ Window(we1, ps1, os1, w2 @ Window(we2, ps2, os2, grandChild)) + if w1.references.intersect(w2.windowOutputSet).isEmpty && + w1.expressions.forall(_.deterministic) && + w2.expressions.forall(_.deterministic) && + compatibleParititions(ps1, ps2) => + Project(w1.output, Window(we2, ps2, os2, Window(we1, ps1, os1, grandChild))) + } +} + /** * Generate a list of additional filters from an operator's existing constraint but remove those * that are either already part of the operator's condition or are part of the operator's child @@ -1258,6 +1379,12 @@ object ConvertToLocalRelation extends Rule[LogicalPlan] { case Limit(IntegerLiteral(limit), LocalRelation(output, data, isStreaming)) => LocalRelation(output, data.take(limit), isStreaming) + + case Filter(condition, LocalRelation(output, data, isStreaming)) + if !hasUnevaluableExpr(condition) => + val predicate = InterpretedPredicate.create(condition, output) + predicate.initialize(0) + LocalRelation(output, data.filter(row => predicate.eval(row)), isStreaming) } private def hasUnevaluableExpr(expr: Expression): Boolean = { @@ -1315,7 +1442,7 @@ object ReplaceDeduplicateWithAggregate extends Rule[LogicalPlan] { */ object ReplaceIntersectWithSemiJoin extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { - case Intersect(left, right) => + case Intersect(left, right, false) => assert(left.output.size == right.output.size) val joinCond = left.output.zip(right.output).map { case (l, r) => EqualNullSafe(l, r) } Distinct(Join(left, right, LeftSemi, joinCond.reduceLeftOption(And))) @@ -1336,13 +1463,149 @@ object ReplaceIntersectWithSemiJoin extends Rule[LogicalPlan] { */ object ReplaceExceptWithAntiJoin extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { - case Except(left, right) => + case Except(left, right, false) => assert(left.output.size == right.output.size) val joinCond = left.output.zip(right.output).map { case (l, r) => EqualNullSafe(l, r) } Distinct(Join(left, right, LeftAnti, joinCond.reduceLeftOption(And))) } } +/** + * Replaces logical [[Except]] operator using a combination of Union, Aggregate + * and Generate operator. + * + * Input Query : + * {{{ + * SELECT c1 FROM ut1 EXCEPT ALL SELECT c1 FROM ut2 + * }}} + * + * Rewritten Query: + * {{{ + * SELECT c1 + * FROM ( + * SELECT replicate_rows(sum_val, c1) + * FROM ( + * SELECT c1, sum_val + * FROM ( + * SELECT c1, sum(vcol) AS sum_val + * FROM ( + * SELECT 1L as vcol, c1 FROM ut1 + * UNION ALL + * SELECT -1L as vcol, c1 FROM ut2 + * ) AS union_all + * GROUP BY union_all.c1 + * ) + * WHERE sum_val > 0 + * ) + * ) + * }}} + */ + +object RewriteExceptAll extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case Except(left, right, true) => + assert(left.output.size == right.output.size) + + val newColumnLeft = Alias(Literal(1L), "vcol")() + val newColumnRight = Alias(Literal(-1L), "vcol")() + val modifiedLeftPlan = Project(Seq(newColumnLeft) ++ left.output, left) + val modifiedRightPlan = Project(Seq(newColumnRight) ++ right.output, right) + val unionPlan = Union(modifiedLeftPlan, modifiedRightPlan) + val aggSumCol = + Alias(AggregateExpression(Sum(unionPlan.output.head.toAttribute), Complete, false), "sum")() + val aggOutputColumns = left.output ++ Seq(aggSumCol) + val aggregatePlan = Aggregate(left.output, aggOutputColumns, unionPlan) + val filteredAggPlan = Filter(GreaterThan(aggSumCol.toAttribute, Literal(0L)), aggregatePlan) + val genRowPlan = Generate( + ReplicateRows(Seq(aggSumCol.toAttribute) ++ left.output), + unrequiredChildIndex = Nil, + outer = false, + qualifier = None, + left.output, + filteredAggPlan + ) + Project(left.output, genRowPlan) + } +} + +/** + * Replaces logical [[Intersect]] operator using a combination of Union, Aggregate + * and Generate operator. + * + * Input Query : + * {{{ + * SELECT c1 FROM ut1 INTERSECT ALL SELECT c1 FROM ut2 + * }}} + * + * Rewritten Query: + * {{{ + * SELECT c1 + * FROM ( + * SELECT replicate_row(min_count, c1) + * FROM ( + * SELECT c1, If (vcol1_cnt > vcol2_cnt, vcol2_cnt, vcol1_cnt) AS min_count + * FROM ( + * SELECT c1, count(vcol1) as vcol1_cnt, count(vcol2) as vcol2_cnt + * FROM ( + * SELECT true as vcol1, null as , c1 FROM ut1 + * UNION ALL + * SELECT null as vcol1, true as vcol2, c1 FROM ut2 + * ) AS union_all + * GROUP BY c1 + * HAVING vcol1_cnt >= 1 AND vcol2_cnt >= 1 + * ) + * ) + * ) + * }}} + */ +object RewriteIntersectAll extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case Intersect(left, right, true) => + assert(left.output.size == right.output.size) + + val trueVcol1 = Alias(Literal(true), "vcol1")() + val nullVcol1 = Alias(Literal(null, BooleanType), "vcol1")() + + val trueVcol2 = Alias(Literal(true), "vcol2")() + val nullVcol2 = Alias(Literal(null, BooleanType), "vcol2")() + + // Add a projection on the top of left and right plans to project out + // the additional virtual columns. + val leftPlanWithAddedVirtualCols = Project(Seq(trueVcol1, nullVcol2) ++ left.output, left) + val rightPlanWithAddedVirtualCols = Project(Seq(nullVcol1, trueVcol2) ++ right.output, right) + + val unionPlan = Union(leftPlanWithAddedVirtualCols, rightPlanWithAddedVirtualCols) + + // Expressions to compute count and minimum of both the counts. + val vCol1AggrExpr = + Alias(AggregateExpression(Count(unionPlan.output(0)), Complete, false), "vcol1_count")() + val vCol2AggrExpr = + Alias(AggregateExpression(Count(unionPlan.output(1)), Complete, false), "vcol2_count")() + val ifExpression = Alias(If( + GreaterThan(vCol1AggrExpr.toAttribute, vCol2AggrExpr.toAttribute), + vCol2AggrExpr.toAttribute, + vCol1AggrExpr.toAttribute + ), "min_count")() + + val aggregatePlan = Aggregate(left.output, + Seq(vCol1AggrExpr, vCol2AggrExpr) ++ left.output, unionPlan) + val filterPlan = Filter(And(GreaterThanOrEqual(vCol1AggrExpr.toAttribute, Literal(1L)), + GreaterThanOrEqual(vCol2AggrExpr.toAttribute, Literal(1L))), aggregatePlan) + val projectMinPlan = Project(left.output ++ Seq(ifExpression), filterPlan) + + // Apply the replicator to replicate rows based on min_count + val genRowPlan = Generate( + ReplicateRows(Seq(ifExpression.toAttribute) ++ left.output), + unrequiredChildIndex = Nil, + outer = false, + qualifier = None, + left.output, + projectMinPlan + ) + Project(left.output, genRowPlan) + } +} + /** * Removes literals from group expressions in [[Aggregate]], as they have no effect to the result * but only makes the grouping key bigger. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceExceptWithFilter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceExceptWithFilter.scala index 45edf266bbce4..efd3944eba7f5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceExceptWithFilter.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceExceptWithFilter.scala @@ -46,7 +46,7 @@ object ReplaceExceptWithFilter extends Rule[LogicalPlan] { } plan.transform { - case e @ Except(left, right) if isEligible(left, right) => + case e @ Except(left, right, false) if isEligible(left, right) => val newCondition = transformCondition(left, skipProject(right)) newCondition.map { c => Distinct(Filter(Not(c), left)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala index 1d363b8146e3f..f8037588fa71e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala @@ -218,15 +218,24 @@ object ReorderAssociativeOperator extends Rule[LogicalPlan] { object OptimizeIn extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { case q: LogicalPlan => q transformExpressionsDown { - case In(v, list) if list.isEmpty && !v.nullable => FalseLiteral + case In(v, list) if list.isEmpty => + // When v is not nullable, the following expression will be optimized + // to FalseLiteral which is tested in OptimizeInSuite.scala + If(IsNotNull(v), FalseLiteral, Literal(null, BooleanType)) case expr @ In(v, list) if expr.inSetConvertible => val newList = ExpressionSet(list).toSeq - if (newList.size > SQLConf.get.optimizerInSetConversionThreshold) { + if (newList.length == 1 + // TODO: `EqualTo` for structural types are not working. Until SPARK-24443 is addressed, + // TODO: we exclude them in this rule. + && !v.isInstanceOf[CreateNamedStructLike] + && !newList.head.isInstanceOf[CreateNamedStructLike]) { + EqualTo(v, newList.head) + } else if (newList.length > SQLConf.get.optimizerInSetConversionThreshold) { val hSet = newList.map(e => e.eval(EmptyRow)) InSet(v, HashSet() ++ hSet) - } else if (newList.size < list.size) { + } else if (newList.length < list.length) { expr.copy(list = newList) - } else { // newList.length == list.length + } else { // newList.length == list.length && newList.length > 1 expr } } @@ -254,10 +263,15 @@ object BooleanSimplification extends Rule[LogicalPlan] with PredicateHelper { case TrueLiteral Or _ => TrueLiteral case _ Or TrueLiteral => TrueLiteral - case a And b if Not(a).semanticEquals(b) => FalseLiteral - case a Or b if Not(a).semanticEquals(b) => TrueLiteral - case a And b if a.semanticEquals(Not(b)) => FalseLiteral - case a Or b if a.semanticEquals(Not(b)) => TrueLiteral + case a And b if Not(a).semanticEquals(b) => + If(IsNull(a), Literal.create(null, a.dataType), FalseLiteral) + case a And b if a.semanticEquals(Not(b)) => + If(IsNull(b), Literal.create(null, b.dataType), FalseLiteral) + + case a Or b if Not(a).semanticEquals(b) => + If(IsNull(a), Literal.create(null, a.dataType), TrueLiteral) + case a Or b if a.semanticEquals(Not(b)) => + If(IsNull(b), Literal.create(null, b.dataType), TrueLiteral) case a And b if a.semanticEquals(b) => a case a Or b if a.semanticEquals(b) => a @@ -381,6 +395,8 @@ object SimplifyConditionals extends Rule[LogicalPlan] with PredicateHelper { case If(TrueLiteral, trueValue, _) => trueValue case If(FalseLiteral, _, falseValue) => falseValue case If(Literal(null, _), _, falseValue) => falseValue + case If(cond, trueValue, falseValue) + if cond.deterministic && trueValue.semanticEquals(falseValue) => trueValue case e @ CaseWhen(branches, elseValue) if branches.exists(x => falseOrNullLiteral(x._1)) => // If there are branches that are always false, remove them. @@ -394,17 +410,35 @@ object SimplifyConditionals extends Rule[LogicalPlan] with PredicateHelper { e.copy(branches = newBranches) } - case e @ CaseWhen(branches, _) if branches.headOption.map(_._1) == Some(TrueLiteral) => + case CaseWhen(branches, _) if branches.headOption.map(_._1).contains(TrueLiteral) => // If the first branch is a true literal, remove the entire CaseWhen and use the value // from that. Note that CaseWhen.branches should never be empty, and as a result the // headOption (rather than head) added above is just an extra (and unnecessary) safeguard. branches.head._2 case CaseWhen(branches, _) if branches.exists(_._1 == TrueLiteral) => - // a branc with a TRue condition eliminates all following branches, + // a branch with a true condition eliminates all following branches, // these branches can be pruned away val (h, t) = branches.span(_._1 != TrueLiteral) CaseWhen( h :+ t.head, None) + + case e @ CaseWhen(branches, Some(elseValue)) + if branches.forall(_._2.semanticEquals(elseValue)) => + // For non-deterministic conditions with side effect, we can not remove it, or change + // the ordering. As a result, we try to remove the deterministic conditions from the tail. + var hitNonDeterministicCond = false + var i = branches.length + while (i > 0 && !hitNonDeterministicCond) { + hitNonDeterministicCond = !branches(i - 1)._1.deterministic + if (!hitNonDeterministicCond) { + i -= 1 + } + } + if (i == 0) { + elseValue + } else { + e.copy(branches = branches.take(i).map(branch => (branch._1, elseValue))) + } } } } @@ -494,6 +528,7 @@ object NullPropagation extends Rule[LogicalPlan] { // If the value expression is NULL then transform the In expression to null literal. case In(Literal(null, _), _) => Literal.create(null, BooleanType) + case InSubquery(Seq(Literal(null, _)), _) => Literal.create(null, BooleanType) // Non-leaf NullIntolerant expressions will return null, if at least one of its children is // a null literal. @@ -642,6 +677,7 @@ object SimplifyCaseConversionExpressions extends Rule[LogicalPlan] { } } + /** * Combine nested [[Concat]] expressions. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala index de89e17e51f1b..e9b7a8b76e683 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.optimizer import scala.collection.mutable.ArrayBuffer +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.SubExprUtils._ import org.apache.spark.sql.catalyst.expressions.aggregate._ @@ -42,13 +43,6 @@ import org.apache.spark.sql.types._ * condition. */ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper { - private def getValueExpression(e: Expression): Seq[Expression] = { - e match { - case cns : CreateNamedStruct => cns.valExprs - case expr => Seq(expr) - } - } - private def dedupJoin(joinPlan: LogicalPlan): LogicalPlan = joinPlan match { // SPARK-21835: It is possibly that the two sides of the join have conflicting attributes, // the produced join then becomes unresolved and break structural integrity. We should @@ -97,19 +91,19 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper { val (joinCond, outerPlan) = rewriteExistentialExpr(conditions, p) // Deduplicate conflicting attributes if any. dedupJoin(Join(outerPlan, sub, LeftAnti, joinCond)) - case (p, In(value, Seq(ListQuery(sub, conditions, _, _)))) => - val inConditions = getValueExpression(value).zip(sub.output).map(EqualTo.tupled) + case (p, InSubquery(values, ListQuery(sub, conditions, _, _))) => + val inConditions = values.zip(sub.output).map(EqualTo.tupled) val (joinCond, outerPlan) = rewriteExistentialExpr(inConditions ++ conditions, p) // Deduplicate conflicting attributes if any. dedupJoin(Join(outerPlan, sub, LeftSemi, joinCond)) - case (p, Not(In(value, Seq(ListQuery(sub, conditions, _, _))))) => + case (p, Not(InSubquery(values, ListQuery(sub, conditions, _, _)))) => // This is a NULL-aware (left) anti join (NAAJ) e.g. col NOT IN expr // Construct the condition. A NULL in one of the conditions is regarded as a positive // result; such a row will be filtered out by the Anti-Join operator. // Note that will almost certainly be planned as a Broadcast Nested Loop join. // Use EXISTS if performance matters to you. - val inConditions = getValueExpression(value).zip(sub.output).map(EqualTo.tupled) + val inConditions = values.zip(sub.output).map(EqualTo.tupled) val (joinCond, outerPlan) = rewriteExistentialExpr(inConditions, p) // Expand the NOT IN expression with the NULL-aware semantic // to its full form. That is from: @@ -150,9 +144,9 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper { newPlan = dedupJoin( Join(newPlan, sub, ExistenceJoin(exists), conditions.reduceLeftOption(And))) exists - case In(value, Seq(ListQuery(sub, conditions, _, _))) => + case InSubquery(values, ListQuery(sub, conditions, _, _)) => val exists = AttributeReference("exists", BooleanType, nullable = false)() - val inConditions = getValueExpression(value).zip(sub.output).map(EqualTo.tupled) + val inConditions = values.zip(sub.output).map(EqualTo.tupled) val newConditions = (inConditions ++ conditions).reduceLeftOption(And) // Deduplicate conflicting attributes if any. newPlan = dedupJoin(Join(newPlan, sub, ExistenceJoin(exists), newConditions)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 383ebde3229d6..7bc1f63e30540 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -517,11 +517,10 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging * Connect two queries by a Set operator. * * Supported Set operators are: - * - UNION [DISTINCT] - * - UNION ALL - * - EXCEPT [DISTINCT] - * - MINUS [DISTINCT] - * - INTERSECT [DISTINCT] + * - UNION [ DISTINCT | ALL ] + * - EXCEPT [ DISTINCT | ALL ] + * - MINUS [ DISTINCT | ALL ] + * - INTERSECT [DISTINCT | ALL] */ override def visitSetOperation(ctx: SetOperationContext): LogicalPlan = withOrigin(ctx) { val left = plan(ctx.left) @@ -533,17 +532,17 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging case SqlBaseParser.UNION => Distinct(Union(left, right)) case SqlBaseParser.INTERSECT if all => - throw new ParseException("INTERSECT ALL is not supported.", ctx) + Intersect(left, right, isAll = true) case SqlBaseParser.INTERSECT => - Intersect(left, right) + Intersect(left, right, isAll = false) case SqlBaseParser.EXCEPT if all => - throw new ParseException("EXCEPT ALL is not supported.", ctx) + Except(left, right, isAll = true) case SqlBaseParser.EXCEPT => - Except(left, right) + Except(left, right, isAll = false) case SqlBaseParser.SETMINUS if all => - throw new ParseException("MINUS ALL is not supported.", ctx) + Except(left, right, isAll = true) case SqlBaseParser.SETMINUS => - Except(left, right) + Except(left, right, isAll = false) } } @@ -630,11 +629,29 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging val aggregates = Option(ctx.aggregates).toSeq .flatMap(_.namedExpression.asScala) .map(typedVisit[Expression]) - val pivotColumn = UnresolvedAttribute.quoted(ctx.pivotColumn.getText) - val pivotValues = ctx.pivotValues.asScala.map(typedVisit[Expression]).map(Literal.apply) + val pivotColumn = if (ctx.pivotColumn.identifiers.size == 1) { + UnresolvedAttribute.quoted(ctx.pivotColumn.identifier.getText) + } else { + CreateStruct( + ctx.pivotColumn.identifiers.asScala.map( + identifier => UnresolvedAttribute.quoted(identifier.getText))) + } + val pivotValues = ctx.pivotValues.asScala.map(visitPivotValue) Pivot(None, pivotColumn, pivotValues, aggregates, query) } + /** + * Create a Pivot column value with or without an alias. + */ + override def visitPivotValue(ctx: PivotValueContext): Expression = withOrigin(ctx) { + val e = expression(ctx.expression) + if (ctx.identifier != null) { + Alias(e, ctx.identifier.getText)() + } else { + e + } + } + /** * Add a [[Generate]] (Lateral View) to a logical plan. */ @@ -1086,6 +1103,11 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging case not => Not(e) } + def getValueExpressions(e: Expression): Seq[Expression] = e match { + case c: CreateNamedStruct => c.valExprs + case other => Seq(other) + } + // Create the predicate. ctx.kind.getType match { case SqlBaseParser.BETWEEN => @@ -1094,7 +1116,7 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging GreaterThanOrEqual(e, expression(ctx.lower)), LessThanOrEqual(e, expression(ctx.upper)))) case SqlBaseParser.IN if ctx.query != null => - invertIfNotDefined(In(e, Seq(ListQuery(plan(ctx.query))))) + invertIfNotDefined(InSubquery(getValueExpressions(e), ListQuery(plan(ctx.query)))) case SqlBaseParser.IN => invertIfNotDefined(In(e, ctx.expression.asScala.map(expression))) case SqlBaseParser.LIKE => @@ -1294,6 +1316,16 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging } } + /** + * Create an [[LambdaFunction]]. + */ + override def visitLambda(ctx: LambdaContext): Expression = withOrigin(ctx) { + val arguments = ctx.IDENTIFIER().asScala.map { name => + UnresolvedAttribute.quoted(name.getText) + } + LambdaFunction(expression(ctx.expression), arguments) + } + /** * Create a reference to a window frame, i.e. [[WindowSpecReference]]. */ @@ -1507,7 +1539,7 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging case "TIMESTAMP" => Literal(Timestamp.valueOf(value)) case "X" => - val padding = if (value.length % 2 == 1) "0" else "" + val padding = if (value.length % 2 != 0) "0" else "" Literal(DatatypeConverter.parseHexBinary(padding + value)) case other => throw new ParseException(s"Literals of type '$other' are currently not supported.", ctx) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala index 4c20f2368bded..7d8cb1f18b4b5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala @@ -84,12 +84,14 @@ abstract class AbstractSqlParser extends ParserInterface with Logging { val lexer = new SqlBaseLexer(new UpperCaseCharStream(CharStreams.fromString(command))) lexer.removeErrorListeners() lexer.addErrorListener(ParseErrorListener) + lexer.legacy_setops_precedence_enbled = SQLConf.get.setOpsPrecedenceEnforced val tokenStream = new CommonTokenStream(lexer) val parser = new SqlBaseParser(tokenStream) parser.addParseListener(PostProcessor) parser.removeErrorListeners() parser.addErrorListener(ParseErrorListener) + parser.legacy_setops_precedence_enbled = SQLConf.get.setOpsPrecedenceEnforced try { try { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/QueryPlanner.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/QueryPlanner.scala index bc41dd0465e34..6fa5203a06f7c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/QueryPlanner.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/QueryPlanner.scala @@ -81,7 +81,7 @@ abstract class QueryPlanner[PhysicalPlan <: TreeNode[PhysicalPlan]] { childPlans.map { childPlan => // Replace the placeholder by the child plan candidateWithPlaceholders.transformUp { - case p if p == placeholder => childPlan + case p if p.eq(placeholder) => childPlan } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index e431c9523a9da..b1ffdca091461 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -27,8 +27,6 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT /** * The active config object within the current scope. - * Note that if you want to refer config values during execution, you have to capture them - * in Driver and use the captured values in Executors. * See [[SQLConf.get]] for more information. */ def conf: SQLConf = SQLConf.get @@ -286,7 +284,7 @@ object QueryPlan extends PredicateHelper { if (ordinal == -1) { ar } else { - ar.withExprId(ExprId(ordinal)) + ar.withExprId(ExprId(ordinal)).canonicalized } }.canonicalized.asInstanceOf[T] } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/AnalysisHelper.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/AnalysisHelper.scala new file mode 100644 index 0000000000000..9404a809b453c --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/AnalysisHelper.scala @@ -0,0 +1,205 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.plans.logical + +import org.apache.spark.sql.catalyst.analysis.CheckAnalysis +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.plans.QueryPlan +import org.apache.spark.sql.catalyst.trees.{CurrentOrigin, TreeNode} +import org.apache.spark.util.Utils + + +/** + * [[AnalysisHelper]] defines some infrastructure for the query analyzer. In particular, in query + * analysis we don't want to repeatedly re-analyze sub-plans that have previously been analyzed. + * + * This trait defines a flag `analyzed` that can be set to true once analysis is done on the tree. + * This also provides a set of resolve methods that do not recurse down to sub-plans that have the + * analyzed flag set to true. + * + * The analyzer rules should use the various resolve methods, in lieu of the various transform + * methods defined in [[TreeNode]] and [[QueryPlan]]. + * + * To prevent accidental use of the transform methods, this trait also overrides the transform + * methods to throw exceptions in test mode, if they are used in the analyzer. + */ +trait AnalysisHelper extends QueryPlan[LogicalPlan] { self: LogicalPlan => + + private var _analyzed: Boolean = false + + /** + * Recursively marks all nodes in this plan tree as analyzed. + * This should only be called by [[CheckAnalysis]]. + */ + private[catalyst] def setAnalyzed(): Unit = { + if (!_analyzed) { + _analyzed = true + children.foreach(_.setAnalyzed()) + } + } + + /** + * Returns true if this node and its children have already been gone through analysis and + * verification. Note that this is only an optimization used to avoid analyzing trees that + * have already been analyzed, and can be reset by transformations. + */ + def analyzed: Boolean = _analyzed + + /** + * Returns a copy of this node where `rule` has been recursively applied to the tree. When + * `rule` does not apply to a given node, it is left unchanged. This function is similar to + * `transform`, but skips sub-trees that have already been marked as analyzed. + * Users should not expect a specific directionality. If a specific directionality is needed, + * [[resolveOperatorsUp]] or [[resolveOperatorsDown]] should be used. + * + * @param rule the function use to transform this nodes children + */ + def resolveOperators(rule: PartialFunction[LogicalPlan, LogicalPlan]): LogicalPlan = { + resolveOperatorsDown(rule) + } + + /** + * Returns a copy of this node where `rule` has been recursively applied first to all of its + * children and then itself (post-order, bottom-up). When `rule` does not apply to a given node, + * it is left unchanged. This function is similar to `transformUp`, but skips sub-trees that + * have already been marked as analyzed. + * + * @param rule the function use to transform this nodes children + */ + def resolveOperatorsUp(rule: PartialFunction[LogicalPlan, LogicalPlan]): LogicalPlan = { + if (!analyzed) { + AnalysisHelper.allowInvokingTransformsInAnalyzer { + val afterRuleOnChildren = mapChildren(_.resolveOperatorsUp(rule)) + if (self fastEquals afterRuleOnChildren) { + CurrentOrigin.withOrigin(origin) { + rule.applyOrElse(self, identity[LogicalPlan]) + } + } else { + CurrentOrigin.withOrigin(origin) { + rule.applyOrElse(afterRuleOnChildren, identity[LogicalPlan]) + } + } + } + } else { + self + } + } + + /** Similar to [[resolveOperatorsUp]], but does it top-down. */ + def resolveOperatorsDown(rule: PartialFunction[LogicalPlan, LogicalPlan]): LogicalPlan = { + if (!analyzed) { + AnalysisHelper.allowInvokingTransformsInAnalyzer { + val afterRule = CurrentOrigin.withOrigin(origin) { + rule.applyOrElse(self, identity[LogicalPlan]) + } + + // Check if unchanged and then possibly return old copy to avoid gc churn. + if (self fastEquals afterRule) { + mapChildren(_.resolveOperatorsDown(rule)) + } else { + afterRule.mapChildren(_.resolveOperatorsDown(rule)) + } + } + } else { + self + } + } + + /** + * Recursively transforms the expressions of a tree, skipping nodes that have already + * been analyzed. + */ + def resolveExpressions(r: PartialFunction[Expression, Expression]): LogicalPlan = { + resolveOperators { + case p => p.transformExpressions(r) + } + } + + protected def assertNotAnalysisRule(): Unit = { + if (Utils.isTesting && + AnalysisHelper.inAnalyzer.get > 0 && + AnalysisHelper.resolveOperatorDepth.get == 0) { + throw new RuntimeException("This method should not be called in the analyzer") + } + } + + /** + * In analyzer, use [[resolveOperatorsDown()]] instead. If this is used in the analyzer, + * an exception will be thrown in test mode. It is however OK to call this function within + * the scope of a [[resolveOperatorsDown()]] call. + * @see [[TreeNode.transformDown()]]. + */ + override def transformDown(rule: PartialFunction[LogicalPlan, LogicalPlan]): LogicalPlan = { + assertNotAnalysisRule() + super.transformDown(rule) + } + + /** + * Use [[resolveOperators()]] in the analyzer. + * @see [[TreeNode.transformUp()]] + */ + override def transformUp(rule: PartialFunction[LogicalPlan, LogicalPlan]): LogicalPlan = { + assertNotAnalysisRule() + super.transformUp(rule) + } + + /** + * Use [[resolveExpressions()]] in the analyzer. + * @see [[QueryPlan.transformAllExpressions()]] + */ + override def transformAllExpressions(rule: PartialFunction[Expression, Expression]): this.type = { + assertNotAnalysisRule() + super.transformAllExpressions(rule) + } + +} + + +object AnalysisHelper { + + /** + * A thread local to track whether we are in a resolveOperator call (for the purpose of analysis). + * This is an int because resolve* calls might be be nested (e.g. a rule might trigger another + * query compilation within the rule itself), so we are tracking the depth here. + */ + private val resolveOperatorDepth: ThreadLocal[Int] = new ThreadLocal[Int] { + override def initialValue(): Int = 0 + } + + /** + * A thread local to track whether we are in the analysis phase of query compilation. This is an + * int rather than a boolean in case our analyzer recursively calls itself. + */ + private val inAnalyzer: ThreadLocal[Int] = new ThreadLocal[Int] { + override def initialValue(): Int = 0 + } + + def allowInvokingTransformsInAnalyzer[T](f: => T): T = { + resolveOperatorDepth.set(resolveOperatorDepth.get + 1) + try f finally { + resolveOperatorDepth.set(resolveOperatorDepth.get - 1) + } + } + + def markInAnalyzer[T](f: => T): T = { + inAnalyzer.set(inAnalyzer.get + 1) + try f finally { + inAnalyzer.set(inAnalyzer.get - 1) + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index c486ad700f362..5f136629eb15b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -23,12 +23,12 @@ import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.LogicalPlanStats -import org.apache.spark.sql.catalyst.trees.CurrentOrigin import org.apache.spark.sql.types.StructType abstract class LogicalPlan extends QueryPlan[LogicalPlan] + with AnalysisHelper with LogicalPlanStats with QueryPlanConstraints with Logging { @@ -159,7 +159,7 @@ abstract class UnaryNode extends LogicalPlan { var allConstraints = child.constraints.asInstanceOf[Set[Expression]] projectList.foreach { case a @ Alias(l: Literal, _) => - allConstraints += EqualTo(a.toAttribute, l) + allConstraints += EqualNullSafe(a.toAttribute, l) case a @ Alias(e, _) => // For every alias in `projectList`, replace the reference in constraints by its attribute. allConstraints ++= allConstraints.map(_ transform { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index 3bf32ef7884e5..7ff83a9be3622 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -17,13 +17,13 @@ package org.apache.spark.sql.catalyst.plans.logical -import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation +import org.apache.spark.sql.catalyst.{AliasIdentifier} +import org.apache.spark.sql.catalyst.analysis.{MultiInstanceRelation, NamedRelation} import org.apache.spark.sql.catalyst.catalog.{CatalogStorageFormat, CatalogTable} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.plans._ -import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning, - RangePartitioning, RoundRobinPartitioning} +import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning, RangePartitioning, RoundRobinPartitioning} import org.apache.spark.sql.types._ import org.apache.spark.util.Utils import org.apache.spark.util.random.RandomSampler @@ -74,7 +74,7 @@ case class Project(projectList: Seq[NamedExpression], child: LogicalPlan) * their output. * * @param generator the generator expression - * @param unrequiredChildIndex this paramter starts as Nil and gets filled by the Optimizer. + * @param unrequiredChildIndex this parameter starts as Nil and gets filled by the Optimizer. * It's used as an optimization for omitting data generation that will * be discarded next by a projection. * A common use case is when we explode(array(..)) and are interested @@ -113,7 +113,7 @@ case class Generate( def qualifiedGeneratorOutput: Seq[Attribute] = { val qualifiedOutput = qualifier.map { q => // prepend the new qualifier to the existed one - generatorOutput.map(a => a.withQualifier(Some(q))) + generatorOutput.map(a => a.withQualifier(Seq(q))) }.getOrElse(generatorOutput) val nullableOutput = qualifiedOutput.map { // if outer, make all attributes nullable, otherwise keep existing nullability @@ -164,7 +164,12 @@ object SetOperation { def unapply(p: SetOperation): Option[(LogicalPlan, LogicalPlan)] = Some((p.left, p.right)) } -case class Intersect(left: LogicalPlan, right: LogicalPlan) extends SetOperation(left, right) { +case class Intersect( + left: LogicalPlan, + right: LogicalPlan, + isAll: Boolean) extends SetOperation(left, right) { + + override def nodeName: String = getClass.getSimpleName + ( if ( isAll ) "All" else "" ) override def output: Seq[Attribute] = left.output.zip(right.output).map { case (leftAttr, rightAttr) => @@ -183,8 +188,11 @@ case class Intersect(left: LogicalPlan, right: LogicalPlan) extends SetOperation } } -case class Except(left: LogicalPlan, right: LogicalPlan) extends SetOperation(left, right) { - +case class Except( + left: LogicalPlan, + right: LogicalPlan, + isAll: Boolean) extends SetOperation(left, right) { + override def nodeName: String = getClass.getSimpleName + ( if ( isAll ) "All" else "" ) /** We don't use right.output because those rows get excluded from the set. */ override def output: Seq[Attribute] = left.output @@ -344,6 +352,38 @@ case class Join( } } +/** + * Append data to an existing table. + */ +case class AppendData( + table: NamedRelation, + query: LogicalPlan, + isByName: Boolean) extends LogicalPlan { + override def children: Seq[LogicalPlan] = Seq(query) + override def output: Seq[Attribute] = Seq.empty + + override lazy val resolved: Boolean = { + table.resolved && query.resolved && query.output.size == table.output.size && + query.output.zip(table.output).forall { + case (inAttr, outAttr) => + // names and types must match, nullability must be compatible + inAttr.name == outAttr.name && + DataType.equalsIgnoreCompatibleNullability(outAttr.dataType, inAttr.dataType) && + (outAttr.nullable || !inAttr.nullable) + } + } +} + +object AppendData { + def byName(table: NamedRelation, df: LogicalPlan): AppendData = { + new AppendData(table, df, true) + } + + def byPosition(table: NamedRelation, query: LogicalPlan): AppendData = { + new AppendData(table, query, false) + } +} + /** * Insert some data into a table. Note that this plan is unresolved and has to be replaced by the * concrete implementations during analysis. @@ -700,7 +740,7 @@ case class GroupingSets( case class Pivot( groupByExprsOpt: Option[Seq[NamedExpression]], pivotColumn: Expression, - pivotValues: Seq[Literal], + pivotValues: Seq[Expression], aggregates: Seq[Expression], child: LogicalPlan) extends UnaryNode { override lazy val resolved = false // Pivot will be replaced after being resolved. @@ -786,19 +826,37 @@ case class LocalLimit(limitExpr: Expression, child: LogicalPlan) extends OrderPr /** * Aliased subquery. * - * @param alias the alias name for this subquery. + * @param name the alias identifier for this subquery. * @param child the logical plan of this subquery. */ case class SubqueryAlias( - alias: String, + name: AliasIdentifier, child: LogicalPlan) extends OrderPreservingUnaryNode { - override def doCanonicalize(): LogicalPlan = child.canonicalized + def alias: String = name.identifier - override def output: Seq[Attribute] = child.output.map(_.withQualifier(Some(alias))) + override def output: Seq[Attribute] = { + val qualifierList = name.database.map(Seq(_, alias)).getOrElse(Seq(alias)) + child.output.map(_.withQualifier(qualifierList)) + } + override def doCanonicalize(): LogicalPlan = child.canonicalized } +object SubqueryAlias { + def apply( + identifier: String, + child: LogicalPlan): SubqueryAlias = { + SubqueryAlias(AliasIdentifier(identifier), child) + } + + def apply( + identifier: String, + database: String, + child: LogicalPlan): SubqueryAlias = { + SubqueryAlias(AliasIdentifier(identifier, Some(database)), child) + } +} /** * Sample the dataset. * @@ -916,23 +974,3 @@ case class Deduplicate( override def output: Seq[Attribute] = child.output } - -/** - * A logical plan for setting a barrier of analysis. - * - * The SQL Analyzer goes through a whole query plan even most part of it is analyzed. This - * increases the time spent on query analysis for long pipelines in ML, especially. - * - * This logical plan wraps an analyzed logical plan to prevent it from analysis again. The barrier - * is applied to the analyzed logical plan in Dataset. It won't change the output of wrapped - * logical plan and just acts as a wrapper to hide it from analyzer. New operations on the dataset - * will be put on the barrier, so only the new nodes created will be analyzed. - * - * This analysis barrier will be removed at the end of analysis stage. - */ -case class AnalysisBarrier(child: LogicalPlan) extends LeafNode { - override protected def innerChildren: Seq[LogicalPlan] = Seq(child) - override def output: Seq[Attribute] = child.output - override def isStreaming: Boolean = child.isStreaming - override def doCanonicalize(): LogicalPlan = child.canonicalized -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/ValueInterval.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/ValueInterval.scala index f46b4ed764e27..693d2a7210ab8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/ValueInterval.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/ValueInterval.scala @@ -69,6 +69,8 @@ object ValueInterval { false case (n1: NumericValueInterval, n2: NumericValueInterval) => n1.min.compareTo(n2.max) <= 0 && n1.max.compareTo(n2.min) >= 0 + case _ => + throw new UnsupportedOperationException(s"Not supported pair: $r1, $r2 at isIntersected()") } /** @@ -86,6 +88,8 @@ object ValueInterval { val newMax = if (n1.max <= n2.max) n1.max else n2.max (Some(EstimationUtils.fromDouble(newMin, dt)), Some(EstimationUtils.fromDouble(newMax, dt))) + case _ => + throw new UnsupportedOperationException(s"Not supported pair: $r1, $r2 at intersect()") } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index cc1a5e835d9cd..cd28c733f3613 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.plans.physical +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.types.{DataType, IntegerType} @@ -206,6 +208,18 @@ case object SinglePartition extends Partitioning { } } +/** + * Represents a partitioning where rows are only serialized/deserialized locally. The number + * of partitions are not changed and also the distribution of rows. This is mainly used to + * obtain some statistics of map tasks such as number of outputs. + */ +case class LocalPartitioning(childRDD: RDD[InternalRow]) extends Partitioning { + val numPartitions = childRDD.getNumPartitions + + // We will perform this partitioning no matter what the data distribution is. + override def satisfies0(required: Distribution): Boolean = false +} + /** * Represents a partitioning where rows are split up across partitions based on the hash * of `expressions`. All rows where `expressions` evaluate to the same values are guaranteed to be diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala index dccb44ddebfa4..183be5a027ec5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala @@ -21,6 +21,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.catalyst.trees.TreeNode import org.apache.spark.sql.catalyst.util.sideBySide +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.util.Utils object RuleExecutor { @@ -72,6 +73,7 @@ abstract class RuleExecutor[TreeType <: TreeNode[_]] extends Logging { def execute(plan: TreeType): TreeType = { var curPlan = plan val queryExecutionMetrics = RuleExecutor.queryExecutionMeter + val planChangeLogger = new PlanChangeLogger() batches.foreach { batch => val batchStartPlan = curPlan @@ -90,11 +92,7 @@ abstract class RuleExecutor[TreeType <: TreeNode[_]] extends Logging { if (!result.fastEquals(plan)) { queryExecutionMetrics.incNumEffectiveExecution(rule.ruleName) queryExecutionMetrics.incTimeEffectiveExecutionBy(rule.ruleName, runTime) - logTrace( - s""" - |=== Applying Rule ${rule.ruleName} === - |${sideBySide(plan.treeString, result.treeString).mkString("\n")} - """.stripMargin) + planChangeLogger.log(rule.ruleName, plan, result) } queryExecutionMetrics.incExecutionTimeBy(rule.ruleName, runTime) queryExecutionMetrics.incNumExecution(rule.ruleName) @@ -143,4 +141,29 @@ abstract class RuleExecutor[TreeType <: TreeNode[_]] extends Logging { curPlan } + + private class PlanChangeLogger { + + private val logLevel = SQLConf.get.optimizerPlanChangeLogLevel.toUpperCase + + private val logRules = SQLConf.get.optimizerPlanChangeRules.map(Utils.stringToSeq) + + def log(ruleName: String, oldPlan: TreeType, newPlan: TreeType): Unit = { + if (logRules.isEmpty || logRules.get.contains(ruleName)) { + lazy val message = + s""" + |=== Applying Rule ${ruleName} === + |${sideBySide(oldPlan.treeString, newPlan.treeString).mkString("\n")} + """.stripMargin + logLevel match { + case "TRACE" => logTrace(message) + case "DEBUG" => logDebug(message) + case "INFO" => logInfo(message) + case "WARN" => logWarning(message) + case "ERROR" => logError(message) + case _ => logTrace(message) + } + } + } + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayData.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayData.scala index 104b428614849..4da8ce05fe8a3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayData.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayData.scala @@ -22,6 +22,8 @@ import scala.reflect.ClassTag import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{SpecializedGetters, UnsafeArrayData} import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.Platform +import org.apache.spark.unsafe.array.ByteArrayMethods object ArrayData { def toArrayData(input: Any): ArrayData = input match { @@ -34,6 +36,31 @@ object ArrayData { case a: Array[Double] => UnsafeArrayData.fromPrimitiveArray(a) case other => new GenericArrayData(other) } + + + /** + * Allocate [[UnsafeArrayData]] or [[GenericArrayData]] based on given parameters. + * + * @param elementSize a size of an element in bytes. If less than zero, the type of an element is + * non-primitive type + * @param numElements the number of elements the array should contain + * @param additionalErrorMessage string to include in the error message + */ + def allocateArrayData( + elementSize: Int, + numElements: Long, + additionalErrorMessage: String): ArrayData = { + if (elementSize >= 0 && !UnsafeArrayData.shouldUseGenericArrayData(elementSize, numElements)) { + UnsafeArrayData.createFreshArray(numElements.toInt, elementSize) + } else if (numElements <= ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH.toLong) { + new GenericArrayData(new Array[Any](numElements.toInt)) + } else { + throw new RuntimeException(s"Cannot create array with $numElements " + + "elements of data due to exceeding the limit " + + s"${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH} elements for ArrayData. " + + additionalErrorMessage) + } + } } abstract class ArrayData extends SpecializedGetters with Serializable { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala index 80f15053005ff..02813d3939796 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala @@ -96,9 +96,9 @@ object DateTimeUtils { } } - def getThreadLocalDateFormat(): DateFormat = { + def getThreadLocalDateFormat(timeZone: TimeZone): DateFormat = { val sdf = threadLocalDateFormat.get() - sdf.setTimeZone(defaultTimeZone()) + sdf.setTimeZone(timeZone) sdf } @@ -144,7 +144,11 @@ object DateTimeUtils { } def dateToString(days: SQLDate): String = - getThreadLocalDateFormat.format(toJavaDate(days)) + getThreadLocalDateFormat(defaultTimeZone()).format(toJavaDate(days)) + + def dateToString(days: SQLDate, timeZone: TimeZone): String = { + getThreadLocalDateFormat(timeZone).format(toJavaDate(days)) + } // Converts Timestamp to string according to Hive TimestampWritable convention. def timestampToString(us: SQLTimestamp): String = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/RandomIndicesGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/RandomIndicesGenerator.scala new file mode 100644 index 0000000000000..ae05128f94777 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/RandomIndicesGenerator.scala @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.util + +import org.apache.commons.math3.random.MersenneTwister + +/** + * This class is used to generate a random indices of given length. + * + * This implementation uses the "inside-out" version of Fisher-Yates algorithm. + * Reference: + * https://en.wikipedia.org/wiki/Fisher%E2%80%93Yates_shuffle#The_%22inside-out%22_algorithm + */ +case class RandomIndicesGenerator(randomSeed: Long) { + private val random = new MersenneTwister(randomSeed) + + def getNextIndices(length: Int): Array[Int] = { + val indices = new Array[Int](length) + var i = 0 + while (i < length) { + val j = random.nextInt(i + 1) + if (j != i) { + indices(i) = indices(j) + } + indices(j) = i + i += 1 + } + indices + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala index 1dcda49a3af6a..76218b459ef0d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala @@ -17,19 +17,19 @@ package org.apache.spark.sql.catalyst.util -import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion} import org.apache.spark.sql.catalyst.expressions.RowOrdering import org.apache.spark.sql.types._ /** - * Helper functions to check for valid data types. + * Functions to help with checking for valid data types and value comparison of various types. */ object TypeUtils { def checkForNumericExpr(dt: DataType, caller: String): TypeCheckResult = { if (dt.isInstanceOf[NumericType] || dt == NullType) { TypeCheckResult.TypeCheckSuccess } else { - TypeCheckResult.TypeCheckFailure(s"$caller requires numeric types, not $dt") + TypeCheckResult.TypeCheckFailure(s"$caller requires numeric types, not ${dt.catalogString}") } } @@ -37,23 +37,18 @@ object TypeUtils { if (RowOrdering.isOrderable(dt)) { TypeCheckResult.TypeCheckSuccess } else { - TypeCheckResult.TypeCheckFailure(s"$caller does not support ordering on type $dt") + TypeCheckResult.TypeCheckFailure( + s"$caller does not support ordering on type ${dt.catalogString}") } } def checkForSameTypeInputExpr(types: Seq[DataType], caller: String): TypeCheckResult = { - if (types.size <= 1) { + if (TypeCoercion.haveSameType(types)) { TypeCheckResult.TypeCheckSuccess } else { - val firstType = types.head - types.foreach { t => - if (!t.sameType(firstType)) { - return TypeCheckResult.TypeCheckFailure( - s"input to $caller should all be the same type, but it's " + - types.map(_.simpleString).mkString("[", ", ", "]")) - } - } - TypeCheckResult.TypeCheckSuccess + return TypeCheckResult.TypeCheckFailure( + s"input to $caller should all be the same type, but it's " + + types.map(_.catalogString).mkString("[", ", ", "]")) } } @@ -78,4 +73,15 @@ object TypeUtils { } x.length - y.length } + + /** + * Returns true if the equals method of the elements of the data type is implemented properly. + * This also means that they can be safely used in collections relying on the equals method, + * as sets or maps. + */ + def typeWithProperEquals(dataType: DataType): Boolean = dataType match { + case BinaryType => false + case _: AtomicType => true + case _ => false + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala index 4005087dad05a..0978e92dd4f72 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala @@ -155,6 +155,18 @@ package object util { def toPrettySQL(e: Expression): String = usePrettyExpression(e).sql + + def escapeSingleQuotedString(str: String): String = { + val builder = StringBuilder.newBuilder + + str.foreach { + case '\'' => builder ++= s"\\\'" + case ch => builder += ch + } + + builder.toString() + } + /* FIX ME implicit class debugLogging(a: Any) { def debugLogging() { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/ReadOnlySQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/ReadOnlySQLConf.scala index 19f67236c8979..ef4b339730807 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/ReadOnlySQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/ReadOnlySQLConf.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.internal import java.util.{Map => JMap} -import org.apache.spark.{TaskContext, TaskContextImpl} +import org.apache.spark.TaskContext import org.apache.spark.internal.config.{ConfigEntry, ConfigProvider, ConfigReader} /** @@ -29,7 +29,7 @@ import org.apache.spark.internal.config.{ConfigEntry, ConfigProvider, ConfigRead class ReadOnlySQLConf(context: TaskContext) extends SQLConf { @transient override val settings: JMap[String, String] = { - context.asInstanceOf[TaskContextImpl].getLocalProperties().asInstanceOf[JMap[String, String]] + context.getLocalProperties.asInstanceOf[JMap[String, String]] } @transient override protected val reader: ConfigReader = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index da1c34cdc78f2..4928560eacb1c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -20,12 +20,14 @@ package org.apache.spark.sql.internal import java.util.{Locale, NoSuchElementException, Properties, TimeZone} import java.util.concurrent.TimeUnit import java.util.concurrent.atomic.AtomicReference +import java.util.zip.Deflater import scala.collection.JavaConverters._ import scala.collection.immutable import scala.util.matching.Regex import org.apache.hadoop.fs.Path +import org.tukaani.xz.LZMA2Options import org.apache.spark.{SparkContext, TaskContext} import org.apache.spark.internal.Logging @@ -80,6 +82,19 @@ object SQLConf { /** See [[get]] for more information. */ def getFallbackConf: SQLConf = fallbackConf.get() + private lazy val existingConf = new ThreadLocal[SQLConf] { + override def initialValue: SQLConf = null + } + + def withExistingConf[T](conf: SQLConf)(f: => T): T = { + existingConf.set(conf) + try { + f + } finally { + existingConf.remove() + } + } + /** * Defines a getter that returns the SQLConf within scope. * See [[get]] for more information. @@ -114,19 +129,35 @@ object SQLConf { if (TaskContext.get != null) { new ReadOnlySQLConf(TaskContext.get()) } else { - if (Utils.isTesting && SparkContext.getActive.isDefined) { + val isSchedulerEventLoopThread = SparkContext.getActive + .map(_.dagScheduler.eventProcessLoop.eventThread) + .exists(_.getId == Thread.currentThread().getId) + if (isSchedulerEventLoopThread) { // DAGScheduler event loop thread does not have an active SparkSession, the `confGetter` - // will return `fallbackConf` which is unexpected. Here we prevent it from happening. - val schedulerEventLoopThread = - SparkContext.getActive.get.dagScheduler.eventProcessLoop.eventThread - if (schedulerEventLoopThread.getId == Thread.currentThread().getId) { + // will return `fallbackConf` which is unexpected. Here we require the caller to get the + // conf within `withExistingConf`, otherwise fail the query. + val conf = existingConf.get() + if (conf != null) { + conf + } else if (Utils.isTesting) { throw new RuntimeException("Cannot get SQLConf inside scheduler event loop thread.") + } else { + confGetter.get()() } + } else { + confGetter.get()() } - confGetter.get()() } } + val OPTIMIZER_EXCLUDED_RULES = buildConf("spark.sql.optimizer.excludedRules") + .doc("Configures a list of rules to be disabled in the optimizer, in which the rules are " + + "specified by their rule names and separated by comma. It is not guaranteed that all the " + + "rules in this configuration will eventually be excluded, as some rules are necessary " + + "for correctness. The optimizer will log the rules that have indeed been excluded.") + .stringConf + .createOptional + val OPTIMIZER_MAX_ITERATIONS = buildConf("spark.sql.optimizer.maxIterations") .internal() .doc("The max number of iterations the optimizer and analyzer runs.") @@ -140,6 +171,26 @@ object SQLConf { .intConf .createWithDefault(10) + val OPTIMIZER_PLAN_CHANGE_LOG_LEVEL = buildConf("spark.sql.optimizer.planChangeLog.level") + .internal() + .doc("Configures the log level for logging the change from the original plan to the new " + + "plan after a rule is applied. The value can be 'trace', 'debug', 'info', 'warn', or " + + "'error'. The default log level is 'trace'.") + .stringConf + .checkValue( + str => Set("TRACE", "DEBUG", "INFO", "WARN", "ERROR").contains(str.toUpperCase), + "Invalid value for 'spark.sql.optimizer.planChangeLog.level'. Valid values are " + + "'trace', 'debug', 'info', 'warn' and 'error'.") + .createWithDefault("trace") + + val OPTIMIZER_PLAN_CHANGE_LOG_RULES = buildConf("spark.sql.optimizer.planChangeLog.rules") + .internal() + .doc("If this configuration is set, the optimizer will only log plan changes caused by " + + "applying the rules specified in this configuration. The value can be a list of rule " + + "names separated by comma.") + .stringConf + .createOptional + val COMPRESS_CACHED = buildConf("spark.sql.inMemoryColumnarStorage.compressed") .doc("When set to true Spark SQL will automatically select a compression codec for each " + "column based on statistics of the data.") @@ -204,6 +255,13 @@ object SQLConf { .intConf .createWithDefault(4) + val LIMIT_FLAT_GLOBAL_LIMIT = buildConf("spark.sql.limit.flatGlobalLimit") + .internal() + .doc("During global limit, try to evenly distribute limited rows across data " + + "partitions. If disabled, scanning data partitions sequentially until reaching limit number.") + .booleanConf + .createWithDefault(true) + val ADVANCED_PARTITION_PREDICATE_PUSHDOWN = buildConf("spark.sql.hive.advancedPartitionPredicatePushdown.enabled") .internal() @@ -360,7 +418,7 @@ object SQLConf { "`parquet.compression` is specified in the table-specific options/properties, the " + "precedence would be `compression`, `parquet.compression`, " + "`spark.sql.parquet.compression.codec`. Acceptable values include: none, uncompressed, " + - "snappy, gzip, lzo.") + "snappy, gzip, lzo, brotli, lz4, zstd.") .stringConf .transform(_.toLowerCase(Locale.ROOT)) .checkValues(Set("none", "uncompressed", "snappy", "gzip", "lzo", "lz4", "brotli", "zstd")) @@ -378,6 +436,23 @@ object SQLConf { .booleanConf .createWithDefault(true) + val PARQUET_FILTER_PUSHDOWN_TIMESTAMP_ENABLED = + buildConf("spark.sql.parquet.filterPushdown.timestamp") + .doc("If true, enables Parquet filter push-down optimization for Timestamp. " + + "This configuration only has an effect when 'spark.sql.parquet.filterPushdown' is " + + "enabled and Timestamp stored as TIMESTAMP_MICROS or TIMESTAMP_MILLIS type.") + .internal() + .booleanConf + .createWithDefault(true) + + val PARQUET_FILTER_PUSHDOWN_DECIMAL_ENABLED = + buildConf("spark.sql.parquet.filterPushdown.decimal") + .doc("If true, enables Parquet filter push-down optimization for Decimal. " + + "This configuration only has an effect when 'spark.sql.parquet.filterPushdown' is enabled.") + .internal() + .booleanConf + .createWithDefault(true) + val PARQUET_FILTER_PUSHDOWN_STRING_STARTSWITH_ENABLED = buildConf("spark.sql.parquet.filterPushdown.string.startsWith") .doc("If true, enables Parquet filter push-down optimization for string startsWith function. " + @@ -386,6 +461,18 @@ object SQLConf { .booleanConf .createWithDefault(true) + val PARQUET_FILTER_PUSHDOWN_INFILTERTHRESHOLD = + buildConf("spark.sql.parquet.pushdown.inFilterThreshold") + .doc("The maximum number of values to filter push-down optimization for IN predicate. " + + "Large threshold won't necessarily provide much better performance. " + + "The experiment argued that 300 is the limit threshold. " + + "By setting this value to 0 this feature can be disabled. " + + "This configuration only has an effect when 'spark.sql.parquet.filterPushdown' is enabled.") + .internal() + .intConf + .checkValue(threshold => threshold >= 0, "The threshold must not be negative.") + .createWithDefault(10) + val PARQUET_WRITE_LEGACY_FORMAT = buildConf("spark.sql.parquet.writeLegacyFormat") .doc("Whether to be compatible with the legacy Parquet format adopted by Spark 1.4 and prior " + "versions, when converting Parquet schema to Spark SQL schema and vice versa.") @@ -607,6 +694,12 @@ object SQLConf { .booleanConf .createWithDefault(true) + val BUCKETING_MAX_BUCKETS = buildConf("spark.sql.sources.bucketing.maxBuckets") + .doc("The maximum number of buckets allowed. Defaults to 100000") + .intConf + .checkValue(_ > 0, "the value of spark.sql.sources.bucketing.maxBuckets must be larger than 0") + .createWithDefault(100000) + val CROSS_JOINS_ENABLED = buildConf("spark.sql.crossJoin.enabled") .doc("When false, we will throw an error if a query contains a cartesian product without " + "explicit CROSS JOIN syntax.") @@ -814,6 +907,14 @@ object SQLConf { .intConf .createWithDefault(10) + val FLATMAPGROUPSWITHSTATE_STATE_FORMAT_VERSION = + buildConf("spark.sql.streaming.flatMapGroupsWithState.stateFormatVersion") + .internal() + .doc("State format version used by flatMapGroupsWithState operation in a streaming query") + .intConf + .checkValue(v => Set(1, 2).contains(v), "Valid versions are 1 and 2") + .createWithDefault(2) + val CHECKPOINT_LOCATION = buildConf("spark.sql.streaming.checkpointLocation") .doc("The default location for storing checkpoint data for streaming queries.") .stringConf @@ -825,6 +926,25 @@ object SQLConf { .intConf .createWithDefault(100) + val MAX_BATCHES_TO_RETAIN_IN_MEMORY = buildConf("spark.sql.streaming.maxBatchesToRetainInMemory") + .internal() + .doc("The maximum number of batches which will be retained in memory to avoid " + + "loading from files. The value adjusts a trade-off between memory usage vs cache miss: " + + "'2' covers both success and direct failure cases, '1' covers only success case, " + + "and '0' covers extreme case - disable cache to maximize memory size of executors.") + .intConf + .createWithDefault(2) + + val STREAMING_AGGREGATION_STATE_FORMAT_VERSION = + buildConf("spark.sql.streaming.aggregation.stateFormatVersion") + .internal() + .doc("State format version used by streaming aggregation operations in a streaming query. " + + "State between versions are tend to be incompatible, so state format version shouldn't " + + "be modified after running.") + .intConf + .checkValue(v => Set(1, 2).contains(v), "Valid versions are 1 and 2") + .createWithDefault(2) + val UNSUPPORTED_OPERATION_CHECK_ENABLED = buildConf("spark.sql.streaming.unsupportedOperationCheck") .internal() @@ -875,6 +995,21 @@ object SQLConf { .stringConf .createWithDefault("org.apache.spark.sql.execution.streaming.ManifestFileCommitProtocol") + val STREAMING_MULTIPLE_WATERMARK_POLICY = + buildConf("spark.sql.streaming.multipleWatermarkPolicy") + .doc("Policy to calculate the global watermark value when there are multiple watermark " + + "operators in a streaming query. The default value is 'min' which chooses " + + "the minimum watermark reported across multiple operators. Other alternative value is" + + "'max' which chooses the maximum across multiple operators." + + "Note: This configuration cannot be changed between query restarts from the same " + + "checkpoint location.") + .stringConf + .checkValue( + str => Set("min", "max").contains(str.toLowerCase), + "Invalid value for 'spark.sql.streaming.multipleWatermarkPolicy'. " + + "Valid values are 'min' and 'max'") + .createWithDefault("min") // must be same as MultipleWatermarkPolicy.DEFAULT_POLICY_NAME + val OBJECT_AGG_SORT_BASED_FALLBACK_THRESHOLD = buildConf("spark.sql.objectHashAggregate.sortBased.fallbackThreshold") .internal() @@ -1291,7 +1426,11 @@ object SQLConf { "overwriting. In dynamic mode, Spark doesn't delete partitions ahead, and only overwrite " + "those partitions that have data written into it at runtime. By default we use static " + "mode to keep the same behavior of Spark prior to 2.3. Note that this config doesn't " + - "affect Hive serde tables, as they are always overwritten with dynamic mode.") + "affect Hive serde tables, as they are always overwritten with dynamic mode. This can " + + "also be set as an output option for a data source using key partitionOverwriteMode " + + "(which takes precedence over this setting), e.g. " + + "dataframe.write.option(\"partitionOverwriteMode\", \"dynamic\").save(path)." + ) .stringConf .transform(_.toUpperCase(Locale.ROOT)) .checkValues(PartitionOverwriteMode.values.map(_.toString)) @@ -1306,8 +1445,18 @@ object SQLConf { "issues. Turn on this config to insert a local sort before actually doing repartition " + "to generate consistent repartition results. The performance of repartition() may go " + "down since we insert extra local sort before it.") + .booleanConf + .createWithDefault(true) + + val NESTED_SCHEMA_PRUNING_ENABLED = + buildConf("spark.sql.nestedSchemaPruning.enabled") + .internal() + .doc("Prune nested fields from a logical relation's output which are unnecessary in " + + "satisfying a query. This optimization allows columnar file format readers to avoid " + + "reading unnecessary nested column data. Currently Parquet is the only data source that " + + "implements this optimization.") .booleanConf - .createWithDefault(true) + .createWithDefault(false) val TOP_K_SORT_FALLBACK_THRESHOLD = buildConf("spark.sql.execution.topKSortFallbackThreshold") @@ -1361,6 +1510,58 @@ object SQLConf { "This only takes effect when spark.sql.repl.eagerEval.enabled is set to true.") .intConf .createWithDefault(20) + + val FAST_HASH_AGGREGATE_MAX_ROWS_CAPACITY_BIT = + buildConf("spark.sql.codegen.aggregate.fastHashMap.capacityBit") + .internal() + .doc("Capacity for the max number of rows to be held in memory " + + "by the fast hash aggregate product operator. The bit is not for actual value, " + + "but the actual numBuckets is determined by loadFactor " + + "(e.g: default bit value 16 , the actual numBuckets is ((1 << 16) / 0.5).") + .intConf + .checkValue(bit => bit >= 10 && bit <= 30, "The bit value must be in [10, 30].") + .createWithDefault(16) + + val AVRO_COMPRESSION_CODEC = buildConf("spark.sql.avro.compression.codec") + .doc("Compression codec used in writing of AVRO files. Supported codecs: " + + "uncompressed, deflate, snappy, bzip2 and xz. Default codec is snappy.") + .stringConf + .checkValues(Set("uncompressed", "deflate", "snappy", "bzip2", "xz")) + .createWithDefault("snappy") + + val AVRO_DEFLATE_LEVEL = buildConf("spark.sql.avro.deflate.level") + .doc("Compression level for the deflate codec used in writing of AVRO files. " + + "Valid value must be in the range of from 1 to 9 inclusive or -1. " + + "The default value is -1 which corresponds to 6 level in the current implementation.") + .intConf + .checkValues((1 to 9).toSet + Deflater.DEFAULT_COMPRESSION) + .createWithDefault(Deflater.DEFAULT_COMPRESSION) + + val LEGACY_REPLACE_DATABRICKS_SPARK_AVRO_ENABLED = + buildConf("spark.sql.legacy.replaceDatabricksSparkAvro.enabled") + .doc("If it is set to true, the data source provider com.databricks.spark.avro is mapped " + + "to the built-in but external Avro data source module for backward compatibility.") + .booleanConf + .createWithDefault(true) + + val LEGACY_SETOPS_PRECEDENCE_ENABLED = + buildConf("spark.sql.legacy.setopsPrecedence.enabled") + .internal() + .doc("When set to true and the order of evaluation is not specified by parentheses, the " + + "set operations are performed from left to right as they appear in the query. When set " + + "to false and order of evaluation is not specified by parentheses, INTERSECT operations " + + "are performed before any UNION, EXCEPT and MINUS operations.") + .booleanConf + .createWithDefault(false) + + val PARALLEL_FILE_LISTING_IN_STATS_COMPUTATION = + buildConf("spark.sql.parallelFileListingInStatsComputation.enabled") + .internal() + .doc("When true, SQL commands use parallel file listing, " + + "as opposed to single thread listing." + + "This usually speeds up commands that need to list many directories.") + .booleanConf + .createWithDefault(true) } /** @@ -1383,10 +1584,16 @@ class SQLConf extends Serializable with Logging { /** ************************ Spark SQL Params/Hints ******************* */ + def optimizerExcludedRules: Option[String] = getConf(OPTIMIZER_EXCLUDED_RULES) + def optimizerMaxIterations: Int = getConf(OPTIMIZER_MAX_ITERATIONS) def optimizerInSetConversionThreshold: Int = getConf(OPTIMIZER_INSET_CONVERSION_THRESHOLD) + def optimizerPlanChangeLogLevel: String = getConf(OPTIMIZER_PLAN_CHANGE_LOG_LEVEL) + + def optimizerPlanChangeRules: Option[String] = getConf(OPTIMIZER_PLAN_CHANGE_LOG_RULES) + def stateStoreProviderClass: String = getConf(STATE_STORE_PROVIDER_CLASS) def stateStoreMinDeltasForSnapshot: Int = getConf(STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT) @@ -1463,13 +1670,22 @@ class SQLConf extends Serializable with Logging { def minBatchesToRetain: Int = getConf(MIN_BATCHES_TO_RETAIN) + def maxBatchesToRetainInMemory: Int = getConf(MAX_BATCHES_TO_RETAIN_IN_MEMORY) + def parquetFilterPushDown: Boolean = getConf(PARQUET_FILTER_PUSHDOWN_ENABLED) def parquetFilterPushDownDate: Boolean = getConf(PARQUET_FILTER_PUSHDOWN_DATE_ENABLED) + def parquetFilterPushDownTimestamp: Boolean = getConf(PARQUET_FILTER_PUSHDOWN_TIMESTAMP_ENABLED) + + def parquetFilterPushDownDecimal: Boolean = getConf(PARQUET_FILTER_PUSHDOWN_DECIMAL_ENABLED) + def parquetFilterPushDownStringStartWith: Boolean = getConf(PARQUET_FILTER_PUSHDOWN_STRING_STARTSWITH_ENABLED) + def parquetFilterPushDownInFilterThreshold: Int = + getConf(PARQUET_FILTER_PUSHDOWN_INFILTERTHRESHOLD) + def orcFilterPushDown: Boolean = getConf(ORC_FILTER_PUSHDOWN_ENABLED) def verifyPartitionPath: Boolean = getConf(HIVE_VERIFY_PARTITION_PATH) @@ -1498,6 +1714,8 @@ class SQLConf extends Serializable with Logging { def codegenFallback: Boolean = getConf(CODEGEN_FALLBACK) + def codegenComments: Boolean = getConf(StaticSQLConf.CODEGEN_COMMENTS) + def loggingMaxLinesForCodegen: Int = getConf(CODEGEN_LOGGING_MAX_LINES) def hugeMethodLimit: Int = getConf(WHOLESTAGE_HUGE_METHOD_LIMIT) @@ -1508,6 +1726,8 @@ class SQLConf extends Serializable with Logging { def tableRelationCacheSize: Int = getConf(StaticSQLConf.FILESOURCE_TABLE_RELATION_CACHE_SIZE) + def codegenCacheMaxEntries: Int = getConf(StaticSQLConf.CODEGEN_CACHE_MAX_ENTRIES) + def exchangeReuseEnabled: Boolean = getConf(EXCHANGE_REUSE_ENABLED) def caseSensitiveAnalysis: Boolean = getConf(SQLConf.CASE_SENSITIVE) @@ -1524,6 +1744,8 @@ class SQLConf extends Serializable with Logging { def topKSortFallbackThreshold: Int = getConf(TOP_K_SORT_FALLBACK_THRESHOLD) + def fastHashAggregateRowMaxCapacityBit: Int = getConf(FAST_HASH_AGGREGATE_MAX_ROWS_CAPACITY_BIT) + /** * Returns the [[Resolver]] for the current configuration, which can be used to determine if two * identifiers are equal. @@ -1543,6 +1765,8 @@ class SQLConf extends Serializable with Logging { def limitScaleUpFactor: Int = getConf(LIMIT_SCALE_UP_FACTOR) + def limitFlatGlobalLimit: Boolean = getConf(LIMIT_FLAT_GLOBAL_LIMIT) + def advancedPartitionPredicatePushdownEnabled: Boolean = getConf(ADVANCED_PARTITION_PREDICATE_PUSHDOWN) @@ -1609,6 +1833,8 @@ class SQLConf extends Serializable with Logging { def bucketingEnabled: Boolean = getConf(SQLConf.BUCKETING_ENABLED) + def bucketingMaxBuckets: Int = getConf(SQLConf.BUCKETING_MAX_BUCKETS) + def dataFrameSelfJoinAutoResolveAmbiguity: Boolean = getConf(DATAFRAME_SELF_JOIN_AUTO_RESOLVE_AMBIGUITY) @@ -1724,10 +1950,30 @@ class SQLConf extends Serializable with Logging { def partitionOverwriteMode: PartitionOverwriteMode.Value = PartitionOverwriteMode.withName(getConf(PARTITION_OVERWRITE_MODE)) + def nestedSchemaPruningEnabled: Boolean = getConf(NESTED_SCHEMA_PRUNING_ENABLED) + def csvColumnPruning: Boolean = getConf(SQLConf.CSV_PARSER_COLUMN_PRUNING) def legacySizeOfNull: Boolean = getConf(SQLConf.LEGACY_SIZE_OF_NULL) + def isReplEagerEvalEnabled: Boolean = getConf(SQLConf.REPL_EAGER_EVAL_ENABLED) + + def replEagerEvalMaxNumRows: Int = getConf(SQLConf.REPL_EAGER_EVAL_MAX_NUM_ROWS) + + def replEagerEvalTruncate: Int = getConf(SQLConf.REPL_EAGER_EVAL_TRUNCATE) + + def avroCompressionCodec: String = getConf(SQLConf.AVRO_COMPRESSION_CODEC) + + def avroDeflateLevel: Int = getConf(SQLConf.AVRO_DEFLATE_LEVEL) + + def replaceDatabricksSparkAvroEnabled: Boolean = + getConf(SQLConf.LEGACY_REPLACE_DATABRICKS_SPARK_AVRO_ENABLED) + + def setOpsPrecedenceEnforced: Boolean = getConf(SQLConf.LEGACY_SETOPS_PRECEDENCE_ENABLED) + + def parallelFileListingInStatsComputation: Boolean = + getConf(SQLConf.PARALLEL_FILE_LISTING_IN_STATS_COMPUTATION) + /** ********************** SQLConf functionality methods ************ */ /** Set Spark SQL configuration properties. */ @@ -1884,4 +2130,8 @@ class SQLConf extends Serializable with Logging { } cloned } + + def isModifiable(key: String): Boolean = { + sqlConfEntries.containsKey(key) && !staticConfKeys.contains(key) + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/StaticSQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/StaticSQLConf.scala index 382ef28f49a7a..d9c354b165e52 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/StaticSQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/StaticSQLConf.scala @@ -66,6 +66,22 @@ object StaticSQLConf { .checkValue(cacheSize => cacheSize >= 0, "The maximum size of the cache must not be negative") .createWithDefault(1000) + val CODEGEN_CACHE_MAX_ENTRIES = buildStaticConf("spark.sql.codegen.cache.maxEntries") + .internal() + .doc("When nonzero, enable caching of generated classes for operators and expressions. " + + "All jobs share the cache that can use up to the specified number for generated classes.") + .intConf + .checkValue(maxEntries => maxEntries >= 0, "The maximum must not be negative") + .createWithDefault(100) + + val CODEGEN_COMMENTS = buildStaticConf("spark.sql.codegen.comments") + .internal() + .doc("When true, put comment in the generated code. Since computing huge comments " + + "can be extremely expensive in certain cases, such as deeply-nested expressions which " + + "operate over inputs with wide schemas, default is false.") + .booleanConf + .createWithDefault(false) + // When enabling the debug, Spark SQL internal table properties are not filtered out; however, // some related DDL commands (e.g., ANALYZE TABLE and CREATE TABLE LIKE) might not work properly. val DEBUG_MODE = buildStaticConf("spark.sql.debug") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala index 3041f44b116ea..c43cc748655e8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala @@ -145,7 +145,7 @@ abstract class NumericType extends AtomicType { } -private[sql] object NumericType extends AbstractDataType { +private[spark] object NumericType extends AbstractDataType { /** * Enables matching against NumericType for expressions: * {{{ @@ -155,11 +155,12 @@ private[sql] object NumericType extends AbstractDataType { */ def unapply(e: Expression): Boolean = e.dataType.isInstanceOf[NumericType] - override private[sql] def defaultConcreteType: DataType = DoubleType + override private[spark] def defaultConcreteType: DataType = DoubleType - override private[sql] def simpleString: String = "numeric" + override private[spark] def simpleString: String = "numeric" - override private[sql] def acceptsType(other: DataType): Boolean = other.isInstanceOf[NumericType] + override private[spark] def acceptsType(other: DataType): Boolean = + other.isInstanceOf[NumericType] } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala index 38c40482fa4d9..58c75b5dc7a35 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala @@ -42,7 +42,7 @@ object ArrayType extends AbstractDataType { other.isInstanceOf[ArrayType] } - override private[sql] def simpleString: String = "array" + override private[spark] def simpleString: String = "array" } /** @@ -103,7 +103,8 @@ case class ArrayType(elementType: DataType, containsNull: Boolean) extends DataT case a : ArrayType => a.interpretedOrdering.asInstanceOf[Ordering[Any]] case s: StructType => s.interpretedOrdering.asInstanceOf[Ordering[Any]] case other => - throw new IllegalArgumentException(s"Type $other does not support ordered operations") + throw new IllegalArgumentException( + s"Type ${other.catalogString} does not support ordered operations") } def compare(x: ArrayData, y: ArrayData): Int = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala index fd40741cfb5f1..e53628d11ccf3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala @@ -27,7 +27,8 @@ import org.json4s.JsonDSL._ import org.json4s.jackson.JsonMethods._ import org.apache.spark.annotation.InterfaceStability -import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.analysis.Resolver +import org.apache.spark.sql.catalyst.expressions.{Cast, Expression} import org.apache.spark.sql.catalyst.parser.CatalystSqlParser import org.apache.spark.sql.internal.SQLConf import org.apache.spark.util.Utils @@ -113,6 +114,8 @@ abstract class DataType extends AbstractDataType { @InterfaceStability.Stable object DataType { + private val FIXED_DECIMAL = """decimal\(\s*(\d+)\s*,\s*(\-?\d+)\s*\)""".r + def fromDDL(ddl: String): DataType = { try { CatalystSqlParser.parseDataType(ddl) @@ -131,7 +134,6 @@ object DataType { /** Given the string representation of a type, return its DataType */ private def nameToType(name: String): DataType = { - val FIXED_DECIMAL = """decimal\(\s*(\d+)\s*,\s*(\-?\d+)\s*\)""".r name match { case "decimal" => DecimalType.USER_DEFAULT case FIXED_DECIMAL(precision, scale) => DecimalType(precision.toInt, scale.toInt) @@ -336,4 +338,124 @@ object DataType { case (fromDataType, toDataType) => fromDataType == toDataType } } + + private val SparkGeneratedName = """col\d+""".r + private def isSparkGeneratedName(name: String): Boolean = name match { + case SparkGeneratedName(_*) => true + case _ => false + } + + /** + * Returns true if the write data type can be read using the read data type. + * + * The write type is compatible with the read type if: + * - Both types are arrays, the array element types are compatible, and element nullability is + * compatible (read allows nulls or write does not contain nulls). + * - Both types are maps and the map key and value types are compatible, and value nullability + * is compatible (read allows nulls or write does not contain nulls). + * - Both types are structs and each field in the read struct is present in the write struct and + * compatible (including nullability), or is nullable if the write struct does not contain the + * field. Write-side structs are not compatible if they contain fields that are not present in + * the read-side struct. + * - Both types are atomic and the write type can be safely cast to the read type. + * + * Extra fields in write-side structs are not allowed to avoid accidentally writing data that + * the read schema will not read, and to ensure map key equality is not changed when data is read. + * + * @param write a write-side data type to validate against the read type + * @param read a read-side data type + * @return true if data written with the write type can be read using the read type + */ + def canWrite( + write: DataType, + read: DataType, + resolver: Resolver, + context: String, + addError: String => Unit = (_: String) => {}): Boolean = { + (write, read) match { + case (wArr: ArrayType, rArr: ArrayType) => + // run compatibility check first to produce all error messages + val typesCompatible = + canWrite(wArr.elementType, rArr.elementType, resolver, context + ".element", addError) + + if (wArr.containsNull && !rArr.containsNull) { + addError(s"Cannot write nullable elements to array of non-nulls: '$context'") + false + } else { + typesCompatible + } + + case (wMap: MapType, rMap: MapType) => + // map keys cannot include data fields not in the read schema without changing equality when + // read. map keys can be missing fields as long as they are nullable in the read schema. + + // run compatibility check first to produce all error messages + val keyCompatible = + canWrite(wMap.keyType, rMap.keyType, resolver, context + ".key", addError) + val valueCompatible = + canWrite(wMap.valueType, rMap.valueType, resolver, context + ".value", addError) + val typesCompatible = keyCompatible && valueCompatible + + if (wMap.valueContainsNull && !rMap.valueContainsNull) { + addError(s"Cannot write nullable values to map of non-nulls: '$context'") + false + } else { + typesCompatible + } + + case (StructType(writeFields), StructType(readFields)) => + var fieldCompatible = true + readFields.zip(writeFields).foreach { + case (rField, wField) => + val namesMatch = resolver(wField.name, rField.name) || isSparkGeneratedName(wField.name) + val fieldContext = s"$context.${rField.name}" + val typesCompatible = + canWrite(wField.dataType, rField.dataType, resolver, fieldContext, addError) + + if (!namesMatch) { + addError(s"Struct '$context' field name does not match (may be out of order): " + + s"expected '${rField.name}', found '${wField.name}'") + fieldCompatible = false + } else if (!rField.nullable && wField.nullable) { + addError(s"Cannot write nullable values to non-null field: '$fieldContext'") + fieldCompatible = false + } else if (!typesCompatible) { + // errors are added in the recursive call to canWrite above + fieldCompatible = false + } + } + + if (readFields.size > writeFields.size) { + val missingFieldsStr = readFields.takeRight(readFields.size - writeFields.size) + .map(f => s"'${f.name}'").mkString(", ") + if (missingFieldsStr.nonEmpty) { + addError(s"Struct '$context' missing fields: $missingFieldsStr") + fieldCompatible = false + } + + } else if (writeFields.size > readFields.size) { + val extraFieldsStr = writeFields.takeRight(writeFields.size - readFields.size) + .map(f => s"'${f.name}'").mkString(", ") + addError(s"Cannot write extra fields to struct '$context': $extraFieldsStr") + fieldCompatible = false + } + + fieldCompatible + + case (w: AtomicType, r: AtomicType) => + if (!Cast.canSafeCast(w, r)) { + addError(s"Cannot safely cast '$context': $w to $r") + false + } else { + true + } + + case (w, r) if w.sameType(r) && !w.isInstanceOf[NullType] => + true + + case (w, r) => + addError(s"Cannot write '$context': $w is incompatible with $r") + false + } + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala index 6da4f28b12962..9eed2eb202045 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala @@ -479,6 +479,25 @@ object Decimal { dec } + // Max precision of a decimal value stored in `numBytes` bytes + def maxPrecisionForBytes(numBytes: Int): Int = { + Math.round( // convert double to long + Math.floor(Math.log10( // number of base-10 digits + Math.pow(2, 8 * numBytes - 1) - 1))) // max value stored in numBytes + .asInstanceOf[Int] + } + + // Returns the minimum number of bytes needed to store a decimal with a given `precision`. + lazy val minBytesForPrecision = Array.tabulate[Int](39)(computeMinBytesForPrecision) + + private def computeMinBytesForPrecision(precision : Int) : Int = { + var numBytes = 1 + while (math.pow(2.0, 8 * numBytes - 1) < math.pow(10.0, precision)) { + numBytes += 1 + } + numBytes + } + // Evidence parameters for Decimal considered either as Fractional or Integral. We provide two // parameters inheriting from a common trait since both traits define mkNumericOps. // See scala.math's Numeric.scala for examples for Scala's built-in types. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala index dbf51c398fa47..15004e4b9667d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala @@ -48,7 +48,8 @@ case class DecimalType(precision: Int, scale: Int) extends FractionalType { } if (precision > DecimalType.MAX_PRECISION) { - throw new AnalysisException(s"DecimalType can only support precision up to 38") + throw new AnalysisException( + s"${DecimalType.simpleString} can only support precision up to ${DecimalType.MAX_PRECISION}") } // default constructor for Java @@ -120,6 +121,7 @@ object DecimalType extends AbstractDataType { val MINIMUM_ADJUSTED_SCALE = 6 // The decimal types compatible with other numeric types + private[sql] val BooleanDecimal = DecimalType(1, 0) private[sql] val ByteDecimal = DecimalType(3, 0) private[sql] val ShortDecimal = DecimalType(5, 0) private[sql] val IntDecimal = DecimalType(10, 0) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/HiveStringType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/HiveStringType.scala index e0bca937d1d84..4eb3226c5786e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/HiveStringType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/HiveStringType.scala @@ -56,14 +56,18 @@ object HiveStringType { } /** - * Hive char type. + * Hive char type. Similar to other HiveStringType's, these datatypes should only used for + * parsing, and should NOT be used anywhere else. Any instance of these data types should be + * replaced by a [[StringType]] before analysis. */ case class CharType(length: Int) extends HiveStringType { override def simpleString: String = s"char($length)" } /** - * Hive varchar type. + * Hive varchar type. Similar to other HiveStringType's, these datatypes should only used for + * parsing, and should NOT be used anywhere else. Any instance of these data types should be + * replaced by a [[StringType]] before analysis. */ case class VarcharType(length: Int) extends HiveStringType { override def simpleString: String = s"varchar($length)" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala index 6691b81dcea8d..594e155268bf6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala @@ -42,9 +42,9 @@ case class MapType( private[sql] def buildFormattedString(prefix: String, builder: StringBuilder): Unit = { builder.append(s"$prefix-- key: ${keyType.typeName}\n") + DataType.buildFormattedString(keyType, s"$prefix |", builder) builder.append(s"$prefix-- value: ${valueType.typeName} " + s"(valueContainsNull = $valueContainsNull)\n") - DataType.buildFormattedString(keyType, s"$prefix |", builder) DataType.buildFormattedString(valueType, s"$prefix |", builder) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Metadata.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Metadata.scala index 352fb545f4b6b..7c15dc0de4b6b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Metadata.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Metadata.scala @@ -215,6 +215,8 @@ object Metadata { x.## case x: Metadata => hash(x.map) + case null => + 0 case other => throw new RuntimeException(s"Do not support type ${other.getClass}.") } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ObjectType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ObjectType.scala index 2d49fe076786a..203e85e1c99bd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ObjectType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ObjectType.scala @@ -24,7 +24,8 @@ import org.apache.spark.annotation.InterfaceStability @InterfaceStability.Evolving object ObjectType extends AbstractDataType { override private[sql] def defaultConcreteType: DataType = - throw new UnsupportedOperationException("null literals can't be casted to ObjectType") + throw new UnsupportedOperationException( + s"null literals can't be casted to ${ObjectType.simpleString}") override private[sql] def acceptsType(other: DataType): Boolean = other match { case ObjectType(_) => true diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructField.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructField.scala index 2c18fdcc497fe..35f9970a0aaec 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructField.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructField.scala @@ -21,6 +21,7 @@ import org.json4s.JsonAST.JValue import org.json4s.JsonDSL._ import org.apache.spark.annotation.InterfaceStability +import org.apache.spark.sql.catalyst.util.{escapeSingleQuotedString, quoteIdentifier} /** * A field inside a StructType. @@ -74,4 +75,18 @@ case class StructField( def getComment(): Option[String] = { if (metadata.contains("comment")) Option(metadata.getString("comment")) else None } + + /** + * Returns a string containing a schema in DDL format. For example, the following value: + * `StructField("eventId", IntegerType)` will be converted to `eventId` INT. + * + * @since 2.4.0 + */ + def toDDL: String = { + val comment = getComment() + .map(escapeSingleQuotedString) + .map(" COMMENT '" + _ + "'") + + s"${quoteIdentifier(name)} ${dataType.sql}${comment.getOrElse("")}" + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala index 362676b252126..06289b1483203 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala @@ -27,7 +27,7 @@ import org.apache.spark.SparkException import org.apache.spark.annotation.InterfaceStability import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, InterpretedOrdering} import org.apache.spark.sql.catalyst.parser.{CatalystSqlParser, LegacyTypeStringParser} -import org.apache.spark.sql.catalyst.util.quoteIdentifier +import org.apache.spark.sql.catalyst.util.{escapeSingleQuotedString, quoteIdentifier} import org.apache.spark.util.Utils /** @@ -360,6 +360,16 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru s"STRUCT<${fieldTypes.mkString(", ")}>" } + /** + * Returns a string containing a schema in DDL format. For example, the following value: + * `StructType(Seq(StructField("eventId", IntegerType), StructField("s", StringType)))` + * will be converted to `eventId` INT, `s` STRING. + * The returned DDL schema can be used in a table creation. + * + * @since 2.4.0 + */ + def toDDL: String = fields.map(_.toDDL).mkString(",") + private[sql] override def simpleString(maxNumberFields: Int): String = { val builder = new StringBuilder val fieldTypes = fields.take(maxNumberFields).map { @@ -426,13 +436,15 @@ object StructType extends AbstractDataType { private[sql] def fromString(raw: String): StructType = { Try(DataType.fromJson(raw)).getOrElse(LegacyTypeStringParser.parse(raw)) match { case t: StructType => t - case _ => throw new RuntimeException(s"Failed parsing StructType: $raw") + case _ => throw new RuntimeException(s"Failed parsing ${StructType.simpleString}: $raw") } } /** * Creates StructType for a given DDL-formatted string, which is a comma separated list of field * definitions, e.g., a INT, b STRING. + * + * @since 2.2.0 */ def fromDDL(ddl: String): StructType = CatalystSqlParser.parseTableSchema(ddl) @@ -528,7 +540,8 @@ object StructType extends AbstractDataType { leftType case _ => - throw new SparkException(s"Failed to merge incompatible data types $left and $right") + throw new SparkException(s"Failed to merge incompatible data types ${left.catalogString}" + + s" and ${right.catalogString}") } private[sql] def fieldsMap(fields: Array[StructField]): Map[String, StructField] = { diff --git a/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/HiveHasherSuite.java b/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/HiveHasherSuite.java index 76930f9368514..b67c6f3e6e85e 100644 --- a/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/HiveHasherSuite.java +++ b/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/HiveHasherSuite.java @@ -17,8 +17,7 @@ package org.apache.spark.sql.catalyst.expressions; -import org.apache.spark.unsafe.memory.ByteArrayMemoryBlock; -import org.apache.spark.unsafe.memory.MemoryBlock; +import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.types.UTF8String; import org.junit.Assert; import org.junit.Test; @@ -54,7 +53,7 @@ public void testKnownStringAndIntInputs() { for (int i = 0; i < inputs.length; i++) { UTF8String s = UTF8String.fromString("val_" + inputs[i]); - int hash = HiveHasher.hashUnsafeBytesBlock(s.getMemoryBlock()); + int hash = HiveHasher.hashUnsafeBytes(s.getBaseObject(), s.getBaseOffset(), s.numBytes()); Assert.assertEquals(expected[i], ((31 * inputs[i]) + hash)); } } @@ -90,13 +89,13 @@ public void randomizedStressTestBytes() { int byteArrSize = rand.nextInt(100) * 8; byte[] bytes = new byte[byteArrSize]; rand.nextBytes(bytes); - MemoryBlock mb = ByteArrayMemoryBlock.fromArray(bytes); Assert.assertEquals( - HiveHasher.hashUnsafeBytesBlock(mb), - HiveHasher.hashUnsafeBytesBlock(mb)); + HiveHasher.hashUnsafeBytes(bytes, Platform.BYTE_ARRAY_OFFSET, byteArrSize), + HiveHasher.hashUnsafeBytes(bytes, Platform.BYTE_ARRAY_OFFSET, byteArrSize)); - hashcodes.add(HiveHasher.hashUnsafeBytesBlock(mb)); + hashcodes.add(HiveHasher.hashUnsafeBytes( + bytes, Platform.BYTE_ARRAY_OFFSET, byteArrSize)); } // A very loose bound. @@ -113,13 +112,13 @@ public void randomizedStressTestPaddedStrings() { byte[] strBytes = String.valueOf(i).getBytes(StandardCharsets.UTF_8); byte[] paddedBytes = new byte[byteArrSize]; System.arraycopy(strBytes, 0, paddedBytes, 0, strBytes.length); - MemoryBlock mb = ByteArrayMemoryBlock.fromArray(paddedBytes); Assert.assertEquals( - HiveHasher.hashUnsafeBytesBlock(mb), - HiveHasher.hashUnsafeBytesBlock(mb)); + HiveHasher.hashUnsafeBytes(paddedBytes, Platform.BYTE_ARRAY_OFFSET, byteArrSize), + HiveHasher.hashUnsafeBytes(paddedBytes, Platform.BYTE_ARRAY_OFFSET, byteArrSize)); - hashcodes.add(HiveHasher.hashUnsafeBytesBlock(mb)); + hashcodes.add(HiveHasher.hashUnsafeBytes( + paddedBytes, Platform.BYTE_ARRAY_OFFSET, byteArrSize)); } // A very loose bound. diff --git a/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/XXH64Suite.java b/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/XXH64Suite.java index cd8bce623c5df..1baee91b3439c 100644 --- a/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/XXH64Suite.java +++ b/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/XXH64Suite.java @@ -24,8 +24,6 @@ import java.util.Set; import org.apache.spark.unsafe.Platform; -import org.apache.spark.unsafe.memory.ByteArrayMemoryBlock; -import org.apache.spark.unsafe.memory.MemoryBlock; import org.junit.Assert; import org.junit.Test; @@ -144,13 +142,13 @@ public void randomizedStressTestBytes() { int byteArrSize = rand.nextInt(100) * 8; byte[] bytes = new byte[byteArrSize]; rand.nextBytes(bytes); - MemoryBlock mb = ByteArrayMemoryBlock.fromArray(bytes); Assert.assertEquals( - hasher.hashUnsafeWordsBlock(mb), - hasher.hashUnsafeWordsBlock(mb)); + hasher.hashUnsafeWords(bytes, Platform.BYTE_ARRAY_OFFSET, byteArrSize), + hasher.hashUnsafeWords(bytes, Platform.BYTE_ARRAY_OFFSET, byteArrSize)); - hashcodes.add(hasher.hashUnsafeWordsBlock(mb)); + hashcodes.add(hasher.hashUnsafeWords( + bytes, Platform.BYTE_ARRAY_OFFSET, byteArrSize)); } // A very loose bound. @@ -167,13 +165,13 @@ public void randomizedStressTestPaddedStrings() { byte[] strBytes = String.valueOf(i).getBytes(StandardCharsets.UTF_8); byte[] paddedBytes = new byte[byteArrSize]; System.arraycopy(strBytes, 0, paddedBytes, 0, strBytes.length); - MemoryBlock mb = ByteArrayMemoryBlock.fromArray(paddedBytes); Assert.assertEquals( - hasher.hashUnsafeWordsBlock(mb), - hasher.hashUnsafeWordsBlock(mb)); + hasher.hashUnsafeWords(paddedBytes, Platform.BYTE_ARRAY_OFFSET, byteArrSize), + hasher.hashUnsafeWords(paddedBytes, Platform.BYTE_ARRAY_OFFSET, byteArrSize)); - hashcodes.add(hasher.hashUnsafeWordsBlock(mb)); + hashcodes.add(hasher.hashUnsafeWords( + paddedBytes, Platform.BYTE_ARRAY_OFFSET, byteArrSize)); } // A very loose bound. diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala index 353b8344658f2..f9ee948b97e0a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala @@ -261,23 +261,6 @@ class ScalaReflectionSuite extends SparkFunSuite { } } - test("get parameter type from a function object") { - val primitiveFunc = (i: Int, j: Long) => "x" - val primitiveTypes = getParameterTypes(primitiveFunc) - assert(primitiveTypes.forall(_.isPrimitive)) - assert(primitiveTypes === Seq(classOf[Int], classOf[Long])) - - val boxedFunc = (i: java.lang.Integer, j: java.lang.Long) => "x" - val boxedTypes = getParameterTypes(boxedFunc) - assert(boxedTypes.forall(!_.isPrimitive)) - assert(boxedTypes === Seq(classOf[java.lang.Integer], classOf[java.lang.Long])) - - val anyFunc = (i: Any, j: AnyRef) => "x" - val anyTypes = getParameterTypes(anyFunc) - assert(anyTypes.forall(!_.isPrimitive)) - assert(anyTypes === Seq(classOf[java.lang.Object], classOf[java.lang.Object])) - } - test("SPARK-15062: Get correct serializer for List[_]") { val list = List(1, 2, 3) val serializer = serializerFor[List[Int]](BoundReference( diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/SchemaPruningTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/SchemaPruningTest.scala new file mode 100644 index 0000000000000..68e76fc013c18 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/SchemaPruningTest.scala @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst + +import org.scalatest.BeforeAndAfterAll + +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.internal.SQLConf.NESTED_SCHEMA_PRUNING_ENABLED + +/** + * A PlanTest that ensures that all tests in this suite are run with nested schema pruning enabled. + * Remove this trait once the default value of SQLConf.NESTED_SCHEMA_PRUNING_ENABLED is set to true. + */ +private[sql] trait SchemaPruningTest extends PlanTest with BeforeAndAfterAll { + private var originalConfSchemaPruningEnabled = false + + override protected def beforeAll(): Unit = { + originalConfSchemaPruningEnabled = conf.nestedSchemaPruningEnabled + conf.setConf(NESTED_SCHEMA_PRUNING_ENABLED, true) + super.beforeAll() + } + + override protected def afterAll(): Unit = { + try { + super.afterAll() + } finally { + conf.setConf(NESTED_SCHEMA_PRUNING_ENABLED, originalConfSchemaPruningEnabled) + } + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala index 5d2f8e735e3d4..94778840d706b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala @@ -277,13 +277,13 @@ class AnalysisErrorSuite extends AnalysisTest { errorTest( "intersect with unequal number of columns", - testRelation.intersect(testRelation2), + testRelation.intersect(testRelation2, isAll = false), "intersect" :: "number of columns" :: testRelation2.output.length.toString :: testRelation.output.length.toString :: Nil) errorTest( "except with unequal number of columns", - testRelation.except(testRelation2), + testRelation.except(testRelation2, isAll = false), "except" :: "number of columns" :: testRelation2.output.length.toString :: testRelation.output.length.toString :: Nil) @@ -299,22 +299,22 @@ class AnalysisErrorSuite extends AnalysisTest { errorTest( "intersect with incompatible column types", - testRelation.intersect(nestedRelation), + testRelation.intersect(nestedRelation, isAll = false), "intersect" :: "the compatible column types" :: Nil) errorTest( "intersect with a incompatible column type and compatible column types", - testRelation3.intersect(testRelation4), + testRelation3.intersect(testRelation4, isAll = false), "intersect" :: "the compatible column types" :: "map" :: "decimal" :: Nil) errorTest( "except with incompatible column types", - testRelation.except(nestedRelation), + testRelation.except(nestedRelation, isAll = false), "except" :: "the compatible column types" :: Nil) errorTest( "except with a incompatible column type and compatible column types", - testRelation3.except(testRelation4), + testRelation3.except(testRelation4, isAll = false), "except" :: "the compatible column types" :: "map" :: "decimal" :: Nil) errorTest( @@ -334,14 +334,28 @@ class AnalysisErrorSuite extends AnalysisTest { "start time greater than slide duration in time window", testRelation.select( TimeWindow(Literal("2016-01-01 01:01:01"), "1 second", "1 second", "1 minute").as("window")), - "The start time " :: " must be less than the slideDuration " :: Nil + "The absolute value of start time " :: " must be less than the slideDuration " :: Nil ) errorTest( "start time equal to slide duration in time window", testRelation.select( TimeWindow(Literal("2016-01-01 01:01:01"), "1 second", "1 second", "1 second").as("window")), - "The start time " :: " must be less than the slideDuration " :: Nil + "The absolute value of start time " :: " must be less than the slideDuration " :: Nil + ) + + errorTest( + "SPARK-21590: absolute value of start time greater than slide duration in time window", + testRelation.select( + TimeWindow(Literal("2016-01-01 01:01:01"), "1 second", "1 second", "-1 minute").as("window")), + "The absolute value of start time " :: " must be less than the slideDuration " :: Nil + ) + + errorTest( + "SPARK-21590: absolute value of start time equal to slide duration in time window", + testRelation.select( + TimeWindow(Literal("2016-01-01 01:01:01"), "1 second", "1 second", "-1 second").as("window")), + "The absolute value of start time " :: " must be less than the slideDuration " :: Nil ) errorTest( @@ -372,13 +386,6 @@ class AnalysisErrorSuite extends AnalysisTest { "The slide duration" :: " must be greater than 0." :: Nil ) - errorTest( - "negative start time in time window", - testRelation.select( - TimeWindow(Literal("2016-01-01 01:01:01"), "1 second", "1 second", "-5 second").as("window")), - "The start time" :: "must be greater than or equal to 0." :: Nil - ) - errorTest( "generator nested in expressions", listRelation.select(Explode('list) + 1), @@ -392,6 +399,12 @@ class AnalysisErrorSuite extends AnalysisTest { "Generators are not supported outside the SELECT clause, but got: Sort" :: Nil ) + errorTest( + "an evaluated limit class must not be null", + testRelation.limit(Literal(null, IntegerType)), + "The evaluated limit expression must not be null, but got " :: Nil + ) + errorTest( "num_rows in limit clause must be equal to or greater than 0", listRelation.limit(-1), @@ -514,14 +527,14 @@ class AnalysisErrorSuite extends AnalysisTest { right, joinType = Cross, condition = Some('b === 'd)) - assertAnalysisError(plan2, "EqualTo does not support ordering on type MapType" :: Nil) + assertAnalysisError(plan2, "EqualTo does not support ordering on type map" :: Nil) } test("PredicateSubQuery is used outside of a filter") { val a = AttributeReference("a", IntegerType)() val b = AttributeReference("b", IntegerType)() val plan = Project( - Seq(a, Alias(In(a, Seq(ListQuery(LocalRelation(b)))), "c")()), + Seq(a, Alias(InSubquery(Seq(a), ListQuery(LocalRelation(b))), "c")()), LocalRelation(a)) assertAnalysisError(plan, "Predicate sub-queries can only be used in a Filter" :: Nil) } @@ -530,12 +543,13 @@ class AnalysisErrorSuite extends AnalysisTest { val a = AttributeReference("a", IntegerType)() val b = AttributeReference("b", IntegerType)() val c = AttributeReference("c", BooleanType)() - val plan1 = Filter(Cast(Not(In(a, Seq(ListQuery(LocalRelation(b))))), BooleanType), + val plan1 = Filter(Cast(Not(InSubquery(Seq(a), ListQuery(LocalRelation(b)))), BooleanType), LocalRelation(a)) assertAnalysisError(plan1, "Null-aware predicate sub-queries cannot be used in nested conditions" :: Nil) - val plan2 = Filter(Or(Not(In(a, Seq(ListQuery(LocalRelation(b))))), c), LocalRelation(a, c)) + val plan2 = Filter( + Or(Not(InSubquery(Seq(a), ListQuery(LocalRelation(b)))), c), LocalRelation(a, c)) assertAnalysisError(plan2, "Null-aware predicate sub-queries cannot be used in nested conditions" :: Nil) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index cd8579584eada..3b3edac0a314e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -19,8 +19,11 @@ package org.apache.spark.sql.catalyst.analysis import java.util.TimeZone +import scala.reflect.ClassTag + import org.scalatest.Matchers +import org.apache.spark.api.python.PythonEvalType import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ @@ -232,7 +235,7 @@ class AnalysisSuite extends AnalysisTest with Matchers { checkAnalysis(plan, expected) } - test("Analysis may leave unnecassary aliases") { + test("Analysis may leave unnecessary aliases") { val att1 = testRelation.output.head var plan = testRelation.select( CreateStruct(Seq(att1, ((att1.as("aa")) + 1).as("a_plus_1"))).as("col"), @@ -270,7 +273,7 @@ class AnalysisSuite extends AnalysisTest with Matchers { } test("self intersect should resolve duplicate expression IDs") { - val plan = testRelation.intersect(testRelation) + val plan = testRelation.intersect(testRelation, isAll = false) assertAnalysisSuccess(plan) } @@ -314,16 +317,19 @@ class AnalysisSuite extends AnalysisTest with Matchers { checkUDF(udf1, expected1) // only primitive parameter needs special null handling - val udf2 = ScalaUDF((s: String, d: Double) => "x", StringType, string :: double :: Nil) - val expected2 = If(IsNull(double), nullResult, udf2) + val udf2 = ScalaUDF((s: String, d: Double) => "x", StringType, string :: double :: Nil, + nullableTypes = true :: false :: Nil) + val expected2 = + If(IsNull(double), nullResult, udf2.copy(children = string :: KnownNotNull(double) :: Nil)) checkUDF(udf2, expected2) // special null handling should apply to all primitive parameters - val udf3 = ScalaUDF((s: Short, d: Double) => "x", StringType, short :: double :: Nil) + val udf3 = ScalaUDF((s: Short, d: Double) => "x", StringType, short :: double :: Nil, + nullableTypes = false :: false :: Nil) val expected3 = If( IsNull(short) || IsNull(double), nullResult, - udf3) + udf3.copy(children = KnownNotNull(short) :: KnownNotNull(double) :: Nil)) checkUDF(udf3, expected3) // we can skip special null handling for primitive parameters that are not nullable @@ -331,14 +337,24 @@ class AnalysisSuite extends AnalysisTest with Matchers { val udf4 = ScalaUDF( (s: Short, d: Double) => "x", StringType, - short :: double.withNullability(false) :: Nil) + short :: double.withNullability(false) :: Nil, + nullableTypes = false :: false :: Nil) val expected4 = If( IsNull(short), nullResult, - udf4) + udf4.copy(children = KnownNotNull(short) :: double.withNullability(false) :: Nil)) // checkUDF(udf4, expected4) } + test("SPARK-24891 Fix HandleNullInputsForUDF rule") { + val a = testRelation.output(0) + val func = (x: Int, y: Int) => x + y + val udf1 = ScalaUDF(func, IntegerType, a :: a :: Nil) + val udf2 = ScalaUDF(func, IntegerType, a :: udf1 :: Nil) + val plan = Project(Alias(udf2, "")() :: Nil, testRelation) + comparePlans(plan.analyze, plan.analyze.analyze) + } + test("SPARK-11863 mixture of aliases and real columns in order by clause - tpcds 19,55,71") { val a = testRelation2.output(0) val c = testRelation2.output(2) @@ -426,8 +442,8 @@ class AnalysisSuite extends AnalysisTest with Matchers { val unionPlan = Union(firstTable, secondTable) assertAnalysisSuccess(unionPlan) - val r1 = Except(firstTable, secondTable) - val r2 = Intersect(firstTable, secondTable) + val r1 = Except(firstTable, secondTable, isAll = false) + val r2 = Intersect(firstTable, secondTable, isAll = false) assertAnalysisSuccess(r1) assertAnalysisSuccess(r2) @@ -518,9 +534,11 @@ class AnalysisSuite extends AnalysisTest with Matchers { } test("SPARK-22614 RepartitionByExpression partitioning") { - def checkPartitioning[T <: Partitioning](numPartitions: Int, exprs: Expression*): Unit = { + def checkPartitioning[T <: Partitioning: ClassTag]( + numPartitions: Int, exprs: Expression*): Unit = { val partitioning = RepartitionByExpression(exprs, testRelation2, numPartitions).partitioning - assert(partitioning.isInstanceOf[T]) + val clazz = implicitly[ClassTag[T]].runtimeClass + assert(clazz.isInstance(partitioning)) } checkPartitioning[HashPartitioning](numPartitions = 10, exprs = Literal(20)) @@ -544,17 +562,28 @@ class AnalysisSuite extends AnalysisTest with Matchers { } } - test("SPARK-20392: analysis barrier") { - // [[AnalysisBarrier]] will be removed after analysis - checkAnalysis( - Project(Seq(UnresolvedAttribute("tbl.a")), - AnalysisBarrier(SubqueryAlias("tbl", testRelation))), - Project(testRelation.output, SubqueryAlias("tbl", testRelation))) - - // Verify we won't go through a plan wrapped in a barrier. - // Since we wrap an unresolved plan and analyzer won't go through it. It remains unresolved. - val barrier = AnalysisBarrier(Project(Seq(UnresolvedAttribute("tbl.b")), - SubqueryAlias("tbl", testRelation))) - assertAnalysisError(barrier, Seq("cannot resolve '`tbl.b`'")) + test("SPARK-24208: analysis fails on self-join with FlatMapGroupsInPandas") { + val pythonUdf = PythonUDF("pyUDF", null, + StructType(Seq(StructField("a", LongType))), + Seq.empty, + PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF, + true) + val output = pythonUdf.dataType.asInstanceOf[StructType].toAttributes + val project = Project(Seq(UnresolvedAttribute("a")), testRelation) + val flatMapGroupsInPandas = FlatMapGroupsInPandas( + Seq(UnresolvedAttribute("a")), pythonUdf, output, project) + val left = SubqueryAlias("temp0", flatMapGroupsInPandas) + val right = SubqueryAlias("temp1", flatMapGroupsInPandas) + val join = Join(left, right, Inner, None) + assertAnalysisSuccess( + Project(Seq(UnresolvedAttribute("temp0.a"), UnresolvedAttribute("temp1.a")), join)) + } + + test("SPARK-24488 Generator with multiple aliases") { + assertAnalysisSuccess( + listRelation.select(Explode('list).as("first_alias").as("second_alias"))) + assertAnalysisSuccess( + listRelation.select(MultiAlias(MultiAlias( + PosExplode('list), Seq("first_pos", "first_val")), Seq("second_pos", "second_val")))) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DataSourceV2AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DataSourceV2AnalysisSuite.scala new file mode 100644 index 0000000000000..6c899b610ac5b --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DataSourceV2AnalysisSuite.scala @@ -0,0 +1,379 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis + +import java.util.Locale + +import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeReference, Cast, UpCast} +import org.apache.spark.sql.catalyst.plans.logical.{AppendData, LeafNode, LogicalPlan, Project} +import org.apache.spark.sql.types.{DoubleType, FloatType, StructField, StructType} + +case class TestRelation(output: Seq[AttributeReference]) extends LeafNode with NamedRelation { + override def name: String = "table-name" +} + +class DataSourceV2AnalysisSuite extends AnalysisTest { + val table = TestRelation(StructType(Seq( + StructField("x", FloatType), + StructField("y", FloatType))).toAttributes) + + val requiredTable = TestRelation(StructType(Seq( + StructField("x", FloatType, nullable = false), + StructField("y", FloatType, nullable = false))).toAttributes) + + val widerTable = TestRelation(StructType(Seq( + StructField("x", DoubleType), + StructField("y", DoubleType))).toAttributes) + + test("Append.byName: basic behavior") { + val query = TestRelation(table.schema.toAttributes) + + val parsedPlan = AppendData.byName(table, query) + + checkAnalysis(parsedPlan, parsedPlan) + assertResolved(parsedPlan) + } + + test("Append.byName: does not match by position") { + val query = TestRelation(StructType(Seq( + StructField("a", FloatType), + StructField("b", FloatType))).toAttributes) + + val parsedPlan = AppendData.byName(table, query) + + assertNotResolved(parsedPlan) + assertAnalysisError(parsedPlan, Seq( + "Cannot write incompatible data to table", "'table-name'", + "Cannot find data for output column", "'x'", "'y'")) + } + + test("Append.byName: case sensitive column resolution") { + val query = TestRelation(StructType(Seq( + StructField("X", FloatType), // doesn't match case! + StructField("y", FloatType))).toAttributes) + + val parsedPlan = AppendData.byName(table, query) + + assertNotResolved(parsedPlan) + assertAnalysisError(parsedPlan, Seq( + "Cannot write incompatible data to table", "'table-name'", + "Cannot find data for output column", "'x'"), + caseSensitive = true) + } + + test("Append.byName: case insensitive column resolution") { + val query = TestRelation(StructType(Seq( + StructField("X", FloatType), // doesn't match case! + StructField("y", FloatType))).toAttributes) + + val X = query.output.head + val y = query.output.last + + val parsedPlan = AppendData.byName(table, query) + val expectedPlan = AppendData.byName(table, + Project(Seq( + Alias(Cast(toLower(X), FloatType, Some(conf.sessionLocalTimeZone)), "x")(), + Alias(Cast(y, FloatType, Some(conf.sessionLocalTimeZone)), "y")()), + query)) + + assertNotResolved(parsedPlan) + checkAnalysis(parsedPlan, expectedPlan, caseSensitive = false) + assertResolved(expectedPlan) + } + + test("Append.byName: data columns are reordered by name") { + // out of order + val query = TestRelation(StructType(Seq( + StructField("y", FloatType), + StructField("x", FloatType))).toAttributes) + + val y = query.output.head + val x = query.output.last + + val parsedPlan = AppendData.byName(table, query) + val expectedPlan = AppendData.byName(table, + Project(Seq( + Alias(Cast(x, FloatType, Some(conf.sessionLocalTimeZone)), "x")(), + Alias(Cast(y, FloatType, Some(conf.sessionLocalTimeZone)), "y")()), + query)) + + assertNotResolved(parsedPlan) + checkAnalysis(parsedPlan, expectedPlan) + assertResolved(expectedPlan) + } + + test("Append.byName: fail nullable data written to required columns") { + val parsedPlan = AppendData.byName(requiredTable, table) + assertNotResolved(parsedPlan) + assertAnalysisError(parsedPlan, Seq( + "Cannot write incompatible data to table", "'table-name'", + "Cannot write nullable values to non-null column", "'x'", "'y'")) + } + + test("Append.byName: allow required data written to nullable columns") { + val parsedPlan = AppendData.byName(table, requiredTable) + assertResolved(parsedPlan) + checkAnalysis(parsedPlan, parsedPlan) + } + + test("Append.byName: missing required columns cause failure and are identified by name") { + // missing required field x + val query = TestRelation(StructType(Seq( + StructField("y", FloatType, nullable = false))).toAttributes) + + val parsedPlan = AppendData.byName(requiredTable, query) + + assertNotResolved(parsedPlan) + assertAnalysisError(parsedPlan, Seq( + "Cannot write incompatible data to table", "'table-name'", + "Cannot find data for output column", "'x'")) + } + + test("Append.byName: missing optional columns cause failure and are identified by name") { + // missing optional field x + val query = TestRelation(StructType(Seq( + StructField("y", FloatType))).toAttributes) + + val parsedPlan = AppendData.byName(table, query) + + assertNotResolved(parsedPlan) + assertAnalysisError(parsedPlan, Seq( + "Cannot write incompatible data to table", "'table-name'", + "Cannot find data for output column", "'x'")) + } + + test("Append.byName: fail canWrite check") { + val parsedPlan = AppendData.byName(table, widerTable) + + assertNotResolved(parsedPlan) + assertAnalysisError(parsedPlan, Seq( + "Cannot write", "'table-name'", + "Cannot safely cast", "'x'", "'y'", "DoubleType to FloatType")) + } + + test("Append.byName: insert safe cast") { + val x = table.output.head + val y = table.output.last + + val parsedPlan = AppendData.byName(widerTable, table) + val expectedPlan = AppendData.byName(widerTable, + Project(Seq( + Alias(Cast(x, DoubleType, Some(conf.sessionLocalTimeZone)), "x")(), + Alias(Cast(y, DoubleType, Some(conf.sessionLocalTimeZone)), "y")()), + table)) + + assertNotResolved(parsedPlan) + checkAnalysis(parsedPlan, expectedPlan) + assertResolved(expectedPlan) + } + + test("Append.byName: fail extra data fields") { + val query = TestRelation(StructType(Seq( + StructField("x", FloatType), + StructField("y", FloatType), + StructField("z", FloatType))).toAttributes) + + val parsedPlan = AppendData.byName(table, query) + + assertNotResolved(parsedPlan) + assertAnalysisError(parsedPlan, Seq( + "Cannot write", "'table-name'", "too many data columns", + "Table columns: 'x', 'y'", + "Data columns: 'x', 'y', 'z'")) + } + + test("Append.byName: multiple field errors are reported") { + val xRequiredTable = TestRelation(StructType(Seq( + StructField("x", FloatType, nullable = false), + StructField("y", DoubleType))).toAttributes) + + val query = TestRelation(StructType(Seq( + StructField("x", DoubleType), + StructField("b", FloatType))).toAttributes) + + val parsedPlan = AppendData.byName(xRequiredTable, query) + + assertNotResolved(parsedPlan) + assertAnalysisError(parsedPlan, Seq( + "Cannot write incompatible data to table", "'table-name'", + "Cannot safely cast", "'x'", "DoubleType to FloatType", + "Cannot write nullable values to non-null column", "'x'", + "Cannot find data for output column", "'y'")) + } + + test("Append.byPosition: basic behavior") { + val query = TestRelation(StructType(Seq( + StructField("a", FloatType), + StructField("b", FloatType))).toAttributes) + + val a = query.output.head + val b = query.output.last + + val parsedPlan = AppendData.byPosition(table, query) + val expectedPlan = AppendData.byPosition(table, + Project(Seq( + Alias(Cast(a, FloatType, Some(conf.sessionLocalTimeZone)), "x")(), + Alias(Cast(b, FloatType, Some(conf.sessionLocalTimeZone)), "y")()), + query)) + + assertNotResolved(parsedPlan) + checkAnalysis(parsedPlan, expectedPlan, caseSensitive = false) + assertResolved(expectedPlan) + } + + test("Append.byPosition: data columns are not reordered") { + // out of order + val query = TestRelation(StructType(Seq( + StructField("y", FloatType), + StructField("x", FloatType))).toAttributes) + + val y = query.output.head + val x = query.output.last + + val parsedPlan = AppendData.byPosition(table, query) + val expectedPlan = AppendData.byPosition(table, + Project(Seq( + Alias(Cast(y, FloatType, Some(conf.sessionLocalTimeZone)), "x")(), + Alias(Cast(x, FloatType, Some(conf.sessionLocalTimeZone)), "y")()), + query)) + + assertNotResolved(parsedPlan) + checkAnalysis(parsedPlan, expectedPlan) + assertResolved(expectedPlan) + } + + test("Append.byPosition: fail nullable data written to required columns") { + val parsedPlan = AppendData.byPosition(requiredTable, table) + assertNotResolved(parsedPlan) + assertAnalysisError(parsedPlan, Seq( + "Cannot write incompatible data to table", "'table-name'", + "Cannot write nullable values to non-null column", "'x'", "'y'")) + } + + test("Append.byPosition: allow required data written to nullable columns") { + val parsedPlan = AppendData.byPosition(table, requiredTable) + assertResolved(parsedPlan) + checkAnalysis(parsedPlan, parsedPlan) + } + + test("Append.byPosition: missing required columns cause failure") { + // missing optional field x + val query = TestRelation(StructType(Seq( + StructField("y", FloatType, nullable = false))).toAttributes) + + val parsedPlan = AppendData.byPosition(requiredTable, query) + + assertNotResolved(parsedPlan) + assertAnalysisError(parsedPlan, Seq( + "Cannot write", "'table-name'", "not enough data columns", + "Table columns: 'x', 'y'", + "Data columns: 'y'")) + } + + test("Append.byPosition: missing optional columns cause failure") { + // missing optional field x + val query = TestRelation(StructType(Seq( + StructField("y", FloatType))).toAttributes) + + val parsedPlan = AppendData.byPosition(table, query) + + assertNotResolved(parsedPlan) + assertAnalysisError(parsedPlan, Seq( + "Cannot write", "'table-name'", "not enough data columns", + "Table columns: 'x', 'y'", + "Data columns: 'y'")) + } + + test("Append.byPosition: fail canWrite check") { + val widerTable = TestRelation(StructType(Seq( + StructField("a", DoubleType), + StructField("b", DoubleType))).toAttributes) + + val parsedPlan = AppendData.byPosition(table, widerTable) + + assertNotResolved(parsedPlan) + assertAnalysisError(parsedPlan, Seq( + "Cannot write", "'table-name'", + "Cannot safely cast", "'x'", "'y'", "DoubleType to FloatType")) + } + + test("Append.byPosition: insert safe cast") { + val widerTable = TestRelation(StructType(Seq( + StructField("a", DoubleType), + StructField("b", DoubleType))).toAttributes) + + val x = table.output.head + val y = table.output.last + + val parsedPlan = AppendData.byPosition(widerTable, table) + val expectedPlan = AppendData.byPosition(widerTable, + Project(Seq( + Alias(Cast(x, DoubleType, Some(conf.sessionLocalTimeZone)), "a")(), + Alias(Cast(y, DoubleType, Some(conf.sessionLocalTimeZone)), "b")()), + table)) + + assertNotResolved(parsedPlan) + checkAnalysis(parsedPlan, expectedPlan) + assertResolved(expectedPlan) + } + + test("Append.byPosition: fail extra data fields") { + val query = TestRelation(StructType(Seq( + StructField("a", FloatType), + StructField("b", FloatType), + StructField("c", FloatType))).toAttributes) + + val parsedPlan = AppendData.byName(table, query) + + assertNotResolved(parsedPlan) + assertAnalysisError(parsedPlan, Seq( + "Cannot write", "'table-name'", "too many data columns", + "Table columns: 'x', 'y'", + "Data columns: 'a', 'b', 'c'")) + } + + test("Append.byPosition: multiple field errors are reported") { + val xRequiredTable = TestRelation(StructType(Seq( + StructField("x", FloatType, nullable = false), + StructField("y", DoubleType))).toAttributes) + + val query = TestRelation(StructType(Seq( + StructField("x", DoubleType), + StructField("b", FloatType))).toAttributes) + + val parsedPlan = AppendData.byPosition(xRequiredTable, query) + + assertNotResolved(parsedPlan) + assertAnalysisError(parsedPlan, Seq( + "Cannot write incompatible data to table", "'table-name'", + "Cannot write nullable values to non-null column", "'x'", + "Cannot safely cast", "'x'", "DoubleType to FloatType")) + } + + def assertNotResolved(logicalPlan: LogicalPlan): Unit = { + assert(!logicalPlan.resolved, s"Plan should not be resolved: $logicalPlan") + } + + def assertResolved(logicalPlan: LogicalPlan): Unit = { + assert(logicalPlan.resolved, s"Plan should be resolved: $logicalPlan") + } + + def toLower(attr: AttributeReference): AttributeReference = { + AttributeReference(attr.name.toLowerCase(Locale.ROOT), attr.dataType)(attr.exprId) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala index 36714bd631b0e..8eec14842c7e7 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala @@ -109,17 +109,17 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite { assertErrorForDifferingTypes(GreaterThan('intField, 'booleanField)) assertErrorForDifferingTypes(GreaterThanOrEqual('intField, 'booleanField)) - assertError(EqualTo('mapField, 'mapField), "EqualTo does not support ordering on type MapType") + assertError(EqualTo('mapField, 'mapField), "EqualTo does not support ordering on type map") assertError(EqualNullSafe('mapField, 'mapField), - "EqualNullSafe does not support ordering on type MapType") + "EqualNullSafe does not support ordering on type map") assertError(LessThan('mapField, 'mapField), - "LessThan does not support ordering on type MapType") + "LessThan does not support ordering on type map") assertError(LessThanOrEqual('mapField, 'mapField), - "LessThanOrEqual does not support ordering on type MapType") + "LessThanOrEqual does not support ordering on type map") assertError(GreaterThan('mapField, 'mapField), - "GreaterThan does not support ordering on type MapType") + "GreaterThan does not support ordering on type map") assertError(GreaterThanOrEqual('mapField, 'mapField), - "GreaterThanOrEqual does not support ordering on type MapType") + "GreaterThanOrEqual does not support ordering on type map") assertError(If('intField, 'stringField, 'stringField), "type of predicate expression in If should be boolean") @@ -169,10 +169,10 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite { CreateNamedStruct(Seq("a", "b", 2.0)), "even number of arguments") assertError( CreateNamedStruct(Seq(1, "a", "b", 2.0)), - "Only foldable StringType expressions are allowed to appear at odd position") + "Only foldable string expressions are allowed to appear at odd position") assertError( CreateNamedStruct(Seq('a.string.at(0), "a", "b", 2.0)), - "Only foldable StringType expressions are allowed to appear at odd position") + "Only foldable string expressions are allowed to appear at odd position") assertError( CreateNamedStruct(Seq(Literal.create(null, StringType), "a")), "Field name should not be null") diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/LookupFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/LookupFunctionsSuite.scala new file mode 100644 index 0000000000000..cea0f2a9cbc97 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/LookupFunctionsSuite.scala @@ -0,0 +1,104 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis + +import java.net.URI + +import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier} +import org.apache.spark.sql.catalyst.catalog.{CatalogDatabase, InMemoryCatalog, SessionCatalog} +import org.apache.spark.sql.catalyst.expressions.Alias +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.internal.SQLConf + +class LookupFunctionsSuite extends PlanTest { + + test("SPARK-23486: the functionExists for the Persistent function check") { + val externalCatalog = new CustomInMemoryCatalog + val conf = new SQLConf() + val catalog = new SessionCatalog(externalCatalog, FunctionRegistry.builtin, conf) + val analyzer = { + catalog.createDatabase( + CatalogDatabase("default", "", new URI("loc"), Map.empty), + ignoreIfExists = false) + new Analyzer(catalog, conf) + } + + def table(ref: String): LogicalPlan = UnresolvedRelation(TableIdentifier(ref)) + val unresolvedPersistentFunc = UnresolvedFunction("func", Seq.empty, false) + val unresolvedRegisteredFunc = UnresolvedFunction("max", Seq.empty, false) + val plan = Project( + Seq(Alias(unresolvedPersistentFunc, "call1")(), Alias(unresolvedPersistentFunc, "call2")(), + Alias(unresolvedPersistentFunc, "call3")(), Alias(unresolvedRegisteredFunc, "call4")(), + Alias(unresolvedRegisteredFunc, "call5")()), + table("TaBlE")) + analyzer.LookupFunctions.apply(plan) + + assert(externalCatalog.getFunctionExistsCalledTimes == 1) + assert(analyzer.LookupFunctions.normalizeFuncName + (unresolvedPersistentFunc.name).database == Some("default")) + } + + test("SPARK-23486: the functionExists for the Registered function check") { + val externalCatalog = new InMemoryCatalog + val conf = new SQLConf() + val customerFunctionReg = new CustomerFunctionRegistry + val catalog = new SessionCatalog(externalCatalog, customerFunctionReg, conf) + val analyzer = { + catalog.createDatabase( + CatalogDatabase("default", "", new URI("loc"), Map.empty), + ignoreIfExists = false) + new Analyzer(catalog, conf) + } + + def table(ref: String): LogicalPlan = UnresolvedRelation(TableIdentifier(ref)) + val unresolvedRegisteredFunc = UnresolvedFunction("max", Seq.empty, false) + val plan = Project( + Seq(Alias(unresolvedRegisteredFunc, "call1")(), Alias(unresolvedRegisteredFunc, "call2")()), + table("TaBlE")) + analyzer.LookupFunctions.apply(plan) + + assert(customerFunctionReg.getIsRegisteredFunctionCalledTimes == 2) + assert(analyzer.LookupFunctions.normalizeFuncName + (unresolvedRegisteredFunc.name).database == Some("default")) + } +} + +class CustomerFunctionRegistry extends SimpleFunctionRegistry { + + private var isRegisteredFunctionCalledTimes: Int = 0; + + override def functionExists(funcN: FunctionIdentifier): Boolean = synchronized { + isRegisteredFunctionCalledTimes = isRegisteredFunctionCalledTimes + 1 + true + } + + def getIsRegisteredFunctionCalledTimes: Int = isRegisteredFunctionCalledTimes +} + +class CustomInMemoryCatalog extends InMemoryCatalog { + + private var functionExistsCalledTimes: Int = 0 + + override def functionExists(db: String, funcName: String): Boolean = synchronized { + functionExistsCalledTimes = functionExistsCalledTimes + 1 + true + } + + def getFunctionExistsCalledTimes: Int = functionExistsCalledTimes +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveGroupingAnalyticsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveGroupingAnalyticsSuite.scala index 553b1598e7750..8da4d7e3aa372 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveGroupingAnalyticsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveGroupingAnalyticsSuite.scala @@ -91,6 +91,34 @@ class ResolveGroupingAnalyticsSuite extends AnalysisTest { assertAnalysisError(originalPlan3, Seq("doesn't show up in the GROUP BY list")) } + test("grouping sets with no explicit group by expressions") { + val originalPlan = GroupingSets(Seq(Seq(), Seq(unresolved_a), Seq(unresolved_a, unresolved_b)), + Nil, r1, + Seq(unresolved_a, unresolved_b, UnresolvedAlias(count(unresolved_c)))) + val expected = Aggregate(Seq(a, b, gid), Seq(a, b, count(c).as("count(c)")), + Expand( + Seq(Seq(a, b, c, nulInt, nulStr, 3), Seq(a, b, c, a, nulStr, 1), Seq(a, b, c, a, b, 0)), + Seq(a, b, c, a, b, gid), + Project(Seq(a, b, c, a.as("a"), b.as("b")), r1))) + checkAnalysis(originalPlan, expected) + + // Computation of grouping expression should remove duplicate expression based on their + // semantics (semanticEqual). + val originalPlan2 = GroupingSets(Seq(Seq(Multiply(unresolved_a, Literal(2))), + Seq(Multiply(Literal(2), unresolved_a), unresolved_b)), Nil, r1, + Seq(UnresolvedAlias(Multiply(unresolved_a, Literal(2))), + unresolved_b, UnresolvedAlias(count(unresolved_c)))) + + val resultPlan = getAnalyzer(true).executeAndCheck(originalPlan2) + val gExpressions = resultPlan.asInstanceOf[Aggregate].groupingExpressions + assert(gExpressions.size == 3) + val firstGroupingExprAttrName = + gExpressions(0).asInstanceOf[AttributeReference].name.replaceAll("#[0-9]*", "#0") + assert(firstGroupingExprAttrName == "(a#0 * 2)") + assert(gExpressions(1).asInstanceOf[AttributeReference].name == "b") + assert(gExpressions(2).asInstanceOf[AttributeReference].name == VirtualColumn.groupingIdName) + } + test("cube") { val originalPlan = Aggregate(Seq(Cube(Seq(unresolved_a, unresolved_b))), Seq(unresolved_a, unresolved_b, UnresolvedAlias(count(unresolved_c))), r1) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveHintsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveHintsSuite.scala index 9782b5fb0d266..bd66ee5355f45 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveHintsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveHintsSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.expressions.Literal import org.apache.spark.sql.catalyst.parser.CatalystSqlParser import org.apache.spark.sql.catalyst.plans.Inner import org.apache.spark.sql.catalyst.plans.logical._ @@ -120,4 +121,38 @@ class ResolveHintsSuite extends AnalysisTest { testRelation.where('a > 1).select('a).select('a).analyze, caseSensitive = false) } + + test("coalesce and repartition hint") { + checkAnalysis( + UnresolvedHint("COALESCE", Seq(Literal(10)), table("TaBlE")), + Repartition(numPartitions = 10, shuffle = false, child = testRelation)) + checkAnalysis( + UnresolvedHint("coalesce", Seq(Literal(20)), table("TaBlE")), + Repartition(numPartitions = 20, shuffle = false, child = testRelation)) + checkAnalysis( + UnresolvedHint("REPARTITION", Seq(Literal(100)), table("TaBlE")), + Repartition(numPartitions = 100, shuffle = true, child = testRelation)) + checkAnalysis( + UnresolvedHint("RePARTITion", Seq(Literal(200)), table("TaBlE")), + Repartition(numPartitions = 200, shuffle = true, child = testRelation)) + + val errMsgCoal = "COALESCE Hint expects a partition number as parameter" + assertAnalysisError( + UnresolvedHint("COALESCE", Seq.empty, table("TaBlE")), + Seq(errMsgCoal)) + assertAnalysisError( + UnresolvedHint("COALESCE", Seq(Literal(10), Literal(false)), table("TaBlE")), + Seq(errMsgCoal)) + assertAnalysisError( + UnresolvedHint("COALESCE", Seq(Literal(1.0)), table("TaBlE")), + Seq(errMsgCoal)) + + val errMsgRepa = "REPARTITION Hint expects a partition number as parameter" + assertAnalysisError( + UnresolvedHint("REPARTITION", Seq(UnresolvedAttribute("a")), table("TaBlE")), + Seq(errMsgRepa)) + assertAnalysisError( + UnresolvedHint("REPARTITION", Seq(Literal(true)), table("TaBlE")), + Seq(errMsgRepa)) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveLambdaVariablesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveLambdaVariablesSuite.scala new file mode 100644 index 0000000000000..c4171c75ecd03 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveLambdaVariablesSuite.scala @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis + +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} +import org.apache.spark.sql.catalyst.rules.RuleExecutor +import org.apache.spark.sql.types.{ArrayType, IntegerType} + +/** + * Test suite for [[ResolveLambdaVariables]]. + */ +class ResolveLambdaVariablesSuite extends PlanTest { + import org.apache.spark.sql.catalyst.dsl.expressions._ + import org.apache.spark.sql.catalyst.dsl.plans._ + + object Analyzer extends RuleExecutor[LogicalPlan] { + val batches = Batch("Resolution", FixedPoint(4), ResolveLambdaVariables(conf)) :: Nil + } + + private val key = 'key.int + private val values1 = 'values1.array(IntegerType) + private val values2 = 'values2.array(ArrayType(ArrayType(IntegerType))) + private val data = LocalRelation(Seq(key, values1, values2)) + private val lvInt = NamedLambdaVariable("x", IntegerType, nullable = true) + private val lvHiddenInt = NamedLambdaVariable("col0", IntegerType, nullable = true) + private val lvArray = NamedLambdaVariable("x", ArrayType(IntegerType), nullable = true) + + private def plan(e: Expression): LogicalPlan = data.select(e.as("res")) + + private def checkExpression(e1: Expression, e2: Expression): Unit = { + comparePlans(Analyzer.execute(plan(e1)), plan(e2)) + } + + test("resolution - no op") { + checkExpression(key, key) + } + + test("resolution - simple") { + val in = ArrayTransform(values1, LambdaFunction('x.attr + 1, 'x.attr :: Nil)) + val out = ArrayTransform(values1, LambdaFunction(lvInt + 1, lvInt :: Nil)) + checkExpression(in, out) + } + + test("resolution - nested") { + val in = ArrayTransform(values2, LambdaFunction( + ArrayTransform('x.attr, LambdaFunction('x.attr + 1, 'x.attr :: Nil)), 'x.attr :: Nil)) + val out = ArrayTransform(values2, LambdaFunction( + ArrayTransform(lvArray, LambdaFunction(lvInt + 1, lvInt :: Nil)), lvArray :: Nil)) + checkExpression(in, out) + } + + test("resolution - hidden") { + val in = ArrayTransform(values1, key) + val out = ArrayTransform(values1, LambdaFunction(key, lvHiddenInt :: Nil, hidden = true)) + checkExpression(in, out) + } + + test("fail - name collisions") { + val p = plan(ArrayTransform(values1, + LambdaFunction('x.attr + 'X.attr, 'x.attr :: 'X.attr :: Nil))) + val msg = intercept[AnalysisException](Analyzer.execute(p)).getMessage + assert(msg.contains("arguments should not have names that are semantically the same")) + } + + test("fail - lambda arguments") { + val p = plan(ArrayTransform(values1, + LambdaFunction('x.attr + 'y.attr + 'z.attr, 'x.attr :: 'y.attr :: 'z.attr :: Nil))) + val msg = intercept[AnalysisException](Analyzer.execute(p)).getMessage + assert(msg.contains("does not match the number of arguments expected")) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveSubquerySuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveSubquerySuite.scala index 1bf8d76da04d8..74a8590b5eefe 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveSubquerySuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveSubquerySuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.dsl.expressions._ -import org.apache.spark.sql.catalyst.expressions.{In, ListQuery, OuterReference} +import org.apache.spark.sql.catalyst.expressions.{InSubquery, ListQuery} import org.apache.spark.sql.catalyst.plans.logical.{Filter, LocalRelation, Project} /** @@ -33,7 +33,8 @@ class ResolveSubquerySuite extends AnalysisTest { val t2 = LocalRelation(b) test("SPARK-17251 Improve `OuterReference` to be `NamedExpression`") { - val expr = Filter(In(a, Seq(ListQuery(Project(Seq(UnresolvedAttribute("a")), t2)))), t1) + val expr = Filter( + InSubquery(Seq(a), ListQuery(Project(Seq(UnresolvedAttribute("a")), t2))), t1) val m = intercept[AnalysisException] { SimpleAnalyzer.checkAnalysis(SimpleAnalyzer.ResolveSubquery(expr)) }.getMessage diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala index 0acd3b490447d..461eda4334bb9 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala @@ -54,8 +54,9 @@ class TypeCoercionSuite extends AnalysisTest { // | NullType | ByteType | ShortType | IntegerType | LongType | DoubleType | FloatType | Dec(10, 2) | BinaryType | BooleanType | StringType | DateType | TimestampType | ArrayType | MapType | StructType | NullType | CalendarIntervalType | DecimalType(38, 18) | DoubleType | IntegerType | // | CalendarIntervalType | X | X | X | X | X | X | X | X | X | X | X | X | X | X | X | X | CalendarIntervalType | X | X | X | // +----------------------+----------+-----------+-------------+----------+------------+-----------+------------+------------+-------------+------------+----------+---------------+------------+----------+-------------+----------+----------------------+---------------------+-------------+--------------+ - // Note: MapType*, StructType* are castable only when the internal child types also match; otherwise, not castable. + // Note: StructType* is castable when all the internal child types are castable according to the table. // Note: ArrayType* is castable when the element type is castable according to the table. + // Note: MapType* is castable when both the key type and the value type are castable according to the table. // scalastyle:on line.size.limit private def shouldCast(from: DataType, to: AbstractDataType, expected: DataType): Unit = { @@ -396,7 +397,7 @@ class TypeCoercionSuite extends AnalysisTest { widenTest( StructType(Seq(StructField("a", IntegerType, nullable = false))), StructType(Seq(StructField("a", DoubleType, nullable = false))), - None) + Some(StructType(Seq(StructField("a", DoubleType, nullable = false))))) widenTest( StructType(Seq(StructField("a", IntegerType, nullable = false))), @@ -453,15 +454,18 @@ class TypeCoercionSuite extends AnalysisTest { def widenTestWithStringPromotion( t1: DataType, t2: DataType, - expected: Option[DataType]): Unit = { - checkWidenType(TypeCoercion.findWiderTypeForTwo, t1, t2, expected) + expected: Option[DataType], + isSymmetric: Boolean = true): Unit = { + checkWidenType(TypeCoercion.findWiderTypeForTwo, t1, t2, expected, isSymmetric) } def widenTestWithoutStringPromotion( t1: DataType, t2: DataType, - expected: Option[DataType]): Unit = { - checkWidenType(TypeCoercion.findWiderTypeWithoutStringPromotionForTwo, t1, t2, expected) + expected: Option[DataType], + isSymmetric: Boolean = true): Unit = { + checkWidenType( + TypeCoercion.findWiderTypeWithoutStringPromotionForTwo, t1, t2, expected, isSymmetric) } // Decimal @@ -487,12 +491,140 @@ class TypeCoercionSuite extends AnalysisTest { ArrayType(ArrayType(IntegerType), containsNull = false), ArrayType(ArrayType(LongType), containsNull = false), Some(ArrayType(ArrayType(LongType), containsNull = false))) + widenTestWithStringPromotion( + ArrayType(MapType(IntegerType, FloatType), containsNull = false), + ArrayType(MapType(LongType, DoubleType), containsNull = false), + Some(ArrayType(MapType(LongType, DoubleType), containsNull = false))) + widenTestWithStringPromotion( + ArrayType(new StructType().add("num", ShortType), containsNull = false), + ArrayType(new StructType().add("num", LongType), containsNull = false), + Some(ArrayType(new StructType().add("num", LongType), containsNull = false))) + widenTestWithStringPromotion( + ArrayType(IntegerType, containsNull = false), + ArrayType(DecimalType.IntDecimal, containsNull = false), + Some(ArrayType(DecimalType.IntDecimal, containsNull = false))) + widenTestWithStringPromotion( + ArrayType(DecimalType(36, 0), containsNull = false), + ArrayType(DecimalType(36, 35), containsNull = false), + Some(ArrayType(DecimalType(38, 35), containsNull = true))) + + // MapType + widenTestWithStringPromotion( + MapType(ShortType, TimestampType, valueContainsNull = true), + MapType(DoubleType, StringType, valueContainsNull = false), + Some(MapType(DoubleType, StringType, valueContainsNull = true))) + widenTestWithStringPromotion( + MapType(IntegerType, ArrayType(TimestampType), valueContainsNull = false), + MapType(LongType, ArrayType(StringType), valueContainsNull = true), + Some(MapType(LongType, ArrayType(StringType), valueContainsNull = true))) + widenTestWithStringPromotion( + MapType(IntegerType, MapType(ShortType, TimestampType), valueContainsNull = false), + MapType(LongType, MapType(DoubleType, StringType), valueContainsNull = false), + Some(MapType(LongType, MapType(DoubleType, StringType), valueContainsNull = false))) + widenTestWithStringPromotion( + MapType(IntegerType, new StructType().add("num", ShortType), valueContainsNull = false), + MapType(LongType, new StructType().add("num", LongType), valueContainsNull = false), + Some(MapType(LongType, new StructType().add("num", LongType), valueContainsNull = false))) + widenTestWithStringPromotion( + MapType(StringType, IntegerType, valueContainsNull = false), + MapType(StringType, DecimalType.IntDecimal, valueContainsNull = false), + Some(MapType(StringType, DecimalType.IntDecimal, valueContainsNull = false))) + widenTestWithStringPromotion( + MapType(StringType, DecimalType(36, 0), valueContainsNull = false), + MapType(StringType, DecimalType(36, 35), valueContainsNull = false), + Some(MapType(StringType, DecimalType(38, 35), valueContainsNull = true))) + widenTestWithStringPromotion( + MapType(IntegerType, StringType, valueContainsNull = false), + MapType(DecimalType.IntDecimal, StringType, valueContainsNull = false), + Some(MapType(DecimalType.IntDecimal, StringType, valueContainsNull = false))) + widenTestWithStringPromotion( + MapType(DecimalType(36, 0), StringType, valueContainsNull = false), + MapType(DecimalType(36, 35), StringType, valueContainsNull = false), + None) + + // StructType + widenTestWithStringPromotion( + new StructType() + .add("num", ShortType, nullable = true).add("ts", StringType, nullable = false), + new StructType() + .add("num", DoubleType, nullable = false).add("ts", TimestampType, nullable = true), + Some(new StructType() + .add("num", DoubleType, nullable = true).add("ts", StringType, nullable = true))) + widenTestWithStringPromotion( + new StructType() + .add("arr", ArrayType(ShortType, containsNull = false), nullable = false), + new StructType() + .add("arr", ArrayType(DoubleType, containsNull = true), nullable = false), + Some(new StructType() + .add("arr", ArrayType(DoubleType, containsNull = true), nullable = false))) + widenTestWithStringPromotion( + new StructType() + .add("map", MapType(ShortType, TimestampType, valueContainsNull = true), nullable = false), + new StructType() + .add("map", MapType(DoubleType, StringType, valueContainsNull = false), nullable = false), + Some(new StructType() + .add("map", MapType(DoubleType, StringType, valueContainsNull = true), nullable = false))) + widenTestWithStringPromotion( + new StructType().add("num", IntegerType, nullable = false), + new StructType().add("num", DecimalType.IntDecimal, nullable = false), + Some(new StructType().add("num", DecimalType.IntDecimal, nullable = false))) + widenTestWithStringPromotion( + new StructType().add("num", DecimalType(36, 0), nullable = false), + new StructType().add("num", DecimalType(36, 35), nullable = false), + Some(new StructType().add("num", DecimalType(38, 35), nullable = true))) + + widenTestWithStringPromotion( + new StructType().add("num", IntegerType), + new StructType().add("num", LongType).add("str", StringType), + None) + widenTestWithoutStringPromotion( + new StructType().add("num", IntegerType), + new StructType().add("num", LongType).add("str", StringType), + None) + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") { + widenTestWithStringPromotion( + new StructType().add("a", IntegerType), + new StructType().add("A", LongType), + None) + widenTestWithoutStringPromotion( + new StructType().add("a", IntegerType), + new StructType().add("A", LongType), + None) + } + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") { + widenTestWithStringPromotion( + new StructType().add("a", IntegerType), + new StructType().add("A", LongType), + Some(new StructType().add("a", LongType)), + isSymmetric = false) + widenTestWithoutStringPromotion( + new StructType().add("a", IntegerType), + new StructType().add("A", LongType), + Some(new StructType().add("a", LongType)), + isSymmetric = false) + } // Without string promotion widenTestWithoutStringPromotion(IntegerType, StringType, None) widenTestWithoutStringPromotion(StringType, TimestampType, None) widenTestWithoutStringPromotion(ArrayType(LongType), ArrayType(StringType), None) widenTestWithoutStringPromotion(ArrayType(StringType), ArrayType(TimestampType), None) + widenTestWithoutStringPromotion( + MapType(LongType, IntegerType), MapType(StringType, IntegerType), None) + widenTestWithoutStringPromotion( + MapType(IntegerType, LongType), MapType(IntegerType, StringType), None) + widenTestWithoutStringPromotion( + MapType(StringType, IntegerType), MapType(TimestampType, IntegerType), None) + widenTestWithoutStringPromotion( + MapType(IntegerType, StringType), MapType(IntegerType, TimestampType), None) + widenTestWithoutStringPromotion( + new StructType().add("a", IntegerType), + new StructType().add("a", StringType), + None) + widenTestWithoutStringPromotion( + new StructType().add("a", StringType), + new StructType().add("a", IntegerType), + None) // String promotion widenTestWithStringPromotion(IntegerType, StringType, Some(StringType)) @@ -501,6 +633,30 @@ class TypeCoercionSuite extends AnalysisTest { ArrayType(LongType), ArrayType(StringType), Some(ArrayType(StringType))) widenTestWithStringPromotion( ArrayType(StringType), ArrayType(TimestampType), Some(ArrayType(StringType))) + widenTestWithStringPromotion( + MapType(LongType, IntegerType), + MapType(StringType, IntegerType), + Some(MapType(StringType, IntegerType))) + widenTestWithStringPromotion( + MapType(IntegerType, LongType), + MapType(IntegerType, StringType), + Some(MapType(IntegerType, StringType))) + widenTestWithStringPromotion( + MapType(StringType, IntegerType), + MapType(TimestampType, IntegerType), + Some(MapType(StringType, IntegerType))) + widenTestWithStringPromotion( + MapType(IntegerType, StringType), + MapType(IntegerType, TimestampType), + Some(MapType(IntegerType, StringType))) + widenTestWithStringPromotion( + new StructType().add("a", IntegerType), + new StructType().add("a", StringType), + Some(new StructType().add("a", StringType))) + widenTestWithStringPromotion( + new StructType().add("a", StringType), + new StructType().add("a", IntegerType), + Some(new StructType().add("a", StringType))) } private def ruleTest(rule: Rule[LogicalPlan], initial: Expression, transformed: Expression) { @@ -563,46 +719,43 @@ class TypeCoercionSuite extends AnalysisTest { ruleTest(rule, Coalesce(Seq(doubleLit, intLit, floatLit)), - Coalesce(Seq(Cast(doubleLit, DoubleType), - Cast(intLit, DoubleType), Cast(floatLit, DoubleType)))) + Coalesce(Seq(doubleLit, Cast(intLit, DoubleType), Cast(floatLit, DoubleType)))) ruleTest(rule, Coalesce(Seq(longLit, intLit, decimalLit)), Coalesce(Seq(Cast(longLit, DecimalType(22, 0)), - Cast(intLit, DecimalType(22, 0)), Cast(decimalLit, DecimalType(22, 0))))) + Cast(intLit, DecimalType(22, 0)), decimalLit))) ruleTest(rule, Coalesce(Seq(nullLit, intLit)), - Coalesce(Seq(Cast(nullLit, IntegerType), Cast(intLit, IntegerType)))) + Coalesce(Seq(Cast(nullLit, IntegerType), intLit))) ruleTest(rule, Coalesce(Seq(timestampLit, stringLit)), - Coalesce(Seq(Cast(timestampLit, StringType), Cast(stringLit, StringType)))) + Coalesce(Seq(Cast(timestampLit, StringType), stringLit))) ruleTest(rule, Coalesce(Seq(nullLit, floatNullLit, intLit)), - Coalesce(Seq(Cast(nullLit, FloatType), Cast(floatNullLit, FloatType), - Cast(intLit, FloatType)))) + Coalesce(Seq(Cast(nullLit, FloatType), floatNullLit, Cast(intLit, FloatType)))) ruleTest(rule, Coalesce(Seq(nullLit, intLit, decimalLit, doubleLit)), Coalesce(Seq(Cast(nullLit, DoubleType), Cast(intLit, DoubleType), - Cast(decimalLit, DoubleType), Cast(doubleLit, DoubleType)))) + Cast(decimalLit, DoubleType), doubleLit))) ruleTest(rule, Coalesce(Seq(nullLit, floatNullLit, doubleLit, stringLit)), Coalesce(Seq(Cast(nullLit, StringType), Cast(floatNullLit, StringType), - Cast(doubleLit, StringType), Cast(stringLit, StringType)))) + Cast(doubleLit, StringType), stringLit))) ruleTest(rule, Coalesce(Seq(timestampLit, intLit, stringLit)), - Coalesce(Seq(Cast(timestampLit, StringType), Cast(intLit, StringType), - Cast(stringLit, StringType)))) + Coalesce(Seq(Cast(timestampLit, StringType), Cast(intLit, StringType), stringLit))) ruleTest(rule, Coalesce(Seq(tsArrayLit, intArrayLit, strArrayLit)), Coalesce(Seq(Cast(tsArrayLit, ArrayType(StringType)), - Cast(intArrayLit, ArrayType(StringType)), Cast(strArrayLit, ArrayType(StringType))))) + Cast(intArrayLit, ArrayType(StringType)), strArrayLit))) } test("CreateArray casts") { @@ -611,7 +764,7 @@ class TypeCoercionSuite extends AnalysisTest { :: Literal(1) :: Literal.create(1.0, FloatType) :: Nil), - CreateArray(Cast(Literal(1.0), DoubleType) + CreateArray(Literal(1.0) :: Cast(Literal(1), DoubleType) :: Cast(Literal.create(1.0, FloatType), DoubleType) :: Nil)) @@ -623,7 +776,7 @@ class TypeCoercionSuite extends AnalysisTest { :: Nil), CreateArray(Cast(Literal(1.0), StringType) :: Cast(Literal(1), StringType) - :: Cast(Literal("a"), StringType) + :: Literal("a") :: Nil)) ruleTest(TypeCoercion.FunctionArgumentConversion, @@ -641,7 +794,7 @@ class TypeCoercionSuite extends AnalysisTest { :: Nil), CreateArray(Literal.create(null, DecimalType(5, 3)).cast(DecimalType(38, 38)) :: Literal.create(null, DecimalType(22, 10)).cast(DecimalType(38, 38)) - :: Literal.create(null, DecimalType(38, 38)).cast(DecimalType(38, 38)) + :: Literal.create(null, DecimalType(38, 38)) :: Nil)) } @@ -655,7 +808,7 @@ class TypeCoercionSuite extends AnalysisTest { :: Nil), CreateMap(Cast(Literal(1), FloatType) :: Literal("a") - :: Cast(Literal.create(2.0, FloatType), FloatType) + :: Literal.create(2.0, FloatType) :: Literal("b") :: Nil)) ruleTest(TypeCoercion.FunctionArgumentConversion, @@ -677,7 +830,7 @@ class TypeCoercionSuite extends AnalysisTest { :: Literal(3.0) :: Nil), CreateMap(Literal(1) - :: Cast(Literal("a"), StringType) + :: Literal("a") :: Literal(2) :: Cast(Literal(3.0), StringType) :: Nil)) @@ -690,7 +843,7 @@ class TypeCoercionSuite extends AnalysisTest { CreateMap(Literal(1) :: Literal.create(null, DecimalType(38, 0)).cast(DecimalType(38, 38)) :: Literal(2) - :: Literal.create(null, DecimalType(38, 38)).cast(DecimalType(38, 38)) + :: Literal.create(null, DecimalType(38, 38)) :: Nil)) // type coercion for both map keys and values ruleTest(TypeCoercion.FunctionArgumentConversion, @@ -700,8 +853,8 @@ class TypeCoercionSuite extends AnalysisTest { :: Literal(3.0) :: Nil), CreateMap(Cast(Literal(1), DoubleType) - :: Cast(Literal("a"), StringType) - :: Cast(Literal(2.0), DoubleType) + :: Literal("a") + :: Literal(2.0) :: Cast(Literal(3.0), StringType) :: Nil)) } @@ -713,7 +866,7 @@ class TypeCoercionSuite extends AnalysisTest { :: Literal(1) :: Literal.create(1.0, FloatType) :: Nil), - operator(Cast(Literal(1.0), DoubleType) + operator(Literal(1.0) :: Cast(Literal(1), DoubleType) :: Cast(Literal.create(1.0, FloatType), DoubleType) :: Nil)) @@ -724,14 +877,14 @@ class TypeCoercionSuite extends AnalysisTest { :: Nil), operator(Cast(Literal(1L), DecimalType(22, 0)) :: Cast(Literal(1), DecimalType(22, 0)) - :: Cast(Literal(new java.math.BigDecimal("1000000000000000000000")), DecimalType(22, 0)) + :: Literal(new java.math.BigDecimal("1000000000000000000000")) :: Nil)) ruleTest(TypeCoercion.FunctionArgumentConversion, operator(Literal(1.0) :: Literal.create(null, DecimalType(10, 5)) :: Literal(1) :: Nil), - operator(Literal(1.0).cast(DoubleType) + operator(Literal(1.0) :: Literal.create(null, DecimalType(10, 5)).cast(DoubleType) :: Literal(1).cast(DoubleType) :: Nil)) @@ -1102,8 +1255,10 @@ class TypeCoercionSuite extends AnalysisTest { val expectedTypes = Seq(StringType, DecimalType.SYSTEM_DEFAULT, FloatType, DoubleType) - val r1 = widenSetOperationTypes(Except(firstTable, secondTable)).asInstanceOf[Except] - val r2 = widenSetOperationTypes(Intersect(firstTable, secondTable)).asInstanceOf[Intersect] + val r1 = widenSetOperationTypes( + Except(firstTable, secondTable, isAll = false)).asInstanceOf[Except] + val r2 = widenSetOperationTypes( + Intersect(firstTable, secondTable, isAll = false)).asInstanceOf[Intersect] checkOutput(r1.left, expectedTypes) checkOutput(r1.right, expectedTypes) checkOutput(r2.left, expectedTypes) @@ -1168,8 +1323,10 @@ class TypeCoercionSuite extends AnalysisTest { val expectedType1 = Seq(DecimalType(10, 8)) val r1 = widenSetOperationTypes(Union(left1, right1)).asInstanceOf[Union] - val r2 = widenSetOperationTypes(Except(left1, right1)).asInstanceOf[Except] - val r3 = widenSetOperationTypes(Intersect(left1, right1)).asInstanceOf[Intersect] + val r2 = widenSetOperationTypes( + Except(left1, right1, isAll = false)).asInstanceOf[Except] + val r3 = widenSetOperationTypes( + Intersect(left1, right1, isAll = false)).asInstanceOf[Intersect] checkOutput(r1.children.head, expectedType1) checkOutput(r1.children.last, expectedType1) @@ -1189,16 +1346,20 @@ class TypeCoercionSuite extends AnalysisTest { AttributeReference("r", rType)()) val r1 = widenSetOperationTypes(Union(plan1, plan2)).asInstanceOf[Union] - val r2 = widenSetOperationTypes(Except(plan1, plan2)).asInstanceOf[Except] - val r3 = widenSetOperationTypes(Intersect(plan1, plan2)).asInstanceOf[Intersect] + val r2 = widenSetOperationTypes( + Except(plan1, plan2, isAll = false)).asInstanceOf[Except] + val r3 = widenSetOperationTypes( + Intersect(plan1, plan2, isAll = false)).asInstanceOf[Intersect] checkOutput(r1.children.last, Seq(expectedType)) checkOutput(r2.right, Seq(expectedType)) checkOutput(r3.right, Seq(expectedType)) val r4 = widenSetOperationTypes(Union(plan2, plan1)).asInstanceOf[Union] - val r5 = widenSetOperationTypes(Except(plan2, plan1)).asInstanceOf[Except] - val r6 = widenSetOperationTypes(Intersect(plan2, plan1)).asInstanceOf[Intersect] + val r5 = widenSetOperationTypes( + Except(plan2, plan1, isAll = false)).asInstanceOf[Except] + val r6 = widenSetOperationTypes( + Intersect(plan2, plan1, isAll = false)).asInstanceOf[Intersect] checkOutput(r4.children.last, Seq(expectedType)) checkOutput(r5.left, Seq(expectedType)) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala index cb487c8893541..28a164b5d0cad 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala @@ -575,14 +575,14 @@ class UnsupportedOperationsSuite extends SparkFunSuite { // Except: *-stream not supported testBinaryOperationInStreamingPlan( "except", - _.except(_), + _.except(_, isAll = false), streamStreamSupported = false, batchStreamSupported = false) // Intersect: stream-stream not supported testBinaryOperationInStreamingPlan( "intersect", - _.intersect(_), + _.intersect(_, isAll = false), streamStreamSupported = false) // Sort: supported only on batch subplans and after aggregation on streaming plan + complete mode @@ -766,7 +766,7 @@ class UnsupportedOperationsSuite extends SparkFunSuite { * * To test this correctly, the given logical plan is wrapped in a fake operator that makes the * whole plan look like a streaming plan. Otherwise, a batch plan may throw not supported - * exception simply for not being a streaming plan, even though that plan could exists as batch + * exception simply for not being a streaming plan, even though that plan could exist as batch * subplan inside some streaming plan. */ def assertSupportedInStreamingPlan( @@ -793,7 +793,7 @@ class UnsupportedOperationsSuite extends SparkFunSuite { * * To test this correctly, the given logical plan is wrapped in a fake operator that makes the * whole plan look like a streaming plan. Otherwise, a batch plan may throw not supported - * exception simply for not being a streaming plan, even though that plan could exists as batch + * exception simply for not being a streaming plan, even though that plan could exist as batch * subplan inside some streaming plan. */ def assertNotSupportedInStreamingPlan( diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala index 6a7375ee186fa..89fabd4774065 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.catalog import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier} +import org.apache.spark.sql.catalyst.{AliasIdentifier, FunctionIdentifier, TableIdentifier} import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.parser.CatalystSqlParser @@ -537,11 +537,11 @@ abstract class SessionCatalogSuite extends AnalysisTest { val view = View(desc = metadata, output = metadata.schema.toAttributes, child = CatalystSqlParser.parsePlan(metadata.viewText.get)) comparePlans(catalog.lookupRelation(TableIdentifier("view1", Some("db3"))), - SubqueryAlias("view1", view)) + SubqueryAlias("view1", "db3", view)) // Look up a view using current database of the session catalog. catalog.setCurrentDatabase("db3") comparePlans(catalog.lookupRelation(TableIdentifier("view1")), - SubqueryAlias("view1", view)) + SubqueryAlias("view1", "db3", view)) } } @@ -1217,6 +1217,42 @@ abstract class SessionCatalogSuite extends AnalysisTest { } } + test("isRegisteredFunction") { + withBasicCatalog { catalog => + // Returns false when the function does not register + assert(!catalog.isRegisteredFunction(FunctionIdentifier("temp1"))) + + // Returns true when the function does register + val tempFunc1 = (e: Seq[Expression]) => e.head + catalog.registerFunction(newFunc("iff", None), overrideIfExists = false, + functionBuilder = Some(tempFunc1) ) + assert(catalog.isRegisteredFunction(FunctionIdentifier("iff"))) + + // Returns false when using the createFunction + catalog.createFunction(newFunc("sum", Some("db2")), ignoreIfExists = false) + assert(!catalog.isRegisteredFunction(FunctionIdentifier("sum"))) + assert(!catalog.isRegisteredFunction(FunctionIdentifier("sum", Some("db2")))) + } + } + + test("isPersistentFunction") { + withBasicCatalog { catalog => + // Returns false when the function does not register + assert(!catalog.isPersistentFunction(FunctionIdentifier("temp2"))) + + // Returns false when the function does register + val tempFunc2 = (e: Seq[Expression]) => e.head + catalog.registerFunction(newFunc("iff", None), overrideIfExists = false, + functionBuilder = Some(tempFunc2)) + assert(!catalog.isPersistentFunction(FunctionIdentifier("iff"))) + + // Return true when using the createFunction + catalog.createFunction(newFunc("sum", Some("db2")), ignoreIfExists = false) + assert(catalog.isPersistentFunction(FunctionIdentifier("sum", Some("db2")))) + assert(!catalog.isPersistentFunction(FunctionIdentifier("db2.sum"))) + } + } + test("drop function") { withBasicCatalog { catalog => assert(catalog.externalCatalog.listFunctions("db2", "*").toSet == Set("func1")) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala index 630113ce2d948..dd20e6497fbb4 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala @@ -144,7 +144,7 @@ class EncoderResolutionSuite extends PlanTest { // It should pass analysis val bound = encoder.resolveAndBind(attrs) - // If no null values appear, it should works fine + // If no null values appear, it should work fine bound.fromRow(InternalRow(new GenericArrayData(Array(1, 2)))) // If there is null value, it should throw runtime exception diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala index e6d09bdae67d7..f0d61de97ffcd 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.catalyst.{OptionalData, PrimitiveData} import org.apache.spark.sql.catalyst.analysis.AnalysisTest import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeReference} -import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.CodegenInterpretedPlanTest import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Project} import org.apache.spark.sql.catalyst.util.ArrayData import org.apache.spark.sql.types._ @@ -112,7 +112,7 @@ object ReferenceValueClass { case class Container(data: Int) } -class ExpressionEncoderSuite extends PlanTest with AnalysisTest { +class ExpressionEncoderSuite extends CodegenInterpretedPlanTest with AnalysisTest { OuterScopes.addOuterScope(this) implicit def encoder[T : TypeTag]: ExpressionEncoder[T] = verifyNotLeakingReflectionObjects { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala index 6ed175f86ca77..8d89f9c6c41d4 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala @@ -19,8 +19,8 @@ package org.apache.spark.sql.catalyst.encoders import scala.util.Random -import org.apache.spark.SparkFunSuite import org.apache.spark.sql.{RandomDataGenerator, Row} +import org.apache.spark.sql.catalyst.plans.CodegenInterpretedPlanTest import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData} import org.apache.spark.sql.types._ @@ -71,7 +71,7 @@ class ExamplePointUDT extends UserDefinedType[ExamplePoint] { private[spark] override def asNullable: ExamplePointUDT = this } -class RowEncoderSuite extends SparkFunSuite { +class RowEncoderSuite extends CodegenInterpretedPlanTest { private val structOfString = new StructType().add("str", StringType) private val structOfUDT = new StructType().add("udt", new ExamplePointUDT, false) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala index 6edb4348f8309..9a752af523ffc 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala @@ -282,6 +282,12 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper DataTypeTestUtils.ordered.foreach { dt => checkConsistencyBetweenInterpretedAndCodegen(Least, dt, 2) } + + val least = Least(Seq( + Literal.create(Seq(1, 2), ArrayType(IntegerType, containsNull = false)), + Literal.create(Seq(1, 3, null), ArrayType(IntegerType, containsNull = true)))) + assert(least.dataType === ArrayType(IntegerType, containsNull = true)) + checkEvaluation(least, Seq(1, 2)) } test("function greatest") { @@ -334,10 +340,16 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper DataTypeTestUtils.ordered.foreach { dt => checkConsistencyBetweenInterpretedAndCodegen(Greatest, dt, 2) } + + val greatest = Greatest(Seq( + Literal.create(Seq(1, 2), ArrayType(IntegerType, containsNull = false)), + Literal.create(Seq(1, 3, null), ArrayType(IntegerType, containsNull = true)))) + assert(greatest.dataType === ArrayType(IntegerType, containsNull = true)) + checkEvaluation(greatest, Seq(1, 3, null)) } test("SPARK-22499: Least and greatest should not generate codes beyond 64KB") { - val N = 3000 + val N = 2000 val strings = (1 to N).map(x => "s" * x) val inputsExpr = strings.map(Literal.create(_, StringType)) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala index 5b25bdf907c3a..d9f32c000a885 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala @@ -399,21 +399,35 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { } test("casting to fixed-precision decimals") { - // Overflow and rounding for casting to fixed-precision decimals: - // - Values should round with HALF_UP mode by default when you lower scale - // - Values that would overflow the target precision should turn into null - // - Because of this, casts to fixed-precision decimals should be nullable - - assert(cast(123, DecimalType.USER_DEFAULT).nullable === true) + assert(cast(123, DecimalType.USER_DEFAULT).nullable === false) assert(cast(10.03f, DecimalType.SYSTEM_DEFAULT).nullable === true) assert(cast(10.03, DecimalType.SYSTEM_DEFAULT).nullable === true) - assert(cast(Decimal(10.03), DecimalType.SYSTEM_DEFAULT).nullable === true) + assert(cast(Decimal(10.03), DecimalType.SYSTEM_DEFAULT).nullable === false) assert(cast(123, DecimalType(2, 1)).nullable === true) assert(cast(10.03f, DecimalType(2, 1)).nullable === true) assert(cast(10.03, DecimalType(2, 1)).nullable === true) assert(cast(Decimal(10.03), DecimalType(2, 1)).nullable === true) + assert(cast(123, DecimalType.IntDecimal).nullable === false) + assert(cast(10.03f, DecimalType.FloatDecimal).nullable === true) + assert(cast(10.03, DecimalType.DoubleDecimal).nullable === true) + assert(cast(Decimal(10.03), DecimalType(4, 2)).nullable === false) + assert(cast(Decimal(10.03), DecimalType(5, 3)).nullable === false) + + assert(cast(Decimal(10.03), DecimalType(3, 1)).nullable === true) + assert(cast(Decimal(10.03), DecimalType(4, 1)).nullable === false) + assert(cast(Decimal(9.95), DecimalType(2, 1)).nullable === true) + assert(cast(Decimal(9.95), DecimalType(3, 1)).nullable === false) + + assert(cast(Decimal("1003"), DecimalType(3, -1)).nullable === true) + assert(cast(Decimal("1003"), DecimalType(4, -1)).nullable === false) + assert(cast(Decimal("995"), DecimalType(2, -1)).nullable === true) + assert(cast(Decimal("995"), DecimalType(3, -1)).nullable === false) + + assert(cast(true, DecimalType.SYSTEM_DEFAULT).nullable === false) + assert(cast(true, DecimalType(1, 1)).nullable === true) + checkEvaluation(cast(10.03, DecimalType.SYSTEM_DEFAULT), Decimal(10.03)) checkEvaluation(cast(10.03, DecimalType(4, 2)), Decimal(10.03)) @@ -451,6 +465,20 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(cast(Decimal(-9.95), DecimalType(3, 1)), Decimal(-10.0)) checkEvaluation(cast(Decimal(-9.95), DecimalType(1, 0)), null) + checkEvaluation(cast(Decimal("1003"), DecimalType.SYSTEM_DEFAULT), Decimal(1003)) + checkEvaluation(cast(Decimal("1003"), DecimalType(4, 0)), Decimal(1003)) + checkEvaluation(cast(Decimal("1003"), DecimalType(3, -1)), Decimal(1000)) + checkEvaluation(cast(Decimal("1003"), DecimalType(2, -2)), Decimal(1000)) + checkEvaluation(cast(Decimal("1003"), DecimalType(1, -2)), null) + checkEvaluation(cast(Decimal("1003"), DecimalType(2, -1)), null) + checkEvaluation(cast(Decimal("1003"), DecimalType(3, 0)), null) + + checkEvaluation(cast(Decimal("995"), DecimalType(3, 0)), Decimal(995)) + checkEvaluation(cast(Decimal("995"), DecimalType(3, -1)), Decimal(1000)) + checkEvaluation(cast(Decimal("995"), DecimalType(2, -2)), Decimal(1000)) + checkEvaluation(cast(Decimal("995"), DecimalType(2, -1)), null) + checkEvaluation(cast(Decimal("995"), DecimalType(1, -2)), null) + checkEvaluation(cast(Double.NaN, DecimalType.SYSTEM_DEFAULT), null) checkEvaluation(cast(1.0 / 0.0, DecimalType.SYSTEM_DEFAULT), null) checkEvaluation(cast(Float.NaN, DecimalType.SYSTEM_DEFAULT), null) @@ -460,6 +488,9 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(cast(1.0 / 0.0, DecimalType(2, 1)), null) checkEvaluation(cast(Float.NaN, DecimalType(2, 1)), null) checkEvaluation(cast(1.0f / 0.0f, DecimalType(2, 1)), null) + + checkEvaluation(cast(true, DecimalType(2, 1)), Decimal(1)) + checkEvaluation(cast(true, DecimalType(1, 1)), null) } test("cast from date") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala index 5b71becee2de0..c383eec3d56b4 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala @@ -19,12 +19,16 @@ package org.apache.spark.sql.catalyst.expressions import java.sql.Timestamp +import org.apache.log4j.{Appender, AppenderSkeleton, Logger} +import org.apache.log4j.spi.LoggingEvent + import org.apache.spark.SparkFunSuite import org.apache.spark.metrics.source.CodegenMetrics import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.expressions.objects._ import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, DateTimeUtils} import org.apache.spark.sql.types._ @@ -499,4 +503,64 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper { ctx.freshName("a_1") :: ctx.freshName("a_0") :: Nil assert(names2.distinct.length == 4) } + + test("SPARK-25113: should log when there exists generated methods above HugeMethodLimit") { + class MockAppender extends AppenderSkeleton { + var seenMessage = false + + override def append(loggingEvent: LoggingEvent): Unit = { + if (loggingEvent.getRenderedMessage().contains("Generated method too long")) { + seenMessage = true + } + } + + override def close(): Unit = {} + override def requiresLayout(): Boolean = false + } + + val appender = new MockAppender() + withLogAppender(appender) { + val x = 42 + val expr = HugeCodeIntExpression(x) + val proj = GenerateUnsafeProjection.generate(Seq(expr)) + val actual = proj(null) + assert(actual.getInt(0) == x) + } + assert(appender.seenMessage) + } + + private def withLogAppender(appender: Appender)(f: => Unit): Unit = { + val logger = + Logger.getLogger(classOf[CodeGenerator[_, _]].getName) + logger.addAppender(appender) + try f finally { + logger.removeAppender(appender) + } + } +} + +case class HugeCodeIntExpression(value: Int) extends Expression { + override def nullable: Boolean = true + override def dataType: DataType = IntegerType + override def children: Seq[Expression] = Nil + override def eval(input: InternalRow): Any = value + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + // Assuming HugeMethodLimit to be 8000 + val HugeMethodLimit = CodeGenerator.DEFAULT_JVM_HUGE_METHOD_LIMIT + // A single "int dummyN = 0;" will be at least 2 bytes of bytecode: + // 0: iconst_0 + // 1: istore_1 + // and it'll become bigger as the number of local variables increases. + // So 4000 such dummy local variable definitions are sufficient to bump the bytecode size + // of a generated method to above 8000 bytes. + val hugeCode = (0 until (HugeMethodLimit / 2)).map(i => s"int dummy$i = 0;").mkString("\n") + val code = + code"""{ + | $hugeCode + |} + |boolean ${ev.isNull} = false; + |int ${ev.value} = $value; + """.stripMargin + ev.copy(code = code) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGeneratorWithInterpretedFallbackSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGeneratorWithInterpretedFallbackSuite.scala index 531ca9a87370a..28edd85ab6e87 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGeneratorWithInterpretedFallbackSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGeneratorWithInterpretedFallbackSuite.scala @@ -17,17 +17,33 @@ package org.apache.spark.sql.catalyst.expressions +import java.util.concurrent.ExecutionException + import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.expressions.codegen.{CodeAndComment, CodeGenerator} import org.apache.spark.sql.catalyst.plans.PlanTestBase import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types.{IntegerType, LongType} +import org.apache.spark.sql.types.IntegerType class CodeGeneratorWithInterpretedFallbackSuite extends SparkFunSuite with PlanTestBase { - test("UnsafeProjection with codegen factory mode") { - val input = Seq(LongType, IntegerType) - .zipWithIndex.map(x => BoundReference(x._2, x._1, true)) + object FailedCodegenProjection + extends CodeGeneratorWithInterpretedFallback[Seq[Expression], UnsafeProjection] { + + override protected def createCodeGeneratedObject(in: Seq[Expression]): UnsafeProjection = { + val invalidCode = new CodeAndComment("invalid code", Map.empty) + // We assume this compilation throws an exception + CodeGenerator.compile(invalidCode) + null + } + + override protected def createInterpretedObject(in: Seq[Expression]): UnsafeProjection = { + InterpretedUnsafeProjection.createProjection(in) + } + } + test("UnsafeProjection with codegen factory mode") { + val input = Seq(BoundReference(0, IntegerType, nullable = true)) val codegenOnly = CodegenObjectFactoryMode.CODEGEN_ONLY.toString withSQLConf(SQLConf.CODEGEN_FACTORY_MODE.key -> codegenOnly) { val obj = UnsafeProjection.createObject(input) @@ -40,4 +56,24 @@ class CodeGeneratorWithInterpretedFallbackSuite extends SparkFunSuite with PlanT assert(obj.isInstanceOf[InterpretedUnsafeProjection]) } } + + test("fallback to the interpreter mode") { + val input = Seq(BoundReference(0, IntegerType, nullable = true)) + val fallback = CodegenObjectFactoryMode.FALLBACK.toString + withSQLConf(SQLConf.CODEGEN_FACTORY_MODE.key -> fallback) { + val obj = FailedCodegenProjection.createObject(input) + assert(obj.isInstanceOf[InterpretedUnsafeProjection]) + } + } + + test("codegen failures in the CODEGEN_ONLY mode") { + val errMsg = intercept[ExecutionException] { + val input = Seq(BoundReference(0, IntegerType, nullable = true)) + val codegenOnly = CodegenObjectFactoryMode.CODEGEN_ONLY.toString + withSQLConf(SQLConf.CODEGEN_FACTORY_MODE.key -> codegenOnly) { + FailedCodegenProjection.createObject(input) + } + }.getMessage + assert(errMsg.contains("failed to compile: org.codehaus.commons.compiler.CompileException:")) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index d7744eb4c7dc7..c7db4ec9e16b1 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -20,47 +20,54 @@ package org.apache.spark.sql.catalyst.expressions import java.sql.{Date, Timestamp} import java.util.TimeZone +import scala.util.Random + import org.apache.spark.SparkFunSuite import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.util.DateTimeTestUtils +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.unsafe.array.ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH import org.apache.spark.unsafe.types.CalendarInterval class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { - def testSize(legacySizeOfNull: Boolean, sizeOfNull: Any): Unit = { + def testSize(sizeOfNull: Any): Unit = { val a0 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType)) val a1 = Literal.create(Seq[Integer](), ArrayType(IntegerType)) val a2 = Literal.create(Seq(1, 2), ArrayType(IntegerType)) - checkEvaluation(Size(a0, legacySizeOfNull), 3) - checkEvaluation(Size(a1, legacySizeOfNull), 0) - checkEvaluation(Size(a2, legacySizeOfNull), 2) + checkEvaluation(Size(a0), 3) + checkEvaluation(Size(a1), 0) + checkEvaluation(Size(a2), 2) val m0 = Literal.create(Map("a" -> "a", "b" -> "b"), MapType(StringType, StringType)) val m1 = Literal.create(Map[String, String](), MapType(StringType, StringType)) val m2 = Literal.create(Map("a" -> "a"), MapType(StringType, StringType)) - checkEvaluation(Size(m0, legacySizeOfNull), 2) - checkEvaluation(Size(m1, legacySizeOfNull), 0) - checkEvaluation(Size(m2, legacySizeOfNull), 1) + checkEvaluation(Size(m0), 2) + checkEvaluation(Size(m1), 0) + checkEvaluation(Size(m2), 1) checkEvaluation( - Size(Literal.create(null, MapType(StringType, StringType)), legacySizeOfNull), + Size(Literal.create(null, MapType(StringType, StringType))), expected = sizeOfNull) checkEvaluation( - Size(Literal.create(null, ArrayType(StringType)), legacySizeOfNull), + Size(Literal.create(null, ArrayType(StringType))), expected = sizeOfNull) } test("Array and Map Size - legacy") { - testSize(legacySizeOfNull = true, sizeOfNull = -1) + withSQLConf(SQLConf.LEGACY_SIZE_OF_NULL.key -> "true") { + testSize(sizeOfNull = -1) + } } test("Array and Map Size") { - testSize(legacySizeOfNull = false, sizeOfNull = null) + withSQLConf(SQLConf.LEGACY_SIZE_OF_NULL.key -> "false") { + testSize(sizeOfNull = null) + } } test("MapKeys/MapValues") { @@ -83,10 +90,12 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper val mi0 = Literal.create(Map(1 -> 1, 2 -> null, 3 -> 2), MapType(IntegerType, IntegerType)) val mi1 = Literal.create(Map[Int, Int](), MapType(IntegerType, IntegerType)) val mi2 = Literal.create(null, MapType(IntegerType, IntegerType)) + val mid0 = Literal.create(Map(1 -> 1.1, 2 -> 2.2), MapType(IntegerType, DoubleType)) checkEvaluation(MapEntries(mi0), Seq(r(1, 1), r(2, null), r(3, 2))) checkEvaluation(MapEntries(mi1), Seq.empty) checkEvaluation(MapEntries(mi2), null) + checkEvaluation(MapEntries(mid0), Seq(r(1, 1.1), r(2, 2.2))) // Non-primitive-type keys/values val ms0 = Literal.create(Map("a" -> "c", "b" -> null), MapType(StringType, StringType)) @@ -98,6 +107,165 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(MapEntries(ms2), null) } + test("Map Concat") { + val m0 = Literal.create(Map("a" -> "1", "b" -> "2"), MapType(StringType, StringType, + valueContainsNull = false)) + val m1 = Literal.create(Map("c" -> "3", "a" -> "4"), MapType(StringType, StringType, + valueContainsNull = false)) + val m2 = Literal.create(Map("d" -> "4", "e" -> "5"), MapType(StringType, StringType)) + val m3 = Literal.create(Map("a" -> "1", "b" -> "2"), MapType(StringType, StringType)) + val m4 = Literal.create(Map("a" -> null, "c" -> "3"), MapType(StringType, StringType)) + val m5 = Literal.create(Map("a" -> 1, "b" -> 2), MapType(StringType, IntegerType)) + val m6 = Literal.create(Map("a" -> null, "c" -> 3), MapType(StringType, IntegerType)) + val m7 = Literal.create(Map(List(1, 2) -> 1, List(3, 4) -> 2), + MapType(ArrayType(IntegerType), IntegerType)) + val m8 = Literal.create(Map(List(5, 6) -> 3, List(1, 2) -> 4), + MapType(ArrayType(IntegerType), IntegerType)) + val m9 = Literal.create(Map(Map(1 -> 2, 3 -> 4) -> 1, Map(5 -> 6, 7 -> 8) -> 2), + MapType(MapType(IntegerType, IntegerType), IntegerType)) + val m10 = Literal.create(Map(Map(9 -> 10, 11 -> 12) -> 3, Map(1 -> 2, 3 -> 4) -> 4), + MapType(MapType(IntegerType, IntegerType), IntegerType)) + val m11 = Literal.create(Map(1 -> "1", 2 -> "2"), MapType(IntegerType, StringType, + valueContainsNull = false)) + val m12 = Literal.create(Map(3 -> "3", 4 -> "4"), MapType(IntegerType, StringType, + valueContainsNull = false)) + val m13 = Literal.create(Map(1 -> 2, 3 -> 4), + MapType(IntegerType, IntegerType, valueContainsNull = false)) + val m14 = Literal.create(Map(5 -> 6), + MapType(IntegerType, IntegerType, valueContainsNull = false)) + val m15 = Literal.create(Map(7 -> null), + MapType(IntegerType, IntegerType, valueContainsNull = true)) + val mNull = Literal.create(null, MapType(StringType, StringType)) + + // overlapping maps + checkEvaluation(MapConcat(Seq(m0, m1)), + ( + Array("a", "b", "c", "a"), // keys + Array("1", "2", "3", "4") // values + ) + ) + + // maps with no overlap + checkEvaluation(MapConcat(Seq(m0, m2)), + Map("a" -> "1", "b" -> "2", "d" -> "4", "e" -> "5")) + + // 3 maps + checkEvaluation(MapConcat(Seq(m0, m1, m2)), + ( + Array("a", "b", "c", "a", "d", "e"), // keys + Array("1", "2", "3", "4", "4", "5") // values + ) + ) + + // null reference values + checkEvaluation(MapConcat(Seq(m3, m4)), + ( + Array("a", "b", "a", "c"), // keys + Array("1", "2", null, "3") // values + ) + ) + + // null primitive values + checkEvaluation(MapConcat(Seq(m5, m6)), + ( + Array("a", "b", "a", "c"), // keys + Array(1, 2, null, 3) // values + ) + ) + + // keys that are primitive + checkEvaluation(MapConcat(Seq(m11, m12)), + ( + Array(1, 2, 3, 4), // keys + Array("1", "2", "3", "4") // values + ) + ) + + // keys that are arrays, with overlap + checkEvaluation(MapConcat(Seq(m7, m8)), + ( + Array(List(1, 2), List(3, 4), List(5, 6), List(1, 2)), // keys + Array(1, 2, 3, 4) // values + ) + ) + + // keys that are maps, with overlap + checkEvaluation(MapConcat(Seq(m9, m10)), + ( + Array(Map(1 -> 2, 3 -> 4), Map(5 -> 6, 7 -> 8), Map(9 -> 10, 11 -> 12), + Map(1 -> 2, 3 -> 4)), // keys + Array(1, 2, 3, 4) // values + ) + ) + + // both keys and value are primitive and valueContainsNull = false + checkEvaluation(MapConcat(Seq(m13, m14)), Map(1 -> 2, 3 -> 4, 5 -> 6)) + + // both keys and value are primitive and valueContainsNull = true + checkEvaluation(MapConcat(Seq(m13, m15)), Map(1 -> 2, 3 -> 4, 7 -> null)) + + // null map + checkEvaluation(MapConcat(Seq(m0, mNull)), null) + checkEvaluation(MapConcat(Seq(mNull, m0)), null) + checkEvaluation(MapConcat(Seq(mNull, mNull)), null) + checkEvaluation(MapConcat(Seq(mNull)), null) + + // single map + checkEvaluation(MapConcat(Seq(m0)), Map("a" -> "1", "b" -> "2")) + + // no map + checkEvaluation(MapConcat(Seq.empty), Map.empty) + + // force split expressions for input in generated code + val expectedKeys = Array.fill(65)(Seq("a", "b")).flatten ++ Array("d", "e") + val expectedValues = Array.fill(65)(Seq("1", "2")).flatten ++ Array("4", "5") + checkEvaluation(MapConcat( + Seq( + m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, + m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, + m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m2 + )), + (expectedKeys, expectedValues)) + + // argument checking + assert(MapConcat(Seq(m0, m1)).checkInputDataTypes().isSuccess) + assert(MapConcat(Seq(m5, m6)).checkInputDataTypes().isSuccess) + assert(MapConcat(Seq(m0, m5)).checkInputDataTypes().isFailure) + assert(MapConcat(Seq(m0, Literal(12))).checkInputDataTypes().isFailure) + assert(MapConcat(Seq(m0, m1)).dataType.keyType == StringType) + assert(MapConcat(Seq(m0, m1)).dataType.valueType == StringType) + assert(!MapConcat(Seq(m0, m1)).dataType.valueContainsNull) + assert(MapConcat(Seq(m5, m6)).dataType.keyType == StringType) + assert(MapConcat(Seq(m5, m6)).dataType.valueType == IntegerType) + assert(MapConcat(Seq.empty).dataType.keyType == StringType) + assert(MapConcat(Seq.empty).dataType.valueType == StringType) + assert(MapConcat(Seq(m5, m6)).dataType.valueContainsNull) + assert(MapConcat(Seq(m6, m5)).dataType.valueContainsNull) + assert(!MapConcat(Seq(m1, m2)).nullable) + assert(MapConcat(Seq(m1, mNull)).nullable) + + val mapConcat = MapConcat(Seq( + Literal.create(Map(Seq(1, 2) -> Seq("a", "b")), + MapType( + ArrayType(IntegerType, containsNull = false), + ArrayType(StringType, containsNull = false), + valueContainsNull = false)), + Literal.create(Map(Seq(3, 4, null) -> Seq("c", "d", null), Seq(6) -> null), + MapType( + ArrayType(IntegerType, containsNull = true), + ArrayType(StringType, containsNull = true), + valueContainsNull = true)))) + assert(mapConcat.dataType === + MapType( + ArrayType(IntegerType, containsNull = true), + ArrayType(StringType, containsNull = true), + valueContainsNull = true)) + checkEvaluation(mapConcat, Map( + Seq(1, 2) -> Seq("a", "b"), + Seq(3, 4, null) -> Seq("c", "d", null), + Seq(6) -> null)) + } + test("MapFromEntries") { def arrayType(keyType: DataType, valueType: DataType) : DataType = { ArrayType( @@ -158,12 +326,19 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper val d2 = new Decimal().set(100) val a4 = Literal.create(Seq(d2, d1), ArrayType(DecimalType(10, 0))) val a5 = Literal.create(Seq(null, null), ArrayType(NullType)) + val a6 = Literal.create(Seq(true, false, true, false), + ArrayType(BooleanType, containsNull = false)) + val a7 = Literal.create(Seq(true, false, true, false), ArrayType(BooleanType)) + val a8 = Literal.create(Seq(true, false, true, null, false), ArrayType(BooleanType)) checkEvaluation(new SortArray(a0), Seq(1, 2, 3)) checkEvaluation(new SortArray(a1), Seq[Integer]()) checkEvaluation(new SortArray(a2), Seq("a", "b")) checkEvaluation(new SortArray(a3), Seq(null, "a", "b")) checkEvaluation(new SortArray(a4), Seq(d1, d2)) + checkEvaluation(new SortArray(a6), Seq(false, false, true, true)) + checkEvaluation(new SortArray(a7), Seq(false, false, true, true)) + checkEvaluation(new SortArray(a8), Seq(null, false, false, true, true)) checkEvaluation(SortArray(a0, Literal(true)), Seq(1, 2, 3)) checkEvaluation(SortArray(a1, Literal(true)), Seq[Integer]()) checkEvaluation(SortArray(a2, Literal(true)), Seq("a", "b")) @@ -213,10 +388,15 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper val a1 = Literal.create(Seq[String](null, ""), ArrayType(StringType)) val a2 = Literal.create(Seq(null), ArrayType(LongType)) val a3 = Literal.create(null, ArrayType(StringType)) + val a4 = Literal.create(Seq(create_row(1)), ArrayType(StructType(Seq( + StructField("a", IntegerType, true))))) + // Explicitly mark the array type not nullable (spark-25308) + val a5 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType, containsNull = false)) checkEvaluation(ArrayContains(a0, Literal(1)), true) checkEvaluation(ArrayContains(a0, Literal(0)), false) checkEvaluation(ArrayContains(a0, Literal.create(null, IntegerType)), null) + checkEvaluation(ArrayContains(a5, Literal(1)), true) checkEvaluation(ArrayContains(a1, Literal("")), true) checkEvaluation(ArrayContains(a1, Literal("a")), null) @@ -228,6 +408,11 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(ArrayContains(a3, Literal("")), null) checkEvaluation(ArrayContains(a3, Literal.create(null, StringType)), null) + checkEvaluation(ArrayContains(a4, Literal.create(create_row(1), StructType(Seq( + StructField("a", IntegerType, false))))), true) + checkEvaluation(ArrayContains(a4, Literal.create(create_row(0), StructType(Seq( + StructField("a", IntegerType, false))))), false) + // binary val b0 = Literal.create(Seq[Array[Byte]](Array[Byte](5, 6), Array[Byte](1, 2)), ArrayType(BinaryType)) @@ -264,6 +449,7 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper val a4 = Literal.create(Seq[String](null, ""), ArrayType(StringType)) val a5 = Literal.create(Seq[String]("", "abc"), ArrayType(StringType)) val a6 = Literal.create(Seq[String]("def", "ghi"), ArrayType(StringType)) + val a7 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType, containsNull = false)) val emptyIntArray = Literal.create(Seq.empty[Int], ArrayType(IntegerType)) @@ -278,6 +464,7 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(ArraysOverlap(a4, a5), true) checkEvaluation(ArraysOverlap(a4, a6), null) checkEvaluation(ArraysOverlap(a5, a6), false) + checkEvaluation(ArraysOverlap(a7, a7), true) // null handling checkEvaluation(ArraysOverlap(emptyIntArray, a2), false) @@ -296,9 +483,12 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper ArrayType(BinaryType)) val b2 = Literal.create(Seq[Array[Byte]](Array[Byte](2, 1), Array[Byte](4, 3)), ArrayType(BinaryType)) + val b3 = Literal.create(Seq[Array[Byte]](Array[Byte](1, 2), Array[Byte](3, 4)), + ArrayType(BinaryType, containsNull = false)) checkEvaluation(ArraysOverlap(b0, b1), true) checkEvaluation(ArraysOverlap(b0, b2), false) + checkEvaluation(ArraysOverlap(b3, b3), true) // arrays of complex data types val aa0 = Literal.create(Seq[Array[String]](Array[String]("a", "b"), Array[String]("c", "d")), @@ -912,11 +1102,11 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper test("Concat") { // Primitive-type elements - val ai0 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType)) - val ai1 = Literal.create(Seq.empty[Integer], ArrayType(IntegerType)) - val ai2 = Literal.create(Seq(4, null, 5), ArrayType(IntegerType)) - val ai3 = Literal.create(Seq(null, null), ArrayType(IntegerType)) - val ai4 = Literal.create(null, ArrayType(IntegerType)) + val ai0 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType, containsNull = false)) + val ai1 = Literal.create(Seq.empty[Integer], ArrayType(IntegerType, containsNull = false)) + val ai2 = Literal.create(Seq(4, null, 5), ArrayType(IntegerType, containsNull = true)) + val ai3 = Literal.create(Seq(null, null), ArrayType(IntegerType, containsNull = true)) + val ai4 = Literal.create(null, ArrayType(IntegerType, containsNull = false)) checkEvaluation(Concat(Seq(ai0)), Seq(1, 2, 3)) checkEvaluation(Concat(Seq(ai0, ai1)), Seq(1, 2, 3)) @@ -929,14 +1119,18 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(Concat(Seq(ai4, ai0)), null) // Non-primitive-type elements - val as0 = Literal.create(Seq("a", "b", "c"), ArrayType(StringType)) - val as1 = Literal.create(Seq.empty[String], ArrayType(StringType)) - val as2 = Literal.create(Seq("d", null, "e"), ArrayType(StringType)) - val as3 = Literal.create(Seq(null, null), ArrayType(StringType)) - val as4 = Literal.create(null, ArrayType(StringType)) - - val aa0 = Literal.create(Seq(Seq("a", "b"), Seq("c")), ArrayType(ArrayType(StringType))) - val aa1 = Literal.create(Seq(Seq("d"), Seq("e", "f")), ArrayType(ArrayType(StringType))) + val as0 = Literal.create(Seq("a", "b", "c"), ArrayType(StringType, containsNull = false)) + val as1 = Literal.create(Seq.empty[String], ArrayType(StringType, containsNull = false)) + val as2 = Literal.create(Seq("d", null, "e"), ArrayType(StringType, containsNull = true)) + val as3 = Literal.create(Seq(null, null), ArrayType(StringType, containsNull = true)) + val as4 = Literal.create(null, ArrayType(StringType, containsNull = false)) + + val aa0 = Literal.create(Seq(Seq("a", "b"), Seq("c")), + ArrayType(ArrayType(StringType, containsNull = false), containsNull = false)) + val aa1 = Literal.create(Seq(Seq("d"), Seq("e", "f")), + ArrayType(ArrayType(StringType, containsNull = false), containsNull = false)) + val aa2 = Literal.create(Seq(Seq("g", null), null), + ArrayType(ArrayType(StringType, containsNull = true), containsNull = true)) checkEvaluation(Concat(Seq(as0)), Seq("a", "b", "c")) checkEvaluation(Concat(Seq(as0, as1)), Seq("a", "b", "c")) @@ -949,6 +1143,18 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(Concat(Seq(as4, as0)), null) checkEvaluation(Concat(Seq(aa0, aa1)), Seq(Seq("a", "b"), Seq("c"), Seq("d"), Seq("e", "f"))) + + assert(Concat(Seq(ai0, ai1)).dataType.asInstanceOf[ArrayType].containsNull === false) + assert(Concat(Seq(ai0, ai2)).dataType.asInstanceOf[ArrayType].containsNull === true) + assert(Concat(Seq(as0, as1)).dataType.asInstanceOf[ArrayType].containsNull === false) + assert(Concat(Seq(as0, as2)).dataType.asInstanceOf[ArrayType].containsNull === true) + assert(Concat(Seq(aa0, aa1)).dataType === + ArrayType(ArrayType(StringType, containsNull = false), containsNull = false)) + assert(Concat(Seq(aa0, aa2)).dataType === + ArrayType(ArrayType(StringType, containsNull = true), containsNull = true)) + + // force split expressions for input in generated code + checkEvaluation(Concat(Seq.fill(100)(ai0)), Seq.fill(100)(Seq(1, 2, 3)).flatten) } test("Flatten") { @@ -1166,4 +1372,388 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(ArrayDistinct(c1), Seq[Seq[Int]](Seq[Int](5, 6), Seq[Int](2, 1))) checkEvaluation(ArrayDistinct(c2), Seq[Seq[Int]](null, Seq[Int](2, 1))) } + + test("Array Union") { + val a00 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType, containsNull = false)) + val a01 = Literal.create(Seq(4, 2), ArrayType(IntegerType, containsNull = false)) + val a02 = Literal.create(Seq(1, 2, null, 4, 5), ArrayType(IntegerType, containsNull = true)) + val a03 = Literal.create(Seq(-5, 4, -3, 2, 4), ArrayType(IntegerType, containsNull = false)) + val a04 = Literal.create(Seq.empty[Int], ArrayType(IntegerType, containsNull = false)) + val abl0 = Literal.create(Seq[Boolean](true, true), ArrayType(BooleanType, false)) + val abl1 = Literal.create(Seq[Boolean](false, false), ArrayType(BooleanType, false)) + val ab0 = Literal.create(Seq[Byte](1, 2, 3, 2), ArrayType(ByteType, false)) + val ab1 = Literal.create(Seq[Byte](4, 2, 4), ArrayType(ByteType, false)) + val as0 = Literal.create(Seq[Short](1, 2, 3, 2), ArrayType(ShortType, false)) + val as1 = Literal.create(Seq[Short](4, 2, 4), ArrayType(ShortType, false)) + val af0 = Literal.create(Seq[Float](1.1F, 2.2F, 3.3F, 2.2F), ArrayType(FloatType, false)) + val af1 = Literal.create(Seq[Float](4.4F, 2.2F, 4.4F), ArrayType(FloatType, false)) + val ad0 = Literal.create(Seq[Double](1.1, 2.2, 3.3, 2.2), ArrayType(DoubleType, false)) + val ad1 = Literal.create(Seq[Double](4.4, 2.2, 4.4), ArrayType(DoubleType, false)) + + val a10 = Literal.create(Seq(1L, 2L, 3L), ArrayType(LongType, containsNull = false)) + val a11 = Literal.create(Seq(4L, 2L), ArrayType(LongType, containsNull = false)) + val a12 = Literal.create(Seq(1L, 2L, null, 4L, 5L), ArrayType(LongType, containsNull = true)) + val a13 = Literal.create(Seq(-5L, 4L, -3L, 2L, -1L), ArrayType(LongType, containsNull = false)) + val a14 = Literal.create(Seq.empty[Long], ArrayType(LongType, containsNull = false)) + + val a20 = Literal.create(Seq("b", "a", "c"), ArrayType(StringType, containsNull = false)) + val a21 = Literal.create(Seq("c", "d", "a", "f"), ArrayType(StringType, containsNull = false)) + val a22 = Literal.create(Seq("b", null, "a", "g"), ArrayType(StringType, containsNull = true)) + + val a30 = Literal.create(Seq(null, null), ArrayType(IntegerType)) + val a31 = Literal.create(null, ArrayType(StringType)) + + checkEvaluation(ArrayUnion(a00, a01), Seq(1, 2, 3, 4)) + checkEvaluation(ArrayUnion(a02, a03), Seq(1, 2, null, 4, 5, -5, -3)) + checkEvaluation(ArrayUnion(a03, a02), Seq(-5, 4, -3, 2, 1, null, 5)) + checkEvaluation(ArrayUnion(a02, a04), Seq(1, 2, null, 4, 5)) + checkEvaluation(ArrayUnion(abl0, abl1), Seq[Boolean](true, false)) + checkEvaluation(ArrayUnion(ab0, ab1), Seq[Byte](1, 2, 3, 4)) + checkEvaluation(ArrayUnion(as0, as1), Seq[Short](1, 2, 3, 4)) + checkEvaluation(ArrayUnion(af0, af1), Seq[Float](1.1F, 2.2F, 3.3F, 4.4F)) + checkEvaluation(ArrayUnion(ad0, ad1), Seq[Double](1.1, 2.2, 3.3, 4.4)) + + checkEvaluation(ArrayUnion(a10, a11), Seq(1L, 2L, 3L, 4L)) + checkEvaluation(ArrayUnion(a12, a13), Seq(1L, 2L, null, 4L, 5L, -5L, -3L, -1L)) + checkEvaluation(ArrayUnion(a13, a12), Seq(-5L, 4L, -3L, 2L, -1L, 1L, null, 5L)) + checkEvaluation(ArrayUnion(a12, a14), Seq(1L, 2L, null, 4L, 5L)) + + checkEvaluation(ArrayUnion(a20, a21), Seq("b", "a", "c", "d", "f")) + checkEvaluation(ArrayUnion(a20, a22), Seq("b", "a", "c", null, "g")) + + checkEvaluation(ArrayUnion(a30, a30), Seq(null)) + checkEvaluation(ArrayUnion(a20, a31), null) + checkEvaluation(ArrayUnion(a31, a20), null) + + val b0 = Literal.create(Seq[Array[Byte]](Array[Byte](5, 6), Array[Byte](1, 2)), + ArrayType(BinaryType)) + val b1 = Literal.create(Seq[Array[Byte]](Array[Byte](2, 1), Array[Byte](4, 3)), + ArrayType(BinaryType)) + val b2 = Literal.create(Seq[Array[Byte]](Array[Byte](1, 2), Array[Byte](4, 3)), + ArrayType(BinaryType)) + val b3 = Literal.create(Seq[Array[Byte]]( + Array[Byte](1, 2), Array[Byte](4, 3), Array[Byte](1, 2)), ArrayType(BinaryType)) + val b4 = Literal.create(Seq[Array[Byte]](Array[Byte](1, 2), null), ArrayType(BinaryType)) + val b5 = Literal.create(Seq[Array[Byte]](null, Array[Byte](1, 2)), ArrayType(BinaryType)) + val b6 = Literal.create(Seq.empty, ArrayType(BinaryType)) + val arrayWithBinaryNull = Literal.create(Seq(null), ArrayType(BinaryType)) + + checkEvaluation(ArrayUnion(b0, b1), + Seq(Array[Byte](5, 6), Array[Byte](1, 2), Array[Byte](2, 1), Array[Byte](4, 3))) + checkEvaluation(ArrayUnion(b0, b2), + Seq(Array[Byte](5, 6), Array[Byte](1, 2), Array[Byte](4, 3))) + checkEvaluation(ArrayUnion(b2, b4), Seq(Array[Byte](1, 2), Array[Byte](4, 3), null)) + checkEvaluation(ArrayUnion(b3, b0), + Seq(Array[Byte](1, 2), Array[Byte](4, 3), Array[Byte](5, 6))) + checkEvaluation(ArrayUnion(b4, b0), Seq(Array[Byte](1, 2), null, Array[Byte](5, 6))) + checkEvaluation(ArrayUnion(b4, b5), Seq(Array[Byte](1, 2), null)) + checkEvaluation(ArrayUnion(b6, b4), Seq(Array[Byte](1, 2), null)) + checkEvaluation(ArrayUnion(b4, arrayWithBinaryNull), Seq(Array[Byte](1, 2), null)) + + val aa0 = Literal.create(Seq[Seq[Int]](Seq[Int](1, 2), Seq[Int](3, 4)), + ArrayType(ArrayType(IntegerType))) + val aa1 = Literal.create(Seq[Seq[Int]](Seq[Int](5, 6), Seq[Int](2, 1)), + ArrayType(ArrayType(IntegerType))) + checkEvaluation(ArrayUnion(aa0, aa1), + Seq[Seq[Int]](Seq[Int](1, 2), Seq[Int](3, 4), Seq[Int](5, 6), Seq[Int](2, 1))) + + assert(ArrayUnion(a00, a01).dataType.asInstanceOf[ArrayType].containsNull === false) + assert(ArrayUnion(a00, a02).dataType.asInstanceOf[ArrayType].containsNull === true) + assert(ArrayUnion(a20, a21).dataType.asInstanceOf[ArrayType].containsNull === false) + assert(ArrayUnion(a20, a22).dataType.asInstanceOf[ArrayType].containsNull === true) + } + + test("Shuffle") { + // Primitive-type elements + val ai0 = Literal.create(Seq(1, 2, 3, 4, 5), ArrayType(IntegerType, containsNull = false)) + val ai1 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType, containsNull = false)) + val ai2 = Literal.create(Seq(null, 1, null, 3), ArrayType(IntegerType, containsNull = true)) + val ai3 = Literal.create(Seq(2, null, 4, null), ArrayType(IntegerType, containsNull = true)) + val ai4 = Literal.create(Seq(null, null, null), ArrayType(IntegerType, containsNull = true)) + val ai5 = Literal.create(Seq(1), ArrayType(IntegerType, containsNull = false)) + val ai6 = Literal.create(Seq.empty, ArrayType(IntegerType, containsNull = false)) + val ai7 = Literal.create(null, ArrayType(IntegerType, containsNull = true)) + + checkEvaluation(Shuffle(ai0, Some(0)), Seq(4, 1, 2, 3, 5)) + checkEvaluation(Shuffle(ai1, Some(0)), Seq(3, 1, 2)) + checkEvaluation(Shuffle(ai2, Some(0)), Seq(3, null, 1, null)) + checkEvaluation(Shuffle(ai3, Some(0)), Seq(null, 2, null, 4)) + checkEvaluation(Shuffle(ai4, Some(0)), Seq(null, null, null)) + checkEvaluation(Shuffle(ai5, Some(0)), Seq(1)) + checkEvaluation(Shuffle(ai6, Some(0)), Seq.empty) + checkEvaluation(Shuffle(ai7, Some(0)), null) + + // Non-primitive-type elements + val as0 = Literal.create(Seq("a", "b", "c", "d"), ArrayType(StringType, containsNull = false)) + val as1 = Literal.create(Seq("a", "b", "c"), ArrayType(StringType, containsNull = false)) + val as2 = Literal.create(Seq(null, "a", null, "c"), ArrayType(StringType, containsNull = true)) + val as3 = Literal.create(Seq("b", null, "d", null), ArrayType(StringType, containsNull = true)) + val as4 = Literal.create(Seq(null, null, null), ArrayType(StringType, containsNull = true)) + val as5 = Literal.create(Seq("a"), ArrayType(StringType, containsNull = false)) + val as6 = Literal.create(Seq.empty, ArrayType(StringType, containsNull = false)) + val as7 = Literal.create(null, ArrayType(StringType, containsNull = true)) + val aa = Literal.create( + Seq(Seq("a", "b"), Seq("c", "d"), Seq("e")), + ArrayType(ArrayType(StringType))) + + checkEvaluation(Shuffle(as0, Some(0)), Seq("d", "a", "b", "c")) + checkEvaluation(Shuffle(as1, Some(0)), Seq("c", "a", "b")) + checkEvaluation(Shuffle(as2, Some(0)), Seq("c", null, "a", null)) + checkEvaluation(Shuffle(as3, Some(0)), Seq(null, "b", null, "d")) + checkEvaluation(Shuffle(as4, Some(0)), Seq(null, null, null)) + checkEvaluation(Shuffle(as5, Some(0)), Seq("a")) + checkEvaluation(Shuffle(as6, Some(0)), Seq.empty) + checkEvaluation(Shuffle(as7, Some(0)), null) + checkEvaluation(Shuffle(aa, Some(0)), Seq(Seq("e"), Seq("a", "b"), Seq("c", "d"))) + + val r = new Random(1234) + val seed1 = Some(r.nextLong()) + assert(evaluateWithoutCodegen(Shuffle(ai0, seed1)) === + evaluateWithoutCodegen(Shuffle(ai0, seed1))) + assert(evaluateWithGeneratedMutableProjection(Shuffle(ai0, seed1)) === + evaluateWithGeneratedMutableProjection(Shuffle(ai0, seed1))) + assert(evaluateWithUnsafeProjection(Shuffle(ai0, seed1)) === + evaluateWithUnsafeProjection(Shuffle(ai0, seed1))) + + val seed2 = Some(r.nextLong()) + assert(evaluateWithoutCodegen(Shuffle(ai0, seed1)) !== + evaluateWithoutCodegen(Shuffle(ai0, seed2))) + assert(evaluateWithGeneratedMutableProjection(Shuffle(ai0, seed1)) !== + evaluateWithGeneratedMutableProjection(Shuffle(ai0, seed2))) + assert(evaluateWithUnsafeProjection(Shuffle(ai0, seed1)) !== + evaluateWithUnsafeProjection(Shuffle(ai0, seed2))) + + val shuffle = Shuffle(ai0, seed1) + assert(shuffle.fastEquals(shuffle)) + assert(!shuffle.fastEquals(Shuffle(ai0, seed1))) + assert(!shuffle.fastEquals(shuffle.freshCopy())) + assert(!shuffle.fastEquals(Shuffle(ai0, seed2))) + } + + test("Array Except") { + val a00 = Literal.create(Seq(1, 2, 4, 3), ArrayType(IntegerType, false)) + val a01 = Literal.create(Seq(4, 2), ArrayType(IntegerType, false)) + val a02 = Literal.create(Seq(1, 2, 4, 2), ArrayType(IntegerType, false)) + val a03 = Literal.create(Seq(4, 2, 4), ArrayType(IntegerType, false)) + val a04 = Literal.create(Seq(1, 2, null, 4, 5, 1), ArrayType(IntegerType, true)) + val a05 = Literal.create(Seq(-5, 4, null, 2, -1), ArrayType(IntegerType, true)) + val a06 = Literal.create(Seq.empty[Int], ArrayType(IntegerType, false)) + val abl0 = Literal.create(Seq[Boolean](true, true), ArrayType(BooleanType, false)) + val abl1 = Literal.create(Seq[Boolean](false, false), ArrayType(BooleanType, false)) + val ab0 = Literal.create(Seq[Byte](1, 2, 3, 2), ArrayType(ByteType, false)) + val ab1 = Literal.create(Seq[Byte](4, 2, 4), ArrayType(ByteType, false)) + val as0 = Literal.create(Seq[Short](1, 2, 3, 2), ArrayType(ShortType, false)) + val as1 = Literal.create(Seq[Short](4, 2, 4), ArrayType(ShortType, false)) + val af0 = Literal.create(Seq[Float](1.1F, 2.2F, 3.3F, 2.2F), ArrayType(FloatType, false)) + val af1 = Literal.create(Seq[Float](4.4F, 2.2F, 4.4F), ArrayType(FloatType, false)) + val ad0 = Literal.create(Seq[Double](1.1, 2.2, 3.3, 2.2), ArrayType(DoubleType, false)) + val ad1 = Literal.create(Seq[Double](4.4, 2.2, 4.4), ArrayType(DoubleType, false)) + + val a10 = Literal.create(Seq(1L, 2L, 4L, 3L), ArrayType(LongType, false)) + val a11 = Literal.create(Seq(4L, 2L), ArrayType(LongType, false)) + val a12 = Literal.create(Seq(1L, 2L, 4L, 2L), ArrayType(LongType, false)) + val a13 = Literal.create(Seq(4L, 2L), ArrayType(LongType, false)) + val a14 = Literal.create(Seq(1L, 2L, null, 4L, 5L, 1L), ArrayType(LongType, true)) + val a15 = Literal.create(Seq(-5L, 4L, null, 2L, -1L), ArrayType(LongType, true)) + val a16 = Literal.create(Seq.empty[Long], ArrayType(LongType, false)) + + val a20 = Literal.create(Seq("b", "a", "c", "d"), ArrayType(StringType, false)) + val a21 = Literal.create(Seq("c", "a"), ArrayType(StringType, false)) + val a22 = Literal.create(Seq("b", "a", "c", "a"), ArrayType(StringType, false)) + val a23 = Literal.create(Seq("c", "a", "c"), ArrayType(StringType, false)) + val a24 = Literal.create(Seq("c", null, "a", "f", "c"), ArrayType(StringType, true)) + val a25 = Literal.create(Seq("b", null, "a", "g"), ArrayType(StringType, true)) + val a26 = Literal.create(Seq.empty[String], ArrayType(StringType, false)) + + val a30 = Literal.create(Seq(null, null), ArrayType(IntegerType)) + val a31 = Literal.create(null, ArrayType(StringType)) + + checkEvaluation(ArrayExcept(a00, a01), Seq(1, 3)) + checkEvaluation(ArrayExcept(a02, a01), Seq(1)) + checkEvaluation(ArrayExcept(a02, a02), Seq.empty) + checkEvaluation(ArrayExcept(a02, a03), Seq(1)) + checkEvaluation(ArrayExcept(a04, a02), Seq(null, 5)) + checkEvaluation(ArrayExcept(a04, a05), Seq(1, 5)) + checkEvaluation(ArrayExcept(a04, a06), Seq(1, 2, null, 4, 5)) + checkEvaluation(ArrayExcept(a06, a04), Seq.empty) + checkEvaluation(ArrayExcept(abl0, abl1), Seq[Boolean](true)) + checkEvaluation(ArrayExcept(ab0, ab1), Seq[Byte](1, 3)) + checkEvaluation(ArrayExcept(as0, as1), Seq[Short](1, 3)) + checkEvaluation(ArrayExcept(af0, af1), Seq[Float](1.1F, 3.3F)) + checkEvaluation(ArrayExcept(ad0, ad1), Seq[Double](1.1, 3.3)) + + checkEvaluation(ArrayExcept(a10, a11), Seq(1L, 3L)) + checkEvaluation(ArrayExcept(a12, a11), Seq(1L)) + checkEvaluation(ArrayExcept(a12, a12), Seq.empty) + checkEvaluation(ArrayExcept(a12, a13), Seq(1L)) + checkEvaluation(ArrayExcept(a14, a12), Seq(null, 5L)) + checkEvaluation(ArrayExcept(a14, a15), Seq(1L, 5L)) + checkEvaluation(ArrayExcept(a14, a16), Seq(1L, 2L, null, 4L, 5L)) + checkEvaluation(ArrayExcept(a16, a14), Seq.empty) + + checkEvaluation(ArrayExcept(a20, a21), Seq("b", "d")) + checkEvaluation(ArrayExcept(a22, a21), Seq("b")) + checkEvaluation(ArrayExcept(a22, a22), Seq.empty) + checkEvaluation(ArrayExcept(a22, a23), Seq("b")) + checkEvaluation(ArrayExcept(a24, a22), Seq(null, "f")) + checkEvaluation(ArrayExcept(a24, a25), Seq("c", "f")) + checkEvaluation(ArrayExcept(a24, a26), Seq("c", null, "a", "f")) + checkEvaluation(ArrayExcept(a26, a24), Seq.empty) + + checkEvaluation(ArrayExcept(a30, a30), Seq.empty) + checkEvaluation(ArrayExcept(a20, a31), null) + checkEvaluation(ArrayExcept(a31, a20), null) + + val b0 = Literal.create( + Seq[Array[Byte]](Array[Byte](5, 6), Array[Byte](1, 2), Array[Byte](3, 4), Array[Byte](7, 8)), + ArrayType(BinaryType)) + val b1 = Literal.create( + Seq[Array[Byte]](Array[Byte](2, 1), Array[Byte](3, 4), Array[Byte](5, 6)), + ArrayType(BinaryType)) + val b2 = Literal.create( + Seq[Array[Byte]](Array[Byte](1, 2), Array[Byte](3, 4), Array[Byte](1, 2)), + ArrayType(BinaryType)) + val b3 = Literal.create(Seq[Array[Byte]](Array[Byte](2, 1), Array[Byte](3, 4), null), + ArrayType(BinaryType)) + val b4 = Literal.create(Seq[Array[Byte]](null, Array[Byte](3, 4), null), ArrayType(BinaryType)) + val b5 = Literal.create(Seq.empty, ArrayType(BinaryType)) + val arrayWithBinaryNull = Literal.create(Seq(null), ArrayType(BinaryType)) + + checkEvaluation(ArrayExcept(b0, b1), Seq[Array[Byte]](Array[Byte](1, 2), Array[Byte](7, 8))) + checkEvaluation(ArrayExcept(b1, b0), Seq[Array[Byte]](Array[Byte](2, 1))) + checkEvaluation(ArrayExcept(b0, b2), Seq[Array[Byte]](Array[Byte](5, 6), Array[Byte](7, 8))) + checkEvaluation(ArrayExcept(b2, b0), Seq.empty) + checkEvaluation(ArrayExcept(b2, b3), Seq[Array[Byte]](Array[Byte](1, 2))) + checkEvaluation(ArrayExcept(b3, b2), Seq[Array[Byte]](Array[Byte](2, 1), null)) + checkEvaluation(ArrayExcept(b3, b4), Seq[Array[Byte]](Array[Byte](2, 1))) + checkEvaluation(ArrayExcept(b4, b3), Seq.empty) + checkEvaluation(ArrayExcept(b4, b5), Seq[Array[Byte]](null, Array[Byte](3, 4))) + checkEvaluation(ArrayExcept(b5, b4), Seq.empty) + checkEvaluation(ArrayExcept(b4, arrayWithBinaryNull), Seq[Array[Byte]](Array[Byte](3, 4))) + + val aa0 = Literal.create(Seq[Seq[Int]](Seq[Int](1, 2), Seq[Int](3, 4), Seq[Int](1, 2)), + ArrayType(ArrayType(IntegerType))) + val aa1 = Literal.create(Seq[Seq[Int]](Seq[Int](3, 4), Seq[Int](2, 1), Seq[Int](3, 4)), + ArrayType(ArrayType(IntegerType))) + checkEvaluation(ArrayExcept(aa0, aa1), Seq[Seq[Int]](Seq[Int](1, 2))) + checkEvaluation(ArrayExcept(aa1, aa0), Seq[Seq[Int]](Seq[Int](2, 1))) + + assert(ArrayExcept(a00, a01).dataType.asInstanceOf[ArrayType].containsNull === false) + assert(ArrayExcept(a04, a02).dataType.asInstanceOf[ArrayType].containsNull === true) + assert(ArrayExcept(a04, a05).dataType.asInstanceOf[ArrayType].containsNull === true) + assert(ArrayExcept(a20, a21).dataType.asInstanceOf[ArrayType].containsNull === false) + assert(ArrayExcept(a24, a22).dataType.asInstanceOf[ArrayType].containsNull === true) + } + + test("Array Intersect") { + val a00 = Literal.create(Seq(1, 2, 4), ArrayType(IntegerType, false)) + val a01 = Literal.create(Seq(4, 2), ArrayType(IntegerType, false)) + val a02 = Literal.create(Seq(1, 2, 1, 4), ArrayType(IntegerType, false)) + val a03 = Literal.create(Seq(4, 2, 4), ArrayType(IntegerType, false)) + val a04 = Literal.create(Seq(1, 2, null, 4, 5, null), ArrayType(IntegerType, true)) + val a05 = Literal.create(Seq(-5, 4, null, 2, -1, null), ArrayType(IntegerType, true)) + val a06 = Literal.create(Seq.empty[Int], ArrayType(IntegerType, false)) + val abl0 = Literal.create(Seq[Boolean](true, false, true), ArrayType(BooleanType, false)) + val abl1 = Literal.create(Seq[Boolean](true, true), ArrayType(BooleanType, false)) + val ab0 = Literal.create(Seq[Byte](1, 2, 3, 2), ArrayType(ByteType, containsNull = false)) + val ab1 = Literal.create(Seq[Byte](4, 2, 4), ArrayType(ByteType, containsNull = false)) + val as0 = Literal.create(Seq[Short](1, 2, 3, 2), ArrayType(ShortType, containsNull = false)) + val as1 = Literal.create(Seq[Short](4, 2, 4), ArrayType(ShortType, containsNull = false)) + val af0 = Literal.create(Seq[Float](1.1F, 2.2F, 3.3F, 2.2F), ArrayType(FloatType, false)) + val af1 = Literal.create(Seq[Float](4.4F, 2.2F, 4.4F), ArrayType(FloatType, false)) + val ad0 = Literal.create(Seq[Double](1.1, 2.2, 3.3, 2.2), ArrayType(DoubleType, false)) + val ad1 = Literal.create(Seq[Double](4.4, 2.2, 4.4), ArrayType(DoubleType, false)) + + val a10 = Literal.create(Seq(1L, 2L, 4L), ArrayType(LongType, false)) + val a11 = Literal.create(Seq(4L, 2L), ArrayType(LongType, false)) + val a12 = Literal.create(Seq(1L, 2L, 1L, 4L), ArrayType(LongType, false)) + val a13 = Literal.create(Seq(4L, 2L, 4L), ArrayType(LongType, false)) + val a14 = Literal.create(Seq(1L, 2L, null, 4L, 5L, null), ArrayType(LongType, true)) + val a15 = Literal.create(Seq(-5L, 4L, null, 2L, -1L, null), ArrayType(LongType, true)) + val a16 = Literal.create(Seq.empty[Long], ArrayType(LongType, false)) + + val a20 = Literal.create(Seq("b", "a", "c"), ArrayType(StringType, false)) + val a21 = Literal.create(Seq("c", "a"), ArrayType(StringType, false)) + val a22 = Literal.create(Seq("b", "a", "c", "a"), ArrayType(StringType, false)) + val a23 = Literal.create(Seq("c", "a", null, "f"), ArrayType(StringType, true)) + val a24 = Literal.create(Seq("b", null, "a", "g", null), ArrayType(StringType, true)) + val a25 = Literal.create(Seq.empty[String], ArrayType(StringType, false)) + + val a30 = Literal.create(Seq(null, null), ArrayType(IntegerType)) + val a31 = Literal.create(null, ArrayType(StringType)) + + checkEvaluation(ArrayIntersect(a00, a01), Seq(2, 4)) + checkEvaluation(ArrayIntersect(a01, a00), Seq(4, 2)) + checkEvaluation(ArrayIntersect(a02, a03), Seq(2, 4)) + checkEvaluation(ArrayIntersect(a03, a02), Seq(4, 2)) + checkEvaluation(ArrayIntersect(a00, a04), Seq(1, 2, 4)) + checkEvaluation(ArrayIntersect(a04, a05), Seq(2, null, 4)) + checkEvaluation(ArrayIntersect(a02, a06), Seq.empty) + checkEvaluation(ArrayIntersect(a06, a04), Seq.empty) + checkEvaluation(ArrayIntersect(abl0, abl1), Seq[Boolean](true)) + checkEvaluation(ArrayIntersect(ab0, ab1), Seq[Byte](2)) + checkEvaluation(ArrayIntersect(as0, as1), Seq[Short](2)) + checkEvaluation(ArrayIntersect(af0, af1), Seq[Float](2.2F)) + checkEvaluation(ArrayIntersect(ad0, ad1), Seq[Double](2.2D)) + + checkEvaluation(ArrayIntersect(a10, a11), Seq(2L, 4L)) + checkEvaluation(ArrayIntersect(a11, a10), Seq(4L, 2L)) + checkEvaluation(ArrayIntersect(a12, a13), Seq(2L, 4L)) + checkEvaluation(ArrayIntersect(a13, a12), Seq(4L, 2L)) + checkEvaluation(ArrayIntersect(a14, a15), Seq(2L, null, 4L)) + checkEvaluation(ArrayIntersect(a12, a16), Seq.empty) + checkEvaluation(ArrayIntersect(a16, a14), Seq.empty) + + checkEvaluation(ArrayIntersect(a20, a21), Seq("a", "c")) + checkEvaluation(ArrayIntersect(a21, a20), Seq("c", "a")) + checkEvaluation(ArrayIntersect(a22, a21), Seq("a", "c")) + checkEvaluation(ArrayIntersect(a21, a22), Seq("c", "a")) + checkEvaluation(ArrayIntersect(a23, a24), Seq("a", null)) + checkEvaluation(ArrayIntersect(a24, a23), Seq(null, "a")) + checkEvaluation(ArrayIntersect(a24, a25), Seq.empty) + checkEvaluation(ArrayIntersect(a25, a24), Seq.empty) + + checkEvaluation(ArrayIntersect(a30, a30), Seq(null)) + checkEvaluation(ArrayIntersect(a20, a31), null) + checkEvaluation(ArrayIntersect(a31, a20), null) + + val b0 = Literal.create( + Seq[Array[Byte]](Array[Byte](5, 6), Array[Byte](1, 2), Array[Byte](3, 4)), + ArrayType(BinaryType)) + val b1 = Literal.create( + Seq[Array[Byte]](Array[Byte](2, 1), Array[Byte](3, 4), Array[Byte](5, 6)), + ArrayType(BinaryType)) + val b2 = Literal.create( + Seq[Array[Byte]](Array[Byte](3, 4), Array[Byte](1, 2), Array[Byte](1, 2)), + ArrayType(BinaryType)) + val b3 = Literal.create(Seq[Array[Byte]](Array[Byte](1, 2), Array[Byte](3, 4), null), + ArrayType(BinaryType)) + val b4 = Literal.create(Seq[Array[Byte]](null, Array[Byte](3, 4), null), ArrayType(BinaryType)) + val b5 = Literal.create(Seq.empty, ArrayType(BinaryType)) + val arrayWithBinaryNull = Literal.create(Seq(null), ArrayType(BinaryType)) + checkEvaluation(ArrayIntersect(b0, b1), Seq[Array[Byte]](Array[Byte](5, 6), Array[Byte](3, 4))) + checkEvaluation(ArrayIntersect(b1, b0), Seq[Array[Byte]](Array[Byte](3, 4), Array[Byte](5, 6))) + checkEvaluation(ArrayIntersect(b0, b2), Seq[Array[Byte]](Array[Byte](1, 2), Array[Byte](3, 4))) + checkEvaluation(ArrayIntersect(b2, b0), Seq[Array[Byte]](Array[Byte](3, 4), Array[Byte](1, 2))) + checkEvaluation(ArrayIntersect(b2, b3), Seq[Array[Byte]](Array[Byte](3, 4), Array[Byte](1, 2))) + checkEvaluation(ArrayIntersect(b3, b2), Seq[Array[Byte]](Array[Byte](1, 2), Array[Byte](3, 4))) + checkEvaluation(ArrayIntersect(b3, b4), Seq[Array[Byte]](Array[Byte](3, 4), null)) + checkEvaluation(ArrayIntersect(b4, b3), Seq[Array[Byte]](null, Array[Byte](3, 4))) + checkEvaluation(ArrayIntersect(b4, b5), Seq.empty) + checkEvaluation(ArrayIntersect(b5, b4), Seq.empty) + checkEvaluation(ArrayIntersect(b4, arrayWithBinaryNull), Seq[Array[Byte]](null)) + + val aa0 = Literal.create(Seq[Seq[Int]](Seq[Int](1, 2), Seq[Int](3, 4), Seq[Int](1, 2)), + ArrayType(ArrayType(IntegerType))) + val aa1 = Literal.create(Seq[Seq[Int]](Seq[Int](3, 4), Seq[Int](2, 1), Seq[Int](3, 4)), + ArrayType(ArrayType(IntegerType))) + checkEvaluation(ArrayIntersect(aa0, aa1), Seq[Seq[Int]](Seq[Int](3, 4))) + checkEvaluation(ArrayIntersect(aa1, aa0), Seq[Seq[Int]](Seq[Int](3, 4))) + + assert(ArrayIntersect(a00, a01).dataType.asInstanceOf[ArrayType].containsNull === false) + assert(ArrayIntersect(a00, a04).dataType.asInstanceOf[ArrayType].containsNull === false) + assert(ArrayIntersect(a04, a05).dataType.asInstanceOf[ArrayType].containsNull === true) + assert(ArrayIntersect(a20, a21).dataType.asInstanceOf[ArrayType].containsNull === false) + assert(ArrayIntersect(a23, a24).dataType.asInstanceOf[ArrayType].containsNull === true) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala index 726193b411737..77aaf55480ec2 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala @@ -144,6 +144,13 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(CreateArray(byteWithNull), byteSeq :+ null, EmptyRow) checkEvaluation(CreateArray(strWithNull), strSeq :+ null, EmptyRow) checkEvaluation(CreateArray(Literal.create(null, IntegerType) :: Nil), null :: Nil) + + val array = CreateArray(Seq( + Literal.create(intSeq, ArrayType(IntegerType, containsNull = false)), + Literal.create(intSeq :+ null, ArrayType(IntegerType, containsNull = true)))) + assert(array.dataType === + ArrayType(ArrayType(IntegerType, containsNull = true), containsNull = false)) + checkEvaluation(array, Seq(intSeq, intSeq :+ null)) } test("CreateMap") { @@ -184,6 +191,18 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper { CreateMap(interlace(strWithNull, intSeq.map(Literal(_)))), null, null) } + + val map = CreateMap(Seq( + Literal.create(intSeq, ArrayType(IntegerType, containsNull = false)), + Literal.create(strSeq, ArrayType(StringType, containsNull = false)), + Literal.create(intSeq :+ null, ArrayType(IntegerType, containsNull = true)), + Literal.create(strSeq :+ null, ArrayType(StringType, containsNull = true)))) + assert(map.dataType === + MapType( + ArrayType(IntegerType, containsNull = true), + ArrayType(StringType, containsNull = true), + valueContainsNull = false)) + checkEvaluation(map, createMap(Seq(intSeq, intSeq :+ null), Seq(strSeq, strSeq :+ null))) } test("MapFromArrays") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala index a099119732e25..f489d330cf453 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala @@ -113,6 +113,76 @@ class ConditionalExpressionSuite extends SparkFunSuite with ExpressionEvalHelper assert(CaseWhen(Seq((c2, c4_notNull), (c3, c5))).nullable === true) } + test("if/case when - null flags of non-primitive types") { + val arrayWithNulls = Literal.create(Seq("a", null, "b"), ArrayType(StringType, true)) + val arrayWithoutNulls = Literal.create(Seq("c", "d"), ArrayType(StringType, false)) + val structWithNulls = Literal.create( + create_row(null, null), + StructType(Seq(StructField("a", IntegerType, true), StructField("b", StringType, true)))) + val structWithoutNulls = Literal.create( + create_row(1, "a"), + StructType(Seq(StructField("a", IntegerType, false), StructField("b", StringType, false)))) + val mapWithNulls = Literal.create(Map(1 -> null), MapType(IntegerType, StringType, true)) + val mapWithoutNulls = Literal.create(Map(1 -> "a"), MapType(IntegerType, StringType, false)) + + val arrayIf1 = If(Literal.FalseLiteral, arrayWithNulls, arrayWithoutNulls) + val arrayIf2 = If(Literal.FalseLiteral, arrayWithoutNulls, arrayWithNulls) + val arrayIf3 = If(Literal.TrueLiteral, arrayWithNulls, arrayWithoutNulls) + val arrayIf4 = If(Literal.TrueLiteral, arrayWithoutNulls, arrayWithNulls) + val structIf1 = If(Literal.FalseLiteral, structWithNulls, structWithoutNulls) + val structIf2 = If(Literal.FalseLiteral, structWithoutNulls, structWithNulls) + val structIf3 = If(Literal.TrueLiteral, structWithNulls, structWithoutNulls) + val structIf4 = If(Literal.TrueLiteral, structWithoutNulls, structWithNulls) + val mapIf1 = If(Literal.FalseLiteral, mapWithNulls, mapWithoutNulls) + val mapIf2 = If(Literal.FalseLiteral, mapWithoutNulls, mapWithNulls) + val mapIf3 = If(Literal.TrueLiteral, mapWithNulls, mapWithoutNulls) + val mapIf4 = If(Literal.TrueLiteral, mapWithoutNulls, mapWithNulls) + + val arrayCaseWhen1 = CaseWhen(Seq((Literal.FalseLiteral, arrayWithNulls)), arrayWithoutNulls) + val arrayCaseWhen2 = CaseWhen(Seq((Literal.FalseLiteral, arrayWithoutNulls)), arrayWithNulls) + val arrayCaseWhen3 = CaseWhen(Seq((Literal.TrueLiteral, arrayWithNulls)), arrayWithoutNulls) + val arrayCaseWhen4 = CaseWhen(Seq((Literal.TrueLiteral, arrayWithoutNulls)), arrayWithNulls) + val structCaseWhen1 = CaseWhen(Seq((Literal.FalseLiteral, structWithNulls)), structWithoutNulls) + val structCaseWhen2 = CaseWhen(Seq((Literal.FalseLiteral, structWithoutNulls)), structWithNulls) + val structCaseWhen3 = CaseWhen(Seq((Literal.TrueLiteral, structWithNulls)), structWithoutNulls) + val structCaseWhen4 = CaseWhen(Seq((Literal.TrueLiteral, structWithoutNulls)), structWithNulls) + val mapCaseWhen1 = CaseWhen(Seq((Literal.FalseLiteral, mapWithNulls)), mapWithoutNulls) + val mapCaseWhen2 = CaseWhen(Seq((Literal.FalseLiteral, mapWithoutNulls)), mapWithNulls) + val mapCaseWhen3 = CaseWhen(Seq((Literal.TrueLiteral, mapWithNulls)), mapWithoutNulls) + val mapCaseWhen4 = CaseWhen(Seq((Literal.TrueLiteral, mapWithoutNulls)), mapWithNulls) + + def checkResult(expectedType: DataType, expectedValue: Any, result: Expression): Unit = { + assert(expectedType == result.dataType) + checkEvaluation(result, expectedValue) + } + + checkResult(arrayWithNulls.dataType, arrayWithoutNulls.value, arrayIf1) + checkResult(arrayWithNulls.dataType, arrayWithNulls.value, arrayIf2) + checkResult(arrayWithNulls.dataType, arrayWithNulls.value, arrayIf3) + checkResult(arrayWithNulls.dataType, arrayWithoutNulls.value, arrayIf4) + checkResult(structWithNulls.dataType, structWithoutNulls.value, structIf1) + checkResult(structWithNulls.dataType, structWithNulls.value, structIf2) + checkResult(structWithNulls.dataType, structWithNulls.value, structIf3) + checkResult(structWithNulls.dataType, structWithoutNulls.value, structIf4) + checkResult(mapWithNulls.dataType, mapWithoutNulls.value, mapIf1) + checkResult(mapWithNulls.dataType, mapWithNulls.value, mapIf2) + checkResult(mapWithNulls.dataType, mapWithNulls.value, mapIf3) + checkResult(mapWithNulls.dataType, mapWithoutNulls.value, mapIf4) + + checkResult(arrayWithNulls.dataType, arrayWithoutNulls.value, arrayCaseWhen1) + checkResult(arrayWithNulls.dataType, arrayWithNulls.value, arrayCaseWhen2) + checkResult(arrayWithNulls.dataType, arrayWithNulls.value, arrayCaseWhen3) + checkResult(arrayWithNulls.dataType, arrayWithoutNulls.value, arrayCaseWhen4) + checkResult(structWithNulls.dataType, structWithoutNulls.value, structCaseWhen1) + checkResult(structWithNulls.dataType, structWithNulls.value, structCaseWhen2) + checkResult(structWithNulls.dataType, structWithNulls.value, structCaseWhen3) + checkResult(structWithNulls.dataType, structWithoutNulls.value, structCaseWhen4) + checkResult(mapWithNulls.dataType, mapWithoutNulls.value, mapCaseWhen1) + checkResult(mapWithNulls.dataType, mapWithNulls.value, mapCaseWhen2) + checkResult(mapWithNulls.dataType, mapWithNulls.value, mapCaseWhen3) + checkResult(mapWithNulls.dataType, mapWithoutNulls.value, mapCaseWhen4) + } + test("case key when") { val row = create_row(null, 1, 2, "a", "b", "c") val c1 = 'a.int.at(0) @@ -139,7 +209,7 @@ class ConditionalExpressionSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(CaseKeyWhen(literalNull, Seq(c2, c5, c1, c6)), null, row) } - test("case key whn - internal pattern matching expects a List while apply takes a Seq") { + test("case key when - internal pattern matching expects a List while apply takes a Seq") { val indexedSeq = IndexedSeq(Literal(1), Literal(42), Literal(42), Literal(1)) val caseKeyWhaen = CaseKeyWhen(Literal(12), indexedSeq) assert(caseKeyWhaen.branches == diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExprIdSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExprIdSuite.scala new file mode 100644 index 0000000000000..2352db405b1a8 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExprIdSuite.scala @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import java.util.UUID + +import org.apache.spark.SparkFunSuite + +class ExprIdSuite extends SparkFunSuite { + + private val jvmId = UUID.randomUUID() + private val otherJvmId = UUID.randomUUID() + + test("hashcode independent of jvmId") { + val exprId1 = ExprId(12, jvmId) + val exprId2 = ExprId(12, otherJvmId) + assert(exprId1 != exprId2) + assert(exprId1.hashCode() == exprId2.hashCode()) + } + + test("equality should depend on both id and jvmId") { + val exprId1 = ExprId(1, jvmId) + val exprId2 = ExprId(1, jvmId) + assert(exprId1 == exprId2) + + val exprId3 = ExprId(1, jvmId) + val exprId4 = ExprId(2, jvmId) + assert(exprId3 != exprId4) + + val exprId5 = ExprId(1, jvmId) + val exprId6 = ExprId(1, otherJvmId) + assert(exprId5 != exprId6) + } + +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala index 14bfa212b5496..6684e5ce18d4c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala @@ -79,6 +79,12 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks with PlanTestBa java.util.Arrays.equals(result, expected) case (result: Double, expected: Spread[Double @unchecked]) => expected.asInstanceOf[Spread[Double]].isWithin(result) + case (result: InternalRow, expected: InternalRow) => + val st = dataType.asInstanceOf[StructType] + assert(result.numFields == st.length && expected.numFields == st.length) + st.zipWithIndex.forall { case (f, i) => + checkResult(result.get(i, f.dataType), expected.get(i, f.dataType), f.dataType) + } case (result: ArrayData, expected: ArrayData) => result.numElements == expected.numElements && { val et = dataType.asInstanceOf[ArrayType].elementType @@ -99,9 +105,6 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks with PlanTestBa if (expected.isNaN) result.isNaN else expected == result case (result: Float, expected: Float) => if (expected.isNaN) result.isNaN else expected == result - case (result: UnsafeRow, expected: GenericInternalRow) => - val structType = exprDataType.asInstanceOf[StructType] - result.toSeq(structType) == expected.toSeq(structType) case (result: Row, expected: InternalRow) => result.toSeq == expected.toSeq(result.schema) case _ => result == expected diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionSetSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionSetSuite.scala index 12eddf557109f..3ccaa5976cc28 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionSetSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionSetSuite.scala @@ -41,7 +41,7 @@ class ExpressionSetSuite extends SparkFunSuite { // maxHash's hashcode is calculated based on this exprId's hashcode, so we set this // exprId's hashCode to this specific value to make sure maxHash's hashcode is // `Int.MaxValue` - override def hashCode: Int = -1030353449 + override def hashCode: Int = 1394598635 // We are implementing this equals() only because the style-checking rule "you should // implement equals and hashCode together" requires us to override def equals(obj: Any): Boolean = super.equals(obj) @@ -57,7 +57,7 @@ class ExpressionSetSuite extends SparkFunSuite { // minHash's hashcode is calculated based on this exprId's hashcode, so we set this // exprId's hashCode to this specific value to make sure minHash's hashcode is // `Int.MinValue` - override def hashCode: Int = 1407330692 + override def hashCode: Int = -462684520 // We are implementing this equals() only because the style-checking rule "you should // implement equals and hashCode together" requires us to override def equals(obj: Any): Boolean = super.equals(obj) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala new file mode 100644 index 0000000000000..e13f4d98295be --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala @@ -0,0 +1,614 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.util.ArrayBasedMapData +import org.apache.spark.sql.types._ + +class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { + import org.apache.spark.sql.catalyst.dsl.expressions._ + + private def createLambda( + dt: DataType, + nullable: Boolean, + f: Expression => Expression): Expression = { + val lv = NamedLambdaVariable("arg", dt, nullable) + val function = f(lv) + LambdaFunction(function, Seq(lv)) + } + + private def createLambda( + dt1: DataType, + nullable1: Boolean, + dt2: DataType, + nullable2: Boolean, + f: (Expression, Expression) => Expression): Expression = { + val lv1 = NamedLambdaVariable("arg1", dt1, nullable1) + val lv2 = NamedLambdaVariable("arg2", dt2, nullable2) + val function = f(lv1, lv2) + LambdaFunction(function, Seq(lv1, lv2)) + } + + private def createLambda( + dt1: DataType, + nullable1: Boolean, + dt2: DataType, + nullable2: Boolean, + dt3: DataType, + nullable3: Boolean, + f: (Expression, Expression, Expression) => Expression): Expression = { + val lv1 = NamedLambdaVariable("arg1", dt1, nullable1) + val lv2 = NamedLambdaVariable("arg2", dt2, nullable2) + val lv3 = NamedLambdaVariable("arg3", dt3, nullable3) + val function = f(lv1, lv2, lv3) + LambdaFunction(function, Seq(lv1, lv2, lv3)) + } + + private def validateBinding( + e: Expression, + argInfo: Seq[(DataType, Boolean)]): LambdaFunction = e match { + case f: LambdaFunction => + assert(f.arguments.size === argInfo.size) + f.arguments.zip(argInfo).foreach { + case (arg, (dataType, nullable)) => + assert(arg.dataType === dataType) + assert(arg.nullable === nullable) + } + f + } + + def transform(expr: Expression, f: Expression => Expression): Expression = { + val ArrayType(et, cn) = expr.dataType + ArrayTransform(expr, createLambda(et, cn, f)).bind(validateBinding) + } + + def transform(expr: Expression, f: (Expression, Expression) => Expression): Expression = { + val ArrayType(et, cn) = expr.dataType + ArrayTransform(expr, createLambda(et, cn, IntegerType, false, f)).bind(validateBinding) + } + + def filter(expr: Expression, f: Expression => Expression): Expression = { + val ArrayType(et, cn) = expr.dataType + ArrayFilter(expr, createLambda(et, cn, f)).bind(validateBinding) + } + + def transformKeys(expr: Expression, f: (Expression, Expression) => Expression): Expression = { + val MapType(kt, vt, vcn) = expr.dataType + TransformKeys(expr, createLambda(kt, false, vt, vcn, f)).bind(validateBinding) + } + + def aggregate( + expr: Expression, + zero: Expression, + merge: (Expression, Expression) => Expression, + finish: Expression => Expression): Expression = { + val ArrayType(et, cn) = expr.dataType + val zeroType = zero.dataType + ArrayAggregate( + expr, + zero, + createLambda(zeroType, true, et, cn, merge), + createLambda(zeroType, true, finish)) + .bind(validateBinding) + } + + def aggregate( + expr: Expression, + zero: Expression, + merge: (Expression, Expression) => Expression): Expression = { + aggregate(expr, zero, merge, identity) + } + + def transformValues(expr: Expression, f: (Expression, Expression) => Expression): Expression = { + val MapType(kt, vt, vcn) = expr.dataType + TransformValues(expr, createLambda(kt, false, vt, vcn, f)).bind(validateBinding) + } + + test("ArrayTransform") { + val ai0 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType, containsNull = false)) + val ai1 = Literal.create(Seq[Integer](1, null, 3), ArrayType(IntegerType, containsNull = true)) + val ain = Literal.create(null, ArrayType(IntegerType, containsNull = false)) + + val plusOne: Expression => Expression = x => x + 1 + val plusIndex: (Expression, Expression) => Expression = (x, i) => x + i + + checkEvaluation(transform(ai0, plusOne), Seq(2, 3, 4)) + checkEvaluation(transform(ai0, plusIndex), Seq(1, 3, 5)) + checkEvaluation(transform(transform(ai0, plusIndex), plusOne), Seq(2, 4, 6)) + checkEvaluation(transform(ai1, plusOne), Seq(2, null, 4)) + checkEvaluation(transform(ai1, plusIndex), Seq(1, null, 5)) + checkEvaluation(transform(transform(ai1, plusIndex), plusOne), Seq(2, null, 6)) + checkEvaluation(transform(ain, plusOne), null) + + val as0 = Literal.create(Seq("a", "b", "c"), ArrayType(StringType, containsNull = false)) + val as1 = Literal.create(Seq("a", null, "c"), ArrayType(StringType, containsNull = true)) + val asn = Literal.create(null, ArrayType(StringType, containsNull = false)) + + val repeatTwice: Expression => Expression = x => Concat(Seq(x, x)) + val repeatIndexTimes: (Expression, Expression) => Expression = (x, i) => StringRepeat(x, i) + + checkEvaluation(transform(as0, repeatTwice), Seq("aa", "bb", "cc")) + checkEvaluation(transform(as0, repeatIndexTimes), Seq("", "b", "cc")) + checkEvaluation(transform(transform(as0, repeatIndexTimes), repeatTwice), + Seq("", "bb", "cccc")) + checkEvaluation(transform(as1, repeatTwice), Seq("aa", null, "cc")) + checkEvaluation(transform(as1, repeatIndexTimes), Seq("", null, "cc")) + checkEvaluation(transform(transform(as1, repeatIndexTimes), repeatTwice), + Seq("", null, "cccc")) + checkEvaluation(transform(asn, repeatTwice), null) + + val aai = Literal.create(Seq(Seq(1, 2, 3), null, Seq(4, 5)), + ArrayType(ArrayType(IntegerType, containsNull = false), containsNull = true)) + checkEvaluation(transform(aai, array => Cast(transform(array, plusOne), StringType)), + Seq("[2, 3, 4]", null, "[5, 6]")) + checkEvaluation(transform(aai, array => Cast(transform(array, plusIndex), StringType)), + Seq("[1, 3, 5]", null, "[4, 6]")) + } + + test("MapFilter") { + def mapFilter(expr: Expression, f: (Expression, Expression) => Expression): Expression = { + val MapType(kt, vt, vcn) = expr.dataType + MapFilter(expr, createLambda(kt, false, vt, vcn, f)).bind(validateBinding) + } + val mii0 = Literal.create(Map(1 -> 0, 2 -> 10, 3 -> -1), + MapType(IntegerType, IntegerType, valueContainsNull = false)) + val mii1 = Literal.create(Map(1 -> null, 2 -> 10, 3 -> null), + MapType(IntegerType, IntegerType, valueContainsNull = true)) + val miin = Literal.create(null, MapType(IntegerType, IntegerType, valueContainsNull = false)) + + val kGreaterThanV: (Expression, Expression) => Expression = (k, v) => k > v + + checkEvaluation(mapFilter(mii0, kGreaterThanV), Map(1 -> 0, 3 -> -1)) + checkEvaluation(mapFilter(mii1, kGreaterThanV), Map()) + checkEvaluation(mapFilter(miin, kGreaterThanV), null) + + val valueIsNull: (Expression, Expression) => Expression = (_, v) => v.isNull + + checkEvaluation(mapFilter(mii0, valueIsNull), Map()) + checkEvaluation(mapFilter(mii1, valueIsNull), Map(1 -> null, 3 -> null)) + checkEvaluation(mapFilter(miin, valueIsNull), null) + + val msi0 = Literal.create(Map("abcdf" -> 5, "abc" -> 10, "" -> 0), + MapType(StringType, IntegerType, valueContainsNull = false)) + val msi1 = Literal.create(Map("abcdf" -> 5, "abc" -> 10, "" -> null), + MapType(StringType, IntegerType, valueContainsNull = true)) + val msin = Literal.create(null, MapType(StringType, IntegerType, valueContainsNull = false)) + + val isLengthOfKey: (Expression, Expression) => Expression = (k, v) => Length(k) === v + + checkEvaluation(mapFilter(msi0, isLengthOfKey), Map("abcdf" -> 5, "" -> 0)) + checkEvaluation(mapFilter(msi1, isLengthOfKey), Map("abcdf" -> 5)) + checkEvaluation(mapFilter(msin, isLengthOfKey), null) + + val mia0 = Literal.create(Map(1 -> Seq(0, 1, 2), 2 -> Seq(10), -3 -> Seq(-1, 0, -2, 3)), + MapType(IntegerType, ArrayType(IntegerType), valueContainsNull = false)) + val mia1 = Literal.create(Map(1 -> Seq(0, 1, 2), 2 -> null, -3 -> Seq(-1, 0, -2, 3)), + MapType(IntegerType, ArrayType(IntegerType), valueContainsNull = true)) + val mian = Literal.create( + null, MapType(IntegerType, ArrayType(IntegerType), valueContainsNull = false)) + + val customFunc: (Expression, Expression) => Expression = (k, v) => Size(v) + k > 3 + + checkEvaluation(mapFilter(mia0, customFunc), Map(1 -> Seq(0, 1, 2))) + checkEvaluation(mapFilter(mia1, customFunc), Map(1 -> Seq(0, 1, 2))) + checkEvaluation(mapFilter(mian, customFunc), null) + } + + test("ArrayFilter") { + val ai0 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType, containsNull = false)) + val ai1 = Literal.create(Seq[Integer](1, null, 3), ArrayType(IntegerType, containsNull = true)) + val ain = Literal.create(null, ArrayType(IntegerType, containsNull = false)) + + val isEven: Expression => Expression = x => x % 2 === 0 + val isNullOrOdd: Expression => Expression = x => x.isNull || x % 2 === 1 + + checkEvaluation(filter(ai0, isEven), Seq(2)) + checkEvaluation(filter(ai0, isNullOrOdd), Seq(1, 3)) + checkEvaluation(filter(ai1, isEven), Seq.empty) + checkEvaluation(filter(ai1, isNullOrOdd), Seq(1, null, 3)) + checkEvaluation(filter(ain, isEven), null) + checkEvaluation(filter(ain, isNullOrOdd), null) + + val as0 = + Literal.create(Seq("a0", "b1", "a2", "c3"), ArrayType(StringType, containsNull = false)) + val as1 = Literal.create(Seq("a", null, "c"), ArrayType(StringType, containsNull = true)) + val asn = Literal.create(null, ArrayType(StringType, containsNull = false)) + + val startsWithA: Expression => Expression = x => x.startsWith("a") + + checkEvaluation(filter(as0, startsWithA), Seq("a0", "a2")) + checkEvaluation(filter(as1, startsWithA), Seq("a")) + checkEvaluation(filter(asn, startsWithA), null) + + val aai = Literal.create(Seq(Seq(1, 2, 3), null, Seq(4, 5)), + ArrayType(ArrayType(IntegerType, containsNull = false), containsNull = true)) + checkEvaluation(transform(aai, ix => filter(ix, isNullOrOdd)), + Seq(Seq(1, 3), null, Seq(5))) + } + + test("ArrayExists") { + def exists(expr: Expression, f: Expression => Expression): Expression = { + val ArrayType(et, cn) = expr.dataType + ArrayExists(expr, createLambda(et, cn, f)).bind(validateBinding) + } + + val ai0 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType, containsNull = false)) + val ai1 = Literal.create(Seq[Integer](1, null, 3), ArrayType(IntegerType, containsNull = true)) + val ain = Literal.create(null, ArrayType(IntegerType, containsNull = false)) + + val isEven: Expression => Expression = x => x % 2 === 0 + val isNullOrOdd: Expression => Expression = x => x.isNull || x % 2 === 1 + + checkEvaluation(exists(ai0, isEven), true) + checkEvaluation(exists(ai0, isNullOrOdd), true) + checkEvaluation(exists(ai1, isEven), false) + checkEvaluation(exists(ai1, isNullOrOdd), true) + checkEvaluation(exists(ain, isEven), null) + checkEvaluation(exists(ain, isNullOrOdd), null) + + val as0 = + Literal.create(Seq("a0", "b1", "a2", "c3"), ArrayType(StringType, containsNull = false)) + val as1 = Literal.create(Seq(null, "b", "c"), ArrayType(StringType, containsNull = true)) + val asn = Literal.create(null, ArrayType(StringType, containsNull = false)) + + val startsWithA: Expression => Expression = x => x.startsWith("a") + + checkEvaluation(exists(as0, startsWithA), true) + checkEvaluation(exists(as1, startsWithA), false) + checkEvaluation(exists(asn, startsWithA), null) + + val aai = Literal.create(Seq(Seq(1, 2, 3), null, Seq(4, 5)), + ArrayType(ArrayType(IntegerType, containsNull = false), containsNull = true)) + checkEvaluation(transform(aai, ix => exists(ix, isNullOrOdd)), + Seq(true, null, true)) + } + + test("ArrayAggregate") { + val ai0 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType, containsNull = false)) + val ai1 = Literal.create(Seq[Integer](1, null, 3), ArrayType(IntegerType, containsNull = true)) + val ai2 = Literal.create(Seq.empty[Int], ArrayType(IntegerType, containsNull = false)) + val ain = Literal.create(null, ArrayType(IntegerType, containsNull = false)) + + checkEvaluation(aggregate(ai0, 0, (acc, elem) => acc + elem, acc => acc * 10), 60) + checkEvaluation(aggregate(ai1, 0, (acc, elem) => acc + coalesce(elem, 0), acc => acc * 10), 40) + checkEvaluation(aggregate(ai2, 0, (acc, elem) => acc + elem, acc => acc * 10), 0) + checkEvaluation(aggregate(ain, 0, (acc, elem) => acc + elem, acc => acc * 10), null) + + val as0 = Literal.create(Seq("a", "b", "c"), ArrayType(StringType, containsNull = false)) + val as1 = Literal.create(Seq("a", null, "c"), ArrayType(StringType, containsNull = true)) + val as2 = Literal.create(Seq.empty[String], ArrayType(StringType, containsNull = false)) + val asn = Literal.create(null, ArrayType(StringType, containsNull = false)) + + checkEvaluation(aggregate(as0, "", (acc, elem) => Concat(Seq(acc, elem))), "abc") + checkEvaluation(aggregate(as1, "", (acc, elem) => Concat(Seq(acc, coalesce(elem, "x")))), "axc") + checkEvaluation(aggregate(as2, "", (acc, elem) => Concat(Seq(acc, elem))), "") + checkEvaluation(aggregate(asn, "", (acc, elem) => Concat(Seq(acc, elem))), null) + + val aai = Literal.create(Seq[Seq[Integer]](Seq(1, 2, 3), null, Seq(4, 5)), + ArrayType(ArrayType(IntegerType, containsNull = false), containsNull = true)) + checkEvaluation( + aggregate(aai, 0, + (acc, array) => coalesce(aggregate(array, acc, (acc, elem) => acc + elem), acc)), + 15) + } + + test("TransformKeys") { + val ai0 = Literal.create( + Map(1 -> 1, 2 -> 2, 3 -> 3, 4 -> 4), + MapType(IntegerType, IntegerType, valueContainsNull = false)) + val ai1 = Literal.create( + Map.empty[Int, Int], + MapType(IntegerType, IntegerType, valueContainsNull = true)) + val ai2 = Literal.create( + Map(1 -> 1, 2 -> null, 3 -> 3), + MapType(IntegerType, IntegerType, valueContainsNull = true)) + val ai3 = Literal.create(null, MapType(IntegerType, IntegerType, valueContainsNull = false)) + + val plusOne: (Expression, Expression) => Expression = (k, v) => k + 1 + val plusValue: (Expression, Expression) => Expression = (k, v) => k + v + val modKey: (Expression, Expression) => Expression = (k, v) => k % 3 + + checkEvaluation(transformKeys(ai0, plusOne), Map(2 -> 1, 3 -> 2, 4 -> 3, 5 -> 4)) + checkEvaluation(transformKeys(ai0, plusValue), Map(2 -> 1, 4 -> 2, 6 -> 3, 8 -> 4)) + checkEvaluation( + transformKeys(transformKeys(ai0, plusOne), plusValue), Map(3 -> 1, 5 -> 2, 7 -> 3, 9 -> 4)) + checkEvaluation(transformKeys(ai0, modKey), + ArrayBasedMapData(Array(1, 2, 0, 1), Array(1, 2, 3, 4))) + checkEvaluation(transformKeys(ai1, plusOne), Map.empty[Int, Int]) + checkEvaluation(transformKeys(ai1, plusOne), Map.empty[Int, Int]) + checkEvaluation( + transformKeys(transformKeys(ai1, plusOne), plusValue), Map.empty[Int, Int]) + checkEvaluation(transformKeys(ai2, plusOne), Map(2 -> 1, 3 -> null, 4 -> 3)) + checkEvaluation( + transformKeys(transformKeys(ai2, plusOne), plusOne), Map(3 -> 1, 4 -> null, 5 -> 3)) + checkEvaluation(transformKeys(ai3, plusOne), null) + + val as0 = Literal.create( + Map("a" -> "xy", "bb" -> "yz", "ccc" -> "zx"), + MapType(StringType, StringType, valueContainsNull = false)) + val as1 = Literal.create( + Map("a" -> "xy", "bb" -> "yz", "ccc" -> null), + MapType(StringType, StringType, valueContainsNull = true)) + val as2 = Literal.create(null, + MapType(StringType, StringType, valueContainsNull = false)) + val as3 = Literal.create(Map.empty[StringType, StringType], + MapType(StringType, StringType, valueContainsNull = true)) + + val concatValue: (Expression, Expression) => Expression = (k, v) => Concat(Seq(k, v)) + val convertKeyToKeyLength: (Expression, Expression) => Expression = + (k, v) => Length(k) + 1 + + checkEvaluation( + transformKeys(as0, concatValue), Map("axy" -> "xy", "bbyz" -> "yz", "ccczx" -> "zx")) + checkEvaluation( + transformKeys(transformKeys(as0, concatValue), concatValue), + Map("axyxy" -> "xy", "bbyzyz" -> "yz", "ccczxzx" -> "zx")) + checkEvaluation(transformKeys(as3, concatValue), Map.empty[String, String]) + checkEvaluation( + transformKeys(transformKeys(as3, concatValue), convertKeyToKeyLength), + Map.empty[Int, String]) + checkEvaluation(transformKeys(as0, convertKeyToKeyLength), + Map(2 -> "xy", 3 -> "yz", 4 -> "zx")) + checkEvaluation(transformKeys(as1, convertKeyToKeyLength), + Map(2 -> "xy", 3 -> "yz", 4 -> null)) + checkEvaluation(transformKeys(as2, convertKeyToKeyLength), null) + checkEvaluation(transformKeys(as3, convertKeyToKeyLength), Map.empty[Int, String]) + + val ax0 = Literal.create( + Map(1 -> "x", 2 -> "y", 3 -> "z"), + MapType(IntegerType, StringType, valueContainsNull = false)) + + checkEvaluation(transformKeys(ax0, plusOne), Map(2 -> "x", 3 -> "y", 4 -> "z")) + } + + test("TransformValues") { + val ai0 = Literal.create( + Map(1 -> 1, 2 -> 2, 3 -> 3), + MapType(IntegerType, IntegerType, valueContainsNull = false)) + val ai1 = Literal.create( + Map(1 -> 1, 2 -> null, 3 -> 3), + MapType(IntegerType, IntegerType, valueContainsNull = true)) + val ai2 = Literal.create( + Map.empty[Int, Int], + MapType(IntegerType, IntegerType, valueContainsNull = true)) + val ai3 = Literal.create(null, MapType(IntegerType, IntegerType, valueContainsNull = false)) + + val plusOne: (Expression, Expression) => Expression = (k, v) => v + 1 + val valueUpdate: (Expression, Expression) => Expression = (k, v) => k * k + + checkEvaluation(transformValues(ai0, plusOne), Map(1 -> 2, 2 -> 3, 3 -> 4)) + checkEvaluation(transformValues(ai0, valueUpdate), Map(1 -> 1, 2 -> 4, 3 -> 9)) + checkEvaluation( + transformValues(transformValues(ai0, plusOne), valueUpdate), Map(1 -> 1, 2 -> 4, 3 -> 9)) + checkEvaluation(transformValues(ai1, plusOne), Map(1 -> 2, 2 -> null, 3 -> 4)) + checkEvaluation(transformValues(ai1, valueUpdate), Map(1 -> 1, 2 -> 4, 3 -> 9)) + checkEvaluation( + transformValues(transformValues(ai1, plusOne), valueUpdate), Map(1 -> 1, 2 -> 4, 3 -> 9)) + checkEvaluation(transformValues(ai2, plusOne), Map.empty[Int, Int]) + checkEvaluation(transformValues(ai3, plusOne), null) + + val as0 = Literal.create( + Map("a" -> "xy", "bb" -> "yz", "ccc" -> "zx"), + MapType(StringType, StringType, valueContainsNull = false)) + val as1 = Literal.create( + Map("a" -> "xy", "bb" -> null, "ccc" -> "zx"), + MapType(StringType, StringType, valueContainsNull = true)) + val as2 = Literal.create(Map.empty[StringType, StringType], + MapType(StringType, StringType, valueContainsNull = true)) + val as3 = Literal.create(null, MapType(StringType, StringType, valueContainsNull = true)) + + val concatValue: (Expression, Expression) => Expression = (k, v) => Concat(Seq(k, v)) + val valueTypeUpdate: (Expression, Expression) => Expression = + (k, v) => Length(v) + 1 + + checkEvaluation( + transformValues(as0, concatValue), Map("a" -> "axy", "bb" -> "bbyz", "ccc" -> "ccczx")) + checkEvaluation(transformValues(as0, valueTypeUpdate), + Map("a" -> 3, "bb" -> 3, "ccc" -> 3)) + checkEvaluation( + transformValues(transformValues(as0, concatValue), concatValue), + Map("a" -> "aaxy", "bb" -> "bbbbyz", "ccc" -> "cccccczx")) + checkEvaluation(transformValues(as1, concatValue), + Map("a" -> "axy", "bb" -> null, "ccc" -> "ccczx")) + checkEvaluation(transformValues(as1, valueTypeUpdate), + Map("a" -> 3, "bb" -> null, "ccc" -> 3)) + checkEvaluation( + transformValues(transformValues(as1, concatValue), concatValue), + Map("a" -> "aaxy", "bb" -> null, "ccc" -> "cccccczx")) + checkEvaluation(transformValues(as2, concatValue), Map.empty[String, String]) + checkEvaluation(transformValues(as2, valueTypeUpdate), Map.empty[String, Int]) + checkEvaluation( + transformValues(transformValues(as2, concatValue), valueTypeUpdate), + Map.empty[String, Int]) + checkEvaluation(transformValues(as3, concatValue), null) + + val ax0 = Literal.create( + Map(1 -> "x", 2 -> "y", 3 -> "z"), + MapType(IntegerType, StringType, valueContainsNull = false)) + + checkEvaluation(transformValues(ax0, valueUpdate), Map(1 -> 1, 2 -> 4, 3 -> 9)) + } + + test("MapZipWith") { + def map_zip_with( + left: Expression, + right: Expression, + f: (Expression, Expression, Expression) => Expression): Expression = { + val MapType(kt, vt1, _) = left.dataType + val MapType(_, vt2, _) = right.dataType + MapZipWith(left, right, createLambda(kt, false, vt1, true, vt2, true, f)) + .bind(validateBinding) + } + + val mii0 = Literal.create(Map(1 -> 10, 2 -> 20, 3 -> 30), + MapType(IntegerType, IntegerType, valueContainsNull = false)) + val mii1 = Literal.create(Map(1 -> -1, 2 -> -2, 4 -> -4), + MapType(IntegerType, IntegerType, valueContainsNull = false)) + val mii2 = Literal.create(Map(1 -> null, 2 -> -2, 3 -> null), + MapType(IntegerType, IntegerType, valueContainsNull = true)) + val mii3 = Literal.create(Map(), MapType(IntegerType, IntegerType, valueContainsNull = false)) + val mii4 = MapFromArrays( + Literal.create(Seq(2, 2), ArrayType(IntegerType, false)), + Literal.create(Seq(20, 200), ArrayType(IntegerType, false))) + val miin = Literal.create(null, MapType(IntegerType, IntegerType, valueContainsNull = false)) + + val multiplyKeyWithValues: (Expression, Expression, Expression) => Expression = { + (k, v1, v2) => k * v1 * v2 + } + + checkEvaluation( + map_zip_with(mii0, mii1, multiplyKeyWithValues), + Map(1 -> -10, 2 -> -80, 3 -> null, 4 -> null)) + checkEvaluation( + map_zip_with(mii0, mii2, multiplyKeyWithValues), + Map(1 -> null, 2 -> -80, 3 -> null)) + checkEvaluation( + map_zip_with(mii0, mii3, multiplyKeyWithValues), + Map(1 -> null, 2 -> null, 3 -> null)) + checkEvaluation( + map_zip_with(mii0, mii4, multiplyKeyWithValues), + Map(1 -> null, 2 -> 800, 3 -> null)) + checkEvaluation( + map_zip_with(mii4, mii0, multiplyKeyWithValues), + Map(2 -> 800, 1 -> null, 3 -> null)) + checkEvaluation( + map_zip_with(mii0, miin, multiplyKeyWithValues), + null) + assert(map_zip_with(mii0, mii1, multiplyKeyWithValues).dataType === + MapType(IntegerType, IntegerType, valueContainsNull = true)) + + val mss0 = Literal.create(Map("a" -> "x", "b" -> "y", "d" -> "z"), + MapType(StringType, StringType, valueContainsNull = false)) + val mss1 = Literal.create(Map("d" -> "b", "b" -> "d"), + MapType(StringType, StringType, valueContainsNull = false)) + val mss2 = Literal.create(Map("c" -> null, "b" -> "t", "a" -> null), + MapType(StringType, StringType, valueContainsNull = true)) + val mss3 = Literal.create(Map(), MapType(StringType, StringType, valueContainsNull = false)) + val mss4 = MapFromArrays( + Literal.create(Seq("a", "a"), ArrayType(StringType, false)), + Literal.create(Seq("a", "n"), ArrayType(StringType, false))) + val mssn = Literal.create(null, MapType(StringType, StringType, valueContainsNull = false)) + + val concat: (Expression, Expression, Expression) => Expression = { + (k, v1, v2) => Concat(Seq(k, v1, v2)) + } + + checkEvaluation( + map_zip_with(mss0, mss1, concat), + Map("a" -> null, "b" -> "byd", "d" -> "dzb")) + checkEvaluation( + map_zip_with(mss1, mss2, concat), + Map("d" -> null, "b" -> "bdt", "c" -> null, "a" -> null)) + checkEvaluation( + map_zip_with(mss0, mss3, concat), + Map("a" -> null, "b" -> null, "d" -> null)) + checkEvaluation( + map_zip_with(mss0, mss4, concat), + Map("a" -> "axa", "b" -> null, "d" -> null)) + checkEvaluation( + map_zip_with(mss4, mss0, concat), + Map("a" -> "aax", "b" -> null, "d" -> null)) + checkEvaluation( + map_zip_with(mss0, mssn, concat), + null) + assert(map_zip_with(mss0, mss1, concat).dataType === + MapType(StringType, StringType, valueContainsNull = true)) + + def b(data: Byte*): Array[Byte] = Array[Byte](data: _*) + + val mbb0 = Literal.create(Map(b(1, 2) -> b(4), b(2, 1) -> b(5), b(1, 3) -> b(8)), + MapType(BinaryType, BinaryType, valueContainsNull = false)) + val mbb1 = Literal.create(Map(b(2, 1) -> b(7), b(1, 2) -> b(3), b(1, 1) -> b(6)), + MapType(BinaryType, BinaryType, valueContainsNull = false)) + val mbb2 = Literal.create(Map(b(1, 3) -> null, b(1, 2) -> b(2), b(2, 1) -> null), + MapType(BinaryType, BinaryType, valueContainsNull = true)) + val mbb3 = Literal.create(Map(), MapType(BinaryType, BinaryType, valueContainsNull = false)) + val mbb4 = MapFromArrays( + Literal.create(Seq(b(2, 1), b(2, 1)), ArrayType(BinaryType, false)), + Literal.create(Seq(b(1), b(9)), ArrayType(BinaryType, false))) + val mbbn = Literal.create(null, MapType(BinaryType, BinaryType, valueContainsNull = false)) + + checkEvaluation( + map_zip_with(mbb0, mbb1, concat), + Map(b(1, 2) -> b(1, 2, 4, 3), b(2, 1) -> b(2, 1, 5, 7), b(1, 3) -> null, b(1, 1) -> null)) + checkEvaluation( + map_zip_with(mbb1, mbb2, concat), + Map(b(2, 1) -> null, b(1, 2) -> b(1, 2, 3, 2), b(1, 1) -> null, b(1, 3) -> null)) + checkEvaluation( + map_zip_with(mbb0, mbb3, concat), + Map(b(1, 2) -> null, b(2, 1) -> null, b(1, 3) -> null)) + checkEvaluation( + map_zip_with(mbb0, mbb4, concat), + Map(b(1, 2) -> null, b(2, 1) -> b(2, 1, 5, 1), b(1, 3) -> null)) + checkEvaluation( + map_zip_with(mbb4, mbb0, concat), + Map(b(2, 1) -> b(2, 1, 1, 5), b(1, 2) -> null, b(1, 3) -> null)) + checkEvaluation( + map_zip_with(mbb0, mbbn, concat), + null) + } + + test("ZipWith") { + def zip_with( + left: Expression, + right: Expression, + f: (Expression, Expression) => Expression): Expression = { + val ArrayType(leftT, _) = left.dataType + val ArrayType(rightT, _) = right.dataType + ZipWith(left, right, createLambda(leftT, true, rightT, true, f)).bind(validateBinding) + } + + val ai0 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType, containsNull = false)) + val ai1 = Literal.create(Seq(1, 2, 3, 4), ArrayType(IntegerType, containsNull = false)) + val ai2 = Literal.create(Seq[Integer](1, null, 3), ArrayType(IntegerType, containsNull = true)) + val ai3 = Literal.create(Seq[Integer](1, null), ArrayType(IntegerType, containsNull = true)) + val ain = Literal.create(null, ArrayType(IntegerType, containsNull = false)) + + val add: (Expression, Expression) => Expression = (x, y) => x + y + val plusOne: Expression => Expression = x => x + 1 + + checkEvaluation(zip_with(ai0, ai1, add), Seq(2, 4, 6, null)) + checkEvaluation(zip_with(ai3, ai2, add), Seq(2, null, null)) + checkEvaluation(zip_with(ai2, ai3, add), Seq(2, null, null)) + checkEvaluation(zip_with(ain, ain, add), null) + checkEvaluation(zip_with(ai1, ain, add), null) + checkEvaluation(zip_with(ain, ai1, add), null) + + val as0 = Literal.create(Seq("a", "b", "c"), ArrayType(StringType, containsNull = false)) + val as1 = Literal.create(Seq("a", null, "c"), ArrayType(StringType, containsNull = true)) + val as2 = Literal.create(Seq("a"), ArrayType(StringType, containsNull = true)) + val asn = Literal.create(null, ArrayType(StringType, containsNull = false)) + + val concat: (Expression, Expression) => Expression = (x, y) => Concat(Seq(x, y)) + + checkEvaluation(zip_with(as0, as1, concat), Seq("aa", null, "cc")) + checkEvaluation(zip_with(as0, as2, concat), Seq("aa", null, null)) + + val aai1 = Literal.create(Seq(Seq(1, 2, 3), null, Seq(4, 5)), + ArrayType(ArrayType(IntegerType, containsNull = false), containsNull = true)) + val aai2 = Literal.create(Seq(Seq(1, 2, 3)), + ArrayType(ArrayType(IntegerType, containsNull = false), containsNull = true)) + checkEvaluation( + zip_with(aai1, aai2, (a1, a2) => + Cast(zip_with(transform(a1, plusOne), transform(a2, plusOne), add), StringType)), + Seq("[4, 6, 8]", null, null)) + checkEvaluation(zip_with(aai1, aai1, (a1, a2) => Cast(transform(a1, plusOne), StringType)), + Seq("[2, 3, 4]", null, "[5, 6]")) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala index 00e97637eee7e..0e9c8abec33e4 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala @@ -392,7 +392,7 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper with val jsonData = """{"a": 1}""" val schema = StructType(StructField("a", IntegerType) :: Nil) checkEvaluation( - JsonToStructs(schema, Map.empty, Literal(jsonData), gmtId, true), + JsonToStructs(schema, Map.empty, Literal(jsonData), gmtId), InternalRow(1) ) } @@ -401,13 +401,13 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper with val jsonData = """{"a" 1}""" val schema = StructType(StructField("a", IntegerType) :: Nil) checkEvaluation( - JsonToStructs(schema, Map.empty, Literal(jsonData), gmtId, true), + JsonToStructs(schema, Map.empty, Literal(jsonData), gmtId), null ) // Other modes should still return `null`. checkEvaluation( - JsonToStructs(schema, Map("mode" -> PermissiveMode.name), Literal(jsonData), gmtId, true), + JsonToStructs(schema, Map("mode" -> PermissiveMode.name), Literal(jsonData), gmtId), null ) } @@ -416,62 +416,62 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper with val input = """[{"a": 1}, {"a": 2}]""" val schema = ArrayType(StructType(StructField("a", IntegerType) :: Nil)) val output = InternalRow(1) :: InternalRow(2) :: Nil - checkEvaluation(JsonToStructs(schema, Map.empty, Literal(input), gmtId, true), output) + checkEvaluation(JsonToStructs(schema, Map.empty, Literal(input), gmtId), output) } test("from_json - input=object, schema=array, output=array of single row") { val input = """{"a": 1}""" val schema = ArrayType(StructType(StructField("a", IntegerType) :: Nil)) val output = InternalRow(1) :: Nil - checkEvaluation(JsonToStructs(schema, Map.empty, Literal(input), gmtId, true), output) + checkEvaluation(JsonToStructs(schema, Map.empty, Literal(input), gmtId), output) } test("from_json - input=empty array, schema=array, output=empty array") { val input = "[ ]" val schema = ArrayType(StructType(StructField("a", IntegerType) :: Nil)) val output = Nil - checkEvaluation(JsonToStructs(schema, Map.empty, Literal(input), gmtId, true), output) + checkEvaluation(JsonToStructs(schema, Map.empty, Literal(input), gmtId), output) } test("from_json - input=empty object, schema=array, output=array of single row with null") { val input = "{ }" val schema = ArrayType(StructType(StructField("a", IntegerType) :: Nil)) val output = InternalRow(null) :: Nil - checkEvaluation(JsonToStructs(schema, Map.empty, Literal(input), gmtId, true), output) + checkEvaluation(JsonToStructs(schema, Map.empty, Literal(input), gmtId), output) } test("from_json - input=array of single object, schema=struct, output=single row") { val input = """[{"a": 1}]""" val schema = StructType(StructField("a", IntegerType) :: Nil) val output = InternalRow(1) - checkEvaluation(JsonToStructs(schema, Map.empty, Literal(input), gmtId, true), output) + checkEvaluation(JsonToStructs(schema, Map.empty, Literal(input), gmtId), output) } test("from_json - input=array, schema=struct, output=null") { val input = """[{"a": 1}, {"a": 2}]""" val schema = StructType(StructField("a", IntegerType) :: Nil) val output = null - checkEvaluation(JsonToStructs(schema, Map.empty, Literal(input), gmtId, true), output) + checkEvaluation(JsonToStructs(schema, Map.empty, Literal(input), gmtId), output) } test("from_json - input=empty array, schema=struct, output=null") { val input = """[]""" val schema = StructType(StructField("a", IntegerType) :: Nil) val output = null - checkEvaluation(JsonToStructs(schema, Map.empty, Literal(input), gmtId, true), output) + checkEvaluation(JsonToStructs(schema, Map.empty, Literal(input), gmtId), output) } test("from_json - input=empty object, schema=struct, output=single row with null") { val input = """{ }""" val schema = StructType(StructField("a", IntegerType) :: Nil) val output = InternalRow(null) - checkEvaluation(JsonToStructs(schema, Map.empty, Literal(input), gmtId, true), output) + checkEvaluation(JsonToStructs(schema, Map.empty, Literal(input), gmtId), output) } test("from_json null input column") { val schema = StructType(StructField("a", IntegerType) :: Nil) checkEvaluation( - JsonToStructs(schema, Map.empty, Literal.create(null, StringType), gmtId, true), + JsonToStructs(schema, Map.empty, Literal.create(null, StringType), gmtId), null ) } @@ -479,7 +479,7 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper with test("SPARK-20549: from_json bad UTF-8") { val schema = StructType(StructField("a", IntegerType) :: Nil) checkEvaluation( - JsonToStructs(schema, Map.empty, Literal(badJson), gmtId, true), + JsonToStructs(schema, Map.empty, Literal(badJson), gmtId), null) } @@ -491,14 +491,14 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper with c.set(2016, 0, 1, 0, 0, 0) c.set(Calendar.MILLISECOND, 123) checkEvaluation( - JsonToStructs(schema, Map.empty, Literal(jsonData1), gmtId, true), + JsonToStructs(schema, Map.empty, Literal(jsonData1), gmtId), InternalRow(c.getTimeInMillis * 1000L) ) // The result doesn't change because the json string includes timezone string ("Z" here), // which means the string represents the timestamp string in the timezone regardless of // the timeZoneId parameter. checkEvaluation( - JsonToStructs(schema, Map.empty, Literal(jsonData1), Option("PST"), true), + JsonToStructs(schema, Map.empty, Literal(jsonData1), Option("PST")), InternalRow(c.getTimeInMillis * 1000L) ) @@ -512,8 +512,7 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper with schema, Map("timestampFormat" -> "yyyy-MM-dd'T'HH:mm:ss"), Literal(jsonData2), - Option(tz.getID), - true), + Option(tz.getID)), InternalRow(c.getTimeInMillis * 1000L) ) checkEvaluation( @@ -522,8 +521,7 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper with Map("timestampFormat" -> "yyyy-MM-dd'T'HH:mm:ss", DateTimeUtils.TIMEZONE_OPTION -> tz.getID), Literal(jsonData2), - gmtId, - true), + gmtId), InternalRow(c.getTimeInMillis * 1000L) ) } @@ -532,7 +530,7 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper with test("SPARK-19543: from_json empty input column") { val schema = StructType(StructField("a", IntegerType) :: Nil) checkEvaluation( - JsonToStructs(schema, Map.empty, Literal.create(" ", StringType), gmtId, true), + JsonToStructs(schema, Map.empty, Literal.create(" ", StringType), gmtId), null ) } @@ -687,23 +685,31 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper with test("from_json missing fields") { for (forceJsonNullableSchema <- Seq(false, true)) { - val input = - """{ + withSQLConf(SQLConf.FROM_JSON_FORCE_NULLABLE_SCHEMA.key -> forceJsonNullableSchema.toString) { + val input = + """{ | "a": 1, | "c": "foo" |} |""".stripMargin - val jsonSchema = new StructType() - .add("a", LongType, nullable = false) - .add("b", StringType, nullable = false) - .add("c", StringType, nullable = false) - val output = InternalRow(1L, null, UTF8String.fromString("foo")) - val expr = JsonToStructs( - jsonSchema, Map.empty, Literal.create(input, StringType), gmtId, forceJsonNullableSchema) - checkEvaluation(expr, output) - val schema = expr.dataType - val schemaToCompare = if (forceJsonNullableSchema) jsonSchema.asNullable else jsonSchema - assert(schemaToCompare == schema) + val jsonSchema = new StructType() + .add("a", LongType, nullable = false) + .add("b", StringType, nullable = !forceJsonNullableSchema) + .add("c", StringType, nullable = false) + val output = InternalRow(1L, null, UTF8String.fromString("foo")) + val expr = JsonToStructs(jsonSchema, Map.empty, Literal.create(input, StringType), gmtId) + checkEvaluation(expr, output) + val schema = expr.dataType + val schemaToCompare = if (forceJsonNullableSchema) jsonSchema.asNullable else jsonSchema + assert(schemaToCompare == schema) + } } } + + test("SPARK-24709: infer schema of json strings") { + checkEvaluation(SchemaOfJson(Literal.create("""{"col":0}""")), "struct") + checkEvaluation( + SchemaOfJson(Literal.create("""{"col0":["a"], "col1": {"col2": "b"}}""")), + "struct,col1:struct>") + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala index 86f80fe66d28b..3ea6bfac9ddca 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala @@ -226,4 +226,25 @@ class LiteralExpressionSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(Literal('\u0000'), "\u0000") checkEvaluation(Literal.create('\n'), "\n") } + + test("fromString converts String/DataType input correctly") { + checkEvaluation(Literal.fromString(false.toString, BooleanType), false) + checkEvaluation(Literal.fromString(null, NullType), null) + checkEvaluation(Literal.fromString(Int.MaxValue.toByte.toString, ByteType), Int.MaxValue.toByte) + checkEvaluation(Literal.fromString(Short.MaxValue.toShort.toString, ShortType), Short.MaxValue + .toShort) + checkEvaluation(Literal.fromString(Int.MaxValue.toString, IntegerType), Int.MaxValue) + checkEvaluation(Literal.fromString(Long.MaxValue.toString, LongType), Long.MaxValue) + checkEvaluation(Literal.fromString(Float.MaxValue.toString, FloatType), Float.MaxValue) + checkEvaluation(Literal.fromString(Double.MaxValue.toString, DoubleType), Double.MaxValue) + checkEvaluation(Literal.fromString("1.23456", DecimalType(10, 5)), Decimal(1.23456)) + checkEvaluation(Literal.fromString("Databricks", StringType), "Databricks") + val dateString = "1970-01-01" + checkEvaluation(Literal.fromString(dateString, DateType), java.sql.Date.valueOf(dateString)) + val timestampString = "0000-01-01 00:00:00" + checkEvaluation(Literal.fromString(timestampString, TimestampType), + java.sql.Timestamp.valueOf(timestampString)) + val calInterval = new CalendarInterval(1, 1) + checkEvaluation(Literal.fromString(calInterval.toString, CalendarIntervalType), calInterval) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MaskExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MaskExpressionsSuite.scala deleted file mode 100644 index 4d69dc32ace82..0000000000000 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MaskExpressionsSuite.scala +++ /dev/null @@ -1,236 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.expressions - -import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.types.{IntegerType, StringType} - -class MaskExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { - - test("mask") { - checkEvaluation(Mask(Literal("abcd-EFGH-8765-4321"), "U", "l", "#"), "llll-UUUU-####-####") - checkEvaluation( - new Mask(Literal("abcd-EFGH-8765-4321"), Literal("U"), Literal("l"), Literal("#")), - "llll-UUUU-####-####") - checkEvaluation(new Mask(Literal("abcd-EFGH-8765-4321"), Literal("U"), Literal("l")), - "llll-UUUU-nnnn-nnnn") - checkEvaluation(new Mask(Literal("abcd-EFGH-8765-4321"), Literal("U")), "xxxx-UUUU-nnnn-nnnn") - checkEvaluation(new Mask(Literal("abcd-EFGH-8765-4321")), "xxxx-XXXX-nnnn-nnnn") - checkEvaluation(new Mask(Literal(null, StringType)), null) - checkEvaluation(Mask(Literal("abcd-EFGH-8765-4321"), null, "l", "#"), "llll-XXXX-####-####") - checkEvaluation(new Mask( - Literal("abcd-EFGH-8765-4321"), - Literal(null, StringType), - Literal(null, StringType), - Literal(null, StringType)), "xxxx-XXXX-nnnn-nnnn") - checkEvaluation(new Mask(Literal("abcd-EFGH-8765-4321"), Literal("Upper")), - "xxxx-UUUU-nnnn-nnnn") - checkEvaluation(new Mask(Literal("")), "") - checkEvaluation(new Mask(Literal("abcd-EFGH-8765-4321"), Literal("")), "xxxx-XXXX-nnnn-nnnn") - checkEvaluation(Mask(Literal("abcd-EFGH-8765-4321"), "", "", ""), "xxxx-XXXX-nnnn-nnnn") - // scalastyle:off nonascii - checkEvaluation(Mask(Literal("Ul9U"), "\u2200", null, null), "\u2200xn\u2200") - checkEvaluation(new Mask(Literal("Hello World, こんにちは, 𠀋"), Literal("あ"), Literal("𡈽")), - "あ𡈽𡈽𡈽𡈽 あ𡈽𡈽𡈽𡈽, こんにちは, 𠀋") - // scalastyle:on nonascii - intercept[AnalysisException] { - checkEvaluation(new Mask(Literal(""), Literal(1)), "") - } - } - - test("mask_first_n") { - checkEvaluation(MaskFirstN(Literal("aB3d-EFGH-8765"), 6, "U", "l", "#"), - "lU#l-UFGH-8765") - checkEvaluation(new MaskFirstN( - Literal("abcd-EFGH-8765-4321"), Literal(6), Literal("U"), Literal("l"), Literal("#")), - "llll-UFGH-8765-4321") - checkEvaluation( - new MaskFirstN(Literal("abcd-EFGH-8765-4321"), Literal(6), Literal("U"), Literal("l")), - "llll-UFGH-8765-4321") - checkEvaluation(new MaskFirstN(Literal("abcd-EFGH-8765-4321"), Literal(6), Literal("U")), - "xxxx-UFGH-8765-4321") - checkEvaluation(new MaskFirstN(Literal("abcd-EFGH-8765-4321"), Literal(6)), - "xxxx-XFGH-8765-4321") - intercept[AnalysisException] { - checkEvaluation(new MaskFirstN(Literal("abcd-EFGH-8765-4321"), Literal("U")), "") - } - checkEvaluation(new MaskFirstN(Literal("abcd-EFGH-8765-4321")), "xxxx-EFGH-8765-4321") - checkEvaluation(new MaskFirstN(Literal(null, StringType)), null) - checkEvaluation(MaskFirstN(Literal("abcd-EFGH-8765-4321"), 4, "U", "l", null), - "llll-EFGH-8765-4321") - checkEvaluation(new MaskFirstN( - Literal("abcd-EFGH-8765-4321"), - Literal(null, IntegerType), - Literal(null, StringType), - Literal(null, StringType), - Literal(null, StringType)), "xxxx-EFGH-8765-4321") - checkEvaluation(new MaskFirstN(Literal("abcd-EFGH-8765-4321"), Literal(6), Literal("Upper")), - "xxxx-UFGH-8765-4321") - checkEvaluation(new MaskFirstN(Literal("")), "") - checkEvaluation(new MaskFirstN(Literal("abcd-EFGH-8765-4321"), Literal(4), Literal("")), - "xxxx-EFGH-8765-4321") - checkEvaluation(MaskFirstN(Literal("abcd-EFGH-8765-4321"), 1000, "", "", ""), - "xxxx-XXXX-nnnn-nnnn") - checkEvaluation(MaskFirstN(Literal("abcd-EFGH-8765-4321"), -1, "", "", ""), - "abcd-EFGH-8765-4321") - // scalastyle:off nonascii - checkEvaluation(MaskFirstN(Literal("Ul9U"), 2, "\u2200", null, null), "\u2200x9U") - checkEvaluation(new MaskFirstN(Literal("あ, 𠀋, Hello World"), Literal(10)), - "あ, 𠀋, Xxxxo World") - // scalastyle:on nonascii - } - - test("mask_last_n") { - checkEvaluation(MaskLastN(Literal("abcd-EFGH-aB3d"), 6, "U", "l", "#"), - "abcd-EFGU-lU#l") - checkEvaluation(new MaskLastN( - Literal("abcd-EFGH-8765"), Literal(6), Literal("U"), Literal("l"), Literal("#")), - "abcd-EFGU-####") - checkEvaluation( - new MaskLastN(Literal("abcd-EFGH-8765"), Literal(6), Literal("U"), Literal("l")), - "abcd-EFGU-nnnn") - checkEvaluation( - new MaskLastN(Literal("abcd-EFGH-8765"), Literal(6), Literal("U")), - "abcd-EFGU-nnnn") - checkEvaluation( - new MaskLastN(Literal("abcd-EFGH-8765"), Literal(6)), - "abcd-EFGX-nnnn") - intercept[AnalysisException] { - checkEvaluation(new MaskLastN(Literal("abcd-EFGH-8765"), Literal("U")), "") - } - checkEvaluation(new MaskLastN(Literal("abcd-EFGH-8765-4321")), "abcd-EFGH-8765-nnnn") - checkEvaluation(new MaskLastN(Literal(null, StringType)), null) - checkEvaluation(MaskLastN(Literal("abcd-EFGH-8765-4321"), 4, "U", "l", null), - "abcd-EFGH-8765-nnnn") - checkEvaluation(new MaskLastN( - Literal("abcd-EFGH-8765-4321"), - Literal(null, IntegerType), - Literal(null, StringType), - Literal(null, StringType), - Literal(null, StringType)), "abcd-EFGH-8765-nnnn") - checkEvaluation(new MaskLastN(Literal("abcd-EFGH-8765-4321"), Literal(12), Literal("Upper")), - "abcd-EFUU-nnnn-nnnn") - checkEvaluation(new MaskLastN(Literal("")), "") - checkEvaluation(new MaskLastN(Literal("abcd-EFGH-8765-4321"), Literal(16), Literal("")), - "abcx-XXXX-nnnn-nnnn") - checkEvaluation(MaskLastN(Literal("abcd-EFGH-8765-4321"), 1000, "", "", ""), - "xxxx-XXXX-nnnn-nnnn") - checkEvaluation(MaskLastN(Literal("abcd-EFGH-8765-4321"), -1, "", "", ""), - "abcd-EFGH-8765-4321") - // scalastyle:off nonascii - checkEvaluation(MaskLastN(Literal("Ul9U"), 2, "\u2200", null, null), "Uln\u2200") - checkEvaluation(new MaskLastN(Literal("あ, 𠀋, Hello World あ 𠀋"), Literal(10)), - "あ, 𠀋, Hello Xxxxx あ 𠀋") - // scalastyle:on nonascii - } - - test("mask_show_first_n") { - checkEvaluation(MaskShowFirstN(Literal("abcd-EFGH-8765-aB3d"), 6, "U", "l", "#"), - "abcd-EUUU-####-lU#l") - checkEvaluation(new MaskShowFirstN( - Literal("abcd-EFGH-8765-4321"), Literal(6), Literal("U"), Literal("l"), Literal("#")), - "abcd-EUUU-####-####") - checkEvaluation( - new MaskShowFirstN(Literal("abcd-EFGH-8765-4321"), Literal(6), Literal("U"), Literal("l")), - "abcd-EUUU-nnnn-nnnn") - checkEvaluation(new MaskShowFirstN(Literal("abcd-EFGH-8765-4321"), Literal(6), Literal("U")), - "abcd-EUUU-nnnn-nnnn") - checkEvaluation(new MaskShowFirstN(Literal("abcd-EFGH-8765-4321"), Literal(6)), - "abcd-EXXX-nnnn-nnnn") - intercept[AnalysisException] { - checkEvaluation(new MaskShowFirstN(Literal("abcd-EFGH-8765-4321"), Literal("U")), "") - } - checkEvaluation(new MaskShowFirstN(Literal("abcd-EFGH-8765-4321")), "abcd-XXXX-nnnn-nnnn") - checkEvaluation(new MaskShowFirstN(Literal(null, StringType)), null) - checkEvaluation(MaskShowFirstN(Literal("abcd-EFGH-8765-4321"), 4, "U", "l", null), - "abcd-UUUU-nnnn-nnnn") - checkEvaluation(new MaskShowFirstN( - Literal("abcd-EFGH-8765-4321"), - Literal(null, IntegerType), - Literal(null, StringType), - Literal(null, StringType), - Literal(null, StringType)), "abcd-XXXX-nnnn-nnnn") - checkEvaluation( - new MaskShowFirstN(Literal("abcd-EFGH-8765-4321"), Literal(6), Literal("Upper")), - "abcd-EUUU-nnnn-nnnn") - checkEvaluation(new MaskShowFirstN(Literal("")), "") - checkEvaluation(new MaskShowFirstN(Literal("abcd-EFGH-8765-4321"), Literal(4), Literal("")), - "abcd-XXXX-nnnn-nnnn") - checkEvaluation(MaskShowFirstN(Literal("abcd-EFGH-8765-4321"), 1000, "", "", ""), - "abcd-EFGH-8765-4321") - checkEvaluation(MaskShowFirstN(Literal("abcd-EFGH-8765-4321"), -1, "", "", ""), - "xxxx-XXXX-nnnn-nnnn") - // scalastyle:off nonascii - checkEvaluation(MaskShowFirstN(Literal("Ul9U"), 2, "\u2200", null, null), "Uln\u2200") - checkEvaluation(new MaskShowFirstN(Literal("あ, 𠀋, Hello World"), Literal(10)), - "あ, 𠀋, Hellx Xxxxx") - // scalastyle:on nonascii - } - - test("mask_show_last_n") { - checkEvaluation(MaskShowLastN(Literal("aB3d-EFGH-8765"), 6, "U", "l", "#"), - "lU#l-UUUH-8765") - checkEvaluation(new MaskShowLastN( - Literal("abcd-EFGH-8765-4321"), Literal(6), Literal("U"), Literal("l"), Literal("#")), - "llll-UUUU-###5-4321") - checkEvaluation( - new MaskShowLastN(Literal("abcd-EFGH-8765-4321"), Literal(6), Literal("U"), Literal("l")), - "llll-UUUU-nnn5-4321") - checkEvaluation(new MaskShowLastN(Literal("abcd-EFGH-8765-4321"), Literal(6), Literal("U")), - "xxxx-UUUU-nnn5-4321") - checkEvaluation(new MaskShowLastN(Literal("abcd-EFGH-8765-4321"), Literal(6)), - "xxxx-XXXX-nnn5-4321") - intercept[AnalysisException] { - checkEvaluation(new MaskShowLastN(Literal("abcd-EFGH-8765-4321"), Literal("U")), "") - } - checkEvaluation(new MaskShowLastN(Literal("abcd-EFGH-8765-4321")), "xxxx-XXXX-nnnn-4321") - checkEvaluation(new MaskShowLastN(Literal(null, StringType)), null) - checkEvaluation(MaskShowLastN(Literal("abcd-EFGH-8765-4321"), 4, "U", "l", null), - "llll-UUUU-nnnn-4321") - checkEvaluation(new MaskShowLastN( - Literal("abcd-EFGH-8765-4321"), - Literal(null, IntegerType), - Literal(null, StringType), - Literal(null, StringType), - Literal(null, StringType)), "xxxx-XXXX-nnnn-4321") - checkEvaluation(new MaskShowLastN(Literal("abcd-EFGH-8765-4321"), Literal(6), Literal("Upper")), - "xxxx-UUUU-nnn5-4321") - checkEvaluation(new MaskShowLastN(Literal("")), "") - checkEvaluation(new MaskShowLastN(Literal("abcd-EFGH-8765-4321"), Literal(4), Literal("")), - "xxxx-XXXX-nnnn-4321") - checkEvaluation(MaskShowLastN(Literal("abcd-EFGH-8765-4321"), 1000, "", "", ""), - "abcd-EFGH-8765-4321") - checkEvaluation(MaskShowLastN(Literal("abcd-EFGH-8765-4321"), -1, "", "", ""), - "xxxx-XXXX-nnnn-nnnn") - // scalastyle:off nonascii - checkEvaluation(MaskShowLastN(Literal("Ul9U"), 2, "\u2200", null, null), "\u2200x9U") - checkEvaluation(new MaskShowLastN(Literal("あ, 𠀋, Hello World"), Literal(10)), - "あ, 𠀋, Xello World") - // scalastyle:on nonascii - } - - test("mask_hash") { - checkEvaluation(MaskHash(Literal("abcd-EFGH-8765-4321")), "60c713f5ec6912229d2060df1c322776") - checkEvaluation(MaskHash(Literal("")), "d41d8cd98f00b204e9800998ecf8427e") - checkEvaluation(MaskHash(Literal(null, StringType)), null) - // scalastyle:off nonascii - checkEvaluation(MaskHash(Literal("\u2200x9U")), "f1243ef123d516b1f32a3a75309e5711") - // scalastyle:on nonascii - } -} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullExpressionsSuite.scala index 424c3a4696077..6e07f7a59b730 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullExpressionsSuite.scala @@ -86,6 +86,13 @@ class NullExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(Coalesce(Seq(nullLit, lit, lit)), value) checkEvaluation(Coalesce(Seq(nullLit, nullLit, lit)), value) } + + val coalesce = Coalesce(Seq( + Literal.create(null, ArrayType(IntegerType, containsNull = false)), + Literal.create(Seq(1, 2, 3), ArrayType(IntegerType, containsNull = false)), + Literal.create(Seq(1, 2, 3, null), ArrayType(IntegerType, containsNull = true)))) + assert(coalesce.dataType === ArrayType(IntegerType, containsNull = true)) + checkEvaluation(coalesce, Seq(1, 2, 3)) } test("SPARK-16602 Nvl should support numeric-string cases") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala index 20d568c44258f..b0af9e07d1d1d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions import java.sql.{Date, Timestamp} import scala.collection.JavaConverters._ +import scala.language.existentials import scala.reflect.ClassTag import scala.reflect.runtime.universe.TypeTag import scala.util.Random diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala index c48730bd9d1cc..1fa185cc77ebb 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala @@ -30,7 +30,7 @@ class SubexpressionEliminationSuite extends SparkFunSuite { } val b1 = a.withName("name2").withExprId(id) val b2 = a.withExprId(id) - val b3 = a.withQualifier(Some("qualifierName")) + val b3 = a.withQualifier(Seq("qualifierName")) assert(b1 != b2) assert(a != b1) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/TimeWindowSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/TimeWindowSuite.scala index 351d4d0c2eac9..d46135c02bc01 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/TimeWindowSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/TimeWindowSuite.scala @@ -77,6 +77,19 @@ class TimeWindowSuite extends SparkFunSuite with ExpressionEvalHelper with Priva } } + test("SPARK-21590: Start time works with negative values and return microseconds") { + val validDuration = "10 minutes" + for ((text, seconds) <- Seq( + ("-10 seconds", -10000000), // -1e7 + ("-1 minute", -60000000), + ("-1 hour", -3600000000L))) { // -6e7 + assert(TimeWindow(Literal(10L), validDuration, validDuration, "interval " + text).startTime + === seconds) + assert(TimeWindow(Literal(10L), validDuration, validDuration, text).startTime + === seconds) + } + } + private val parseExpression = PrivateMethod[Long]('parseExpression) test("parse sql expression for duration in microseconds - string") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolderSparkSubmitSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolderSparkSubmitSuite.scala index 85682cf6ea670..d2862c8f41d1b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolderSparkSubmitSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolderSparkSubmitSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.expressions.codegen -import org.scalatest.{BeforeAndAfterEach, Matchers} +import org.scalatest.{Assertions, BeforeAndAfterEach, Matchers} import org.apache.spark.{SparkFunSuite, TestUtils} import org.apache.spark.deploy.SparkSubmitSuite @@ -39,7 +39,7 @@ class BufferHolderSparkSubmitSuite val argsForSparkSubmit = Seq( "--class", BufferHolderSparkSubmitSuite.getClass.getName.stripSuffix("$"), "--name", "SPARK-22222", - "--master", "local-cluster[2,1,1024]", + "--master", "local-cluster[1,1,4096]", "--driver-memory", "4g", "--conf", "spark.ui.enabled=false", "--conf", "spark.master.rest.enabled=false", @@ -49,28 +49,36 @@ class BufferHolderSparkSubmitSuite } } -object BufferHolderSparkSubmitSuite { +object BufferHolderSparkSubmitSuite extends Assertions { def main(args: Array[String]): Unit = { val ARRAY_MAX = ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH - val holder = new BufferHolder(new UnsafeRow(1000)) + val unsafeRow = new UnsafeRow(1000) + val holder = new BufferHolder(unsafeRow) holder.reset() - holder.grow(roundToWord(ARRAY_MAX / 2)) - holder.reset() - holder.grow(roundToWord(ARRAY_MAX / 2 + 8)) + assert(intercept[IllegalArgumentException] { + holder.grow(-1) + }.getMessage.contains("because the size is negative")) - holder.reset() - holder.grow(roundToWord(Integer.MAX_VALUE / 2)) + // while to reuse a buffer may happen, this test checks whether the buffer can be grown + holder.grow(ARRAY_MAX / 2) + assert(unsafeRow.getSizeInBytes % 8 == 0) - holder.reset() - holder.grow(roundToWord(Integer.MAX_VALUE)) - } + holder.grow(ARRAY_MAX / 2 + 7) + assert(unsafeRow.getSizeInBytes % 8 == 0) + + holder.grow(Integer.MAX_VALUE / 2) + assert(unsafeRow.getSizeInBytes % 8 == 0) + + holder.grow(ARRAY_MAX - holder.totalSize()) + assert(unsafeRow.getSizeInBytes % 8 == 0) - private def roundToWord(len: Int): Int = { - ByteArrayMethods.roundNumberOfBytesToNearestWord(len) + assert(intercept[IllegalArgumentException] { + holder.grow(ARRAY_MAX + 1 - holder.totalSize()) + }.getMessage.contains("because the size after growing")) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolderSuite.scala index c7c386b5b838a..4e0f903a030aa 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolderSuite.scala @@ -23,17 +23,15 @@ import org.apache.spark.sql.catalyst.expressions.UnsafeRow class BufferHolderSuite extends SparkFunSuite { test("SPARK-16071 Check the size limit to avoid integer overflow") { - var e = intercept[UnsupportedOperationException] { + assert(intercept[UnsupportedOperationException] { new BufferHolder(new UnsafeRow(Int.MaxValue / 8)) - } - assert(e.getMessage.contains("too many fields")) + }.getMessage.contains("too many fields")) val holder = new BufferHolder(new UnsafeRow(1000)) holder.reset() holder.grow(1000) - e = intercept[UnsupportedOperationException] { + assert(intercept[IllegalArgumentException] { holder.grow(Integer.MAX_VALUE) - } - assert(e.getMessage.contains("exceeds size limitation")) + }.getMessage.contains("exceeds size limitation")) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeBlockSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeBlockSuite.scala index d2c6420eadb20..55569b6f2933e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeBlockSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeBlockSuite.scala @@ -65,7 +65,9 @@ class CodeBlockSuite extends SparkFunSuite { |boolean $isNull = false; |int $value = -1; """.stripMargin - val exprValues = code.exprValues + val exprValues = code.asInstanceOf[CodeBlock].blockInputs.collect { + case e: ExprValue => e + }.toSet assert(exprValues.size == 2) assert(exprValues === Set(value, isNull)) } @@ -94,7 +96,9 @@ class CodeBlockSuite extends SparkFunSuite { assert(code.toString == expected) - val exprValues = code.exprValues + val exprValues = code.children.flatMap(_.asInstanceOf[CodeBlock].blockInputs.collect { + case e: ExprValue => e + }).toSet assert(exprValues.size == 5) assert(exprValues === Set(isNull1, value1, isNull2, value2, literal)) } @@ -107,7 +111,7 @@ class CodeBlockSuite extends SparkFunSuite { assert(e.getMessage().contains(s"Can not interpolate ${obj.getClass.getName}")) } - test("replace expr values in code block") { + test("transform expr in code block") { val expr = JavaCode.expression("1 + 1", IntegerType) val isNull = JavaCode.isNullVariable("expr1_isNull") val exprInFunc = JavaCode.variable("expr1", IntegerType) @@ -120,11 +124,11 @@ class CodeBlockSuite extends SparkFunSuite { |}""".stripMargin val aliasedParam = JavaCode.variable("aliased", expr.javaType) - val aliasedInputs = code.asInstanceOf[CodeBlock].blockInputs.map { - case _: SimpleExprValue => aliasedParam - case other => other + + // We want to replace all occurrences of `expr` with the variable `aliasedParam`. + val aliasedCode = code.transformExprValues { + case SimpleExprValue("1 + 1", java.lang.Integer.TYPE) => aliasedParam } - val aliasedCode = CodeBlock(code.asInstanceOf[CodeBlock].codeParts, aliasedInputs).stripMargin val expected = code""" |callFunc(int $aliasedParam) { @@ -133,4 +137,61 @@ class CodeBlockSuite extends SparkFunSuite { |}""".stripMargin assert(aliasedCode.toString == expected.toString) } + + test ("transform expr in nested blocks") { + val expr = JavaCode.expression("1 + 1", IntegerType) + val isNull = JavaCode.isNullVariable("expr1_isNull") + val exprInFunc = JavaCode.variable("expr1", IntegerType) + + val funcs = Seq("callFunc1", "callFunc2", "callFunc3") + val subBlocks = funcs.map { funcName => + code""" + |$funcName(int $expr) { + | boolean $isNull = false; + | int $exprInFunc = $expr + 1; + |}""".stripMargin + } + + val aliasedParam = JavaCode.variable("aliased", expr.javaType) + + val block = code"${subBlocks(0)}\n${subBlocks(1)}\n${subBlocks(2)}" + val transformedBlock = block.transform { + case b: Block => b.transformExprValues { + case SimpleExprValue("1 + 1", java.lang.Integer.TYPE) => aliasedParam + } + }.asInstanceOf[CodeBlock] + + val expected1 = + code""" + |callFunc1(int aliased) { + | boolean expr1_isNull = false; + | int expr1 = aliased + 1; + |}""".stripMargin + + val expected2 = + code""" + |callFunc2(int aliased) { + | boolean expr1_isNull = false; + | int expr1 = aliased + 1; + |}""".stripMargin + + val expected3 = + code""" + |callFunc3(int aliased) { + | boolean expr1_isNull = false; + | int expr1 = aliased + 1; + |}""".stripMargin + + val exprValues = transformedBlock.children.flatMap { block => + block.asInstanceOf[CodeBlock].blockInputs.collect { + case e: ExprValue => e + } + }.toSet + + assert(transformedBlock.children(0).toString == expected1.toString) + assert(transformedBlock.children(1).toString == expected2.toString) + assert(transformedBlock.children(2).toString == expected3.toString) + assert(transformedBlock.toString == (expected1 + expected2 + expected3).toString) + assert(exprValues === Set(isNull, exprInFunc, aliasedParam)) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjectionSuite.scala index e9d21f8a8ebcd..01aa3579aea98 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjectionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjectionSuite.scala @@ -20,8 +20,8 @@ package org.apache.spark.sql.catalyst.expressions.codegen import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.BoundReference -import org.apache.spark.sql.catalyst.util.{ArrayData, MapData} -import org.apache.spark.sql.types.{DataType, Decimal, StringType, StructType} +import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, MapData} +import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} class GenerateUnsafeProjectionSuite extends SparkFunSuite { @@ -33,6 +33,41 @@ class GenerateUnsafeProjectionSuite extends SparkFunSuite { assert(!result.isNullAt(0)) assert(result.getStruct(0, 1).isNullAt(0)) } + + test("Test unsafe projection for array/map/struct") { + val dataType1 = ArrayType(StringType, false) + val exprs1 = BoundReference(0, dataType1, nullable = false) :: Nil + val projection1 = GenerateUnsafeProjection.generate(exprs1) + val result1 = projection1.apply(AlwaysNonNull) + assert(!result1.isNullAt(0)) + assert(!result1.getArray(0).isNullAt(0)) + assert(!result1.getArray(0).isNullAt(1)) + assert(!result1.getArray(0).isNullAt(2)) + + val dataType2 = MapType(StringType, StringType, false) + val exprs2 = BoundReference(0, dataType2, nullable = false) :: Nil + val projection2 = GenerateUnsafeProjection.generate(exprs2) + val result2 = projection2.apply(AlwaysNonNull) + assert(!result2.isNullAt(0)) + assert(!result2.getMap(0).keyArray.isNullAt(0)) + assert(!result2.getMap(0).keyArray.isNullAt(1)) + assert(!result2.getMap(0).keyArray.isNullAt(2)) + assert(!result2.getMap(0).valueArray.isNullAt(0)) + assert(!result2.getMap(0).valueArray.isNullAt(1)) + assert(!result2.getMap(0).valueArray.isNullAt(2)) + + val dataType3 = (new StructType) + .add("a", StringType, nullable = false) + .add("b", StringType, nullable = false) + .add("c", StringType, nullable = false) + val exprs3 = BoundReference(0, dataType3, nullable = false) :: Nil + val projection3 = GenerateUnsafeProjection.generate(exprs3) + val result3 = projection3.apply(InternalRow(AlwaysNonNull)) + assert(!result3.isNullAt(0)) + assert(!result3.getStruct(0, 1).isNullAt(0)) + assert(!result3.getStruct(0, 2).isNullAt(0)) + assert(!result3.getStruct(0, 3).isNullAt(0)) + } } object AlwaysNull extends InternalRow { @@ -59,3 +94,35 @@ object AlwaysNull extends InternalRow { override def get(ordinal: Int, dataType: DataType): AnyRef = notSupported private def notSupported: Nothing = throw new UnsupportedOperationException } + +object AlwaysNonNull extends InternalRow { + private def stringToUTF8Array(stringArray: Array[String]): ArrayData = { + val utf8Array = stringArray.map(s => UTF8String.fromString(s)).toArray + ArrayData.toArrayData(utf8Array) + } + override def numFields: Int = 1 + override def setNullAt(i: Int): Unit = {} + override def copy(): InternalRow = this + override def anyNull: Boolean = notSupported + override def isNullAt(ordinal: Int): Boolean = notSupported + override def update(i: Int, value: Any): Unit = notSupported + override def getBoolean(ordinal: Int): Boolean = notSupported + override def getByte(ordinal: Int): Byte = notSupported + override def getShort(ordinal: Int): Short = notSupported + override def getInt(ordinal: Int): Int = notSupported + override def getLong(ordinal: Int): Long = notSupported + override def getFloat(ordinal: Int): Float = notSupported + override def getDouble(ordinal: Int): Double = notSupported + override def getDecimal(ordinal: Int, precision: Int, scale: Int): Decimal = notSupported + override def getUTF8String(ordinal: Int): UTF8String = UTF8String.fromString("test") + override def getBinary(ordinal: Int): Array[Byte] = notSupported + override def getInterval(ordinal: Int): CalendarInterval = notSupported + override def getStruct(ordinal: Int, numFields: Int): InternalRow = notSupported + override def getArray(ordinal: Int): ArrayData = stringToUTF8Array(Array("1", "2", "3")) + val keyArray = stringToUTF8Array(Array("1", "2", "3")) + val valueArray = stringToUTF8Array(Array("a", "b", "c")) + override def getMap(ordinal: Int): MapData = new ArrayBasedMapData(keyArray, valueArray) + override def get(ordinal: Int, dataType: DataType): AnyRef = notSupported + private def notSupported: Nothing = throw new UnsupportedOperationException + +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala index 653c07f1835ca..6cd1108eef333 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala @@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types.BooleanType class BooleanSimplificationSuite extends PlanTest with PredicateHelper { @@ -37,6 +38,7 @@ class BooleanSimplificationSuite extends PlanTest with PredicateHelper { Batch("Constant Folding", FixedPoint(50), NullPropagation, ConstantFolding, + SimplifyConditionals, BooleanSimplification, PruneFilters) :: Nil } @@ -48,6 +50,14 @@ class BooleanSimplificationSuite extends PlanTest with PredicateHelper { testRelation.output, Seq(Row(1, 2, 3, "abc")) ) + val testNotNullableRelation = LocalRelation('a.int.notNull, 'b.int.notNull, 'c.int.notNull, + 'd.string.notNull, 'e.boolean.notNull, 'f.boolean.notNull, 'g.boolean.notNull, + 'h.boolean.notNull) + + val testNotNullableRelationWithData = LocalRelation.fromExternalRows( + testNotNullableRelation.output, Seq(Row(1, 2, 3, "abc")) + ) + private def checkCondition(input: Expression, expected: LogicalPlan): Unit = { val plan = testRelationWithData.where(input).analyze val actual = Optimize.execute(plan) @@ -61,6 +71,13 @@ class BooleanSimplificationSuite extends PlanTest with PredicateHelper { comparePlans(actual, correctAnswer) } + private def checkConditionInNotNullableRelation( + input: Expression, expected: LogicalPlan): Unit = { + val plan = testNotNullableRelationWithData.where(input).analyze + val actual = Optimize.execute(plan) + comparePlans(actual, expected) + } + test("a && a => a") { checkCondition(Literal(1) < 'a && Literal(1) < 'a, Literal(1) < 'a) checkCondition(Literal(1) < 'a && Literal(1) < 'a && Literal(1) < 'a, Literal(1) < 'a) @@ -174,10 +191,30 @@ class BooleanSimplificationSuite extends PlanTest with PredicateHelper { } test("Complementation Laws") { - checkCondition('a && !'a, testRelation) - checkCondition(!'a && 'a, testRelation) + checkConditionInNotNullableRelation('e && !'e, testNotNullableRelation) + checkConditionInNotNullableRelation(!'e && 'e, testNotNullableRelation) + + checkConditionInNotNullableRelation('e || !'e, testNotNullableRelationWithData) + checkConditionInNotNullableRelation(!'e || 'e, testNotNullableRelationWithData) + } + + test("Complementation Laws - null handling") { + checkCondition('e && !'e, + testRelationWithData.where(If('e.isNull, Literal.create(null, BooleanType), false)).analyze) + checkCondition(!'e && 'e, + testRelationWithData.where(If('e.isNull, Literal.create(null, BooleanType), false)).analyze) + + checkCondition('e || !'e, + testRelationWithData.where(If('e.isNull, Literal.create(null, BooleanType), true)).analyze) + checkCondition(!'e || 'e, + testRelationWithData.where(If('e.isNull, Literal.create(null, BooleanType), true)).analyze) + } + + test("Complementation Laws - negative case") { + checkCondition('e && !'f, testRelationWithData.where('e && !'f).analyze) + checkCondition(!'f && 'e, testRelationWithData.where(!'f && 'e).analyze) - checkCondition('a || !'a, testRelationWithData) - checkCondition(!'a || 'a, testRelationWithData) + checkCondition('e || !'f, testRelationWithData.where('e || !'f).analyze) + checkCondition(!'f || 'e, testRelationWithData.where(!'f || 'e).analyze) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala index 8b05ba32e6eef..8d7c9bf220bc2 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala @@ -140,6 +140,30 @@ class ColumnPruningSuite extends PlanTest { comparePlans(optimized, expected) } + test("Column pruning for ScriptTransformation") { + val input = LocalRelation('a.int, 'b.string, 'c.double) + val query = + ScriptTransformation( + Seq('a, 'b), + "func", + Seq.empty, + input, + null).analyze + val optimized = Optimize.execute(query) + + val expected = + ScriptTransformation( + Seq('a, 'b), + "func", + Seq.empty, + Project( + Seq('a, 'b), + input), + null).analyze + + comparePlans(optimized, expected) + } + test("Column pruning on Filter") { val input = LocalRelation('a.int, 'b.string, 'c.double) val plan1 = Filter('a > 1, input).analyze @@ -156,10 +180,10 @@ class ColumnPruningSuite extends PlanTest { test("Column pruning on except/intersect/distinct") { val input = LocalRelation('a.int, 'b.string, 'c.double) - val query = Project('a :: Nil, Except(input, input)).analyze + val query = Project('a :: Nil, Except(input, input, isAll = false)).analyze comparePlans(Optimize.execute(query), query) - val query2 = Project('a :: Nil, Intersect(input, input)).analyze + val query2 = Project('a :: Nil, Intersect(input, input, isAll = false)).analyze comparePlans(Optimize.execute(query2), query2) val query3 = Project('a :: Nil, Distinct(input)).analyze comparePlans(Optimize.execute(query3), query3) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConvertToLocalRelationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConvertToLocalRelationSuite.scala index 049a19b86f7cd..0c015f88e1e84 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConvertToLocalRelationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConvertToLocalRelationSuite.scala @@ -21,6 +21,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.expressions.{LessThan, Literal} import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} import org.apache.spark.sql.catalyst.rules.RuleExecutor @@ -52,4 +53,21 @@ class ConvertToLocalRelationSuite extends PlanTest { comparePlans(optimized, correctAnswer) } + test("Filter on LocalRelation should be turned into a single LocalRelation") { + val testRelation = LocalRelation( + LocalRelation('a.int, 'b.int).output, + InternalRow(1, 2) :: InternalRow(4, 5) :: Nil) + + val correctAnswer = LocalRelation( + LocalRelation('a1.int, 'b1.int).output, + InternalRow(1, 3) :: Nil) + + val filterAndProjectOnLocal = testRelation + .select(UnresolvedAttribute("a").as("a1"), (UnresolvedAttribute("b") + 1).as("b1")) + .where(LessThan(UnresolvedAttribute("b1"), Literal.create(6))) + + val optimized = Optimize.execute(filterAndProjectOnLocal.analyze) + + comparePlans(optimized, correctAnswer) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala index e4671f0d1cce6..a40ba2dc38b70 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala @@ -196,7 +196,7 @@ class InferFiltersFromConstraintsSuite extends PlanTest { test("constraints should be inferred from aliased literals") { val originalLeft = testRelation.subquery('left).as("left") - val optimizedLeft = testRelation.subquery('left).where(IsNotNull('a) && 'a === 2).as("left") + val optimizedLeft = testRelation.subquery('left).where(IsNotNull('a) && 'a <=> 2).as("left") val right = Project(Seq(Literal(2).as("two")), testRelation.subquery('right)).as("right") val condition = Some("left.a".attr === "right.two".attr) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala index 478118ed709f7..a36083b847043 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala @@ -121,6 +121,21 @@ class OptimizeInSuite extends PlanTest { comparePlans(optimized, correctAnswer) } + test("OptimizedIn test: NULL IN (subquery) gets transformed to Filter(null)") { + val subquery = ListQuery(testRelation.select(UnresolvedAttribute("a"))) + val originalQuery = + testRelation + .where(InSubquery(Seq(Literal.create(null, NullType)), subquery)) + .analyze + + val optimized = Optimize.execute(originalQuery.analyze) + val correctAnswer = + testRelation + .where(Literal.create(null, BooleanType)) + .analyze + comparePlans(optimized, correctAnswer) + } + test("OptimizedIn test: Inset optimization disabled as " + "list expression contains attribute)") { val originalQuery = @@ -176,6 +191,21 @@ class OptimizeInSuite extends PlanTest { } } + test("OptimizedIn test: one element in list gets transformed to EqualTo.") { + val originalQuery = + testRelation + .where(In(UnresolvedAttribute("a"), Seq(Literal(1)))) + .analyze + + val optimized = Optimize.execute(originalQuery) + val correctAnswer = + testRelation + .where(EqualTo(UnresolvedAttribute("a"), Literal(1))) + .analyze + + comparePlans(optimized, correctAnswer) + } + test("OptimizedIn test: In empty list gets transformed to FalseLiteral " + "when value is not nullable") { val originalQuery = @@ -191,4 +221,21 @@ class OptimizeInSuite extends PlanTest { comparePlans(optimized, correctAnswer) } + + test("OptimizedIn test: In empty list gets transformed to `If` expression " + + "when value is nullable") { + val originalQuery = + testRelation + .where(In(UnresolvedAttribute("a"), Nil)) + .analyze + + val optimized = Optimize.execute(originalQuery) + val correctAnswer = + testRelation + .where(If(IsNotNull(UnresolvedAttribute("a")), + Literal(false), Literal.create(null, BooleanType))) + .analyze + + comparePlans(optimized, correctAnswer) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizerExtendableSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizerExtendableSuite.scala index 7112c033eabce..36b083a540c3c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizerExtendableSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizerExtendableSuite.scala @@ -47,7 +47,7 @@ class OptimizerExtendableSuite extends SparkFunSuite { DummyRule) :: Nil } - override def batches: Seq[Batch] = super.batches ++ myBatches + override def defaultBatches: Seq[Batch] = super.defaultBatches ++ myBatches } test("Extending batches possible") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizerLoggingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizerLoggingSuite.scala new file mode 100644 index 0000000000000..915f408089fe9 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizerLoggingSuite.scala @@ -0,0 +1,148 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import scala.collection.mutable.ArrayBuffer + +import org.apache.log4j.{Appender, AppenderSkeleton, Level, Logger} +import org.apache.log4j.spi.LoggingEvent + +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} +import org.apache.spark.sql.catalyst.rules.RuleExecutor +import org.apache.spark.sql.internal.SQLConf + +class OptimizerLoggingSuite extends PlanTest { + + object Optimize extends RuleExecutor[LogicalPlan] { + val batches = Batch("Optimizer Batch", FixedPoint(100), + PushDownPredicate, + ColumnPruning, + CollapseProject) :: Nil + } + + class MockAppender extends AppenderSkeleton { + val loggingEvents = new ArrayBuffer[LoggingEvent]() + + override def append(loggingEvent: LoggingEvent): Unit = { + if (loggingEvent.getRenderedMessage().contains("Applying Rule")) { + loggingEvents.append(loggingEvent) + } + } + + override def close(): Unit = {} + override def requiresLayout(): Boolean = false + } + + private def withLogLevelAndAppender(level: Level, appender: Appender)(f: => Unit): Unit = { + val logger = Logger.getLogger(Optimize.getClass.getName.dropRight(1)) + val restoreLevel = logger.getLevel + logger.setLevel(level) + logger.addAppender(appender) + try f finally { + logger.setLevel(restoreLevel) + logger.removeAppender(appender) + } + } + + private def verifyLog(expectedLevel: Level, expectedRules: Seq[String]): Unit = { + val logAppender = new MockAppender() + withLogLevelAndAppender(Level.TRACE, logAppender) { + val input = LocalRelation('a.int, 'b.string, 'c.double) + val query = input.select('a, 'b).select('a).where('a > 1).analyze + val expected = input.where('a > 1).select('a).analyze + comparePlans(Optimize.execute(query), expected) + } + val logMessages = logAppender.loggingEvents.map(_.getRenderedMessage) + assert(expectedRules.forall(rule => logMessages.exists(_.contains(rule)))) + assert(logAppender.loggingEvents.forall(_.getLevel == expectedLevel)) + } + + test("test log level") { + val levels = Seq( + "TRACE" -> Level.TRACE, + "trace" -> Level.TRACE, + "DEBUG" -> Level.DEBUG, + "debug" -> Level.DEBUG, + "INFO" -> Level.INFO, + "info" -> Level.INFO, + "WARN" -> Level.WARN, + "warn" -> Level.WARN, + "ERROR" -> Level.ERROR, + "error" -> Level.ERROR, + "deBUG" -> Level.DEBUG) + + levels.foreach { level => + withSQLConf(SQLConf.OPTIMIZER_PLAN_CHANGE_LOG_LEVEL.key -> level._1) { + verifyLog( + level._2, + Seq( + PushDownPredicate.ruleName, + ColumnPruning.ruleName, + CollapseProject.ruleName)) + } + } + } + + test("test invalid log level conf") { + val levels = Seq( + "", + "*d_", + "infoo") + + levels.foreach { level => + val error = intercept[IllegalArgumentException] { + withSQLConf(SQLConf.OPTIMIZER_PLAN_CHANGE_LOG_LEVEL.key -> level) {} + } + assert(error.getMessage.contains( + "Invalid value for 'spark.sql.optimizer.planChangeLog.level'.")) + } + } + + test("test log rules") { + val rulesSeq = Seq( + Seq(PushDownPredicate.ruleName, + ColumnPruning.ruleName, + CollapseProject.ruleName).reduce(_ + "," + _) -> + Seq(PushDownPredicate.ruleName, + ColumnPruning.ruleName, + CollapseProject.ruleName), + Seq(PushDownPredicate.ruleName, + ColumnPruning.ruleName).reduce(_ + "," + _) -> + Seq(PushDownPredicate.ruleName, + ColumnPruning.ruleName), + CollapseProject.ruleName -> + Seq(CollapseProject.ruleName), + Seq(ColumnPruning.ruleName, + "DummyRule").reduce(_ + "," + _) -> + Seq(ColumnPruning.ruleName), + "DummyRule" -> Seq(), + "" -> Seq() + ) + + rulesSeq.foreach { case (rulesConf, expectedRules) => + withSQLConf( + SQLConf.OPTIMIZER_PLAN_CHANGE_LOG_RULES.key -> rulesConf, + SQLConf.OPTIMIZER_PLAN_CHANGE_LOG_LEVEL.key -> "INFO") { + verifyLog(Level.INFO, expectedRules) + } + } + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizerRuleExclusionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizerRuleExclusionSuite.scala new file mode 100644 index 0000000000000..4fa4a7aadc8f2 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizerRuleExclusionSuite.scala @@ -0,0 +1,132 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.logical.LocalRelation +import org.apache.spark.sql.internal.SQLConf.OPTIMIZER_EXCLUDED_RULES + + +class OptimizerRuleExclusionSuite extends PlanTest { + + val testRelation = LocalRelation('a.int, 'b.int, 'c.int) + + private def verifyExcludedRules(optimizer: Optimizer, rulesToExclude: Seq[String]) { + val nonExcludableRules = optimizer.nonExcludableRules + + val excludedRuleNames = rulesToExclude.filter(!nonExcludableRules.contains(_)) + // Batches whose rules are all to be excluded should be removed as a whole. + val excludedBatchNames = optimizer.batches + .filter(batch => batch.rules.forall(rule => excludedRuleNames.contains(rule.ruleName))) + .map(_.name) + + withSQLConf( + OPTIMIZER_EXCLUDED_RULES.key -> excludedRuleNames.foldLeft("")((l, r) => l + "," + r)) { + val batches = optimizer.batches + // Verify removed batches. + assert(batches.forall(batch => !excludedBatchNames.contains(batch.name))) + // Verify removed rules. + assert( + batches + .forall(batch => batch.rules.forall(rule => !excludedRuleNames.contains(rule.ruleName)))) + // Verify non-excludable rules retained. + nonExcludableRules.foreach { nonExcludableRule => + assert( + optimizer.batches + .exists(batch => batch.rules.exists(rule => rule.ruleName == nonExcludableRule))) + } + } + } + + test("Exclude a single rule from multiple batches") { + verifyExcludedRules( + new SimpleTestOptimizer(), + Seq( + PushPredicateThroughJoin.ruleName)) + } + + test("Exclude multiple rules from single or multiple batches") { + verifyExcludedRules( + new SimpleTestOptimizer(), + Seq( + CombineUnions.ruleName, + RemoveLiteralFromGroupExpressions.ruleName, + RemoveRepetitionFromGroupExpressions.ruleName)) + } + + test("Exclude non-existent rule with other valid rules") { + verifyExcludedRules( + new SimpleTestOptimizer(), + Seq( + LimitPushDown.ruleName, + InferFiltersFromConstraints.ruleName, + "DummyRuleName")) + } + + test("Try to exclude some non-excludable rules") { + verifyExcludedRules( + new SimpleTestOptimizer(), + Seq( + ReplaceIntersectWithSemiJoin.ruleName, + PullupCorrelatedPredicates.ruleName, + RewriteCorrelatedScalarSubquery.ruleName, + RewritePredicateSubquery.ruleName, + RewriteExceptAll.ruleName, + RewriteIntersectAll.ruleName)) + } + + test("Custom optimizer") { + val optimizer = new SimpleTestOptimizer() { + override def defaultBatches: Seq[Batch] = + Batch("push", Once, + PushDownPredicate, + PushPredicateThroughJoin, + PushProjectionThroughUnion) :: + Batch("pull", Once, + PullupCorrelatedPredicates) :: Nil + + override def nonExcludableRules: Seq[String] = + PushDownPredicate.ruleName :: + PullupCorrelatedPredicates.ruleName :: Nil + } + + verifyExcludedRules( + optimizer, + Seq( + PushDownPredicate.ruleName, + PushProjectionThroughUnion.ruleName, + PullupCorrelatedPredicates.ruleName)) + } + + test("Verify optimized plan after excluding CombineUnions rule") { + val excludedRules = Seq( + ConvertToLocalRelation.ruleName, + PropagateEmptyRelation.ruleName, + CombineUnions.ruleName) + + withSQLConf( + OPTIMIZER_EXCLUDED_RULES.key -> excludedRules.foldLeft("")((l, r) => l + "," + r)) { + val optimizer = new SimpleTestOptimizer() + val originalQuery = testRelation.union(testRelation.union(testRelation)).analyze + val optimized = optimizer.execute(originalQuery) + comparePlans(originalQuery, optimized) + } + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizerStructuralIntegrityCheckerSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizerStructuralIntegrityCheckerSuite.scala index 6e183d81b7265..a22a81e9844d3 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizerStructuralIntegrityCheckerSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizerStructuralIntegrityCheckerSuite.scala @@ -44,7 +44,7 @@ class OptimizerStructuralIntegrityCheckerSuite extends PlanTest { EmptyFunctionRegistry, new SQLConf())) { val newBatch = Batch("OptimizeRuleBreakSI", Once, OptimizeRuleBreakSI) - override def batches: Seq[Batch] = Seq(newBatch) ++ super.batches + override def defaultBatches: Seq[Batch] = Seq(newBatch) ++ super.defaultBatches } test("check for invalid plan after execution of rule") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelationSuite.scala index f1ce7543ffdc1..d395bba105a7b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelationSuite.scala @@ -147,7 +147,7 @@ class PropagateEmptyRelationSuite extends PlanTest { .where(false) .select('a) .where('a > 1) - .where('a != 200) + .where('a =!= 200) .orderBy('a.asc) val optimized = Optimize.execute(query.analyze) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PullupCorrelatedPredicatesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PullupCorrelatedPredicatesSuite.scala index 169b8737d808b..8a5a55146726e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PullupCorrelatedPredicatesSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PullupCorrelatedPredicatesSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ -import org.apache.spark.sql.catalyst.expressions.{In, ListQuery} +import org.apache.spark.sql.catalyst.expressions.{InSubquery, ListQuery} import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} import org.apache.spark.sql.catalyst.rules.RuleExecutor @@ -42,7 +42,7 @@ class PullupCorrelatedPredicatesSuite extends PlanTest { .select('c) val outerQuery = testRelation - .where(In('a, Seq(ListQuery(correlatedSubquery)))) + .where(InSubquery(Seq('a), ListQuery(correlatedSubquery))) .select('a).analyze assert(outerQuery.resolved) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala index 52dc2e9fb076c..3b1b2d588ef67 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala @@ -42,7 +42,7 @@ class ReplaceOperatorSuite extends PlanTest { val table1 = LocalRelation('a.int, 'b.int) val table2 = LocalRelation('c.int, 'd.int) - val query = Intersect(table1, table2) + val query = Intersect(table1, table2, isAll = false) val optimized = Optimize.execute(query.analyze) val correctAnswer = @@ -60,7 +60,7 @@ class ReplaceOperatorSuite extends PlanTest { val table2 = Filter(attributeB === 2, Filter(attributeA === 1, table1)) val table3 = Filter(attributeB < 1, Filter(attributeA >= 2, table1)) - val query = Except(table2, table3) + val query = Except(table2, table3, isAll = false) val optimized = Optimize.execute(query.analyze) val correctAnswer = @@ -79,7 +79,7 @@ class ReplaceOperatorSuite extends PlanTest { val table1 = LocalRelation.fromExternalRows(Seq(attributeA, attributeB), data = Seq(Row(1, 2))) val table2 = Filter(attributeB < 1, Filter(attributeA >= 2, table1)) - val query = Except(table1, table2) + val query = Except(table1, table2, isAll = false) val optimized = Optimize.execute(query.analyze) val correctAnswer = @@ -99,7 +99,7 @@ class ReplaceOperatorSuite extends PlanTest { val table3 = Project(Seq(attributeA, attributeB), Filter(attributeB < 1, Filter(attributeA >= 2, table1))) - val query = Except(table2, table3) + val query = Except(table2, table3, isAll = false) val optimized = Optimize.execute(query.analyze) val correctAnswer = @@ -120,7 +120,7 @@ class ReplaceOperatorSuite extends PlanTest { val table3 = Project(Seq(attributeA, attributeB), Filter(attributeB < 1, Filter(attributeA >= 2, table1))) - val query = Except(table2, table3) + val query = Except(table2, table3, isAll = false) val optimized = Optimize.execute(query.analyze) val correctAnswer = @@ -141,7 +141,7 @@ class ReplaceOperatorSuite extends PlanTest { Filter(attributeB < 1, Filter(attributeA >= 2, table1))) val table3 = Filter(attributeB === 2, Filter(attributeA === 1, table1)) - val query = Except(table2, table3) + val query = Except(table2, table3, isAll = false) val optimized = Optimize.execute(query.analyze) val correctAnswer = @@ -158,7 +158,7 @@ class ReplaceOperatorSuite extends PlanTest { val table1 = LocalRelation('a.int, 'b.int) val table2 = LocalRelation('c.int, 'd.int) - val query = Except(table1, table2) + val query = Except(table1, table2, isAll = false) val optimized = Optimize.execute(query.analyze) val correctAnswer = @@ -173,7 +173,7 @@ class ReplaceOperatorSuite extends PlanTest { val left = table.where('b < 1).select('a).as("left") val right = table.where('b < 3).select('a).as("right") - val query = Except(left, right) + val query = Except(left, right, isAll = false) val optimized = Optimize.execute(query.analyze) val correctAnswer = diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationSuite.scala index aa8841109329c..da3923f8d6477 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationSuite.scala @@ -20,10 +20,11 @@ package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ -import org.apache.spark.sql.catalyst.expressions.Literal +import org.apache.spark.sql.catalyst.expressions.{Alias, And, GreaterThan, GreaterThanOrEqual, If, Literal, ReplicateRows} import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ +import org.apache.spark.sql.types.BooleanType class SetOperationSuite extends PlanTest { object Optimize extends RuleExecutor[LogicalPlan] { @@ -144,4 +145,55 @@ class SetOperationSuite extends PlanTest { Distinct(Union(query3 :: query4 :: Nil))).analyze comparePlans(distinctUnionCorrectAnswer2, optimized2) } + + test("EXCEPT ALL rewrite") { + val input = Except(testRelation, testRelation2, isAll = true) + val rewrittenPlan = RewriteExceptAll(input) + + val planFragment = testRelation.select(Literal(1L).as("vcol"), 'a, 'b, 'c) + .union(testRelation2.select(Literal(-1L).as("vcol"), 'd, 'e, 'f)) + .groupBy('a, 'b, 'c)('a, 'b, 'c, sum('vcol).as("sum")) + .where(GreaterThan('sum, Literal(0L))).analyze + val multiplerAttr = planFragment.output.last + val output = planFragment.output.dropRight(1) + val expectedPlan = Project(output, + Generate( + ReplicateRows(Seq(multiplerAttr) ++ output), + Nil, + false, + None, + output, + planFragment + )) + comparePlans(expectedPlan, rewrittenPlan) + } + + test("INTERSECT ALL rewrite") { + val input = Intersect(testRelation, testRelation2, isAll = true) + val rewrittenPlan = RewriteIntersectAll(input) + val leftRelation = testRelation + .select(Literal(true).as("vcol1"), Literal(null, BooleanType).as("vcol2"), 'a, 'b, 'c) + val rightRelation = testRelation2 + .select(Literal(null, BooleanType).as("vcol1"), Literal(true).as("vcol2"), 'd, 'e, 'f) + val planFragment = leftRelation.union(rightRelation) + .groupBy('a, 'b, 'c)(count('vcol1).as("vcol1_count"), + count('vcol2).as("vcol2_count"), 'a, 'b, 'c) + .where(And(GreaterThanOrEqual('vcol1_count, Literal(1L)), + GreaterThanOrEqual('vcol2_count, Literal(1L)))) + .select('a, 'b, 'c, + If(GreaterThan('vcol1_count, 'vcol2_count), 'vcol2_count, 'vcol1_count).as("min_count")) + .analyze + val multiplerAttr = planFragment.output.last + val output = planFragment.output.dropRight(1) + val expectedPlan = Project(output, + Generate( + ReplicateRows(Seq(multiplerAttr) ++ output), + Nil, + false, + None, + output, + planFragment + )) + comparePlans(expectedPlan, rewrittenPlan) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala index b597c8e162c83..8ad7c12020b82 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.optimizer +import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral} @@ -29,7 +30,8 @@ import org.apache.spark.sql.types.{IntegerType, NullType} class SimplifyConditionalSuite extends PlanTest with PredicateHelper { object Optimize extends RuleExecutor[LogicalPlan] { - val batches = Batch("SimplifyConditionals", FixedPoint(50), SimplifyConditionals) :: Nil + val batches = Batch("SimplifyConditionals", FixedPoint(50), + BooleanSimplification, ConstantFolding, SimplifyConditionals) :: Nil } protected def assertEquivalent(e1: Expression, e2: Expression): Unit = { @@ -43,6 +45,10 @@ class SimplifyConditionalSuite extends PlanTest with PredicateHelper { private val unreachableBranch = (FalseLiteral, Literal(20)) private val nullBranch = (Literal.create(null, NullType), Literal(30)) + val isNotNullCond = IsNotNull(UnresolvedAttribute(Seq("a"))) + val isNullCond = IsNull(UnresolvedAttribute("b")) + val notCond = Not(UnresolvedAttribute("c")) + test("simplify if") { assertEquivalent( If(TrueLiteral, Literal(10), Literal(20)), @@ -57,6 +63,23 @@ class SimplifyConditionalSuite extends PlanTest with PredicateHelper { Literal(20)) } + test("remove unnecessary if when the outputs are semantic equivalence") { + assertEquivalent( + If(IsNotNull(UnresolvedAttribute("a")), + Subtract(Literal(10), Literal(1)), + Add(Literal(6), Literal(3))), + Literal(9)) + + // For non-deterministic condition, we don't remove the `If` statement. + assertEquivalent( + If(GreaterThan(Rand(0), Literal(0.5)), + Subtract(Literal(10), Literal(1)), + Add(Literal(6), Literal(3))), + If(GreaterThan(Rand(0), Literal(0.5)), + Literal(9), + Literal(9))) + } + test("remove unreachable branches") { // i.e. removing branches whose conditions are always false assertEquivalent( @@ -100,4 +123,47 @@ class SimplifyConditionalSuite extends PlanTest with PredicateHelper { None), CaseWhen(normalBranch :: trueBranch :: Nil, None)) } + + test("simplify CaseWhen if all the outputs are semantic equivalence") { + // When the conditions in `CaseWhen` are all deterministic, `CaseWhen` can be removed. + assertEquivalent( + CaseWhen((isNotNullCond, Subtract(Literal(3), Literal(2))) :: + (isNullCond, Literal(1)) :: + (notCond, Add(Literal(6), Literal(-5))) :: + Nil, + Add(Literal(2), Literal(-1))), + Literal(1) + ) + + // For non-deterministic conditions, we don't remove the `CaseWhen` statement. + assertEquivalent( + CaseWhen((GreaterThan(Rand(0), Literal(0.5)), Subtract(Literal(3), Literal(2))) :: + (LessThan(Rand(1), Literal(0.5)), Literal(1)) :: + (EqualTo(Rand(2), Literal(0.5)), Add(Literal(6), Literal(-5))) :: + Nil, + Add(Literal(2), Literal(-1))), + CaseWhen((GreaterThan(Rand(0), Literal(0.5)), Literal(1)) :: + (LessThan(Rand(1), Literal(0.5)), Literal(1)) :: + (EqualTo(Rand(2), Literal(0.5)), Literal(1)) :: + Nil, + Literal(1)) + ) + + // When we have mixture of deterministic and non-deterministic conditions, we remove + // the deterministic conditions from the tail until a non-deterministic one is seen. + assertEquivalent( + CaseWhen((GreaterThan(Rand(0), Literal(0.5)), Subtract(Literal(3), Literal(2))) :: + (NonFoldableLiteral(true), Add(Literal(2), Literal(-1))) :: + (LessThan(Rand(1), Literal(0.5)), Literal(1)) :: + (NonFoldableLiteral(true), Add(Literal(6), Literal(-5))) :: + (NonFoldableLiteral(false), Literal(1)) :: + Nil, + Add(Literal(2), Literal(-1))), + CaseWhen((GreaterThan(Rand(0), Literal(0.5)), Literal(1)) :: + (NonFoldableLiteral(true), Literal(1)) :: + (LessThan(Rand(1), Literal(0.5)), Literal(1)) :: + Nil, + Literal(1)) + ) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/TransposeWindowSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/TransposeWindowSuite.scala new file mode 100644 index 0000000000000..58b3d1c98f3cd --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/TransposeWindowSuite.scala @@ -0,0 +1,114 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.expressions.Rand +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} +import org.apache.spark.sql.catalyst.rules.RuleExecutor + +class TransposeWindowSuite extends PlanTest { + object Optimize extends RuleExecutor[LogicalPlan] { + val batches = + Batch("CollapseProject", FixedPoint(100), CollapseProject, RemoveRedundantProject) :: + Batch("FlipWindow", Once, CollapseWindow, TransposeWindow) :: Nil + } + + val testRelation = LocalRelation('a.string, 'b.string, 'c.int, 'd.string) + + val a = testRelation.output(0) + val b = testRelation.output(1) + val c = testRelation.output(2) + val d = testRelation.output(3) + + val partitionSpec1 = Seq(a) + val partitionSpec2 = Seq(a, b) + val partitionSpec3 = Seq(d) + val partitionSpec4 = Seq(b, a, d) + + val orderSpec1 = Seq(d.asc) + val orderSpec2 = Seq(d.desc) + + test("transpose two adjacent windows with compatible partitions") { + val query = testRelation + .window(Seq(sum(c).as('sum_a_2)), partitionSpec2, orderSpec2) + .window(Seq(sum(c).as('sum_a_1)), partitionSpec1, orderSpec1) + + val analyzed = query.analyze + val optimized = Optimize.execute(analyzed) + + val correctAnswer = testRelation + .window(Seq(sum(c).as('sum_a_1)), partitionSpec1, orderSpec1) + .window(Seq(sum(c).as('sum_a_2)), partitionSpec2, orderSpec2) + .select('a, 'b, 'c, 'd, 'sum_a_2, 'sum_a_1) + + comparePlans(optimized, correctAnswer.analyze) + } + + test("transpose two adjacent windows with differently ordered compatible partitions") { + val query = testRelation + .window(Seq(sum(c).as('sum_a_2)), partitionSpec4, Seq.empty) + .window(Seq(sum(c).as('sum_a_1)), partitionSpec2, Seq.empty) + + val analyzed = query.analyze + val optimized = Optimize.execute(analyzed) + + val correctAnswer = testRelation + .window(Seq(sum(c).as('sum_a_1)), partitionSpec2, Seq.empty) + .window(Seq(sum(c).as('sum_a_2)), partitionSpec4, Seq.empty) + .select('a, 'b, 'c, 'd, 'sum_a_2, 'sum_a_1) + + comparePlans(optimized, correctAnswer.analyze) + } + + test("don't transpose two adjacent windows with incompatible partitions") { + val query = testRelation + .window(Seq(sum(c).as('sum_a_2)), partitionSpec3, Seq.empty) + .window(Seq(sum(c).as('sum_a_1)), partitionSpec1, Seq.empty) + + val analyzed = query.analyze + val optimized = Optimize.execute(analyzed) + + comparePlans(optimized, analyzed) + } + + test("don't transpose two adjacent windows with intersection of partition and output set") { + val query = testRelation + .window(Seq(('a + 'b).as('e), sum(c).as('sum_a_2)), partitionSpec3, Seq.empty) + .window(Seq(sum(c).as('sum_a_1)), Seq(a, 'e), Seq.empty) + + val analyzed = query.analyze + val optimized = Optimize.execute(analyzed) + + comparePlans(optimized, analyzed) + } + + test("don't transpose two adjacent windows with non-deterministic expressions") { + val query = testRelation + .window(Seq(Rand(0).as('e), sum(c).as('sum_a_2)), partitionSpec3, Seq.empty) + .window(Seq(sum(c).as('sum_a_1)), partitionSpec1, Seq.empty) + + val analyzed = query.analyze + val optimized = Optimize.execute(analyzed) + + comparePlans(optimized, analyzed) + } + +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ErrorParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ErrorParserSuite.scala index f67697eb86c26..baaf01800b33b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ErrorParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ErrorParserSuite.scala @@ -58,8 +58,5 @@ class ErrorParserSuite extends SparkFunSuite { intercept("select *\nfrom r\norder by q\ncluster by q", 3, 0, "Combination of ORDER BY/SORT BY/DISTRIBUTE BY/CLUSTER BY is not supported", "^^^") - intercept("select * from r except all select * from t", 1, 0, - "EXCEPT ALL is not supported", - "^^^") } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala index cb8a1fecb80a7..781fc1e957ae0 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala @@ -154,7 +154,19 @@ class ExpressionParserSuite extends PlanTest { test("in sub-query") { assertEqual( "a in (select b from c)", - In('a, Seq(ListQuery(table("c").select('b))))) + InSubquery(Seq('a), ListQuery(table("c").select('b)))) + + assertEqual( + "(a, b, c) in (select d, e, f from g)", + InSubquery(Seq('a, 'b, 'c), ListQuery(table("g").select('d, 'e, 'f)))) + + assertEqual( + "(a, b) in (select c from d)", + InSubquery(Seq('a, 'b), ListQuery(table("d").select('c)))) + + assertEqual( + "(a) in (select b from c)", + InSubquery(Seq('a), ListQuery(table("c").select('b)))) } test("like expressions") { @@ -234,6 +246,11 @@ class ExpressionParserSuite extends PlanTest { intercept("foo(a x)", "extraneous input 'x'") } + test("lambda functions") { + assertEqual("x -> x + 1", LambdaFunction('x + 1, Seq('x.attr))) + assertEqual("(x, y) -> x + y", LambdaFunction('x + 'y, Seq('x.attr, 'y.attr))) + } + test("window function expressions") { val func = 'foo.function(star()) def windowed( @@ -469,7 +486,7 @@ class ExpressionParserSuite extends PlanTest { Literal(BigDecimal("90912830918230182310293801923652346786").underlying())) assertEqual("123.0E-28BD", Literal(BigDecimal("123.0E-28").underlying())) assertEqual("123.08BD", Literal(BigDecimal("123.08").underlying())) - intercept("1.20E-38BD", "DecimalType can only support precision up to 38") + intercept("1.20E-38BD", "decimal can only support precision up to 38") } test("strings") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala index fb51376c6163f..422bf97e30e7e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala @@ -22,6 +22,7 @@ import org.apache.spark.sql.catalyst.analysis.{AnalysisTest, UnresolvedAttribute import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.IntegerType /** @@ -64,15 +65,16 @@ class PlanParserSuite extends AnalysisTest { assertEqual("select * from a union select * from b", Distinct(a.union(b))) assertEqual("select * from a union distinct select * from b", Distinct(a.union(b))) assertEqual("select * from a union all select * from b", a.union(b)) - assertEqual("select * from a except select * from b", a.except(b)) - intercept("select * from a except all select * from b", "EXCEPT ALL is not supported.") - assertEqual("select * from a except distinct select * from b", a.except(b)) - assertEqual("select * from a minus select * from b", a.except(b)) - intercept("select * from a minus all select * from b", "MINUS ALL is not supported.") - assertEqual("select * from a minus distinct select * from b", a.except(b)) - assertEqual("select * from a intersect select * from b", a.intersect(b)) - intercept("select * from a intersect all select * from b", "INTERSECT ALL is not supported.") - assertEqual("select * from a intersect distinct select * from b", a.intersect(b)) + assertEqual("select * from a except select * from b", a.except(b, isAll = false)) + assertEqual("select * from a except distinct select * from b", a.except(b, isAll = false)) + assertEqual("select * from a except all select * from b", a.except(b, isAll = true)) + assertEqual("select * from a minus select * from b", a.except(b, isAll = false)) + assertEqual("select * from a minus all select * from b", a.except(b, isAll = true)) + assertEqual("select * from a minus distinct select * from b", a.except(b, isAll = false)) + assertEqual("select * from a " + + "intersect select * from b", a.intersect(b, isAll = false)) + assertEqual("select * from a intersect distinct select * from b", a.intersect(b, isAll = false)) + assertEqual("select * from a intersect all select * from b", a.intersect(b, isAll = true)) } test("common table expressions") { @@ -592,6 +594,33 @@ class PlanParserSuite extends AnalysisTest { parsePlan("SELECT /*+ MAPJOIN(t) */ a from t where true group by a order by a"), UnresolvedHint("MAPJOIN", Seq($"t"), table("t").where(Literal(true)).groupBy('a)('a)).orderBy('a.asc)) + + comparePlans( + parsePlan("SELECT /*+ COALESCE(10) */ * FROM t"), + UnresolvedHint("COALESCE", Seq(Literal(10)), + table("t").select(star()))) + + comparePlans( + parsePlan("SELECT /*+ REPARTITION(100) */ * FROM t"), + UnresolvedHint("REPARTITION", Seq(Literal(100)), + table("t").select(star()))) + + comparePlans( + parsePlan( + "INSERT INTO s SELECT /*+ REPARTITION(100), COALESCE(500), COALESCE(10) */ * FROM t"), + InsertIntoTable(table("s"), Map.empty, + UnresolvedHint("REPARTITION", Seq(Literal(100)), + UnresolvedHint("COALESCE", Seq(Literal(500)), + UnresolvedHint("COALESCE", Seq(Literal(10)), + table("t").select(star())))), overwrite = false, ifPartitionNotExists = false)) + + comparePlans( + parsePlan("SELECT /*+ BROADCASTJOIN(u), REPARTITION(100) */ * FROM t"), + UnresolvedHint("BROADCASTJOIN", Seq($"u"), + UnresolvedHint("REPARTITION", Seq(Literal(100)), + table("t").select(star())))) + + intercept("SELECT /*+ COALESCE(30 + 50) */ * FROM t", "mismatched input") } test("SPARK-20854: select hint syntax with expressions") { @@ -678,4 +707,50 @@ class PlanParserSuite extends AnalysisTest { OneRowRelation().select('rtrim.function("c&^,.", "bc...,,,&&&ccc")) ) } + + test("precedence of set operations") { + val a = table("a").select(star()) + val b = table("b").select(star()) + val c = table("c").select(star()) + val d = table("d").select(star()) + + val query1 = + """ + |SELECT * FROM a + |UNION + |SELECT * FROM b + |EXCEPT + |SELECT * FROM c + |INTERSECT + |SELECT * FROM d + """.stripMargin + + val query2 = + """ + |SELECT * FROM a + |UNION + |SELECT * FROM b + |EXCEPT ALL + |SELECT * FROM c + |INTERSECT ALL + |SELECT * FROM d + """.stripMargin + + assertEqual(query1, Distinct(a.union(b)).except(c.intersect(d, isAll = false), isAll = false)) + assertEqual(query2, Distinct(a.union(b)).except(c.intersect(d, isAll = true), isAll = true)) + + // Now disable precedence enforcement to verify the old behaviour. + withSQLConf(SQLConf.LEGACY_SETOPS_PRECEDENCE_ENABLED.key -> "true") { + assertEqual(query1, + Distinct(a.union(b)).except(c, isAll = false).intersect(d, isAll = false)) + assertEqual(query2, Distinct(a.union(b)).except(c, isAll = true).intersect(d, isAll = true)) + } + + // Explicitly enable the precedence enforcement + withSQLConf(SQLConf.LEGACY_SETOPS_PRECEDENCE_ENABLED.key -> "false") { + assertEqual(query1, + Distinct(a.union(b)).except(c.intersect(d, isAll = false), isAll = false)) + assertEqual(query2, Distinct(a.union(b)).except(c.intersect(d, isAll = true), isAll = true)) + } + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala index a37e06d922642..5ad748b6113d6 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala @@ -187,7 +187,7 @@ class ConstraintPropagationSuite extends SparkFunSuite with PlanTest { verifyConstraints(tr1 .where('a.attr > 10) - .intersect(tr2.where('b.attr < 100)) + .intersect(tr2.where('b.attr < 100), isAll = false) .analyze.constraints, ExpressionSet(Seq(resolveColumn(tr1, "a") > 10, resolveColumn(tr1, "b") < 100, @@ -200,7 +200,7 @@ class ConstraintPropagationSuite extends SparkFunSuite with PlanTest { val tr2 = LocalRelation('a.int, 'b.int, 'c.int) verifyConstraints(tr1 .where('a.attr > 10) - .except(tr2.where('b.attr < 100)) + .except(tr2.where('b.attr < 100), isAll = false) .analyze.constraints, ExpressionSet(Seq(resolveColumn(tr1, "a") > 10, IsNotNull(resolveColumn(tr1, "a"))))) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/LogicalPlanSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/LogicalPlanSuite.scala index bf569cb869428..aaab3ff1bf128 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/LogicalPlanSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/LogicalPlanSuite.scala @@ -18,13 +18,12 @@ package org.apache.spark.sql.catalyst.plans import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeReference, Coalesce, Literal, NamedExpression} +import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeReference, Literal, NamedExpression} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.types.IntegerType /** - * This suite is used to test [[LogicalPlan]]'s `transformUp/transformDown` plus analysis barrier - * and make sure it can correctly skip sub-trees that have already been analyzed. + * This suite is used to test [[LogicalPlan]]'s `transformUp/transformDown`. */ class LogicalPlanSuite extends SparkFunSuite { private var invocationCount = 0 @@ -60,31 +59,6 @@ class LogicalPlanSuite extends SparkFunSuite { assert(invocationCount === 2) } - test("transformUp skips all ready resolved plans wrapped in analysis barrier") { - invocationCount = 0 - val plan = AnalysisBarrier(Project(Nil, Project(Nil, testRelation))) - plan transformUp function - - assert(invocationCount === 0) - - invocationCount = 0 - plan transformDown function - assert(invocationCount === 0) - } - - test("transformUp skips partially resolved plans wrapped in analysis barrier") { - invocationCount = 0 - val plan1 = AnalysisBarrier(Project(Nil, testRelation)) - val plan2 = Project(Nil, plan1) - plan2 transformUp function - - assert(invocationCount === 1) - - invocationCount = 0 - plan2 transformDown function - assert(invocationCount === 1) - } - test("isStreaming") { val relation = LocalRelation(AttributeReference("a", IntegerType, nullable = true)()) val incrementalRelation = LocalRelation( diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala index 6241d5cbb1d25..67740c3166471 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala @@ -17,12 +17,15 @@ package org.apache.spark.sql.catalyst.plans +import org.scalactic.source import org.scalatest.Suite +import org.scalatest.Tag import org.apache.spark.SparkFunSuite import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.analysis.SimpleAnalyzer import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.CodegenObjectFactoryMode import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.util._ @@ -33,6 +36,23 @@ import org.apache.spark.sql.internal.SQLConf */ trait PlanTest extends SparkFunSuite with PlanTestBase +trait CodegenInterpretedPlanTest extends PlanTest { + + override protected def test( + testName: String, + testTags: Tag*)(testFun: => Any)(implicit pos: source.Position): Unit = { + val codegenMode = CodegenObjectFactoryMode.CODEGEN_ONLY.toString + val interpretedMode = CodegenObjectFactoryMode.NO_CODEGEN.toString + + withSQLConf(SQLConf.CODEGEN_FACTORY_MODE.key -> codegenMode) { + super.test(testName + " (codegen path)", testTags: _*)(testFun)(pos) + } + withSQLConf(SQLConf.CODEGEN_FACTORY_MODE.key -> interpretedMode) { + super.test(testName + " (interpreted path)", testTags: _*)(testFun)(pos) + } + } +} + /** * Provides helper methods for comparing plans, but without the overhead of * mandating a FunSuite. @@ -60,6 +80,8 @@ trait PlanTestBase extends PredicateHelper { self: Suite => Alias(a.child, a.name)(exprId = ExprId(0)) case ae: AggregateExpression => ae.copy(resultId = ExprId(0)) + case lv: NamedLambdaVariable => + lv.copy(exprId = ExprId(0), value = null) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/logical/AnalysisHelperSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/logical/AnalysisHelperSuite.scala new file mode 100644 index 0000000000000..9100e10ca0c09 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/logical/AnalysisHelperSuite.scala @@ -0,0 +1,159 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.plans.logical + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.expressions.{Alias, Expression, Literal, NamedExpression} + + +class AnalysisHelperSuite extends SparkFunSuite { + + private var invocationCount = 0 + private val function: PartialFunction[LogicalPlan, LogicalPlan] = { + case p: Project => + invocationCount += 1 + p + } + + private val exprFunction: PartialFunction[Expression, Expression] = { + case e: Literal => + invocationCount += 1 + Literal.TrueLiteral + } + + private def projectExprs: Seq[NamedExpression] = Alias(Literal.TrueLiteral, "A")() :: Nil + + test("setAnalyze is recursive") { + val plan = Project(Nil, LocalRelation()) + plan.setAnalyzed() + assert(plan.find(!_.analyzed).isEmpty) + } + + test("resolveOperator runs on operators recursively") { + invocationCount = 0 + val plan = Project(Nil, Project(Nil, LocalRelation())) + plan.resolveOperators(function) + assert(invocationCount === 2) + } + + test("resolveOperatorsDown runs on operators recursively") { + invocationCount = 0 + val plan = Project(Nil, Project(Nil, LocalRelation())) + plan.resolveOperatorsDown(function) + assert(invocationCount === 2) + } + + test("resolveExpressions runs on operators recursively") { + invocationCount = 0 + val plan = Project(projectExprs, Project(projectExprs, LocalRelation())) + plan.resolveExpressions(exprFunction) + assert(invocationCount === 2) + } + + test("resolveOperator skips all ready resolved plans") { + invocationCount = 0 + val plan = Project(Nil, Project(Nil, LocalRelation())) + plan.setAnalyzed() + plan.resolveOperators(function) + assert(invocationCount === 0) + } + + test("resolveOperatorsDown skips all ready resolved plans") { + invocationCount = 0 + val plan = Project(Nil, Project(Nil, LocalRelation())) + plan.setAnalyzed() + plan.resolveOperatorsDown(function) + assert(invocationCount === 0) + } + + test("resolveExpressions skips all ready resolved plans") { + invocationCount = 0 + val plan = Project(projectExprs, Project(projectExprs, LocalRelation())) + plan.setAnalyzed() + plan.resolveExpressions(exprFunction) + assert(invocationCount === 0) + } + + test("resolveOperator skips partially resolved plans") { + invocationCount = 0 + val plan1 = Project(Nil, LocalRelation()) + val plan2 = Project(Nil, plan1) + plan1.setAnalyzed() + plan2.resolveOperators(function) + assert(invocationCount === 1) + } + + test("resolveOperatorsDown skips partially resolved plans") { + invocationCount = 0 + val plan1 = Project(Nil, LocalRelation()) + val plan2 = Project(Nil, plan1) + plan1.setAnalyzed() + plan2.resolveOperatorsDown(function) + assert(invocationCount === 1) + } + + test("resolveExpressions skips partially resolved plans") { + invocationCount = 0 + val plan1 = Project(projectExprs, LocalRelation()) + val plan2 = Project(projectExprs, plan1) + plan1.setAnalyzed() + plan2.resolveExpressions(exprFunction) + assert(invocationCount === 1) + } + + test("do not allow transform in analyzer") { + val plan = Project(Nil, LocalRelation()) + // These should be OK since we are not in the analzyer + plan.transform { case p: Project => p } + plan.transformUp { case p: Project => p } + plan.transformDown { case p: Project => p } + plan.transformAllExpressions { case lit: Literal => lit } + + // The following should fail in the analyzer scope + AnalysisHelper.markInAnalyzer { + intercept[RuntimeException] { plan.transform { case p: Project => p } } + intercept[RuntimeException] { plan.transformUp { case p: Project => p } } + intercept[RuntimeException] { plan.transformDown { case p: Project => p } } + intercept[RuntimeException] { plan.transformAllExpressions { case lit: Literal => lit } } + } + } + + test("allow transform in resolveOperators in the analyzer") { + val plan = Project(Nil, LocalRelation()) + AnalysisHelper.markInAnalyzer { + plan.resolveOperators { case p: Project => p.transform { case p: Project => p } } + plan.resolveOperatorsDown { case p: Project => p.transform { case p: Project => p } } + plan.resolveExpressions { case lit: Literal => + Project(Nil, LocalRelation()).transform { case p: Project => p } + lit + } + } + } + + test("allow transform with allowInvokingTransformsInAnalyzer in the analyzer") { + val plan = Project(Nil, LocalRelation()) + AnalysisHelper.markInAnalyzer { + AnalysisHelper.allowInvokingTransformsInAnalyzer { + plan.transform { case p: Project => p } + plan.transformUp { case p: Project => p } + plan.transformDown { case p: Project => p } + plan.transformAllExpressions { case lit: Literal => lit } + } + } + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala index cbf6106697f30..2423668392231 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala @@ -662,18 +662,18 @@ class DateTimeUtilsSuite extends SparkFunSuite { assert(daysToMillis(16800, TimeZoneGMT) === c.getTimeInMillis) // There are some days are skipped entirely in some timezone, skip them here. - val skipped_days = Map[String, Int]( - "Kwajalein" -> 8632, - "Pacific/Apia" -> 15338, - "Pacific/Enderbury" -> 9131, - "Pacific/Fakaofo" -> 15338, - "Pacific/Kiritimati" -> 9131, - "Pacific/Kwajalein" -> 8632, - "MIT" -> 15338) + val skipped_days = Map[String, Set[Int]]( + "Kwajalein" -> Set(8632), + "Pacific/Apia" -> Set(15338), + "Pacific/Enderbury" -> Set(9130, 9131), + "Pacific/Fakaofo" -> Set(15338), + "Pacific/Kiritimati" -> Set(9130, 9131), + "Pacific/Kwajalein" -> Set(8632), + "MIT" -> Set(15338)) for (tz <- DateTimeTestUtils.ALL_TIMEZONES) { - val skipped = skipped_days.getOrElse(tz.getID, Int.MinValue) + val skipped = skipped_days.getOrElse(tz.getID, Set.empty) (-20000 to 20000).foreach { d => - if (d != skipped) { + if (!skipped.contains(d)) { assert(millisToDays(daysToMillis(d, tz), tz) === d, s"Round trip of ${d} did not work in tz ${tz}") } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/UnsafeArraySuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/UnsafeArraySuite.scala index 8f75c14192c9b..755c8897cada2 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/UnsafeArraySuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/UnsafeArraySuite.scala @@ -114,7 +114,7 @@ class UnsafeArraySuite extends SparkFunSuite { assert(unsafeDate.isInstanceOf[UnsafeArrayData]) assert(unsafeDate.numElements == dateArray.length) dateArray.zipWithIndex.map { case (e, i) => - assert(unsafeDate.get(i, DateType) == e) + assert(unsafeDate.get(i, DateType).asInstanceOf[Int] == e) } val unsafeTimestamp = ExpressionEncoder[Array[Long]].resolveAndBind(). @@ -122,7 +122,7 @@ class UnsafeArraySuite extends SparkFunSuite { assert(unsafeTimestamp.isInstanceOf[UnsafeArrayData]) assert(unsafeTimestamp.numElements == timestampArray.length) timestampArray.zipWithIndex.map { case (e, i) => - assert(unsafeTimestamp.get(i, TimestampType) == e) + assert(unsafeTimestamp.get(i, TimestampType).asInstanceOf[Long] == e) } Seq(decimalArray4_1, decimalArray20_20).map { decimalArray => diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala index 5a86f4055dce7..122a3125ee2c4 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala @@ -154,7 +154,7 @@ class DataTypeSuite extends SparkFunSuite { left.merge(right) }.getMessage assert(message.equals("Failed to merge fields 'b' and 'b'. " + - "Failed to merge incompatible data types FloatType and LongType")) + "Failed to merge incompatible data types float and bigint")) } test("existsRecursively") { @@ -452,4 +452,30 @@ class DataTypeSuite extends SparkFunSuite { new StructType().add("f1", IntegerType).add("f", new StructType().add("f2", StringType, false)), new StructType().add("f2", IntegerType).add("g", new StructType().add("f1", StringType)), false) + + test("SPARK-25031: MapType should produce current formatted string for complex types") { + val keyType: DataType = StructType(Seq( + StructField("a", DataTypes.IntegerType), + StructField("b", DataTypes.IntegerType))) + + val valueType: DataType = StructType(Seq( + StructField("c", DataTypes.IntegerType), + StructField("d", DataTypes.IntegerType))) + + val builder = new StringBuilder + + MapType(keyType, valueType).buildFormattedString(prefix = "", builder = builder) + + val result = builder.toString() + val expected = + """-- key: struct + | |-- a: integer (nullable = true) + | |-- b: integer (nullable = true) + |-- value: struct (valueContainsNull = true) + | |-- c: integer (nullable = true) + | |-- d: integer (nullable = true) + |""".stripMargin + + assert(result === expected) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeWriteCompatibilitySuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeWriteCompatibilitySuite.scala new file mode 100644 index 0000000000000..d92f52f3248aa --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeWriteCompatibilitySuite.scala @@ -0,0 +1,404 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.types + +import scala.collection.mutable + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.analysis +import org.apache.spark.sql.catalyst.expressions.Cast + +class DataTypeWriteCompatibilitySuite extends SparkFunSuite { + private val atomicTypes = Seq(BooleanType, ByteType, ShortType, IntegerType, LongType, FloatType, + DoubleType, DateType, TimestampType, StringType, BinaryType) + + private val point2 = StructType(Seq( + StructField("x", FloatType, nullable = false), + StructField("y", FloatType, nullable = false))) + + private val widerPoint2 = StructType(Seq( + StructField("x", DoubleType, nullable = false), + StructField("y", DoubleType, nullable = false))) + + private val point3 = StructType(Seq( + StructField("x", FloatType, nullable = false), + StructField("y", FloatType, nullable = false), + StructField("z", FloatType))) + + private val simpleContainerTypes = Seq( + ArrayType(LongType), ArrayType(LongType, containsNull = false), MapType(StringType, DoubleType), + MapType(StringType, DoubleType, valueContainsNull = false), point2, point3) + + private val nestedContainerTypes = Seq(ArrayType(point2, containsNull = false), + MapType(StringType, point3, valueContainsNull = false)) + + private val allNonNullTypes = Seq( + atomicTypes, simpleContainerTypes, nestedContainerTypes, Seq(CalendarIntervalType)).flatten + + test("Check NullType is incompatible with all other types") { + allNonNullTypes.foreach { t => + assertSingleError(NullType, t, "nulls", s"Should not allow writing None to type $t") { err => + assert(err.contains(s"incompatible with $t")) + } + } + } + + test("Check each type with itself") { + allNonNullTypes.foreach { t => + assertAllowed(t, t, "t", s"Should allow writing type to itself $t") + } + } + + test("Check atomic types: write allowed only when casting is safe") { + atomicTypes.foreach { w => + atomicTypes.foreach { r => + if (Cast.canSafeCast(w, r)) { + assertAllowed(w, r, "t", s"Should allow writing $w to $r because cast is safe") + + } else { + assertSingleError(w, r, "t", + s"Should not allow writing $w to $r because cast is not safe") { err => + assert(err.contains("'t'"), "Should include the field name context") + assert(err.contains("Cannot safely cast"), "Should identify unsafe cast") + assert(err.contains(s"$w"), "Should include write type") + assert(err.contains(s"$r"), "Should include read type") + } + } + } + } + } + + test("Check struct types: missing required field") { + val missingRequiredField = StructType(Seq(StructField("x", FloatType, nullable = false))) + assertSingleError(missingRequiredField, point2, "t", + "Should fail because required field 'y' is missing") { err => + assert(err.contains("'t'"), "Should include the struct name for context") + assert(err.contains("'y'"), "Should include the nested field name") + assert(err.contains("missing field"), "Should call out field missing") + } + } + + test("Check struct types: missing starting field, matched by position") { + val missingRequiredField = StructType(Seq(StructField("y", FloatType, nullable = false))) + + // should have 2 errors: names x and y don't match, and field y is missing + assertNumErrors(missingRequiredField, point2, "t", + "Should fail because field 'x' is matched to field 'y' and required field 'y' is missing", 2) + { errs => + assert(errs(0).contains("'t'"), "Should include the struct name for context") + assert(errs(0).contains("expected 'x', found 'y'"), "Should detect name mismatch") + assert(errs(0).contains("field name does not match"), "Should identify name problem") + + assert(errs(1).contains("'t'"), "Should include the struct name for context") + assert(errs(1).contains("'y'"), "Should include the _last_ nested fields of the read schema") + assert(errs(1).contains("missing field"), "Should call out field missing") + } + } + + test("Check struct types: missing middle field, matched by position") { + val missingMiddleField = StructType(Seq( + StructField("x", FloatType, nullable = false), + StructField("z", FloatType, nullable = false))) + + val expectedStruct = StructType(Seq( + StructField("x", FloatType, nullable = false), + StructField("y", FloatType, nullable = false), + StructField("z", FloatType, nullable = true))) + + // types are compatible: (req int, req int) => (req int, req int, opt int) + // but this should still fail because the names do not match. + + assertNumErrors(missingMiddleField, expectedStruct, "t", + "Should fail because field 'y' is matched to field 'z'", 2) { errs => + assert(errs(0).contains("'t'"), "Should include the struct name for context") + assert(errs(0).contains("expected 'y', found 'z'"), "Should detect name mismatch") + assert(errs(0).contains("field name does not match"), "Should identify name problem") + + assert(errs(1).contains("'t'"), "Should include the struct name for context") + assert(errs(1).contains("'z'"), "Should include the nested field name") + assert(errs(1).contains("missing field"), "Should call out field missing") + } + } + + test("Check struct types: generic colN names are ignored") { + val missingMiddleField = StructType(Seq( + StructField("col1", FloatType, nullable = false), + StructField("col2", FloatType, nullable = false))) + + val expectedStruct = StructType(Seq( + StructField("x", FloatType, nullable = false), + StructField("y", FloatType, nullable = false))) + + // types are compatible: (req int, req int) => (req int, req int) + // names don't match, but match the naming convention used by Spark to fill in names + + assertAllowed(missingMiddleField, expectedStruct, "t", + "Should succeed because column names are ignored") + } + + test("Check struct types: required field is optional") { + val requiredFieldIsOptional = StructType(Seq( + StructField("x", FloatType), + StructField("y", FloatType, nullable = false))) + + assertSingleError(requiredFieldIsOptional, point2, "t", + "Should fail because required field 'x' is optional") { err => + assert(err.contains("'t.x'"), "Should include the nested field name context") + assert(err.contains("Cannot write nullable values to non-null field")) + } + } + + test("Check struct types: data field would be dropped") { + assertSingleError(point3, point2, "t", + "Should fail because field 'z' would be dropped") { err => + assert(err.contains("'t'"), "Should include the struct name for context") + assert(err.contains("'z'"), "Should include the extra field name") + assert(err.contains("Cannot write extra fields")) + } + } + + test("Check struct types: unsafe casts are not allowed") { + assertNumErrors(widerPoint2, point2, "t", + "Should fail because types require unsafe casts", 2) { errs => + + assert(errs(0).contains("'t.x'"), "Should include the nested field name context") + assert(errs(0).contains("Cannot safely cast")) + + assert(errs(1).contains("'t.y'"), "Should include the nested field name context") + assert(errs(1).contains("Cannot safely cast")) + } + } + + test("Check struct types: type promotion is allowed") { + assertAllowed(point2, widerPoint2, "t", + "Should allow widening float fields x and y to double") + } + + ignore("Check struct types: missing optional field is allowed") { + // built-in data sources do not yet support missing fields when optional + assertAllowed(point2, point3, "t", + "Should allow writing point (x,y) to point(x,y,z=null)") + } + + test("Check array types: unsafe casts are not allowed") { + val arrayOfLong = ArrayType(LongType) + val arrayOfInt = ArrayType(IntegerType) + + assertSingleError(arrayOfLong, arrayOfInt, "arr", + "Should not allow array of longs to array of ints") { err => + assert(err.contains("'arr.element'"), + "Should identify problem with named array's element type") + assert(err.contains("Cannot safely cast")) + } + } + + test("Check array types: type promotion is allowed") { + val arrayOfLong = ArrayType(LongType) + val arrayOfInt = ArrayType(IntegerType) + assertAllowed(arrayOfInt, arrayOfLong, "arr", + "Should allow array of int written to array of long column") + } + + test("Check array types: cannot write optional to required elements") { + val arrayOfRequired = ArrayType(LongType, containsNull = false) + val arrayOfOptional = ArrayType(LongType) + + assertSingleError(arrayOfOptional, arrayOfRequired, "arr", + "Should not allow array of optional elements to array of required elements") { err => + assert(err.contains("'arr'"), "Should include type name context") + assert(err.contains("Cannot write nullable elements to array of non-nulls")) + } + } + + test("Check array types: writing required to optional elements is allowed") { + val arrayOfRequired = ArrayType(LongType, containsNull = false) + val arrayOfOptional = ArrayType(LongType) + + assertAllowed(arrayOfRequired, arrayOfOptional, "arr", + "Should allow array of required elements to array of optional elements") + } + + test("Check map value types: unsafe casts are not allowed") { + val mapOfLong = MapType(StringType, LongType) + val mapOfInt = MapType(StringType, IntegerType) + + assertSingleError(mapOfLong, mapOfInt, "m", + "Should not allow map of longs to map of ints") { err => + assert(err.contains("'m.value'"), "Should identify problem with named map's value type") + assert(err.contains("Cannot safely cast")) + } + } + + test("Check map value types: type promotion is allowed") { + val mapOfLong = MapType(StringType, LongType) + val mapOfInt = MapType(StringType, IntegerType) + + assertAllowed(mapOfInt, mapOfLong, "m", "Should allow map of int written to map of long column") + } + + test("Check map value types: cannot write optional to required values") { + val mapOfRequired = MapType(StringType, LongType, valueContainsNull = false) + val mapOfOptional = MapType(StringType, LongType) + + assertSingleError(mapOfOptional, mapOfRequired, "m", + "Should not allow map of optional values to map of required values") { err => + assert(err.contains("'m'"), "Should include type name context") + assert(err.contains("Cannot write nullable values to map of non-nulls")) + } + } + + test("Check map value types: writing required to optional values is allowed") { + val mapOfRequired = MapType(StringType, LongType, valueContainsNull = false) + val mapOfOptional = MapType(StringType, LongType) + + assertAllowed(mapOfRequired, mapOfOptional, "m", + "Should allow map of required elements to map of optional elements") + } + + test("Check map key types: unsafe casts are not allowed") { + val mapKeyLong = MapType(LongType, StringType) + val mapKeyInt = MapType(IntegerType, StringType) + + assertSingleError(mapKeyLong, mapKeyInt, "m", + "Should not allow map of long keys to map of int keys") { err => + assert(err.contains("'m.key'"), "Should identify problem with named map's key type") + assert(err.contains("Cannot safely cast")) + } + } + + test("Check map key types: type promotion is allowed") { + val mapKeyLong = MapType(LongType, StringType) + val mapKeyInt = MapType(IntegerType, StringType) + + assertAllowed(mapKeyInt, mapKeyLong, "m", + "Should allow map of int written to map of long column") + } + + test("Check types with multiple errors") { + val readType = StructType(Seq( + StructField("a", ArrayType(DoubleType, containsNull = false)), + StructField("arr_of_structs", ArrayType(point2, containsNull = false)), + StructField("bad_nested_type", ArrayType(StringType)), + StructField("m", MapType(LongType, FloatType, valueContainsNull = false)), + StructField("map_of_structs", MapType(StringType, point3, valueContainsNull = false)), + StructField("x", IntegerType, nullable = false), + StructField("missing1", StringType, nullable = false), + StructField("missing2", StringType) + )) + + val missingMiddleField = StructType(Seq( + StructField("x", FloatType, nullable = false), + StructField("z", FloatType, nullable = false))) + + val writeType = StructType(Seq( + StructField("a", ArrayType(StringType)), + StructField("arr_of_structs", ArrayType(point3)), + StructField("bad_nested_type", point3), + StructField("m", MapType(DoubleType, DoubleType)), + StructField("map_of_structs", MapType(StringType, missingMiddleField)), + StructField("y", LongType) + )) + + assertNumErrors(writeType, readType, "top", "Should catch 14 errors", 14) { errs => + assert(errs(0).contains("'top.a.element'"), "Should identify bad type") + assert(errs(0).contains("Cannot safely cast")) + assert(errs(0).contains("StringType to DoubleType")) + + assert(errs(1).contains("'top.a'"), "Should identify bad type") + assert(errs(1).contains("Cannot write nullable elements to array of non-nulls")) + + assert(errs(2).contains("'top.arr_of_structs.element'"), "Should identify bad type") + assert(errs(2).contains("'z'"), "Should identify bad field") + assert(errs(2).contains("Cannot write extra fields to struct")) + + assert(errs(3).contains("'top.arr_of_structs'"), "Should identify bad type") + assert(errs(3).contains("Cannot write nullable elements to array of non-nulls")) + + assert(errs(4).contains("'top.bad_nested_type'"), "Should identify bad type") + assert(errs(4).contains("is incompatible with")) + + assert(errs(5).contains("'top.m.key'"), "Should identify bad type") + assert(errs(5).contains("Cannot safely cast")) + assert(errs(5).contains("DoubleType to LongType")) + + assert(errs(6).contains("'top.m.value'"), "Should identify bad type") + assert(errs(6).contains("Cannot safely cast")) + assert(errs(6).contains("DoubleType to FloatType")) + + assert(errs(7).contains("'top.m'"), "Should identify bad type") + assert(errs(7).contains("Cannot write nullable values to map of non-nulls")) + + assert(errs(8).contains("'top.map_of_structs.value'"), "Should identify bad type") + assert(errs(8).contains("expected 'y', found 'z'"), "Should detect name mismatch") + assert(errs(8).contains("field name does not match"), "Should identify name problem") + + assert(errs(9).contains("'top.map_of_structs.value'"), "Should identify bad type") + assert(errs(9).contains("'z'"), "Should identify missing field") + assert(errs(9).contains("missing fields"), "Should detect missing field") + + assert(errs(10).contains("'top.map_of_structs'"), "Should identify bad type") + assert(errs(10).contains("Cannot write nullable values to map of non-nulls")) + + assert(errs(11).contains("'top.x'"), "Should identify bad type") + assert(errs(11).contains("Cannot safely cast")) + assert(errs(11).contains("LongType to IntegerType")) + + assert(errs(12).contains("'top'"), "Should identify bad type") + assert(errs(12).contains("expected 'x', found 'y'"), "Should detect name mismatch") + assert(errs(12).contains("field name does not match"), "Should identify name problem") + + assert(errs(13).contains("'top'"), "Should identify bad type") + assert(errs(13).contains("'missing1'"), "Should identify missing field") + assert(errs(13).contains("missing fields"), "Should detect missing field") + } + } + + // Helper functions + + def assertAllowed(writeType: DataType, readType: DataType, name: String, desc: String): Unit = { + assert( + DataType.canWrite(writeType, readType, analysis.caseSensitiveResolution, name, + errMsg => fail(s"Should not produce errors but was called with: $errMsg")) === true, desc) + } + + def assertSingleError( + writeType: DataType, + readType: DataType, + name: String, + desc: String) + (errFunc: String => Unit): Unit = { + assertNumErrors(writeType, readType, name, desc, 1) { errs => + errFunc(errs.head) + } + } + + def assertNumErrors( + writeType: DataType, + readType: DataType, + name: String, + desc: String, + numErrs: Int) + (errFunc: Seq[String] => Unit): Unit = { + val errs = new mutable.ArrayBuffer[String]() + assert( + DataType.canWrite(writeType, readType, analysis.caseSensitiveResolution, name, + errMsg => errs += errMsg) === false, desc) + assert(errs.size === numErrs, s"Should produce $numErrs error messages") + errFunc(errs) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/MetadataSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/MetadataSuite.scala new file mode 100644 index 0000000000000..210e65708170f --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/MetadataSuite.scala @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.types + +import org.apache.spark.SparkFunSuite + +class MetadataSuite extends SparkFunSuite { + test("String Metadata") { + val meta = new MetadataBuilder().putString("key", "value").build() + assert(meta === meta) + assert(meta.## !== 0) + assert(meta.getString("key") === "value") + assert(meta.contains("key")) + intercept[NoSuchElementException](meta.getString("no_such_key")) + intercept[ClassCastException](meta.getBoolean("key")) + } + + test("Long Metadata") { + val meta = new MetadataBuilder().putLong("key", 12).build() + assert(meta === meta) + assert(meta.## !== 0) + assert(meta.getLong("key") === 12) + assert(meta.contains("key")) + intercept[NoSuchElementException](meta.getLong("no_such_key")) + intercept[ClassCastException](meta.getBoolean("key")) + } + + test("Double Metadata") { + val meta = new MetadataBuilder().putDouble("key", 12).build() + assert(meta === meta) + assert(meta.## !== 0) + assert(meta.getDouble("key") === 12) + assert(meta.contains("key")) + intercept[NoSuchElementException](meta.getDouble("no_such_key")) + intercept[ClassCastException](meta.getBoolean("key")) + } + + test("Boolean Metadata") { + val meta = new MetadataBuilder().putBoolean("key", true).build() + assert(meta === meta) + assert(meta.## !== 0) + assert(meta.getBoolean("key") === true) + assert(meta.contains("key")) + intercept[NoSuchElementException](meta.getBoolean("no_such_key")) + intercept[ClassCastException](meta.getString("key")) + } + + test("Null Metadata") { + val meta = new MetadataBuilder().putNull("key").build() + assert(meta === meta) + assert(meta.## !== 0) + assert(meta.getString("key") === null) + assert(meta.getDouble("key") === 0) + assert(meta.getLong("key") === 0) + assert(meta.getBoolean("key") === false) + assert(meta.contains("key")) + intercept[NoSuchElementException](meta.getLong("no_such_key")) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/StructTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/StructTypeSuite.scala index c6ca8bb005429..53a78c94aa6fb 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/StructTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/StructTypeSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.types import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.types.StructType.fromDDL class StructTypeSuite extends SparkFunSuite { @@ -37,4 +38,36 @@ class StructTypeSuite extends SparkFunSuite { val e = intercept[IllegalArgumentException](s.fieldIndex("c")).getMessage assert(e.contains("Available fields: a, b")) } + + test("SPARK-24849: toDDL - simple struct") { + val struct = StructType(Seq(StructField("a", IntegerType))) + + assert(struct.toDDL == "`a` INT") + } + + test("SPARK-24849: round trip toDDL - fromDDL") { + val struct = new StructType().add("a", IntegerType).add("b", StringType) + + assert(fromDDL(struct.toDDL) === struct) + } + + test("SPARK-24849: round trip fromDDL - toDDL") { + val struct = "`a` MAP,`b` INT" + + assert(fromDDL(struct).toDDL === struct) + } + + test("SPARK-24849: toDDL must take into account case of fields.") { + val struct = new StructType() + .add("metaData", new StructType().add("eventId", StringType)) + + assert(struct.toDDL == "`metaData` STRUCT<`eventId`: STRING>") + } + + test("SPARK-24849: toDDL should output field's comment") { + val struct = StructType(Seq( + StructField("b", BooleanType).withComment("Field's comment"))) + + assert(struct.toDDL == """`b` BOOLEAN COMMENT 'Field\'s comment'""") + } } diff --git a/sql/core/benchmarks/FilterPushdownBenchmark-results.txt b/sql/core/benchmarks/FilterPushdownBenchmark-results.txt new file mode 100644 index 0000000000000..a75a15c99328a --- /dev/null +++ b/sql/core/benchmarks/FilterPushdownBenchmark-results.txt @@ -0,0 +1,738 @@ +================================================================================================ +Pushdown for many distinct value case +================================================================================================ + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 0 string row (value IS NULL): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 8970 / 9122 1.8 570.3 1.0X +Parquet Vectorized (Pushdown) 471 / 491 33.4 30.0 19.0X +Native ORC Vectorized 7661 / 7853 2.1 487.0 1.2X +Native ORC Vectorized (Pushdown) 1134 / 1161 13.9 72.1 7.9X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 0 string row ('7864320' < value < '7864320'): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 9246 / 9297 1.7 587.8 1.0X +Parquet Vectorized (Pushdown) 480 / 488 32.8 30.5 19.3X +Native ORC Vectorized 7838 / 7850 2.0 498.3 1.2X +Native ORC Vectorized (Pushdown) 1054 / 1118 14.9 67.0 8.8X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 1 string row (value = '7864320'): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 8989 / 9100 1.7 571.5 1.0X +Parquet Vectorized (Pushdown) 448 / 467 35.1 28.5 20.1X +Native ORC Vectorized 7680 / 7768 2.0 488.3 1.2X +Native ORC Vectorized (Pushdown) 1067 / 1118 14.7 67.8 8.4X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 1 string row (value <=> '7864320'): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 9115 / 9266 1.7 579.5 1.0X +Parquet Vectorized (Pushdown) 466 / 492 33.7 29.7 19.5X +Native ORC Vectorized 7800 / 7914 2.0 495.9 1.2X +Native ORC Vectorized (Pushdown) 1075 / 1102 14.6 68.4 8.5X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 1 string row ('7864320' <= value <= '7864320'): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 9099 / 9237 1.7 578.5 1.0X +Parquet Vectorized (Pushdown) 462 / 475 34.1 29.3 19.7X +Native ORC Vectorized 7847 / 7925 2.0 498.9 1.2X +Native ORC Vectorized (Pushdown) 1078 / 1114 14.6 68.5 8.4X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select all string rows (value IS NOT NULL): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 19303 / 19547 0.8 1227.3 1.0X +Parquet Vectorized (Pushdown) 19924 / 20089 0.8 1266.7 1.0X +Native ORC Vectorized 18725 / 19079 0.8 1190.5 1.0X +Native ORC Vectorized (Pushdown) 19310 / 19492 0.8 1227.7 1.0X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 0 int row (value IS NULL): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 8117 / 8323 1.9 516.1 1.0X +Parquet Vectorized (Pushdown) 484 / 494 32.5 30.8 16.8X +Native ORC Vectorized 6811 / 7036 2.3 433.0 1.2X +Native ORC Vectorized (Pushdown) 1061 / 1082 14.8 67.5 7.6X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 0 int row (7864320 < value < 7864320): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 8105 / 8140 1.9 515.3 1.0X +Parquet Vectorized (Pushdown) 478 / 505 32.9 30.4 17.0X +Native ORC Vectorized 6914 / 7211 2.3 439.6 1.2X +Native ORC Vectorized (Pushdown) 1044 / 1064 15.1 66.4 7.8X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 1 int row (value = 7864320): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 7983 / 8116 2.0 507.6 1.0X +Parquet Vectorized (Pushdown) 464 / 487 33.9 29.5 17.2X +Native ORC Vectorized 6703 / 6774 2.3 426.1 1.2X +Native ORC Vectorized (Pushdown) 1017 / 1058 15.5 64.6 7.9X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 1 int row (value <=> 7864320): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 7942 / 7983 2.0 504.9 1.0X +Parquet Vectorized (Pushdown) 468 / 479 33.6 29.7 17.0X +Native ORC Vectorized 6677 / 6779 2.4 424.5 1.2X +Native ORC Vectorized (Pushdown) 1021 / 1068 15.4 64.9 7.8X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 1 int row (7864320 <= value <= 7864320): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 7909 / 7958 2.0 502.8 1.0X +Parquet Vectorized (Pushdown) 485 / 494 32.4 30.8 16.3X +Native ORC Vectorized 6751 / 6846 2.3 429.2 1.2X +Native ORC Vectorized (Pushdown) 1043 / 1077 15.1 66.3 7.6X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 1 int row (7864319 < value < 7864321): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 8010 / 8033 2.0 509.2 1.0X +Parquet Vectorized (Pushdown) 472 / 489 33.3 30.0 17.0X +Native ORC Vectorized 6655 / 6808 2.4 423.1 1.2X +Native ORC Vectorized (Pushdown) 1015 / 1067 15.5 64.5 7.9X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 10% int rows (value < 1572864): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 8983 / 9035 1.8 571.1 1.0X +Parquet Vectorized (Pushdown) 2204 / 2231 7.1 140.1 4.1X +Native ORC Vectorized 7864 / 8011 2.0 500.0 1.1X +Native ORC Vectorized (Pushdown) 2674 / 2789 5.9 170.0 3.4X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 50% int rows (value < 7864320): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 12723 / 12903 1.2 808.9 1.0X +Parquet Vectorized (Pushdown) 9112 / 9282 1.7 579.3 1.4X +Native ORC Vectorized 12090 / 12230 1.3 768.7 1.1X +Native ORC Vectorized (Pushdown) 9242 / 9372 1.7 587.6 1.4X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 90% int rows (value < 14155776): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 16453 / 16678 1.0 1046.1 1.0X +Parquet Vectorized (Pushdown) 15997 / 16262 1.0 1017.0 1.0X +Native ORC Vectorized 16652 / 17070 0.9 1058.7 1.0X +Native ORC Vectorized (Pushdown) 15843 / 16112 1.0 1007.2 1.0X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select all int rows (value IS NOT NULL): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 17098 / 17254 0.9 1087.1 1.0X +Parquet Vectorized (Pushdown) 17302 / 17529 0.9 1100.1 1.0X +Native ORC Vectorized 16790 / 17098 0.9 1067.5 1.0X +Native ORC Vectorized (Pushdown) 17329 / 17914 0.9 1101.7 1.0X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select all int rows (value > -1): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 17088 / 17392 0.9 1086.4 1.0X +Parquet Vectorized (Pushdown) 17609 / 17863 0.9 1119.5 1.0X +Native ORC Vectorized 18334 / 69831 0.9 1165.7 0.9X +Native ORC Vectorized (Pushdown) 17465 / 17629 0.9 1110.4 1.0X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select all int rows (value != -1): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 16903 / 17233 0.9 1074.6 1.0X +Parquet Vectorized (Pushdown) 16945 / 17032 0.9 1077.3 1.0X +Native ORC Vectorized 16377 / 16762 1.0 1041.2 1.0X +Native ORC Vectorized (Pushdown) 16950 / 17212 0.9 1077.7 1.0X + + +================================================================================================ +Pushdown for few distinct value case (use dictionary encoding) +================================================================================================ + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 0 distinct string row (value IS NULL): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 7245 / 7322 2.2 460.7 1.0X +Parquet Vectorized (Pushdown) 378 / 389 41.6 24.0 19.2X +Native ORC Vectorized 6720 / 6778 2.3 427.2 1.1X +Native ORC Vectorized (Pushdown) 1009 / 1032 15.6 64.2 7.2X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 0 distinct string row ('100' < value < '100'): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 7627 / 7795 2.1 484.9 1.0X +Parquet Vectorized (Pushdown) 384 / 406 41.0 24.4 19.9X +Native ORC Vectorized 6724 / 7824 2.3 427.5 1.1X +Native ORC Vectorized (Pushdown) 968 / 986 16.3 61.5 7.9X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 1 distinct string row (value = '100'): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 7157 / 7534 2.2 455.0 1.0X +Parquet Vectorized (Pushdown) 542 / 565 29.0 34.5 13.2X +Native ORC Vectorized 6716 / 7214 2.3 427.0 1.1X +Native ORC Vectorized (Pushdown) 1212 / 1288 13.0 77.0 5.9X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 1 distinct string row (value <=> '100'): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 7368 / 7552 2.1 468.4 1.0X +Parquet Vectorized (Pushdown) 544 / 556 28.9 34.6 13.5X +Native ORC Vectorized 6740 / 6867 2.3 428.5 1.1X +Native ORC Vectorized (Pushdown) 1230 / 1426 12.8 78.2 6.0X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 1 distinct string row ('100' <= value <= '100'): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 7427 / 7734 2.1 472.2 1.0X +Parquet Vectorized (Pushdown) 556 / 568 28.3 35.4 13.3X +Native ORC Vectorized 6847 / 7059 2.3 435.3 1.1X +Native ORC Vectorized (Pushdown) 1226 / 1230 12.8 77.9 6.1X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select all distinct string rows (value IS NOT NULL): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 16998 / 17311 0.9 1080.7 1.0X +Parquet Vectorized (Pushdown) 16977 / 17250 0.9 1079.4 1.0X +Native ORC Vectorized 18447 / 19852 0.9 1172.8 0.9X +Native ORC Vectorized (Pushdown) 16614 / 17102 0.9 1056.3 1.0X + + +================================================================================================ +Pushdown benchmark for StringStartsWith +================================================================================================ + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +StringStartsWith filter: (value like '10%'): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 9705 / 10814 1.6 617.0 1.0X +Parquet Vectorized (Pushdown) 3086 / 3574 5.1 196.2 3.1X +Native ORC Vectorized 10094 / 10695 1.6 641.8 1.0X +Native ORC Vectorized (Pushdown) 9611 / 9999 1.6 611.0 1.0X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +StringStartsWith filter: (value like '1000%'): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 8016 / 8183 2.0 509.7 1.0X +Parquet Vectorized (Pushdown) 444 / 457 35.4 28.2 18.0X +Native ORC Vectorized 6970 / 7169 2.3 443.2 1.2X +Native ORC Vectorized (Pushdown) 7447 / 7503 2.1 473.5 1.1X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +StringStartsWith filter: (value like '786432%'): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 7908 / 8046 2.0 502.8 1.0X +Parquet Vectorized (Pushdown) 408 / 429 38.6 25.9 19.4X +Native ORC Vectorized 7021 / 7100 2.2 446.4 1.1X +Native ORC Vectorized (Pushdown) 7310 / 7490 2.2 464.8 1.1X + + +================================================================================================ +Pushdown benchmark for decimal +================================================================================================ + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 1 decimal(9, 2) row (value = 7864320): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 4546 / 4743 3.5 289.0 1.0X +Parquet Vectorized (Pushdown) 161 / 175 98.0 10.2 28.3X +Native ORC Vectorized 5721 / 5842 2.7 363.7 0.8X +Native ORC Vectorized (Pushdown) 1019 / 1070 15.4 64.8 4.5X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 10% decimal(9, 2) rows (value < 1572864): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 6340 / 7236 2.5 403.1 1.0X +Parquet Vectorized (Pushdown) 3052 / 3164 5.2 194.1 2.1X +Native ORC Vectorized 8370 / 9214 1.9 532.1 0.8X +Native ORC Vectorized (Pushdown) 4137 / 4242 3.8 263.0 1.5X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 50% decimal(9, 2) rows (value < 7864320): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 12976 / 13249 1.2 825.0 1.0X +Parquet Vectorized (Pushdown) 12655 / 13570 1.2 804.6 1.0X +Native ORC Vectorized 15562 / 15950 1.0 989.4 0.8X +Native ORC Vectorized (Pushdown) 15042 / 15668 1.0 956.3 0.9X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 90% decimal(9, 2) rows (value < 14155776): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 14303 / 14616 1.1 909.3 1.0X +Parquet Vectorized (Pushdown) 14380 / 14649 1.1 914.3 1.0X +Native ORC Vectorized 16964 / 17358 0.9 1078.5 0.8X +Native ORC Vectorized (Pushdown) 17255 / 17874 0.9 1097.0 0.8X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 1 decimal(18, 2) row (value = 7864320): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 4701 / 6416 3.3 298.9 1.0X +Parquet Vectorized (Pushdown) 128 / 164 122.8 8.1 36.7X +Native ORC Vectorized 5698 / 7904 2.8 362.3 0.8X +Native ORC Vectorized (Pushdown) 913 / 942 17.2 58.0 5.2X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 10% decimal(18, 2) rows (value < 1572864): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 5376 / 5461 2.9 341.8 1.0X +Parquet Vectorized (Pushdown) 1479 / 1543 10.6 94.0 3.6X +Native ORC Vectorized 6640 / 6748 2.4 422.2 0.8X +Native ORC Vectorized (Pushdown) 2438 / 2479 6.5 155.0 2.2X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 50% decimal(18, 2) rows (value < 7864320): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 9224 / 9356 1.7 586.5 1.0X +Parquet Vectorized (Pushdown) 7172 / 7415 2.2 456.0 1.3X +Native ORC Vectorized 11017 / 11408 1.4 700.4 0.8X +Native ORC Vectorized (Pushdown) 8771 / 10218 1.8 557.7 1.1X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 90% decimal(18, 2) rows (value < 14155776): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 13933 / 15990 1.1 885.8 1.0X +Parquet Vectorized (Pushdown) 12683 / 12942 1.2 806.4 1.1X +Native ORC Vectorized 16344 / 20196 1.0 1039.1 0.9X +Native ORC Vectorized (Pushdown) 15162 / 16627 1.0 964.0 0.9X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 1 decimal(38, 2) row (value = 7864320): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 7102 / 8282 2.2 451.5 1.0X +Parquet Vectorized (Pushdown) 124 / 150 126.4 7.9 57.1X +Native ORC Vectorized 5811 / 6883 2.7 369.5 1.2X +Native ORC Vectorized (Pushdown) 1121 / 1502 14.0 71.3 6.3X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 10% decimal(38, 2) rows (value < 1572864): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 6894 / 7562 2.3 438.3 1.0X +Parquet Vectorized (Pushdown) 1863 / 1980 8.4 118.4 3.7X +Native ORC Vectorized 6812 / 6848 2.3 433.1 1.0X +Native ORC Vectorized (Pushdown) 2511 / 2598 6.3 159.7 2.7X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 50% decimal(38, 2) rows (value < 7864320): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 11732 / 12183 1.3 745.9 1.0X +Parquet Vectorized (Pushdown) 8912 / 9945 1.8 566.6 1.3X +Native ORC Vectorized 11499 / 12387 1.4 731.1 1.0X +Native ORC Vectorized (Pushdown) 9328 / 9382 1.7 593.1 1.3X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 90% decimal(38, 2) rows (value < 14155776): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 16272 / 16328 1.0 1034.6 1.0X +Parquet Vectorized (Pushdown) 15714 / 18100 1.0 999.1 1.0X +Native ORC Vectorized 16539 / 18897 1.0 1051.5 1.0X +Native ORC Vectorized (Pushdown) 16328 / 17306 1.0 1038.1 1.0X + + +================================================================================================ +Pushdown benchmark for InSet -> InFilters +================================================================================================ + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +InSet -> InFilters (values count: 5, distribution: 10): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 7993 / 8104 2.0 508.2 1.0X +Parquet Vectorized (Pushdown) 507 / 532 31.0 32.2 15.8X +Native ORC Vectorized 6922 / 7163 2.3 440.1 1.2X +Native ORC Vectorized (Pushdown) 1017 / 1058 15.5 64.6 7.9X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +InSet -> InFilters (values count: 5, distribution: 50): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 7855 / 7963 2.0 499.4 1.0X +Parquet Vectorized (Pushdown) 503 / 516 31.3 32.0 15.6X +Native ORC Vectorized 6825 / 6954 2.3 433.9 1.2X +Native ORC Vectorized (Pushdown) 1019 / 1044 15.4 64.8 7.7X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +InSet -> InFilters (values count: 5, distribution: 90): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 7858 / 7928 2.0 499.6 1.0X +Parquet Vectorized (Pushdown) 490 / 519 32.1 31.1 16.0X +Native ORC Vectorized 7079 / 7966 2.2 450.1 1.1X +Native ORC Vectorized (Pushdown) 1276 / 1673 12.3 81.1 6.2X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +InSet -> InFilters (values count: 10, distribution: 10): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 8007 / 11155 2.0 509.0 1.0X +Parquet Vectorized (Pushdown) 519 / 540 30.3 33.0 15.4X +Native ORC Vectorized 6848 / 7072 2.3 435.4 1.2X +Native ORC Vectorized (Pushdown) 1026 / 1050 15.3 65.2 7.8X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +InSet -> InFilters (values count: 10, distribution: 50): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 7876 / 7956 2.0 500.7 1.0X +Parquet Vectorized (Pushdown) 521 / 535 30.2 33.1 15.1X +Native ORC Vectorized 7051 / 7368 2.2 448.3 1.1X +Native ORC Vectorized (Pushdown) 1014 / 1035 15.5 64.5 7.8X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +InSet -> InFilters (values count: 10, distribution: 90): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 7897 / 8229 2.0 502.1 1.0X +Parquet Vectorized (Pushdown) 513 / 530 30.7 32.6 15.4X +Native ORC Vectorized 6730 / 6990 2.3 427.9 1.2X +Native ORC Vectorized (Pushdown) 1003 / 1036 15.7 63.8 7.9X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +InSet -> InFilters (values count: 50, distribution: 10): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 7967 / 8175 2.0 506.5 1.0X +Parquet Vectorized (Pushdown) 8155 / 8434 1.9 518.5 1.0X +Native ORC Vectorized 7002 / 7107 2.2 445.2 1.1X +Native ORC Vectorized (Pushdown) 1092 / 1139 14.4 69.4 7.3X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +InSet -> InFilters (values count: 50, distribution: 50): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 8032 / 8122 2.0 510.7 1.0X +Parquet Vectorized (Pushdown) 8141 / 8908 1.9 517.6 1.0X +Native ORC Vectorized 7140 / 7387 2.2 454.0 1.1X +Native ORC Vectorized (Pushdown) 1156 / 1220 13.6 73.5 6.9X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +InSet -> InFilters (values count: 50, distribution: 90): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 8088 / 8350 1.9 514.2 1.0X +Parquet Vectorized (Pushdown) 8629 / 8702 1.8 548.6 0.9X +Native ORC Vectorized 7480 / 7886 2.1 475.6 1.1X +Native ORC Vectorized (Pushdown) 1106 / 1145 14.2 70.3 7.3X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +InSet -> InFilters (values count: 100, distribution: 10): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 8028 / 8165 2.0 510.4 1.0X +Parquet Vectorized (Pushdown) 8349 / 8674 1.9 530.8 1.0X +Native ORC Vectorized 7107 / 7354 2.2 451.8 1.1X +Native ORC Vectorized (Pushdown) 1175 / 1207 13.4 74.7 6.8X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +InSet -> InFilters (values count: 100, distribution: 50): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 8041 / 8195 2.0 511.2 1.0X +Parquet Vectorized (Pushdown) 8466 / 8604 1.9 538.2 0.9X +Native ORC Vectorized 7116 / 7286 2.2 452.4 1.1X +Native ORC Vectorized (Pushdown) 1197 / 1214 13.1 76.1 6.7X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +InSet -> InFilters (values count: 100, distribution: 90): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 7998 / 8311 2.0 508.5 1.0X +Parquet Vectorized (Pushdown) 9366 / 11257 1.7 595.5 0.9X +Native ORC Vectorized 7856 / 9273 2.0 499.5 1.0X +Native ORC Vectorized (Pushdown) 1350 / 1747 11.7 85.8 5.9X + + +================================================================================================ +Pushdown benchmark for tinyint +================================================================================================ + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 1 tinyint row (value = CAST(63 AS tinyint)): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 3461 / 3997 4.5 220.1 1.0X +Parquet Vectorized (Pushdown) 270 / 315 58.4 17.1 12.8X +Native ORC Vectorized 4107 / 5372 3.8 261.1 0.8X +Native ORC Vectorized (Pushdown) 778 / 1553 20.2 49.5 4.4X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 10% tinyint rows (value < CAST(12 AS tinyint)): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 4771 / 6655 3.3 303.3 1.0X +Parquet Vectorized (Pushdown) 1322 / 1606 11.9 84.0 3.6X +Native ORC Vectorized 4437 / 4572 3.5 282.1 1.1X +Native ORC Vectorized (Pushdown) 1781 / 1976 8.8 113.2 2.7X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 50% tinyint rows (value < CAST(63 AS tinyint)): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 7433 / 7752 2.1 472.6 1.0X +Parquet Vectorized (Pushdown) 5863 / 5913 2.7 372.8 1.3X +Native ORC Vectorized 7986 / 8084 2.0 507.7 0.9X +Native ORC Vectorized (Pushdown) 6522 / 6608 2.4 414.6 1.1X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 90% tinyint rows (value < CAST(114 AS tinyint)): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 11190 / 11519 1.4 711.4 1.0X +Parquet Vectorized (Pushdown) 10861 / 11206 1.4 690.5 1.0X +Native ORC Vectorized 11622 / 12196 1.4 738.9 1.0X +Native ORC Vectorized (Pushdown) 11377 / 11654 1.4 723.3 1.0X + + +================================================================================================ +Pushdown benchmark for Timestamp +================================================================================================ + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 1 timestamp stored as INT96 row (value = CAST(7864320 AS timestamp)): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 4784 / 4956 3.3 304.2 1.0X +Parquet Vectorized (Pushdown) 4838 / 4917 3.3 307.6 1.0X +Native ORC Vectorized 3923 / 4173 4.0 249.4 1.2X +Native ORC Vectorized (Pushdown) 894 / 943 17.6 56.8 5.4X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 10% timestamp stored as INT96 rows (value < CAST(1572864 AS timestamp)): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 5686 / 5901 2.8 361.5 1.0X +Parquet Vectorized (Pushdown) 5555 / 5895 2.8 353.2 1.0X +Native ORC Vectorized 4844 / 4957 3.2 308.0 1.2X +Native ORC Vectorized (Pushdown) 2141 / 2230 7.3 136.1 2.7X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 50% timestamp stored as INT96 rows (value < CAST(7864320 AS timestamp)): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 9100 / 9421 1.7 578.6 1.0X +Parquet Vectorized (Pushdown) 9122 / 9496 1.7 580.0 1.0X +Native ORC Vectorized 8365 / 8874 1.9 531.9 1.1X +Native ORC Vectorized (Pushdown) 7128 / 7376 2.2 453.2 1.3X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 90% timestamp stored as INT96 rows (value < CAST(14155776 AS timestamp)): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 12764 / 13120 1.2 811.5 1.0X +Parquet Vectorized (Pushdown) 12656 / 13003 1.2 804.7 1.0X +Native ORC Vectorized 13096 / 13233 1.2 832.6 1.0X +Native ORC Vectorized (Pushdown) 12710 / 15611 1.2 808.1 1.0X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 1 timestamp stored as TIMESTAMP_MICROS row (value = CAST(7864320 AS timestamp)): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 4381 / 4796 3.6 278.5 1.0X +Parquet Vectorized (Pushdown) 122 / 137 129.3 7.7 36.0X +Native ORC Vectorized 3913 / 3988 4.0 248.8 1.1X +Native ORC Vectorized (Pushdown) 905 / 945 17.4 57.6 4.8X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 10% timestamp stored as TIMESTAMP_MICROS rows (value < CAST(1572864 AS timestamp)): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 5145 / 5184 3.1 327.1 1.0X +Parquet Vectorized (Pushdown) 1426 / 1519 11.0 90.7 3.6X +Native ORC Vectorized 4827 / 4901 3.3 306.9 1.1X +Native ORC Vectorized (Pushdown) 2133 / 2210 7.4 135.6 2.4X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 50% timestamp stored as TIMESTAMP_MICROS rows (value < CAST(7864320 AS timestamp)): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 9234 / 9516 1.7 587.1 1.0X +Parquet Vectorized (Pushdown) 6752 / 7046 2.3 429.3 1.4X +Native ORC Vectorized 8418 / 8998 1.9 535.2 1.1X +Native ORC Vectorized (Pushdown) 7199 / 7314 2.2 457.7 1.3X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 90% timestamp stored as TIMESTAMP_MICROS rows (value < CAST(14155776 AS timestamp)): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 12414 / 12458 1.3 789.2 1.0X +Parquet Vectorized (Pushdown) 12094 / 12249 1.3 768.9 1.0X +Native ORC Vectorized 12198 / 13755 1.3 775.5 1.0X +Native ORC Vectorized (Pushdown) 12205 / 12431 1.3 776.0 1.0X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 1 timestamp stored as TIMESTAMP_MILLIS row (value = CAST(7864320 AS timestamp)): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 4369 / 4515 3.6 277.8 1.0X +Parquet Vectorized (Pushdown) 116 / 125 136.2 7.3 37.8X +Native ORC Vectorized 3965 / 4703 4.0 252.1 1.1X +Native ORC Vectorized (Pushdown) 892 / 1162 17.6 56.7 4.9X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 10% timestamp stored as TIMESTAMP_MILLIS rows (value < CAST(1572864 AS timestamp)): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 5211 / 5409 3.0 331.3 1.0X +Parquet Vectorized (Pushdown) 1427 / 1438 11.0 90.7 3.7X +Native ORC Vectorized 4719 / 4883 3.3 300.1 1.1X +Native ORC Vectorized (Pushdown) 2191 / 2228 7.2 139.3 2.4X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 50% timestamp stored as TIMESTAMP_MILLIS rows (value < CAST(7864320 AS timestamp)): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 8716 / 8953 1.8 554.2 1.0X +Parquet Vectorized (Pushdown) 6632 / 6968 2.4 421.7 1.3X +Native ORC Vectorized 8376 / 9118 1.9 532.5 1.0X +Native ORC Vectorized (Pushdown) 7218 / 7609 2.2 458.9 1.2X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 90% timestamp stored as TIMESTAMP_MILLIS rows (value < CAST(14155776 AS timestamp)): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 12264 / 12452 1.3 779.7 1.0X +Parquet Vectorized (Pushdown) 11766 / 11927 1.3 748.0 1.0X +Native ORC Vectorized 12101 / 12301 1.3 769.3 1.0X +Native ORC Vectorized (Pushdown) 11983 / 12651 1.3 761.9 1.0X + + +================================================================================================ +Pushdown benchmark with many filters +================================================================================================ + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_181-b13 on Mac OS X 10.13.6 +Intel(R) Core(TM) i7-4770HQ CPU @ 2.20GHz + +Select 1 row with 1 filters: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 158 / 182 0.0 158442969.0 1.0X +Parquet Vectorized (Pushdown) 150 / 158 0.0 149718289.0 1.1X +Native ORC Vectorized 141 / 148 0.0 141259852.0 1.1X +Native ORC Vectorized (Pushdown) 142 / 147 0.0 142016472.0 1.1X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_181-b13 on Mac OS X 10.13.6 +Intel(R) Core(TM) i7-4770HQ CPU @ 2.20GHz + +Select 1 row with 250 filters: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 1013 / 1026 0.0 1013194322.0 1.0X +Parquet Vectorized (Pushdown) 1326 / 1332 0.0 1326301956.0 0.8X +Native ORC Vectorized 1005 / 1010 0.0 1005266379.0 1.0X +Native ORC Vectorized (Pushdown) 1068 / 1071 0.0 1067964993.0 0.9X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_181-b13 on Mac OS X 10.13.6 +Intel(R) Core(TM) i7-4770HQ CPU @ 2.20GHz + +Select 1 row with 500 filters: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 3598 / 3614 0.0 3598001202.0 1.0X +Parquet Vectorized (Pushdown) 4282 / 4333 0.0 4281849770.0 0.8X +Native ORC Vectorized 3594 / 3619 0.0 3593551548.0 1.0X +Native ORC Vectorized (Pushdown) 3834 / 3840 0.0 3834240570.0 0.9X diff --git a/sql/core/pom.xml b/sql/core/pom.xml index f270c70fbfcf0..ba17f5f33f2b6 100644 --- a/sql/core/pom.xml +++ b/sql/core/pom.xml @@ -38,7 +38,7 @@ com.univocity univocity-parsers - 2.6.3 + 2.7.3 jar @@ -118,7 +118,7 @@ org.apache.xbean - xbean-asm5-shaded + xbean-asm6-shaded org.scalacheck @@ -146,19 +146,6 @@ parquet-avro test - - - org.apache.avro - avro - 1.8.1 - test - org.mockito mockito-core diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java index c7c4c7b3e7715..c8cf44b51df77 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java @@ -20,8 +20,8 @@ import java.io.IOException; import org.apache.spark.SparkEnv; +import org.apache.spark.TaskContext; import org.apache.spark.internal.config.package$; -import org.apache.spark.memory.TaskMemoryManager; import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.catalyst.expressions.UnsafeProjection; import org.apache.spark.sql.catalyst.expressions.UnsafeRow; @@ -82,7 +82,7 @@ public static boolean supportsAggregationBufferSchema(StructType schema) { * @param emptyAggregationBuffer the default value for new keys (a "zero" of the agg. function) * @param aggregationBufferSchema the schema of the aggregation buffer, used for row conversion. * @param groupingKeySchema the schema of the grouping key, used for row conversion. - * @param taskMemoryManager the memory manager used to allocate our Unsafe memory structures. + * @param taskContext the current task context. * @param initialCapacity the initial capacity of the map (a sizing hint to avoid re-hashing). * @param pageSizeBytes the data page size, in bytes; limits the maximum record size. */ @@ -90,19 +90,26 @@ public UnsafeFixedWidthAggregationMap( InternalRow emptyAggregationBuffer, StructType aggregationBufferSchema, StructType groupingKeySchema, - TaskMemoryManager taskMemoryManager, + TaskContext taskContext, int initialCapacity, long pageSizeBytes) { this.aggregationBufferSchema = aggregationBufferSchema; this.currentAggregationBuffer = new UnsafeRow(aggregationBufferSchema.length()); this.groupingKeyProjection = UnsafeProjection.create(groupingKeySchema); this.groupingKeySchema = groupingKeySchema; - this.map = - new BytesToBytesMap(taskMemoryManager, initialCapacity, pageSizeBytes, true); + this.map = new BytesToBytesMap( + taskContext.taskMemoryManager(), initialCapacity, pageSizeBytes, true); // Initialize the buffer for aggregation value final UnsafeProjection valueProjection = UnsafeProjection.create(aggregationBufferSchema); this.emptyAggregationBuffer = valueProjection.apply(emptyAggregationBuffer).getBytes(); + + // Register a cleanup task with TaskContext to ensure that memory is guaranteed to be freed at + // the end of the task. This is necessary to avoid memory leaks in when the downstream operator + // does not fully consume the aggregation map's output (e.g. aggregate followed by limit). + taskContext.addTaskCompletionListener(context -> { + free(); + }); } /** diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java index d5969b55eef96..ba26b57567e64 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java @@ -167,6 +167,8 @@ void readBatch(int total, WritableColumnVector column) throws IOException { leftInPage = (int) (endOfPageValueCount - valuesRead); } int num = Math.min(total, leftInPage); + PrimitiveType.PrimitiveTypeName typeName = + descriptor.getPrimitiveType().getPrimitiveTypeName(); if (isCurrentPageDictionaryEncoded) { // Read and decode dictionary ids. defColumn.readIntegers( @@ -175,12 +177,12 @@ void readBatch(int total, WritableColumnVector column) throws IOException { // TIMESTAMP_MILLIS encoded as INT64 can't be lazily decoded as we need to post process // the values to add microseconds precision. if (column.hasDictionary() || (rowId == 0 && - (descriptor.getType() == PrimitiveType.PrimitiveTypeName.INT32 || - (descriptor.getType() == PrimitiveType.PrimitiveTypeName.INT64 && + (typeName == PrimitiveType.PrimitiveTypeName.INT32 || + (typeName == PrimitiveType.PrimitiveTypeName.INT64 && originalType != OriginalType.TIMESTAMP_MILLIS) || - descriptor.getType() == PrimitiveType.PrimitiveTypeName.FLOAT || - descriptor.getType() == PrimitiveType.PrimitiveTypeName.DOUBLE || - descriptor.getType() == PrimitiveType.PrimitiveTypeName.BINARY))) { + typeName == PrimitiveType.PrimitiveTypeName.FLOAT || + typeName == PrimitiveType.PrimitiveTypeName.DOUBLE || + typeName == PrimitiveType.PrimitiveTypeName.BINARY))) { // Column vector supports lazy decoding of dictionary values so just set the dictionary. // We can't do this if rowId != 0 AND the column doesn't have a dictionary (i.e. some // non-dictionary encoded values have already been added). @@ -195,7 +197,7 @@ void readBatch(int total, WritableColumnVector column) throws IOException { decodeDictionaryIds(0, rowId, column, column.getDictionaryIds()); } column.setDictionary(null); - switch (descriptor.getType()) { + switch (typeName) { case BOOLEAN: readBooleanBatch(rowId, num, column); break; @@ -218,10 +220,11 @@ void readBatch(int total, WritableColumnVector column) throws IOException { readBinaryBatch(rowId, num, column); break; case FIXED_LEN_BYTE_ARRAY: - readFixedLenByteArrayBatch(rowId, num, column, descriptor.getTypeLength()); + readFixedLenByteArrayBatch( + rowId, num, column, descriptor.getPrimitiveType().getTypeLength()); break; default: - throw new IOException("Unsupported type: " + descriptor.getType()); + throw new IOException("Unsupported type: " + typeName); } } @@ -243,8 +246,8 @@ private SchemaColumnConvertNotSupportedException constructConvertNotSupportedExc WritableColumnVector column) { return new SchemaColumnConvertNotSupportedException( Arrays.toString(descriptor.getPath()), - descriptor.getType().toString(), - column.dataType().toString()); + descriptor.getPrimitiveType().getPrimitiveTypeName().toString(), + column.dataType().catalogString()); } /** @@ -255,7 +258,7 @@ private void decodeDictionaryIds( int num, WritableColumnVector column, WritableColumnVector dictionaryIds) { - switch (descriptor.getType()) { + switch (descriptor.getPrimitiveType().getPrimitiveTypeName()) { case INT32: if (column.dataType() == DataTypes.IntegerType || DecimalType.is32BitDecimalType(column.dataType())) { @@ -381,7 +384,8 @@ private void decodeDictionaryIds( break; default: - throw new UnsupportedOperationException("Unsupported type: " + descriptor.getType()); + throw new UnsupportedOperationException( + "Unsupported type: " + descriptor.getPrimitiveType().getPrimitiveTypeName()); } } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java index 5934a23db8af1..f02861355c404 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java @@ -270,21 +270,23 @@ public boolean nextBatch() throws IOException { private void initializeInternal() throws IOException, UnsupportedOperationException { // Check that the requested schema is supported. missingColumns = new boolean[requestedSchema.getFieldCount()]; + List columns = requestedSchema.getColumns(); + List paths = requestedSchema.getPaths(); for (int i = 0; i < requestedSchema.getFieldCount(); ++i) { Type t = requestedSchema.getFields().get(i); if (!t.isPrimitive() || t.isRepetition(Type.Repetition.REPEATED)) { throw new UnsupportedOperationException("Complex types not supported."); } - String[] colPath = requestedSchema.getPaths().get(i); + String[] colPath = paths.get(i); if (fileSchema.containsPath(colPath)) { ColumnDescriptor fd = fileSchema.getColumnDescription(colPath); - if (!fd.equals(requestedSchema.getColumns().get(i))) { + if (!fd.equals(columns.get(i))) { throw new UnsupportedOperationException("Schema evolution not supported."); } missingColumns[i] = false; } else { - if (requestedSchema.getColumns().get(i).getMaxDefinitionLevel() == 0) { + if (columns.get(i).getMaxDefinitionLevel() == 0) { // Column is missing in data but the required data is non-nullable. This file is invalid. throw new IOException("Required column is missing in data file. Col: " + Arrays.toString(colPath)); diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java index 6fdadde628551..5e0cf7d370dd1 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java @@ -23,7 +23,6 @@ import org.apache.spark.sql.types.*; import org.apache.spark.unsafe.Platform; -import org.apache.spark.unsafe.memory.OffHeapMemoryBlock; import org.apache.spark.unsafe.types.UTF8String; /** @@ -207,7 +206,7 @@ public byte[] getBytes(int rowId, int count) { @Override protected UTF8String getBytesAsUTF8String(int rowId, int count) { - return new UTF8String(new OffHeapMemoryBlock(data + rowId, count)); + return UTF8String.fromAddress(null, data + rowId, count); } // diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/BatchReadSupportProvider.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/BatchReadSupportProvider.java new file mode 100644 index 0000000000000..f403dc619e86c --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/BatchReadSupportProvider.java @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2; + +import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Utils; +import org.apache.spark.sql.sources.v2.reader.BatchReadSupport; +import org.apache.spark.sql.types.StructType; + +/** + * A mix-in interface for {@link DataSourceV2}. Data sources can implement this interface to + * provide data reading ability for batch processing. + * + * This interface is used to create {@link BatchReadSupport} instances when end users run + * {@code SparkSession.read.format(...).option(...).load()}. + */ +@InterfaceStability.Evolving +public interface BatchReadSupportProvider extends DataSourceV2 { + + /** + * Creates a {@link BatchReadSupport} instance to load the data from this data source with a user + * specified schema, which is called by Spark at the beginning of each batch query. + * + * Spark will call this method at the beginning of each batch query to create a + * {@link BatchReadSupport} instance. + * + * By default this method throws {@link UnsupportedOperationException}, implementations should + * override this method to handle user specified schema. + * + * @param schema the user specified schema. + * @param options the options for the returned data source reader, which is an immutable + * case-insensitive string-to-string map. + */ + default BatchReadSupport createBatchReadSupport(StructType schema, DataSourceOptions options) { + return DataSourceV2Utils.failForUserSpecifiedSchema(this); + } + + /** + * Creates a {@link BatchReadSupport} instance to scan the data from this data source, which is + * called by Spark at the beginning of each batch query. + * + * @param options the options for the returned data source reader, which is an immutable + * case-insensitive string-to-string map. + */ + BatchReadSupport createBatchReadSupport(DataSourceOptions options); +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/WriteSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/BatchWriteSupportProvider.java similarity index 58% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/WriteSupport.java rename to sql/core/src/main/java/org/apache/spark/sql/sources/v2/BatchWriteSupportProvider.java index 83aeec0c47853..bd10c3353bf12 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/WriteSupport.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/BatchWriteSupportProvider.java @@ -21,32 +21,39 @@ import org.apache.spark.annotation.InterfaceStability; import org.apache.spark.sql.SaveMode; -import org.apache.spark.sql.sources.v2.writer.DataSourceWriter; +import org.apache.spark.sql.sources.v2.writer.BatchWriteSupport; import org.apache.spark.sql.types.StructType; /** * A mix-in interface for {@link DataSourceV2}. Data sources can implement this interface to - * provide data writing ability and save the data to the data source. + * provide data writing ability for batch processing. + * + * This interface is used to create {@link BatchWriteSupport} instances when end users run + * {@code Dataset.write.format(...).option(...).save()}. */ @InterfaceStability.Evolving -public interface WriteSupport extends DataSourceV2 { +public interface BatchWriteSupportProvider extends DataSourceV2 { /** - * Creates an optional {@link DataSourceWriter} to save the data to this data source. Data - * sources can return None if there is no writing needed to be done according to the save mode. + * Creates an optional {@link BatchWriteSupport} instance to save the data to this data source, + * which is called by Spark at the beginning of each batch query. * - * If this method fails (by throwing an exception), the action will fail and no Spark job will be - * submitted. + * Data sources can return None if there is no writing needed to be done according to the save + * mode. * - * @param jobId A unique string for the writing job. It's possible that there are many writing - * jobs running at the same time, and the returned {@link DataSourceWriter} can - * use this job id to distinguish itself from other jobs. + * @param queryId A unique string for the writing query. It's possible that there are many + * writing queries running at the same time, and the returned + * {@link BatchWriteSupport} can use this id to distinguish itself from others. * @param schema the schema of the data to be written. * @param mode the save mode which determines what to do when the data are already in this data * source, please refer to {@link SaveMode} for more details. * @param options the options for the returned data source writer, which is an immutable * case-insensitive string-to-string map. + * @return a write support to write data to this data source. */ - Optional createWriter( - String jobId, StructType schema, SaveMode mode, DataSourceOptions options); + Optional createBatchWriteSupport( + String queryId, + StructType schema, + SaveMode mode, + DataSourceOptions options); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ContinuousReadSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ContinuousReadSupport.java deleted file mode 100644 index 7df5a451ae5f3..0000000000000 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ContinuousReadSupport.java +++ /dev/null @@ -1,46 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.sources.v2; - -import java.util.Optional; - -import org.apache.spark.annotation.InterfaceStability; -import org.apache.spark.sql.sources.v2.reader.streaming.ContinuousReader; -import org.apache.spark.sql.types.StructType; - -/** - * A mix-in interface for {@link DataSourceV2}. Data sources can implement this interface to - * provide data reading ability for continuous stream processing. - */ -@InterfaceStability.Evolving -public interface ContinuousReadSupport extends DataSourceV2 { - /** - * Creates a {@link ContinuousReader} to scan the data from this data source. - * - * @param schema the user provided schema, or empty() if none was provided - * @param checkpointLocation a path to Hadoop FS scratch space that can be used for failure - * recovery. Readers for the same logical source in the same query - * will be given the same checkpointLocation. - * @param options the options for the returned data source reader, which is an immutable - * case-insensitive string-to-string map. - */ - ContinuousReader createContinuousReader( - Optional schema, - String checkpointLocation, - DataSourceOptions options); -} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ContinuousReadSupportProvider.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ContinuousReadSupportProvider.java new file mode 100644 index 0000000000000..824c290518acf --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ContinuousReadSupportProvider.java @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2; + +import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Utils; +import org.apache.spark.sql.sources.v2.reader.streaming.ContinuousReadSupport; +import org.apache.spark.sql.types.StructType; + +/** + * A mix-in interface for {@link DataSourceV2}. Data sources can implement this interface to + * provide data reading ability for continuous stream processing. + * + * This interface is used to create {@link ContinuousReadSupport} instances when end users run + * {@code SparkSession.readStream.format(...).option(...).load()} with a continuous trigger. + */ +@InterfaceStability.Evolving +public interface ContinuousReadSupportProvider extends DataSourceV2 { + + /** + * Creates a {@link ContinuousReadSupport} instance to scan the data from this streaming data + * source with a user specified schema, which is called by Spark at the beginning of each + * continuous streaming query. + * + * By default this method throws {@link UnsupportedOperationException}, implementations should + * override this method to handle user specified schema. + * + * @param schema the user provided schema. + * @param checkpointLocation a path to Hadoop FS scratch space that can be used for failure + * recovery. Readers for the same logical source in the same query + * will be given the same checkpointLocation. + * @param options the options for the returned data source reader, which is an immutable + * case-insensitive string-to-string map. + */ + default ContinuousReadSupport createContinuousReadSupport( + StructType schema, + String checkpointLocation, + DataSourceOptions options) { + return DataSourceV2Utils.failForUserSpecifiedSchema(this); + } + + /** + * Creates a {@link ContinuousReadSupport} instance to scan the data from this streaming data + * source, which is called by Spark at the beginning of each continuous streaming query. + * + * @param checkpointLocation a path to Hadoop FS scratch space that can be used for failure + * recovery. Readers for the same logical source in the same query + * will be given the same checkpointLocation. + * @param options the options for the returned data source reader, which is an immutable + * case-insensitive string-to-string map. + */ + ContinuousReadSupport createContinuousReadSupport( + String checkpointLocation, + DataSourceOptions options); +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/DataSourceV2.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/DataSourceV2.java index 6234071320dc9..6e31e84bf6c72 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/DataSourceV2.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/DataSourceV2.java @@ -22,9 +22,13 @@ /** * The base interface for data source v2. Implementations must have a public, 0-arg constructor. * - * Note that this is an empty interface. Data source implementations should mix-in at least one of - * the plug-in interfaces like {@link ReadSupport} and {@link WriteSupport}. Otherwise it's just - * a dummy data source which is un-readable/writable. + * Note that this is an empty interface. Data source implementations must mix in interfaces such as + * {@link BatchReadSupportProvider} or {@link BatchWriteSupportProvider}, which can provide + * batch or streaming read/write support instances. Otherwise it's just a dummy data source which + * is un-readable/writable. + * + * If Spark fails to execute any methods in the implementations of this interface (by throwing an + * exception), the read action will fail and no Spark job will be submitted. */ @InterfaceStability.Evolving public interface DataSourceV2 {} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/MicroBatchReadSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/MicroBatchReadSupport.java deleted file mode 100644 index 7f4a2c9593c76..0000000000000 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/MicroBatchReadSupport.java +++ /dev/null @@ -1,52 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.sources.v2; - -import java.util.Optional; - -import org.apache.spark.annotation.InterfaceStability; -import org.apache.spark.sql.sources.v2.reader.streaming.MicroBatchReader; -import org.apache.spark.sql.types.StructType; - -/** - * A mix-in interface for {@link DataSourceV2}. Data sources can implement this interface to - * provide streaming micro-batch data reading ability. - */ -@InterfaceStability.Evolving -public interface MicroBatchReadSupport extends DataSourceV2 { - /** - * Creates a {@link MicroBatchReader} to read batches of data from this data source in a - * streaming query. - * - * The execution engine will create a micro-batch reader at the start of a streaming query, - * alternate calls to setOffsetRange and planInputPartitions for each batch to process, and - * then call stop() when the execution is complete. Note that a single query may have multiple - * executions due to restart or failure recovery. - * - * @param schema the user provided schema, or empty() if none was provided - * @param checkpointLocation a path to Hadoop FS scratch space that can be used for failure - * recovery. Readers for the same logical source in the same query - * will be given the same checkpointLocation. - * @param options the options for the returned data source reader, which is an immutable - * case-insensitive string-to-string map. - */ - MicroBatchReader createMicroBatchReader( - Optional schema, - String checkpointLocation, - DataSourceOptions options); -} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/MicroBatchReadSupportProvider.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/MicroBatchReadSupportProvider.java new file mode 100644 index 0000000000000..61c08e7fa89df --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/MicroBatchReadSupportProvider.java @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2; + +import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Utils; +import org.apache.spark.sql.sources.v2.reader.streaming.MicroBatchReadSupport; +import org.apache.spark.sql.types.StructType; + +/** + * A mix-in interface for {@link DataSourceV2}. Data sources can implement this interface to + * provide data reading ability for micro-batch stream processing. + * + * This interface is used to create {@link MicroBatchReadSupport} instances when end users run + * {@code SparkSession.readStream.format(...).option(...).load()} with a micro-batch trigger. + */ +@InterfaceStability.Evolving +public interface MicroBatchReadSupportProvider extends DataSourceV2 { + + /** + * Creates a {@link MicroBatchReadSupport} instance to scan the data from this streaming data + * source with a user specified schema, which is called by Spark at the beginning of each + * micro-batch streaming query. + * + * By default this method throws {@link UnsupportedOperationException}, implementations should + * override this method to handle user specified schema. + * + * @param schema the user provided schema. + * @param checkpointLocation a path to Hadoop FS scratch space that can be used for failure + * recovery. Readers for the same logical source in the same query + * will be given the same checkpointLocation. + * @param options the options for the returned data source reader, which is an immutable + * case-insensitive string-to-string map. + */ + default MicroBatchReadSupport createMicroBatchReadSupport( + StructType schema, + String checkpointLocation, + DataSourceOptions options) { + return DataSourceV2Utils.failForUserSpecifiedSchema(this); + } + + /** + * Creates a {@link MicroBatchReadSupport} instance to scan the data from this streaming data + * source, which is called by Spark at the beginning of each micro-batch streaming query. + * + * @param checkpointLocation a path to Hadoop FS scratch space that can be used for failure + * recovery. Readers for the same logical source in the same query + * will be given the same checkpointLocation. + * @param options the options for the returned data source reader, which is an immutable + * case-insensitive string-to-string map. + */ + MicroBatchReadSupport createMicroBatchReadSupport( + String checkpointLocation, + DataSourceOptions options); +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ReadSupportWithSchema.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ReadSupportWithSchema.java deleted file mode 100644 index f31659904cc53..0000000000000 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ReadSupportWithSchema.java +++ /dev/null @@ -1,49 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.sources.v2; - -import org.apache.spark.annotation.InterfaceStability; -import org.apache.spark.sql.sources.v2.reader.DataSourceReader; -import org.apache.spark.sql.types.StructType; - -/** - * A mix-in interface for {@link DataSourceV2}. Data sources can implement this interface to - * provide data reading ability and scan the data from the data source. - * - * This is a variant of {@link ReadSupport} that accepts user-specified schema when reading data. - * A data source can implement both {@link ReadSupport} and {@link ReadSupportWithSchema} if it - * supports both schema inference and user-specified schema. - */ -@InterfaceStability.Evolving -public interface ReadSupportWithSchema extends DataSourceV2 { - - /** - * Create a {@link DataSourceReader} to scan the data from this data source. - * - * If this method fails (by throwing an exception), the action will fail and no Spark job will be - * submitted. - * - * @param schema the full schema of this data source reader. Full schema usually maps to the - * physical schema of the underlying storage of this data source reader, e.g. - * CSV files, JSON files, etc, while this reader may not read data with full - * schema, as column pruning or other optimizations may happen. - * @param options the options for the returned data source reader, which is an immutable - * case-insensitive string-to-string map. - */ - DataSourceReader createReader(StructType schema, DataSourceOptions options); -} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SessionConfigSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SessionConfigSupport.java index 9d66805d79b9e..bbe430e299261 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SessionConfigSupport.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SessionConfigSupport.java @@ -27,10 +27,11 @@ @InterfaceStability.Evolving public interface SessionConfigSupport extends DataSourceV2 { - /** - * Key prefix of the session configs to propagate. Spark will extract all session configs that - * starts with `spark.datasource.$keyPrefix`, turn `spark.datasource.$keyPrefix.xxx -> yyy` - * into `xxx -> yyy`, and propagate them to all data source operations in this session. - */ - String keyPrefix(); + /** + * Key prefix of the session configs to propagate, which is usually the data source name. Spark + * will extract all session configs that starts with `spark.datasource.$keyPrefix`, turn + * `spark.datasource.$keyPrefix.xxx -> yyy` into `xxx -> yyy`, and propagate them to all + * data source operations in this session. + */ + String keyPrefix(); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/StreamWriteSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/StreamWriteSupport.java deleted file mode 100644 index a77b01497269e..0000000000000 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/StreamWriteSupport.java +++ /dev/null @@ -1,52 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.sources.v2; - -import org.apache.spark.annotation.InterfaceStability; -import org.apache.spark.sql.execution.streaming.BaseStreamingSink; -import org.apache.spark.sql.sources.v2.writer.DataSourceWriter; -import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter; -import org.apache.spark.sql.streaming.OutputMode; -import org.apache.spark.sql.types.StructType; - -/** - * A mix-in interface for {@link DataSourceV2}. Data sources can implement this interface to - * provide data writing ability for structured streaming. - */ -@InterfaceStability.Evolving -public interface StreamWriteSupport extends DataSourceV2, BaseStreamingSink { - - /** - * Creates an optional {@link StreamWriter} to save the data to this data source. Data - * sources can return None if there is no writing needed to be done. - * - * @param queryId A unique string for the writing query. It's possible that there are many - * writing queries running at the same time, and the returned - * {@link DataSourceWriter} can use this id to distinguish itself from others. - * @param schema the schema of the data to be written. - * @param mode the output mode which determines what successive epoch output means to this - * sink, please refer to {@link OutputMode} for more details. - * @param options the options for the returned data source writer, which is an immutable - * case-insensitive string-to-string map. - */ - StreamWriter createStreamWriter( - String queryId, - StructType schema, - OutputMode mode, - DataSourceOptions options); -} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/StreamingWriteSupportProvider.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/StreamingWriteSupportProvider.java new file mode 100644 index 0000000000000..f9ca85d8089b4 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/StreamingWriteSupportProvider.java @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2; + +import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.sql.execution.streaming.BaseStreamingSink; +import org.apache.spark.sql.sources.v2.writer.streaming.StreamingWriteSupport; +import org.apache.spark.sql.streaming.OutputMode; +import org.apache.spark.sql.types.StructType; + +/** + * A mix-in interface for {@link DataSourceV2}. Data sources can implement this interface to + * provide data writing ability for structured streaming. + * + * This interface is used to create {@link StreamingWriteSupport} instances when end users run + * {@code Dataset.writeStream.format(...).option(...).start()}. + */ +@InterfaceStability.Evolving +public interface StreamingWriteSupportProvider extends DataSourceV2, BaseStreamingSink { + + /** + * Creates a {@link StreamingWriteSupport} instance to save the data to this data source, which is + * called by Spark at the beginning of each streaming query. + * + * @param queryId A unique string for the writing query. It's possible that there are many + * writing queries running at the same time, and the returned + * {@link StreamingWriteSupport} can use this id to distinguish itself from others. + * @param schema the schema of the data to be written. + * @param mode the output mode which determines what successive epoch output means to this + * sink, please refer to {@link OutputMode} for more details. + * @param options the options for the returned data source writer, which is an immutable + * case-insensitive string-to-string map. + */ + StreamingWriteSupport createStreamingWriteSupport( + String queryId, + StructType schema, + OutputMode mode, + DataSourceOptions options); +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/BatchReadSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/BatchReadSupport.java new file mode 100644 index 0000000000000..452ee86675b42 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/BatchReadSupport.java @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2.reader; + +import org.apache.spark.annotation.InterfaceStability; + +/** + * An interface that defines how to load the data from data source for batch processing. + * + * The execution engine will get an instance of this interface from a data source provider + * (e.g. {@link org.apache.spark.sql.sources.v2.BatchReadSupportProvider}) at the start of a batch + * query, then call {@link #newScanConfigBuilder()} and create an instance of {@link ScanConfig}. + * The {@link ScanConfigBuilder} can apply operator pushdown and keep the pushdown result in + * {@link ScanConfig}. The {@link ScanConfig} will be used to create input partitions and reader + * factory to scan data from the data source with a Spark job. + */ +@InterfaceStability.Evolving +public interface BatchReadSupport extends ReadSupport { + + /** + * Returns a builder of {@link ScanConfig}. Spark will call this method and create a + * {@link ScanConfig} for each data scanning job. + * + * The builder can take some query specific information to do operators pushdown, and keep these + * information in the created {@link ScanConfig}. + * + * This is the first step of the data scan. All other methods in {@link BatchReadSupport} needs + * to take {@link ScanConfig} as an input. + */ + ScanConfigBuilder newScanConfigBuilder(); + + /** + * Returns a factory, which produces one {@link PartitionReader} for one {@link InputPartition}. + */ + PartitionReaderFactory createReaderFactory(ScanConfig config); +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataSourceReader.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataSourceReader.java deleted file mode 100644 index 36a3e542b5a11..0000000000000 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataSourceReader.java +++ /dev/null @@ -1,80 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.sources.v2.reader; - -import java.util.List; - -import org.apache.spark.annotation.InterfaceStability; -import org.apache.spark.sql.Row; -import org.apache.spark.sql.sources.v2.DataSourceOptions; -import org.apache.spark.sql.sources.v2.ReadSupport; -import org.apache.spark.sql.sources.v2.ReadSupportWithSchema; -import org.apache.spark.sql.types.StructType; - -/** - * A data source reader that is returned by - * {@link ReadSupport#createReader(DataSourceOptions)} or - * {@link ReadSupportWithSchema#createReader(StructType, DataSourceOptions)}. - * It can mix in various query optimization interfaces to speed up the data scan. The actual scan - * logic is delegated to {@link InputPartition}s, which are returned by - * {@link #planInputPartitions()}. - * - * There are mainly 3 kinds of query optimizations: - * 1. Operators push-down. E.g., filter push-down, required columns push-down(aka column - * pruning), etc. Names of these interfaces start with `SupportsPushDown`. - * 2. Information Reporting. E.g., statistics reporting, ordering reporting, etc. - * Names of these interfaces start with `SupportsReporting`. - * 3. Special scans. E.g, columnar scan, unsafe row scan, etc. - * Names of these interfaces start with `SupportsScan`. Note that a reader should only - * implement at most one of the special scans, if more than one special scans are implemented, - * only one of them would be respected, according to the priority list from high to low: - * {@link SupportsScanColumnarBatch}, {@link SupportsScanUnsafeRow}. - * - * If an exception was throw when applying any of these query optimizations, the action will fail - * and no Spark job will be submitted. - * - * Spark first applies all operator push-down optimizations that this data source supports. Then - * Spark collects information this data source reported for further optimizations. Finally Spark - * issues the scan request and does the actual data reading. - */ -@InterfaceStability.Evolving -public interface DataSourceReader { - - /** - * Returns the actual schema of this data source reader, which may be different from the physical - * schema of the underlying storage, as column pruning or other optimizations may happen. - * - * If this method fails (by throwing an exception), the action will fail and no Spark job will be - * submitted. - */ - StructType readSchema(); - - /** - * Returns a list of {@link InputPartition}s. Each {@link InputPartition} is responsible for - * creating a data reader to output data of one RDD partition. The number of input partitions - * returned here is the same as the number of RDD partitions this scan outputs. - * - * Note that, this may not be a full scan if the data source reader mixes in other optimization - * interfaces like column pruning, filter push-down, etc. These optimizations are applied before - * Spark issues the scan request. - * - * If this method fails (by throwing an exception), the action will fail and no Spark job will be - * submitted. - */ - List> planInputPartitions(); -} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/InputPartition.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/InputPartition.java index f2038d0de3ffe..95c30de907e44 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/InputPartition.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/InputPartition.java @@ -22,18 +22,18 @@ import org.apache.spark.annotation.InterfaceStability; /** - * An input partition returned by {@link DataSourceReader#planInputPartitions()} and is - * responsible for creating the actual data reader of one RDD partition. - * The relationship between {@link InputPartition} and {@link InputPartitionReader} - * is similar to the relationship between {@link Iterable} and {@link java.util.Iterator}. + * A serializable representation of an input partition returned by + * {@link ReadSupport#planInputPartitions(ScanConfig)}. * - * Note that {@link InputPartition}s will be serialized and sent to executors, then - * {@link InputPartitionReader}s will be created on executors to do the actual reading. So - * {@link InputPartition} must be serializable while {@link InputPartitionReader} doesn't need to - * be. + * Note that {@link InputPartition} will be serialized and sent to executors, then + * {@link PartitionReader} will be created by + * {@link PartitionReaderFactory#createReader(InputPartition)} or + * {@link PartitionReaderFactory#createColumnarReader(InputPartition)} on executors to do + * the actual reading. So {@link InputPartition} must be serializable while {@link PartitionReader} + * doesn't need to be. */ @InterfaceStability.Evolving -public interface InputPartition extends Serializable { +public interface InputPartition extends Serializable { /** * The preferred locations where the input partition reader returned by this partition can run @@ -51,12 +51,4 @@ public interface InputPartition extends Serializable { default String[] preferredLocations() { return new String[0]; } - - /** - * Returns an input partition reader to do the actual reading work. - * - * If this method fails (by throwing an exception), the corresponding Spark task would fail and - * get retried until hitting the maximum retry times. - */ - InputPartitionReader createPartitionReader(); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/InputPartitionReader.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/PartitionReader.java similarity index 65% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/InputPartitionReader.java rename to sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/PartitionReader.java index 33fa7be4c1b20..04ff8d0a19fc3 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/InputPartitionReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/PartitionReader.java @@ -23,31 +23,27 @@ import org.apache.spark.annotation.InterfaceStability; /** - * An input partition reader returned by {@link InputPartition#createPartitionReader()} and is - * responsible for outputting data for a RDD partition. + * A partition reader returned by {@link PartitionReaderFactory#createReader(InputPartition)} or + * {@link PartitionReaderFactory#createColumnarReader(InputPartition)}. It's responsible for + * outputting data for a RDD partition. * - * Note that, Currently the type `T` can only be {@link org.apache.spark.sql.Row} for normal input - * partition readers, or {@link org.apache.spark.sql.catalyst.expressions.UnsafeRow} for input - * partition readers that mix in {@link SupportsScanUnsafeRow}. + * Note that, Currently the type `T` can only be {@link org.apache.spark.sql.catalyst.InternalRow} + * for normal data sources, or {@link org.apache.spark.sql.vectorized.ColumnarBatch} for columnar + * data sources(whose {@link PartitionReaderFactory#supportColumnarReads(InputPartition)} + * returns true). */ @InterfaceStability.Evolving -public interface InputPartitionReader extends Closeable { +public interface PartitionReader extends Closeable { /** * Proceed to next record, returns false if there is no more records. * - * If this method fails (by throwing an exception), the corresponding Spark task would fail and - * get retried until hitting the maximum retry times. - * * @throws IOException if failure happens during disk/network IO like reading files. */ boolean next() throws IOException; /** * Return the current record. This method should return same value until `next` is called. - * - * If this method fails (by throwing an exception), the corresponding Spark task would fail and - * get retried until hitting the maximum retry times. */ T get(); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/PartitionReaderFactory.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/PartitionReaderFactory.java new file mode 100644 index 0000000000000..f35de9310eee3 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/PartitionReaderFactory.java @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2.reader; + +import java.io.Serializable; + +import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.vectorized.ColumnarBatch; + +/** + * A factory used to create {@link PartitionReader} instances. + * + * If Spark fails to execute any methods in the implementations of this interface or in the returned + * {@link PartitionReader} (by throwing an exception), corresponding Spark task would fail and + * get retried until hitting the maximum retry times. + */ +@InterfaceStability.Evolving +public interface PartitionReaderFactory extends Serializable { + + /** + * Returns a row-based partition reader to read data from the given {@link InputPartition}. + * + * Implementations probably need to cast the input partition to the concrete + * {@link InputPartition} class defined for the data source. + */ + PartitionReader createReader(InputPartition partition); + + /** + * Returns a columnar partition reader to read data from the given {@link InputPartition}. + * + * Implementations probably need to cast the input partition to the concrete + * {@link InputPartition} class defined for the data source. + */ + default PartitionReader createColumnarReader(InputPartition partition) { + throw new UnsupportedOperationException("Cannot create columnar reader."); + } + + /** + * Returns true if the given {@link InputPartition} should be read by Spark in a columnar way. + * This means, implementations must also implement {@link #createColumnarReader(InputPartition)} + * for the input partitions that this method returns true. + * + * As of Spark 2.4, Spark can only read all input partition in a columnar way, or none of them. + * Data source can't mix columnar and row-based partitions. This may be relaxed in future + * versions. + */ + default boolean supportColumnarReads(InputPartition partition) { + return false; + } +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ReadSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ReadSupport.java new file mode 100644 index 0000000000000..a58ddb288f1ed --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ReadSupport.java @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2.reader; + +import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.sql.types.StructType; + +/** + * The base interface for all the batch and streaming read supports. Data sources should implement + * concrete read support interfaces like {@link BatchReadSupport}. + * + * If Spark fails to execute any methods in the implementations of this interface (by throwing an + * exception), the read action will fail and no Spark job will be submitted. + */ +@InterfaceStability.Evolving +public interface ReadSupport { + + /** + * Returns the full schema of this data source, which is usually the physical schema of the + * underlying storage. This full schema should not be affected by column pruning or other + * optimizations. + */ + StructType fullSchema(); + + /** + * Returns a list of {@link InputPartition input partitions}. Each {@link InputPartition} + * represents a data split that can be processed by one Spark task. The number of input + * partitions returned here is the same as the number of RDD partitions this scan outputs. + * + * Note that, this may not be a full scan if the data source supports optimization like filter + * push-down. Implementations should check the input {@link ScanConfig} and adjust the resulting + * {@link InputPartition input partitions}. + */ + InputPartition[] planInputPartitions(ScanConfig config); +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ReadSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ScanConfig.java similarity index 51% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/ReadSupport.java rename to sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ScanConfig.java index b2526ded53d92..7462ce2820585 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ReadSupport.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ScanConfig.java @@ -15,26 +15,31 @@ * limitations under the License. */ -package org.apache.spark.sql.sources.v2; +package org.apache.spark.sql.sources.v2.reader; import org.apache.spark.annotation.InterfaceStability; -import org.apache.spark.sql.sources.v2.reader.DataSourceReader; +import org.apache.spark.sql.types.StructType; /** - * A mix-in interface for {@link DataSourceV2}. Data sources can implement this interface to - * provide data reading ability and scan the data from the data source. + * An interface that carries query specific information for the data scanning job, like operator + * pushdown information and streaming query offsets. This is defined as an empty interface, and data + * sources should define their own {@link ScanConfig} classes. + * + * For APIs that take a {@link ScanConfig} as input, like + * {@link ReadSupport#planInputPartitions(ScanConfig)}, + * {@link BatchReadSupport#createReaderFactory(ScanConfig)} and + * {@link SupportsReportStatistics#estimateStatistics(ScanConfig)}, implementations mostly need to + * cast the input {@link ScanConfig} to the concrete {@link ScanConfig} class of the data source. */ @InterfaceStability.Evolving -public interface ReadSupport extends DataSourceV2 { +public interface ScanConfig { /** - * Creates a {@link DataSourceReader} to scan the data from this data source. + * Returns the actual schema of this data source reader, which may be different from the physical + * schema of the underlying storage, as column pruning or other optimizations may happen. * * If this method fails (by throwing an exception), the action will fail and no Spark job will be * submitted. - * - * @param options the options for the returned data source reader, which is an immutable - * case-insensitive string-to-string map. */ - DataSourceReader createReader(DataSourceOptions options); + StructType readSchema(); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ContinuousInputPartition.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ScanConfigBuilder.java similarity index 61% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ContinuousInputPartition.java rename to sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ScanConfigBuilder.java index dcb87715d0b6f..4c0eedfddfe22 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ContinuousInputPartition.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ScanConfigBuilder.java @@ -18,18 +18,13 @@ package org.apache.spark.sql.sources.v2.reader; import org.apache.spark.annotation.InterfaceStability; -import org.apache.spark.sql.sources.v2.reader.streaming.PartitionOffset; /** - * A mix-in interface for {@link InputPartition}. Continuous input partitions can - * implement this interface to provide creating {@link InputPartitionReader} with particular offset. + * An interface for building the {@link ScanConfig}. Implementations can mixin those + * SupportsPushDownXYZ interfaces to do operator pushdown, and keep the operator pushdown result in + * the returned {@link ScanConfig}. */ @InterfaceStability.Evolving -public interface ContinuousInputPartition extends InputPartition { - /** - * Create an input partition reader with particular offset as its startOffset. - * - * @param offset offset want to set as the input partition reader's startOffset. - */ - InputPartitionReader createContinuousReader(PartitionOffset offset); +public interface ScanConfigBuilder { + ScanConfig build(); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/Statistics.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/Statistics.java index e8cd7adbca071..44799c7d49137 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/Statistics.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/Statistics.java @@ -23,7 +23,7 @@ /** * An interface to represent statistics for a data source, which is returned by - * {@link SupportsReportStatistics#getStatistics()}. + * {@link SupportsReportStatistics#estimateStatistics(ScanConfig)}. */ @InterfaceStability.Evolving public interface Statistics { diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownCatalystFilters.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownCatalystFilters.java deleted file mode 100644 index 4543c143a9aca..0000000000000 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownCatalystFilters.java +++ /dev/null @@ -1,57 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.sources.v2.reader; - -import org.apache.spark.annotation.InterfaceStability; -import org.apache.spark.sql.catalyst.expressions.Expression; - -/** - * A mix-in interface for {@link DataSourceReader}. Data source readers can implement this - * interface to push down arbitrary expressions as predicates to the data source. - * This is an experimental and unstable interface as {@link Expression} is not public and may get - * changed in the future Spark versions. - * - * Note that, if data source readers implement both this interface and - * {@link SupportsPushDownFilters}, Spark will ignore {@link SupportsPushDownFilters} and only - * process this interface. - */ -@InterfaceStability.Unstable -public interface SupportsPushDownCatalystFilters extends DataSourceReader { - - /** - * Pushes down filters, and returns filters that need to be evaluated after scanning. - */ - Expression[] pushCatalystFilters(Expression[] filters); - - /** - * Returns the catalyst filters that are pushed to the data source via - * {@link #pushCatalystFilters(Expression[])}. - * - * There are 3 kinds of filters: - * 1. pushable filters which don't need to be evaluated again after scanning. - * 2. pushable filters which still need to be evaluated after scanning, e.g. parquet - * row group filter. - * 3. non-pushable filters. - * Both case 1 and 2 should be considered as pushed filters and should be returned by this method. - * - * It's possible that there is no filters in the query and - * {@link #pushCatalystFilters(Expression[])} is never called, empty array should be returned for - * this case. - */ - Expression[] pushedCatalystFilters(); -} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownFilters.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownFilters.java index b6a90a3d0b681..5e7985f645a06 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownFilters.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownFilters.java @@ -21,15 +21,11 @@ import org.apache.spark.sql.sources.Filter; /** - * A mix-in interface for {@link DataSourceReader}. Data source readers can implement this - * interface to push down filters to the data source and reduce the size of the data to be read. - * - * Note that, if data source readers implement both this interface and - * {@link SupportsPushDownCatalystFilters}, Spark will ignore this interface and only process - * {@link SupportsPushDownCatalystFilters}. + * A mix-in interface for {@link ScanConfigBuilder}. Data sources can implement this interface to + * push down filters to the data source and reduce the size of the data to be read. */ @InterfaceStability.Evolving -public interface SupportsPushDownFilters extends DataSourceReader { +public interface SupportsPushDownFilters extends ScanConfigBuilder { /** * Pushes down filters, and returns filters that need to be evaluated after scanning. diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownRequiredColumns.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownRequiredColumns.java index 427b4d00a1128..edb164937d6ef 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownRequiredColumns.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownRequiredColumns.java @@ -21,12 +21,12 @@ import org.apache.spark.sql.types.StructType; /** - * A mix-in interface for {@link DataSourceReader}. Data source readers can implement this + * A mix-in interface for {@link ScanConfigBuilder}. Data sources can implement this * interface to push down required columns to the data source and only read these columns during * scan to reduce the size of the data to be read. */ @InterfaceStability.Evolving -public interface SupportsPushDownRequiredColumns extends DataSourceReader { +public interface SupportsPushDownRequiredColumns extends ScanConfigBuilder { /** * Applies column pruning w.r.t. the given requiredSchema. @@ -35,8 +35,8 @@ public interface SupportsPushDownRequiredColumns extends DataSourceReader { * also OK to do the pruning partially, e.g., a data source may not be able to prune nested * fields, and only prune top-level columns. * - * Note that, data source readers should update {@link DataSourceReader#readSchema()} after - * applying column pruning. + * Note that, {@link ScanConfig#readSchema()} implementation should take care of the column + * pruning applied here. */ void pruneColumns(StructType requiredSchema); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportPartitioning.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportPartitioning.java index 6b60da7c4dc1d..db62cd4515362 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportPartitioning.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportPartitioning.java @@ -21,17 +21,17 @@ import org.apache.spark.sql.sources.v2.reader.partitioning.Partitioning; /** - * A mix in interface for {@link DataSourceReader}. Data source readers can implement this - * interface to report data partitioning and try to avoid shuffle at Spark side. + * A mix in interface for {@link BatchReadSupport}. Data sources can implement this interface to + * report data partitioning and try to avoid shuffle at Spark side. * - * Note that, when the reader creates exactly one {@link InputPartition}, Spark may avoid - * adding a shuffle even if the reader does not implement this interface. + * Note that, when a {@link ReadSupport} implementation creates exactly one {@link InputPartition}, + * Spark may avoid adding a shuffle even if the reader does not implement this interface. */ @InterfaceStability.Evolving -public interface SupportsReportPartitioning extends DataSourceReader { +public interface SupportsReportPartitioning extends ReadSupport { /** * Returns the output data partitioning that this reader guarantees. */ - Partitioning outputPartitioning(); + Partitioning outputPartitioning(ScanConfig config); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportStatistics.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportStatistics.java index 926396414816c..1831488ba096f 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportStatistics.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportStatistics.java @@ -20,18 +20,18 @@ import org.apache.spark.annotation.InterfaceStability; /** - * A mix in interface for {@link DataSourceReader}. Data source readers can implement this - * interface to report statistics to Spark. + * A mix in interface for {@link BatchReadSupport}. Data sources can implement this interface to + * report statistics to Spark. * - * Statistics are reported to the optimizer before any operator is pushed to the DataSourceReader. - * Implementations that return more accurate statistics based on pushed operators will not improve - * query performance until the planner can push operators before getting stats. + * As of Spark 2.4, statistics are reported to the optimizer before any operator is pushed to the + * data source. Implementations that return more accurate statistics based on pushed operators will + * not improve query performance until the planner can push operators before getting stats. */ @InterfaceStability.Evolving -public interface SupportsReportStatistics extends DataSourceReader { +public interface SupportsReportStatistics extends ReadSupport { /** - * Returns the basic statistics of this data source. + * Returns the estimated statistics of this data source scan. */ - Statistics getStatistics(); + Statistics estimateStatistics(ScanConfig config); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanColumnarBatch.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanColumnarBatch.java deleted file mode 100644 index 0faf81db24605..0000000000000 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanColumnarBatch.java +++ /dev/null @@ -1,53 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.sources.v2.reader; - -import java.util.List; - -import org.apache.spark.annotation.InterfaceStability; -import org.apache.spark.sql.Row; -import org.apache.spark.sql.vectorized.ColumnarBatch; - -/** - * A mix-in interface for {@link DataSourceReader}. Data source readers can implement this - * interface to output {@link ColumnarBatch} and make the scan faster. - */ -@InterfaceStability.Evolving -public interface SupportsScanColumnarBatch extends DataSourceReader { - @Override - default List> planInputPartitions() { - throw new IllegalStateException( - "planInputPartitions not supported by default within SupportsScanColumnarBatch."); - } - - /** - * Similar to {@link DataSourceReader#planInputPartitions()}, but returns columnar data - * in batches. - */ - List> planBatchInputPartitions(); - - /** - * Returns true if the concrete data source reader can read data in batch according to the scan - * properties like required columns, pushes filters, etc. It's possible that the implementation - * can only support some certain columns with certain types. Users can overwrite this method and - * {@link #planInputPartitions()} to fallback to normal read path under some conditions. - */ - default boolean enableBatchRead() { - return true; - } -} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanUnsafeRow.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanUnsafeRow.java deleted file mode 100644 index f2220f6d31093..0000000000000 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanUnsafeRow.java +++ /dev/null @@ -1,46 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.sources.v2.reader; - -import java.util.List; - -import org.apache.spark.annotation.InterfaceStability; -import org.apache.spark.sql.Row; -import org.apache.spark.sql.catalyst.expressions.UnsafeRow; - -/** - * A mix-in interface for {@link DataSourceReader}. Data source readers can implement this - * interface to output {@link UnsafeRow} directly and avoid the row copy at Spark side. - * This is an experimental and unstable interface, as {@link UnsafeRow} is not public and may get - * changed in the future Spark versions. - */ -@InterfaceStability.Unstable -public interface SupportsScanUnsafeRow extends DataSourceReader { - - @Override - default List> planInputPartitions() { - throw new IllegalStateException( - "planInputPartitions not supported by default within SupportsScanUnsafeRow"); - } - - /** - * Similar to {@link DataSourceReader#planInputPartitions()}, - * but returns data in unsafe row format. - */ - List> planUnsafeInputPartitions(); -} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/ClusteredDistribution.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/ClusteredDistribution.java index 38ca5fc6387b2..6764d4b7665c7 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/ClusteredDistribution.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/ClusteredDistribution.java @@ -18,12 +18,12 @@ package org.apache.spark.sql.sources.v2.reader.partitioning; import org.apache.spark.annotation.InterfaceStability; -import org.apache.spark.sql.sources.v2.reader.InputPartitionReader; +import org.apache.spark.sql.sources.v2.reader.PartitionReader; /** * A concrete implementation of {@link Distribution}. Represents a distribution where records that * share the same values for the {@link #clusteredColumns} will be produced by the same - * {@link InputPartitionReader}. + * {@link PartitionReader}. */ @InterfaceStability.Evolving public class ClusteredDistribution implements Distribution { diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/Distribution.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/Distribution.java index 5e32ba6952e1c..364a3f553923c 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/Distribution.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/Distribution.java @@ -18,14 +18,14 @@ package org.apache.spark.sql.sources.v2.reader.partitioning; import org.apache.spark.annotation.InterfaceStability; -import org.apache.spark.sql.sources.v2.reader.InputPartitionReader; +import org.apache.spark.sql.sources.v2.reader.PartitionReader; /** * An interface to represent data distribution requirement, which specifies how the records should - * be distributed among the data partitions (one {@link InputPartitionReader} outputs data for one + * be distributed among the data partitions (one {@link PartitionReader} outputs data for one * partition). * Note that this interface has nothing to do with the data ordering inside one - * partition(the output records of a single {@link InputPartitionReader}). + * partition(the output records of a single {@link PartitionReader}). * * The instance of this interface is created and provided by Spark, then consumed by * {@link Partitioning#satisfy(Distribution)}. This means data source developers don't need to diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/Partitioning.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/Partitioning.java index f460f6bfe3bb9..fb0b6f1df43bb 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/Partitioning.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/Partitioning.java @@ -19,12 +19,13 @@ import org.apache.spark.annotation.InterfaceStability; import org.apache.spark.sql.sources.v2.reader.InputPartition; +import org.apache.spark.sql.sources.v2.reader.ScanConfig; import org.apache.spark.sql.sources.v2.reader.SupportsReportPartitioning; /** * An interface to represent the output data partitioning for a data source, which is returned by - * {@link SupportsReportPartitioning#outputPartitioning()}. Note that this should work like a - * snapshot. Once created, it should be deterministic and always report the same number of + * {@link SupportsReportPartitioning#outputPartitioning(ScanConfig)}. Note that this should work + * like a snapshot. Once created, it should be deterministic and always report the same number of * partitions and the same "satisfy" result for a certain distribution. */ @InterfaceStability.Evolving diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousInputPartitionReader.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousPartitionReader.java similarity index 60% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousInputPartitionReader.java rename to sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousPartitionReader.java index 7b0ba0bbdda90..9101c8a44d34e 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousInputPartitionReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousPartitionReader.java @@ -18,19 +18,20 @@ package org.apache.spark.sql.sources.v2.reader.streaming; import org.apache.spark.annotation.InterfaceStability; -import org.apache.spark.sql.sources.v2.reader.InputPartitionReader; +import org.apache.spark.sql.sources.v2.reader.PartitionReader; /** - * A variation on {@link InputPartitionReader} for use with streaming in continuous processing mode. + * A variation on {@link PartitionReader} for use with continuous streaming processing. */ @InterfaceStability.Evolving -public interface ContinuousInputPartitionReader extends InputPartitionReader { - /** - * Get the offset of the current record, or the start offset if no records have been read. - * - * The execution engine will call this method along with get() to keep track of the current - * offset. When an epoch ends, the offset of the previous record in each partition will be saved - * as a restart checkpoint. - */ - PartitionOffset getOffset(); +public interface ContinuousPartitionReader extends PartitionReader { + + /** + * Get the offset of the current record, or the start offset if no records have been read. + * + * The execution engine will call this method along with get() to keep track of the current + * offset. When an epoch ends, the offset of the previous record in each partition will be saved + * as a restart checkpoint. + */ + PartitionOffset getOffset(); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/SupportsWriteInternalRow.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousPartitionReaderFactory.java similarity index 52% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/SupportsWriteInternalRow.java rename to sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousPartitionReaderFactory.java index d2cf7e01c08c8..2d9f1ca1686a1 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/SupportsWriteInternalRow.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousPartitionReaderFactory.java @@ -15,27 +15,26 @@ * limitations under the License. */ -package org.apache.spark.sql.sources.v2.writer; +package org.apache.spark.sql.sources.v2.reader.streaming; import org.apache.spark.annotation.InterfaceStability; -import org.apache.spark.sql.Row; import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.sources.v2.reader.InputPartition; +import org.apache.spark.sql.sources.v2.reader.PartitionReaderFactory; +import org.apache.spark.sql.vectorized.ColumnarBatch; /** - * A mix-in interface for {@link DataSourceWriter}. Data source writers can implement this - * interface to write {@link InternalRow} directly and avoid the row conversion at Spark side. - * This is an experimental and unstable interface, as {@link InternalRow} is not public and may get - * changed in the future Spark versions. + * A variation on {@link PartitionReaderFactory} that returns {@link ContinuousPartitionReader} + * instead of {@link org.apache.spark.sql.sources.v2.reader.PartitionReader}. It's used for + * continuous streaming processing. */ - -@InterfaceStability.Unstable -public interface SupportsWriteInternalRow extends DataSourceWriter { +@InterfaceStability.Evolving +public interface ContinuousPartitionReaderFactory extends PartitionReaderFactory { + @Override + ContinuousPartitionReader createReader(InputPartition partition); @Override - default DataWriterFactory createWriterFactory() { - throw new IllegalStateException( - "createWriterFactory should not be called with SupportsWriteInternalRow."); + default ContinuousPartitionReader createColumnarReader(InputPartition partition) { + throw new UnsupportedOperationException("Cannot create columnar reader."); } - - DataWriterFactory createInternalRowWriterFactory(); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousReadSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousReadSupport.java new file mode 100644 index 0000000000000..9a3ad2eb8a801 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousReadSupport.java @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2.reader.streaming; + +import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.sql.execution.streaming.BaseStreamingSource; +import org.apache.spark.sql.sources.v2.reader.InputPartition; +import org.apache.spark.sql.sources.v2.reader.ScanConfig; +import org.apache.spark.sql.sources.v2.reader.ScanConfigBuilder; + +/** + * An interface that defines how to load the data from data source for continuous streaming + * processing. + * + * The execution engine will get an instance of this interface from a data source provider + * (e.g. {@link org.apache.spark.sql.sources.v2.ContinuousReadSupportProvider}) at the start of a + * streaming query, then call {@link #newScanConfigBuilder(Offset)} and create an instance of + * {@link ScanConfig} for the duration of the streaming query or until + * {@link #needsReconfiguration(ScanConfig)} is true. The {@link ScanConfig} will be used to create + * input partitions and reader factory to scan data with a Spark job for its duration. At the end + * {@link #stop()} will be called when the streaming execution is completed. Note that a single + * query may have multiple executions due to restart or failure recovery. + */ +@InterfaceStability.Evolving +public interface ContinuousReadSupport extends StreamingReadSupport, BaseStreamingSource { + + /** + * Returns a builder of {@link ScanConfig}. Spark will call this method and create a + * {@link ScanConfig} for each data scanning job. + * + * The builder can take some query specific information to do operators pushdown, store streaming + * offsets, etc., and keep these information in the created {@link ScanConfig}. + * + * This is the first step of the data scan. All other methods in {@link ContinuousReadSupport} + * needs to take {@link ScanConfig} as an input. + */ + ScanConfigBuilder newScanConfigBuilder(Offset start); + + /** + * Returns a factory, which produces one {@link ContinuousPartitionReader} for one + * {@link InputPartition}. + */ + ContinuousPartitionReaderFactory createContinuousReaderFactory(ScanConfig config); + + /** + * Merge partitioned offsets coming from {@link ContinuousPartitionReader} instances + * for each partition to a single global offset. + */ + Offset mergeOffsets(PartitionOffset[] offsets); + + /** + * The execution engine will call this method in every epoch to determine if new input + * partitions need to be generated, which may be required if for example the underlying + * source system has had partitions added or removed. + * + * If true, the query will be shut down and restarted with a new {@link ContinuousReadSupport} + * instance. + */ + default boolean needsReconfiguration(ScanConfig config) { + return false; + } +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousReader.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousReader.java deleted file mode 100644 index 6e960bedf8020..0000000000000 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousReader.java +++ /dev/null @@ -1,79 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.sources.v2.reader.streaming; - -import org.apache.spark.annotation.InterfaceStability; -import org.apache.spark.sql.execution.streaming.BaseStreamingSource; -import org.apache.spark.sql.sources.v2.reader.DataSourceReader; - -import java.util.Optional; - -/** - * A mix-in interface for {@link DataSourceReader}. Data source readers can implement this - * interface to allow reading in a continuous processing mode stream. - * - * Implementations must ensure each partition reader is a {@link ContinuousInputPartitionReader}. - * - * Note: This class currently extends {@link BaseStreamingSource} to maintain compatibility with - * DataSource V1 APIs. This extension will be removed once we get rid of V1 completely. - */ -@InterfaceStability.Evolving -public interface ContinuousReader extends BaseStreamingSource, DataSourceReader { - /** - * Merge partitioned offsets coming from {@link ContinuousInputPartitionReader} instances - * for each partition to a single global offset. - */ - Offset mergeOffsets(PartitionOffset[] offsets); - - /** - * Deserialize a JSON string into an Offset of the implementation-defined offset type. - * @throws IllegalArgumentException if the JSON does not encode a valid offset for this reader - */ - Offset deserializeOffset(String json); - - /** - * Set the desired start offset for partitions created from this reader. The scan will - * start from the first record after the provided offset, or from an implementation-defined - * inferred starting point if no offset is provided. - */ - void setStartOffset(Optional start); - - /** - * Return the specified or inferred start offset for this reader. - * - * @throws IllegalStateException if setStartOffset has not been called - */ - Offset getStartOffset(); - - /** - * The execution engine will call this method in every epoch to determine if new input - * partitions need to be generated, which may be required if for example the underlying - * source system has had partitions added or removed. - * - * If true, the query will be shut down and restarted with a new reader. - */ - default boolean needsReconfiguration() { - return false; - } - - /** - * Informs the source that Spark has completed processing all data for offsets less than or - * equal to `end` and will only request offsets greater than `end` in the future. - */ - void commit(Offset end); -} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/MicroBatchReadSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/MicroBatchReadSupport.java new file mode 100644 index 0000000000000..edb0db11bff2c --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/MicroBatchReadSupport.java @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2.reader.streaming; + +import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.sql.execution.streaming.BaseStreamingSource; +import org.apache.spark.sql.sources.v2.reader.*; + +/** + * An interface that defines how to scan the data from data source for micro-batch streaming + * processing. + * + * The execution engine will get an instance of this interface from a data source provider + * (e.g. {@link org.apache.spark.sql.sources.v2.MicroBatchReadSupportProvider}) at the start of a + * streaming query, then call {@link #newScanConfigBuilder(Offset, Offset)} and create an instance + * of {@link ScanConfig} for each micro-batch. The {@link ScanConfig} will be used to create input + * partitions and reader factory to scan a micro-batch with a Spark job. At the end {@link #stop()} + * will be called when the streaming execution is completed. Note that a single query may have + * multiple executions due to restart or failure recovery. + */ +@InterfaceStability.Evolving +public interface MicroBatchReadSupport extends StreamingReadSupport, BaseStreamingSource { + + /** + * Returns a builder of {@link ScanConfig}. Spark will call this method and create a + * {@link ScanConfig} for each data scanning job. + * + * The builder can take some query specific information to do operators pushdown, store streaming + * offsets, etc., and keep these information in the created {@link ScanConfig}. + * + * This is the first step of the data scan. All other methods in {@link MicroBatchReadSupport} + * needs to take {@link ScanConfig} as an input. + */ + ScanConfigBuilder newScanConfigBuilder(Offset start, Offset end); + + /** + * Returns a factory, which produces one {@link PartitionReader} for one {@link InputPartition}. + */ + PartitionReaderFactory createReaderFactory(ScanConfig config); + + /** + * Returns the most recent offset available. + */ + Offset latestOffset(); +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/MicroBatchReader.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/MicroBatchReader.java deleted file mode 100644 index 0159c731762d9..0000000000000 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/MicroBatchReader.java +++ /dev/null @@ -1,75 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.sources.v2.reader.streaming; - -import org.apache.spark.annotation.InterfaceStability; -import org.apache.spark.sql.sources.v2.reader.DataSourceReader; -import org.apache.spark.sql.execution.streaming.BaseStreamingSource; - -import java.util.Optional; - -/** - * A mix-in interface for {@link DataSourceReader}. Data source readers can implement this - * interface to indicate they allow micro-batch streaming reads. - * - * Note: This class currently extends {@link BaseStreamingSource} to maintain compatibility with - * DataSource V1 APIs. This extension will be removed once we get rid of V1 completely. - */ -@InterfaceStability.Evolving -public interface MicroBatchReader extends DataSourceReader, BaseStreamingSource { - /** - * Set the desired offset range for input partitions created from this reader. Partition readers - * will generate only data within (`start`, `end`]; that is, from the first record after `start` - * to the record with offset `end`. - * - * @param start The initial offset to scan from. If not specified, scan from an - * implementation-specified start point, such as the earliest available record. - * @param end The last offset to include in the scan. If not specified, scan up to an - * implementation-defined endpoint, such as the last available offset - * or the start offset plus a target batch size. - */ - void setOffsetRange(Optional start, Optional end); - - /** - * Returns the specified (if explicitly set through setOffsetRange) or inferred start offset - * for this reader. - * - * @throws IllegalStateException if setOffsetRange has not been called - */ - Offset getStartOffset(); - - /** - * Return the specified (if explicitly set through setOffsetRange) or inferred end offset - * for this reader. - * - * @throws IllegalStateException if setOffsetRange has not been called - */ - Offset getEndOffset(); - - /** - * Deserialize a JSON string into an Offset of the implementation-defined offset type. - * @throws IllegalArgumentException if the JSON does not encode a valid offset for this reader - */ - Offset deserializeOffset(String json); - - /** - * Informs the source that Spark has completed processing all data for offsets less than or - * equal to `end` and will only request offsets greater than `end` in the future. - */ - void commit(Offset end); -} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/Offset.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/Offset.java index e41c0351edc82..6cf27734867cb 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/Offset.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/Offset.java @@ -20,8 +20,8 @@ import org.apache.spark.annotation.InterfaceStability; /** - * An abstract representation of progress through a {@link MicroBatchReader} or - * {@link ContinuousReader}. + * An abstract representation of progress through a {@link MicroBatchReadSupport} or + * {@link ContinuousReadSupport}. * During execution, offsets provided by the data source implementation will be logged and used as * restart checkpoints. Each source should provide an offset implementation which the source can use * to reconstruct a position in the stream up to which data has been seen/processed. diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/StreamingReadSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/StreamingReadSupport.java new file mode 100644 index 0000000000000..84872d1ebc26e --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/StreamingReadSupport.java @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2.reader.streaming; + +import org.apache.spark.sql.sources.v2.reader.ReadSupport; + +/** + * A base interface for streaming read support. This is package private and is invisible to data + * sources. Data sources should implement concrete streaming read support interfaces: + * {@link MicroBatchReadSupport} or {@link ContinuousReadSupport}. + */ +interface StreamingReadSupport extends ReadSupport { + + /** + * Returns the initial offset for a streaming query to start reading from. Note that the + * streaming data source should not assume that it will start reading from its initial offset: + * if Spark is restarting an existing query, it will restart from the check-pointed offset rather + * than the initial one. + */ + Offset initialOffset(); + + /** + * Deserialize a JSON string into an Offset of the implementation-defined offset type. + * + * @throws IllegalArgumentException if the JSON does not encode a valid offset for this reader + */ + Offset deserializeOffset(String json); + + /** + * Informs the source that Spark has completed processing all data for offsets less than or + * equal to `end` and will only request offsets greater than `end` in the future. + */ + void commit(Offset end); +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceWriter.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/BatchWriteSupport.java similarity index 80% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceWriter.java rename to sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/BatchWriteSupport.java index 7eedc85a5d6f3..0ec9e05d6a02b 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceWriter.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/BatchWriteSupport.java @@ -18,28 +18,13 @@ package org.apache.spark.sql.sources.v2.writer; import org.apache.spark.annotation.InterfaceStability; -import org.apache.spark.sql.Row; -import org.apache.spark.sql.SaveMode; -import org.apache.spark.sql.sources.v2.DataSourceOptions; -import org.apache.spark.sql.sources.v2.StreamWriteSupport; -import org.apache.spark.sql.sources.v2.WriteSupport; -import org.apache.spark.sql.streaming.OutputMode; -import org.apache.spark.sql.types.StructType; /** - * A data source writer that is returned by - * {@link WriteSupport#createWriter(String, StructType, SaveMode, DataSourceOptions)}/ - * {@link StreamWriteSupport#createStreamWriter( - * String, StructType, OutputMode, DataSourceOptions)}. - * It can mix in various writing optimization interfaces to speed up the data saving. The actual - * writing logic is delegated to {@link DataWriter}. - * - * If an exception was throw when applying any of these writing optimizations, the action will fail - * and no Spark job will be submitted. + * An interface that defines how to write the data to data source for batch processing. * * The writing procedure is: - * 1. Create a writer factory by {@link #createWriterFactory()}, serialize and send it to all the - * partitions of the input data(RDD). + * 1. Create a writer factory by {@link #createBatchWriterFactory()}, serialize and send it to all + * the partitions of the input data(RDD). * 2. For each partition, create the data writer, and write the data of the partition with this * writer. If all the data are written successfully, call {@link DataWriter#commit()}. If * exception happens during the writing, call {@link DataWriter#abort()}. @@ -53,7 +38,7 @@ * Please refer to the documentation of commit/abort methods for detailed specifications. */ @InterfaceStability.Evolving -public interface DataSourceWriter { +public interface BatchWriteSupport { /** * Creates a writer factory which will be serialized and sent to executors. @@ -61,7 +46,7 @@ public interface DataSourceWriter { * If this method fails (by throwing an exception), the action will fail and no Spark job will be * submitted. */ - DataWriterFactory createWriterFactory(); + DataWriterFactory createBatchWriterFactory(); /** * Returns whether Spark should use the commit coordinator to ensure that at most one task for diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriter.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriter.java index 1626c0013e4e7..5fb067966ee67 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriter.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriter.java @@ -22,7 +22,7 @@ import org.apache.spark.annotation.InterfaceStability; /** - * A data writer returned by {@link DataWriterFactory#createDataWriter(int, long, long)} and is + * A data writer returned by {@link DataWriterFactory#createWriter(int, long)} and is * responsible for writing data for an input RDD partition. * * One Spark task has one exclusive data writer, so there is no thread-safe concern. @@ -36,11 +36,11 @@ * * If this data writer succeeds(all records are successfully written and {@link #commit()} * succeeds), a {@link WriterCommitMessage} will be sent to the driver side and pass to - * {@link DataSourceWriter#commit(WriterCommitMessage[])} with commit messages from other data + * {@link BatchWriteSupport#commit(WriterCommitMessage[])} with commit messages from other data * writers. If this data writer fails(one record fails to write or {@link #commit()} fails), an * exception will be sent to the driver side, and Spark may retry this writing task a few times. - * In each retry, {@link DataWriterFactory#createDataWriter(int, long, long)} will receive a - * different `taskId`. Spark will call {@link DataSourceWriter#abort(WriterCommitMessage[])} + * In each retry, {@link DataWriterFactory#createWriter(int, long)} will receive a + * different `taskId`. Spark will call {@link BatchWriteSupport#abort(WriterCommitMessage[])} * when the configured number of retries is exhausted. * * Besides the retry mechanism, Spark may launch speculative tasks if the existing writing task @@ -53,9 +53,7 @@ * successfully, and have a way to revert committed data writers without the commit message, because * Spark only accepts the commit message that arrives first and ignore others. * - * Note that, Currently the type `T` can only be {@link org.apache.spark.sql.Row} for normal data - * source writers, or {@link org.apache.spark.sql.catalyst.InternalRow} for data source writers - * that mix in {@link SupportsWriteInternalRow}. + * Note that, Currently the type `T` can only be {@link org.apache.spark.sql.catalyst.InternalRow}. */ @InterfaceStability.Evolving public interface DataWriter { @@ -73,11 +71,11 @@ public interface DataWriter { /** * Commits this writer after all records are written successfully, returns a commit message which * will be sent back to driver side and passed to - * {@link DataSourceWriter#commit(WriterCommitMessage[])}. + * {@link BatchWriteSupport#commit(WriterCommitMessage[])}. * * The written data should only be visible to data source readers after - * {@link DataSourceWriter#commit(WriterCommitMessage[])} succeeds, which means this method - * should still "hide" the written data and ask the {@link DataSourceWriter} at driver side to + * {@link BatchWriteSupport#commit(WriterCommitMessage[])} succeeds, which means this method + * should still "hide" the written data and ask the {@link BatchWriteSupport} at driver side to * do the final commit via {@link WriterCommitMessage}. * * If this method fails (by throwing an exception), {@link #abort()} will be called and this @@ -95,7 +93,7 @@ public interface DataWriter { * failed. * * If this method fails(by throwing an exception), the underlying data source may have garbage - * that need to be cleaned by {@link DataSourceWriter#abort(WriterCommitMessage[])} or manually, + * that need to be cleaned by {@link BatchWriteSupport#abort(WriterCommitMessage[])} or manually, * but these garbage should not be visible to data source readers. * * @throws IOException if failure happens during disk/network IO like writing files. diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriterFactory.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriterFactory.java index 0932ff8f8f8a7..19a36dd232456 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriterFactory.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriterFactory.java @@ -19,35 +19,37 @@ import java.io.Serializable; +import org.apache.spark.TaskContext; import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.sql.catalyst.InternalRow; /** - * A factory of {@link DataWriter} returned by {@link DataSourceWriter#createWriterFactory()}, + * A factory of {@link DataWriter} returned by {@link BatchWriteSupport#createBatchWriterFactory()}, * which is responsible for creating and initializing the actual data writer at executor side. * * Note that, the writer factory will be serialized and sent to executors, then the data writer - * will be created on executors and do the actual writing. So {@link DataWriterFactory} must be + * will be created on executors and do the actual writing. So this interface must be * serializable and {@link DataWriter} doesn't need to be. */ @InterfaceStability.Evolving -public interface DataWriterFactory extends Serializable { +public interface DataWriterFactory extends Serializable { /** - * Returns a data writer to do the actual writing work. + * Returns a data writer to do the actual writing work. Note that, Spark will reuse the same data + * object instance when sending data to the data writer, for better performance. Data writers + * are responsible for defensive copies if necessary, e.g. copy the data before buffer it in a + * list. * - * If this method fails (by throwing an exception), the action will fail and no Spark job will be - * submitted. + * If this method fails (by throwing an exception), the corresponding Spark write task would fail + * and get retried until hitting the maximum retry times. * * @param partitionId A unique id of the RDD partition that the returned writer will process. * Usually Spark processes many RDD partitions at the same time, * implementations should use the partition id to distinguish writers for * different partitions. - * @param taskId A unique identifier for a task that is performing the write of the partition - * data. Spark may run multiple tasks for the same partition (due to speculation - * or task failures, for example). - * @param epochId A monotonically increasing id for streaming queries that are split in to - * discrete periods of execution. For non-streaming queries, - * this ID will always be 0. + * @param taskId The task id returned by {@link TaskContext#taskAttemptId()}. Spark may run + * multiple tasks for the same partition (due to speculation or task failures, + * for example). */ - DataWriter createDataWriter(int partitionId, long taskId, long epochId); + DataWriter createWriter(int partitionId, long taskId); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/WriterCommitMessage.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/WriterCommitMessage.java index 9e38836c0edf9..123335c414e9f 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/WriterCommitMessage.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/WriterCommitMessage.java @@ -19,15 +19,16 @@ import java.io.Serializable; +import org.apache.spark.sql.sources.v2.writer.streaming.StreamingWriteSupport; import org.apache.spark.annotation.InterfaceStability; /** * A commit message returned by {@link DataWriter#commit()} and will be sent back to the driver side - * as the input parameter of {@link DataSourceWriter#commit(WriterCommitMessage[])}. + * as the input parameter of {@link BatchWriteSupport#commit(WriterCommitMessage[])} or + * {@link StreamingWriteSupport#commit(long, WriterCommitMessage[])}. * - * This is an empty interface, data sources should define their own message class and use it in - * their {@link DataWriter#commit()} and {@link DataSourceWriter#commit(WriterCommitMessage[])} - * implementations. + * This is an empty interface, data sources should define their own message class and use it when + * generating messages at executor side and handling the messages at driver side. */ @InterfaceStability.Evolving public interface WriterCommitMessage extends Serializable {} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/streaming/StreamingDataWriterFactory.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/streaming/StreamingDataWriterFactory.java new file mode 100644 index 0000000000000..a4da24fc5ae68 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/streaming/StreamingDataWriterFactory.java @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2.writer.streaming; + +import java.io.Serializable; + +import org.apache.spark.TaskContext; +import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.sources.v2.writer.DataWriter; + +/** + * A factory of {@link DataWriter} returned by + * {@link StreamingWriteSupport#createStreamingWriterFactory()}, which is responsible for creating + * and initializing the actual data writer at executor side. + * + * Note that, the writer factory will be serialized and sent to executors, then the data writer + * will be created on executors and do the actual writing. So this interface must be + * serializable and {@link DataWriter} doesn't need to be. + */ +@InterfaceStability.Evolving +public interface StreamingDataWriterFactory extends Serializable { + + /** + * Returns a data writer to do the actual writing work. Note that, Spark will reuse the same data + * object instance when sending data to the data writer, for better performance. Data writers + * are responsible for defensive copies if necessary, e.g. copy the data before buffer it in a + * list. + * + * If this method fails (by throwing an exception), the corresponding Spark write task would fail + * and get retried until hitting the maximum retry times. + * + * @param partitionId A unique id of the RDD partition that the returned writer will process. + * Usually Spark processes many RDD partitions at the same time, + * implementations should use the partition id to distinguish writers for + * different partitions. + * @param taskId The task id returned by {@link TaskContext#taskAttemptId()}. Spark may run + * multiple tasks for the same partition (due to speculation or task failures, + * for example). + * @param epochId A monotonically increasing id for streaming queries that are split in to + * discrete periods of execution. + */ + DataWriter createWriter(int partitionId, long taskId, long epochId); +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/streaming/StreamWriter.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/streaming/StreamingWriteSupport.java similarity index 78% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/streaming/StreamWriter.java rename to sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/streaming/StreamingWriteSupport.java index a316b2a4c1d82..3fdfac5e1c84a 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/streaming/StreamWriter.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/streaming/StreamingWriteSupport.java @@ -18,27 +18,36 @@ package org.apache.spark.sql.sources.v2.writer.streaming; import org.apache.spark.annotation.InterfaceStability; -import org.apache.spark.sql.sources.v2.writer.DataSourceWriter; import org.apache.spark.sql.sources.v2.writer.DataWriter; import org.apache.spark.sql.sources.v2.writer.WriterCommitMessage; /** - * A {@link DataSourceWriter} for use with structured streaming. + * An interface that defines how to write the data to data source for streaming processing. * * Streaming queries are divided into intervals of data called epochs, with a monotonically * increasing numeric ID. This writer handles commits and aborts for each successive epoch. */ @InterfaceStability.Evolving -public interface StreamWriter extends DataSourceWriter { +public interface StreamingWriteSupport { + + /** + * Creates a writer factory which will be serialized and sent to executors. + * + * If this method fails (by throwing an exception), the action will fail and no Spark job will be + * submitted. + */ + StreamingDataWriterFactory createStreamingWriterFactory(); + /** * Commits this writing job for the specified epoch with a list of commit messages. The commit * messages are collected from successful data writers and are produced by * {@link DataWriter#commit()}. * * If this method fails (by throwing an exception), this writing job is considered to have been - * failed, and the execution engine will attempt to call {@link #abort(WriterCommitMessage[])}. + * failed, and the execution engine will attempt to call + * {@link #abort(long, WriterCommitMessage[])}. * - * The execution engine may call commit() multiple times for the same epoch in some circumstances. + * The execution engine may call `commit` multiple times for the same epoch in some circumstances. * To support exactly-once data semantics, implementations must ensure that multiple commits for * the same epoch are idempotent. */ @@ -46,7 +55,8 @@ public interface StreamWriter extends DataSourceWriter { /** * Aborts this writing job because some data writers are failed and keep failing when retried, or - * the Spark job fails with some unknown reasons, or {@link #commit(WriterCommitMessage[])} fails. + * the Spark job fails with some unknown reasons, or {@link #commit(long, WriterCommitMessage[])} + * fails. * * If this method fails (by throwing an exception), the underlying data source may require manual * cleanup. @@ -58,14 +68,4 @@ public interface StreamWriter extends DataSourceWriter { * clean up the data left by data writers. */ void abort(long epochId, WriterCommitMessage[] messages); - - default void commit(WriterCommitMessage[] messages) { - throw new UnsupportedOperationException( - "Commit without epoch should not be called with StreamWriter"); - } - - default void abort(WriterCommitMessage[] messages) { - throw new UnsupportedOperationException( - "Abort without epoch should not be called with StreamWriter"); - } } diff --git a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java index 227a16f7e69e9..5f58b031f6aef 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java @@ -25,7 +25,6 @@ import org.apache.spark.annotation.InterfaceStability; import org.apache.spark.sql.execution.arrow.ArrowUtils; import org.apache.spark.sql.types.*; -import org.apache.spark.unsafe.memory.OffHeapMemoryBlock; import org.apache.spark.unsafe.types.UTF8String; /** @@ -162,13 +161,13 @@ public ArrowColumnVector(ValueVector vector) { } else if (vector instanceof ListVector) { ListVector listVector = (ListVector) vector; accessor = new ArrayAccessor(listVector); - } else if (vector instanceof NullableMapVector) { - NullableMapVector mapVector = (NullableMapVector) vector; - accessor = new StructAccessor(mapVector); + } else if (vector instanceof StructVector) { + StructVector structVector = (StructVector) vector; + accessor = new StructAccessor(structVector); - childColumns = new ArrowColumnVector[mapVector.size()]; + childColumns = new ArrowColumnVector[structVector.size()]; for (int i = 0; i < childColumns.length; ++i) { - childColumns[i] = new ArrowColumnVector(mapVector.getVectorById(i)); + childColumns[i] = new ArrowColumnVector(structVector.getVectorById(i)); } } else { throw new UnsupportedOperationException(); @@ -378,10 +377,9 @@ final UTF8String getUTF8String(int rowId) { if (stringResult.isSet == 0) { return null; } else { - return new UTF8String(new OffHeapMemoryBlock( + return UTF8String.fromAddress(null, stringResult.buffer.memoryAddress() + stringResult.start, - stringResult.end - stringResult.start - )); + stringResult.end - stringResult.start); } } } @@ -455,9 +453,9 @@ final boolean isNullAt(int rowId) { @Override final ColumnarArray getArray(int rowId) { ArrowBuf offsets = accessor.getOffsetBuffer(); - int index = rowId * accessor.OFFSET_WIDTH; + int index = rowId * ListVector.OFFSET_WIDTH; int start = offsets.getInt(index); - int end = offsets.getInt(index + accessor.OFFSET_WIDTH); + int end = offsets.getInt(index + ListVector.OFFSET_WIDTH); return new ColumnarArray(arrayData, start, end - start); } } @@ -472,7 +470,7 @@ final ColumnarArray getArray(int rowId) { */ private static class StructAccessor extends ArrowVectorAccessor { - StructAccessor(NullableMapVector vector) { + StructAccessor(StructVector vector) { super(vector); } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index 4eee3de5f7d4e..ae27690f2e5ba 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -345,7 +345,7 @@ class Column(val expr: Expression) extends Logging { * * // Java: * import static org.apache.spark.sql.functions.*; - * people.select( people("age").gt(21) ); + * people.select( people.col("age").gt(21) ); * }}} * * @group expr_ops @@ -361,7 +361,7 @@ class Column(val expr: Expression) extends Logging { * * // Java: * import static org.apache.spark.sql.functions.*; - * people.select( people("age").gt(21) ); + * people.select( people.col("age").gt(21) ); * }}} * * @group java_expr_ops @@ -376,7 +376,7 @@ class Column(val expr: Expression) extends Logging { * people.select( people("age") < 21 ) * * // Java: - * people.select( people("age").lt(21) ); + * people.select( people.col("age").lt(21) ); * }}} * * @group expr_ops @@ -391,7 +391,7 @@ class Column(val expr: Expression) extends Logging { * people.select( people("age") < 21 ) * * // Java: - * people.select( people("age").lt(21) ); + * people.select( people.col("age").lt(21) ); * }}} * * @group java_expr_ops @@ -406,7 +406,7 @@ class Column(val expr: Expression) extends Logging { * people.select( people("age") <= 21 ) * * // Java: - * people.select( people("age").leq(21) ); + * people.select( people.col("age").leq(21) ); * }}} * * @group expr_ops @@ -421,7 +421,7 @@ class Column(val expr: Expression) extends Logging { * people.select( people("age") <= 21 ) * * // Java: - * people.select( people("age").leq(21) ); + * people.select( people.col("age").leq(21) ); * }}} * * @group java_expr_ops @@ -436,7 +436,7 @@ class Column(val expr: Expression) extends Logging { * people.select( people("age") >= 21 ) * * // Java: - * people.select( people("age").geq(21) ) + * people.select( people.col("age").geq(21) ) * }}} * * @group expr_ops @@ -451,7 +451,7 @@ class Column(val expr: Expression) extends Logging { * people.select( people("age") >= 21 ) * * // Java: - * people.select( people("age").geq(21) ) + * people.select( people.col("age").geq(21) ) * }}} * * @group java_expr_ops @@ -588,7 +588,7 @@ class Column(val expr: Expression) extends Logging { * people.filter( people("inSchool") || people("isEmployed") ) * * // Java: - * people.filter( people("inSchool").or(people("isEmployed")) ); + * people.filter( people.col("inSchool").or(people.col("isEmployed")) ); * }}} * * @group expr_ops @@ -603,7 +603,7 @@ class Column(val expr: Expression) extends Logging { * people.filter( people("inSchool") || people("isEmployed") ) * * // Java: - * people.filter( people("inSchool").or(people("isEmployed")) ); + * people.filter( people.col("inSchool").or(people.col("isEmployed")) ); * }}} * * @group java_expr_ops @@ -618,7 +618,7 @@ class Column(val expr: Expression) extends Logging { * people.select( people("inSchool") && people("isEmployed") ) * * // Java: - * people.select( people("inSchool").and(people("isEmployed")) ); + * people.select( people.col("inSchool").and(people.col("isEmployed")) ); * }}} * * @group expr_ops @@ -633,7 +633,7 @@ class Column(val expr: Expression) extends Logging { * people.select( people("inSchool") && people("isEmployed") ) * * // Java: - * people.select( people("inSchool").and(people("isEmployed")) ); + * people.select( people.col("inSchool").and(people.col("isEmployed")) ); * }}} * * @group java_expr_ops @@ -648,7 +648,7 @@ class Column(val expr: Expression) extends Logging { * people.select( people("height") + people("weight") ) * * // Java: - * people.select( people("height").plus(people("weight")) ); + * people.select( people.col("height").plus(people.col("weight")) ); * }}} * * @group expr_ops @@ -663,7 +663,7 @@ class Column(val expr: Expression) extends Logging { * people.select( people("height") + people("weight") ) * * // Java: - * people.select( people("height").plus(people("weight")) ); + * people.select( people.col("height").plus(people.col("weight")) ); * }}} * * @group java_expr_ops @@ -678,7 +678,7 @@ class Column(val expr: Expression) extends Logging { * people.select( people("height") - people("weight") ) * * // Java: - * people.select( people("height").minus(people("weight")) ); + * people.select( people.col("height").minus(people.col("weight")) ); * }}} * * @group expr_ops @@ -693,7 +693,7 @@ class Column(val expr: Expression) extends Logging { * people.select( people("height") - people("weight") ) * * // Java: - * people.select( people("height").minus(people("weight")) ); + * people.select( people.col("height").minus(people.col("weight")) ); * }}} * * @group java_expr_ops @@ -708,7 +708,7 @@ class Column(val expr: Expression) extends Logging { * people.select( people("height") * people("weight") ) * * // Java: - * people.select( people("height").multiply(people("weight")) ); + * people.select( people.col("height").multiply(people.col("weight")) ); * }}} * * @group expr_ops @@ -723,7 +723,7 @@ class Column(val expr: Expression) extends Logging { * people.select( people("height") * people("weight") ) * * // Java: - * people.select( people("height").multiply(people("weight")) ); + * people.select( people.col("height").multiply(people.col("weight")) ); * }}} * * @group java_expr_ops @@ -738,7 +738,7 @@ class Column(val expr: Expression) extends Logging { * people.select( people("height") / people("weight") ) * * // Java: - * people.select( people("height").divide(people("weight")) ); + * people.select( people.col("height").divide(people.col("weight")) ); * }}} * * @group expr_ops @@ -753,7 +753,7 @@ class Column(val expr: Expression) extends Logging { * people.select( people("height") / people("weight") ) * * // Java: - * people.select( people("height").divide(people("weight")) ); + * people.select( people.col("height").divide(people.col("weight")) ); * }}} * * @group java_expr_ops diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala index f3a2b70657c48..5288907b7d7ff 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala @@ -494,6 +494,8 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { case (NumericType, dt) => dt.isInstanceOf[NumericType] case (StringType, dt) => dt == StringType case (BooleanType, dt) => dt == BooleanType + case _ => + throw new IllegalArgumentException(s"$targetType is not matched at fillValue") } // Only fill if the column is part of the cols list. if (typeMatches && cols.exists(col => columnEquals(f.name, col))) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index ec9352a7fa055..e6c2cba79841a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -37,7 +37,7 @@ import org.apache.spark.sql.execution.datasources.jdbc._ import org.apache.spark.sql.execution.datasources.json.TextInputJsonDataSource import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Utils -import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2, ReadSupport, ReadSupportWithSchema} +import org.apache.spark.sql.sources.v2.{BatchReadSupportProvider, DataSourceOptions, DataSourceV2} import org.apache.spark.sql.types.{StringType, StructType} import org.apache.spark.unsafe.types.UTF8String @@ -194,7 +194,7 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { val cls = DataSource.lookupDataSource(source, sparkSession.sessionState.conf) if (classOf[DataSourceV2].isAssignableFrom(cls)) { val ds = cls.newInstance().asInstanceOf[DataSourceV2] - if (ds.isInstanceOf[ReadSupport] || ds.isInstanceOf[ReadSupportWithSchema]) { + if (ds.isInstanceOf[BatchReadSupportProvider]) { val sessionOptions = DataSourceV2Utils.extractSessionConfigs( ds = ds, conf = sparkSession.sessionState.conf) val pathsOption = { @@ -450,7 +450,8 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { input => rawParser.parse(input, createParser, UTF8String.fromString), parsedOptions.parseMode, schema, - parsedOptions.columnNameOfCorruptRecord) + parsedOptions.columnNameOfCorruptRecord, + parsedOptions.multiLine) iter.flatMap(parser.parse) } sparkSession.internalCreateDataFrame(parsed, schema, isStreaming = jsonDataset.isStreaming) @@ -505,10 +506,11 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { StructType(schema.filterNot(_.name == parsedOptions.columnNameOfCorruptRecord)) val linesWithoutHeader: RDD[String] = maybeFirstLine.map { firstLine => - CSVDataSource.checkHeader( - firstLine, - new CsvParser(parsedOptions.asParserSettings), + val parser = new CsvParser(parsedOptions.asParserSettings) + val columnNames = parser.parseLine(firstLine) + CSVDataSource.checkHeaderColumnNames( actualSchema, + columnNames, csvDataset.getClass.getCanonicalName, parsedOptions.enforceSchema, sparkSession.sessionState.conf.caseSensitiveAnalysis) @@ -521,7 +523,8 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { input => Seq(rawParser.parse(input)), parsedOptions.parseMode, schema, - parsedOptions.columnNameOfCorruptRecord) + parsedOptions.columnNameOfCorruptRecord, + parsedOptions.multiLine) iter.flatMap(parser.parse) } sparkSession.internalCreateDataFrame(parsed, schema, isStreaming = csvDataset.isStreaming) @@ -568,6 +571,7 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { * whitespaces from values being read should be skipped. *
    • `nullValue` (default empty string): sets the string representation of a null value. Since * 2.0.1, this applies to all supported types including the string type.
    • + *
    • `emptyValue` (default empty string): sets the string representation of an empty value.
    • *
    • `nanValue` (default `NaN`): sets the string representation of a non-number" value.
    • *
    • `positiveInf` (default `Inf`): sets the string representation of a positive infinity * value.
    • diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index 90bea2d676e22..dfb8c4718550f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -17,8 +17,7 @@ package org.apache.spark.sql -import java.text.SimpleDateFormat -import java.util.{Date, Locale, Properties, UUID} +import java.util.{Locale, Properties, UUID} import scala.collection.JavaConverters._ @@ -26,12 +25,11 @@ import org.apache.spark.annotation.InterfaceStability import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.{EliminateSubqueryAliases, UnresolvedRelation} import org.apache.spark.sql.catalyst.catalog._ -import org.apache.spark.sql.catalyst.plans.logical.{AnalysisBarrier, InsertIntoTable, LogicalPlan} +import org.apache.spark.sql.catalyst.plans.logical.{AppendData, InsertIntoTable, LogicalPlan} import org.apache.spark.sql.execution.SQLExecution import org.apache.spark.sql.execution.command.DDLUtils import org.apache.spark.sql.execution.datasources.{CreateTable, DataSource, LogicalRelation} -import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Utils -import org.apache.spark.sql.execution.datasources.v2.WriteToDataSourceV2 +import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, DataSourceV2Utils, WriteToDataSourceV2} import org.apache.spark.sql.sources.BaseRelation import org.apache.spark.sql.sources.v2._ import org.apache.spark.sql.types.StructType @@ -240,21 +238,29 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { val cls = DataSource.lookupDataSource(source, df.sparkSession.sessionState.conf) if (classOf[DataSourceV2].isAssignableFrom(cls)) { - val ds = cls.newInstance() - ds match { - case ws: WriteSupport => - val options = new DataSourceOptions((extraOptions ++ - DataSourceV2Utils.extractSessionConfigs( - ds = ds.asInstanceOf[DataSourceV2], - conf = df.sparkSession.sessionState.conf)).asJava) - // Using a timestamp and a random UUID to distinguish different writing jobs. This is good - // enough as there won't be tons of writing jobs created at the same second. - val jobId = new SimpleDateFormat("yyyyMMddHHmmss", Locale.US) - .format(new Date()) + "-" + UUID.randomUUID() - val writer = ws.createWriter(jobId, df.logicalPlan.schema, mode, options) - if (writer.isPresent) { + val source = cls.newInstance().asInstanceOf[DataSourceV2] + source match { + case provider: BatchWriteSupportProvider => + val options = extraOptions ++ + DataSourceV2Utils.extractSessionConfigs(source, df.sparkSession.sessionState.conf) + + val relation = DataSourceV2Relation.create(source, options.toMap) + if (mode == SaveMode.Append) { runCommand(df.sparkSession, "save") { - WriteToDataSourceV2(writer.get(), df.logicalPlan) + AppendData.byName(relation, df.logicalPlan) + } + + } else { + val writer = provider.createBatchWriteSupport( + UUID.randomUUID().toString, + df.logicalPlan.output.toStructType, + mode, + new DataSourceOptions(options.asJava)) + + if (writer.isPresent) { + runCommand(df.sparkSession, "save") { + WriteToDataSourceV2(writer.get, df.logicalPlan) + } } } @@ -275,7 +281,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { sparkSession = df.sparkSession, className = source, partitionColumns = partitioningColumns.getOrElse(Nil), - options = extraOptions.toMap).planForWriting(mode, AnalysisBarrier(df.logicalPlan)) + options = extraOptions.toMap).planForWriting(mode, df.logicalPlan) } } @@ -351,7 +357,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { private def assertNotPartitioned(operation: String): Unit = { if (partitioningColumns.isDefined) { - throw new AnalysisException( s"'$operation' does not support partitioning") + throw new AnalysisException(s"'$operation' does not support partitioning") } } @@ -544,8 +550,8 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { *
        *
      • `compression` (default is the value specified in `spark.sql.parquet.compression.codec`): * compression codec to use when saving to file. This can be one of the known case-insensitive - * shorten names(`none`, `snappy`, `gzip`, and `lzo`). This will override - * `spark.sql.parquet.compression.codec`.
      • + * shorten names(`none`, `uncompressed`, `snappy`, `gzip`, `lzo`, `brotli`, `lz4`, and `zstd`). + * This will override `spark.sql.parquet.compression.codec`. *
      * * @since 1.4.0 @@ -629,6 +635,9 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { * enclosed in quotes. Default is to only escape values containing a quote character. *
    • `header` (default `false`): writes the names of columns as the first line.
    • *
    • `nullValue` (default empty string): sets the string representation of a null value.
    • + *
    • `emptyValue` (default `""`): sets the string representation of an empty value.
    • + *
    • `encoding` (by default it is not set): specifies encoding (charset) of saved csv + * files. If it is not set, the UTF-8 charset will be used.
    • *
    • `compression` (default `null`): compression codec to use when saving to file. This can be * one of the known case-insensitive shorten names (`none`, `bzip2`, `gzip`, `lz4`, * `snappy` and `deflate`).
    • diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 2ec236fc75efc..fa14aa14ee968 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -48,7 +48,7 @@ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, PartitioningCollection} import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution._ -import org.apache.spark.sql.execution.arrow.{ArrowConverters, ArrowPayload} +import org.apache.spark.sql.execution.arrow.{ArrowBatchStreamWriter, ArrowConverters} import org.apache.spark.sql.execution.command._ import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.execution.python.EvaluatePython @@ -65,7 +65,12 @@ private[sql] object Dataset { val dataset = new Dataset(sparkSession, logicalPlan, implicitly[Encoder[T]]) // Eagerly bind the encoder so we verify that the encoder matches the underlying // schema. The user will get an error if this is not the case. - dataset.deserializer + // optimization: it is guaranteed that [[InternalRow]] can be converted to [[Row]] so + // do not do this check in that case. this check can be expensive since it requires running + // the whole [[Analyzer]] to resolve the deserializer + if (dataset.exprEnc.clsTag.runtimeClass != classOf[Row]) { + dataset.deserializer + } dataset } @@ -195,9 +200,6 @@ class Dataset[T] private[sql]( } } - // Wraps analyzed logical plans with an analysis barrier so we won't traverse/resolve it again. - @transient private[sql] val planWithBarrier = AnalysisBarrier(logicalPlan) - /** * Currently [[ExpressionEncoder]] is the only implementation of [[Encoder]], here we turn the * passed in encoder to [[ExpressionEncoder]] explicitly, and mark it implicit so that we can use @@ -304,16 +306,16 @@ class Dataset[T] private[sql]( // Compute the width of each column for (row <- rows) { for ((cell, i) <- row.zipWithIndex) { - colWidths(i) = math.max(colWidths(i), cell.length) + colWidths(i) = math.max(colWidths(i), Utils.stringHalfWidth(cell)) } } val paddedRows = rows.map { row => row.zipWithIndex.map { case (cell, i) => if (truncate > 0) { - StringUtils.leftPad(cell, colWidths(i)) + StringUtils.leftPad(cell, colWidths(i) - Utils.stringHalfWidth(cell) + cell.length) } else { - StringUtils.rightPad(cell, colWidths(i)) + StringUtils.rightPad(cell, colWidths(i) - Utils.stringHalfWidth(cell) + cell.length) } } } @@ -335,12 +337,10 @@ class Dataset[T] private[sql]( // Compute the width of field name and data columns val fieldNameColWidth = fieldNames.foldLeft(minimumColWidth) { case (curMax, fieldName) => - math.max(curMax, fieldName.length) + math.max(curMax, Utils.stringHalfWidth(fieldName)) } val dataColWidth = dataRows.foldLeft(minimumColWidth) { case (curMax, row) => - math.max(curMax, row.map(_.length).reduceLeftOption[Int] { case (cellMax, cell) => - math.max(cellMax, cell) - }.getOrElse(0)) + math.max(curMax, row.map(cell => Utils.stringHalfWidth(cell)).max) } dataRows.zipWithIndex.foreach { case (row, i) => @@ -349,8 +349,10 @@ class Dataset[T] private[sql]( s"-RECORD $i", fieldNameColWidth + dataColWidth + 5, "-") sb.append(rowHeader).append("\n") row.zipWithIndex.map { case (cell, j) => - val fieldName = StringUtils.rightPad(fieldNames(j), fieldNameColWidth) - val data = StringUtils.rightPad(cell, dataColWidth) + val fieldName = StringUtils.rightPad(fieldNames(j), + fieldNameColWidth - Utils.stringHalfWidth(fieldNames(j)) + fieldNames(j).length) + val data = StringUtils.rightPad(cell, + dataColWidth - Utils.stringHalfWidth(cell) + cell.length) s" $fieldName | $data " }.addString(sb, "", "\n", "\n") } @@ -426,7 +428,7 @@ class Dataset[T] private[sql]( */ @Experimental @InterfaceStability.Evolving - def as[U : Encoder]: Dataset[U] = Dataset[U](sparkSession, planWithBarrier) + def as[U : Encoder]: Dataset[U] = Dataset[U](sparkSession, logicalPlan) /** * Converts this strongly typed collection of data to generic `DataFrame` with columns renamed. @@ -680,7 +682,7 @@ class Dataset[T] private[sql]( require(parsedDelay.milliseconds >= 0 && parsedDelay.months >= 0, s"delay threshold ($delayThreshold) should not be negative.") EliminateEventTimeWatermark( - EventTimeWatermark(UnresolvedAttribute(eventTime), parsedDelay, planWithBarrier)) + EventTimeWatermark(UnresolvedAttribute(eventTime), parsedDelay, logicalPlan)) } /** @@ -853,7 +855,7 @@ class Dataset[T] private[sql]( * @since 2.0.0 */ def join(right: Dataset[_]): DataFrame = withPlan { - Join(planWithBarrier, right.planWithBarrier, joinType = Inner, None) + Join(logicalPlan, right.logicalPlan, joinType = Inner, None) } /** @@ -931,7 +933,7 @@ class Dataset[T] private[sql]( // Analyze the self join. The assumption is that the analyzer will disambiguate left vs right // by creating a new instance for one of the branch. val joined = sparkSession.sessionState.executePlan( - Join(planWithBarrier, right.planWithBarrier, joinType = JoinType(joinType), None)) + Join(logicalPlan, right.logicalPlan, joinType = JoinType(joinType), None)) .analyzed.asInstanceOf[Join] withPlan { @@ -992,7 +994,7 @@ class Dataset[T] private[sql]( // Trigger analysis so in the case of self-join, the analyzer will clone the plan. // After the cloning, left and right side will have distinct expression ids. val plan = withPlan( - Join(planWithBarrier, right.planWithBarrier, JoinType(joinType), Some(joinExprs.expr))) + Join(logicalPlan, right.logicalPlan, JoinType(joinType), Some(joinExprs.expr))) .queryExecution.analyzed.asInstanceOf[Join] // If auto self join alias is disabled, return the plan. @@ -1001,8 +1003,8 @@ class Dataset[T] private[sql]( } // If left/right have no output set intersection, return the plan. - val lanalyzed = withPlan(this.planWithBarrier).queryExecution.analyzed - val ranalyzed = withPlan(right.planWithBarrier).queryExecution.analyzed + val lanalyzed = withPlan(this.logicalPlan).queryExecution.analyzed + val ranalyzed = withPlan(right.logicalPlan).queryExecution.analyzed if (lanalyzed.outputSet.intersect(ranalyzed.outputSet).isEmpty) { return withPlan(plan) } @@ -1016,6 +1018,11 @@ class Dataset[T] private[sql]( catalyst.expressions.EqualTo( withPlan(plan.left).resolve(a.name), withPlan(plan.right).resolve(b.name)) + case catalyst.expressions.EqualNullSafe(a: AttributeReference, b: AttributeReference) + if a.sameRef(b) => + catalyst.expressions.EqualNullSafe( + withPlan(plan.left).resolve(a.name), + withPlan(plan.right).resolve(b.name)) }} withPlan { @@ -1034,7 +1041,7 @@ class Dataset[T] private[sql]( * @since 2.1.0 */ def crossJoin(right: Dataset[_]): DataFrame = withPlan { - Join(planWithBarrier, right.planWithBarrier, joinType = Cross, None) + Join(logicalPlan, right.logicalPlan, joinType = Cross, None) } /** @@ -1066,8 +1073,8 @@ class Dataset[T] private[sql]( // etc. val joined = sparkSession.sessionState.executePlan( Join( - this.planWithBarrier, - other.planWithBarrier, + this.logicalPlan, + other.logicalPlan, JoinType(joinType), Some(condition.expr))).analyzed.asInstanceOf[Join] @@ -1288,7 +1295,7 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ def as(alias: String): Dataset[T] = withTypedPlan { - SubqueryAlias(alias, planWithBarrier) + SubqueryAlias(alias, logicalPlan) } /** @@ -1326,7 +1333,7 @@ class Dataset[T] private[sql]( */ @scala.annotation.varargs def select(cols: Column*): DataFrame = withPlan { - Project(cols.map(_.named), planWithBarrier) + Project(cols.map(_.named), logicalPlan) } /** @@ -1381,8 +1388,7 @@ class Dataset[T] private[sql]( @InterfaceStability.Evolving def select[U1](c1: TypedColumn[T, U1]): Dataset[U1] = { implicit val encoder = c1.encoder - val project = Project(c1.withInputType(exprEnc, planWithBarrier.output).named :: Nil, - planWithBarrier) + val project = Project(c1.withInputType(exprEnc, logicalPlan.output).named :: Nil, logicalPlan) if (encoder.flat) { new Dataset[U1](sparkSession, project, encoder) @@ -1400,8 +1406,8 @@ class Dataset[T] private[sql]( protected def selectUntyped(columns: TypedColumn[_, _]*): Dataset[_] = { val encoders = columns.map(_.encoder) val namedColumns = - columns.map(_.withInputType(exprEnc, planWithBarrier.output).named) - val execution = new QueryExecution(sparkSession, Project(namedColumns, planWithBarrier)) + columns.map(_.withInputType(exprEnc, logicalPlan.output).named) + val execution = new QueryExecution(sparkSession, Project(namedColumns, logicalPlan)) new Dataset(sparkSession, execution, ExpressionEncoder.tuple(encoders)) } @@ -1477,7 +1483,7 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ def filter(condition: Column): Dataset[T] = withTypedPlan { - Filter(condition.expr, planWithBarrier) + Filter(condition.expr, logicalPlan) } /** @@ -1656,15 +1662,14 @@ class Dataset[T] private[sql]( @Experimental @InterfaceStability.Evolving def groupByKey[K: Encoder](func: T => K): KeyValueGroupedDataset[K, T] = { - val inputPlan = planWithBarrier - val withGroupingKey = AppendColumns(func, inputPlan) + val withGroupingKey = AppendColumns(func, logicalPlan) val executed = sparkSession.sessionState.executePlan(withGroupingKey) new KeyValueGroupedDataset( encoderFor[K], encoderFor[T], executed, - inputPlan.output, + logicalPlan.output, withGroupingKey.newColumns) } @@ -1802,7 +1807,7 @@ class Dataset[T] private[sql]( * @since 2.0.0 */ def limit(n: Int): Dataset[T] = withTypedPlan { - Limit(Literal(n), planWithBarrier) + Limit(Literal(n), logicalPlan) } /** @@ -1852,7 +1857,7 @@ class Dataset[T] private[sql]( def union(other: Dataset[T]): Dataset[T] = withSetOperator { // This breaks caching, but it's usually ok because it addresses a very specific use case: // using union to union many files or partitions. - CombineUnions(Union(logicalPlan, other.logicalPlan)).mapChildren(AnalysisBarrier) + CombineUnions(Union(logicalPlan, other.logicalPlan)) } /** @@ -1911,7 +1916,7 @@ class Dataset[T] private[sql]( // This breaks caching, but it's usually ok because it addresses a very specific use case: // using union to union many files or partitions. - CombineUnions(Union(logicalPlan, rightChild)).mapChildren(AnalysisBarrier) + CombineUnions(Union(logicalPlan, rightChild)) } /** @@ -1925,9 +1930,26 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ def intersect(other: Dataset[T]): Dataset[T] = withSetOperator { - Intersect(planWithBarrier, other.planWithBarrier) + Intersect(logicalPlan, other.logicalPlan, isAll = false) } + /** + * Returns a new Dataset containing rows only in both this Dataset and another Dataset while + * preserving the duplicates. + * This is equivalent to `INTERSECT ALL` in SQL. + * + * @note Equality checking is performed directly on the encoded representation of the data + * and thus is not affected by a custom `equals` function defined on `T`. Also as standard + * in SQL, this function resolves columns by position (not by name). + * + * @group typedrel + * @since 2.4.0 + */ + def intersectAll(other: Dataset[T]): Dataset[T] = withSetOperator { + Intersect(logicalPlan, other.logicalPlan, isAll = true) + } + + /** * Returns a new Dataset containing rows in this Dataset but not in another Dataset. * This is equivalent to `EXCEPT DISTINCT` in SQL. @@ -1939,7 +1961,23 @@ class Dataset[T] private[sql]( * @since 2.0.0 */ def except(other: Dataset[T]): Dataset[T] = withSetOperator { - Except(planWithBarrier, other.planWithBarrier) + Except(logicalPlan, other.logicalPlan, isAll = false) + } + + /** + * Returns a new Dataset containing rows in this Dataset but not in another Dataset while + * preserving the duplicates. + * This is equivalent to `EXCEPT ALL` in SQL. + * + * @note Equality checking is performed directly on the encoded representation of the data + * and thus is not affected by a custom `equals` function defined on `T`. Also as standard in + * SQL, this function resolves columns by position (not by name). + * + * @group typedrel + * @since 2.4.0 + */ + def exceptAll(other: Dataset[T]): Dataset[T] = withSetOperator { + Except(logicalPlan, other.logicalPlan, isAll = true) } /** @@ -1990,7 +2028,7 @@ class Dataset[T] private[sql]( */ def sample(withReplacement: Boolean, fraction: Double, seed: Long): Dataset[T] = { withTypedPlan { - Sample(0.0, fraction, withReplacement, seed, planWithBarrier) + Sample(0.0, fraction, withReplacement, seed, logicalPlan) } } @@ -2032,15 +2070,15 @@ class Dataset[T] private[sql]( // overlapping splits. To prevent this, we explicitly sort each input partition to make the // ordering deterministic. Note that MapTypes cannot be sorted and are explicitly pruned out // from the sort order. - val sortOrder = planWithBarrier.output + val sortOrder = logicalPlan.output .filter(attr => RowOrdering.isOrderable(attr.dataType)) .map(SortOrder(_, Ascending)) val plan = if (sortOrder.nonEmpty) { - Sort(sortOrder, global = false, planWithBarrier) + Sort(sortOrder, global = false, logicalPlan) } else { // SPARK-12662: If sort order is empty, we materialize the dataset to guarantee determinism cache() - planWithBarrier + logicalPlan } val sum = weights.sum val normalizedCumWeights = weights.map(_ / sum).scanLeft(0.0d)(_ + _) @@ -2124,7 +2162,7 @@ class Dataset[T] private[sql]( withPlan { Generate(generator, unrequiredChildIndex = Nil, outer = false, - qualifier = None, generatorOutput = Nil, planWithBarrier) + qualifier = None, generatorOutput = Nil, logicalPlan) } } @@ -2165,7 +2203,7 @@ class Dataset[T] private[sql]( withPlan { Generate(generator, unrequiredChildIndex = Nil, outer = false, - qualifier = None, generatorOutput = Nil, planWithBarrier) + qualifier = None, generatorOutput = Nil, logicalPlan) } } @@ -2316,7 +2354,7 @@ class Dataset[T] private[sql]( u.name, sparkSession.sessionState.analyzer.resolver).getOrElse(u) case Column(expr: Expression) => expr } - val attrs = this.planWithBarrier.output + val attrs = this.logicalPlan.output val colsAfterDrop = attrs.filter { attr => attr != expression }.map(attr => Column(attr)) @@ -2364,7 +2402,7 @@ class Dataset[T] private[sql]( } cols } - Deduplicate(groupCols, planWithBarrier) + Deduplicate(groupCols, logicalPlan) } /** @@ -2546,7 +2584,7 @@ class Dataset[T] private[sql]( @Experimental @InterfaceStability.Evolving def filter(func: T => Boolean): Dataset[T] = { - withTypedPlan(TypedFilter(func, planWithBarrier)) + withTypedPlan(TypedFilter(func, logicalPlan)) } /** @@ -2560,7 +2598,7 @@ class Dataset[T] private[sql]( @Experimental @InterfaceStability.Evolving def filter(func: FilterFunction[T]): Dataset[T] = { - withTypedPlan(TypedFilter(func, planWithBarrier)) + withTypedPlan(TypedFilter(func, logicalPlan)) } /** @@ -2574,7 +2612,7 @@ class Dataset[T] private[sql]( @Experimental @InterfaceStability.Evolving def map[U : Encoder](func: T => U): Dataset[U] = withTypedPlan { - MapElements[T, U](func, planWithBarrier) + MapElements[T, U](func, logicalPlan) } /** @@ -2589,7 +2627,7 @@ class Dataset[T] private[sql]( @InterfaceStability.Evolving def map[U](func: MapFunction[T, U], encoder: Encoder[U]): Dataset[U] = { implicit val uEnc = encoder - withTypedPlan(MapElements[T, U](func, planWithBarrier)) + withTypedPlan(MapElements[T, U](func, logicalPlan)) } /** @@ -2605,7 +2643,7 @@ class Dataset[T] private[sql]( def mapPartitions[U : Encoder](func: Iterator[T] => Iterator[U]): Dataset[U] = { new Dataset[U]( sparkSession, - MapPartitions[T, U](func, planWithBarrier), + MapPartitions[T, U](func, logicalPlan), implicitly[Encoder[U]]) } @@ -2636,7 +2674,7 @@ class Dataset[T] private[sql]( val rowEncoder = encoder.asInstanceOf[ExpressionEncoder[Row]] Dataset.ofRows( sparkSession, - MapPartitionsInR(func, packageNames, broadcastVars, schema, rowEncoder, planWithBarrier)) + MapPartitionsInR(func, packageNames, broadcastVars, schema, rowEncoder, logicalPlan)) } /** @@ -2800,7 +2838,7 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ def repartition(numPartitions: Int): Dataset[T] = withTypedPlan { - Repartition(numPartitions, shuffle = true, planWithBarrier) + Repartition(numPartitions, shuffle = true, logicalPlan) } /** @@ -2823,7 +2861,7 @@ class Dataset[T] private[sql]( |For range partitioning use repartitionByRange(...) instead. """.stripMargin) withTypedPlan { - RepartitionByExpression(partitionExprs.map(_.expr), planWithBarrier, numPartitions) + RepartitionByExpression(partitionExprs.map(_.expr), logicalPlan, numPartitions) } } @@ -2861,7 +2899,7 @@ class Dataset[T] private[sql]( case expr: Expression => SortOrder(expr, Ascending) }) withTypedPlan { - RepartitionByExpression(sortOrder, planWithBarrier, numPartitions) + RepartitionByExpression(sortOrder, logicalPlan, numPartitions) } } @@ -2900,7 +2938,7 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ def coalesce(numPartitions: Int): Dataset[T] = withTypedPlan { - Repartition(numPartitions, shuffle = false, planWithBarrier) + Repartition(numPartitions, shuffle = false, logicalPlan) } /** @@ -2985,7 +3023,7 @@ class Dataset[T] private[sql]( // Represents the `QueryExecution` used to produce the content of the Dataset as an `RDD`. @transient private lazy val rddQueryExecution: QueryExecution = { - val deserialized = CatalystSerde.deserialize[T](planWithBarrier) + val deserialized = CatalystSerde.deserialize[T](logicalPlan) sparkSession.sessionState.executePlan(deserialized) } @@ -3111,7 +3149,7 @@ class Dataset[T] private[sql]( comment = None, properties = Map.empty, originalText = None, - child = planWithBarrier, + child = logicalPlan, allowExisting = false, replace = replace, viewType = viewType) @@ -3235,13 +3273,49 @@ class Dataset[T] private[sql]( } /** - * Collect a Dataset as ArrowPayload byte arrays and serve to PySpark. + * Collect a Dataset as Arrow batches and serve stream to PySpark. */ private[sql] def collectAsArrowToPython(): Array[Any] = { + val timeZoneId = sparkSession.sessionState.conf.sessionLocalTimeZone + withAction("collectAsArrowToPython", queryExecution) { plan => - val iter: Iterator[Array[Byte]] = - toArrowPayload(plan).collect().iterator.map(_.asPythonSerializable) - PythonRDD.serveIterator(iter, "serve-Arrow") + PythonRDD.serveToStream("serve-Arrow") { out => + val batchWriter = new ArrowBatchStreamWriter(schema, out, timeZoneId) + val arrowBatchRdd = toArrowBatchRdd(plan) + val numPartitions = arrowBatchRdd.partitions.length + + // Store collection results for worst case of 1 to N-1 partitions + val results = new Array[Array[Array[Byte]]](numPartitions - 1) + var lastIndex = -1 // index of last partition written + + // Handler to eagerly write partitions to Python in order + def handlePartitionBatches(index: Int, arrowBatches: Array[Array[Byte]]): Unit = { + // If result is from next partition in order + if (index - 1 == lastIndex) { + batchWriter.writeBatches(arrowBatches.iterator) + lastIndex += 1 + // Write stored partitions that come next in order + while (lastIndex < results.length && results(lastIndex) != null) { + batchWriter.writeBatches(results(lastIndex).iterator) + results(lastIndex) = null + lastIndex += 1 + } + // After last batch, end the stream + if (lastIndex == results.length) { + batchWriter.end() + } + } else { + // Store partitions received out of order + results(index - 1) = arrowBatches + } + } + + sparkSession.sparkContext.runJob( + arrowBatchRdd, + (ctx: TaskContext, it: Iterator[Array[Byte]]) => it.toArray, + 0 until numPartitions, + handlePartitionBatches) + } } } @@ -3324,7 +3398,7 @@ class Dataset[T] private[sql]( } } withTypedPlan { - Sort(sortOrder, global = global, planWithBarrier) + Sort(sortOrder, global = global, logicalPlan) } } @@ -3348,20 +3422,20 @@ class Dataset[T] private[sql]( } } - /** Convert to an RDD of ArrowPayload byte arrays */ - private[sql] def toArrowPayload(plan: SparkPlan): RDD[ArrowPayload] = { + /** Convert to an RDD of serialized ArrowRecordBatches. */ + private[sql] def toArrowBatchRdd(plan: SparkPlan): RDD[Array[Byte]] = { val schemaCaptured = this.schema val maxRecordsPerBatch = sparkSession.sessionState.conf.arrowMaxRecordsPerBatch val timeZoneId = sparkSession.sessionState.conf.sessionLocalTimeZone plan.execute().mapPartitionsInternal { iter => val context = TaskContext.get() - ArrowConverters.toPayloadIterator( + ArrowConverters.toBatchIterator( iter, schemaCaptured, maxRecordsPerBatch, timeZoneId, context) } } // This is only used in tests, for now. - private[sql] def toArrowPayload: RDD[ArrowPayload] = { - toArrowPayload(queryExecution.executedPlan) + private[sql] def toArrowBatchRdd: RDD[Array[Byte]] = { + toArrowBatchRdd(queryExecution.executedPlan) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala index 36f6038aa9485..6bab21dca0cbd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala @@ -49,7 +49,7 @@ class KeyValueGroupedDataset[K, V] private[sql]( private implicit val kExprEnc = encoderFor(kEncoder) private implicit val vExprEnc = encoderFor(vEncoder) - private def logicalPlan = AnalysisBarrier(queryExecution.analyzed) + private def logicalPlan = queryExecution.analyzed private def sparkSession = queryExecution.sparkSession /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala index c6449cd5a16b0..d700fb83b9b70 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala @@ -62,18 +62,17 @@ class RelationalGroupedDataset protected[sql]( groupType match { case RelationalGroupedDataset.GroupByType => - Dataset.ofRows( - df.sparkSession, Aggregate(groupingExprs, aliasedAgg, df.planWithBarrier)) + Dataset.ofRows(df.sparkSession, Aggregate(groupingExprs, aliasedAgg, df.logicalPlan)) case RelationalGroupedDataset.RollupType => Dataset.ofRows( - df.sparkSession, Aggregate(Seq(Rollup(groupingExprs)), aliasedAgg, df.planWithBarrier)) + df.sparkSession, Aggregate(Seq(Rollup(groupingExprs)), aliasedAgg, df.logicalPlan)) case RelationalGroupedDataset.CubeType => Dataset.ofRows( - df.sparkSession, Aggregate(Seq(Cube(groupingExprs)), aliasedAgg, df.planWithBarrier)) + df.sparkSession, Aggregate(Seq(Cube(groupingExprs)), aliasedAgg, df.logicalPlan)) case RelationalGroupedDataset.PivotType(pivotCol, values) => val aliasedGrps = groupingExprs.map(alias) Dataset.ofRows( - df.sparkSession, Pivot(Some(aliasedGrps), pivotCol, values, aggExprs, df.planWithBarrier)) + df.sparkSession, Pivot(Some(aliasedGrps), pivotCol, values, aggExprs, df.logicalPlan)) } } @@ -315,7 +314,67 @@ class RelationalGroupedDataset protected[sql]( * @param pivotColumn Name of the column to pivot. * @since 1.6.0 */ - def pivot(pivotColumn: String): RelationalGroupedDataset = { + def pivot(pivotColumn: String): RelationalGroupedDataset = pivot(Column(pivotColumn)) + + /** + * Pivots a column of the current `DataFrame` and performs the specified aggregation. + * There are two versions of pivot function: one that requires the caller to specify the list + * of distinct values to pivot on, and one that does not. The latter is more concise but less + * efficient, because Spark needs to first compute the list of distinct values internally. + * + * {{{ + * // Compute the sum of earnings for each year by course with each course as a separate column + * df.groupBy("year").pivot("course", Seq("dotNET", "Java")).sum("earnings") + * + * // Or without specifying column values (less efficient) + * df.groupBy("year").pivot("course").sum("earnings") + * }}} + * + * @param pivotColumn Name of the column to pivot. + * @param values List of values that will be translated to columns in the output DataFrame. + * @since 1.6.0 + */ + def pivot(pivotColumn: String, values: Seq[Any]): RelationalGroupedDataset = { + pivot(Column(pivotColumn), values) + } + + /** + * (Java-specific) Pivots a column of the current `DataFrame` and performs the specified + * aggregation. + * + * There are two versions of pivot function: one that requires the caller to specify the list + * of distinct values to pivot on, and one that does not. The latter is more concise but less + * efficient, because Spark needs to first compute the list of distinct values internally. + * + * {{{ + * // Compute the sum of earnings for each year by course with each course as a separate column + * df.groupBy("year").pivot("course", Arrays.asList("dotNET", "Java")).sum("earnings"); + * + * // Or without specifying column values (less efficient) + * df.groupBy("year").pivot("course").sum("earnings"); + * }}} + * + * @param pivotColumn Name of the column to pivot. + * @param values List of values that will be translated to columns in the output DataFrame. + * @since 1.6.0 + */ + def pivot(pivotColumn: String, values: java.util.List[Any]): RelationalGroupedDataset = { + pivot(Column(pivotColumn), values) + } + + /** + * Pivots a column of the current `DataFrame` and performs the specified aggregation. + * This is an overloaded version of the `pivot` method with `pivotColumn` of the `String` type. + * + * {{{ + * // Or without specifying column values (less efficient) + * df.groupBy($"year").pivot($"course").sum($"earnings"); + * }}} + * + * @param pivotColumn he column to pivot. + * @since 2.4.0 + */ + def pivot(pivotColumn: Column): RelationalGroupedDataset = { // This is to prevent unintended OOM errors when the number of distinct values is large val maxValues = df.sparkSession.sessionState.conf.dataFramePivotMaxValues // Get the distinct values of the column and sort them so its consistent @@ -340,29 +399,24 @@ class RelationalGroupedDataset protected[sql]( /** * Pivots a column of the current `DataFrame` and performs the specified aggregation. - * There are two versions of pivot function: one that requires the caller to specify the list - * of distinct values to pivot on, and one that does not. The latter is more concise but less - * efficient, because Spark needs to first compute the list of distinct values internally. + * This is an overloaded version of the `pivot` method with `pivotColumn` of the `String` type. * * {{{ * // Compute the sum of earnings for each year by course with each course as a separate column - * df.groupBy("year").pivot("course", Seq("dotNET", "Java")).sum("earnings") - * - * // Or without specifying column values (less efficient) - * df.groupBy("year").pivot("course").sum("earnings") + * df.groupBy($"year").pivot($"course", Seq("dotNET", "Java")).sum($"earnings") * }}} * - * @param pivotColumn Name of the column to pivot. + * @param pivotColumn the column to pivot. * @param values List of values that will be translated to columns in the output DataFrame. - * @since 1.6.0 + * @since 2.4.0 */ - def pivot(pivotColumn: String, values: Seq[Any]): RelationalGroupedDataset = { + def pivot(pivotColumn: Column, values: Seq[Any]): RelationalGroupedDataset = { groupType match { case RelationalGroupedDataset.GroupByType => new RelationalGroupedDataset( df, groupingExprs, - RelationalGroupedDataset.PivotType(df.resolve(pivotColumn), values.map(Literal.apply))) + RelationalGroupedDataset.PivotType(pivotColumn.expr, values.map(Literal.apply))) case _: RelationalGroupedDataset.PivotType => throw new UnsupportedOperationException("repeated pivots are not supported") case _ => @@ -372,25 +426,14 @@ class RelationalGroupedDataset protected[sql]( /** * (Java-specific) Pivots a column of the current `DataFrame` and performs the specified - * aggregation. - * - * There are two versions of pivot function: one that requires the caller to specify the list - * of distinct values to pivot on, and one that does not. The latter is more concise but less - * efficient, because Spark needs to first compute the list of distinct values internally. - * - * {{{ - * // Compute the sum of earnings for each year by course with each course as a separate column - * df.groupBy("year").pivot("course", Arrays.asList("dotNET", "Java")).sum("earnings"); - * - * // Or without specifying column values (less efficient) - * df.groupBy("year").pivot("course").sum("earnings"); - * }}} + * aggregation. This is an overloaded version of the `pivot` method with `pivotColumn` of + * the `String` type. * - * @param pivotColumn Name of the column to pivot. + * @param pivotColumn the column to pivot. * @param values List of values that will be translated to columns in the output DataFrame. - * @since 1.6.0 + * @since 2.4.0 */ - def pivot(pivotColumn: String, values: java.util.List[Any]): RelationalGroupedDataset = { + def pivot(pivotColumn: Column, values: java.util.List[Any]): RelationalGroupedDataset = { pivot(pivotColumn, values.asScala) } @@ -433,7 +476,7 @@ class RelationalGroupedDataset protected[sql]( df.exprEnc.schema, groupingAttributes, df.logicalPlan.output, - df.planWithBarrier)) + df.logicalPlan)) } /** @@ -452,14 +495,14 @@ class RelationalGroupedDataset protected[sql]( require(expr.evalType == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF, "Must pass a grouped map udf") require(expr.dataType.isInstanceOf[StructType], - "The returnType of the udf must be a StructType") + s"The returnType of the udf must be a ${StructType.simpleString}") val groupingNamedExpressions = groupingExprs.map { case ne: NamedExpression => ne case other => Alias(other, other.toString)() } val groupingAttributes = groupingNamedExpressions.map(_.toAttribute) - val child = df.planWithBarrier + val child = df.logicalPlan val project = Project(groupingNamedExpressions ++ child.output, child) val output = expr.dataType.asInstanceOf[StructType].toAttributes val plan = FlatMapGroupsInPandas(groupingAttributes, expr, output, project) @@ -470,8 +513,11 @@ class RelationalGroupedDataset protected[sql]( override def toString: String = { val builder = new StringBuilder builder.append("RelationalGroupedDataset: [grouping expressions: [") - val kFields = groupingExprs.map(_.asInstanceOf[NamedExpression]).map { - case f => s"${f.name}: ${f.dataType.simpleString(2)}" + val kFields = groupingExprs.collect { + case expr: NamedExpression if expr.resolved => + s"${expr.name}: ${expr.dataType.simpleString(2)}" + case expr: NamedExpression => expr.name + case o => o.toString } builder.append(kFields.take(2).mkString(", ")) if (kFields.length > 2) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/RuntimeConfig.scala b/sql/core/src/main/scala/org/apache/spark/sql/RuntimeConfig.scala index b352e332bc7e0..3c39579149fff 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/RuntimeConfig.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/RuntimeConfig.scala @@ -132,6 +132,17 @@ class RuntimeConfig private[sql](sqlConf: SQLConf = new SQLConf) { sqlConf.unsetConf(key) } + /** + * Indicates whether the configuration property with the given key + * is modifiable in the current session. + * + * @return `true` if the configuration property is modifiable. For static SQL, Spark Core, + * invalid (not existing) and other non-modifiable configuration properties, + * the returned value is `false`. + * @since 2.4.0 + */ + def isModifiable(key: String): Boolean = sqlConf.isModifiable(key) + /** * Returns whether a particular key is set. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala index 565042fcf762e..2b847fb6f9458 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -92,7 +92,8 @@ class SparkSession private( // If there is no active SparkSession, uses the default SQL conf. Otherwise, use the session's. SQLConf.setSQLConfGetter(() => { - SparkSession.getActiveSession.map(_.sessionState.conf).getOrElse(SQLConf.getFallbackConf) + SparkSession.getActiveSession.filterNot(_.sparkContext.isStopped).map(_.sessionState.conf) + .getOrElse(SQLConf.getFallbackConf) }) /** @@ -269,7 +270,7 @@ class SparkSession private( */ @transient lazy val emptyDataFrame: DataFrame = { - createDataFrame(sparkContext.emptyRDD[Row], StructType(Nil)) + createDataFrame(sparkContext.emptyRDD[Row].setName("empty"), StructType(Nil)) } /** @@ -394,7 +395,7 @@ class SparkSession private( // BeanInfo is not serializable so we must rediscover it remotely for each partition. SQLContext.beansToRows(iter, Utils.classForName(className), attributeSeq) } - Dataset.ofRows(self, LogicalRDD(attributeSeq, rowRdd)(self)) + Dataset.ofRows(self, LogicalRDD(attributeSeq, rowRdd.setName(rdd.name))(self)) } /** @@ -593,7 +594,7 @@ class SparkSession private( } else { rowRDD.map { r: Row => InternalRow.fromSeq(r.toSeq) } } - internalCreateDataFrame(catalystRows, schema) + internalCreateDataFrame(catalystRows.setName(rowRDD.name), schema) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala index f94baef39dfad..c37ba0c60c3d4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala @@ -31,7 +31,7 @@ import org.apache.spark.sql.catalyst.analysis.FunctionRegistry import org.apache.spark.sql.catalyst.expressions.{Expression, ScalaUDF} import org.apache.spark.sql.execution.aggregate.ScalaUDAF import org.apache.spark.sql.execution.python.UserDefinedPythonFunction -import org.apache.spark.sql.expressions.{UserDefinedAggregateFunction, UserDefinedFunction} +import org.apache.spark.sql.expressions.{SparkUserDefinedFunction, UserDefinedAggregateFunction, UserDefinedFunction} import org.apache.spark.sql.types.DataType import org.apache.spark.util.Utils @@ -113,7 +113,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends (0 to 22).foreach { x => val types = (1 to x).foldRight("RT")((i, s) => {s"A$i, $s"}) val typeTags = (1 to x).map(i => s"A$i: TypeTag").foldLeft("RT: TypeTag")(_ + ", " + _) - val inputTypes = (1 to x).foldRight("Nil")((i, s) => {s"ScalaReflection.schemaFor[A$i].dataType :: $s"}) + val inputSchemas = (1 to x).foldRight("Nil")((i, s) => {s"ScalaReflection.schemaFor[A$i] :: $s"}) println(s""" |/** | * Registers a deterministic Scala closure of $x arguments as user-defined function (UDF). @@ -122,15 +122,16 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends | */ |def register[$typeTags](name: String, func: Function$x[$types]): UserDefinedFunction = { | val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - | val inputTypes = Try($inputTypes).toOption + | val inputSchemas: Option[Seq[ScalaReflection.Schema]] = Try($inputSchemas).toOption | def builder(e: Seq[Expression]) = if (e.length == $x) { - | ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable, udfDeterministic = true) + | ScalaUDF(func, dataType, e, inputSchemas.map(_.map(_.dataType)).getOrElse(Nil), Some(name), nullable, + | udfDeterministic = true, nullableTypes = inputSchemas.map(_.map(_.nullable)).getOrElse(Nil)) | } else { | throw new AnalysisException("Invalid number of arguments for function " + name + | ". Expected: $x; Found: " + e.length) | } | functionRegistry.createOrReplaceTempFunction(name, builder) - | val udf = UserDefinedFunction(func, dataType, inputTypes).withName(name) + | val udf = SparkUserDefinedFunction.create(func, dataType, inputSchemas).withName(name) | if (nullable) udf else udf.asNonNullable() |}""".stripMargin) } @@ -167,15 +168,16 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register[RT: TypeTag](name: String, func: Function0[RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - val inputTypes = Try(Nil).toOption + val inputSchemas: Option[Seq[ScalaReflection.Schema]] = Try(Nil).toOption def builder(e: Seq[Expression]) = if (e.length == 0) { - ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable, udfDeterministic = true) + ScalaUDF(func, dataType, e, inputSchemas.map(_.map(_.dataType)).getOrElse(Nil), Some(name), nullable, + udfDeterministic = true, nullableTypes = inputSchemas.map(_.map(_.nullable)).getOrElse(Nil)) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 0; Found: " + e.length) } functionRegistry.createOrReplaceTempFunction(name, builder) - val udf = UserDefinedFunction(func, dataType, inputTypes).withName(name) + val udf = SparkUserDefinedFunction.create(func, dataType, inputSchemas).withName(name) if (nullable) udf else udf.asNonNullable() } @@ -186,15 +188,16 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register[RT: TypeTag, A1: TypeTag](name: String, func: Function1[A1, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: Nil).toOption + val inputSchemas: Option[Seq[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1] :: Nil).toOption def builder(e: Seq[Expression]) = if (e.length == 1) { - ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable, udfDeterministic = true) + ScalaUDF(func, dataType, e, inputSchemas.map(_.map(_.dataType)).getOrElse(Nil), Some(name), nullable, + udfDeterministic = true, nullableTypes = inputSchemas.map(_.map(_.nullable)).getOrElse(Nil)) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 1; Found: " + e.length) } functionRegistry.createOrReplaceTempFunction(name, builder) - val udf = UserDefinedFunction(func, dataType, inputTypes).withName(name) + val udf = SparkUserDefinedFunction.create(func, dataType, inputSchemas).withName(name) if (nullable) udf else udf.asNonNullable() } @@ -205,15 +208,16 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag](name: String, func: Function2[A1, A2, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: Nil).toOption + val inputSchemas: Option[Seq[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1] :: ScalaReflection.schemaFor[A2] :: Nil).toOption def builder(e: Seq[Expression]) = if (e.length == 2) { - ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable, udfDeterministic = true) + ScalaUDF(func, dataType, e, inputSchemas.map(_.map(_.dataType)).getOrElse(Nil), Some(name), nullable, + udfDeterministic = true, nullableTypes = inputSchemas.map(_.map(_.nullable)).getOrElse(Nil)) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 2; Found: " + e.length) } functionRegistry.createOrReplaceTempFunction(name, builder) - val udf = UserDefinedFunction(func, dataType, inputTypes).withName(name) + val udf = SparkUserDefinedFunction.create(func, dataType, inputSchemas).withName(name) if (nullable) udf else udf.asNonNullable() } @@ -224,15 +228,16 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag](name: String, func: Function3[A1, A2, A3, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: Nil).toOption + val inputSchemas: Option[Seq[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1] :: ScalaReflection.schemaFor[A2] :: ScalaReflection.schemaFor[A3] :: Nil).toOption def builder(e: Seq[Expression]) = if (e.length == 3) { - ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable, udfDeterministic = true) + ScalaUDF(func, dataType, e, inputSchemas.map(_.map(_.dataType)).getOrElse(Nil), Some(name), nullable, + udfDeterministic = true, nullableTypes = inputSchemas.map(_.map(_.nullable)).getOrElse(Nil)) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 3; Found: " + e.length) } functionRegistry.createOrReplaceTempFunction(name, builder) - val udf = UserDefinedFunction(func, dataType, inputTypes).withName(name) + val udf = SparkUserDefinedFunction.create(func, dataType, inputSchemas).withName(name) if (nullable) udf else udf.asNonNullable() } @@ -243,15 +248,16 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag](name: String, func: Function4[A1, A2, A3, A4, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: Nil).toOption + val inputSchemas: Option[Seq[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1] :: ScalaReflection.schemaFor[A2] :: ScalaReflection.schemaFor[A3] :: ScalaReflection.schemaFor[A4] :: Nil).toOption def builder(e: Seq[Expression]) = if (e.length == 4) { - ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable, udfDeterministic = true) + ScalaUDF(func, dataType, e, inputSchemas.map(_.map(_.dataType)).getOrElse(Nil), Some(name), nullable, + udfDeterministic = true, nullableTypes = inputSchemas.map(_.map(_.nullable)).getOrElse(Nil)) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 4; Found: " + e.length) } functionRegistry.createOrReplaceTempFunction(name, builder) - val udf = UserDefinedFunction(func, dataType, inputTypes).withName(name) + val udf = SparkUserDefinedFunction.create(func, dataType, inputSchemas).withName(name) if (nullable) udf else udf.asNonNullable() } @@ -262,15 +268,16 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag](name: String, func: Function5[A1, A2, A3, A4, A5, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: Nil).toOption + val inputSchemas: Option[Seq[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1] :: ScalaReflection.schemaFor[A2] :: ScalaReflection.schemaFor[A3] :: ScalaReflection.schemaFor[A4] :: ScalaReflection.schemaFor[A5] :: Nil).toOption def builder(e: Seq[Expression]) = if (e.length == 5) { - ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable, udfDeterministic = true) + ScalaUDF(func, dataType, e, inputSchemas.map(_.map(_.dataType)).getOrElse(Nil), Some(name), nullable, + udfDeterministic = true, nullableTypes = inputSchemas.map(_.map(_.nullable)).getOrElse(Nil)) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 5; Found: " + e.length) } functionRegistry.createOrReplaceTempFunction(name, builder) - val udf = UserDefinedFunction(func, dataType, inputTypes).withName(name) + val udf = SparkUserDefinedFunction.create(func, dataType, inputSchemas).withName(name) if (nullable) udf else udf.asNonNullable() } @@ -281,15 +288,16 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag](name: String, func: Function6[A1, A2, A3, A4, A5, A6, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: Nil).toOption + val inputSchemas: Option[Seq[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1] :: ScalaReflection.schemaFor[A2] :: ScalaReflection.schemaFor[A3] :: ScalaReflection.schemaFor[A4] :: ScalaReflection.schemaFor[A5] :: ScalaReflection.schemaFor[A6] :: Nil).toOption def builder(e: Seq[Expression]) = if (e.length == 6) { - ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable, udfDeterministic = true) + ScalaUDF(func, dataType, e, inputSchemas.map(_.map(_.dataType)).getOrElse(Nil), Some(name), nullable, + udfDeterministic = true, nullableTypes = inputSchemas.map(_.map(_.nullable)).getOrElse(Nil)) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 6; Found: " + e.length) } functionRegistry.createOrReplaceTempFunction(name, builder) - val udf = UserDefinedFunction(func, dataType, inputTypes).withName(name) + val udf = SparkUserDefinedFunction.create(func, dataType, inputSchemas).withName(name) if (nullable) udf else udf.asNonNullable() } @@ -300,15 +308,16 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag](name: String, func: Function7[A1, A2, A3, A4, A5, A6, A7, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: Nil).toOption + val inputSchemas: Option[Seq[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1] :: ScalaReflection.schemaFor[A2] :: ScalaReflection.schemaFor[A3] :: ScalaReflection.schemaFor[A4] :: ScalaReflection.schemaFor[A5] :: ScalaReflection.schemaFor[A6] :: ScalaReflection.schemaFor[A7] :: Nil).toOption def builder(e: Seq[Expression]) = if (e.length == 7) { - ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable, udfDeterministic = true) + ScalaUDF(func, dataType, e, inputSchemas.map(_.map(_.dataType)).getOrElse(Nil), Some(name), nullable, + udfDeterministic = true, nullableTypes = inputSchemas.map(_.map(_.nullable)).getOrElse(Nil)) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 7; Found: " + e.length) } functionRegistry.createOrReplaceTempFunction(name, builder) - val udf = UserDefinedFunction(func, dataType, inputTypes).withName(name) + val udf = SparkUserDefinedFunction.create(func, dataType, inputSchemas).withName(name) if (nullable) udf else udf.asNonNullable() } @@ -319,15 +328,16 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag](name: String, func: Function8[A1, A2, A3, A4, A5, A6, A7, A8, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: Nil).toOption + val inputSchemas: Option[Seq[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1] :: ScalaReflection.schemaFor[A2] :: ScalaReflection.schemaFor[A3] :: ScalaReflection.schemaFor[A4] :: ScalaReflection.schemaFor[A5] :: ScalaReflection.schemaFor[A6] :: ScalaReflection.schemaFor[A7] :: ScalaReflection.schemaFor[A8] :: Nil).toOption def builder(e: Seq[Expression]) = if (e.length == 8) { - ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable, udfDeterministic = true) + ScalaUDF(func, dataType, e, inputSchemas.map(_.map(_.dataType)).getOrElse(Nil), Some(name), nullable, + udfDeterministic = true, nullableTypes = inputSchemas.map(_.map(_.nullable)).getOrElse(Nil)) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 8; Found: " + e.length) } functionRegistry.createOrReplaceTempFunction(name, builder) - val udf = UserDefinedFunction(func, dataType, inputTypes).withName(name) + val udf = SparkUserDefinedFunction.create(func, dataType, inputSchemas).withName(name) if (nullable) udf else udf.asNonNullable() } @@ -338,15 +348,16 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag](name: String, func: Function9[A1, A2, A3, A4, A5, A6, A7, A8, A9, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: Nil).toOption + val inputSchemas: Option[Seq[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1] :: ScalaReflection.schemaFor[A2] :: ScalaReflection.schemaFor[A3] :: ScalaReflection.schemaFor[A4] :: ScalaReflection.schemaFor[A5] :: ScalaReflection.schemaFor[A6] :: ScalaReflection.schemaFor[A7] :: ScalaReflection.schemaFor[A8] :: ScalaReflection.schemaFor[A9] :: Nil).toOption def builder(e: Seq[Expression]) = if (e.length == 9) { - ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable, udfDeterministic = true) + ScalaUDF(func, dataType, e, inputSchemas.map(_.map(_.dataType)).getOrElse(Nil), Some(name), nullable, + udfDeterministic = true, nullableTypes = inputSchemas.map(_.map(_.nullable)).getOrElse(Nil)) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 9; Found: " + e.length) } functionRegistry.createOrReplaceTempFunction(name, builder) - val udf = UserDefinedFunction(func, dataType, inputTypes).withName(name) + val udf = SparkUserDefinedFunction.create(func, dataType, inputSchemas).withName(name) if (nullable) udf else udf.asNonNullable() } @@ -357,15 +368,16 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag](name: String, func: Function10[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: Nil).toOption + val inputSchemas: Option[Seq[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1] :: ScalaReflection.schemaFor[A2] :: ScalaReflection.schemaFor[A3] :: ScalaReflection.schemaFor[A4] :: ScalaReflection.schemaFor[A5] :: ScalaReflection.schemaFor[A6] :: ScalaReflection.schemaFor[A7] :: ScalaReflection.schemaFor[A8] :: ScalaReflection.schemaFor[A9] :: ScalaReflection.schemaFor[A10] :: Nil).toOption def builder(e: Seq[Expression]) = if (e.length == 10) { - ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable, udfDeterministic = true) + ScalaUDF(func, dataType, e, inputSchemas.map(_.map(_.dataType)).getOrElse(Nil), Some(name), nullable, + udfDeterministic = true, nullableTypes = inputSchemas.map(_.map(_.nullable)).getOrElse(Nil)) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 10; Found: " + e.length) } functionRegistry.createOrReplaceTempFunction(name, builder) - val udf = UserDefinedFunction(func, dataType, inputTypes).withName(name) + val udf = SparkUserDefinedFunction.create(func, dataType, inputSchemas).withName(name) if (nullable) udf else udf.asNonNullable() } @@ -376,15 +388,16 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag](name: String, func: Function11[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: Nil).toOption + val inputSchemas: Option[Seq[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1] :: ScalaReflection.schemaFor[A2] :: ScalaReflection.schemaFor[A3] :: ScalaReflection.schemaFor[A4] :: ScalaReflection.schemaFor[A5] :: ScalaReflection.schemaFor[A6] :: ScalaReflection.schemaFor[A7] :: ScalaReflection.schemaFor[A8] :: ScalaReflection.schemaFor[A9] :: ScalaReflection.schemaFor[A10] :: ScalaReflection.schemaFor[A11] :: Nil).toOption def builder(e: Seq[Expression]) = if (e.length == 11) { - ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable, udfDeterministic = true) + ScalaUDF(func, dataType, e, inputSchemas.map(_.map(_.dataType)).getOrElse(Nil), Some(name), nullable, + udfDeterministic = true, nullableTypes = inputSchemas.map(_.map(_.nullable)).getOrElse(Nil)) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 11; Found: " + e.length) } functionRegistry.createOrReplaceTempFunction(name, builder) - val udf = UserDefinedFunction(func, dataType, inputTypes).withName(name) + val udf = SparkUserDefinedFunction.create(func, dataType, inputSchemas).withName(name) if (nullable) udf else udf.asNonNullable() } @@ -395,15 +408,16 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag](name: String, func: Function12[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: Nil).toOption + val inputSchemas: Option[Seq[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1] :: ScalaReflection.schemaFor[A2] :: ScalaReflection.schemaFor[A3] :: ScalaReflection.schemaFor[A4] :: ScalaReflection.schemaFor[A5] :: ScalaReflection.schemaFor[A6] :: ScalaReflection.schemaFor[A7] :: ScalaReflection.schemaFor[A8] :: ScalaReflection.schemaFor[A9] :: ScalaReflection.schemaFor[A10] :: ScalaReflection.schemaFor[A11] :: ScalaReflection.schemaFor[A12] :: Nil).toOption def builder(e: Seq[Expression]) = if (e.length == 12) { - ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable, udfDeterministic = true) + ScalaUDF(func, dataType, e, inputSchemas.map(_.map(_.dataType)).getOrElse(Nil), Some(name), nullable, + udfDeterministic = true, nullableTypes = inputSchemas.map(_.map(_.nullable)).getOrElse(Nil)) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 12; Found: " + e.length) } functionRegistry.createOrReplaceTempFunction(name, builder) - val udf = UserDefinedFunction(func, dataType, inputTypes).withName(name) + val udf = SparkUserDefinedFunction.create(func, dataType, inputSchemas).withName(name) if (nullable) udf else udf.asNonNullable() } @@ -414,15 +428,16 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag](name: String, func: Function13[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: Nil).toOption + val inputSchemas: Option[Seq[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1] :: ScalaReflection.schemaFor[A2] :: ScalaReflection.schemaFor[A3] :: ScalaReflection.schemaFor[A4] :: ScalaReflection.schemaFor[A5] :: ScalaReflection.schemaFor[A6] :: ScalaReflection.schemaFor[A7] :: ScalaReflection.schemaFor[A8] :: ScalaReflection.schemaFor[A9] :: ScalaReflection.schemaFor[A10] :: ScalaReflection.schemaFor[A11] :: ScalaReflection.schemaFor[A12] :: ScalaReflection.schemaFor[A13] :: Nil).toOption def builder(e: Seq[Expression]) = if (e.length == 13) { - ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable, udfDeterministic = true) + ScalaUDF(func, dataType, e, inputSchemas.map(_.map(_.dataType)).getOrElse(Nil), Some(name), nullable, + udfDeterministic = true, nullableTypes = inputSchemas.map(_.map(_.nullable)).getOrElse(Nil)) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 13; Found: " + e.length) } functionRegistry.createOrReplaceTempFunction(name, builder) - val udf = UserDefinedFunction(func, dataType, inputTypes).withName(name) + val udf = SparkUserDefinedFunction.create(func, dataType, inputSchemas).withName(name) if (nullable) udf else udf.asNonNullable() } @@ -433,15 +448,16 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag](name: String, func: Function14[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: Nil).toOption + val inputSchemas: Option[Seq[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1] :: ScalaReflection.schemaFor[A2] :: ScalaReflection.schemaFor[A3] :: ScalaReflection.schemaFor[A4] :: ScalaReflection.schemaFor[A5] :: ScalaReflection.schemaFor[A6] :: ScalaReflection.schemaFor[A7] :: ScalaReflection.schemaFor[A8] :: ScalaReflection.schemaFor[A9] :: ScalaReflection.schemaFor[A10] :: ScalaReflection.schemaFor[A11] :: ScalaReflection.schemaFor[A12] :: ScalaReflection.schemaFor[A13] :: ScalaReflection.schemaFor[A14] :: Nil).toOption def builder(e: Seq[Expression]) = if (e.length == 14) { - ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable, udfDeterministic = true) + ScalaUDF(func, dataType, e, inputSchemas.map(_.map(_.dataType)).getOrElse(Nil), Some(name), nullable, + udfDeterministic = true, nullableTypes = inputSchemas.map(_.map(_.nullable)).getOrElse(Nil)) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 14; Found: " + e.length) } functionRegistry.createOrReplaceTempFunction(name, builder) - val udf = UserDefinedFunction(func, dataType, inputTypes).withName(name) + val udf = SparkUserDefinedFunction.create(func, dataType, inputSchemas).withName(name) if (nullable) udf else udf.asNonNullable() } @@ -452,15 +468,16 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag](name: String, func: Function15[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: Nil).toOption + val inputSchemas: Option[Seq[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1] :: ScalaReflection.schemaFor[A2] :: ScalaReflection.schemaFor[A3] :: ScalaReflection.schemaFor[A4] :: ScalaReflection.schemaFor[A5] :: ScalaReflection.schemaFor[A6] :: ScalaReflection.schemaFor[A7] :: ScalaReflection.schemaFor[A8] :: ScalaReflection.schemaFor[A9] :: ScalaReflection.schemaFor[A10] :: ScalaReflection.schemaFor[A11] :: ScalaReflection.schemaFor[A12] :: ScalaReflection.schemaFor[A13] :: ScalaReflection.schemaFor[A14] :: ScalaReflection.schemaFor[A15] :: Nil).toOption def builder(e: Seq[Expression]) = if (e.length == 15) { - ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable, udfDeterministic = true) + ScalaUDF(func, dataType, e, inputSchemas.map(_.map(_.dataType)).getOrElse(Nil), Some(name), nullable, + udfDeterministic = true, nullableTypes = inputSchemas.map(_.map(_.nullable)).getOrElse(Nil)) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 15; Found: " + e.length) } functionRegistry.createOrReplaceTempFunction(name, builder) - val udf = UserDefinedFunction(func, dataType, inputTypes).withName(name) + val udf = SparkUserDefinedFunction.create(func, dataType, inputSchemas).withName(name) if (nullable) udf else udf.asNonNullable() } @@ -471,15 +488,16 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag](name: String, func: Function16[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: ScalaReflection.schemaFor[A16].dataType :: Nil).toOption + val inputSchemas: Option[Seq[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1] :: ScalaReflection.schemaFor[A2] :: ScalaReflection.schemaFor[A3] :: ScalaReflection.schemaFor[A4] :: ScalaReflection.schemaFor[A5] :: ScalaReflection.schemaFor[A6] :: ScalaReflection.schemaFor[A7] :: ScalaReflection.schemaFor[A8] :: ScalaReflection.schemaFor[A9] :: ScalaReflection.schemaFor[A10] :: ScalaReflection.schemaFor[A11] :: ScalaReflection.schemaFor[A12] :: ScalaReflection.schemaFor[A13] :: ScalaReflection.schemaFor[A14] :: ScalaReflection.schemaFor[A15] :: ScalaReflection.schemaFor[A16] :: Nil).toOption def builder(e: Seq[Expression]) = if (e.length == 16) { - ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable, udfDeterministic = true) + ScalaUDF(func, dataType, e, inputSchemas.map(_.map(_.dataType)).getOrElse(Nil), Some(name), nullable, + udfDeterministic = true, nullableTypes = inputSchemas.map(_.map(_.nullable)).getOrElse(Nil)) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 16; Found: " + e.length) } functionRegistry.createOrReplaceTempFunction(name, builder) - val udf = UserDefinedFunction(func, dataType, inputTypes).withName(name) + val udf = SparkUserDefinedFunction.create(func, dataType, inputSchemas).withName(name) if (nullable) udf else udf.asNonNullable() } @@ -490,15 +508,16 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag](name: String, func: Function17[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: ScalaReflection.schemaFor[A16].dataType :: ScalaReflection.schemaFor[A17].dataType :: Nil).toOption + val inputSchemas: Option[Seq[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1] :: ScalaReflection.schemaFor[A2] :: ScalaReflection.schemaFor[A3] :: ScalaReflection.schemaFor[A4] :: ScalaReflection.schemaFor[A5] :: ScalaReflection.schemaFor[A6] :: ScalaReflection.schemaFor[A7] :: ScalaReflection.schemaFor[A8] :: ScalaReflection.schemaFor[A9] :: ScalaReflection.schemaFor[A10] :: ScalaReflection.schemaFor[A11] :: ScalaReflection.schemaFor[A12] :: ScalaReflection.schemaFor[A13] :: ScalaReflection.schemaFor[A14] :: ScalaReflection.schemaFor[A15] :: ScalaReflection.schemaFor[A16] :: ScalaReflection.schemaFor[A17] :: Nil).toOption def builder(e: Seq[Expression]) = if (e.length == 17) { - ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable, udfDeterministic = true) + ScalaUDF(func, dataType, e, inputSchemas.map(_.map(_.dataType)).getOrElse(Nil), Some(name), nullable, + udfDeterministic = true, nullableTypes = inputSchemas.map(_.map(_.nullable)).getOrElse(Nil)) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 17; Found: " + e.length) } functionRegistry.createOrReplaceTempFunction(name, builder) - val udf = UserDefinedFunction(func, dataType, inputTypes).withName(name) + val udf = SparkUserDefinedFunction.create(func, dataType, inputSchemas).withName(name) if (nullable) udf else udf.asNonNullable() } @@ -509,15 +528,16 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag](name: String, func: Function18[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: ScalaReflection.schemaFor[A16].dataType :: ScalaReflection.schemaFor[A17].dataType :: ScalaReflection.schemaFor[A18].dataType :: Nil).toOption + val inputSchemas: Option[Seq[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1] :: ScalaReflection.schemaFor[A2] :: ScalaReflection.schemaFor[A3] :: ScalaReflection.schemaFor[A4] :: ScalaReflection.schemaFor[A5] :: ScalaReflection.schemaFor[A6] :: ScalaReflection.schemaFor[A7] :: ScalaReflection.schemaFor[A8] :: ScalaReflection.schemaFor[A9] :: ScalaReflection.schemaFor[A10] :: ScalaReflection.schemaFor[A11] :: ScalaReflection.schemaFor[A12] :: ScalaReflection.schemaFor[A13] :: ScalaReflection.schemaFor[A14] :: ScalaReflection.schemaFor[A15] :: ScalaReflection.schemaFor[A16] :: ScalaReflection.schemaFor[A17] :: ScalaReflection.schemaFor[A18] :: Nil).toOption def builder(e: Seq[Expression]) = if (e.length == 18) { - ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable, udfDeterministic = true) + ScalaUDF(func, dataType, e, inputSchemas.map(_.map(_.dataType)).getOrElse(Nil), Some(name), nullable, + udfDeterministic = true, nullableTypes = inputSchemas.map(_.map(_.nullable)).getOrElse(Nil)) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 18; Found: " + e.length) } functionRegistry.createOrReplaceTempFunction(name, builder) - val udf = UserDefinedFunction(func, dataType, inputTypes).withName(name) + val udf = SparkUserDefinedFunction.create(func, dataType, inputSchemas).withName(name) if (nullable) udf else udf.asNonNullable() } @@ -528,15 +548,16 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag, A19: TypeTag](name: String, func: Function19[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, A19, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: ScalaReflection.schemaFor[A16].dataType :: ScalaReflection.schemaFor[A17].dataType :: ScalaReflection.schemaFor[A18].dataType :: ScalaReflection.schemaFor[A19].dataType :: Nil).toOption + val inputSchemas: Option[Seq[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1] :: ScalaReflection.schemaFor[A2] :: ScalaReflection.schemaFor[A3] :: ScalaReflection.schemaFor[A4] :: ScalaReflection.schemaFor[A5] :: ScalaReflection.schemaFor[A6] :: ScalaReflection.schemaFor[A7] :: ScalaReflection.schemaFor[A8] :: ScalaReflection.schemaFor[A9] :: ScalaReflection.schemaFor[A10] :: ScalaReflection.schemaFor[A11] :: ScalaReflection.schemaFor[A12] :: ScalaReflection.schemaFor[A13] :: ScalaReflection.schemaFor[A14] :: ScalaReflection.schemaFor[A15] :: ScalaReflection.schemaFor[A16] :: ScalaReflection.schemaFor[A17] :: ScalaReflection.schemaFor[A18] :: ScalaReflection.schemaFor[A19] :: Nil).toOption def builder(e: Seq[Expression]) = if (e.length == 19) { - ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable, udfDeterministic = true) + ScalaUDF(func, dataType, e, inputSchemas.map(_.map(_.dataType)).getOrElse(Nil), Some(name), nullable, + udfDeterministic = true, nullableTypes = inputSchemas.map(_.map(_.nullable)).getOrElse(Nil)) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 19; Found: " + e.length) } functionRegistry.createOrReplaceTempFunction(name, builder) - val udf = UserDefinedFunction(func, dataType, inputTypes).withName(name) + val udf = SparkUserDefinedFunction.create(func, dataType, inputSchemas).withName(name) if (nullable) udf else udf.asNonNullable() } @@ -547,15 +568,16 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag, A19: TypeTag, A20: TypeTag](name: String, func: Function20[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, A19, A20, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: ScalaReflection.schemaFor[A16].dataType :: ScalaReflection.schemaFor[A17].dataType :: ScalaReflection.schemaFor[A18].dataType :: ScalaReflection.schemaFor[A19].dataType :: ScalaReflection.schemaFor[A20].dataType :: Nil).toOption + val inputSchemas: Option[Seq[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1] :: ScalaReflection.schemaFor[A2] :: ScalaReflection.schemaFor[A3] :: ScalaReflection.schemaFor[A4] :: ScalaReflection.schemaFor[A5] :: ScalaReflection.schemaFor[A6] :: ScalaReflection.schemaFor[A7] :: ScalaReflection.schemaFor[A8] :: ScalaReflection.schemaFor[A9] :: ScalaReflection.schemaFor[A10] :: ScalaReflection.schemaFor[A11] :: ScalaReflection.schemaFor[A12] :: ScalaReflection.schemaFor[A13] :: ScalaReflection.schemaFor[A14] :: ScalaReflection.schemaFor[A15] :: ScalaReflection.schemaFor[A16] :: ScalaReflection.schemaFor[A17] :: ScalaReflection.schemaFor[A18] :: ScalaReflection.schemaFor[A19] :: ScalaReflection.schemaFor[A20] :: Nil).toOption def builder(e: Seq[Expression]) = if (e.length == 20) { - ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable, udfDeterministic = true) + ScalaUDF(func, dataType, e, inputSchemas.map(_.map(_.dataType)).getOrElse(Nil), Some(name), nullable, + udfDeterministic = true, nullableTypes = inputSchemas.map(_.map(_.nullable)).getOrElse(Nil)) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 20; Found: " + e.length) } functionRegistry.createOrReplaceTempFunction(name, builder) - val udf = UserDefinedFunction(func, dataType, inputTypes).withName(name) + val udf = SparkUserDefinedFunction.create(func, dataType, inputSchemas).withName(name) if (nullable) udf else udf.asNonNullable() } @@ -566,15 +588,16 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag, A19: TypeTag, A20: TypeTag, A21: TypeTag](name: String, func: Function21[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, A19, A20, A21, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: ScalaReflection.schemaFor[A16].dataType :: ScalaReflection.schemaFor[A17].dataType :: ScalaReflection.schemaFor[A18].dataType :: ScalaReflection.schemaFor[A19].dataType :: ScalaReflection.schemaFor[A20].dataType :: ScalaReflection.schemaFor[A21].dataType :: Nil).toOption + val inputSchemas: Option[Seq[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1] :: ScalaReflection.schemaFor[A2] :: ScalaReflection.schemaFor[A3] :: ScalaReflection.schemaFor[A4] :: ScalaReflection.schemaFor[A5] :: ScalaReflection.schemaFor[A6] :: ScalaReflection.schemaFor[A7] :: ScalaReflection.schemaFor[A8] :: ScalaReflection.schemaFor[A9] :: ScalaReflection.schemaFor[A10] :: ScalaReflection.schemaFor[A11] :: ScalaReflection.schemaFor[A12] :: ScalaReflection.schemaFor[A13] :: ScalaReflection.schemaFor[A14] :: ScalaReflection.schemaFor[A15] :: ScalaReflection.schemaFor[A16] :: ScalaReflection.schemaFor[A17] :: ScalaReflection.schemaFor[A18] :: ScalaReflection.schemaFor[A19] :: ScalaReflection.schemaFor[A20] :: ScalaReflection.schemaFor[A21] :: Nil).toOption def builder(e: Seq[Expression]) = if (e.length == 21) { - ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable, udfDeterministic = true) + ScalaUDF(func, dataType, e, inputSchemas.map(_.map(_.dataType)).getOrElse(Nil), Some(name), nullable, + udfDeterministic = true, nullableTypes = inputSchemas.map(_.map(_.nullable)).getOrElse(Nil)) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 21; Found: " + e.length) } functionRegistry.createOrReplaceTempFunction(name, builder) - val udf = UserDefinedFunction(func, dataType, inputTypes).withName(name) + val udf = SparkUserDefinedFunction.create(func, dataType, inputSchemas).withName(name) if (nullable) udf else udf.asNonNullable() } @@ -585,15 +608,16 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag, A19: TypeTag, A20: TypeTag, A21: TypeTag, A22: TypeTag](name: String, func: Function22[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, A19, A20, A21, A22, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: ScalaReflection.schemaFor[A16].dataType :: ScalaReflection.schemaFor[A17].dataType :: ScalaReflection.schemaFor[A18].dataType :: ScalaReflection.schemaFor[A19].dataType :: ScalaReflection.schemaFor[A20].dataType :: ScalaReflection.schemaFor[A21].dataType :: ScalaReflection.schemaFor[A22].dataType :: Nil).toOption + val inputSchemas: Option[Seq[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1] :: ScalaReflection.schemaFor[A2] :: ScalaReflection.schemaFor[A3] :: ScalaReflection.schemaFor[A4] :: ScalaReflection.schemaFor[A5] :: ScalaReflection.schemaFor[A6] :: ScalaReflection.schemaFor[A7] :: ScalaReflection.schemaFor[A8] :: ScalaReflection.schemaFor[A9] :: ScalaReflection.schemaFor[A10] :: ScalaReflection.schemaFor[A11] :: ScalaReflection.schemaFor[A12] :: ScalaReflection.schemaFor[A13] :: ScalaReflection.schemaFor[A14] :: ScalaReflection.schemaFor[A15] :: ScalaReflection.schemaFor[A16] :: ScalaReflection.schemaFor[A17] :: ScalaReflection.schemaFor[A18] :: ScalaReflection.schemaFor[A19] :: ScalaReflection.schemaFor[A20] :: ScalaReflection.schemaFor[A21] :: ScalaReflection.schemaFor[A22] :: Nil).toOption def builder(e: Seq[Expression]) = if (e.length == 22) { - ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable, udfDeterministic = true) + ScalaUDF(func, dataType, e, inputSchemas.map(_.map(_.dataType)).getOrElse(Nil), Some(name), nullable, + udfDeterministic = true, nullableTypes = inputSchemas.map(_.map(_.nullable)).getOrElse(Nil)) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 22; Found: " + e.length) } functionRegistry.createOrReplaceTempFunction(name, builder) - val udf = UserDefinedFunction(func, dataType, inputTypes).withName(name) + val udf = SparkUserDefinedFunction.create(func, dataType, inputSchemas).withName(name) if (nullable) udf else udf.asNonNullable() } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala index b33760b1edbc6..c0830e77b5a87 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.api.python -import org.apache.spark.api.java.JavaRDD import org.apache.spark.sql.{DataFrame, SQLContext} import org.apache.spark.sql.catalyst.analysis.FunctionRegistry import org.apache.spark.sql.catalyst.expressions.ExpressionInfo @@ -34,17 +33,19 @@ private[sql] object PythonSQLUtils { } /** - * Python Callable function to convert ArrowPayloads into a [[DataFrame]]. + * Python callable function to read a file in Arrow stream format and create a [[DataFrame]] + * using each serialized ArrowRecordBatch as a partition. * - * @param payloadRDD A JavaRDD of ArrowPayloads. - * @param schemaString JSON Formatted Schema for ArrowPayloads. * @param sqlContext The active [[SQLContext]]. - * @return The converted [[DataFrame]]. + * @param filename File to read the Arrow stream from. + * @param schemaString JSON Formatted Spark schema for Arrow batches. + * @return A new [[DataFrame]]. */ - def arrowPayloadToDataFrame( - payloadRDD: JavaRDD[Array[Byte]], - schemaString: String, - sqlContext: SQLContext): DataFrame = { - ArrowConverters.toDataFrame(payloadRDD, schemaString, sqlContext) + def arrowReadStreamFromFile( + sqlContext: SQLContext, + filename: String, + schemaString: String): DataFrame = { + val jrdd = ArrowConverters.readArrowStreamFromFile(sqlContext, filename) + ArrowConverters.toDataFrame(jrdd, schemaString, sqlContext) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala index 39d9a95ca4710..c9929935fb8ac 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala @@ -26,7 +26,7 @@ import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.spark.internal.Logging import org.apache.spark.sql.{Dataset, SparkSession} import org.apache.spark.sql.catalyst.expressions.SubqueryExpression -import org.apache.spark.sql.catalyst.plans.logical.{AnalysisBarrier, LogicalPlan, ResolvedHint} +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, ResolvedHint} import org.apache.spark.sql.execution.columnar.InMemoryRelation import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, LogicalRelation} import org.apache.spark.storage.StorageLevel @@ -97,7 +97,7 @@ class CacheManager extends Logging { val inMemoryRelation = InMemoryRelation( sparkSession.sessionState.conf.useCompression, sparkSession.sessionState.conf.columnBatchSize, storageLevel, - sparkSession.sessionState.executePlan(AnalysisBarrier(planToCache)).executedPlan, + sparkSession.sessionState.executePlan(planToCache).executedPlan, tableName, planToCache) cachedData.add(CachedData(planToCache, inMemoryRelation)) @@ -173,7 +173,7 @@ class CacheManager extends Logging { // Remove the cache entry before we create a new one, so that we can have a different // physical plan. it.remove() - val plan = spark.sessionState.executePlan(AnalysisBarrier(cd.plan)).executedPlan + val plan = spark.sessionState.executePlan(cd.plan).executedPlan val newCache = InMemoryRelation( cacheBuilder = cd.cachedRepresentation.cacheBuilder.withCachedPlan(plan), logicalPlan = cd.plan) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala index d7f2654be0451..36ed016773b67 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala @@ -166,10 +166,12 @@ case class FileSourceScanExec( override val tableIdentifier: Option[TableIdentifier]) extends DataSourceScanExec with ColumnarBatchScan { - override val supportsBatch: Boolean = relation.fileFormat.supportBatch( + // Note that some vals referring the file-based relation are lazy intentionally + // so that this plan can be canonicalized on executor side too. See SPARK-23731. + override lazy val supportsBatch: Boolean = relation.fileFormat.supportBatch( relation.sparkSession, StructType.fromAttributes(output)) - override val needsUnsafeRowConversion: Boolean = { + override lazy val needsUnsafeRowConversion: Boolean = { if (relation.fileFormat.isInstanceOf[ParquetSource]) { SparkSession.getActiveSession.get.sessionState.conf.parquetVectorizedReaderEnabled } else { @@ -199,7 +201,7 @@ case class FileSourceScanExec( ret } - override val (outputPartitioning, outputOrdering): (Partitioning, Seq[SortOrder]) = { + override lazy val (outputPartitioning, outputOrdering): (Partitioning, Seq[SortOrder]) = { val bucketSpec = if (relation.sparkSession.sessionState.conf.bucketingEnabled) { relation.bucketSpec } else { @@ -270,7 +272,7 @@ case class FileSourceScanExec( private val pushedDownFilters = dataFilters.flatMap(DataSourceStrategy.translateFilter) logInfo(s"Pushed Filters: ${pushedDownFilters.mkString(",")}") - override val metadata: Map[String, String] = { + override lazy val metadata: Map[String, String] = { def seqToString(seq: Seq[Any]) = seq.mkString("[", ", ", "]") val location = relation.location val locationDesc = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala index be50a1571a2ff..2962becb64e88 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala @@ -103,6 +103,10 @@ case class ExternalRDDScanExec[T]( override lazy val metrics = Map( "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows")) + private def rddName: String = Option(rdd.name).map(n => s" $n").getOrElse("") + + override val nodeName: String = s"Scan$rddName" + protected override def doExecute(): RDD[InternalRow] = { val numOutputRows = longMetric("numOutputRows") val outputDataType = outputObjAttr.dataType @@ -116,7 +120,7 @@ case class ExternalRDDScanExec[T]( } override def simpleString: String = { - s"Scan $nodeName${output.mkString("[", ",", "]")}" + s"$nodeName${output.mkString("[", ",", "]")}" } } @@ -169,10 +173,14 @@ case class LogicalRDD( case class RDDScanExec( output: Seq[Attribute], rdd: RDD[InternalRow], - override val nodeName: String, + name: String, override val outputPartitioning: Partitioning = UnknownPartitioning(0), override val outputOrdering: Seq[SortOrder] = Nil) extends LeafExecNode { + private def rddName: String = Option(rdd.name).map(n => s" $n").getOrElse("") + + override val nodeName: String = s"Scan $name$rddName" + override lazy val metrics = Map( "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows")) @@ -189,6 +197,6 @@ case class RDDScanExec( } override def simpleString: String = { - s"Scan $nodeName${Utils.truncatedString(output, "[", ",", "]")}" + s"$nodeName${Utils.truncatedString(output, "[", ",", "]")}" } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GetStructFieldObject.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GetStructFieldObject.scala new file mode 100644 index 0000000000000..c88b2f8c034fc --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GetStructFieldObject.scala @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import org.apache.spark.sql.catalyst.expressions.{Expression, GetStructField} +import org.apache.spark.sql.types.StructField + +/** + * A Scala extractor that extracts the child expression and struct field from a [[GetStructField]]. + * This is in contrast to the [[GetStructField]] case class extractor which returns the field + * ordinal instead of the field itself. + */ +private[execution] object GetStructFieldObject { + def unapply(getStructField: GetStructField): Option[(Expression, StructField)] = + Some(( + getStructField.child, + getStructField.childSchema(getStructField.ordinal))) +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ProjectionOverSchema.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ProjectionOverSchema.scala new file mode 100644 index 0000000000000..612a7b87b9832 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ProjectionOverSchema.scala @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.types._ + +/** + * A Scala extractor that projects an expression over a given schema. Data types, + * field indexes and field counts of complex type extractors and attributes + * are adjusted to fit the schema. All other expressions are left as-is. This + * class is motivated by columnar nested schema pruning. + */ +private[execution] case class ProjectionOverSchema(schema: StructType) { + private val fieldNames = schema.fieldNames.toSet + + def unapply(expr: Expression): Option[Expression] = getProjection(expr) + + private def getProjection(expr: Expression): Option[Expression] = + expr match { + case a: AttributeReference if fieldNames.contains(a.name) => + Some(a.copy(dataType = schema(a.name).dataType)(a.exprId, a.qualifier)) + case GetArrayItem(child, arrayItemOrdinal) => + getProjection(child).map { projection => GetArrayItem(projection, arrayItemOrdinal) } + case a: GetArrayStructFields => + getProjection(a.child).map(p => (p, p.dataType)).map { + case (projection, ArrayType(projSchema @ StructType(_), _)) => + GetArrayStructFields(projection, + projSchema(a.field.name), + projSchema.fieldIndex(a.field.name), + projSchema.size, + a.containsNull) + case (_, projSchema) => + throw new IllegalStateException( + s"unmatched child schema for GetArrayStructFields: ${projSchema.toString}" + ) + } + case GetMapValue(child, key) => + getProjection(child).map { projection => GetMapValue(projection, key) } + case GetStructFieldObject(child, field: StructField) => + getProjection(child).map(p => (p, p.dataType)).map { + case (projection, projSchema: StructType) => + GetStructField(projection, projSchema.fieldIndex(field.name)) + case (_, projSchema) => + throw new IllegalStateException( + s"unmatched child schema for GetStructField: ${projSchema.toString}" + ) + } + case _ => + None + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala index 3112b306c365e..64f49e2d0d4e6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala @@ -89,7 +89,6 @@ class QueryExecution(val sparkSession: SparkSession, val logical: LogicalPlan) { /** A sequence of rules that will be applied in order to the physical plan before execution. */ protected def preparations: Seq[Rule[SparkPlan]] = Seq( - python.ExtractPythonUDFs, PlanSubqueries(sparkSession), EnsureRequirements(sparkSession.sessionState.conf), CollapseCodegenStages(sparkSession.sessionState.conf), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SelectedField.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SelectedField.scala new file mode 100644 index 0000000000000..0e7c593f9fb67 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SelectedField.scala @@ -0,0 +1,134 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.types._ + +/** + * A Scala extractor that builds a [[org.apache.spark.sql.types.StructField]] from a Catalyst + * complex type extractor. For example, consider a relation with the following schema: + * + * {{{ + * root + * |-- name: struct (nullable = true) + * | |-- first: string (nullable = true) + * | |-- last: string (nullable = true) + * }}} + * + * Further, suppose we take the select expression `name.first`. This will parse into an + * `Alias(child, "first")`. Ignoring the alias, `child` matches the following pattern: + * + * {{{ + * GetStructFieldObject( + * AttributeReference("name", StructType(_), _, _), + * StructField("first", StringType, _, _)) + * }}} + * + * [[SelectedField]] converts that expression into + * + * {{{ + * StructField("name", StructType(Array(StructField("first", StringType)))) + * }}} + * + * by mapping each complex type extractor to a [[org.apache.spark.sql.types.StructField]] with the + * same name as its child (or "parent" going right to left in the select expression) and a data + * type appropriate to the complex type extractor. In our example, the name of the child expression + * is "name" and its data type is a [[org.apache.spark.sql.types.StructType]] with a single string + * field named "first". + * + * @param expr the top-level complex type extractor + */ +private[execution] object SelectedField { + def unapply(expr: Expression): Option[StructField] = { + // If this expression is an alias, work on its child instead + val unaliased = expr match { + case Alias(child, _) => child + case expr => expr + } + selectField(unaliased, None) + } + + private def selectField(expr: Expression, fieldOpt: Option[StructField]): Option[StructField] = { + expr match { + // No children. Returns a StructField with the attribute name or None if fieldOpt is None. + case AttributeReference(name, dataType, nullable, metadata) => + fieldOpt.map(field => + StructField(name, wrapStructType(dataType, field), nullable, metadata)) + // Handles case "expr0.field[n]", where "expr0" is of struct type and "expr0.field" is of + // array type. + case GetArrayItem(x @ GetStructFieldObject(child, field @ StructField(name, + dataType, nullable, metadata)), _) => + val childField = fieldOpt.map(field => StructField(name, + wrapStructType(dataType, field), nullable, metadata)).getOrElse(field) + selectField(child, Some(childField)) + // Handles case "expr0.field[n]", where "expr0.field" is of array type. + case GetArrayItem(child, _) => + selectField(child, fieldOpt) + // Handles case "expr0.field.subfield", where "expr0" and "expr0.field" are of array type. + case GetArrayStructFields(child: GetArrayStructFields, + field @ StructField(name, dataType, nullable, metadata), _, _, _) => + val childField = fieldOpt.map(field => StructField(name, + wrapStructType(dataType, field), + nullable, metadata)).orElse(Some(field)) + selectField(child, childField) + // Handles case "expr0.field", where "expr0" is of array type. + case GetArrayStructFields(child, + field @ StructField(name, dataType, nullable, metadata), _, _, _) => + val childField = + fieldOpt.map(field => StructField(name, + wrapStructType(dataType, field), + nullable, metadata)).orElse(Some(field)) + selectField(child, childField) + // Handles case "expr0.field[key]", where "expr0" is of struct type and "expr0.field" is of + // map type. + case GetMapValue(x @ GetStructFieldObject(child, field @ StructField(name, + dataType, + nullable, metadata)), _) => + val childField = fieldOpt.map(field => StructField(name, + wrapStructType(dataType, field), + nullable, metadata)).orElse(Some(field)) + selectField(child, childField) + // Handles case "expr0.field[key]", where "expr0.field" is of map type. + case GetMapValue(child, _) => + selectField(child, fieldOpt) + // Handles case "expr0.field", where expr0 is of struct type. + case GetStructFieldObject(child, + field @ StructField(name, dataType, nullable, metadata)) => + val childField = fieldOpt.map(field => StructField(name, + wrapStructType(dataType, field), + nullable, metadata)).orElse(Some(field)) + selectField(child, childField) + case _ => + None + } + } + + // Constructs a composition of complex types with a StructType(Array(field)) at its core. Returns + // a StructType for a StructType, an ArrayType for an ArrayType and a MapType for a MapType. + private def wrapStructType(dataType: DataType, field: StructField): DataType = { + dataType match { + case _: StructType => + StructType(Array(field)) + case ArrayType(elementType, containsNull) => + ArrayType(wrapStructType(elementType, field), containsNull) + case MapType(keyType, valueType, valueContainsNull) => + MapType(keyType, wrapStructType(valueType, field), valueContainsNull) + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala index 00ff4c8ac310b..6c6d344240cea 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala @@ -21,20 +21,26 @@ import org.apache.spark.sql.ExperimentalMethods import org.apache.spark.sql.catalyst.catalog.SessionCatalog import org.apache.spark.sql.catalyst.optimizer.Optimizer import org.apache.spark.sql.execution.datasources.PruneFileSourcePartitions -import org.apache.spark.sql.execution.python.ExtractPythonUDFFromAggregate +import org.apache.spark.sql.execution.datasources.parquet.ParquetSchemaPruning +import org.apache.spark.sql.execution.python.{ExtractPythonUDFFromAggregate, ExtractPythonUDFs} class SparkOptimizer( catalog: SessionCatalog, experimentalMethods: ExperimentalMethods) extends Optimizer(catalog) { - override def batches: Seq[Batch] = (preOptimizationBatches ++ super.batches :+ + override def defaultBatches: Seq[Batch] = (preOptimizationBatches ++ super.defaultBatches :+ Batch("Optimize Metadata Only Query", Once, OptimizeMetadataOnlyQuery(catalog)) :+ - Batch("Extract Python UDF from Aggregate", Once, ExtractPythonUDFFromAggregate) :+ - Batch("Prune File Source Table Partitions", Once, PruneFileSourcePartitions)) ++ + Batch("Extract Python UDFs", Once, + Seq(ExtractPythonUDFFromAggregate, ExtractPythonUDFs): _*) :+ + Batch("Prune File Source Table Partitions", Once, PruneFileSourcePartitions) :+ + Batch("Parquet Schema Pruning", Once, ParquetSchemaPruning)) ++ postHocOptimizationBatches :+ Batch("User Provided Optimizers", fixedPoint, experimentalMethods.extraOptimizations: _*) + override def nonExcludableRules: Seq[String] = + super.nonExcludableRules :+ ExtractPythonUDFFromAggregate.ruleName + /** * Optimization batches that are executed before the regular optimization batches (also before * the finish analysis batch). diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index 398758a3331b4..1f97993e20458 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -47,17 +47,15 @@ import org.apache.spark.util.ThreadUtils abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializable { /** - * A handle to the SQL Context that was used to create this plan. Since many operators need + * A handle to the SQL Context that was used to create this plan. Since many operators need * access to the sqlContext for RDD operations or configuration this field is automatically * populated by the query planning infrastructure. */ - @transient - final val sqlContext = SparkSession.getActiveSession.map(_.sqlContext).orNull + @transient final val sqlContext = SparkSession.getActiveSession.map(_.sqlContext).orNull protected def sparkContext = sqlContext.sparkContext // sqlContext will be null when SparkPlan nodes are created without the active sessions. - // So far, this only happens in the test cases. val subexpressionEliminationEnabled: Boolean = if (sqlContext != null) { sqlContext.conf.subexpressionEliminationEnabled } else { @@ -69,7 +67,9 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ /** Overridden make copy also propagates sqlContext to copied plan. */ override def makeCopy(newArgs: Array[AnyRef]): SparkPlan = { - SparkSession.setActiveSession(sqlContext.sparkSession) + if (sqlContext != null) { + SparkSession.setActiveSession(sqlContext.sparkSession) + } super.makeCopy(newArgs) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala index 2a2315896831c..59ffd16381116 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala @@ -17,8 +17,6 @@ package org.apache.spark.sql.execution -import com.fasterxml.jackson.annotation.JsonIgnoreProperties - import org.apache.spark.annotation.DeveloperApi import org.apache.spark.sql.execution.exchange.ReusedExchangeExec import org.apache.spark.sql.execution.metric.SQLMetricInfo @@ -28,11 +26,11 @@ import org.apache.spark.sql.execution.metric.SQLMetricInfo * Stores information about a SQL SparkPlan. */ @DeveloperApi -@JsonIgnoreProperties(Array("metadata")) // The metadata field was removed in Spark 2.3. class SparkPlanInfo( val nodeName: String, val simpleString: String, val children: Seq[SparkPlanInfo], + val metadata: Map[String, String], val metrics: Seq[SQLMetricInfo]) { override def hashCode(): Int = { @@ -59,6 +57,12 @@ private[execution] object SparkPlanInfo { new SQLMetricInfo(metric.name.getOrElse(key), metric.id, metric.metricType) } - new SparkPlanInfo(plan.nodeName, plan.simpleString, children.map(fromSparkPlan), metrics) + // dump the file scan metadata (e.g file path) to event log + val metadata = plan match { + case fileScan: FileSourceScanExec => fileScan.metadata + case _ => Map[String, String]() + } + new SparkPlanInfo(plan.nodeName, plan.simpleString, children.map(fromSparkPlan), + metadata, metrics) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala index 75f5ec0e253df..2a4a1c8ef3438 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala @@ -36,6 +36,7 @@ class SparkPlanner( override def strategies: Seq[Strategy] = experimentalMethods.extraStrategies ++ extraPlanningStrategies ++ ( + PythonEvals :: DataSourceV2Strategy :: FileSourceStrategy :: DataSourceStrategy(conf) :: diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala index 4828fa60a7b58..89cb63784c0f4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala @@ -1458,6 +1458,14 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) { * }}} */ override def visitAlterViewQuery(ctx: AlterViewQueryContext): LogicalPlan = withOrigin(ctx) { + // ALTER VIEW ... AS INSERT INTO is not allowed. + ctx.query.queryNoWith match { + case s: SingleInsertQueryContext if s.insertInto != null => + operationNotAllowed("ALTER VIEW ... AS INSERT INTO", ctx) + case _: MultiInsertQueryContext => + operationNotAllowed("ALTER VIEW ... AS FROM ... [INSERT INTO ...]+", ctx) + case _ => // OK + } AlterViewAsCommand( name = visitTableIdentifier(ctx.tableIdentifier), originalText = source(ctx.query), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 07a6fcae83b70..89442a70283f5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -27,14 +27,16 @@ import org.apache.spark.sql.catalyst.planning._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.physical._ +import org.apache.spark.sql.catalyst.streaming.InternalOutputModes import org.apache.spark.sql.execution.columnar.{InMemoryRelation, InMemoryTableScanExec} import org.apache.spark.sql.execution.command._ import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec import org.apache.spark.sql.execution.joins.{BuildLeft, BuildRight, BuildSide} +import org.apache.spark.sql.execution.python._ import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.sources.MemoryPlanV2 import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.streaming.StreamingQuery +import org.apache.spark.sql.streaming.{OutputMode, StreamingQuery} import org.apache.spark.sql.types.StructType /** @@ -64,29 +66,35 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { * Plans special cases of limit operators. */ object SpecialLimits extends Strategy { + private def decideTopRankNode(limit: Int, child: LogicalPlan): Seq[SparkPlan] = { + if (limit < conf.topKSortFallbackThreshold) { + child match { + case Sort(order, true, child) => + TakeOrderedAndProjectExec(limit, order, child.output, planLater(child)) :: Nil + case Project(projectList, Sort(order, true, child)) => + TakeOrderedAndProjectExec(limit, order, projectList, planLater(child)) :: Nil + } + } else { + GlobalLimitExec(limit, + LocalLimitExec(limit, planLater(child)), + orderedLimit = true) :: Nil + } + } + override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case ReturnAnswer(rootPlan) => rootPlan match { - case Limit(IntegerLiteral(limit), Sort(order, true, child)) - if limit < conf.topKSortFallbackThreshold => - TakeOrderedAndProjectExec(limit, order, child.output, planLater(child)) :: Nil - case Limit(IntegerLiteral(limit), Project(projectList, Sort(order, true, child))) - if limit < conf.topKSortFallbackThreshold => - TakeOrderedAndProjectExec(limit, order, projectList, planLater(child)) :: Nil + case Limit(IntegerLiteral(limit), s @ Sort(order, true, child)) => + decideTopRankNode(limit, s) + case Limit(IntegerLiteral(limit), p @ Project(projectList, Sort(order, true, child))) => + decideTopRankNode(limit, p) case Limit(IntegerLiteral(limit), child) => - // With whole stage codegen, Spark releases resources only when all the output data of the - // query plan are consumed. It's possible that `CollectLimitExec` only consumes a little - // data from child plan and finishes the query without releasing resources. Here we wrap - // the child plan with `LocalLimitExec`, to stop the processing of whole stage codegen and - // trigger the resource releasing work, after we consume `limit` rows. - CollectLimitExec(limit, LocalLimitExec(limit, planLater(child))) :: Nil + CollectLimitExec(limit, planLater(child)) :: Nil case other => planLater(other) :: Nil } - case Limit(IntegerLiteral(limit), Sort(order, true, child)) - if limit < conf.topKSortFallbackThreshold => - TakeOrderedAndProjectExec(limit, order, child.output, planLater(child)) :: Nil - case Limit(IntegerLiteral(limit), Project(projectList, Sort(order, true, child))) - if limit < conf.topKSortFallbackThreshold => - TakeOrderedAndProjectExec(limit, order, projectList, planLater(child)) :: Nil + case Limit(IntegerLiteral(limit), s @ Sort(order, true, child)) => + decideTopRankNode(limit, s) + case Limit(IntegerLiteral(limit), p @ Project(projectList, Sort(order, true, child))) => + decideTopRankNode(limit, p) case _ => Nil } } @@ -332,10 +340,13 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { "Streaming aggregation doesn't support group aggregate pandas UDF") } + val stateVersion = conf.getConf(SQLConf.STREAMING_AGGREGATION_STATE_FORMAT_VERSION) + aggregate.AggUtils.planStreamingAggregation( namedGroupingExpressions, aggregateExpressions.map(expr => expr.asInstanceOf[AggregateExpression]), rewrittenResultExpressions, + stateVersion, planLater(child)) case _ => Nil @@ -354,6 +365,29 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { } } + /** + * Used to plan the streaming global limit operator for streams in append mode. + * We need to check for either a direct Limit or a Limit wrapped in a ReturnAnswer operator, + * following the example of the SpecialLimits Strategy above. + * Streams with limit in Append mode use the stateful StreamingGlobalLimitExec. + * Streams with limit in Complete mode use the stateless CollectLimitExec operator. + * Limit is unsupported for streams in Update mode. + */ + case class StreamingGlobalLimitStrategy(outputMode: OutputMode) extends Strategy { + override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { + case ReturnAnswer(rootPlan) => rootPlan match { + case Limit(IntegerLiteral(limit), child) + if plan.isStreaming && outputMode == InternalOutputModes.Append => + StreamingGlobalLimitExec(limit, LocalLimitExec(limit, planLater(child))) :: Nil + case _ => Nil + } + case Limit(IntegerLiteral(limit), child) + if plan.isStreaming && outputMode == InternalOutputModes.Append => + StreamingGlobalLimitExec(limit, LocalLimitExec(limit, planLater(child))) :: Nil + case _ => Nil + } + } + object StreamingJoinStrategy extends Strategy { override def apply(plan: LogicalPlan): Seq[SparkPlan] = { plan match { @@ -485,15 +519,30 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case FlatMapGroupsWithState( func, keyDeser, valueDeser, groupAttr, dataAttr, outputAttr, stateEnc, outputMode, _, timeout, child) => + val stateVersion = conf.getConf(SQLConf.FLATMAPGROUPSWITHSTATE_STATE_FORMAT_VERSION) val execPlan = FlatMapGroupsWithStateExec( - func, keyDeser, valueDeser, groupAttr, dataAttr, outputAttr, None, stateEnc, outputMode, - timeout, batchTimestampMs = None, eventTimeWatermark = None, planLater(child)) + func, keyDeser, valueDeser, groupAttr, dataAttr, outputAttr, None, stateEnc, stateVersion, + outputMode, timeout, batchTimestampMs = None, eventTimeWatermark = None, planLater(child)) execPlan :: Nil case _ => Nil } } + /** + * Strategy to convert EvalPython logical operator to physical operator. + */ + object PythonEvals extends Strategy { + override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { + case ArrowEvalPython(udfs, output, child) => + ArrowEvalPythonExec(udfs, output, planLater(child)) :: Nil + case BatchEvalPython(udfs, output, child) => + BatchEvalPythonExec(udfs, output, planLater(child)) :: Nil + case _ => + Nil + } + } + object BasicOperators extends Strategy { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case d: DataWritingCommand => DataWritingCommandExec(d, planLater(d.query)) :: Nil @@ -509,12 +558,20 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case logical.Distinct(child) => throw new IllegalStateException( "logical distinct operator should have been replaced by aggregate in the optimizer") - case logical.Intersect(left, right) => + case logical.Intersect(left, right, false) => + throw new IllegalStateException( + "logical intersect operator should have been replaced by semi-join in the optimizer") + case logical.Intersect(left, right, true) => throw new IllegalStateException( - "logical intersect operator should have been replaced by semi-join in the optimizer") - case logical.Except(left, right) => + "logical intersect operator should have been replaced by union, aggregate" + + " and generate operators in the optimizer") + case logical.Except(left, right, false) => throw new IllegalStateException( "logical except operator should have been replaced by anti-join in the optimizer") + case logical.Except(left, right, true) => + throw new IllegalStateException( + "logical except (all) operator should have been replaced by union, aggregate" + + " and generate operators in the optimizer") case logical.DeserializeToObject(deserializer, objAttr, child) => execution.DeserializeToObjectExec(deserializer, objAttr, planLater(child)) :: Nil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala index 372dc3db36ce6..1fc4de9e56015 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala @@ -21,6 +21,7 @@ import java.util.Locale import java.util.function.Supplier import scala.collection.mutable +import scala.util.control.NonFatal import org.apache.spark.broadcast import org.apache.spark.rdd.RDD @@ -275,7 +276,7 @@ trait CodegenSupport extends SparkPlan { required: AttributeSet): String = { val evaluateVars = new StringBuilder variables.zipWithIndex.foreach { case (ev, i) => - if (ev.code != "" && required.contains(attributes(i))) { + if (ev.code.nonEmpty && required.contains(attributes(i))) { evaluateVars.append(ev.code.toString + "\n") ev.code = EmptyBlock } @@ -582,7 +583,7 @@ case class WholeStageCodegenExec(child: SparkPlan)(val codegenStageId: Int) val (_, maxCodeSize) = try { CodeGenerator.compile(cleanedSource) } catch { - case _: Exception if !Utils.isTesting && sqlContext.conf.codegenFallback => + case NonFatal(_) if !Utils.isTesting && sqlContext.conf.codegenFallback => // We should already saw the error message logWarning(s"Whole-stage codegen disabled for plan (id=$codegenStageId):\n $treeString") return child.execute() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala index ebbdf1aaa024d..6be88c463dbd9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala @@ -177,6 +177,10 @@ object AggUtils { case agg @ AggregateExpression(aggregateFunction, mode, true, _) => aggregateFunction.transformDown(distinctColumnAttributeLookup) .asInstanceOf[AggregateFunction] + case agg => + throw new IllegalArgumentException( + "Non-distinct aggregate is found in functionsWithDistinct " + + s"at planAggregateWithOneDistinct: $agg") } val partialDistinctAggregate: SparkPlan = { @@ -256,6 +260,7 @@ object AggUtils { groupingExpressions: Seq[NamedExpression], functionsWithoutDistinct: Seq[AggregateExpression], resultExpressions: Seq[NamedExpression], + stateFormatVersion: Int, child: SparkPlan): Seq[SparkPlan] = { val groupingAttributes = groupingExpressions.map(_.toAttribute) @@ -287,7 +292,8 @@ object AggUtils { child = partialAggregate) } - val restored = StateStoreRestoreExec(groupingAttributes, None, partialMerged1) + val restored = StateStoreRestoreExec(groupingAttributes, None, stateFormatVersion, + partialMerged1) val partialMerged2: SparkPlan = { val aggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = PartialMerge)) @@ -311,6 +317,7 @@ object AggUtils { stateInfo = None, outputMode = None, eventTimeWatermark = None, + stateFormatVersion = stateFormatVersion, partialMerged2) val finalAndCompleteAggregate: SparkPlan = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index 8c7b2c187cccd..98adba50b2973 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -328,7 +328,7 @@ case class HashAggregateExec( initialBuffer, bufferSchema, groupingKeySchema, - TaskContext.get().taskMemoryManager(), + TaskContext.get(), 1024 * 16, // initial capacity TaskContext.get().taskMemoryManager().pageSizeBytes ) @@ -579,6 +579,7 @@ case class HashAggregateExec( case _ => } } + val bitMaxCapacity = sqlContext.conf.fastHashAggregateRowMaxCapacityBit val thisPlan = ctx.addReferenceObj("plan", this) @@ -588,7 +589,7 @@ case class HashAggregateExec( val fastHashMapClassName = ctx.freshName("FastHashMap") if (isVectorizedHashMapEnabled) { val generatedMap = new VectorizedHashMapGenerator(ctx, aggregateExpressions, - fastHashMapClassName, groupingKeySchema, bufferSchema).generate() + fastHashMapClassName, groupingKeySchema, bufferSchema, bitMaxCapacity).generate() ctx.addInnerClass(generatedMap) // Inline mutable state since not many aggregation operations in a task @@ -598,7 +599,7 @@ case class HashAggregateExec( forceInline = true) } else { val generatedMap = new RowBasedHashMapGenerator(ctx, aggregateExpressions, - fastHashMapClassName, groupingKeySchema, bufferSchema).generate() + fastHashMapClassName, groupingKeySchema, bufferSchema, bitMaxCapacity).generate() ctx.addInnerClass(generatedMap) // Inline mutable state since not many aggregation operations in a task diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/RowBasedHashMapGenerator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/RowBasedHashMapGenerator.scala index d5508275c48c5..56cf78d8b7fc1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/RowBasedHashMapGenerator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/RowBasedHashMapGenerator.scala @@ -39,66 +39,52 @@ class RowBasedHashMapGenerator( aggregateExpressions: Seq[AggregateExpression], generatedClassName: String, groupingKeySchema: StructType, - bufferSchema: StructType) + bufferSchema: StructType, + bitMaxCapacity: Int) extends HashMapGenerator (ctx, aggregateExpressions, generatedClassName, groupingKeySchema, bufferSchema) { override protected def initializeAggregateHashMap(): String = { - val generatedKeySchema: String = - s"new org.apache.spark.sql.types.StructType()" + - groupingKeySchema.map { key => - val keyName = ctx.addReferenceObj("keyName", key.name) - key.dataType match { - case d: DecimalType => - s""".add($keyName, org.apache.spark.sql.types.DataTypes.createDecimalType( - |${d.precision}, ${d.scale}))""".stripMargin - case _ => - s""".add($keyName, org.apache.spark.sql.types.DataTypes.${key.dataType})""" - } - }.mkString("\n").concat(";") + val keySchema = ctx.addReferenceObj("keySchemaTerm", groupingKeySchema) + val valueSchema = ctx.addReferenceObj("valueSchemaTerm", bufferSchema) - val generatedValueSchema: String = - s"new org.apache.spark.sql.types.StructType()" + - bufferSchema.map { key => - val keyName = ctx.addReferenceObj("keyName", key.name) - key.dataType match { - case d: DecimalType => - s""".add($keyName, org.apache.spark.sql.types.DataTypes.createDecimalType( - |${d.precision}, ${d.scale}))""".stripMargin - case _ => - s""".add($keyName, org.apache.spark.sql.types.DataTypes.${key.dataType})""" - } - }.mkString("\n").concat(";") + val numVarLenFields = groupingKeys.map(_.dataType).count { + case dt if UnsafeRow.isFixedLength(dt) => false + // TODO: consider large decimal and interval type + case _ => true + } s""" | private org.apache.spark.sql.catalyst.expressions.RowBasedKeyValueBatch batch; | private int[] buckets; - | private int capacity = 1 << 16; + | private int capacity = 1 << $bitMaxCapacity; | private double loadFactor = 0.5; | private int numBuckets = (int) (capacity / loadFactor); | private int maxSteps = 2; | private int numRows = 0; - | private org.apache.spark.sql.types.StructType keySchema = $generatedKeySchema - | private org.apache.spark.sql.types.StructType valueSchema = $generatedValueSchema | private Object emptyVBase; | private long emptyVOff; | private int emptyVLen; | private boolean isBatchFull = false; + | private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter agg_rowWriter; | | | public $generatedClassName( | org.apache.spark.memory.TaskMemoryManager taskMemoryManager, | InternalRow emptyAggregationBuffer) { | batch = org.apache.spark.sql.catalyst.expressions.RowBasedKeyValueBatch - | .allocate(keySchema, valueSchema, taskMemoryManager, capacity); + | .allocate($keySchema, $valueSchema, taskMemoryManager, capacity); | - | final UnsafeProjection valueProjection = UnsafeProjection.create(valueSchema); + | final UnsafeProjection valueProjection = UnsafeProjection.create($valueSchema); | final byte[] emptyBuffer = valueProjection.apply(emptyAggregationBuffer).getBytes(); | | emptyVBase = emptyBuffer; | emptyVOff = Platform.BYTE_ARRAY_OFFSET; | emptyVLen = emptyBuffer.length; | + | agg_rowWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter( + | ${groupingKeySchema.length}, ${numVarLenFields * 32}); + | | buckets = new int[numBuckets]; | java.util.Arrays.fill(buckets, -1); | } @@ -136,12 +122,6 @@ class RowBasedHashMapGenerator( * */ protected def generateFindOrInsert(): String = { - val numVarLenFields = groupingKeys.map(_.dataType).count { - case dt if UnsafeRow.isFixedLength(dt) => false - // TODO: consider large decimal and interval type - case _ => true - } - val createUnsafeRowForKey = groupingKeys.zipWithIndex.map { case (key: Buffer, ordinal: Int) => key.dataType match { case t: DecimalType => @@ -154,6 +134,12 @@ class RowBasedHashMapGenerator( } }.mkString(";\n") + val resetNullBits = if (groupingKeySchema.map(_.nullable).forall(_ == false)) { + "" + } else { + "agg_rowWriter.zeroOutNullBytes();" + } + s""" |public org.apache.spark.sql.catalyst.expressions.UnsafeRow findOrInsert(${ groupingKeySignature}) { @@ -164,12 +150,8 @@ class RowBasedHashMapGenerator( | // Return bucket index if it's either an empty slot or already contains the key | if (buckets[idx] == -1) { | if (numRows < capacity && !isBatchFull) { - | // creating the unsafe for new entry - | org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter agg_rowWriter - | = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter( - | ${groupingKeySchema.length}, ${numVarLenFields * 32}); - | agg_rowWriter.reset(); //TODO: investigate if reset or zeroout are actually needed - | agg_rowWriter.zeroOutNullBytes(); + | agg_rowWriter.reset(); + | $resetNullBits | ${createUnsafeRowForKey}; | org.apache.spark.sql.catalyst.expressions.UnsafeRow agg_result | = agg_rowWriter.getRow(); diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala index 9dc334c1ead3c..72505f7fac0c6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala @@ -166,7 +166,7 @@ class TungstenAggregationIterator( initialAggregationBuffer, StructType.fromAttributes(aggregateFunctions.flatMap(_.aggBufferAttributes)), StructType.fromAttributes(groupingExpressions.map(_.toAttribute)), - TaskContext.get().taskMemoryManager(), + TaskContext.get(), 1024 * 16, // initial capacity TaskContext.get().taskMemoryManager().pageSizeBytes ) @@ -372,7 +372,7 @@ class TungstenAggregationIterator( } } - TaskContext.get().addTaskCompletionListener(_ => { + TaskContext.get().addTaskCompletionListener[Unit](_ => { // At the end of the task, update the task's peak memory usage. Since we destroy // the map to create the sorter, their memory usages should not overlap, so it is safe // to just use the max of the two. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala index 7b3580cecc60d..f9c4ecc14e6c7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala @@ -47,59 +47,35 @@ class VectorizedHashMapGenerator( aggregateExpressions: Seq[AggregateExpression], generatedClassName: String, groupingKeySchema: StructType, - bufferSchema: StructType) + bufferSchema: StructType, + bitMaxCapacity: Int) extends HashMapGenerator (ctx, aggregateExpressions, generatedClassName, groupingKeySchema, bufferSchema) { override protected def initializeAggregateHashMap(): String = { - val generatedSchema: String = - s"new org.apache.spark.sql.types.StructType()" + - (groupingKeySchema ++ bufferSchema).map { key => - val keyName = ctx.addReferenceObj("keyName", key.name) - key.dataType match { - case d: DecimalType => - s""".add($keyName, org.apache.spark.sql.types.DataTypes.createDecimalType( - |${d.precision}, ${d.scale}))""".stripMargin - case _ => - s""".add($keyName, org.apache.spark.sql.types.DataTypes.${key.dataType})""" - } - }.mkString("\n").concat(";") - - val generatedAggBufferSchema: String = - s"new org.apache.spark.sql.types.StructType()" + - bufferSchema.map { key => - val keyName = ctx.addReferenceObj("keyName", key.name) - key.dataType match { - case d: DecimalType => - s""".add($keyName, org.apache.spark.sql.types.DataTypes.createDecimalType( - |${d.precision}, ${d.scale}))""".stripMargin - case _ => - s""".add($keyName, org.apache.spark.sql.types.DataTypes.${key.dataType})""" - } - }.mkString("\n").concat(";") + val schemaStructType = new StructType((groupingKeySchema ++ bufferSchema).toArray) + val schema = ctx.addReferenceObj("schemaTerm", schemaStructType) + val aggBufferSchemaFieldsLength = bufferSchema.fields.length s""" | private ${classOf[OnHeapColumnVector].getName}[] vectors; | private ${classOf[ColumnarBatch].getName} batch; | private ${classOf[MutableColumnarRow].getName} aggBufferRow; | private int[] buckets; - | private int capacity = 1 << 16; + | private int capacity = 1 << $bitMaxCapacity; | private double loadFactor = 0.5; | private int numBuckets = (int) (capacity / loadFactor); | private int maxSteps = 2; | private int numRows = 0; - | private org.apache.spark.sql.types.StructType schema = $generatedSchema - | private org.apache.spark.sql.types.StructType aggregateBufferSchema = - | $generatedAggBufferSchema | | public $generatedClassName() { - | vectors = ${classOf[OnHeapColumnVector].getName}.allocateColumns(capacity, schema); + | vectors = ${classOf[OnHeapColumnVector].getName}.allocateColumns(capacity, $schema); | batch = new ${classOf[ColumnarBatch].getName}(vectors); | | // Generates a projection to return the aggregate buffer only. | ${classOf[OnHeapColumnVector].getName}[] aggBufferVectors = - | new ${classOf[OnHeapColumnVector].getName}[aggregateBufferSchema.fields().length]; - | for (int i = 0; i < aggregateBufferSchema.fields().length; i++) { + | new ${classOf[OnHeapColumnVector].getName}[$aggBufferSchemaFieldsLength]; + | for (int i = 0; i < $aggBufferSchemaFieldsLength; i++) { | aggBufferVectors[i] = vectors[i + ${groupingKeys.length}]; | } | aggBufferRow = new ${classOf[MutableColumnarRow].getName}(aggBufferVectors); diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala index 7487564ed64da..1a48bc8398a63 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala @@ -17,81 +17,83 @@ package org.apache.spark.sql.execution.arrow -import java.io.ByteArrayOutputStream -import java.nio.channels.Channels +import java.io.{ByteArrayInputStream, ByteArrayOutputStream, FileInputStream, OutputStream} +import java.nio.channels.{Channels, SeekableByteChannel} import scala.collection.JavaConverters._ +import org.apache.arrow.flatbuf.MessageHeader import org.apache.arrow.memory.BufferAllocator import org.apache.arrow.vector._ -import org.apache.arrow.vector.ipc.{ArrowFileReader, ArrowFileWriter} -import org.apache.arrow.vector.ipc.message.ArrowRecordBatch -import org.apache.arrow.vector.util.ByteArrayReadableSeekableByteChannel +import org.apache.arrow.vector.ipc.{ArrowStreamWriter, ReadChannel, WriteChannel} +import org.apache.arrow.vector.ipc.message.{ArrowRecordBatch, MessageSerializer} import org.apache.spark.TaskContext import org.apache.spark.api.java.JavaRDD +import org.apache.spark.network.util.JavaUtils import org.apache.spark.sql.{DataFrame, SQLContext} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.types._ import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch, ColumnVector} -import org.apache.spark.util.Utils +import org.apache.spark.util.{ByteBufferOutputStream, Utils} /** - * Store Arrow data in a form that can be serialized by Spark and served to a Python process. + * Writes serialized ArrowRecordBatches to a DataOutputStream in the Arrow stream format. */ -private[sql] class ArrowPayload private[sql] (payload: Array[Byte]) extends Serializable { +private[sql] class ArrowBatchStreamWriter( + schema: StructType, + out: OutputStream, + timeZoneId: String) { - /** - * Convert the ArrowPayload to an ArrowRecordBatch. - */ - def loadBatch(allocator: BufferAllocator): ArrowRecordBatch = { - ArrowConverters.byteArrayToBatch(payload, allocator) - } + val arrowSchema = ArrowUtils.toArrowSchema(schema, timeZoneId) + val writeChannel = new WriteChannel(Channels.newChannel(out)) + + // Write the Arrow schema first, before batches + MessageSerializer.serialize(writeChannel, arrowSchema) /** - * Get the ArrowPayload as a type that can be served to Python. + * Consume iterator to write each serialized ArrowRecordBatch to the stream. */ - def asPythonSerializable: Array[Byte] = payload -} - -/** - * Iterator interface to iterate over Arrow record batches and return rows - */ -private[sql] trait ArrowRowIterator extends Iterator[InternalRow] { + def writeBatches(arrowBatchIter: Iterator[Array[Byte]]): Unit = { + arrowBatchIter.foreach(writeChannel.write) + } /** - * Return the schema loaded from the Arrow record batch being iterated over + * End the Arrow stream, does not close output stream. */ - def schema: StructType + def end(): Unit = { + ArrowStreamWriter.writeEndOfStream(writeChannel) + } } private[sql] object ArrowConverters { /** - * Maps Iterator from InternalRow to ArrowPayload. Limit ArrowRecordBatch size in ArrowPayload - * by setting maxRecordsPerBatch or use 0 to fully consume rowIter. + * Maps Iterator from InternalRow to serialized ArrowRecordBatches. Limit ArrowRecordBatch size + * in a batch by setting maxRecordsPerBatch or use 0 to fully consume rowIter. */ - private[sql] def toPayloadIterator( + private[sql] def toBatchIterator( rowIter: Iterator[InternalRow], schema: StructType, maxRecordsPerBatch: Int, timeZoneId: String, - context: TaskContext): Iterator[ArrowPayload] = { + context: TaskContext): Iterator[Array[Byte]] = { val arrowSchema = ArrowUtils.toArrowSchema(schema, timeZoneId) val allocator = - ArrowUtils.rootAllocator.newChildAllocator("toPayloadIterator", 0, Long.MaxValue) + ArrowUtils.rootAllocator.newChildAllocator("toBatchIterator", 0, Long.MaxValue) val root = VectorSchemaRoot.create(arrowSchema, allocator) + val unloader = new VectorUnloader(root) val arrowWriter = ArrowWriter.create(root) - context.addTaskCompletionListener { _ => + context.addTaskCompletionListener[Unit] { _ => root.close() allocator.close() } - new Iterator[ArrowPayload] { + new Iterator[Array[Byte]] { override def hasNext: Boolean = rowIter.hasNext || { root.close() @@ -99,9 +101,9 @@ private[sql] object ArrowConverters { false } - override def next(): ArrowPayload = { + override def next(): Array[Byte] = { val out = new ByteArrayOutputStream() - val writer = new ArrowFileWriter(root, null, Channels.newChannel(out)) + val writeChannel = new WriteChannel(Channels.newChannel(out)) Utils.tryWithSafeFinally { var rowCount = 0 @@ -111,45 +113,46 @@ private[sql] object ArrowConverters { rowCount += 1 } arrowWriter.finish() - writer.writeBatch() + val batch = unloader.getRecordBatch() + MessageSerializer.serialize(writeChannel, batch) + batch.close() } { arrowWriter.reset() - writer.close() } - new ArrowPayload(out.toByteArray) + out.toByteArray } } } /** - * Maps Iterator from ArrowPayload to InternalRow. Returns a pair containing the row iterator - * and the schema from the first batch of Arrow data read. + * Maps iterator from serialized ArrowRecordBatches to InternalRows. */ - private[sql] def fromPayloadIterator( - payloadIter: Iterator[ArrowPayload], - context: TaskContext): ArrowRowIterator = { + private[sql] def fromBatchIterator( + arrowBatchIter: Iterator[Array[Byte]], + schema: StructType, + timeZoneId: String, + context: TaskContext): Iterator[InternalRow] = { val allocator = - ArrowUtils.rootAllocator.newChildAllocator("fromPayloadIterator", 0, Long.MaxValue) + ArrowUtils.rootAllocator.newChildAllocator("fromBatchIterator", 0, Long.MaxValue) + + val arrowSchema = ArrowUtils.toArrowSchema(schema, timeZoneId) + val root = VectorSchemaRoot.create(arrowSchema, allocator) - new ArrowRowIterator { - private var reader: ArrowFileReader = null - private var schemaRead = StructType(Seq.empty) - private var rowIter = if (payloadIter.hasNext) nextBatch() else Iterator.empty + new Iterator[InternalRow] { + private var rowIter = if (arrowBatchIter.hasNext) nextBatch() else Iterator.empty - context.addTaskCompletionListener { _ => - closeReader() + context.addTaskCompletionListener[Unit] { _ => + root.close() allocator.close() } - override def schema: StructType = schemaRead - override def hasNext: Boolean = rowIter.hasNext || { - closeReader() - if (payloadIter.hasNext) { + if (arrowBatchIter.hasNext) { rowIter = nextBatch() true } else { + root.close() allocator.close() false } @@ -157,19 +160,11 @@ private[sql] object ArrowConverters { override def next(): InternalRow = rowIter.next() - private def closeReader(): Unit = { - if (reader != null) { - reader.close() - reader = null - } - } - private def nextBatch(): Iterator[InternalRow] = { - val in = new ByteArrayReadableSeekableByteChannel(payloadIter.next().asPythonSerializable) - reader = new ArrowFileReader(in, allocator) - reader.loadNextBatch() // throws IOException - val root = reader.getVectorSchemaRoot // throws IOException - schemaRead = ArrowUtils.fromArrowSchema(root.getSchema) + val arrowRecordBatch = ArrowConverters.loadBatch(arrowBatchIter.next(), allocator) + val vectorLoader = new VectorLoader(root) + vectorLoader.load(arrowRecordBatch) + arrowRecordBatch.close() val columns = root.getFieldVectors.asScala.map { vector => new ArrowColumnVector(vector).asInstanceOf[ColumnVector] @@ -183,34 +178,106 @@ private[sql] object ArrowConverters { } /** - * Convert a byte array to an ArrowRecordBatch. + * Load a serialized ArrowRecordBatch. */ - private[arrow] def byteArrayToBatch( + private[arrow] def loadBatch( batchBytes: Array[Byte], allocator: BufferAllocator): ArrowRecordBatch = { - val in = new ByteArrayReadableSeekableByteChannel(batchBytes) - val reader = new ArrowFileReader(in, allocator) - - // Read a batch from a byte stream, ensure the reader is closed - Utils.tryWithSafeFinally { - val root = reader.getVectorSchemaRoot // throws IOException - val unloader = new VectorUnloader(root) - reader.loadNextBatch() // throws IOException - unloader.getRecordBatch - } { - reader.close() - } + val in = new ByteArrayInputStream(batchBytes) + MessageSerializer.deserializeRecordBatch( + new ReadChannel(Channels.newChannel(in)), allocator) // throws IOException } + /** + * Create a DataFrame from a JavaRDD of serialized ArrowRecordBatches. + */ private[sql] def toDataFrame( - payloadRDD: JavaRDD[Array[Byte]], + arrowBatchRDD: JavaRDD[Array[Byte]], schemaString: String, sqlContext: SQLContext): DataFrame = { - val rdd = payloadRDD.rdd.mapPartitions { iter => + val schema = DataType.fromJson(schemaString).asInstanceOf[StructType] + val timeZoneId = sqlContext.sessionState.conf.sessionLocalTimeZone + val rdd = arrowBatchRDD.rdd.mapPartitions { iter => val context = TaskContext.get() - ArrowConverters.fromPayloadIterator(iter.map(new ArrowPayload(_)), context) + ArrowConverters.fromBatchIterator(iter, schema, timeZoneId, context) + } + sqlContext.internalCreateDataFrame(rdd.setName("arrow"), schema) + } + + /** + * Read a file as an Arrow stream and parallelize as an RDD of serialized ArrowRecordBatches. + */ + private[sql] def readArrowStreamFromFile( + sqlContext: SQLContext, + filename: String): JavaRDD[Array[Byte]] = { + Utils.tryWithResource(new FileInputStream(filename)) { fileStream => + // Create array to consume iterator so that we can safely close the file + val batches = getBatchesFromStream(fileStream.getChannel).toArray + // Parallelize the record batches to create an RDD + JavaRDD.fromRDD(sqlContext.sparkContext.parallelize(batches, batches.length)) + } + } + + /** + * Read an Arrow stream input and return an iterator of serialized ArrowRecordBatches. + */ + private[sql] def getBatchesFromStream(in: SeekableByteChannel): Iterator[Array[Byte]] = { + + // Iterate over the serialized Arrow RecordBatch messages from a stream + new Iterator[Array[Byte]] { + var batch: Array[Byte] = readNextBatch() + + override def hasNext: Boolean = batch != null + + override def next(): Array[Byte] = { + val prevBatch = batch + batch = readNextBatch() + prevBatch + } + + // This gets the next serialized ArrowRecordBatch by reading message metadata to check if it + // is a RecordBatch message and then returning the complete serialized message which consists + // of a int32 length, serialized message metadata and a serialized RecordBatch message body + def readNextBatch(): Array[Byte] = { + val msgMetadata = MessageSerializer.readMessage(new ReadChannel(in)) + if (msgMetadata == null) { + return null + } + + // Get the length of the body, which has not been read at this point + val bodyLength = msgMetadata.getMessageBodyLength.toInt + + // Only care about RecordBatch messages, skip Schema and unsupported Dictionary messages + if (msgMetadata.getMessage.headerType() == MessageHeader.RecordBatch) { + + // Buffer backed output large enough to hold the complete serialized message + val bbout = new ByteBufferOutputStream(4 + msgMetadata.getMessageLength + bodyLength) + + // Write message metadata to ByteBuffer output stream + MessageSerializer.writeMessageBuffer( + new WriteChannel(Channels.newChannel(bbout)), + msgMetadata.getMessageLength, + msgMetadata.getMessageBuffer) + + // Get a zero-copy ByteBuffer with already contains message metadata, must close first + bbout.close() + val bb = bbout.toByteBuffer + bb.position(bbout.getCount()) + + // Read message body directly into the ByteBuffer to avoid copy, return backed byte array + bb.limit(bb.capacity()) + JavaUtils.readFully(in, bb) + bb.array() + } else { + if (bodyLength > 0) { + // Skip message body if not a RecordBatch + in.position(in.position() + bodyLength) + } + + // Proceed to next message + readNextBatch() + } + } } - val schema = DataType.fromJson(schemaString).asInstanceOf[StructType] - sqlContext.internalCreateDataFrame(rdd, schema) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowUtils.scala index 93c8127681b3e..533097ac399e9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowUtils.scala @@ -47,11 +47,13 @@ object ArrowUtils { case DateType => new ArrowType.Date(DateUnit.DAY) case TimestampType => if (timeZoneId == null) { - throw new UnsupportedOperationException("TimestampType must supply timeZoneId parameter") + throw new UnsupportedOperationException( + s"${TimestampType.catalogString} must supply timeZoneId parameter") } else { new ArrowType.Timestamp(TimeUnit.MICROSECOND, timeZoneId) } - case _ => throw new UnsupportedOperationException(s"Unsupported data type: ${dt.simpleString}") + case _ => + throw new UnsupportedOperationException(s"Unsupported data type: ${dt.catalogString}") } def fromArrowType(dt: ArrowType): DataType = dt match { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala index 66888fce7f9f5..8dd484af6e908 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala @@ -21,7 +21,6 @@ import scala.collection.JavaConverters._ import org.apache.arrow.vector._ import org.apache.arrow.vector.complex._ -import org.apache.arrow.vector.types.pojo.ArrowType import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.SpecializedGetters @@ -62,13 +61,13 @@ object ArrowWriter { case (ArrayType(_, _), vector: ListVector) => val elementVector = createFieldWriter(vector.getDataVector()) new ArrayWriter(vector, elementVector) - case (StructType(_), vector: NullableMapVector) => + case (StructType(_), vector: StructVector) => val children = (0 until vector.size()).map { ordinal => createFieldWriter(vector.getChildByOrdinal(ordinal)) } new StructWriter(vector, children.toArray) case (dt, _) => - throw new UnsupportedOperationException(s"Unsupported data type: ${dt.simpleString}") + throw new UnsupportedOperationException(s"Unsupported data type: ${dt.catalogString}") } } } @@ -129,20 +128,7 @@ private[arrow] abstract class ArrowFieldWriter { } def reset(): Unit = { - // TODO: reset() should be in a common interface - valueVector match { - case fixedWidthVector: BaseFixedWidthVector => fixedWidthVector.reset() - case variableWidthVector: BaseVariableWidthVector => variableWidthVector.reset() - case listVector: ListVector => - // Manual "reset" the underlying buffer. - // TODO: When we upgrade to Arrow 0.10.0, we can simply remove this and call - // `listVector.reset()`. - val buffers = listVector.getBuffers(false) - buffers.foreach(buf => buf.setZero(0, buf.capacity())) - listVector.setValueCount(0) - listVector.setLastSet(0) - case _ => - } + valueVector.reset() count = 0 } } @@ -323,7 +309,7 @@ private[arrow] class ArrayWriter( } private[arrow] class StructWriter( - val valueVector: NullableMapVector, + val valueVector: StructVector, children: Array[ArrowFieldWriter]) extends ArrowFieldWriter { override def setNull(): Unit = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnType.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnType.scala index e9b150fd86095..542a10fc175c0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnType.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnType.scala @@ -717,7 +717,7 @@ private[columnar] object ColumnType { case struct: StructType => STRUCT(struct) case udt: UserDefinedType[_] => apply(udt.sqlType) case other => - throw new Exception(s"Unsupported type: ${other.simpleString}") + throw new Exception(s"Unsupported type: ${other.catalogString}") } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala index 7c8faec53a828..1a8fbaca53f59 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical.{HintInfo, LogicalPlan, Statistics} import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.storage.StorageLevel -import org.apache.spark.util.LongAccumulator +import org.apache.spark.util.{LongAccumulator, Utils} /** @@ -207,4 +207,7 @@ case class InMemoryRelation( } override protected def otherCopyArgs: Seq[AnyRef] = Seq(statsOfPlanToCache) + + override def simpleString: String = + s"InMemoryRelation [${Utils.truncatedString(output, ", ")}], ${cacheBuilder.storageLevel}" } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala index 997cf92449c68..196d057c2de1b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala @@ -97,7 +97,7 @@ case class InMemoryTableScanExec( columnarBatch.column(i).asInstanceOf[WritableColumnVector], columnarBatchSchema.fields(i).dataType, rowCount) } - taskContext.foreach(_.addTaskCompletionListener(_ => columnarBatch.close())) + taskContext.foreach(_.addTaskCompletionListener[Unit](_ => columnarBatch.close())) columnarBatch } @@ -183,6 +183,18 @@ case class InMemoryTableScanExec( private val stats = relation.partitionStatistics private def statsFor(a: Attribute) = stats.forAttribute(a) + // Currently, only use statistics from atomic types except binary type only. + private object ExtractableLiteral { + def unapply(expr: Expression): Option[Literal] = expr match { + case lit: Literal => lit.dataType match { + case BinaryType => None + case _: AtomicType => Some(lit) + case _ => None + } + case _ => None + } + } + // Returned filter predicate should return false iff it is impossible for the input expression // to evaluate to `true' based on statistics collected about this partition batch. @transient lazy val buildFilter: PartialFunction[Expression, Expression] = { @@ -194,33 +206,37 @@ case class InMemoryTableScanExec( if buildFilter.isDefinedAt(lhs) && buildFilter.isDefinedAt(rhs) => buildFilter(lhs) || buildFilter(rhs) - case EqualTo(a: AttributeReference, l: Literal) => + case EqualTo(a: AttributeReference, ExtractableLiteral(l)) => statsFor(a).lowerBound <= l && l <= statsFor(a).upperBound - case EqualTo(l: Literal, a: AttributeReference) => + case EqualTo(ExtractableLiteral(l), a: AttributeReference) => statsFor(a).lowerBound <= l && l <= statsFor(a).upperBound - case EqualNullSafe(a: AttributeReference, l: Literal) => + case EqualNullSafe(a: AttributeReference, ExtractableLiteral(l)) => statsFor(a).lowerBound <= l && l <= statsFor(a).upperBound - case EqualNullSafe(l: Literal, a: AttributeReference) => + case EqualNullSafe(ExtractableLiteral(l), a: AttributeReference) => statsFor(a).lowerBound <= l && l <= statsFor(a).upperBound - case LessThan(a: AttributeReference, l: Literal) => statsFor(a).lowerBound < l - case LessThan(l: Literal, a: AttributeReference) => l < statsFor(a).upperBound + case LessThan(a: AttributeReference, ExtractableLiteral(l)) => statsFor(a).lowerBound < l + case LessThan(ExtractableLiteral(l), a: AttributeReference) => l < statsFor(a).upperBound - case LessThanOrEqual(a: AttributeReference, l: Literal) => statsFor(a).lowerBound <= l - case LessThanOrEqual(l: Literal, a: AttributeReference) => l <= statsFor(a).upperBound + case LessThanOrEqual(a: AttributeReference, ExtractableLiteral(l)) => + statsFor(a).lowerBound <= l + case LessThanOrEqual(ExtractableLiteral(l), a: AttributeReference) => + l <= statsFor(a).upperBound - case GreaterThan(a: AttributeReference, l: Literal) => l < statsFor(a).upperBound - case GreaterThan(l: Literal, a: AttributeReference) => statsFor(a).lowerBound < l + case GreaterThan(a: AttributeReference, ExtractableLiteral(l)) => l < statsFor(a).upperBound + case GreaterThan(ExtractableLiteral(l), a: AttributeReference) => statsFor(a).lowerBound < l - case GreaterThanOrEqual(a: AttributeReference, l: Literal) => l <= statsFor(a).upperBound - case GreaterThanOrEqual(l: Literal, a: AttributeReference) => statsFor(a).lowerBound <= l + case GreaterThanOrEqual(a: AttributeReference, ExtractableLiteral(l)) => + l <= statsFor(a).upperBound + case GreaterThanOrEqual(ExtractableLiteral(l), a: AttributeReference) => + statsFor(a).lowerBound <= l case IsNull(a: Attribute) => statsFor(a).nullCount > 0 case IsNotNull(a: Attribute) => statsFor(a).count - statsFor(a).nullCount > 0 case In(a: AttributeReference, list: Seq[Expression]) - if list.forall(_.isInstanceOf[Literal]) && list.nonEmpty => + if list.forall(ExtractableLiteral.unapply(_).isDefined) && list.nonEmpty => list.map(l => statsFor(a).lowerBound <= l.asInstanceOf[Literal] && l.asInstanceOf[Literal] <= statsFor(a).upperBound).reduce(_ || _) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala index 640e01336aa75..3fea6d7c7fbfe 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala @@ -47,7 +47,7 @@ case class AnalyzeColumnCommand( if (tableMeta.tableType == CatalogTableType.VIEW) { throw new AnalysisException("ANALYZE TABLE is not supported on views.") } - val sizeInBytes = CommandUtils.calculateTotalSize(sessionState, tableMeta) + val sizeInBytes = CommandUtils.calculateTotalSize(sparkSession, tableMeta) // Compute stats for each column val (rowCount, newColStats) = computeColumnStats(sparkSession, tableIdentWithDB, columnNames) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzePartitionCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzePartitionCommand.scala index 5b54b2270b5ec..18fefa0a6f19f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzePartitionCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzePartitionCommand.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.command import org.apache.spark.sql.{AnalysisException, Column, Row, SparkSession} import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.{NoSuchPartitionException, UnresolvedAttribute} -import org.apache.spark.sql.catalyst.catalog.{CatalogTable, CatalogTableType} +import org.apache.spark.sql.catalyst.catalog.{CatalogTable, CatalogTableType, ExternalCatalogUtils} import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.catalyst.expressions.{And, EqualTo, Literal} import org.apache.spark.sql.execution.datasources.PartitioningUtils @@ -140,7 +140,13 @@ case class AnalyzePartitionCommand( val df = tableDf.filter(Column(filter)).groupBy(partitionColumns: _*).count() df.collect().map { r => - val partitionColumnValues = partitionColumns.indices.map(r.get(_).toString) + val partitionColumnValues = partitionColumns.indices.map { i => + if (r.isNullAt(i)) { + ExternalCatalogUtils.DEFAULT_PARTITION_NAME + } else { + r.get(i).toString + } + } val spec = tableMeta.partitionColumnNames.zip(partitionColumnValues).toMap val count = BigInt(r.getLong(partitionColumns.size)) (spec, count) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala index 58b53e8b1c551..3076e919dd61f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala @@ -39,7 +39,7 @@ case class AnalyzeTableCommand( } // Compute stats for the whole table - val newTotalSize = CommandUtils.calculateTotalSize(sessionState, tableMeta) + val newTotalSize = CommandUtils.calculateTotalSize(sparkSession, tableMeta) val newRowCount = if (noscan) None else Some(BigInt(sparkSession.table(tableIdentWithDB).count())) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/CommandUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/CommandUtils.scala index c27048626c8eb..df71bc9effb3e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/CommandUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/CommandUtils.scala @@ -21,12 +21,13 @@ import java.net.URI import scala.util.control.NonFatal -import org.apache.hadoop.fs.{FileSystem, Path} +import org.apache.hadoop.fs.{FileSystem, Path, PathFilter} import org.apache.spark.internal.Logging import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.catalog.{CatalogStatistics, CatalogTable, CatalogTablePartition} +import org.apache.spark.sql.execution.datasources.{DataSourceUtils, InMemoryFileIndex} import org.apache.spark.sql.internal.SessionState @@ -38,7 +39,7 @@ object CommandUtils extends Logging { val catalog = sparkSession.sessionState.catalog if (sparkSession.sessionState.conf.autoSizeUpdateEnabled) { val newTable = catalog.getTableMetadata(table.identifier) - val newSize = CommandUtils.calculateTotalSize(sparkSession.sessionState, newTable) + val newSize = CommandUtils.calculateTotalSize(sparkSession, newTable) val newStats = CatalogStatistics(sizeInBytes = newSize) catalog.alterTableStats(table.identifier, Some(newStats)) } else { @@ -47,15 +48,29 @@ object CommandUtils extends Logging { } } - def calculateTotalSize(sessionState: SessionState, catalogTable: CatalogTable): BigInt = { + def calculateTotalSize(spark: SparkSession, catalogTable: CatalogTable): BigInt = { + val sessionState = spark.sessionState if (catalogTable.partitionColumnNames.isEmpty) { calculateLocationSize(sessionState, catalogTable.identifier, catalogTable.storage.locationUri) } else { // Calculate table size as a sum of the visible partitions. See SPARK-21079 val partitions = sessionState.catalog.listPartitions(catalogTable.identifier) - partitions.map { p => - calculateLocationSize(sessionState, catalogTable.identifier, p.storage.locationUri) - }.sum + if (spark.sessionState.conf.parallelFileListingInStatsComputation) { + val paths = partitions.map(x => new Path(x.storage.locationUri.get)) + val stagingDir = sessionState.conf.getConfString("hive.exec.stagingdir", ".hive-staging") + val pathFilter = new PathFilter with Serializable { + override def accept(path: Path): Boolean = { + DataSourceUtils.isDataPath(path) && !path.getName.startsWith(stagingDir) + } + } + val fileStatusSeq = InMemoryFileIndex.bulkListLeafFiles( + paths, sessionState.newHadoopConf(), pathFilter, spark) + fileStatusSeq.flatMap(_._2.map(_.getLen)).sum + } else { + partitions.map { p => + calculateLocationSize(sessionState, catalogTable.identifier, p.storage.locationUri) + }.sum + } } } @@ -78,7 +93,8 @@ object CommandUtils extends Logging { val size = if (fileStatus.isDirectory) { fs.listStatus(path) .map { status => - if (!status.getPath.getName.startsWith(stagingDir)) { + if (!status.getPath.getName.startsWith(stagingDir) && + DataSourceUtils.isDataPath(path)) { getPathSize(fs, status.getPath) } else { 0L diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/DataWritingCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/DataWritingCommand.scala index e11dbd201004d..a1bb5af1ab723 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/DataWritingCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/DataWritingCommand.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{Command, LogicalPlan} import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.datasources.BasicWriteJobStatsTracker import org.apache.spark.sql.execution.datasources.FileFormatWriter -import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} +import org.apache.spark.sql.execution.metric.SQLMetric import org.apache.spark.util.SerializableConfiguration /** @@ -41,8 +41,12 @@ trait DataWritingCommand extends Command { override final def children: Seq[LogicalPlan] = query :: Nil - // Output columns of the analyzed input query plan - def outputColumns: Seq[Attribute] + // Output column names of the analyzed input query plan. + def outputColumnNames: Seq[String] + + // Output columns of the analyzed input query plan. + def outputColumns: Seq[Attribute] = + DataWritingCommand.logicalPlanOutputWithNames(query, outputColumnNames) lazy val metrics: Map[String, SQLMetric] = BasicWriteJobStatsTracker.metrics @@ -53,3 +57,21 @@ trait DataWritingCommand extends Command { def run(sparkSession: SparkSession, child: SparkPlan): Seq[Row] } + +object DataWritingCommand { + /** + * Returns output attributes with provided names. + * The length of provided names should be the same of the length of [[LogicalPlan.output]]. + */ + def logicalPlanOutputWithNames( + query: LogicalPlan, + names: Seq[String]): Seq[Attribute] = { + // Save the output attributes to a variable to avoid duplicated function calls. + val outputAttributes = query.output + assert(outputAttributes.length == names.length, + "The length of provided names doesn't match the length of output attributes.") + outputAttributes.zip(names).map { case (attr, outputName) => + attr.withName(outputName) + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala index f6ef433f2ce15..b2e1f530b5328 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala @@ -139,7 +139,7 @@ case class CreateDataSourceTableAsSelectCommand( table: CatalogTable, mode: SaveMode, query: LogicalPlan, - outputColumns: Seq[Attribute]) + outputColumnNames: Seq[String]) extends DataWritingCommand { override def run(sparkSession: SparkSession, child: SparkPlan): Seq[Row] = { @@ -214,7 +214,7 @@ case class CreateDataSourceTableAsSelectCommand( catalogTable = if (tableExists) Some(table) else None) try { - dataSource.writeAndRead(mode, query, outputColumns, physicalPlan) + dataSource.writeAndRead(mode, query, outputColumnNames, physicalPlan) } catch { case ex: AnalysisException => logError(s"Failed to write to table ${table.identifier.unquotedString}", ex) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala index 04bf8c6dd917f..e1faecedd20ed 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala @@ -892,7 +892,8 @@ object DDLUtils { */ def verifyNotReadPath(query: LogicalPlan, outputPath: Path) : Unit = { val inputPaths = query.collect { - case LogicalRelation(r: HadoopFsRelation, _, _, _) => r.location.rootPaths + case LogicalRelation(r: HadoopFsRelation, _, _, _) => + r.location.rootPaths }.flatten if (inputPaths.contains(outputPath)) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala index ec3961f84bd8d..2eca1c40a5b3f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala @@ -18,14 +18,14 @@ package org.apache.spark.sql.execution.command import java.io.File -import java.net.URI +import java.net.{URI, URISyntaxException} import java.nio.file.FileSystems import scala.collection.mutable.ArrayBuffer import scala.util.Try import scala.util.control.NonFatal -import org.apache.hadoop.fs.Path +import org.apache.hadoop.fs.{FileContext, FsConstants, Path} import org.apache.spark.sql.{AnalysisException, Row, SparkSession} import org.apache.spark.sql.catalyst.TableIdentifier @@ -35,7 +35,7 @@ import org.apache.spark.sql.catalyst.catalog.CatalogTableType._ import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} import org.apache.spark.sql.catalyst.plans.logical.Histogram -import org.apache.spark.sql.catalyst.util.quoteIdentifier +import org.apache.spark.sql.catalyst.util.{escapeSingleQuotedString, quoteIdentifier} import org.apache.spark.sql.execution.datasources.{DataSource, PartitioningUtils} import org.apache.spark.sql.execution.datasources.csv.CSVFileFormat import org.apache.spark.sql.execution.datasources.json.JsonFileFormat @@ -303,94 +303,44 @@ case class LoadDataCommand( s"partitioned, but a partition spec was provided.") } } - - val loadPath = + val loadPath = { if (isLocal) { - val uri = Utils.resolveURI(path) - val file = new File(uri.getPath) - val exists = if (file.getAbsolutePath.contains("*")) { - val fileSystem = FileSystems.getDefault - val dir = file.getParentFile.getAbsolutePath - if (dir.contains("*")) { - throw new AnalysisException( - s"LOAD DATA input path allows only filename wildcard: $path") - } - - // Note that special characters such as "*" on Windows are not allowed as a path. - // Calling `WindowsFileSystem.getPath` throws an exception if there are in the path. - val dirPath = fileSystem.getPath(dir) - val pathPattern = new File(dirPath.toAbsolutePath.toString, file.getName).toURI.getPath - val safePathPattern = if (Utils.isWindows) { - // On Windows, the pattern should not start with slashes for absolute file paths. - pathPattern.stripPrefix("/") - } else { - pathPattern - } - val files = new File(dir).listFiles() - if (files == null) { - false - } else { - val matcher = fileSystem.getPathMatcher("glob:" + safePathPattern) - files.exists(f => matcher.matches(fileSystem.getPath(f.getAbsolutePath))) - } - } else { - new File(file.getAbsolutePath).exists() - } - if (!exists) { - throw new AnalysisException(s"LOAD DATA input path does not exist: $path") - } - uri + val localFS = FileContext.getLocalFSFileContext() + makeQualified(FsConstants.LOCAL_FS_URI, localFS.getWorkingDirectory(), new Path(path)) } else { - val uri = new URI(path) - val hdfsUri = if (uri.getScheme() != null && uri.getAuthority() != null) { - uri - } else { - // Follow Hive's behavior: - // If no schema or authority is provided with non-local inpath, - // we will use hadoop configuration "fs.defaultFS". - val defaultFSConf = sparkSession.sessionState.newHadoopConf().get("fs.defaultFS") - val defaultFS = if (defaultFSConf == null) { - new URI("") - } else { - new URI(defaultFSConf) - } - - val scheme = if (uri.getScheme() != null) { - uri.getScheme() - } else { - defaultFS.getScheme() - } - val authority = if (uri.getAuthority() != null) { - uri.getAuthority() - } else { - defaultFS.getAuthority() - } - - if (scheme == null) { - throw new AnalysisException( - s"LOAD DATA: URI scheme is required for non-local input paths: '$path'") - } - - // Follow Hive's behavior: - // If LOCAL is not specified, and the path is relative, - // then the path is interpreted relative to "/user/" - val uriPath = uri.getPath() - val absolutePath = if (uriPath != null && uriPath.startsWith("/")) { - uriPath - } else { - s"/user/${System.getProperty("user.name")}/$uriPath" - } - new URI(scheme, authority, absolutePath, uri.getQuery(), uri.getFragment()) - } - val hadoopConf = sparkSession.sessionState.newHadoopConf() - val srcPath = new Path(hdfsUri) - val fs = srcPath.getFileSystem(hadoopConf) - if (!fs.exists(srcPath)) { - throw new AnalysisException(s"LOAD DATA input path does not exist: $path") - } - hdfsUri + val loadPath = new Path(path) + // Follow Hive's behavior: + // If no schema or authority is provided with non-local inpath, + // we will use hadoop configuration "fs.defaultFS". + val defaultFSConf = sparkSession.sessionState.newHadoopConf().get("fs.defaultFS") + val defaultFS = if (defaultFSConf == null) new URI("") else new URI(defaultFSConf) + // Follow Hive's behavior: + // If LOCAL is not specified, and the path is relative, + // then the path is interpreted relative to "/user/" + val uriPath = new Path(s"/user/${System.getProperty("user.name")}/") + // makeQualified() will ignore the query parameter part while creating a path, so the + // entire string will be considered while making a Path instance,this is mainly done + // by considering the wild card scenario in mind.as per old logic query param is + // been considered while creating URI instance and if path contains wild card char '?' + // the remaining charecters after '?' will be removed while forming URI instance + makeQualified(defaultFS, uriPath, loadPath) } - + } + val fs = loadPath.getFileSystem(sparkSession.sessionState.newHadoopConf()) + // This handling is because while resolving the invalid URLs starting with file:/// + // system throws IllegalArgumentException from globStatus API,so in order to handle + // such scenarios this code is added in try catch block and after catching the + // runtime exception a generic error will be displayed to the user. + try { + val fileStatus = fs.globStatus(loadPath) + if (fileStatus == null || fileStatus.isEmpty) { + throw new AnalysisException(s"LOAD DATA input path does not exist: $path") + } + } catch { + case e: IllegalArgumentException => + log.warn(s"Exception while validating the load path $path ", e) + throw new AnalysisException(s"LOAD DATA input path does not exist: $path") + } if (partition.nonEmpty) { catalog.loadPartition( targetTable.identifier, @@ -413,6 +363,36 @@ case class LoadDataCommand( CommandUtils.updateTableStats(sparkSession, targetTable) Seq.empty[Row] } + + /** + * Returns a qualified path object. Method ported from org.apache.hadoop.fs.Path class. + * + * @param defaultUri default uri corresponding to the filesystem provided. + * @param workingDir the working directory for the particular child path wd-relative names. + * @param path Path instance based on the path string specified by the user. + * @return qualified path object + */ + private def makeQualified(defaultUri: URI, workingDir: Path, path: Path): Path = { + val pathUri = if (path.isAbsolute()) path.toUri() else new Path(workingDir, path).toUri() + if (pathUri.getScheme == null || pathUri.getAuthority == null && + defaultUri.getAuthority != null) { + val scheme = if (pathUri.getScheme == null) defaultUri.getScheme else pathUri.getScheme + val authority = if (pathUri.getAuthority == null) { + if (defaultUri.getAuthority == null) "" else defaultUri.getAuthority + } else { + pathUri.getAuthority + } + try { + val newUri = new URI(scheme, authority, pathUri.getPath, pathUri.getFragment) + new Path(newUri) + } catch { + case e: URISyntaxException => + throw new IllegalArgumentException(e) + } + } else { + path + } + } } /** @@ -960,6 +940,9 @@ case class ShowCreateTableCommand(table: TableIdentifier) extends RunnableComman case EXTERNAL => " EXTERNAL TABLE" case VIEW => " VIEW" case MANAGED => " TABLE" + case t => + throw new IllegalArgumentException( + s"Unknown table type is found at showCreateHiveTable: $t") } builder ++= s"CREATE$tableTypeString ${table.quotedString}" @@ -982,7 +965,7 @@ case class ShowCreateTableCommand(table: TableIdentifier) extends RunnableComman private def showHiveTableHeader(metadata: CatalogTable, builder: StringBuilder): Unit = { val columns = metadata.schema.filterNot { column => metadata.partitionColumnNames.contains(column.name) - }.map(columnToDDLFragment) + }.map(_.toDDL) if (columns.nonEmpty) { builder ++= columns.mkString("(", ", ", ")\n") @@ -994,14 +977,10 @@ case class ShowCreateTableCommand(table: TableIdentifier) extends RunnableComman .foreach(builder.append) } - private def columnToDDLFragment(column: StructField): String = { - val comment = column.getComment().map(escapeSingleQuotedString).map(" COMMENT '" + _ + "'") - s"${quoteIdentifier(column.name)} ${column.dataType.catalogString}${comment.getOrElse("")}" - } private def showHiveTableNonDataColumns(metadata: CatalogTable, builder: StringBuilder): Unit = { if (metadata.partitionColumnNames.nonEmpty) { - val partCols = metadata.partitionSchema.map(columnToDDLFragment) + val partCols = metadata.partitionSchema.map(_.toDDL) builder ++= partCols.mkString("PARTITIONED BY (", ", ", ")\n") } @@ -1072,7 +1051,7 @@ case class ShowCreateTableCommand(table: TableIdentifier) extends RunnableComman private def showDataSourceTableDataColumns( metadata: CatalogTable, builder: StringBuilder): Unit = { - val columns = metadata.schema.fields.map(f => s"${quoteIdentifier(f.name)} ${f.dataType.sql}") + val columns = metadata.schema.fields.map(_.toDDL) builder ++= columns.mkString("(", ", ", ")\n") } @@ -1117,15 +1096,4 @@ case class ShowCreateTableCommand(table: TableIdentifier) extends RunnableComman } } } - - private def escapeSingleQuotedString(str: String): String = { - val builder = StringBuilder.newBuilder - - str.foreach { - case '\'' => builder ++= s"\\\'" - case ch => builder += ch - } - - builder.toString() - } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CodecStreams.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CodecStreams.scala index c0df6c779d7bd..9fddfad249e5b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CodecStreams.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CodecStreams.scala @@ -50,7 +50,7 @@ object CodecStreams { */ def createInputStreamWithCloseResource(config: Configuration, path: Path): InputStream = { val inputStream = createInputStream(config, path) - Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => inputStream.close())) + Option(TaskContext.get()).foreach(_.addTaskCompletionListener[Unit](_ => inputStream.close())) inputStream } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala index f16d824201e77..ce3bc3dd48327 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala @@ -34,6 +34,7 @@ import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.command.DataWritingCommand import org.apache.spark.sql.execution.datasources.csv.CSVFileFormat import org.apache.spark.sql.execution.datasources.jdbc.JdbcRelationProvider import org.apache.spark.sql.execution.datasources.json.JsonFileFormat @@ -396,6 +397,7 @@ case class DataSource( hs.partitionSchema.map(_.name), "in the partition schema", equality) + DataSourceUtils.verifyReadSchema(hs.fileFormat, hs.dataSchema) case _ => SchemaUtils.checkColumnNameDuplication( relation.schema.map(_.name), @@ -449,7 +451,7 @@ case class DataSource( mode = mode, catalogTable = catalogTable, fileIndex = fileIndex, - outputColumns = data.output) + outputColumnNames = data.output.map(_.name)) } /** @@ -459,9 +461,9 @@ case class DataSource( * @param mode The save mode for this writing. * @param data The input query plan that produces the data to be written. Note that this plan * is analyzed and optimized. - * @param outputColumns The original output columns of the input query plan. The optimizer may not - * preserve the output column's names' case, so we need this parameter - * instead of `data.output`. + * @param outputColumnNames The original output column names of the input query plan. The + * optimizer may not preserve the output column's names' case, so we need + * this parameter instead of `data.output`. * @param physicalPlan The physical plan of the input query plan. We should run the writing * command with this physical plan instead of creating a new physical plan, * so that the metrics can be correctly linked to the given physical plan and @@ -470,8 +472,9 @@ case class DataSource( def writeAndRead( mode: SaveMode, data: LogicalPlan, - outputColumns: Seq[Attribute], + outputColumnNames: Seq[String], physicalPlan: SparkPlan): BaseRelation = { + val outputColumns = DataWritingCommand.logicalPlanOutputWithNames(data, outputColumnNames) if (outputColumns.map(_.dataType).exists(_.isInstanceOf[CalendarIntervalType])) { throw new AnalysisException("Cannot save interval data type into external storage.") } @@ -494,7 +497,9 @@ case class DataSource( s"Unable to resolve $name given [${data.output.map(_.name).mkString(", ")}]") } } - val resolved = cmd.copy(partitionColumns = resolvedPartCols, outputColumns = outputColumns) + val resolved = cmd.copy( + partitionColumns = resolvedPartCols, + outputColumnNames = outputColumnNames) resolved.run(sparkSession, physicalPlan) // Replace the schema with that of the DataFrame we just wrote out to avoid re-inferring copy(userSpecifiedSchema = Some(outputColumns.toStructType.asNullable)).resolveRelation() @@ -613,6 +618,8 @@ object DataSource extends Logging { case name if name.equalsIgnoreCase("orc") && conf.getConf(SQLConf.ORC_IMPLEMENTATION) == "hive" => "org.apache.spark.sql.hive.orc.OrcFileFormat" + case "com.databricks.spark.avro" if conf.replaceDatabricksSparkAvroEnabled => + "org.apache.spark.sql.avro.AvroFileFormat" case name => name } val provider2 = s"$provider1.DefaultSource" @@ -635,11 +642,17 @@ object DataSource extends Logging { "Please use the native ORC data source by setting 'spark.sql.orc.impl' to " + "'native'") } else if (provider1.toLowerCase(Locale.ROOT) == "avro" || - provider1 == "com.databricks.spark.avro") { + provider1 == "com.databricks.spark.avro" || + provider1 == "org.apache.spark.sql.avro") { throw new AnalysisException( - s"Failed to find data source: ${provider1.toLowerCase(Locale.ROOT)}. " + - "Please find an Avro package at " + - "http://spark.apache.org/third-party-projects.html") + s"Failed to find data source: $provider1. Avro is built-in but external data " + + "source module since Spark 2.4. Please deploy the application as per " + + "the deployment section of \"Apache Avro Data Source Guide\".") + } else if (provider1.toLowerCase(Locale.ROOT) == "kafka") { + throw new AnalysisException( + s"Failed to find data source: $provider1. Please deploy the application as " + + "per the deployment section of " + + "\"Structured Streaming + Kafka Integration Guide\".") } else { throw new ClassNotFoundException( s"Failed to find data source: $provider1. Please find packages at " + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index 7b129435c45db..c6000442fae76 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -131,7 +131,7 @@ case class DataSourceAnalysis(conf: SQLConf) extends Rule[LogicalPlan] with Cast projectList } - override def apply(plan: LogicalPlan): LogicalPlan = plan transform { + override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case CreateTable(tableDesc, mode, None) if DDLUtils.isDatasourceTable(tableDesc) => DDLUtils.checkDataColNames(tableDesc) CreateDataSourceTableCommand(tableDesc, ignoreIfExists = mode == SaveMode.Ignore) @@ -139,7 +139,7 @@ case class DataSourceAnalysis(conf: SQLConf) extends Rule[LogicalPlan] with Cast case CreateTable(tableDesc, mode, Some(query)) if query.resolved && DDLUtils.isDatasourceTable(tableDesc) => DDLUtils.checkDataColNames(tableDesc.copy(schema = query.schema)) - CreateDataSourceTableAsSelectCommand(tableDesc, mode, query, query.output) + CreateDataSourceTableAsSelectCommand(tableDesc, mode, query, query.output.map(_.name)) case InsertIntoTable(l @ LogicalRelation(_: InsertableRelation, _, _, _), parts, query, overwrite, false) if parts.isEmpty => @@ -209,7 +209,7 @@ case class DataSourceAnalysis(conf: SQLConf) extends Rule[LogicalPlan] with Cast mode, table, Some(t.location), - actualQuery.output) + actualQuery.output.map(_.name)) } } @@ -252,7 +252,7 @@ class FindDataSourceTable(sparkSession: SparkSession) extends Rule[LogicalPlan] table.partitionSchema.asNullable.toAttributes) } - override def apply(plan: LogicalPlan): LogicalPlan = plan transform { + override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case i @ InsertIntoTable(UnresolvedCatalogRelation(tableMeta), _, _, _, _) if DDLUtils.isDatasourceTable(tableMeta) => i.copy(table = readDataSourceTable(tableMeta)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceUtils.scala index c5347218c4b40..90cec5e72c1a7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceUtils.scala @@ -17,10 +17,9 @@ package org.apache.spark.sql.execution.datasources -import org.apache.spark.sql.execution.datasources.csv.CSVFileFormat -import org.apache.spark.sql.execution.datasources.json.JsonFileFormat -import org.apache.spark.sql.execution.datasources.orc.OrcFileFormat -import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat +import org.apache.hadoop.fs.Path + +import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.types._ @@ -42,65 +41,22 @@ object DataSourceUtils { /** * Verify if the schema is supported in datasource. This verification should be done - * in a driver side, e.g., `prepareWrite`, `buildReader`, and `buildReaderWithPartitionValues` - * in `FileFormat`. - * - * Unsupported data types of csv, json, orc, and parquet are as follows; - * csv -> R/W: Interval, Null, Array, Map, Struct - * json -> W: Interval - * orc -> W: Interval, Null - * parquet -> R/W: Interval, Null + * in a driver side. */ private def verifySchema(format: FileFormat, schema: StructType, isReadPath: Boolean): Unit = { - def throwUnsupportedException(dataType: DataType): Unit = { - throw new UnsupportedOperationException( - s"$format data source does not support ${dataType.simpleString} data type.") - } - - def verifyType(dataType: DataType): Unit = dataType match { - case BooleanType | ByteType | ShortType | IntegerType | LongType | FloatType | DoubleType | - StringType | BinaryType | DateType | TimestampType | _: DecimalType => - - // All the unsupported types for CSV - case _: NullType | _: CalendarIntervalType | _: StructType | _: ArrayType | _: MapType - if format.isInstanceOf[CSVFileFormat] => - throwUnsupportedException(dataType) - - case st: StructType => st.foreach { f => verifyType(f.dataType) } - - case ArrayType(elementType, _) => verifyType(elementType) - - case MapType(keyType, valueType, _) => - verifyType(keyType) - verifyType(valueType) - - case udt: UserDefinedType[_] => verifyType(udt.sqlType) - - // Interval type not supported in all the write path - case _: CalendarIntervalType if !isReadPath => - throwUnsupportedException(dataType) - - // JSON and ORC don't support an Interval type, but we pass it in read pass - // for back-compatibility. - case _: CalendarIntervalType if format.isInstanceOf[JsonFileFormat] || - format.isInstanceOf[OrcFileFormat] => - - // Interval type not supported in the other read path - case _: CalendarIntervalType => - throwUnsupportedException(dataType) - - // For JSON & ORC backward-compatibility - case _: NullType if format.isInstanceOf[JsonFileFormat] || - (isReadPath && format.isInstanceOf[OrcFileFormat]) => - - // Null type not supported in the other path - case _: NullType => - throwUnsupportedException(dataType) - - // We keep this default case for safeguards - case _ => throwUnsupportedException(dataType) + schema.foreach { field => + if (!format.supportDataType(field.dataType, isReadPath)) { + throw new AnalysisException( + s"$format data source does not support ${field.dataType.catalogString} data type.") + } } + } - schema.foreach(field => verifyType(field.dataType)) + // SPARK-24626: Metadata files and temporary files should not be + // counted as data files, so that they shouldn't participate in tasks like + // location size calculation. + private[sql] def isDataPath(path: Path): Boolean = { + val name = path.getName + !(name.startsWith("_") || name.startsWith(".")) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FailureSafeParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FailureSafeParser.scala index 43591a9ff524a..90e81661bae7a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FailureSafeParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FailureSafeParser.scala @@ -21,6 +21,7 @@ import org.apache.spark.SparkException import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.GenericInternalRow import org.apache.spark.sql.catalyst.util._ +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.StructType import org.apache.spark.unsafe.types.UTF8String @@ -28,7 +29,8 @@ class FailureSafeParser[IN]( rawParser: IN => Seq[InternalRow], mode: ParseMode, schema: StructType, - columnNameOfCorruptRecord: String) { + columnNameOfCorruptRecord: String, + isMultiLine: Boolean) { private val corruptFieldIndex = schema.getFieldIndex(columnNameOfCorruptRecord) private val actualSchema = StructType(schema.filterNot(_.name == columnNameOfCorruptRecord)) @@ -56,9 +58,15 @@ class FailureSafeParser[IN]( } } + private val skipParsing = !isMultiLine && mode == PermissiveMode && schema.isEmpty + def parse(input: IN): Iterator[InternalRow] = { try { - rawParser.apply(input).toIterator.map(row => toResultRow(Some(row), () => null)) + if (skipParsing) { + Iterator.single(InternalRow.empty) + } else { + rawParser.apply(input).toIterator.map(row => toResultRow(Some(row), () => null)) + } } catch { case e: BadRecordException => mode match { case PermissiveMode => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormat.scala index 023e127888290..2c162e23644ef 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormat.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources.Filter -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.types.{DataType, StructType} /** @@ -57,7 +57,7 @@ trait FileFormat { dataSchema: StructType): OutputWriterFactory /** - * Returns whether this format support returning columnar batch or not. + * Returns whether this format supports returning columnar batch or not. * * TODO: we should just have different traits for the different formats. */ @@ -152,6 +152,11 @@ trait FileFormat { } } + /** + * Returns whether this format supports the given [[DataType]] in read/write path. + * By default all data types are supported. + */ + def supportDataType(dataType: DataType, isReadPath: Boolean): Boolean = true } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala index 52da8356ab835..7c6ab4bc922fe 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala @@ -96,9 +96,11 @@ object FileFormatWriter extends Logging { val caseInsensitiveOptions = CaseInsensitiveMap(options) + val dataSchema = dataColumns.toStructType + DataSourceUtils.verifyWriteSchema(fileFormat, dataSchema) // Note: prepareWrite has side effect. It sets "job". val outputWriterFactory = - fileFormat.prepareWrite(sparkSession, job, caseInsensitiveOptions, dataColumns.toStructType) + fileFormat.prepareWrite(sparkSession, job, caseInsensitiveOptions, dataSchema) val description = new WriteJobDescription( uuid = UUID.randomUUID().toString, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala index 28c36b6020d33..345c9d82ca0e7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala @@ -89,14 +89,6 @@ class FileScanRDD( inputMetrics.setBytesRead(existingBytesRead + getBytesReadCallback()) } - // If we can't get the bytes read from the FS stats, fall back to the file size, - // which may be inaccurate. - private def updateBytesReadWithFileSize(): Unit = { - if (currentFile != null) { - inputMetrics.incBytesRead(currentFile.length) - } - } - private[this] val files = split.asInstanceOf[FilePartition].files.toIterator private[this] var currentFile: PartitionedFile = null private[this] var currentIterator: Iterator[Object] = null @@ -139,7 +131,6 @@ class FileScanRDD( /** Advances to the next file. Returns true if a new non-empty iterator is available. */ private def nextIterator(): Boolean = { - updateBytesReadWithFileSize() if (files.hasNext) { currentFile = files.next() logInfo(s"Reading File $currentFile") @@ -208,13 +199,12 @@ class FileScanRDD( override def close(): Unit = { updateBytesRead() - updateBytesReadWithFileSize() InputFileBlockHolder.unset() } } // Register an on-task-completion callback to close the input stream. - context.addTaskCompletionListener(_ => iterator.close()) + context.addTaskCompletionListener[Unit](_ => iterator.close()) iterator.asInstanceOf[Iterator[InternalRow]] // This is an erasure hack. } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InMemoryFileIndex.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InMemoryFileIndex.scala index 9d9f8bd5bb58e..dc5c2ff927e4a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InMemoryFileIndex.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InMemoryFileIndex.scala @@ -162,7 +162,7 @@ object InMemoryFileIndex extends Logging { * * @return for each input path, the set of discovered files for the path */ - private def bulkListLeafFiles( + private[sql] def bulkListLeafFiles( paths: Seq[Path], hadoopConf: Configuration, filter: PathFilter, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala index dd7ef0d15c140..484942d35c857 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala @@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogTable, CatalogT import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.command._ import org.apache.spark.sql.internal.SQLConf.PartitionOverwriteMode @@ -55,14 +56,14 @@ case class InsertIntoHadoopFsRelationCommand( mode: SaveMode, catalogTable: Option[CatalogTable], fileIndex: Option[FileIndex], - outputColumns: Seq[Attribute]) + outputColumnNames: Seq[String]) extends DataWritingCommand { import org.apache.spark.sql.catalyst.catalog.ExternalCatalogUtils.escapePathName override def run(sparkSession: SparkSession, child: SparkPlan): Seq[Row] = { // Most formats don't do well with duplicate columns, so lets not allow that - SchemaUtils.checkSchemaColumnNameDuplication( - query.schema, + SchemaUtils.checkColumnNameDuplication( + outputColumnNames, s"when inserting into $outputPath", sparkSession.sessionState.conf.caseSensitiveAnalysis) @@ -91,8 +92,12 @@ case class InsertIntoHadoopFsRelationCommand( val pathExists = fs.exists(qualifiedOutputPath) - val enableDynamicOverwrite = - sparkSession.sessionState.conf.partitionOverwriteMode == PartitionOverwriteMode.DYNAMIC + val parameters = CaseInsensitiveMap(options) + + val partitionOverwriteMode = parameters.get("partitionOverwriteMode") + .map(mode => PartitionOverwriteMode.withName(mode.toUpperCase)) + .getOrElse(sparkSession.sessionState.conf.partitionOverwriteMode) + val enableDynamicOverwrite = partitionOverwriteMode == PartitionOverwriteMode.DYNAMIC // This config only makes sense when we are overwriting a partitioned dataset with dynamic // partition columns. val dynamicPartitionOverwrite = enableDynamicOverwrite && mode == SaveMode.Overwrite && @@ -166,7 +171,15 @@ case class InsertIntoHadoopFsRelationCommand( // update metastore partition metadata - refreshUpdatedPartitions(updatedPartitionPaths) + if (updatedPartitionPaths.isEmpty && staticPartitions.nonEmpty + && partitionColumns.length == staticPartitions.size) { + // Avoid empty static partition can't loaded to datasource table. + val staticPathFragment = + PartitioningUtils.getPathFragment(staticPartitions, partitionColumns) + refreshUpdatedPartitions(Set(staticPathFragment)) + } else { + refreshUpdatedPartitions(updatedPartitionPaths) + } // refresh cached files in FileIndex fileIndex.foreach(_.refresh()) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala index f9a24806953e6..3183fd30e5e0d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala @@ -30,7 +30,7 @@ import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.{Resolver, TypeCoercion} import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec -import org.apache.spark.sql.catalyst.expressions.{Cast, Literal} +import org.apache.spark.sql.catalyst.expressions.{Attribute, Cast, Literal} import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ import org.apache.spark.sql.util.SchemaUtils @@ -284,6 +284,10 @@ object PartitioningUtils { }.mkString("/") } + def getPathFragment(spec: TablePartitionSpec, partitionColumns: Seq[Attribute]): String = { + getPathFragment(spec, StructType.fromAttributes(partitionColumns)) + } + /** * Normalize the column names in partition specification, w.r.t. the real partition column names * and case sensitivity. e.g., if the partition spec has a column named `monTh`, and there is a @@ -410,7 +414,7 @@ object PartitioningUtils { val dateTry = Try { // try and parse the date, if no exception occurs this is a candidate to be resolved as // DateType - DateTimeUtils.getThreadLocalDateFormat.parse(raw) + DateTimeUtils.getThreadLocalDateFormat(DateTimeUtils.defaultTimeZone()).parse(raw) // SPARK-23436: Casting the string to date may still return null if a bad Date is provided. // This can happen since DateFormat.parse may not use the entire text of the given string: // so if there are extra-characters after the date, it returns correctly. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala index 82322df407521..e840ff1682502 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala @@ -54,7 +54,8 @@ abstract class CSVDataSource extends Serializable { requiredSchema: StructType, // Actual schema of data in the csv file dataSchema: StructType, - caseSensitive: Boolean): Iterator[InternalRow] + caseSensitive: Boolean, + columnPruning: Boolean): Iterator[InternalRow] /** * Infers the schema from `inputPaths` files. @@ -181,25 +182,6 @@ object CSVDataSource extends Logging { } } } - - /** - * Checks that CSV header contains the same column names as fields names in the given schema - * by taking into account case sensitivity. - */ - def checkHeader( - header: String, - parser: CsvParser, - schema: StructType, - fileName: String, - enforceSchema: Boolean, - caseSensitive: Boolean): Unit = { - checkHeaderColumnNames( - schema, - parser.parseLine(header), - fileName, - enforceSchema, - caseSensitive) - } } object TextInputCSVDataSource extends CSVDataSource { @@ -211,10 +193,11 @@ object TextInputCSVDataSource extends CSVDataSource { parser: UnivocityParser, requiredSchema: StructType, dataSchema: StructType, - caseSensitive: Boolean): Iterator[InternalRow] = { + caseSensitive: Boolean, + columnPruning: Boolean): Iterator[InternalRow] = { val lines = { val linesReader = new HadoopFileLinesReader(file, conf) - Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => linesReader.close())) + Option(TaskContext.get()).foreach(_.addTaskCompletionListener[Unit](_ => linesReader.close())) linesReader.map { line => new String(line.getBytes, 0, line.getLength, parser.options.charset) } @@ -227,10 +210,11 @@ object TextInputCSVDataSource extends CSVDataSource { // Note: if there are only comments in the first block, the header would probably // be not extracted. CSVUtils.extractHeader(lines, parser.options).foreach { header => - CSVDataSource.checkHeader( - header, - parser.tokenizer, - dataSchema, + val schema = if (columnPruning) requiredSchema else dataSchema + val columnNames = parser.tokenizer.parseLine(header) + CSVDataSource.checkHeaderColumnNames( + schema, + columnNames, file.filePath, parser.options.enforceSchema, caseSensitive) @@ -256,23 +240,25 @@ object TextInputCSVDataSource extends CSVDataSource { sparkSession: SparkSession, csv: Dataset[String], maybeFirstLine: Option[String], - parsedOptions: CSVOptions): StructType = maybeFirstLine match { - case Some(firstLine) => - val firstRow = new CsvParser(parsedOptions.asParserSettings).parseLine(firstLine) - val caseSensitive = sparkSession.sessionState.conf.caseSensitiveAnalysis - val header = makeSafeHeader(firstRow, caseSensitive, parsedOptions) - val sampled: Dataset[String] = CSVUtils.sample(csv, parsedOptions) - val tokenRDD = sampled.rdd.mapPartitions { iter => - val filteredLines = CSVUtils.filterCommentAndEmpty(iter, parsedOptions) - val linesWithoutHeader = - CSVUtils.filterHeaderLine(filteredLines, firstLine, parsedOptions) - val parser = new CsvParser(parsedOptions.asParserSettings) - linesWithoutHeader.map(parser.parseLine) - } - CSVInferSchema.infer(tokenRDD, header, parsedOptions) - case None => - // If the first line could not be read, just return the empty schema. - StructType(Nil) + parsedOptions: CSVOptions): StructType = { + val csvParser = new CsvParser(parsedOptions.asParserSettings) + maybeFirstLine.map(csvParser.parseLine(_)) match { + case Some(firstRow) if firstRow != null => + val caseSensitive = sparkSession.sessionState.conf.caseSensitiveAnalysis + val header = makeSafeHeader(firstRow, caseSensitive, parsedOptions) + val sampled: Dataset[String] = CSVUtils.sample(csv, parsedOptions) + val tokenRDD = sampled.rdd.mapPartitions { iter => + val filteredLines = CSVUtils.filterCommentAndEmpty(iter, parsedOptions) + val linesWithoutHeader = + CSVUtils.filterHeaderLine(filteredLines, maybeFirstLine.get, parsedOptions) + val parser = new CsvParser(parsedOptions.asParserSettings) + linesWithoutHeader.map(parser.parseLine) + } + CSVInferSchema.infer(tokenRDD, header, parsedOptions) + case _ => + // If the first line could not be read, just return the empty schema. + StructType(Nil) + } } private def createBaseDataset( @@ -308,10 +294,12 @@ object MultiLineCSVDataSource extends CSVDataSource { parser: UnivocityParser, requiredSchema: StructType, dataSchema: StructType, - caseSensitive: Boolean): Iterator[InternalRow] = { + caseSensitive: Boolean, + columnPruning: Boolean): Iterator[InternalRow] = { def checkHeader(header: Array[String]): Unit = { + val schema = if (columnPruning) requiredSchema else dataSchema CSVDataSource.checkHeaderColumnNames( - dataSchema, + schema, header, file.filePath, parser.options.enforceSchema, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala index fa366ccce6b61..9aad0bd55e736 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.execution.datasources.csv +import java.nio.charset.Charset + import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileStatus, Path} import org.apache.hadoop.mapreduce._ @@ -66,7 +68,6 @@ class CSVFileFormat extends TextBasedFileFormat with DataSourceRegister { job: Job, options: Map[String, String], dataSchema: StructType): OutputWriterFactory = { - DataSourceUtils.verifyWriteSchema(this, dataSchema) val conf = job.getConfiguration val csvOptions = new CSVOptions( options, @@ -98,7 +99,6 @@ class CSVFileFormat extends TextBasedFileFormat with DataSourceRegister { filters: Seq[Filter], options: Map[String, String], hadoopConf: Configuration): (PartitionedFile) => Iterator[InternalRow] = { - DataSourceUtils.verifyReadSchema(this, dataSchema) val broadcastedHadoopConf = sparkSession.sparkContext.broadcast(new SerializableConfiguration(hadoopConf)) @@ -131,6 +131,7 @@ class CSVFileFormat extends TextBasedFileFormat with DataSourceRegister { ) } val caseSensitive = sparkSession.sessionState.conf.caseSensitiveAnalysis + val columnPruning = sparkSession.sessionState.conf.csvColumnPruning (file: PartitionedFile) => { val conf = broadcastedHadoopConf.value.value @@ -144,7 +145,8 @@ class CSVFileFormat extends TextBasedFileFormat with DataSourceRegister { parser, requiredSchema, dataSchema, - caseSensitive) + caseSensitive, + columnPruning) } } @@ -153,6 +155,15 @@ class CSVFileFormat extends TextBasedFileFormat with DataSourceRegister { override def hashCode(): Int = getClass.hashCode() override def equals(other: Any): Boolean = other.isInstanceOf[CSVFileFormat] + + override def supportDataType(dataType: DataType, isReadPath: Boolean): Boolean = dataType match { + case _: AtomicType => true + + case udt: UserDefinedType[_] => supportDataType(udt.sqlType, isReadPath) + + case _ => false + } + } private[csv] class CsvOutputWriter( @@ -161,7 +172,9 @@ private[csv] class CsvOutputWriter( context: TaskAttemptContext, params: CSVOptions) extends OutputWriter with Logging { - private val writer = CodecStreams.createOutputStreamWriter(context, new Path(path)) + private val charset = Charset.forName(params.charset) + + private val writer = CodecStreams.createOutputStreamWriter(context, new Path(path), charset) private val gen = new UnivocityGenerator(dataSchema, writer, params) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala index fab8d62da0c1d..492a21be6df3b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala @@ -162,6 +162,21 @@ class CSVOptions( */ val enforceSchema = getBool("enforceSchema", default = true) + + /** + * String representation of an empty value in read and in write. + */ + val emptyValue = parameters.get("emptyValue") + /** + * The string is returned when CSV reader doesn't have any characters for input value, + * or an empty quoted string `""`. Default value is empty string. + */ + val emptyValueInRead = emptyValue.getOrElse("") + /** + * The value is used instead of an empty string in write. Default value is `""` + */ + val emptyValueInWrite = emptyValue.getOrElse("\"\"") + def asWriterSettings: CsvWriterSettings = { val writerSettings = new CsvWriterSettings() val format = writerSettings.getFormat @@ -173,7 +188,7 @@ class CSVOptions( writerSettings.setIgnoreLeadingWhitespaces(ignoreLeadingWhiteSpaceFlagInWrite) writerSettings.setIgnoreTrailingWhitespaces(ignoreTrailingWhiteSpaceFlagInWrite) writerSettings.setNullValue(nullValue) - writerSettings.setEmptyValue("\"\"") + writerSettings.setEmptyValue(emptyValueInWrite) writerSettings.setSkipEmptyLines(true) writerSettings.setQuoteAllFields(quoteAll) writerSettings.setQuoteEscapingEnabled(escapeQuotes) @@ -194,7 +209,7 @@ class CSVOptions( settings.setInputBufferSize(inputBufferSize) settings.setMaxColumns(maxColumns) settings.setNullValue(nullValue) - settings.setEmptyValue("") + settings.setEmptyValue(emptyValueInRead) settings.setMaxCharsPerColumn(maxCharsPerColumn) settings.setUnescapedQuoteHandling(UnescapedQuoteHandling.STOP_AT_DELIMITER) settings diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala index aa545e1a0c00a..9088d43905e28 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala @@ -33,29 +33,49 @@ import org.apache.spark.sql.execution.datasources.FailureSafeParser import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String + +/** + * Constructs a parser for a given schema that translates CSV data to an [[InternalRow]]. + * + * @param dataSchema The CSV data schema that is specified by the user, or inferred from underlying + * data files. + * @param requiredSchema The schema of the data that should be output for each row. This should be a + * subset of the columns in dataSchema. + * @param options Configuration options for a CSV parser. + */ class UnivocityParser( dataSchema: StructType, requiredSchema: StructType, val options: CSVOptions) extends Logging { require(requiredSchema.toSet.subsetOf(dataSchema.toSet), - "requiredSchema should be the subset of schema.") + s"requiredSchema (${requiredSchema.catalogString}) should be the subset of " + + s"dataSchema (${dataSchema.catalogString}).") def this(schema: StructType, options: CSVOptions) = this(schema, schema, options) // A `ValueConverter` is responsible for converting the given value to a desired type. private type ValueConverter = String => Any + // This index is used to reorder parsed tokens + private val tokenIndexArr = + requiredSchema.map(f => java.lang.Integer.valueOf(dataSchema.indexOf(f))).toArray + + // When column pruning is enabled, the parser only parses the required columns based on + // their positions in the data schema. + private val parsedSchema = if (options.columnPruning) requiredSchema else dataSchema + val tokenizer = { val parserSetting = options.asParserSettings - if (options.columnPruning && requiredSchema.length < dataSchema.length) { - val tokenIndexArr = requiredSchema.map(f => java.lang.Integer.valueOf(dataSchema.indexOf(f))) + // When to-be-parsed schema is shorter than the to-be-read data schema, we let Univocity CSV + // parser select a sequence of fields for reading by their positions. + // if (options.columnPruning && requiredSchema.length < dataSchema.length) { + if (parsedSchema.length < dataSchema.length) { parserSetting.selectIndexes(tokenIndexArr: _*) } new CsvParser(parserSetting) } - private val schema = if (options.columnPruning) requiredSchema else dataSchema - private val row = new GenericInternalRow(schema.length) + private val row = new GenericInternalRow(requiredSchema.length) // Retrieve the raw record string. private def getCurrentInput: UTF8String = { @@ -82,7 +102,7 @@ class UnivocityParser( // // output row - ["A", 2] private val valueConverters: Array[ValueConverter] = { - schema.map(f => makeConverter(f.name, f.dataType, f.nullable, options)).toArray + requiredSchema.map(f => makeConverter(f.name, f.dataType, f.nullable, options)).toArray } /** @@ -183,29 +203,32 @@ class UnivocityParser( } } - private val doParse = if (schema.nonEmpty) { - (input: String) => convert(tokenizer.parseLine(input)) - } else { - // If `columnPruning` enabled and partition attributes scanned only, - // `schema` gets empty. - (_: String) => InternalRow.empty - } - /** * Parses a single CSV string and turns it into either one resulting row or no row (if the * the record is malformed). */ - def parse(input: String): InternalRow = doParse(input) + def parse(input: String): InternalRow = convert(tokenizer.parseLine(input)) + + private val getToken = if (options.columnPruning) { + (tokens: Array[String], index: Int) => tokens(index) + } else { + (tokens: Array[String], index: Int) => tokens(tokenIndexArr(index)) + } private def convert(tokens: Array[String]): InternalRow = { - if (tokens.length != schema.length) { + if (tokens == null) { + throw BadRecordException( + () => getCurrentInput, + () => None, + new RuntimeException("Malformed CSV record")) + } else if (tokens.length != parsedSchema.length) { // If the number of tokens doesn't match the schema, we should treat it as a malformed record. // However, we still have chance to parse some of the tokens, by adding extra null tokens in // the tail if the number is smaller, or by dropping extra tokens if the number is larger. - val checkedTokens = if (schema.length > tokens.length) { - tokens ++ new Array[String](schema.length - tokens.length) + val checkedTokens = if (parsedSchema.length > tokens.length) { + tokens ++ new Array[String](parsedSchema.length - tokens.length) } else { - tokens.take(schema.length) + tokens.take(parsedSchema.length) } def getPartialResult(): Option[InternalRow] = { try { @@ -222,9 +245,11 @@ class UnivocityParser( new RuntimeException("Malformed CSV record")) } else { try { + // When the length of the returned tokens is identical to the length of the parsed schema, + // we just need to convert the tokens that correspond to the required columns. var i = 0 - while (i < schema.length) { - row(i) = valueConverters(i).apply(tokens(i)) + while (i < requiredSchema.length) { + row(i) = valueConverters(i).apply(getToken(tokens, i)) i += 1 } row @@ -265,7 +290,8 @@ private[csv] object UnivocityParser { input => Seq(parser.convert(input)), parser.options.parseMode, schema, - parser.options.columnNameOfCorruptRecord) + parser.options.columnNameOfCorruptRecord, + parser.options.multiLine) convertStream(inputStream, shouldDropHeader, tokenizer, checkHeader) { tokens => safeParser.parse(tokens) }.flatten @@ -313,7 +339,8 @@ private[csv] object UnivocityParser { input => Seq(parser.parse(input)), parser.options.parseMode, schema, - parser.options.columnNameOfCorruptRecord) + parser.options.columnNameOfCorruptRecord, + parser.options.multiLine) filteredLines.flatMap(safeParser.parse) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala index eea966d30948b..7dfbb9d8b5c05 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala @@ -119,9 +119,9 @@ class JDBCOptions( // the column used to partition val partitionColumn = parameters.get(JDBC_PARTITION_COLUMN) // the lower bound of partition column - val lowerBound = parameters.get(JDBC_LOWER_BOUND).map(_.toLong) + val lowerBound = parameters.get(JDBC_LOWER_BOUND) // the upper bound of the partition column - val upperBound = parameters.get(JDBC_UPPER_BOUND).map(_.toLong) + val upperBound = parameters.get(JDBC_UPPER_BOUND) // numPartitions is also used for data source writing require((partitionColumn.isEmpty && lowerBound.isEmpty && upperBound.isEmpty) || (partitionColumn.isDefined && lowerBound.isDefined && upperBound.isDefined && @@ -157,6 +157,8 @@ class JDBCOptions( // ------------------------------------------------------------ // if to truncate the table from the JDBC database val isTruncate = parameters.getOrElse(JDBC_TRUNCATE, "false").toBoolean + + val isCascadeTruncate: Option[Boolean] = parameters.get(JDBC_CASCADE_TRUNCATE).map(_.toBoolean) // the create table option , which can be table_options or partition_options. // E.g., "CREATE TABLE t (name string) ENGINE=InnoDB DEFAULT CHARSET=utf8" // TODO: to reuse the existing partition parameters for those partition specific options @@ -181,6 +183,9 @@ class JDBCOptions( } // An option to execute custom SQL before fetching data from the remote DB val sessionInitStatement = parameters.get(JDBC_SESSION_INIT_STATEMENT) + + // An option to allow/disallow pushing down predicate into JDBC data source + val pushDownPredicate = parameters.getOrElse(JDBC_PUSHDOWN_PREDICATE, "true").toBoolean } class JdbcOptionsInWrite( @@ -225,10 +230,12 @@ object JDBCOptions { val JDBC_QUERY_TIMEOUT = newOption("queryTimeout") val JDBC_BATCH_FETCH_SIZE = newOption("fetchsize") val JDBC_TRUNCATE = newOption("truncate") + val JDBC_CASCADE_TRUNCATE = newOption("cascadeTruncate") val JDBC_CREATE_TABLE_OPTIONS = newOption("createTableOptions") val JDBC_CREATE_TABLE_COLUMN_TYPES = newOption("createTableColumnTypes") val JDBC_CUSTOM_DATAFRAME_COLUMN_TYPES = newOption("customSchema") val JDBC_BATCH_INSERT_SIZE = newOption("batchsize") val JDBC_TXN_ISOLATION_LEVEL = newOption("isolationLevel") val JDBC_SESSION_INIT_STATEMENT = newOption("sessionInitStatement") + val JDBC_PUSHDOWN_PREDICATE = newOption("pushDownPredicate") } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala index 1b3b17c75e756..16b493892e3be 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala @@ -265,7 +265,7 @@ private[jdbc] class JDBCRDD( closed = true } - context.addTaskCompletionListener{ context => close() } + context.addTaskCompletionListener[Unit]{ context => close() } val inputMetrics = context.taskMetrics().inputMetrics val part = thePart.asInstanceOf[JDBCPartition] diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala index 97e2d255cb7be..f15014442e3fb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.execution.datasources.jdbc +import java.sql.{Date, Timestamp} + import scala.collection.mutable.ArrayBuffer import org.apache.spark.Partition @@ -24,9 +26,10 @@ import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD import org.apache.spark.sql.{AnalysisException, DataFrame, Row, SaveMode, SparkSession, SQLContext} import org.apache.spark.sql.catalyst.analysis._ +import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.jdbc.JdbcDialects import org.apache.spark.sql.sources._ -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.types.{DataType, DateType, NumericType, StructType, TimestampType} import org.apache.spark.util.Utils /** @@ -34,6 +37,7 @@ import org.apache.spark.util.Utils */ private[sql] case class JDBCPartitioningInfo( column: String, + columnType: DataType, lowerBound: Long, upperBound: Long, numPartitions: Int) @@ -51,16 +55,43 @@ private[sql] object JDBCRelation extends Logging { * the rows with null value for the partitions column. * * @param schema resolved schema of a JDBC table - * @param partitioning partition information to generate the where clause for each partition * @param resolver function used to determine if two identifiers are equal + * @param timeZoneId timezone ID to be used if a partition column type is date or timestamp * @param jdbcOptions JDBC options that contains url * @return an array of partitions with where clause for each partition */ def columnPartition( schema: StructType, - partitioning: JDBCPartitioningInfo, resolver: Resolver, + timeZoneId: String, jdbcOptions: JDBCOptions): Array[Partition] = { + val partitioning = { + import JDBCOptions._ + + val partitionColumn = jdbcOptions.partitionColumn + val lowerBound = jdbcOptions.lowerBound + val upperBound = jdbcOptions.upperBound + val numPartitions = jdbcOptions.numPartitions + + if (partitionColumn.isEmpty) { + assert(lowerBound.isEmpty && upperBound.isEmpty, "When 'partitionColumn' is not " + + s"specified, '$JDBC_LOWER_BOUND' and '$JDBC_UPPER_BOUND' are expected to be empty") + null + } else { + assert(lowerBound.nonEmpty && upperBound.nonEmpty && numPartitions.nonEmpty, + s"When 'partitionColumn' is specified, '$JDBC_LOWER_BOUND', '$JDBC_UPPER_BOUND', and " + + s"'$JDBC_NUM_PARTITIONS' are also required") + + val (column, columnType) = verifyAndGetNormalizedPartitionColumn( + schema, partitionColumn.get, resolver, jdbcOptions) + + val lowerBoundValue = toInternalBoundValue(lowerBound.get, columnType) + val upperBoundValue = toInternalBoundValue(upperBound.get, columnType) + JDBCPartitioningInfo( + column, columnType, lowerBoundValue, upperBoundValue, numPartitions.get) + } + } + if (partitioning == null || partitioning.numPartitions <= 1 || partitioning.lowerBound == partitioning.upperBound) { return Array[Partition](JDBCPartition(null, 0)) @@ -72,6 +103,8 @@ private[sql] object JDBCRelation extends Logging { "Operation not allowed: the lower bound of partitioning column is larger than the upper " + s"bound. Lower bound: $lowerBound; Upper bound: $upperBound") + val boundValueToString: Long => String = + toBoundValueInWhereClause(_, partitioning.columnType, timeZoneId) val numPartitions = if ((upperBound - lowerBound) >= partitioning.numPartitions || /* check for overflow */ (upperBound - lowerBound) < 0) { @@ -80,24 +113,25 @@ private[sql] object JDBCRelation extends Logging { logWarning("The number of partitions is reduced because the specified number of " + "partitions is less than the difference between upper bound and lower bound. " + s"Updated number of partitions: ${upperBound - lowerBound}; Input number of " + - s"partitions: ${partitioning.numPartitions}; Lower bound: $lowerBound; " + - s"Upper bound: $upperBound.") + s"partitions: ${partitioning.numPartitions}; " + + s"Lower bound: ${boundValueToString(lowerBound)}; " + + s"Upper bound: ${boundValueToString(upperBound)}.") upperBound - lowerBound } // Overflow and silliness can happen if you subtract then divide. // Here we get a little roundoff, but that's (hopefully) OK. val stride: Long = upperBound / numPartitions - lowerBound / numPartitions - val column = verifyAndGetNormalizedColumnName( - schema, partitioning.column, resolver, jdbcOptions) - var i: Int = 0 - var currentValue: Long = lowerBound + val column = partitioning.column + var currentValue = lowerBound val ans = new ArrayBuffer[Partition]() while (i < numPartitions) { - val lBound = if (i != 0) s"$column >= $currentValue" else null + val lBoundValue = boundValueToString(currentValue) + val lBound = if (i != 0) s"$column >= $lBoundValue" else null currentValue += stride - val uBound = if (i != numPartitions - 1) s"$column < $currentValue" else null + val uBoundValue = boundValueToString(currentValue) + val uBound = if (i != numPartitions - 1) s"$column < $uBoundValue" else null val whereClause = if (uBound == null) { lBound @@ -109,23 +143,58 @@ private[sql] object JDBCRelation extends Logging { ans += JDBCPartition(whereClause, i) i = i + 1 } - ans.toArray + val partitions = ans.toArray + logInfo(s"Number of partitions: $numPartitions, WHERE clauses of these partitions: " + + partitions.map(_.asInstanceOf[JDBCPartition].whereClause).mkString(", ")) + partitions } - // Verify column name based on the JDBC resolved schema - private def verifyAndGetNormalizedColumnName( + // Verify column name and type based on the JDBC resolved schema + private def verifyAndGetNormalizedPartitionColumn( schema: StructType, columnName: String, resolver: Resolver, - jdbcOptions: JDBCOptions): String = { + jdbcOptions: JDBCOptions): (String, DataType) = { val dialect = JdbcDialects.get(jdbcOptions.url) - schema.map(_.name).find { fieldName => - resolver(fieldName, columnName) || - resolver(dialect.quoteIdentifier(fieldName), columnName) - }.map(dialect.quoteIdentifier).getOrElse { + val column = schema.find { f => + resolver(f.name, columnName) || resolver(dialect.quoteIdentifier(f.name), columnName) + }.getOrElse { throw new AnalysisException(s"User-defined partition column $columnName not " + s"found in the JDBC relation: ${schema.simpleString(Utils.maxNumToStringFields)}") } + column.dataType match { + case _: NumericType | DateType | TimestampType => + case _ => + throw new AnalysisException( + s"Partition column type should be ${NumericType.simpleString}, " + + s"${DateType.catalogString}, or ${TimestampType.catalogString}, but " + + s"${column.dataType.catalogString} found.") + } + (dialect.quoteIdentifier(column.name), column.dataType) + } + + private def toInternalBoundValue(value: String, columnType: DataType): Long = columnType match { + case _: NumericType => value.toLong + case DateType => DateTimeUtils.fromJavaDate(Date.valueOf(value)).toLong + case TimestampType => DateTimeUtils.fromJavaTimestamp(Timestamp.valueOf(value)) + } + + private def toBoundValueInWhereClause( + value: Long, + columnType: DataType, + timeZoneId: String): String = { + def dateTimeToString(): String = { + val timeZone = DateTimeUtils.getTimeZone(timeZoneId) + val dateTimeStr = columnType match { + case DateType => DateTimeUtils.dateToString(value.toInt, timeZone) + case TimestampType => DateTimeUtils.timestampToString(value, timeZone) + } + s"'$dateTimeStr'" + } + columnType match { + case _: NumericType => value.toString + case DateType | TimestampType => dateTimeToString() + } } /** @@ -172,7 +241,11 @@ private[sql] case class JDBCRelation( // Check if JDBCRDD.compileFilter can accept input filters override def unhandledFilters(filters: Array[Filter]): Array[Filter] = { - filters.filter(JDBCRDD.compileFilter(_, JdbcDialects.get(jdbcOptions.url)).isEmpty) + if (jdbcOptions.pushDownPredicate) { + filters.filter(JDBCRDD.compileFilter(_, JdbcDialects.get(jdbcOptions.url)).isEmpty) + } else { + filters + } } override def buildScan(requiredColumns: Array[String], filters: Array[Filter]): RDD[Row] = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala index 782d626c1573c..e7456f9c8ed0b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala @@ -29,28 +29,11 @@ class JdbcRelationProvider extends CreatableRelationProvider override def createRelation( sqlContext: SQLContext, parameters: Map[String, String]): BaseRelation = { - import JDBCOptions._ - val jdbcOptions = new JDBCOptions(parameters) - val partitionColumn = jdbcOptions.partitionColumn - val lowerBound = jdbcOptions.lowerBound - val upperBound = jdbcOptions.upperBound - val numPartitions = jdbcOptions.numPartitions - - val partitionInfo = if (partitionColumn.isEmpty) { - assert(lowerBound.isEmpty && upperBound.isEmpty, "When 'partitionColumn' is not specified, " + - s"'$JDBC_LOWER_BOUND' and '$JDBC_UPPER_BOUND' are expected to be empty") - null - } else { - assert(lowerBound.nonEmpty && upperBound.nonEmpty && numPartitions.nonEmpty, - s"When 'partitionColumn' is specified, '$JDBC_LOWER_BOUND', '$JDBC_UPPER_BOUND', and " + - s"'$JDBC_NUM_PARTITIONS' are also required") - JDBCPartitioningInfo( - partitionColumn.get, lowerBound.get, upperBound.get, numPartitions.get) - } val resolver = sqlContext.conf.resolver + val timeZoneId = sqlContext.conf.sessionLocalTimeZone val schema = JDBCRelation.getSchema(resolver, jdbcOptions) - val parts = JDBCRelation.columnPartition(schema, partitionInfo, resolver, jdbcOptions) + val parts = JDBCRelation.columnPartition(schema, resolver, timeZoneId, jdbcOptions) JDBCRelation(schema, parts, jdbcOptions)(sqlContext.sparkSession) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala index b81737eda475b..edea549748b47 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala @@ -105,7 +105,12 @@ object JdbcUtils extends Logging { val statement = conn.createStatement try { statement.setQueryTimeout(options.queryTimeout) - statement.executeUpdate(dialect.getTruncateQuery(options.table)) + val truncateQuery = if (options.isCascadeTruncate.isDefined) { + dialect.getTruncateQuery(options.table, options.isCascadeTruncate) + } else { + dialect.getTruncateQuery(options.table) + } + statement.executeUpdate(truncateQuery) } finally { statement.close() } @@ -175,7 +180,7 @@ object JdbcUtils extends Logging { private def getJdbcType(dt: DataType, dialect: JdbcDialect): JdbcType = { dialect.getJDBCType(dt).orElse(getCommonJDBCType(dt)).getOrElse( - throw new IllegalArgumentException(s"Can't get JDBC type for ${dt.simpleString}")) + throw new IllegalArgumentException(s"Can't get JDBC type for ${dt.catalogString}")) } /** @@ -480,7 +485,7 @@ object JdbcUtils extends Logging { case LongType if metadata.contains("binarylong") => throw new IllegalArgumentException(s"Unsupported array element " + - s"type ${dt.simpleString} based on binary") + s"type ${dt.catalogString} based on binary") case ArrayType(_, _) => throw new IllegalArgumentException("Nested arrays unsupported") @@ -494,7 +499,7 @@ object JdbcUtils extends Logging { array => new GenericArrayData(elementConversion.apply(array.getArray))) row.update(pos, array) - case _ => throw new IllegalArgumentException(s"Unsupported type ${dt.simpleString}") + case _ => throw new IllegalArgumentException(s"Unsupported type ${dt.catalogString}") } private def nullSafeConvert[T](input: T, f: T => Any): Any = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala index 3b6df45e949e8..76f58371ae264 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala @@ -33,7 +33,7 @@ import org.apache.spark.input.{PortableDataStream, StreamInputFormat} import org.apache.spark.rdd.{BinaryFileRDD, RDD} import org.apache.spark.sql.{Dataset, Encoders, SparkSession} import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.json.{CreateJacksonParser, JacksonParser, JSONOptions} +import org.apache.spark.sql.catalyst.json.{CreateJacksonParser, JacksonParser, JsonInferSchema, JSONOptions} import org.apache.spark.sql.execution.SQLExecution import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.execution.datasources.text.TextFileFormat @@ -130,7 +130,7 @@ object TextInputJsonDataSource extends JsonDataSource { parser: JacksonParser, schema: StructType): Iterator[InternalRow] = { val linesReader = new HadoopFileLinesReader(file, parser.options.lineSeparatorInRead, conf) - Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => linesReader.close())) + Option(TaskContext.get()).foreach(_.addTaskCompletionListener[Unit](_ => linesReader.close())) val textParser = parser.options.encoding .map(enc => CreateJacksonParser.text(enc, _: JsonFactory, _: Text)) .getOrElse(CreateJacksonParser.text(_: JsonFactory, _: Text)) @@ -139,7 +139,8 @@ object TextInputJsonDataSource extends JsonDataSource { input => parser.parse(input, textParser, textToUTF8String), parser.options.parseMode, schema, - parser.options.columnNameOfCorruptRecord) + parser.options.columnNameOfCorruptRecord, + parser.options.multiLine) linesReader.flatMap(safeParser.parse) } @@ -223,7 +224,8 @@ object MultiLineJsonDataSource extends JsonDataSource { input => parser.parse[InputStream](input, streamParser, partitionedFileString), parser.options.parseMode, schema, - parser.options.columnNameOfCorruptRecord) + parser.options.columnNameOfCorruptRecord, + parser.options.multiLine) safeParser.parse( CodecStreams.createInputStreamWithCloseResource(conf, new Path(new URI(file.filePath)))) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala index 383bff1375a93..a9241afba537b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala @@ -30,7 +30,7 @@ import org.apache.spark.sql.catalyst.json.{JacksonGenerator, JacksonParser, JSON import org.apache.spark.sql.catalyst.util.CompressionCodecs import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.sources._ -import org.apache.spark.sql.types.{StringType, StructType} +import org.apache.spark.sql.types._ import org.apache.spark.util.SerializableConfiguration class JsonFileFormat extends TextBasedFileFormat with DataSourceRegister { @@ -65,8 +65,6 @@ class JsonFileFormat extends TextBasedFileFormat with DataSourceRegister { job: Job, options: Map[String, String], dataSchema: StructType): OutputWriterFactory = { - DataSourceUtils.verifyWriteSchema(this, dataSchema) - val conf = job.getConfiguration val parsedOptions = new JSONOptions( options, @@ -98,8 +96,6 @@ class JsonFileFormat extends TextBasedFileFormat with DataSourceRegister { filters: Seq[Filter], options: Map[String, String], hadoopConf: Configuration): PartitionedFile => Iterator[InternalRow] = { - DataSourceUtils.verifyReadSchema(this, dataSchema) - val broadcastedHadoopConf = sparkSession.sparkContext.broadcast(new SerializableConfiguration(hadoopConf)) @@ -148,6 +144,23 @@ class JsonFileFormat extends TextBasedFileFormat with DataSourceRegister { override def hashCode(): Int = getClass.hashCode() override def equals(other: Any): Boolean = other.isInstanceOf[JsonFileFormat] + + override def supportDataType(dataType: DataType, isReadPath: Boolean): Boolean = dataType match { + case _: AtomicType => true + + case st: StructType => st.forall { f => supportDataType(f.dataType, isReadPath) } + + case ArrayType(elementType, _) => supportDataType(elementType, isReadPath) + + case MapType(keyType, valueType, _) => + supportDataType(keyType, isReadPath) && supportDataType(valueType, isReadPath) + + case udt: UserDefinedType[_] => supportDataType(udt.sqlType, isReadPath) + + case _: NullType => true + + case _ => false + } } private[json] class JsonOutputWriter( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala index df488a748e3e5..4574f8247af54 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala @@ -59,6 +59,19 @@ private[sql] object OrcFileFormat { def checkFieldNames(names: Seq[String]): Unit = { names.foreach(checkFieldName) } + + def getQuotedSchemaString(dataType: DataType): String = dataType match { + case _: AtomicType => dataType.catalogString + case StructType(fields) => + fields.map(f => s"`${f.name}`:${getQuotedSchemaString(f.dataType)}") + .mkString("struct<", ",", ">") + case ArrayType(elementType, _) => + s"array<${getQuotedSchemaString(elementType)}>" + case MapType(keyType, valueType, _) => + s"map<${getQuotedSchemaString(keyType)},${getQuotedSchemaString(valueType)}>" + case _ => // UDT and others + dataType.catalogString + } } /** @@ -89,13 +102,11 @@ class OrcFileFormat job: Job, options: Map[String, String], dataSchema: StructType): OutputWriterFactory = { - DataSourceUtils.verifyWriteSchema(this, dataSchema) - val orcOptions = new OrcOptions(options, sparkSession.sessionState.conf) val conf = job.getConfiguration - conf.set(MAPRED_OUTPUT_SCHEMA.getAttribute, dataSchema.catalogString) + conf.set(MAPRED_OUTPUT_SCHEMA.getAttribute, OrcFileFormat.getQuotedSchemaString(dataSchema)) conf.set(COMPRESS.getAttribute, orcOptions.compressionCodec) @@ -143,8 +154,6 @@ class OrcFileFormat filters: Seq[Filter], options: Map[String, String], hadoopConf: Configuration): (PartitionedFile) => Iterator[InternalRow] = { - DataSourceUtils.verifyReadSchema(this, dataSchema) - if (sparkSession.sessionState.conf.orcFilterPushDown) { OrcFilters.createFilter(dataSchema, filters).foreach { f => OrcInputFormat.setSearchArgument(hadoopConf, f, dataSchema.fieldNames) @@ -196,7 +205,7 @@ class OrcFileFormat // There is a possibility that `initialize` and `initBatch` hit some errors (like OOM) // after opening a file. val iter = new RecordReaderIterator(batchReader) - Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => iter.close())) + Option(TaskContext.get()).foreach(_.addTaskCompletionListener[Unit](_ => iter.close())) batchReader.initialize(fileSplit, taskAttemptContext) batchReader.initBatch( @@ -211,7 +220,7 @@ class OrcFileFormat val orcRecordReader = new OrcInputFormat[OrcStruct] .createRecordReader(fileSplit, taskAttemptContext) val iter = new RecordReaderIterator[OrcStruct](orcRecordReader) - Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => iter.close())) + Option(TaskContext.get()).foreach(_.addTaskCompletionListener[Unit](_ => iter.close())) val fullSchema = requiredSchema.toAttributes ++ partitionSchema.toAttributes val unsafeProjection = GenerateUnsafeProjection.generate(fullSchema, fullSchema) @@ -228,4 +237,21 @@ class OrcFileFormat } } } + + override def supportDataType(dataType: DataType, isReadPath: Boolean): Boolean = dataType match { + case _: AtomicType => true + + case st: StructType => st.forall { f => supportDataType(f.dataType, isReadPath) } + + case ArrayType(elementType, _) => supportDataType(elementType, isReadPath) + + case MapType(keyType, valueType, _) => + supportDataType(keyType, isReadPath) && supportDataType(valueType, isReadPath) + + case udt: UserDefinedType[_] => supportDataType(udt.sqlType, isReadPath) + + case _: NullType => isReadPath + + case _ => false + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala index 4f44ae4fa1d71..dbafc468c6c40 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala @@ -17,11 +17,12 @@ package org.apache.spark.sql.execution.datasources.orc -import org.apache.orc.storage.ql.io.sarg.{PredicateLeaf, SearchArgument, SearchArgumentFactory} +import org.apache.orc.storage.ql.io.sarg.{PredicateLeaf, SearchArgument} import org.apache.orc.storage.ql.io.sarg.SearchArgument.Builder +import org.apache.orc.storage.ql.io.sarg.SearchArgumentFactory.newBuilder import org.apache.orc.storage.serde2.io.HiveDecimalWritable -import org.apache.spark.sql.sources.Filter +import org.apache.spark.sql.sources.{And, Filter} import org.apache.spark.sql.types._ /** @@ -54,7 +55,17 @@ import org.apache.spark.sql.types._ * builder methods mentioned above can only be found in test code, where all tested filters are * known to be convertible. */ -private[orc] object OrcFilters { +private[sql] object OrcFilters { + private[sql] def buildTree(filters: Seq[Filter]): Option[Filter] = { + filters match { + case Seq() => None + case Seq(filter) => Some(filter) + case Seq(filter1, filter2) => Some(And(filter1, filter2)) + case _ => // length > 2 + val (left, right) = filters.splitAt(filters.length / 2) + Some(And(buildTree(left).get, buildTree(right).get)) + } + } /** * Create ORC filter as a SearchArgument instance. @@ -66,14 +77,14 @@ private[orc] object OrcFilters { // collect all convertible ones to build the final `SearchArgument`. val convertibleFilters = for { filter <- filters - _ <- buildSearchArgument(dataTypeMap, filter, SearchArgumentFactory.newBuilder()) + _ <- buildSearchArgument(dataTypeMap, filter, newBuilder) } yield filter for { // Combines all convertible filters using `And` to produce a single conjunction - conjunction <- convertibleFilters.reduceOption(org.apache.spark.sql.sources.And) + conjunction <- buildTree(convertibleFilters) // Then tries to build a single ORC `SearchArgument` for the conjunction predicate - builder <- buildSearchArgument(dataTypeMap, conjunction, SearchArgumentFactory.newBuilder()) + builder <- buildSearchArgument(dataTypeMap, conjunction, newBuilder) } yield builder.build() } @@ -98,7 +109,7 @@ private[orc] object OrcFilters { case DateType => PredicateLeaf.Type.DATE case TimestampType => PredicateLeaf.Type.TIMESTAMP case _: DecimalType => PredicateLeaf.Type.DECIMAL - case _ => throw new UnsupportedOperationException(s"DataType: $dataType") + case _ => throw new UnsupportedOperationException(s"DataType: ${dataType.catalogString}") } /** @@ -127,8 +138,6 @@ private[orc] object OrcFilters { dataTypeMap: Map[String, DataType], expression: Filter, builder: Builder): Option[Builder] = { - def newBuilder = SearchArgumentFactory.newBuilder() - def getType(attribute: String): PredicateLeaf.Type = getPredicateLeafType(dataTypeMap(attribute)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcSerializer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcSerializer.scala index 899af0750cadf..90d1268028096 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcSerializer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcSerializer.scala @@ -223,6 +223,6 @@ class OrcSerializer(dataSchema: StructType) { * Return a Orc value object for the given Spark schema. */ private def createOrcValue(dataType: DataType) = { - OrcStruct.createValue(TypeDescription.fromString(dataType.catalogString)) + OrcStruct.createValue(TypeDescription.fromString(OrcFileFormat.getQuotedSchemaString(dataType))) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala index 460194ba61c8b..95fb25bf5addb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.execution.datasources.orc +import java.util.Locale + import scala.collection.JavaConverters._ import org.apache.hadoop.conf.Configuration @@ -27,7 +29,7 @@ import org.apache.spark.SparkException import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.internal.Logging import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.catalyst.analysis.{caseInsensitiveResolution, caseSensitiveResolution} +import org.apache.spark.sql.catalyst.analysis.caseSensitiveResolution import org.apache.spark.sql.catalyst.parser.CatalystSqlParser import org.apache.spark.sql.types._ @@ -79,9 +81,10 @@ object OrcUtils extends Logging { val ignoreCorruptFiles = sparkSession.sessionState.conf.ignoreCorruptFiles val conf = sparkSession.sessionState.newHadoopConf() // TODO: We need to support merge schema. Please see SPARK-11412. - files.map(_.getPath).flatMap(readSchema(_, conf, ignoreCorruptFiles)).headOption.map { schema => - logDebug(s"Reading schema from file $files, got Hive schema string: $schema") - CatalystSqlParser.parseDataType(schema.toString).asInstanceOf[StructType] + files.toIterator.map(file => readSchema(file.getPath, conf, ignoreCorruptFiles)).collectFirst { + case Some(schema) => + logDebug(s"Reading schema from file $files, got Hive schema string: $schema") + CatalystSqlParser.parseDataType(schema.toString).asInstanceOf[StructType] } } @@ -104,7 +107,7 @@ object OrcUtils extends Logging { // This is a ORC file written by Hive, no field names in the physical schema, assume the // physical schema maps to the data scheme by index. assert(orcFieldNames.length <= dataSchema.length, "The given data schema " + - s"${dataSchema.simpleString} has less fields than the actual ORC physical schema, " + + s"${dataSchema.catalogString} has less fields than the actual ORC physical schema, " + "no idea which columns were dropped, fail to read.") Some(requiredSchema.fieldNames.map { name => val index = dataSchema.fieldIndex(name) @@ -115,8 +118,29 @@ object OrcUtils extends Logging { } }) } else { - val resolver = if (isCaseSensitive) caseSensitiveResolution else caseInsensitiveResolution - Some(requiredSchema.fieldNames.map { name => orcFieldNames.indexWhere(resolver(_, name)) }) + if (isCaseSensitive) { + Some(requiredSchema.fieldNames.map { name => + orcFieldNames.indexWhere(caseSensitiveResolution(_, name)) + }) + } else { + // Do case-insensitive resolution only if in case-insensitive mode + val caseInsensitiveOrcFieldMap = + orcFieldNames.zipWithIndex.groupBy(_._1.toLowerCase(Locale.ROOT)) + Some(requiredSchema.fieldNames.map { requiredFieldName => + caseInsensitiveOrcFieldMap + .get(requiredFieldName.toLowerCase(Locale.ROOT)) + .map { matchedOrcFields => + if (matchedOrcFields.size > 1) { + // Need to fail if there is ambiguity, i.e. more than one field is matched. + val matchedOrcFieldsString = matchedOrcFields.map(_._1).mkString("[", ", ", "]") + throw new RuntimeException(s"""Found duplicate field(s) "$requiredFieldName": """ + + s"$matchedOrcFieldsString in case-insensitive mode") + } else { + matchedOrcFields.head._2 + } + }.getOrElse(-1) + }) + } } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala index 93de1faef527a..ea4f1592a7c2e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala @@ -22,7 +22,6 @@ import java.net.URI import scala.collection.JavaConverters._ import scala.collection.mutable -import scala.collection.parallel.ForkJoinTaskSupport import scala.util.{Failure, Try} import org.apache.hadoop.conf.Configuration @@ -78,8 +77,6 @@ class ParquetFileFormat job: Job, options: Map[String, String], dataSchema: StructType): OutputWriterFactory = { - DataSourceUtils.verifyWriteSchema(this, dataSchema) - val parquetOptions = new ParquetOptions(options, sparkSession.sessionState.conf) val conf = ContextUtil.getConfiguration(job) @@ -303,8 +300,6 @@ class ParquetFileFormat filters: Seq[Filter], options: Map[String, String], hadoopConf: Configuration): (PartitionedFile) => Iterator[InternalRow] = { - DataSourceUtils.verifyReadSchema(this, dataSchema) - hadoopConf.set(ParquetInputFormat.READ_SUPPORT_CLASS, classOf[ParquetReadSupport].getName) hadoopConf.set( ParquetReadSupport.SPARK_ROW_REQUESTED_SCHEMA, @@ -315,6 +310,9 @@ class ParquetFileFormat hadoopConf.set( SQLConf.SESSION_LOCAL_TIMEZONE.key, sparkSession.sessionState.conf.sessionLocalTimeZone) + hadoopConf.setBoolean( + SQLConf.CASE_SENSITIVE.key, + sparkSession.sessionState.conf.caseSensitiveAnalysis) ParquetWriteSupport.setSchema(requiredSchema, hadoopConf) @@ -338,40 +336,29 @@ class ParquetFileFormat val enableVectorizedReader: Boolean = sqlConf.parquetVectorizedReaderEnabled && resultSchema.forall(_.dataType.isInstanceOf[AtomicType]) - val enableRecordFilter: Boolean = - sparkSession.sessionState.conf.parquetRecordFilterEnabled - val timestampConversion: Boolean = - sparkSession.sessionState.conf.isParquetINT96TimestampConversion + val enableRecordFilter: Boolean = sqlConf.parquetRecordFilterEnabled + val timestampConversion: Boolean = sqlConf.isParquetINT96TimestampConversion val capacity = sqlConf.parquetVectorizedReaderBatchSize - val enableParquetFilterPushDown: Boolean = - sparkSession.sessionState.conf.parquetFilterPushDown + val enableParquetFilterPushDown: Boolean = sqlConf.parquetFilterPushDown // Whole stage codegen (PhysicalRDD) is able to deal with batches directly val returningBatch = supportBatch(sparkSession, resultSchema) val pushDownDate = sqlConf.parquetFilterPushDownDate + val pushDownTimestamp = sqlConf.parquetFilterPushDownTimestamp + val pushDownDecimal = sqlConf.parquetFilterPushDownDecimal val pushDownStringStartWith = sqlConf.parquetFilterPushDownStringStartWith + val pushDownInFilterThreshold = sqlConf.parquetFilterPushDownInFilterThreshold + val isCaseSensitive = sqlConf.caseSensitiveAnalysis (file: PartitionedFile) => { assert(file.partitionValues.numFields == partitionSchema.size) - // Try to push down filters when filter push-down is enabled. - val pushed = if (enableParquetFilterPushDown) { - filters - // Collects all converted Parquet filter predicates. Notice that not all predicates can be - // converted (`ParquetFilters.createFilter` returns an `Option`). That's why a `flatMap` - // is used here. - .flatMap(new ParquetFilters(pushDownDate, pushDownStringStartWith) - .createFilter(requiredSchema, _)) - .reduceOption(FilterApi.and) - } else { - None - } - val fileSplit = new FileSplit(new Path(new URI(file.filePath)), file.start, file.length, Array.empty) + val filePath = fileSplit.getPath val split = new org.apache.parquet.hadoop.ParquetInputSplit( - fileSplit.getPath, + filePath, fileSplit.getStart, fileSplit.getStart + fileSplit.getLength, fileSplit.getLength, @@ -379,16 +366,34 @@ class ParquetFileFormat null) val sharedConf = broadcastedHadoopConf.value.value + + lazy val footerFileMetaData = + ParquetFileReader.readFooter(sharedConf, filePath, SKIP_ROW_GROUPS).getFileMetaData + // Try to push down filters when filter push-down is enabled. + val pushed = if (enableParquetFilterPushDown) { + val parquetSchema = footerFileMetaData.getSchema + val parquetFilters = new ParquetFilters(pushDownDate, pushDownTimestamp, pushDownDecimal, + pushDownStringStartWith, pushDownInFilterThreshold, isCaseSensitive) + filters + // Collects all converted Parquet filter predicates. Notice that not all predicates can be + // converted (`ParquetFilters.createFilter` returns an `Option`). That's why a `flatMap` + // is used here. + .flatMap(parquetFilters.createFilter(parquetSchema, _)) + .reduceOption(FilterApi.and) + } else { + None + } + // PARQUET_INT96_TIMESTAMP_CONVERSION says to apply timezone conversions to int96 timestamps' // *only* if the file was created by something other than "parquet-mr", so check the actual // writer here for this file. We have to do this per-file, as each file in the table may // have different writers. - def isCreatedByParquetMr(): Boolean = { - val footer = ParquetFileReader.readFooter(sharedConf, fileSplit.getPath, SKIP_ROW_GROUPS) - footer.getFileMetaData().getCreatedBy().startsWith("parquet-mr") - } + // Define isCreatedByParquetMr as function to avoid unnecessary parquet footer reads. + def isCreatedByParquetMr: Boolean = + footerFileMetaData.getCreatedBy().startsWith("parquet-mr") + val convertTz = - if (timestampConversion && !isCreatedByParquetMr()) { + if (timestampConversion && !isCreatedByParquetMr) { Some(DateTimeUtils.getTimeZone(sharedConf.get(SQLConf.SESSION_LOCAL_TIMEZONE.key))) } else { None @@ -409,7 +414,7 @@ class ParquetFileFormat convertTz.orNull, enableOffHeapColumnVector && taskContext.isDefined, capacity) val iter = new RecordReaderIterator(vectorizedReader) // SPARK-23457 Register a task completion lister before `initialization`. - taskContext.foreach(_.addTaskCompletionListener(_ => iter.close())) + taskContext.foreach(_.addTaskCompletionListener[Unit](_ => iter.close())) vectorizedReader.initialize(split, hadoopAttemptContext) logDebug(s"Appending $partitionSchema ${file.partitionValues}") vectorizedReader.initBatch(partitionSchema, file.partitionValues) @@ -430,7 +435,7 @@ class ParquetFileFormat } val iter = new RecordReaderIterator(reader) // SPARK-23457 Register a task completion lister before `initialization`. - taskContext.foreach(_.addTaskCompletionListener(_ => iter.close())) + taskContext.foreach(_.addTaskCompletionListener[Unit](_ => iter.close())) reader.initialize(split, hadoopAttemptContext) val fullSchema = requiredSchema.toAttributes ++ partitionSchema.toAttributes @@ -450,6 +455,21 @@ class ParquetFileFormat } } } + + override def supportDataType(dataType: DataType, isReadPath: Boolean): Boolean = dataType match { + case _: AtomicType => true + + case st: StructType => st.forall { f => supportDataType(f.dataType, isReadPath) } + + case ArrayType(elementType, _) => supportDataType(elementType, isReadPath) + + case MapType(keyType, valueType, _) => + supportDataType(keyType, isReadPath) && supportDataType(valueType, isReadPath) + + case udt: UserDefinedType[_] => supportDataType(udt.sqlType, isReadPath) + + case _ => false + } } object ParquetFileFormat extends Logging { @@ -515,30 +535,23 @@ object ParquetFileFormat extends Logging { conf: Configuration, partFiles: Seq[FileStatus], ignoreCorruptFiles: Boolean): Seq[Footer] = { - val parFiles = partFiles.par - val pool = ThreadUtils.newForkJoinPool("readingParquetFooters", 8) - parFiles.tasksupport = new ForkJoinTaskSupport(pool) - try { - parFiles.flatMap { currentFile => - try { - // Skips row group information since we only need the schema. - // ParquetFileReader.readFooter throws RuntimeException, instead of IOException, - // when it can't read the footer. - Some(new Footer(currentFile.getPath(), - ParquetFileReader.readFooter( - conf, currentFile, SKIP_ROW_GROUPS))) - } catch { case e: RuntimeException => - if (ignoreCorruptFiles) { - logWarning(s"Skipped the footer in the corrupted file: $currentFile", e) - None - } else { - throw new IOException(s"Could not read footer for file: $currentFile", e) - } + ThreadUtils.parmap(partFiles, "readingParquetFooters", 8) { currentFile => + try { + // Skips row group information since we only need the schema. + // ParquetFileReader.readFooter throws RuntimeException, instead of IOException, + // when it can't read the footer. + Some(new Footer(currentFile.getPath(), + ParquetFileReader.readFooter( + conf, currentFile, SKIP_ROW_GROUPS))) + } catch { case e: RuntimeException => + if (ignoreCorruptFiles) { + logWarning(s"Skipped the footer in the corrupted file: $currentFile", e) + None + } else { + throw new IOException(s"Could not read footer for file: $currentFile", e) } - }.seq - } finally { - pool.shutdown() - } + } + }.flatten } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala index 21c9e2e4f82b4..0c286defb9406 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala @@ -17,197 +17,424 @@ package org.apache.spark.sql.execution.datasources.parquet -import java.sql.Date +import java.lang.{Boolean => JBoolean, Double => JDouble, Float => JFloat, Long => JLong} +import java.math.{BigDecimal => JBigDecimal} +import java.sql.{Date, Timestamp} +import java.util.Locale + +import scala.collection.JavaConverters.asScalaBufferConverter import org.apache.parquet.filter2.predicate._ import org.apache.parquet.filter2.predicate.FilterApi._ import org.apache.parquet.io.api.Binary -import org.apache.parquet.schema.PrimitiveComparator +import org.apache.parquet.schema.{DecimalMetadata, MessageType, OriginalType, PrimitiveComparator} +import org.apache.parquet.schema.OriginalType._ +import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName +import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName._ -import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils} import org.apache.spark.sql.catalyst.util.DateTimeUtils.SQLDate import org.apache.spark.sql.sources -import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String /** * Some utility function to convert Spark data source filters to Parquet filters. */ -private[parquet] class ParquetFilters(pushDownDate: Boolean, pushDownStartWith: Boolean) { +private[parquet] class ParquetFilters( + pushDownDate: Boolean, + pushDownTimestamp: Boolean, + pushDownDecimal: Boolean, + pushDownStartWith: Boolean, + pushDownInFilterThreshold: Int, + caseSensitive: Boolean) { + + /** + * Holds a single field information stored in the underlying parquet file. + * + * @param fieldName field name in parquet file + * @param fieldType field type related info in parquet file + */ + private case class ParquetField( + fieldName: String, + fieldType: ParquetSchemaType) + + private case class ParquetSchemaType( + originalType: OriginalType, + primitiveTypeName: PrimitiveTypeName, + length: Int, + decimalMetadata: DecimalMetadata) + + private val ParquetBooleanType = ParquetSchemaType(null, BOOLEAN, 0, null) + private val ParquetByteType = ParquetSchemaType(INT_8, INT32, 0, null) + private val ParquetShortType = ParquetSchemaType(INT_16, INT32, 0, null) + private val ParquetIntegerType = ParquetSchemaType(null, INT32, 0, null) + private val ParquetLongType = ParquetSchemaType(null, INT64, 0, null) + private val ParquetFloatType = ParquetSchemaType(null, FLOAT, 0, null) + private val ParquetDoubleType = ParquetSchemaType(null, DOUBLE, 0, null) + private val ParquetStringType = ParquetSchemaType(UTF8, BINARY, 0, null) + private val ParquetBinaryType = ParquetSchemaType(null, BINARY, 0, null) + private val ParquetDateType = ParquetSchemaType(DATE, INT32, 0, null) + private val ParquetTimestampMicrosType = ParquetSchemaType(TIMESTAMP_MICROS, INT64, 0, null) + private val ParquetTimestampMillisType = ParquetSchemaType(TIMESTAMP_MILLIS, INT64, 0, null) private def dateToDays(date: Date): SQLDate = { DateTimeUtils.fromJavaDate(date) } - private val makeEq: PartialFunction[DataType, (String, Any) => FilterPredicate] = { - case BooleanType => - (n: String, v: Any) => FilterApi.eq(booleanColumn(n), v.asInstanceOf[java.lang.Boolean]) - case IntegerType => - (n: String, v: Any) => FilterApi.eq(intColumn(n), v.asInstanceOf[Integer]) - case LongType => - (n: String, v: Any) => FilterApi.eq(longColumn(n), v.asInstanceOf[java.lang.Long]) - case FloatType => - (n: String, v: Any) => FilterApi.eq(floatColumn(n), v.asInstanceOf[java.lang.Float]) - case DoubleType => - (n: String, v: Any) => FilterApi.eq(doubleColumn(n), v.asInstanceOf[java.lang.Double]) + private def decimalToInt32(decimal: JBigDecimal): Integer = decimal.unscaledValue().intValue() + + private def decimalToInt64(decimal: JBigDecimal): JLong = decimal.unscaledValue().longValue() + + private def decimalToByteArray(decimal: JBigDecimal, numBytes: Int): Binary = { + val decimalBuffer = new Array[Byte](numBytes) + val bytes = decimal.unscaledValue().toByteArray + + val fixedLengthBytes = if (bytes.length == numBytes) { + bytes + } else { + val signByte = if (bytes.head < 0) -1: Byte else 0: Byte + java.util.Arrays.fill(decimalBuffer, 0, numBytes - bytes.length, signByte) + System.arraycopy(bytes, 0, decimalBuffer, numBytes - bytes.length, bytes.length) + decimalBuffer + } + Binary.fromConstantByteArray(fixedLengthBytes, 0, numBytes) + } + + private val makeEq: PartialFunction[ParquetSchemaType, (String, Any) => FilterPredicate] = { + case ParquetBooleanType => + (n: String, v: Any) => FilterApi.eq(booleanColumn(n), v.asInstanceOf[JBoolean]) + case ParquetByteType | ParquetShortType | ParquetIntegerType => + (n: String, v: Any) => FilterApi.eq( + intColumn(n), + Option(v).map(_.asInstanceOf[Number].intValue.asInstanceOf[Integer]).orNull) + case ParquetLongType => + (n: String, v: Any) => FilterApi.eq(longColumn(n), v.asInstanceOf[JLong]) + case ParquetFloatType => + (n: String, v: Any) => FilterApi.eq(floatColumn(n), v.asInstanceOf[JFloat]) + case ParquetDoubleType => + (n: String, v: Any) => FilterApi.eq(doubleColumn(n), v.asInstanceOf[JDouble]) // Binary.fromString and Binary.fromByteArray don't accept null values - case StringType => + case ParquetStringType => (n: String, v: Any) => FilterApi.eq( binaryColumn(n), Option(v).map(s => Binary.fromString(s.asInstanceOf[String])).orNull) - case BinaryType => + case ParquetBinaryType => (n: String, v: Any) => FilterApi.eq( binaryColumn(n), Option(v).map(b => Binary.fromReusedByteArray(v.asInstanceOf[Array[Byte]])).orNull) - case DateType if pushDownDate => + case ParquetDateType if pushDownDate => (n: String, v: Any) => FilterApi.eq( intColumn(n), Option(v).map(date => dateToDays(date.asInstanceOf[Date]).asInstanceOf[Integer]).orNull) + case ParquetTimestampMicrosType if pushDownTimestamp => + (n: String, v: Any) => FilterApi.eq( + longColumn(n), + Option(v).map(t => DateTimeUtils.fromJavaTimestamp(t.asInstanceOf[Timestamp]) + .asInstanceOf[JLong]).orNull) + case ParquetTimestampMillisType if pushDownTimestamp => + (n: String, v: Any) => FilterApi.eq( + longColumn(n), + Option(v).map(_.asInstanceOf[Timestamp].getTime.asInstanceOf[JLong]).orNull) + + case ParquetSchemaType(DECIMAL, INT32, _, _) if pushDownDecimal => + (n: String, v: Any) => FilterApi.eq( + intColumn(n), + Option(v).map(d => decimalToInt32(d.asInstanceOf[JBigDecimal])).orNull) + case ParquetSchemaType(DECIMAL, INT64, _, _) if pushDownDecimal => + (n: String, v: Any) => FilterApi.eq( + longColumn(n), + Option(v).map(d => decimalToInt64(d.asInstanceOf[JBigDecimal])).orNull) + case ParquetSchemaType(DECIMAL, FIXED_LEN_BYTE_ARRAY, length, _) if pushDownDecimal => + (n: String, v: Any) => FilterApi.eq( + binaryColumn(n), + Option(v).map(d => decimalToByteArray(d.asInstanceOf[JBigDecimal], length)).orNull) } - private val makeNotEq: PartialFunction[DataType, (String, Any) => FilterPredicate] = { - case BooleanType => - (n: String, v: Any) => FilterApi.notEq(booleanColumn(n), v.asInstanceOf[java.lang.Boolean]) - case IntegerType => - (n: String, v: Any) => FilterApi.notEq(intColumn(n), v.asInstanceOf[Integer]) - case LongType => - (n: String, v: Any) => FilterApi.notEq(longColumn(n), v.asInstanceOf[java.lang.Long]) - case FloatType => - (n: String, v: Any) => FilterApi.notEq(floatColumn(n), v.asInstanceOf[java.lang.Float]) - case DoubleType => - (n: String, v: Any) => FilterApi.notEq(doubleColumn(n), v.asInstanceOf[java.lang.Double]) - - case StringType => + private val makeNotEq: PartialFunction[ParquetSchemaType, (String, Any) => FilterPredicate] = { + case ParquetBooleanType => + (n: String, v: Any) => FilterApi.notEq(booleanColumn(n), v.asInstanceOf[JBoolean]) + case ParquetByteType | ParquetShortType | ParquetIntegerType => + (n: String, v: Any) => FilterApi.notEq( + intColumn(n), + Option(v).map(_.asInstanceOf[Number].intValue.asInstanceOf[Integer]).orNull) + case ParquetLongType => + (n: String, v: Any) => FilterApi.notEq(longColumn(n), v.asInstanceOf[JLong]) + case ParquetFloatType => + (n: String, v: Any) => FilterApi.notEq(floatColumn(n), v.asInstanceOf[JFloat]) + case ParquetDoubleType => + (n: String, v: Any) => FilterApi.notEq(doubleColumn(n), v.asInstanceOf[JDouble]) + + case ParquetStringType => (n: String, v: Any) => FilterApi.notEq( binaryColumn(n), Option(v).map(s => Binary.fromString(s.asInstanceOf[String])).orNull) - case BinaryType => + case ParquetBinaryType => (n: String, v: Any) => FilterApi.notEq( binaryColumn(n), Option(v).map(b => Binary.fromReusedByteArray(v.asInstanceOf[Array[Byte]])).orNull) - case DateType if pushDownDate => + case ParquetDateType if pushDownDate => (n: String, v: Any) => FilterApi.notEq( intColumn(n), Option(v).map(date => dateToDays(date.asInstanceOf[Date]).asInstanceOf[Integer]).orNull) + case ParquetTimestampMicrosType if pushDownTimestamp => + (n: String, v: Any) => FilterApi.notEq( + longColumn(n), + Option(v).map(t => DateTimeUtils.fromJavaTimestamp(t.asInstanceOf[Timestamp]) + .asInstanceOf[JLong]).orNull) + case ParquetTimestampMillisType if pushDownTimestamp => + (n: String, v: Any) => FilterApi.notEq( + longColumn(n), + Option(v).map(_.asInstanceOf[Timestamp].getTime.asInstanceOf[JLong]).orNull) + + case ParquetSchemaType(DECIMAL, INT32, _, _) if pushDownDecimal => + (n: String, v: Any) => FilterApi.notEq( + intColumn(n), + Option(v).map(d => decimalToInt32(d.asInstanceOf[JBigDecimal])).orNull) + case ParquetSchemaType(DECIMAL, INT64, _, _) if pushDownDecimal => + (n: String, v: Any) => FilterApi.notEq( + longColumn(n), + Option(v).map(d => decimalToInt64(d.asInstanceOf[JBigDecimal])).orNull) + case ParquetSchemaType(DECIMAL, FIXED_LEN_BYTE_ARRAY, length, _) if pushDownDecimal => + (n: String, v: Any) => FilterApi.notEq( + binaryColumn(n), + Option(v).map(d => decimalToByteArray(d.asInstanceOf[JBigDecimal], length)).orNull) } - private val makeLt: PartialFunction[DataType, (String, Any) => FilterPredicate] = { - case IntegerType => - (n: String, v: Any) => FilterApi.lt(intColumn(n), v.asInstanceOf[Integer]) - case LongType => - (n: String, v: Any) => FilterApi.lt(longColumn(n), v.asInstanceOf[java.lang.Long]) - case FloatType => - (n: String, v: Any) => FilterApi.lt(floatColumn(n), v.asInstanceOf[java.lang.Float]) - case DoubleType => - (n: String, v: Any) => FilterApi.lt(doubleColumn(n), v.asInstanceOf[java.lang.Double]) - - case StringType => + private val makeLt: PartialFunction[ParquetSchemaType, (String, Any) => FilterPredicate] = { + case ParquetByteType | ParquetShortType | ParquetIntegerType => (n: String, v: Any) => - FilterApi.lt(binaryColumn(n), - Binary.fromString(v.asInstanceOf[String])) - case BinaryType => + FilterApi.lt(intColumn(n), v.asInstanceOf[Number].intValue.asInstanceOf[Integer]) + case ParquetLongType => + (n: String, v: Any) => FilterApi.lt(longColumn(n), v.asInstanceOf[JLong]) + case ParquetFloatType => + (n: String, v: Any) => FilterApi.lt(floatColumn(n), v.asInstanceOf[JFloat]) + case ParquetDoubleType => + (n: String, v: Any) => FilterApi.lt(doubleColumn(n), v.asInstanceOf[JDouble]) + + case ParquetStringType => + (n: String, v: Any) => + FilterApi.lt(binaryColumn(n), Binary.fromString(v.asInstanceOf[String])) + case ParquetBinaryType => (n: String, v: Any) => FilterApi.lt(binaryColumn(n), Binary.fromReusedByteArray(v.asInstanceOf[Array[Byte]])) - case DateType if pushDownDate => + case ParquetDateType if pushDownDate => + (n: String, v: Any) => + FilterApi.lt(intColumn(n), dateToDays(v.asInstanceOf[Date]).asInstanceOf[Integer]) + case ParquetTimestampMicrosType if pushDownTimestamp => (n: String, v: Any) => FilterApi.lt( - intColumn(n), - Option(v).map(date => dateToDays(date.asInstanceOf[Date]).asInstanceOf[Integer]).orNull) + longColumn(n), + DateTimeUtils.fromJavaTimestamp(v.asInstanceOf[Timestamp]).asInstanceOf[JLong]) + case ParquetTimestampMillisType if pushDownTimestamp => + (n: String, v: Any) => FilterApi.lt( + longColumn(n), + v.asInstanceOf[Timestamp].getTime.asInstanceOf[JLong]) + + case ParquetSchemaType(DECIMAL, INT32, _, _) if pushDownDecimal => + (n: String, v: Any) => + FilterApi.lt(intColumn(n), decimalToInt32(v.asInstanceOf[JBigDecimal])) + case ParquetSchemaType(DECIMAL, INT64, _, _) if pushDownDecimal => + (n: String, v: Any) => + FilterApi.lt(longColumn(n), decimalToInt64(v.asInstanceOf[JBigDecimal])) + case ParquetSchemaType(DECIMAL, FIXED_LEN_BYTE_ARRAY, length, _) if pushDownDecimal => + (n: String, v: Any) => + FilterApi.lt(binaryColumn(n), decimalToByteArray(v.asInstanceOf[JBigDecimal], length)) } - private val makeLtEq: PartialFunction[DataType, (String, Any) => FilterPredicate] = { - case IntegerType => - (n: String, v: Any) => FilterApi.ltEq(intColumn(n), v.asInstanceOf[java.lang.Integer]) - case LongType => - (n: String, v: Any) => FilterApi.ltEq(longColumn(n), v.asInstanceOf[java.lang.Long]) - case FloatType => - (n: String, v: Any) => FilterApi.ltEq(floatColumn(n), v.asInstanceOf[java.lang.Float]) - case DoubleType => - (n: String, v: Any) => FilterApi.ltEq(doubleColumn(n), v.asInstanceOf[java.lang.Double]) - - case StringType => + private val makeLtEq: PartialFunction[ParquetSchemaType, (String, Any) => FilterPredicate] = { + case ParquetByteType | ParquetShortType | ParquetIntegerType => (n: String, v: Any) => - FilterApi.ltEq(binaryColumn(n), - Binary.fromString(v.asInstanceOf[String])) - case BinaryType => + FilterApi.ltEq(intColumn(n), v.asInstanceOf[Number].intValue.asInstanceOf[Integer]) + case ParquetLongType => + (n: String, v: Any) => FilterApi.ltEq(longColumn(n), v.asInstanceOf[JLong]) + case ParquetFloatType => + (n: String, v: Any) => FilterApi.ltEq(floatColumn(n), v.asInstanceOf[JFloat]) + case ParquetDoubleType => + (n: String, v: Any) => FilterApi.ltEq(doubleColumn(n), v.asInstanceOf[JDouble]) + + case ParquetStringType => + (n: String, v: Any) => + FilterApi.ltEq(binaryColumn(n), Binary.fromString(v.asInstanceOf[String])) + case ParquetBinaryType => (n: String, v: Any) => FilterApi.ltEq(binaryColumn(n), Binary.fromReusedByteArray(v.asInstanceOf[Array[Byte]])) - case DateType if pushDownDate => + case ParquetDateType if pushDownDate => + (n: String, v: Any) => + FilterApi.ltEq(intColumn(n), dateToDays(v.asInstanceOf[Date]).asInstanceOf[Integer]) + case ParquetTimestampMicrosType if pushDownTimestamp => (n: String, v: Any) => FilterApi.ltEq( - intColumn(n), - Option(v).map(date => dateToDays(date.asInstanceOf[Date]).asInstanceOf[Integer]).orNull) + longColumn(n), + DateTimeUtils.fromJavaTimestamp(v.asInstanceOf[Timestamp]).asInstanceOf[JLong]) + case ParquetTimestampMillisType if pushDownTimestamp => + (n: String, v: Any) => FilterApi.ltEq( + longColumn(n), + v.asInstanceOf[Timestamp].getTime.asInstanceOf[JLong]) + + case ParquetSchemaType(DECIMAL, INT32, _, _) if pushDownDecimal => + (n: String, v: Any) => + FilterApi.ltEq(intColumn(n), decimalToInt32(v.asInstanceOf[JBigDecimal])) + case ParquetSchemaType(DECIMAL, INT64, _, _) if pushDownDecimal => + (n: String, v: Any) => + FilterApi.ltEq(longColumn(n), decimalToInt64(v.asInstanceOf[JBigDecimal])) + case ParquetSchemaType(DECIMAL, FIXED_LEN_BYTE_ARRAY, length, _) if pushDownDecimal => + (n: String, v: Any) => + FilterApi.ltEq(binaryColumn(n), decimalToByteArray(v.asInstanceOf[JBigDecimal], length)) } - private val makeGt: PartialFunction[DataType, (String, Any) => FilterPredicate] = { - case IntegerType => - (n: String, v: Any) => FilterApi.gt(intColumn(n), v.asInstanceOf[java.lang.Integer]) - case LongType => - (n: String, v: Any) => FilterApi.gt(longColumn(n), v.asInstanceOf[java.lang.Long]) - case FloatType => - (n: String, v: Any) => FilterApi.gt(floatColumn(n), v.asInstanceOf[java.lang.Float]) - case DoubleType => - (n: String, v: Any) => FilterApi.gt(doubleColumn(n), v.asInstanceOf[java.lang.Double]) - - case StringType => + private val makeGt: PartialFunction[ParquetSchemaType, (String, Any) => FilterPredicate] = { + case ParquetByteType | ParquetShortType | ParquetIntegerType => (n: String, v: Any) => - FilterApi.gt(binaryColumn(n), - Binary.fromString(v.asInstanceOf[String])) - case BinaryType => + FilterApi.gt(intColumn(n), v.asInstanceOf[Number].intValue.asInstanceOf[Integer]) + case ParquetLongType => + (n: String, v: Any) => FilterApi.gt(longColumn(n), v.asInstanceOf[JLong]) + case ParquetFloatType => + (n: String, v: Any) => FilterApi.gt(floatColumn(n), v.asInstanceOf[JFloat]) + case ParquetDoubleType => + (n: String, v: Any) => FilterApi.gt(doubleColumn(n), v.asInstanceOf[JDouble]) + + case ParquetStringType => + (n: String, v: Any) => + FilterApi.gt(binaryColumn(n), Binary.fromString(v.asInstanceOf[String])) + case ParquetBinaryType => (n: String, v: Any) => FilterApi.gt(binaryColumn(n), Binary.fromReusedByteArray(v.asInstanceOf[Array[Byte]])) - case DateType if pushDownDate => + case ParquetDateType if pushDownDate => + (n: String, v: Any) => + FilterApi.gt(intColumn(n), dateToDays(v.asInstanceOf[Date]).asInstanceOf[Integer]) + case ParquetTimestampMicrosType if pushDownTimestamp => (n: String, v: Any) => FilterApi.gt( - intColumn(n), - Option(v).map(date => dateToDays(date.asInstanceOf[Date]).asInstanceOf[Integer]).orNull) + longColumn(n), + DateTimeUtils.fromJavaTimestamp(v.asInstanceOf[Timestamp]).asInstanceOf[JLong]) + case ParquetTimestampMillisType if pushDownTimestamp => + (n: String, v: Any) => FilterApi.gt( + longColumn(n), + v.asInstanceOf[Timestamp].getTime.asInstanceOf[JLong]) + + case ParquetSchemaType(DECIMAL, INT32, _, _) if pushDownDecimal => + (n: String, v: Any) => + FilterApi.gt(intColumn(n), decimalToInt32(v.asInstanceOf[JBigDecimal])) + case ParquetSchemaType(DECIMAL, INT64, _, _) if pushDownDecimal => + (n: String, v: Any) => + FilterApi.gt(longColumn(n), decimalToInt64(v.asInstanceOf[JBigDecimal])) + case ParquetSchemaType(DECIMAL, FIXED_LEN_BYTE_ARRAY, length, _) if pushDownDecimal => + (n: String, v: Any) => + FilterApi.gt(binaryColumn(n), decimalToByteArray(v.asInstanceOf[JBigDecimal], length)) } - private val makeGtEq: PartialFunction[DataType, (String, Any) => FilterPredicate] = { - case IntegerType => - (n: String, v: Any) => FilterApi.gtEq(intColumn(n), v.asInstanceOf[java.lang.Integer]) - case LongType => - (n: String, v: Any) => FilterApi.gtEq(longColumn(n), v.asInstanceOf[java.lang.Long]) - case FloatType => - (n: String, v: Any) => FilterApi.gtEq(floatColumn(n), v.asInstanceOf[java.lang.Float]) - case DoubleType => - (n: String, v: Any) => FilterApi.gtEq(doubleColumn(n), v.asInstanceOf[java.lang.Double]) - - case StringType => + private val makeGtEq: PartialFunction[ParquetSchemaType, (String, Any) => FilterPredicate] = { + case ParquetByteType | ParquetShortType | ParquetIntegerType => (n: String, v: Any) => - FilterApi.gtEq(binaryColumn(n), - Binary.fromString(v.asInstanceOf[String])) - case BinaryType => + FilterApi.gtEq(intColumn(n), v.asInstanceOf[Number].intValue.asInstanceOf[Integer]) + case ParquetLongType => + (n: String, v: Any) => FilterApi.gtEq(longColumn(n), v.asInstanceOf[JLong]) + case ParquetFloatType => + (n: String, v: Any) => FilterApi.gtEq(floatColumn(n), v.asInstanceOf[JFloat]) + case ParquetDoubleType => + (n: String, v: Any) => FilterApi.gtEq(doubleColumn(n), v.asInstanceOf[JDouble]) + + case ParquetStringType => + (n: String, v: Any) => + FilterApi.gtEq(binaryColumn(n), Binary.fromString(v.asInstanceOf[String])) + case ParquetBinaryType => (n: String, v: Any) => FilterApi.gtEq(binaryColumn(n), Binary.fromReusedByteArray(v.asInstanceOf[Array[Byte]])) - case DateType if pushDownDate => + case ParquetDateType if pushDownDate => + (n: String, v: Any) => + FilterApi.gtEq(intColumn(n), dateToDays(v.asInstanceOf[Date]).asInstanceOf[Integer]) + case ParquetTimestampMicrosType if pushDownTimestamp => (n: String, v: Any) => FilterApi.gtEq( - intColumn(n), - Option(v).map(date => dateToDays(date.asInstanceOf[Date]).asInstanceOf[Integer]).orNull) + longColumn(n), + DateTimeUtils.fromJavaTimestamp(v.asInstanceOf[Timestamp]).asInstanceOf[JLong]) + case ParquetTimestampMillisType if pushDownTimestamp => + (n: String, v: Any) => FilterApi.gtEq( + longColumn(n), + v.asInstanceOf[Timestamp].getTime.asInstanceOf[JLong]) + + case ParquetSchemaType(DECIMAL, INT32, _, _) if pushDownDecimal => + (n: String, v: Any) => + FilterApi.gtEq(intColumn(n), decimalToInt32(v.asInstanceOf[JBigDecimal])) + case ParquetSchemaType(DECIMAL, INT64, _, _) if pushDownDecimal => + (n: String, v: Any) => + FilterApi.gtEq(longColumn(n), decimalToInt64(v.asInstanceOf[JBigDecimal])) + case ParquetSchemaType(DECIMAL, FIXED_LEN_BYTE_ARRAY, length, _) if pushDownDecimal => + (n: String, v: Any) => + FilterApi.gtEq(binaryColumn(n), decimalToByteArray(v.asInstanceOf[JBigDecimal], length)) } /** - * Returns a map from name of the column to the data type, if predicate push down applies. + * Returns a map, which contains parquet field name and data type, if predicate push down applies. */ - private def getFieldMap(dataType: DataType): Map[String, DataType] = dataType match { - case StructType(fields) => - // Here we don't flatten the fields in the nested schema but just look up through - // root fields. Currently, accessing to nested fields does not push down filters - // and it does not support to create filters for them. - fields.map(f => f.name -> f.dataType).toMap - case _ => Map.empty[String, DataType] + private def getFieldMap(dataType: MessageType): Map[String, ParquetField] = { + // Here we don't flatten the fields in the nested schema but just look up through + // root fields. Currently, accessing to nested fields does not push down filters + // and it does not support to create filters for them. + val primitiveFields = + dataType.getFields.asScala.filter(_.isPrimitive).map(_.asPrimitiveType()).map { f => + f.getName -> ParquetField(f.getName, + ParquetSchemaType(f.getOriginalType, + f.getPrimitiveTypeName, f.getTypeLength, f.getDecimalMetadata)) + } + if (caseSensitive) { + primitiveFields.toMap + } else { + // Don't consider ambiguity here, i.e. more than one field is matched in case insensitive + // mode, just skip pushdown for these fields, they will trigger Exception when reading, + // See: SPARK-25132. + val dedupPrimitiveFields = + primitiveFields + .groupBy(_._1.toLowerCase(Locale.ROOT)) + .filter(_._2.size == 1) + .mapValues(_.head._2) + CaseInsensitiveMap(dedupPrimitiveFields) + } } /** * Converts data sources filters to Parquet filter predicates. */ - def createFilter(schema: StructType, predicate: sources.Filter): Option[FilterPredicate] = { - val nameToType = getFieldMap(schema) + def createFilter(schema: MessageType, predicate: sources.Filter): Option[FilterPredicate] = { + val nameToParquetField = getFieldMap(schema) + + // Decimal type must make sure that filter value's scale matched the file. + // If doesn't matched, which would cause data corruption. + def isDecimalMatched(value: Any, decimalMeta: DecimalMetadata): Boolean = value match { + case decimal: JBigDecimal => + decimal.scale == decimalMeta.getScale + case _ => false + } + + // Parquet's type in the given file should be matched to the value's type + // in the pushed filter in order to push down the filter to Parquet. + def valueCanMakeFilterOn(name: String, value: Any): Boolean = { + value == null || (nameToParquetField(name).fieldType match { + case ParquetBooleanType => value.isInstanceOf[JBoolean] + case ParquetByteType | ParquetShortType | ParquetIntegerType => value.isInstanceOf[Number] + case ParquetLongType => value.isInstanceOf[JLong] + case ParquetFloatType => value.isInstanceOf[JFloat] + case ParquetDoubleType => value.isInstanceOf[JDouble] + case ParquetStringType => value.isInstanceOf[String] + case ParquetBinaryType => value.isInstanceOf[Array[Byte]] + case ParquetDateType => value.isInstanceOf[Date] + case ParquetTimestampMicrosType | ParquetTimestampMillisType => + value.isInstanceOf[Timestamp] + case ParquetSchemaType(DECIMAL, INT32, _, decimalMeta) => + isDecimalMatched(value, decimalMeta) + case ParquetSchemaType(DECIMAL, INT64, _, decimalMeta) => + isDecimalMatched(value, decimalMeta) + case ParquetSchemaType(DECIMAL, FIXED_LEN_BYTE_ARRAY, _, decimalMeta) => + isDecimalMatched(value, decimalMeta) + case _ => false + }) + } // Parquet does not allow dots in the column name because dots are used as a column path // delimiter. Since Parquet 1.8.2 (PARQUET-389), Parquet accepts the filter predicates // with missing columns. The incorrect results could be got from Parquet when we push down // filters for the column having dots in the names. Thus, we do not push down such filters. // See SPARK-20364. - def canMakeFilterOn(name: String): Boolean = nameToType.contains(name) && !name.contains(".") + def canMakeFilterOn(name: String, value: Any): Boolean = { + nameToParquetField.contains(name) && !name.contains(".") && valueCanMakeFilterOn(name, value) + } // NOTE: // @@ -225,30 +452,40 @@ private[parquet] class ParquetFilters(pushDownDate: Boolean, pushDownStartWith: // Probably I missed something and obviously this should be changed. predicate match { - case sources.IsNull(name) if canMakeFilterOn(name) => - makeEq.lift(nameToType(name)).map(_(name, null)) - case sources.IsNotNull(name) if canMakeFilterOn(name) => - makeNotEq.lift(nameToType(name)).map(_(name, null)) - - case sources.EqualTo(name, value) if canMakeFilterOn(name) => - makeEq.lift(nameToType(name)).map(_(name, value)) - case sources.Not(sources.EqualTo(name, value)) if canMakeFilterOn(name) => - makeNotEq.lift(nameToType(name)).map(_(name, value)) - - case sources.EqualNullSafe(name, value) if canMakeFilterOn(name) => - makeEq.lift(nameToType(name)).map(_(name, value)) - case sources.Not(sources.EqualNullSafe(name, value)) if canMakeFilterOn(name) => - makeNotEq.lift(nameToType(name)).map(_(name, value)) - - case sources.LessThan(name, value) if canMakeFilterOn(name) => - makeLt.lift(nameToType(name)).map(_(name, value)) - case sources.LessThanOrEqual(name, value) if canMakeFilterOn(name) => - makeLtEq.lift(nameToType(name)).map(_(name, value)) - - case sources.GreaterThan(name, value) if canMakeFilterOn(name) => - makeGt.lift(nameToType(name)).map(_(name, value)) - case sources.GreaterThanOrEqual(name, value) if canMakeFilterOn(name) => - makeGtEq.lift(nameToType(name)).map(_(name, value)) + case sources.IsNull(name) if canMakeFilterOn(name, null) => + makeEq.lift(nameToParquetField(name).fieldType) + .map(_(nameToParquetField(name).fieldName, null)) + case sources.IsNotNull(name) if canMakeFilterOn(name, null) => + makeNotEq.lift(nameToParquetField(name).fieldType) + .map(_(nameToParquetField(name).fieldName, null)) + + case sources.EqualTo(name, value) if canMakeFilterOn(name, value) => + makeEq.lift(nameToParquetField(name).fieldType) + .map(_(nameToParquetField(name).fieldName, value)) + case sources.Not(sources.EqualTo(name, value)) if canMakeFilterOn(name, value) => + makeNotEq.lift(nameToParquetField(name).fieldType) + .map(_(nameToParquetField(name).fieldName, value)) + + case sources.EqualNullSafe(name, value) if canMakeFilterOn(name, value) => + makeEq.lift(nameToParquetField(name).fieldType) + .map(_(nameToParquetField(name).fieldName, value)) + case sources.Not(sources.EqualNullSafe(name, value)) if canMakeFilterOn(name, value) => + makeNotEq.lift(nameToParquetField(name).fieldType) + .map(_(nameToParquetField(name).fieldName, value)) + + case sources.LessThan(name, value) if canMakeFilterOn(name, value) => + makeLt.lift(nameToParquetField(name).fieldType) + .map(_(nameToParquetField(name).fieldName, value)) + case sources.LessThanOrEqual(name, value) if canMakeFilterOn(name, value) => + makeLtEq.lift(nameToParquetField(name).fieldType) + .map(_(nameToParquetField(name).fieldName, value)) + + case sources.GreaterThan(name, value) if canMakeFilterOn(name, value) => + makeGt.lift(nameToParquetField(name).fieldType) + .map(_(nameToParquetField(name).fieldName, value)) + case sources.GreaterThanOrEqual(name, value) if canMakeFilterOn(name, value) => + makeGtEq.lift(nameToParquetField(name).fieldType) + .map(_(nameToParquetField(name).fieldName, value)) case sources.And(lhs, rhs) => // At here, it is not safe to just convert one side if we do not understand the @@ -272,7 +509,15 @@ private[parquet] class ParquetFilters(pushDownDate: Boolean, pushDownStartWith: case sources.Not(pred) => createFilter(schema, pred).map(FilterApi.not) - case sources.StringStartsWith(name, prefix) if pushDownStartWith && canMakeFilterOn(name) => + case sources.In(name, values) if canMakeFilterOn(name, values.head) + && values.distinct.length <= pushDownInFilterThreshold => + values.distinct.flatMap { v => + makeEq.lift(nameToParquetField(name).fieldType) + .map(_(nameToParquetField(name).fieldName, v)) + }.reduceLeftOption(FilterApi.or) + + case sources.StringStartsWith(name, prefix) + if pushDownStartWith && canMakeFilterOn(name, prefix) => Option(prefix).map { v => FilterApi.userDefined(binaryColumn(name), new UserDefinedPredicate[Binary] with Serializable { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadSupport.scala index 40ce5d5e0564e..3319e73f2b313 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadSupport.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadSupport.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.execution.datasources.parquet -import java.util.{Map => JMap, TimeZone} +import java.util.{Locale, Map => JMap, TimeZone} import scala.collection.JavaConverters._ @@ -30,6 +30,7 @@ import org.apache.parquet.schema.Type.Repetition import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ /** @@ -71,8 +72,10 @@ private[parquet] class ParquetReadSupport(val convertTz: Option[TimeZone]) StructType.fromString(schemaString) } - val parquetRequestedSchema = - ParquetReadSupport.clipParquetSchema(context.getFileSchema, catalystRequestedSchema) + val caseSensitive = context.getConfiguration.getBoolean(SQLConf.CASE_SENSITIVE.key, + SQLConf.CASE_SENSITIVE.defaultValue.get) + val parquetRequestedSchema = ParquetReadSupport.clipParquetSchema( + context.getFileSchema, catalystRequestedSchema, caseSensitive) new ReadContext(parquetRequestedSchema, Map.empty[String, String].asJava) } @@ -117,8 +120,12 @@ private[parquet] object ParquetReadSupport { * Tailors `parquetSchema` according to `catalystSchema` by removing column paths don't exist * in `catalystSchema`, and adding those only exist in `catalystSchema`. */ - def clipParquetSchema(parquetSchema: MessageType, catalystSchema: StructType): MessageType = { - val clippedParquetFields = clipParquetGroupFields(parquetSchema.asGroupType(), catalystSchema) + def clipParquetSchema( + parquetSchema: MessageType, + catalystSchema: StructType, + caseSensitive: Boolean = true): MessageType = { + val clippedParquetFields = clipParquetGroupFields( + parquetSchema.asGroupType(), catalystSchema, caseSensitive) if (clippedParquetFields.isEmpty) { ParquetSchemaConverter.EMPTY_MESSAGE } else { @@ -129,20 +136,21 @@ private[parquet] object ParquetReadSupport { } } - private def clipParquetType(parquetType: Type, catalystType: DataType): Type = { + private def clipParquetType( + parquetType: Type, catalystType: DataType, caseSensitive: Boolean): Type = { catalystType match { case t: ArrayType if !isPrimitiveCatalystType(t.elementType) => // Only clips array types with nested type as element type. - clipParquetListType(parquetType.asGroupType(), t.elementType) + clipParquetListType(parquetType.asGroupType(), t.elementType, caseSensitive) case t: MapType if !isPrimitiveCatalystType(t.keyType) || !isPrimitiveCatalystType(t.valueType) => // Only clips map types with nested key type or value type - clipParquetMapType(parquetType.asGroupType(), t.keyType, t.valueType) + clipParquetMapType(parquetType.asGroupType(), t.keyType, t.valueType, caseSensitive) case t: StructType => - clipParquetGroup(parquetType.asGroupType(), t) + clipParquetGroup(parquetType.asGroupType(), t, caseSensitive) case _ => // UDTs and primitive types are not clipped. For UDTs, a clipped version might not be able @@ -168,14 +176,15 @@ private[parquet] object ParquetReadSupport { * of the [[ArrayType]] should also be a nested type, namely an [[ArrayType]], a [[MapType]], or a * [[StructType]]. */ - private def clipParquetListType(parquetList: GroupType, elementType: DataType): Type = { + private def clipParquetListType( + parquetList: GroupType, elementType: DataType, caseSensitive: Boolean): Type = { // Precondition of this method, should only be called for lists with nested element types. assert(!isPrimitiveCatalystType(elementType)) // Unannotated repeated group should be interpreted as required list of required element, so // list element type is just the group itself. Clip it. if (parquetList.getOriginalType == null && parquetList.isRepetition(Repetition.REPEATED)) { - clipParquetType(parquetList, elementType) + clipParquetType(parquetList, elementType, caseSensitive) } else { assert( parquetList.getOriginalType == OriginalType.LIST, @@ -207,7 +216,7 @@ private[parquet] object ParquetReadSupport { Types .buildGroup(parquetList.getRepetition) .as(OriginalType.LIST) - .addField(clipParquetType(repeatedGroup, elementType)) + .addField(clipParquetType(repeatedGroup, elementType, caseSensitive)) .named(parquetList.getName) } else { // Otherwise, the repeated field's type is the element type with the repeated field's @@ -218,7 +227,7 @@ private[parquet] object ParquetReadSupport { .addField( Types .repeatedGroup() - .addField(clipParquetType(repeatedGroup.getType(0), elementType)) + .addField(clipParquetType(repeatedGroup.getType(0), elementType, caseSensitive)) .named(repeatedGroup.getName)) .named(parquetList.getName) } @@ -231,7 +240,10 @@ private[parquet] object ParquetReadSupport { * a [[StructType]]. */ private def clipParquetMapType( - parquetMap: GroupType, keyType: DataType, valueType: DataType): GroupType = { + parquetMap: GroupType, + keyType: DataType, + valueType: DataType, + caseSensitive: Boolean): GroupType = { // Precondition of this method, only handles maps with nested key types or value types. assert(!isPrimitiveCatalystType(keyType) || !isPrimitiveCatalystType(valueType)) @@ -243,8 +255,8 @@ private[parquet] object ParquetReadSupport { Types .repeatedGroup() .as(repeatedGroup.getOriginalType) - .addField(clipParquetType(parquetKeyType, keyType)) - .addField(clipParquetType(parquetValueType, valueType)) + .addField(clipParquetType(parquetKeyType, keyType, caseSensitive)) + .addField(clipParquetType(parquetValueType, valueType, caseSensitive)) .named(repeatedGroup.getName) Types @@ -262,8 +274,9 @@ private[parquet] object ParquetReadSupport { * [[MessageType]]. Because it's legal to construct an empty requested schema for column * pruning. */ - private def clipParquetGroup(parquetRecord: GroupType, structType: StructType): GroupType = { - val clippedParquetFields = clipParquetGroupFields(parquetRecord, structType) + private def clipParquetGroup( + parquetRecord: GroupType, structType: StructType, caseSensitive: Boolean): GroupType = { + val clippedParquetFields = clipParquetGroupFields(parquetRecord, structType, caseSensitive) Types .buildGroup(parquetRecord.getRepetition) .as(parquetRecord.getOriginalType) @@ -277,14 +290,35 @@ private[parquet] object ParquetReadSupport { * @return A list of clipped [[GroupType]] fields, which can be empty. */ private def clipParquetGroupFields( - parquetRecord: GroupType, structType: StructType): Seq[Type] = { - val parquetFieldMap = parquetRecord.getFields.asScala.map(f => f.getName -> f).toMap + parquetRecord: GroupType, structType: StructType, caseSensitive: Boolean): Seq[Type] = { val toParquet = new SparkToParquetSchemaConverter(writeLegacyParquetFormat = false) - structType.map { f => - parquetFieldMap - .get(f.name) - .map(clipParquetType(_, f.dataType)) - .getOrElse(toParquet.convertField(f)) + if (caseSensitive) { + val caseSensitiveParquetFieldMap = + parquetRecord.getFields.asScala.map(f => f.getName -> f).toMap + structType.map { f => + caseSensitiveParquetFieldMap + .get(f.name) + .map(clipParquetType(_, f.dataType, caseSensitive)) + .getOrElse(toParquet.convertField(f)) + } + } else { + // Do case-insensitive resolution only if in case-insensitive mode + val caseInsensitiveParquetFieldMap = + parquetRecord.getFields.asScala.groupBy(_.getName.toLowerCase(Locale.ROOT)) + structType.map { f => + caseInsensitiveParquetFieldMap + .get(f.name.toLowerCase(Locale.ROOT)) + .map { parquetTypes => + if (parquetTypes.size > 1) { + // Need to fail if there is ambiguity, i.e. more than one field is matched + val parquetTypesString = parquetTypes.map(_.getName).mkString("[", ", ", "]") + throw new RuntimeException(s"""Found duplicate field(s) "${f.name}": """ + + s"$parquetTypesString in case-insensitive mode") + } else { + clipParquetType(parquetTypes.head, f.dataType, caseSensitive) + } + }.getOrElse(toParquet.convertField(f)) + } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala index c61be077d309f..8ce8a86d2f026 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala @@ -26,7 +26,6 @@ import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName._ import org.apache.parquet.schema.Type.Repetition._ import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.execution.datasources.parquet.ParquetSchemaConverter.maxPrecisionForBytes import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -171,7 +170,7 @@ class ParquetToSparkSchemaConverter( case FIXED_LEN_BYTE_ARRAY => originalType match { - case DECIMAL => makeDecimalType(maxPrecisionForBytes(field.getTypeLength)) + case DECIMAL => makeDecimalType(Decimal.maxPrecisionForBytes(field.getTypeLength)) case INTERVAL => typeNotImplemented() case _ => illegalType() } @@ -411,7 +410,7 @@ class SparkToParquetSchemaConverter( .as(DECIMAL) .precision(precision) .scale(scale) - .length(ParquetSchemaConverter.minBytesForPrecision(precision)) + .length(Decimal.minBytesForPrecision(precision)) .named(field.name) // ======================== @@ -445,7 +444,7 @@ class SparkToParquetSchemaConverter( .as(DECIMAL) .precision(precision) .scale(scale) - .length(ParquetSchemaConverter.minBytesForPrecision(precision)) + .length(Decimal.minBytesForPrecision(precision)) .named(field.name) // =================================== @@ -555,7 +554,7 @@ class SparkToParquetSchemaConverter( convertField(field.copy(dataType = udt.sqlType)) case _ => - throw new AnalysisException(s"Unsupported data type $field.dataType") + throw new AnalysisException(s"Unsupported data type ${field.dataType.catalogString}") } } } @@ -584,23 +583,4 @@ private[sql] object ParquetSchemaConverter { throw new AnalysisException(message) } } - - private def computeMinBytesForPrecision(precision : Int) : Int = { - var numBytes = 1 - while (math.pow(2.0, 8 * numBytes - 1) < math.pow(10.0, precision)) { - numBytes += 1 - } - numBytes - } - - // Returns the minimum number of bytes needed to store a decimal with a given `precision`. - val minBytesForPrecision = Array.tabulate[Int](39)(computeMinBytesForPrecision) - - // Max precision of a decimal value stored in `numBytes` bytes - def maxPrecisionForBytes(numBytes: Int): Int = { - Math.round( // convert double to long - Math.floor(Math.log10( // number of base-10 digits - Math.pow(2, 8 * numBytes - 1) - 1))) // max value stored in numBytes - .asInstanceOf[Int] - } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaPruning.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaPruning.scala new file mode 100644 index 0000000000000..91080b15727d6 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaPruning.scala @@ -0,0 +1,279 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.parquet + +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.planning.PhysicalOperation +import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, Project} +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.execution.{ProjectionOverSchema, SelectedField} +import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, LogicalRelation} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types.{ArrayType, DataType, MapType, StructField, StructType} + +/** + * Prunes unnecessary Parquet columns given a [[PhysicalOperation]] over a + * [[ParquetRelation]]. By "Parquet column", we mean a column as defined in the + * Parquet format. In Spark SQL, a root-level Parquet column corresponds to a + * SQL column, and a nested Parquet column corresponds to a [[StructField]]. + */ +private[sql] object ParquetSchemaPruning extends Rule[LogicalPlan] { + override def apply(plan: LogicalPlan): LogicalPlan = + if (SQLConf.get.nestedSchemaPruningEnabled) { + apply0(plan) + } else { + plan + } + + private def apply0(plan: LogicalPlan): LogicalPlan = + plan transformDown { + case op @ PhysicalOperation(projects, filters, + l @ LogicalRelation(hadoopFsRelation: HadoopFsRelation, _, _, _)) + if canPruneRelation(hadoopFsRelation) => + val (normalizedProjects, normalizedFilters) = + normalizeAttributeRefNames(l, projects, filters) + val requestedRootFields = identifyRootFields(normalizedProjects, normalizedFilters) + + // If requestedRootFields includes a nested field, continue. Otherwise, + // return op + if (requestedRootFields.exists { root: RootField => !root.derivedFromAtt }) { + val dataSchema = hadoopFsRelation.dataSchema + val prunedDataSchema = pruneDataSchema(dataSchema, requestedRootFields) + + // If the data schema is different from the pruned data schema, continue. Otherwise, + // return op. We effect this comparison by counting the number of "leaf" fields in + // each schemata, assuming the fields in prunedDataSchema are a subset of the fields + // in dataSchema. + if (countLeaves(dataSchema) > countLeaves(prunedDataSchema)) { + val prunedParquetRelation = + hadoopFsRelation.copy(dataSchema = prunedDataSchema)(hadoopFsRelation.sparkSession) + + val prunedRelation = buildPrunedRelation(l, prunedParquetRelation) + val projectionOverSchema = ProjectionOverSchema(prunedDataSchema) + + buildNewProjection(normalizedProjects, normalizedFilters, prunedRelation, + projectionOverSchema) + } else { + op + } + } else { + op + } + } + + /** + * Checks to see if the given relation is Parquet and can be pruned. + */ + private def canPruneRelation(fsRelation: HadoopFsRelation) = + fsRelation.fileFormat.isInstanceOf[ParquetFileFormat] + + /** + * Normalizes the names of the attribute references in the given projects and filters to reflect + * the names in the given logical relation. This makes it possible to compare attributes and + * fields by name. Returns a tuple with the normalized projects and filters, respectively. + */ + private def normalizeAttributeRefNames( + logicalRelation: LogicalRelation, + projects: Seq[NamedExpression], + filters: Seq[Expression]): (Seq[NamedExpression], Seq[Expression]) = { + val normalizedAttNameMap = logicalRelation.output.map(att => (att.exprId, att.name)).toMap + val normalizedProjects = projects.map(_.transform { + case att: AttributeReference if normalizedAttNameMap.contains(att.exprId) => + att.withName(normalizedAttNameMap(att.exprId)) + }).map { case expr: NamedExpression => expr } + val normalizedFilters = filters.map(_.transform { + case att: AttributeReference if normalizedAttNameMap.contains(att.exprId) => + att.withName(normalizedAttNameMap(att.exprId)) + }) + (normalizedProjects, normalizedFilters) + } + + /** + * Returns the set of fields from the Parquet file that the query plan needs. + */ + private def identifyRootFields(projects: Seq[NamedExpression], filters: Seq[Expression]) = { + val projectionRootFields = projects.flatMap(getRootFields) + val filterRootFields = filters.flatMap(getRootFields) + + // Kind of expressions don't need to access any fields of a root fields, e.g., `IsNotNull`. + // For them, if there are any nested fields accessed in the query, we don't need to add root + // field access of above expressions. + // For example, for a query `SELECT name.first FROM contacts WHERE name IS NOT NULL`, + // we don't need to read nested fields of `name` struct other than `first` field. + val (rootFields, optRootFields) = (projectionRootFields ++ filterRootFields) + .distinct.partition(_.contentAccessed) + + optRootFields.filter { opt => + !rootFields.exists(_.field.name == opt.field.name) + } ++ rootFields + } + + /** + * Builds the new output [[Project]] Spark SQL operator that has the pruned output relation. + */ + private def buildNewProjection( + projects: Seq[NamedExpression], filters: Seq[Expression], prunedRelation: LogicalRelation, + projectionOverSchema: ProjectionOverSchema) = { + // Construct a new target for our projection by rewriting and + // including the original filters where available + val projectionChild = + if (filters.nonEmpty) { + val projectedFilters = filters.map(_.transformDown { + case projectionOverSchema(expr) => expr + }) + val newFilterCondition = projectedFilters.reduce(And) + Filter(newFilterCondition, prunedRelation) + } else { + prunedRelation + } + + // Construct the new projections of our Project by + // rewriting the original projections + val newProjects = projects.map(_.transformDown { + case projectionOverSchema(expr) => expr + }).map { case expr: NamedExpression => expr } + + if (log.isDebugEnabled) { + logDebug(s"New projects:\n${newProjects.map(_.treeString).mkString("\n")}") + } + + Project(newProjects, projectionChild) + } + + /** + * Filters the schema from the given file by the requested fields. + * Schema field ordering from the file is preserved. + */ + private def pruneDataSchema( + fileDataSchema: StructType, + requestedRootFields: Seq[RootField]) = { + // Merge the requested root fields into a single schema. Note the ordering of the fields + // in the resulting schema may differ from their ordering in the logical relation's + // original schema + val mergedSchema = requestedRootFields + .map { case root: RootField => StructType(Array(root.field)) } + .reduceLeft(_ merge _) + val dataSchemaFieldNames = fileDataSchema.fieldNames.toSet + val mergedDataSchema = + StructType(mergedSchema.filter(f => dataSchemaFieldNames.contains(f.name))) + // Sort the fields of mergedDataSchema according to their order in dataSchema, + // recursively. This makes mergedDataSchema a pruned schema of dataSchema + sortLeftFieldsByRight(mergedDataSchema, fileDataSchema).asInstanceOf[StructType] + } + + /** + * Builds a pruned logical relation from the output of the output relation and the schema of the + * pruned base relation. + */ + private def buildPrunedRelation( + outputRelation: LogicalRelation, + prunedBaseRelation: HadoopFsRelation) = { + // We need to replace the expression ids of the pruned relation output attributes + // with the expression ids of the original relation output attributes so that + // references to the original relation's output are not broken + val outputIdMap = outputRelation.output.map(att => (att.name, att.exprId)).toMap + val prunedRelationOutput = + prunedBaseRelation + .schema + .toAttributes + .map { + case att if outputIdMap.contains(att.name) => + att.withExprId(outputIdMap(att.name)) + case att => att + } + outputRelation.copy(relation = prunedBaseRelation, output = prunedRelationOutput) + } + + /** + * Gets the root (aka top-level, no-parent) [[StructField]]s for the given [[Expression]]. + * When expr is an [[Attribute]], construct a field around it and indicate that that + * field was derived from an attribute. + */ + private def getRootFields(expr: Expression): Seq[RootField] = { + expr match { + case att: Attribute => + RootField(StructField(att.name, att.dataType, att.nullable), derivedFromAtt = true) :: Nil + case SelectedField(field) => RootField(field, derivedFromAtt = false) :: Nil + // Root field accesses by `IsNotNull` and `IsNull` are special cases as the expressions + // don't actually use any nested fields. These root field accesses might be excluded later + // if there are any nested fields accesses in the query plan. + case IsNotNull(SelectedField(field)) => + RootField(field, derivedFromAtt = false, contentAccessed = false) :: Nil + case IsNull(SelectedField(field)) => + RootField(field, derivedFromAtt = false, contentAccessed = false) :: Nil + case IsNotNull(_: Attribute) | IsNull(_: Attribute) => + expr.children.flatMap(getRootFields).map(_.copy(contentAccessed = false)) + case _ => + expr.children.flatMap(getRootFields) + } + } + + /** + * Counts the "leaf" fields of the given dataType. Informally, this is the + * number of fields of non-complex data type in the tree representation of + * [[DataType]]. + */ + private def countLeaves(dataType: DataType): Int = { + dataType match { + case array: ArrayType => countLeaves(array.elementType) + case map: MapType => countLeaves(map.keyType) + countLeaves(map.valueType) + case struct: StructType => + struct.map(field => countLeaves(field.dataType)).sum + case _ => 1 + } + } + + /** + * Sorts the fields and descendant fields of structs in left according to their order in + * right. This function assumes that the fields of left are a subset of the fields of + * right, recursively. That is, left is a "subschema" of right, ignoring order of + * fields. + */ + private def sortLeftFieldsByRight(left: DataType, right: DataType): DataType = + (left, right) match { + case (ArrayType(leftElementType, containsNull), ArrayType(rightElementType, _)) => + ArrayType( + sortLeftFieldsByRight(leftElementType, rightElementType), + containsNull) + case (MapType(leftKeyType, leftValueType, containsNull), + MapType(rightKeyType, rightValueType, _)) => + MapType( + sortLeftFieldsByRight(leftKeyType, rightKeyType), + sortLeftFieldsByRight(leftValueType, rightValueType), + containsNull) + case (leftStruct: StructType, rightStruct: StructType) => + val filteredRightFieldNames = rightStruct.fieldNames.filter(leftStruct.fieldNames.contains) + val sortedLeftFields = filteredRightFieldNames.map { fieldName => + val leftFieldType = leftStruct(fieldName).dataType + val rightFieldType = rightStruct(fieldName).dataType + val sortedLeftFieldType = sortLeftFieldsByRight(leftFieldType, rightFieldType) + StructField(fieldName, sortedLeftFieldType) + } + StructType(sortedLeftFields) + case _ => left + } + + /** + * This represents a "root" schema field (aka top-level, no-parent). `field` is the + * `StructField` for field name and datatype. `derivedFromAtt` indicates whether it + * was derived from an attribute or had a proper child. `contentAccessed` means whether + * it was accessed with its content by the expressions refer it. + */ + private case class RootField(field: StructField, derivedFromAtt: Boolean, + contentAccessed: Boolean = true) +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetWriteSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetWriteSupport.scala index af4e1433c876f..b40b8c2e61f33 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetWriteSupport.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetWriteSupport.scala @@ -33,7 +33,6 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.SpecializedGetters import org.apache.spark.sql.catalyst.util.DateTimeUtils -import org.apache.spark.sql.execution.datasources.parquet.ParquetSchemaConverter.minBytesForPrecision import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -73,7 +72,8 @@ private[parquet] class ParquetWriteSupport extends WriteSupport[InternalRow] wit private val timestampBuffer = new Array[Byte](12) // Reusable byte array used to write decimal values - private val decimalBuffer = new Array[Byte](minBytesForPrecision(DecimalType.MAX_PRECISION)) + private val decimalBuffer = + new Array[Byte](Decimal.minBytesForPrecision(DecimalType.MAX_PRECISION)) override def init(configuration: Configuration): WriteContext = { val schemaString = configuration.get(ParquetWriteSupport.SPARK_ROW_SCHEMA) @@ -212,7 +212,7 @@ private[parquet] class ParquetWriteSupport extends WriteSupport[InternalRow] wit precision <= DecimalType.MAX_PRECISION, s"Decimal precision $precision exceeds max precision ${DecimalType.MAX_PRECISION}") - val numBytes = minBytesForPrecision(precision) + val numBytes = Decimal.minBytesForPrecision(precision) val int32Writer = (row: SpecializedGetters, ordinal: Int) => { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala index cab00251622b8..949aa665527ac 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala @@ -39,7 +39,7 @@ class ResolveSQLOnFile(sparkSession: SparkSession) extends Rule[LogicalPlan] { sparkSession.sessionState.conf.runSQLonFile && u.tableIdentifier.database.isDefined } - def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { + def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case u: UnresolvedRelation if maybeSQLFile(u) => try { val dataSource = DataSource( @@ -73,7 +73,7 @@ case class PreprocessTableCreation(sparkSession: SparkSession) extends Rule[Logi // catalog is a def and not a val/lazy val as the latter would introduce a circular reference private def catalog = sparkSession.sessionState.catalog - def apply(plan: LogicalPlan): LogicalPlan = plan transform { + def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { // When we CREATE TABLE without specifying the table schema, we should fail the query if // bucketing information is specified, as we can't infer bucketing from data files currently. // Since the runtime inferred partition columns could be different from what user specified, @@ -281,7 +281,7 @@ case class PreprocessTableCreation(sparkSession: SparkSession) extends Rule[Logi schema.filter(f => normalizedPartitionCols.contains(f.name)).map(_.dataType).foreach { case _: AtomicType => // OK - case other => failAnalysis(s"Cannot use ${other.simpleString} for partition column") + case other => failAnalysis(s"Cannot use ${other.catalogString} for partition column") } normalizedPartitionCols @@ -307,7 +307,7 @@ case class PreprocessTableCreation(sparkSession: SparkSession) extends Rule[Logi normalizedBucketSpec.sortColumnNames.map(schema(_)).map(_.dataType).foreach { case dt if RowOrdering.isOrderable(dt) => // OK - case other => failAnalysis(s"Cannot use ${other.simpleString} for sorting column") + case other => failAnalysis(s"Cannot use ${other.catalogString} for sorting column") } Some(normalizedBucketSpec) @@ -365,7 +365,7 @@ case class PreprocessTableInsertion(conf: SQLConf) extends Rule[LogicalPlan] { } } - def apply(plan: LogicalPlan): LogicalPlan = plan transform { + def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case i @ InsertIntoTable(table, _, query, _, _) if table.resolved && query.resolved => table match { case relation: HiveTableRelation => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala index e93908da43535..268297148b522 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala @@ -30,7 +30,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter import org.apache.spark.sql.catalyst.util.CompressionCodecs import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.sources._ -import org.apache.spark.sql.types.{StringType, StructType} +import org.apache.spark.sql.types.{DataType, StringType, StructType} import org.apache.spark.util.SerializableConfiguration /** @@ -47,11 +47,6 @@ class TextFileFormat extends TextBasedFileFormat with DataSourceRegister { throw new AnalysisException( s"Text data source supports only a single column, and you have ${schema.size} columns.") } - val tpe = schema(0).dataType - if (tpe != StringType) { - throw new AnalysisException( - s"Text data source supports only a string column, but you have ${tpe.simpleString}.") - } } override def isSplitable( @@ -125,7 +120,7 @@ class TextFileFormat extends TextBasedFileFormat with DataSourceRegister { } else { new HadoopFileWholeTextReader(file, confValue) } - Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => reader.close())) + Option(TaskContext.get()).foreach(_.addTaskCompletionListener[Unit](_ => reader.close())) if (requiredSchema.isEmpty) { val emptyUnsafeRow = new UnsafeRow(0) reader.map(_ => emptyUnsafeRow) @@ -141,6 +136,9 @@ class TextFileFormat extends TextBasedFileFormat with DataSourceRegister { } } } + + override def supportDataType(dataType: DataType, isReadPath: Boolean): Boolean = + dataType == StringType } class TextOutputWriter( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala index 8d6fb3820d420..f62f7349d1da7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala @@ -17,20 +17,22 @@ package org.apache.spark.sql.execution.datasources.v2 -import scala.collection.JavaConverters._ -import scala.reflect.ClassTag - -import org.apache.spark.{InterruptibleIterator, Partition, SparkContext, TaskContext} +import org.apache.spark._ import org.apache.spark.rdd.RDD -import org.apache.spark.sql.sources.v2.reader.InputPartition +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.sources.v2.reader.{InputPartition, PartitionReader, PartitionReaderFactory} -class DataSourceRDDPartition[T : ClassTag](val index: Int, val inputPartition: InputPartition[T]) +class DataSourceRDDPartition(val index: Int, val inputPartition: InputPartition) extends Partition with Serializable -class DataSourceRDD[T: ClassTag]( +// TODO: we should have 2 RDDs: an RDD[InternalRow] for row-based scan, an `RDD[ColumnarBatch]` for +// columnar scan. +class DataSourceRDD( sc: SparkContext, - @transient private val inputPartitions: Seq[InputPartition[T]]) - extends RDD[T](sc, Nil) { + @transient private val inputPartitions: Seq[InputPartition], + partitionReaderFactory: PartitionReaderFactory, + columnarReads: Boolean) + extends RDD[InternalRow](sc, Nil) { override protected def getPartitions: Array[Partition] = { inputPartitions.zipWithIndex.map { @@ -38,11 +40,21 @@ class DataSourceRDD[T: ClassTag]( }.toArray } - override def compute(split: Partition, context: TaskContext): Iterator[T] = { - val reader = split.asInstanceOf[DataSourceRDDPartition[T]].inputPartition - .createPartitionReader() - context.addTaskCompletionListener(_ => reader.close()) - val iter = new Iterator[T] { + private def castPartition(split: Partition): DataSourceRDDPartition = split match { + case p: DataSourceRDDPartition => p + case _ => throw new SparkException(s"[BUG] Not a DataSourceRDDPartition: $split") + } + + override def compute(split: Partition, context: TaskContext): Iterator[InternalRow] = { + val inputPartition = castPartition(split).inputPartition + val reader: PartitionReader[_] = if (columnarReads) { + partitionReaderFactory.createColumnarReader(inputPartition) + } else { + partitionReaderFactory.createReader(inputPartition) + } + + context.addTaskCompletionListener[Unit](_ => reader.close()) + val iter = new Iterator[Any] { private[this] var valuePrepared = false override def hasNext: Boolean = { @@ -52,7 +64,7 @@ class DataSourceRDD[T: ClassTag]( valuePrepared } - override def next(): T = { + override def next(): Any = { if (!hasNext) { throw new java.util.NoSuchElementException("End of stream") } @@ -60,10 +72,11 @@ class DataSourceRDD[T: ClassTag]( reader.get() } } - new InterruptibleIterator(context, iter) + // TODO: SPARK-25083 remove the type erasure hack in data source scan + new InterruptibleIterator(context, iter.asInstanceOf[Iterator[InternalRow]]) } override def getPreferredLocations(split: Partition): Seq[String] = { - split.asInstanceOf[DataSourceRDDPartition[T]].inputPartition.preferredLocations() + castPartition(split).inputPartition.preferredLocations() } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala index 7613eb210c659..f7e29593a6353 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala @@ -17,43 +17,53 @@ package org.apache.spark.sql.execution.datasources.v2 +import java.util.UUID + import scala.collection.JavaConverters._ -import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation +import org.apache.spark.sql.{AnalysisException, SaveMode} +import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.analysis.{MultiInstanceRelation, NamedRelation} import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression} import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, Statistics} import org.apache.spark.sql.sources.DataSourceRegister -import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2, ReadSupport, ReadSupportWithSchema} -import org.apache.spark.sql.sources.v2.reader.{DataSourceReader, SupportsReportStatistics} +import org.apache.spark.sql.sources.v2.{BatchReadSupportProvider, BatchWriteSupportProvider, DataSourceOptions, DataSourceV2} +import org.apache.spark.sql.sources.v2.reader.{BatchReadSupport, ReadSupport, ScanConfigBuilder, SupportsReportStatistics} +import org.apache.spark.sql.sources.v2.writer.BatchWriteSupport import org.apache.spark.sql.types.StructType /** * A logical plan representing a data source v2 scan. * * @param source An instance of a [[DataSourceV2]] implementation. - * @param options The options for this scan. Used to create fresh [[DataSourceReader]]. - * @param userSpecifiedSchema The user-specified schema for this scan. Used to create fresh - * [[DataSourceReader]]. + * @param options The options for this scan. Used to create fresh [[BatchWriteSupport]]. + * @param userSpecifiedSchema The user-specified schema for this scan. */ case class DataSourceV2Relation( source: DataSourceV2, + readSupport: BatchReadSupport, output: Seq[AttributeReference], options: Map[String, String], - userSpecifiedSchema: Option[StructType]) - extends LeafNode with MultiInstanceRelation with DataSourceV2StringFormat { + tableIdent: Option[TableIdentifier] = None, + userSpecifiedSchema: Option[StructType] = None) + extends LeafNode with MultiInstanceRelation with NamedRelation with DataSourceV2StringFormat { import DataSourceV2Relation._ + override def name: String = { + tableIdent.map(_.unquotedString).getOrElse(s"${source.name}:unknown") + } + override def pushedFilters: Seq[Expression] = Seq.empty override def simpleString: String = "RelationV2 " + metadataString - def newReader(): DataSourceReader = source.createReader(options, userSpecifiedSchema) + def newWriteSupport(): BatchWriteSupport = source.createWriteSupport(options, schema) - override def computeStats(): Statistics = newReader match { + override def computeStats(): Statistics = readSupport match { case r: SupportsReportStatistics => - Statistics(sizeInBytes = r.getStatistics.sizeInBytes().orElse(conf.defaultSizeInBytes)) + val statistics = r.estimateStatistics(readSupport.newScanConfigBuilder().build()) + Statistics(sizeInBytes = statistics.sizeInBytes().orElse(conf.defaultSizeInBytes)) case _ => Statistics(sizeInBytes = conf.defaultSizeInBytes) } @@ -74,7 +84,8 @@ case class StreamingDataSourceV2Relation( output: Seq[AttributeReference], source: DataSourceV2, options: Map[String, String], - reader: DataSourceReader) + readSupport: ReadSupport, + scanConfigBuilder: ScanConfigBuilder) extends LeafNode with MultiInstanceRelation with DataSourceV2StringFormat { override def isStreaming: Boolean = true @@ -88,7 +99,8 @@ case class StreamingDataSourceV2Relation( // TODO: unify the equal/hashCode implementation for all data source v2 query plans. override def equals(other: Any): Boolean = other match { case other: StreamingDataSourceV2Relation => - output == other.output && reader.getClass == other.reader.getClass && options == other.options + output == other.output && readSupport.getClass == other.readSupport.getClass && + options == other.options case _ => false } @@ -96,9 +108,10 @@ case class StreamingDataSourceV2Relation( Seq(output, source, options).hashCode() } - override def computeStats(): Statistics = reader match { + override def computeStats(): Statistics = readSupport match { case r: SupportsReportStatistics => - Statistics(sizeInBytes = r.getStatistics.sizeInBytes().orElse(conf.defaultSizeInBytes)) + val statistics = r.estimateStatistics(scanConfigBuilder.build()) + Statistics(sizeInBytes = statistics.sizeInBytes().orElse(conf.defaultSizeInBytes)) case _ => Statistics(sizeInBytes = conf.defaultSizeInBytes) } @@ -106,28 +119,21 @@ case class StreamingDataSourceV2Relation( object DataSourceV2Relation { private implicit class SourceHelpers(source: DataSourceV2) { - def asReadSupport: ReadSupport = { + def asReadSupportProvider: BatchReadSupportProvider = { source match { - case support: ReadSupport => - support - case _: ReadSupportWithSchema => - // this method is only called if there is no user-supplied schema. if there is no - // user-supplied schema and ReadSupport was not implemented, throw a helpful exception. - throw new AnalysisException(s"Data source requires a user-supplied schema: $name") + case provider: BatchReadSupportProvider => + provider case _ => throw new AnalysisException(s"Data source is not readable: $name") } } - def asReadSupportWithSchema: ReadSupportWithSchema = { + def asWriteSupportProvider: BatchWriteSupportProvider = { source match { - case support: ReadSupportWithSchema => - support - case _: ReadSupport => - throw new AnalysisException( - s"Data source does not support user-supplied schema: $name") + case provider: BatchWriteSupportProvider => + provider case _ => - throw new AnalysisException(s"Data source is not readable: $name") + throw new AnalysisException(s"Data source is not writable: $name") } } @@ -140,25 +146,44 @@ object DataSourceV2Relation { } } - def createReader( + def createReadSupport( options: Map[String, String], - userSpecifiedSchema: Option[StructType]): DataSourceReader = { + userSpecifiedSchema: Option[StructType]): BatchReadSupport = { val v2Options = new DataSourceOptions(options.asJava) userSpecifiedSchema match { case Some(s) => - asReadSupportWithSchema.createReader(s, v2Options) + asReadSupportProvider.createBatchReadSupport(s, v2Options) case _ => - asReadSupport.createReader(v2Options) + asReadSupportProvider.createBatchReadSupport(v2Options) } } + + def createWriteSupport( + options: Map[String, String], + schema: StructType): BatchWriteSupport = { + asWriteSupportProvider.createBatchWriteSupport( + UUID.randomUUID().toString, + schema, + SaveMode.Append, + new DataSourceOptions(options.asJava)).get + } } def create( source: DataSourceV2, options: Map[String, String], - userSpecifiedSchema: Option[StructType]): DataSourceV2Relation = { - val reader = source.createReader(options, userSpecifiedSchema) + tableIdent: Option[TableIdentifier] = None, + userSpecifiedSchema: Option[StructType] = None): DataSourceV2Relation = { + val readSupport = source.createReadSupport(options, userSpecifiedSchema) + val output = readSupport.fullSchema().toAttributes + val ident = tableIdent.orElse(tableFromOptions(options)) DataSourceV2Relation( - source, reader.readSchema().toAttributes, options, userSpecifiedSchema) + source, readSupport, output, options, ident, userSpecifiedSchema) + } + + private def tableFromOptions(options: Map[String, String]): Option[TableIdentifier] = { + options + .get(DataSourceOptions.TABLE_KEY) + .map(TableIdentifier(_, options.get(DataSourceOptions.DATABASE_KEY))) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala index c6a7684bf6ab0..04a97735d024d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala @@ -17,12 +17,8 @@ package org.apache.spark.sql.execution.datasources.v2 -import scala.collection.JavaConverters._ - import org.apache.spark.rdd.RDD -import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, RowEncoder} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical import org.apache.spark.sql.catalyst.plans.physical.SinglePartition @@ -30,9 +26,7 @@ import org.apache.spark.sql.execution.{ColumnarBatchScan, LeafExecNode, WholeSta import org.apache.spark.sql.execution.streaming.continuous._ import org.apache.spark.sql.sources.v2.DataSourceV2 import org.apache.spark.sql.sources.v2.reader._ -import org.apache.spark.sql.sources.v2.reader.streaming.ContinuousReader -import org.apache.spark.sql.types.StructType -import org.apache.spark.sql.vectorized.ColumnarBatch +import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousPartitionReaderFactory, ContinuousReadSupport, MicroBatchReadSupport} /** * Physical plan node for scanning data from a data source. @@ -42,7 +36,8 @@ case class DataSourceV2ScanExec( @transient source: DataSourceV2, @transient options: Map[String, String], @transient pushedFilters: Seq[Expression], - @transient reader: DataSourceReader) + @transient readSupport: ReadSupport, + @transient scanConfig: ScanConfig) extends LeafExecNode with DataSourceV2StringFormat with ColumnarBatchScan { override def simpleString: String = "ScanV2 " + metadataString @@ -50,7 +45,8 @@ case class DataSourceV2ScanExec( // TODO: unify the equal/hashCode implementation for all data source v2 query plans. override def equals(other: Any): Boolean = other match { case other: DataSourceV2ScanExec => - output == other.output && reader.getClass == other.reader.getClass && options == other.options + output == other.output && readSupport.getClass == other.readSupport.getClass && + options == other.options case _ => false } @@ -58,40 +54,39 @@ case class DataSourceV2ScanExec( Seq(output, source, options).hashCode() } - override def outputPartitioning: physical.Partitioning = reader match { - case r: SupportsScanColumnarBatch if r.enableBatchRead() && batchPartitions.size == 1 => - SinglePartition - - case r: SupportsScanColumnarBatch if !r.enableBatchRead() && partitions.size == 1 => - SinglePartition - - case r if !r.isInstanceOf[SupportsScanColumnarBatch] && partitions.size == 1 => + override def outputPartitioning: physical.Partitioning = readSupport match { + case _ if partitions.length == 1 => SinglePartition case s: SupportsReportPartitioning => new DataSourcePartitioning( - s.outputPartitioning(), AttributeMap(output.map(a => a -> a.name))) + s.outputPartitioning(scanConfig), AttributeMap(output.map(a => a -> a.name))) case _ => super.outputPartitioning } - private lazy val partitions: Seq[InputPartition[UnsafeRow]] = reader match { - case r: SupportsScanUnsafeRow => r.planUnsafeInputPartitions().asScala - case _ => - reader.planInputPartitions().asScala.map { - new RowToUnsafeRowInputPartition(_, reader.readSchema()): InputPartition[UnsafeRow] - } + private lazy val partitions: Seq[InputPartition] = readSupport.planInputPartitions(scanConfig) + + private lazy val readerFactory = readSupport match { + case r: BatchReadSupport => r.createReaderFactory(scanConfig) + case r: MicroBatchReadSupport => r.createReaderFactory(scanConfig) + case r: ContinuousReadSupport => r.createContinuousReaderFactory(scanConfig) + case _ => throw new IllegalStateException("unknown read support: " + readSupport) } - private lazy val batchPartitions: Seq[InputPartition[ColumnarBatch]] = reader match { - case r: SupportsScanColumnarBatch if r.enableBatchRead() => - assert(!reader.isInstanceOf[ContinuousReader], - "continuous stream reader does not support columnar read yet.") - r.planBatchInputPartitions().asScala + // TODO: clean this up when we have dedicated scan plan for continuous streaming. + override val supportsBatch: Boolean = { + require(partitions.forall(readerFactory.supportColumnarReads) || + !partitions.exists(readerFactory.supportColumnarReads), + "Cannot mix row-based and columnar input partitions.") + + partitions.exists(readerFactory.supportColumnarReads) } - private lazy val inputRDD: RDD[InternalRow] = reader match { - case _: ContinuousReader => + private lazy val inputRDD: RDD[InternalRow] = readSupport match { + case _: ContinuousReadSupport => + assert(!supportsBatch, + "continuous stream reader does not support columnar read yet.") EpochCoordinatorRef.get( sparkContext.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY), sparkContext.env) @@ -100,22 +95,17 @@ case class DataSourceV2ScanExec( sparkContext, sqlContext.conf.continuousStreamingExecutorQueueSize, sqlContext.conf.continuousStreamingExecutorPollIntervalMs, - partitions).asInstanceOf[RDD[InternalRow]] - - case r: SupportsScanColumnarBatch if r.enableBatchRead() => - new DataSourceRDD(sparkContext, batchPartitions).asInstanceOf[RDD[InternalRow]] + partitions, + schema, + readerFactory.asInstanceOf[ContinuousPartitionReaderFactory]) case _ => - new DataSourceRDD(sparkContext, partitions).asInstanceOf[RDD[InternalRow]] + new DataSourceRDD( + sparkContext, partitions, readerFactory.asInstanceOf[PartitionReaderFactory], supportsBatch) } override def inputRDDs(): Seq[RDD[InternalRow]] = Seq(inputRDD) - override val supportsBatch: Boolean = reader match { - case r: SupportsScanColumnarBatch if r.enableBatchRead() => true - case _ => false - } - override protected def needsUnsafeRowConversion: Boolean = false override protected def doExecute(): RDD[InternalRow] = { @@ -130,27 +120,3 @@ case class DataSourceV2ScanExec( } } } - -class RowToUnsafeRowInputPartition(partition: InputPartition[Row], schema: StructType) - extends InputPartition[UnsafeRow] { - - override def preferredLocations: Array[String] = partition.preferredLocations - - override def createPartitionReader: InputPartitionReader[UnsafeRow] = { - new RowToUnsafeInputPartitionReader( - partition.createPartitionReader, RowEncoder.apply(schema).resolveAndBind()) - } -} - -class RowToUnsafeInputPartitionReader( - val rowReader: InputPartitionReader[Row], - encoder: ExpressionEncoder[Row]) - - extends InputPartitionReader[UnsafeRow] { - - override def next: Boolean = rowReader.next - - override def get: UnsafeRow = encoder.toRow(rowReader.get).asInstanceOf[UnsafeRow] - - override def close(): Unit = rowReader.close() -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala index 2a7f1de2c7c19..9a3109e7c199e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala @@ -22,12 +22,12 @@ import scala.collection.mutable import org.apache.spark.sql.{sources, Strategy} import org.apache.spark.sql.catalyst.expressions.{And, AttributeReference, AttributeSet, Expression} import org.apache.spark.sql.catalyst.planning.PhysicalOperation -import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Repartition} +import org.apache.spark.sql.catalyst.plans.logical.{AppendData, LogicalPlan, Repartition} import org.apache.spark.sql.execution.{FilterExec, ProjectExec, SparkPlan} import org.apache.spark.sql.execution.datasources.DataSourceStrategy import org.apache.spark.sql.execution.streaming.continuous.{ContinuousCoalesceExec, WriteToContinuousDataSource, WriteToContinuousDataSourceExec} -import org.apache.spark.sql.sources.v2.reader.{DataSourceReader, SupportsPushDownCatalystFilters, SupportsPushDownFilters, SupportsPushDownRequiredColumns} -import org.apache.spark.sql.sources.v2.reader.streaming.ContinuousReader +import org.apache.spark.sql.sources.v2.reader._ +import org.apache.spark.sql.sources.v2.reader.streaming.ContinuousReadSupport object DataSourceV2Strategy extends Strategy { @@ -37,14 +37,9 @@ object DataSourceV2Strategy extends Strategy { * @return pushed filter and post-scan filters. */ private def pushFilters( - reader: DataSourceReader, + configBuilder: ScanConfigBuilder, filters: Seq[Expression]): (Seq[Expression], Seq[Expression]) = { - reader match { - case r: SupportsPushDownCatalystFilters => - val postScanFilters = r.pushCatalystFilters(filters.toArray) - val pushedFilters = r.pushedCatalystFilters() - (pushedFilters, postScanFilters) - + configBuilder match { case r: SupportsPushDownFilters => // A map from translated data source filters to original catalyst filter expressions. val translatedFilterToExpr = mutable.HashMap.empty[sources.Filter, Expression] @@ -76,41 +71,43 @@ object DataSourceV2Strategy extends Strategy { /** * Applies column pruning to the data source, w.r.t. the references of the given expressions. * - * @return new output attributes after column pruning. + * @return the created `ScanConfig`(since column pruning is the last step of operator pushdown), + * and new output attributes after column pruning. */ // TODO: nested column pruning. private def pruneColumns( - reader: DataSourceReader, + configBuilder: ScanConfigBuilder, relation: DataSourceV2Relation, - exprs: Seq[Expression]): Seq[AttributeReference] = { - reader match { + exprs: Seq[Expression]): (ScanConfig, Seq[AttributeReference]) = { + configBuilder match { case r: SupportsPushDownRequiredColumns => val requiredColumns = AttributeSet(exprs.flatMap(_.references)) val neededOutput = relation.output.filter(requiredColumns.contains) if (neededOutput != relation.output) { r.pruneColumns(neededOutput.toStructType) + val config = r.build() val nameToAttr = relation.output.map(_.name).zip(relation.output).toMap - r.readSchema().toAttributes.map { + config -> config.readSchema().toAttributes.map { // We have to keep the attribute id during transformation. a => a.withExprId(nameToAttr(a.name).exprId) } } else { - relation.output + r.build() -> relation.output } - case _ => relation.output + case _ => configBuilder.build() -> relation.output } } override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case PhysicalOperation(project, filters, relation: DataSourceV2Relation) => - val reader = relation.newReader() + val configBuilder = relation.readSupport.newScanConfigBuilder() // `pushedFilters` will be pushed down and evaluated in the underlying data sources. // `postScanFilters` need to be evaluated after the scan. // `postScanFilters` and `pushedFilters` can overlap, e.g. the parquet row group filter. - val (pushedFilters, postScanFilters) = pushFilters(reader, filters) - val output = pruneColumns(reader, relation, project ++ postScanFilters) + val (pushedFilters, postScanFilters) = pushFilters(configBuilder, filters) + val (config, output) = pruneColumns(configBuilder, relation, project ++ postScanFilters) logInfo( s""" |Pushing operators to ${relation.source.getClass} @@ -120,31 +117,40 @@ object DataSourceV2Strategy extends Strategy { """.stripMargin) val scan = DataSourceV2ScanExec( - output, relation.source, relation.options, pushedFilters, reader) + output, + relation.source, + relation.options, + pushedFilters, + relation.readSupport, + config) val filterCondition = postScanFilters.reduceLeftOption(And) val withFilter = filterCondition.map(FilterExec(_, scan)).getOrElse(scan) - val withProjection = if (withFilter.output != project) { - ProjectExec(project, withFilter) - } else { - withFilter - } - - withProjection :: Nil + // always add the projection, which will produce unsafe rows required by some operators + ProjectExec(project, withFilter) :: Nil case r: StreamingDataSourceV2Relation => - DataSourceV2ScanExec(r.output, r.source, r.options, r.pushedFilters, r.reader) :: Nil + // TODO: support operator pushdown for streaming data sources. + val scanConfig = r.scanConfigBuilder.build() + // ensure there is a projection, which will produce unsafe rows required by some operators + ProjectExec(r.output, + DataSourceV2ScanExec( + r.output, r.source, r.options, r.pushedFilters, r.readSupport, scanConfig)) :: Nil case WriteToDataSourceV2(writer, query) => WriteToDataSourceV2Exec(writer, planLater(query)) :: Nil + case AppendData(r: DataSourceV2Relation, query, _) => + WriteToDataSourceV2Exec(r.newWriteSupport(), planLater(query)) :: Nil + case WriteToContinuousDataSource(writer, query) => WriteToContinuousDataSourceExec(writer, planLater(query)) :: Nil case Repartition(1, false, child) => - val isContinuous = child.collectFirst { - case StreamingDataSourceV2Relation(_, _, _, r: ContinuousReader) => r + val isContinuous = child.find { + case s: StreamingDataSourceV2Relation => s.readSupport.isInstanceOf[ContinuousReadSupport] + case _ => false }.isDefined if (isContinuous) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Utils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Utils.scala index 5267f5f1580c3..e9cc3991155c4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Utils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Utils.scala @@ -21,6 +21,7 @@ import java.util.regex.Pattern import org.apache.spark.internal.Logging import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.sources.DataSourceRegister import org.apache.spark.sql.sources.v2.{DataSourceV2, SessionConfigSupport} private[sql] object DataSourceV2Utils extends Logging { @@ -55,4 +56,12 @@ private[sql] object DataSourceV2Utils extends Logging { case _ => Map.empty } + + def failForUserSpecifiedSchema[T](ds: DataSourceV2): T = { + val name = ds match { + case register: DataSourceRegister => register.shortName() + case _ => ds.getClass.getName + } + throw new UnsupportedOperationException(name + " source does not support user-specified schema") + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala similarity index 66% rename from sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala index b1148c0f62f7c..c3f7b690ef636 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala @@ -23,21 +23,20 @@ import org.apache.spark.{SparkEnv, SparkException, TaskContext} import org.apache.spark.executor.CommitDeniedException import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD -import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, RowEncoder} import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.SparkPlan -import org.apache.spark.sql.execution.streaming.MicroBatchExecution import org.apache.spark.sql.sources.v2.writer._ -import org.apache.spark.sql.types.StructType import org.apache.spark.util.Utils /** - * The logical plan for writing data into data source v2. + * Deprecated logical plan for writing data into data source v2. This is being replaced by more + * specific logical plans, like [[org.apache.spark.sql.catalyst.plans.logical.AppendData]]. */ -case class WriteToDataSourceV2(writer: DataSourceWriter, query: LogicalPlan) extends LogicalPlan { +@deprecated("Use specific logical plans like AppendData instead", "2.4.0") +case class WriteToDataSourceV2(writeSupport: BatchWriteSupport, query: LogicalPlan) + extends LogicalPlan { override def children: Seq[LogicalPlan] = Seq(query) override def output: Seq[Attribute] = Nil } @@ -45,50 +44,48 @@ case class WriteToDataSourceV2(writer: DataSourceWriter, query: LogicalPlan) ext /** * The physical plan for writing data into data source v2. */ -case class WriteToDataSourceV2Exec(writer: DataSourceWriter, query: SparkPlan) extends SparkPlan { +case class WriteToDataSourceV2Exec(writeSupport: BatchWriteSupport, query: SparkPlan) + extends SparkPlan { + override def children: Seq[SparkPlan] = Seq(query) override def output: Seq[Attribute] = Nil override protected def doExecute(): RDD[InternalRow] = { - val writeTask = writer match { - case w: SupportsWriteInternalRow => w.createInternalRowWriterFactory() - case _ => new InternalRowDataWriterFactory(writer.createWriterFactory(), query.schema) - } - - val useCommitCoordinator = writer.useCommitCoordinator + val writerFactory = writeSupport.createBatchWriterFactory() + val useCommitCoordinator = writeSupport.useCommitCoordinator val rdd = query.execute() val messages = new Array[WriterCommitMessage](rdd.partitions.length) - logInfo(s"Start processing data source writer: $writer. " + + logInfo(s"Start processing data source write support: $writeSupport. " + s"The input RDD has ${messages.length} partitions.") try { sparkContext.runJob( rdd, (context: TaskContext, iter: Iterator[InternalRow]) => - DataWritingSparkTask.run(writeTask, context, iter, useCommitCoordinator), + DataWritingSparkTask.run(writerFactory, context, iter, useCommitCoordinator), rdd.partitions.indices, (index, message: WriterCommitMessage) => { messages(index) = message - writer.onDataWriterCommit(message) + writeSupport.onDataWriterCommit(message) } ) - logInfo(s"Data source writer $writer is committing.") - writer.commit(messages) - logInfo(s"Data source writer $writer committed.") + logInfo(s"Data source write support $writeSupport is committing.") + writeSupport.commit(messages) + logInfo(s"Data source write support $writeSupport committed.") } catch { case cause: Throwable => - logError(s"Data source writer $writer is aborting.") + logError(s"Data source write support $writeSupport is aborting.") try { - writer.abort(messages) + writeSupport.abort(messages) } catch { case t: Throwable => - logError(s"Data source writer $writer failed to abort.") + logError(s"Data source write support $writeSupport failed to abort.") cause.addSuppressed(t) throw new SparkException("Writing job failed.", cause) } - logError(s"Data source writer $writer aborted.") + logError(s"Data source write support $writeSupport aborted.") cause match { // Only wrap non fatal exceptions. case NonFatal(e) => throw new SparkException("Writing job aborted.", e) @@ -102,7 +99,7 @@ case class WriteToDataSourceV2Exec(writer: DataSourceWriter, query: SparkPlan) e object DataWritingSparkTask extends Logging { def run( - writeTask: DataWriterFactory[InternalRow], + writerFactory: DataWriterFactory, context: TaskContext, iter: Iterator[InternalRow], useCommitCoordinator: Boolean): WriterCommitMessage = { @@ -111,8 +108,7 @@ object DataWritingSparkTask extends Logging { val partId = context.partitionId() val taskId = context.taskAttemptId() val attemptId = context.attemptNumber() - val epochId = Option(context.getLocalProperty(MicroBatchExecution.BATCH_ID_KEY)).getOrElse("0") - val dataWriter = writeTask.createDataWriter(partId, taskId, epochId.toLong) + val dataWriter = writerFactory.createWriter(partId, taskId) // write the data and commit this writer. Utils.tryWithSafeFinallyAndFailureCallbacks(block = { @@ -155,27 +151,3 @@ object DataWritingSparkTask extends Logging { }) } } - -class InternalRowDataWriterFactory( - rowWriterFactory: DataWriterFactory[Row], - schema: StructType) extends DataWriterFactory[InternalRow] { - - override def createDataWriter( - partitionId: Int, - taskId: Long, - epochId: Long): DataWriter[InternalRow] = { - new InternalRowDataWriter( - rowWriterFactory.createDataWriter(partitionId, taskId, epochId), - RowEncoder.apply(schema).resolveAndBind()) - } -} - -class InternalRowDataWriter(rowWriter: DataWriter[Row], encoder: ExpressionEncoder[Row]) - extends DataWriter[InternalRow] { - - override def write(record: InternalRow): Unit = rowWriter.write(encoder.fromRow(record)) - - override def commit(): WriterCommitMessage = rowWriter.commit() - - override def abort(): Unit = rowWriter.abort() -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala index a717cbd4a7df9..366e1fe6a4aaa 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala @@ -29,6 +29,9 @@ import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.expressions.codegen.{CodeFormatter, CodegenContext, ExprCode} import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.catalyst.trees.TreeNodeRef +import org.apache.spark.sql.execution.streaming.{StreamExecution, StreamingQueryWrapper} +import org.apache.spark.sql.execution.streaming.continuous.WriteToContinuousDataSourceExec +import org.apache.spark.sql.streaming.StreamingQuery import org.apache.spark.util.{AccumulatorV2, LongAccumulator} /** @@ -40,6 +43,16 @@ import org.apache.spark.util.{AccumulatorV2, LongAccumulator} * sql("SELECT 1").debug() * sql("SELECT 1").debugCodegen() * }}} + * + * or for streaming case (structured streaming): + * {{{ + * import org.apache.spark.sql.execution.debug._ + * val query = df.writeStream.<...>.start() + * query.debugCodegen() + * }}} + * + * Note that debug in structured streaming is not supported, because it doesn't make sense for + * streaming to execute batch once while main query is running concurrently. */ package object debug { @@ -88,14 +101,50 @@ package object debug { } } + /** + * Get WholeStageCodegenExec subtrees and the codegen in a query plan into one String + * + * @param query the streaming query for codegen + * @return single String containing all WholeStageCodegen subtrees and corresponding codegen + */ + def codegenString(query: StreamingQuery): String = { + val w = asStreamExecution(query) + if (w.lastExecution != null) { + codegenString(w.lastExecution.executedPlan) + } else { + "No physical plan. Waiting for data." + } + } + + /** + * Get WholeStageCodegenExec subtrees and the codegen in a query plan + * + * @param query the streaming query for codegen + * @return Sequence of WholeStageCodegen subtrees and corresponding codegen + */ + def codegenStringSeq(query: StreamingQuery): Seq[(String, String)] = { + val w = asStreamExecution(query) + if (w.lastExecution != null) { + codegenStringSeq(w.lastExecution.executedPlan) + } else { + Seq.empty + } + } + + private def asStreamExecution(query: StreamingQuery): StreamExecution = query match { + case wrapper: StreamingQueryWrapper => wrapper.streamingQuery + case q: StreamExecution => q + case _ => throw new IllegalArgumentException("Parameter should be an instance of " + + "StreamExecution!") + } + /** * Augments [[Dataset]]s with debug methods. */ implicit class DebugQuery(query: Dataset[_]) extends Logging { def debug(): Unit = { - val plan = query.queryExecution.executedPlan val visited = new collection.mutable.HashSet[TreeNodeRef]() - val debugPlan = plan transform { + val debugPlan = query.queryExecution.executedPlan transform { case s: SparkPlan if !visited.contains(new TreeNodeRef(s)) => visited += new TreeNodeRef(s) DebugExec(s) @@ -116,6 +165,12 @@ package object debug { } } + implicit class DebugStreamQuery(query: StreamingQuery) extends Logging { + def debugCodegen(): Unit = { + debugPrint(codegenString(query)) + } + } + case class DebugExec(child: SparkPlan) extends UnaryExecNode with CodegenSupport { def output: Seq[Attribute] = child.output diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala index c55f9b8f1a7fc..a80673c705f1a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.execution.exchange +import java.util.concurrent.TimeoutException + import scala.concurrent.{ExecutionContext, Future} import scala.concurrent.duration._ import scala.util.control.NonFatal @@ -140,7 +142,16 @@ case class BroadcastExchangeExec( } override protected[sql] def doExecuteBroadcast[T](): broadcast.Broadcast[T] = { - ThreadUtils.awaitResult(relationFuture, timeout).asInstanceOf[broadcast.Broadcast[T]] + try { + ThreadUtils.awaitResult(relationFuture, timeout).asInstanceOf[broadcast.Broadcast[T]] + } catch { + case ex: TimeoutException => + logError(s"Could not execute broadcast in ${timeout.toSeconds} secs.", ex) + throw new SparkException(s"Could not execute broadcast in ${timeout.toSeconds} secs. " + + s"You can increase the timeout for broadcasts via ${SQLConf.BROADCAST_TIMEOUT.key} or " + + s"disable broadcast join by setting ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key} to -1", + ex) + } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala index ad95879d86f42..d2d5011bbcb97 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala @@ -82,7 +82,6 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] { if (adaptiveExecutionEnabled && supportsCoordinator) { val coordinator = new ExchangeCoordinator( - children.length, targetPostShuffleInputSize, minNumPostShufflePartitions) children.zip(requiredChildDistributions).map { @@ -279,13 +278,6 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] { */ private def reorderJoinPredicates(plan: SparkPlan): SparkPlan = { plan match { - case BroadcastHashJoinExec(leftKeys, rightKeys, joinType, buildSide, condition, left, - right) => - val (reorderedLeftKeys, reorderedRightKeys) = - reorderJoinKeys(leftKeys, rightKeys, left.outputPartitioning, right.outputPartitioning) - BroadcastHashJoinExec(reorderedLeftKeys, reorderedRightKeys, joinType, buildSide, condition, - left, right) - case ShuffledHashJoinExec(leftKeys, rightKeys, joinType, buildSide, condition, left, right) => val (reorderedLeftKeys, reorderedRightKeys) = reorderJoinKeys(leftKeys, rightKeys, left.outputPartitioning, right.outputPartitioning) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ExchangeCoordinator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ExchangeCoordinator.scala index 051e610eb2705..f5d93ee5fa914 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ExchangeCoordinator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ExchangeCoordinator.scala @@ -83,7 +83,6 @@ import org.apache.spark.sql.execution.{ShuffledRowRDD, SparkPlan} * - post-shuffle partition 3: pre-shuffle partition 3 and 4 (size 50 MB) */ class ExchangeCoordinator( - numExchanges: Int, advisoryTargetPostShuffleInputSize: Long, minNumPostShufflePartitions: Option[Int] = None) extends Logging { @@ -91,8 +90,14 @@ class ExchangeCoordinator( // The registered Exchange operators. private[this] val exchanges = ArrayBuffer[ShuffleExchangeExec]() + // `lazy val` is used here so that we could notice the wrong use of this class, e.g., all the + // exchanges should be registered before `postShuffleRDD` called first time. If a new exchange is + // registered after the `postShuffleRDD` call, `assert(exchanges.length == numExchanges)` fails + // in `doEstimationIfNecessary`. + private[this] lazy val numExchanges = exchanges.size + // This map is used to lookup the post-shuffle ShuffledRowRDD for an Exchange operator. - private[this] val postShuffleRDDs: JMap[ShuffleExchangeExec, ShuffledRowRDD] = + private[this] lazy val postShuffleRDDs: JMap[ShuffleExchangeExec, ShuffledRowRDD] = new JHashMap[ShuffleExchangeExec, ShuffledRowRDD](numExchanges) // A boolean that indicates if this coordinator has made decision on how to shuffle data. @@ -117,10 +122,6 @@ class ExchangeCoordinator( */ def estimatePartitionStartIndices( mapOutputStatistics: Array[MapOutputStatistics]): Array[Int] = { - // If we have mapOutputStatistics.length < numExchange, it is because we do not submit - // a stage when the number of partitions of this dependency is 0. - assert(mapOutputStatistics.length <= numExchanges) - // If minNumPostShufflePartitions is defined, it is possible that we need to use a // value less than advisoryTargetPostShuffleInputSize as the target input size of // a post shuffle task. @@ -228,6 +229,10 @@ class ExchangeCoordinator( j += 1 } + // If we have mapOutputStatistics.length < numExchange, it is because we do not submit + // a stage when the number of partitions of this dependency is 0. + assert(mapOutputStatistics.length <= numExchanges) + // Now, we estimate partitionStartIndices. partitionStartIndices.length will be the // number of post-shuffle partitions. val partitionStartIndices = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala index b89203719541b..9576605b1a214 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala @@ -231,6 +231,11 @@ object ShuffleExchangeExec { override def numPartitions: Int = 1 override def getPartition(key: Any): Int = 0 } + case l: LocalPartitioning => + new Partitioner { + override def numPartitions: Int = l.numPartitions + override def getPartition(key: Any): Int = key.asInstanceOf[Int] + } case _ => sys.error(s"Exchange not implemented for $newPartitioning") // TODO: Handle BroadcastPartitioning. } @@ -247,9 +252,15 @@ object ShuffleExchangeExec { val projection = UnsafeProjection.create(h.partitionIdExpression :: Nil, outputAttributes) row => projection(row).getInt(0) case RangePartitioning(_, _) | SinglePartition => identity + case _: LocalPartitioning => + val partitionId = TaskContext.get().partitionId() + _ => partitionId case _ => sys.error(s"Exchange not implemented for $newPartitioning") } + val isRoundRobin = newPartitioning.isInstanceOf[RoundRobinPartitioning] && + newPartitioning.numPartitions > 1 + val rddWithPartitionIds: RDD[Product2[Int, InternalRow]] = { // [SPARK-23207] Have to make sure the generated RoundRobinPartitioning is deterministic, // otherwise a retry task may output different rows and thus lead to data loss. @@ -259,9 +270,7 @@ object ShuffleExchangeExec { // // Note that we don't perform local sort if the new partitioning has only 1 partition, under // that case all output rows go to the same partition. - val newRdd = if (SQLConf.get.sortBeforeRepartition && - newPartitioning.numPartitions > 1 && - newPartitioning.isInstanceOf[RoundRobinPartitioning]) { + val newRdd = if (isRoundRobin && SQLConf.get.sortBeforeRepartition) { rdd.mapPartitionsInternal { iter => val recordComparatorSupplier = new Supplier[RecordComparator] { override def get: RecordComparator = new RecordBinaryComparator() @@ -297,17 +306,19 @@ object ShuffleExchangeExec { rdd } + // round-robin function is order sensitive if we don't sort the input. + val isOrderSensitive = isRoundRobin && !SQLConf.get.sortBeforeRepartition if (needToCopyObjectsBeforeShuffle(part)) { - newRdd.mapPartitionsInternal { iter => + newRdd.mapPartitionsWithIndexInternal((_, iter) => { val getPartitionKey = getPartitionKeyExtractor() iter.map { row => (part.getPartition(getPartitionKey(row)), row.copy()) } - } + }, isOrderSensitive = isOrderSensitive) } else { - newRdd.mapPartitionsInternal { iter => + newRdd.mapPartitionsWithIndexInternal((_, iter) => { val getPartitionKey = getPartitionKeyExtractor() val mutablePair = new MutablePair[Int, InternalRow]() iter.map { row => mutablePair.update(part.getPartition(getPartitionKey(row)), row) } - } + }, isOrderSensitive = isOrderSensitive) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala index 0da0e8610c392..a6f3ea47c8492 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala @@ -320,7 +320,7 @@ case class BroadcastHashJoinExec( |if (!$conditionPassed) { | $matched = null; | // reset the variables those are already evaluated. - | ${buildVars.filter(_.code == "").map(v => s"${v.isNull} = true;").mkString("\n")} + | ${buildVars.filter(_.code.isEmpty).map(v => s"${v.isNull} = true;").mkString("\n")} |} |$numOutput.add(1); |${consume(ctx, resultVars)} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala index 0396168d3f311..dab873bf9b9a0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala @@ -214,7 +214,7 @@ trait HashJoin { } // At the end of the task, we update the avg hash probe. - TaskContext.get().addTaskCompletionListener(_ => + TaskContext.get().addTaskCompletionListener[Unit](_ => avgHashProbe.set(hashed.getAverageProbesPerLookup)) val resultProj = createResultProjection diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala index 20ce01f4ce8cc..86eb47a70f1ad 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala @@ -772,6 +772,8 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap array = readLongArray(readBuffer, length) val pageLength = readLong().toInt page = readLongArray(readBuffer, pageLength) + // Restore cursor variable to make this map able to be serialized again on executors. + cursor = pageLength * 8 + Platform.LONG_ARRAY_OFFSET } override def readExternal(in: ObjectInput): Unit = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala index 897a4dae39f32..2b59ed6e4d16b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala @@ -57,7 +57,7 @@ case class ShuffledHashJoinExec( buildTime += (System.nanoTime() - start) / 1000000 buildDataSize += relation.estimatedSize // This relation is usually used until the end of task. - context.addTaskCompletionListener(_ => relation.close()) + context.addTaskCompletionListener[Unit](_ => relation.close()) relation } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala index 66bcda8913738..1a09632f93ca1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala @@ -47,13 +47,16 @@ case class CollectLimitExec(limit: Int, child: SparkPlan) extends UnaryExecNode } /** - * Helper trait which defines methods that are shared by both - * [[LocalLimitExec]] and [[GlobalLimitExec]]. + * Take the first `limit` elements of each child partition, but do not collect or shuffle them. */ -trait BaseLimitExec extends UnaryExecNode with CodegenSupport { - val limit: Int +case class LocalLimitExec(limit: Int, child: SparkPlan) extends UnaryExecNode with CodegenSupport { + override def output: Seq[Attribute] = child.output + override def outputOrdering: Seq[SortOrder] = child.outputOrdering + + override def outputPartitioning: Partitioning = child.outputPartitioning + protected override def doExecute(): RDD[InternalRow] = child.execute().mapPartitions { iter => iter.take(limit) } @@ -93,25 +96,96 @@ trait BaseLimitExec extends UnaryExecNode with CodegenSupport { } /** - * Take the first `limit` elements of each child partition, but do not collect or shuffle them. + * Take the `limit` elements of the child output. */ -case class LocalLimitExec(limit: Int, child: SparkPlan) extends BaseLimitExec { +case class GlobalLimitExec(limit: Int, child: SparkPlan, + orderedLimit: Boolean = false) extends UnaryExecNode { - override def outputOrdering: Seq[SortOrder] = child.outputOrdering + override def output: Seq[Attribute] = child.output override def outputPartitioning: Partitioning = child.outputPartitioning -} -/** - * Take the first `limit` elements of the child's single output partition. - */ -case class GlobalLimitExec(limit: Int, child: SparkPlan) extends BaseLimitExec { + override def outputOrdering: Seq[SortOrder] = child.outputOrdering - override def requiredChildDistribution: List[Distribution] = AllTuples :: Nil + private val serializer: Serializer = new UnsafeRowSerializer(child.output.size) - override def outputPartitioning: Partitioning = child.outputPartitioning + protected override def doExecute(): RDD[InternalRow] = { + val childRDD = child.execute() + val partitioner = LocalPartitioning(childRDD) + val shuffleDependency = ShuffleExchangeExec.prepareShuffleDependency( + childRDD, child.output, partitioner, serializer) + val numberOfOutput: Seq[Long] = if (shuffleDependency.rdd.getNumPartitions != 0) { + // submitMapStage does not accept RDD with 0 partition. + // So, we will not submit this dependency. + val submittedStageFuture = sparkContext.submitMapStage(shuffleDependency) + submittedStageFuture.get().recordsByPartitionId.toSeq + } else { + Nil + } - override def outputOrdering: Seq[SortOrder] = child.outputOrdering + // This is an optimization to evenly distribute limited rows across all partitions. + // When enabled, Spark goes to take rows at each partition repeatedly until reaching + // limit number. When disabled, Spark takes all rows at first partition, then rows + // at second partition ..., until reaching limit number. + // The optimization is disabled when it is needed to keep the original order of rows + // before global sort, e.g., select * from table order by col limit 10. + val flatGlobalLimit = sqlContext.conf.limitFlatGlobalLimit && !orderedLimit + + val shuffled = new ShuffledRowRDD(shuffleDependency) + + val sumOfOutput = numberOfOutput.sum + if (sumOfOutput <= limit) { + shuffled + } else if (!flatGlobalLimit) { + var numRowTaken = 0 + val takeAmounts = numberOfOutput.map { num => + if (numRowTaken + num < limit) { + numRowTaken += num.toInt + num.toInt + } else { + val toTake = limit - numRowTaken + numRowTaken += toTake + toTake + } + } + val broadMap = sparkContext.broadcast(takeAmounts) + shuffled.mapPartitionsWithIndexInternal { case (index, iter) => + iter.take(broadMap.value(index).toInt) + } + } else { + // We try to evenly require the asked limit number of rows across all child rdd's partitions. + var rowsNeedToTake: Long = limit + val takeAmountByPartition: Array[Long] = Array.fill[Long](numberOfOutput.length)(0L) + val remainingRowsByPartition: Array[Long] = Array(numberOfOutput: _*) + + while (rowsNeedToTake > 0) { + val nonEmptyParts = remainingRowsByPartition.count(_ > 0) + // If the rows needed to take are less the number of non-empty partitions, take one row from + // each non-empty partitions until we reach `limit` rows. + // Otherwise, evenly divide the needed rows to each non-empty partitions. + val takePerPart = math.max(1, rowsNeedToTake / nonEmptyParts) + remainingRowsByPartition.zipWithIndex.foreach { case (num, index) => + // In case `rowsNeedToTake` < `nonEmptyParts`, we may run out of `rowsNeedToTake` during + // the traversal, so we need to add this check. + if (rowsNeedToTake > 0 && num > 0) { + if (num >= takePerPart) { + rowsNeedToTake -= takePerPart + takeAmountByPartition(index) += takePerPart + remainingRowsByPartition(index) -= takePerPart + } else { + rowsNeedToTake -= num + takeAmountByPartition(index) += num + remainingRowsByPartition(index) -= num + } + } + } + } + val broadMap = sparkContext.broadcast(takeAmountByPartition) + shuffled.mapPartitionsWithIndexInternal { case (index, iter) => + iter.take(broadMap.value(index).toInt) + } + } + } } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala index b4f0ae1eb1a18..cbf707f4a9cfd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala @@ -19,7 +19,6 @@ package org.apache.spark.sql.execution.metric import java.text.NumberFormat import java.util.Locale -import java.util.concurrent.atomic.LongAdder import org.apache.spark.SparkContext import org.apache.spark.scheduler.AccumulableInfo @@ -33,45 +32,40 @@ import org.apache.spark.util.{AccumulatorContext, AccumulatorV2, Utils} * on the driver side must be explicitly posted using [[SQLMetrics.postDriverMetricUpdates()]]. */ class SQLMetric(val metricType: String, initValue: Long = 0L) extends AccumulatorV2[Long, Long] { - // This is a workaround for SPARK-11013. // We may use -1 as initial value of the accumulator, if the accumulator is valid, we will // update it at the end of task and the value will be at least 0. Then we can filter out the -1 // values before calculate max, min, etc. - private[this] val _value = new LongAdder - private val _zeroValue = initValue - _value.add(initValue) + private[this] var _value = initValue + private var _zeroValue = initValue override def copy(): SQLMetric = { - val newAcc = new SQLMetric(metricType, initValue) - newAcc.add(_value.sum()) + val newAcc = new SQLMetric(metricType, _value) + newAcc._zeroValue = initValue newAcc } - override def reset(): Unit = this.set(_zeroValue) + override def reset(): Unit = _value = _zeroValue override def merge(other: AccumulatorV2[Long, Long]): Unit = other match { - case o: SQLMetric => _value.add(o.value) + case o: SQLMetric => _value += o.value case _ => throw new UnsupportedOperationException( s"Cannot merge ${this.getClass.getName} with ${other.getClass.getName}") } - override def isZero(): Boolean = _value.sum() == _zeroValue + override def isZero(): Boolean = _value == _zeroValue - override def add(v: Long): Unit = _value.add(v) + override def add(v: Long): Unit = _value += v // We can set a double value to `SQLMetric` which stores only long value, if it is // average metrics. def set(v: Double): Unit = SQLMetrics.setDoubleForAverageMetrics(this, v) - def set(v: Long): Unit = { - _value.reset() - _value.add(v) - } + def set(v: Long): Unit = _value = v - def +=(v: Long): Unit = _value.add(v) + def +=(v: Long): Unit = _value += v - override def value: Long = _value.sum() + override def value: Long = _value // Provide special identifier as metadata so we can tell that this is a `SQLMetric` later override def toInfo(update: Option[Any], value: Option[Any]): AccumulableInfo = { @@ -110,7 +104,7 @@ object SQLMetrics { * spill size, etc. */ def createSizeMetric(sc: SparkContext, name: String): SQLMetric = { - // The final result of this metric in physical operator UI may looks like: + // The final result of this metric in physical operator UI may look like: // data size total (min, med, max): // 100GB (100MB, 1GB, 10GB) val acc = new SQLMetric(SIZE_METRIC, -1) @@ -159,7 +153,7 @@ object SQLMetrics { Seq.fill(3)(0L) } else { val sorted = validValues.sorted - Seq(sorted.head, sorted(validValues.length / 2), sorted(validValues.length - 1)) + Seq(sorted(0), sorted(validValues.length / 2), sorted(validValues.length - 1)) } metric.map(v => numberFormat.format(v.toDouble / baseForAvgMetric)) } @@ -179,8 +173,7 @@ object SQLMetrics { Seq.fill(4)(0L) } else { val sorted = validValues.sorted - Seq(sorted.sum, sorted.head, sorted(validValues.length / 2), - sorted(validValues.length - 1)) + Seq(sorted.sum, sorted(0), sorted(validValues.length / 2), sorted(validValues.length - 1)) } metric.map(strFormat) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala index d00f6f042d6e0..2ab7240556aaa 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala @@ -79,8 +79,6 @@ case class AggregateInPandasExec( override protected def doExecute(): RDD[InternalRow] = { val inputRDD = child.execute() - val bufferSize = inputRDD.conf.getInt("spark.buffer.size", 65536) - val reuseWorker = inputRDD.conf.getBoolean("spark.python.worker.reuse", defaultValue = true) val sessionLocalTimeZone = conf.sessionLocalTimeZone val pythonRunnerConf = ArrowUtils.getPythonRunnerConfMap(conf) @@ -125,7 +123,7 @@ case class AggregateInPandasExec( // combine input with output from Python. val queue = HybridRowQueue(context.taskMemoryManager(), new File(Utils.getLocalDir(SparkEnv.get.conf)), groupingExpressions.length) - context.addTaskCompletionListener { _ => + context.addTaskCompletionListener[Unit] { _ => queue.close() } @@ -137,8 +135,6 @@ case class AggregateInPandasExec( val columnarBatchIter = new ArrowPythonRunner( pyFuncs, - bufferSize, - reuseWorker, PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF, argOffsets, aggInputSchema, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala index 0bc21c0986e69..2b87796dc6833 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala @@ -23,6 +23,7 @@ import org.apache.spark.TaskContext import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, UnaryNode} import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.arrow.ArrowUtils import org.apache.spark.sql.types.StructType @@ -57,7 +58,13 @@ private class BatchIterator[T](iter: Iterator[T], batchSize: Int) } /** - * A physical plan that evaluates a [[PythonUDF]], + * A logical plan that evaluates a [[PythonUDF]]. + */ +case class ArrowEvalPython(udfs: Seq[PythonUDF], output: Seq[Attribute], child: LogicalPlan) + extends UnaryNode + +/** + * A physical plan that evaluates a [[PythonUDF]]. */ case class ArrowEvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], child: SparkPlan) extends EvalPythonExec(udfs, output, child) { @@ -68,8 +75,6 @@ case class ArrowEvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], chi protected override def evaluate( funcs: Seq[ChainedPythonFunctions], - bufferSize: Int, - reuseWorker: Boolean, argOffsets: Array[Array[Int]], iter: Iterator[InternalRow], schema: StructType, @@ -82,8 +87,6 @@ case class ArrowEvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], chi val columnarBatchIter = new ArrowPythonRunner( funcs, - bufferSize, - reuseWorker, PythonEvalType.SQL_SCALAR_PANDAS_UDF, argOffsets, schema, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala index ca665652f204d..18992d7a9f974 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala @@ -39,15 +39,13 @@ import org.apache.spark.util.Utils */ class ArrowPythonRunner( funcs: Seq[ChainedPythonFunctions], - bufferSize: Int, - reuseWorker: Boolean, evalType: Int, argOffsets: Array[Array[Int]], schema: StructType, timeZoneId: String, conf: Map[String, String]) extends BasePythonRunner[Iterator[InternalRow], ColumnarBatch]( - funcs, bufferSize, reuseWorker, evalType, argOffsets) { + funcs, evalType, argOffsets) { protected override def newWriterThread( env: SparkEnv, @@ -131,7 +129,7 @@ class ArrowPythonRunner( private var schema: StructType = _ private var vectors: Array[ColumnVector] = _ - context.addTaskCompletionListener { _ => + context.addTaskCompletionListener[Unit] { _ => if (reader != null) { reader.close(false) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala index f4d83e8dc7c2b..b08b7e60e130b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala @@ -25,9 +25,16 @@ import org.apache.spark.TaskContext import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, UnaryNode} import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.types.{StructField, StructType} +/** + * A logical plan that evaluates a [[PythonUDF]] + */ +case class BatchEvalPython(udfs: Seq[PythonUDF], output: Seq[Attribute], child: LogicalPlan) + extends UnaryNode + /** * A physical plan that evaluates a [[PythonUDF]] */ @@ -36,8 +43,6 @@ case class BatchEvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], chi protected override def evaluate( funcs: Seq[ChainedPythonFunctions], - bufferSize: Int, - reuseWorker: Boolean, argOffsets: Array[Array[Int]], iter: Iterator[InternalRow], schema: StructType, @@ -68,8 +73,7 @@ case class BatchEvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], chi }.grouped(100).map(x => pickle.dumps(x.toArray)) // Output iterator for results from Python. - val outputIterator = new PythonUDFRunner( - funcs, bufferSize, reuseWorker, PythonEvalType.SQL_BATCHED_UDF, argOffsets) + val outputIterator = new PythonUDFRunner(funcs, PythonEvalType.SQL_BATCHED_UDF, argOffsets) .compute(inputIterator, context.partitionId(), context) val unpickle = new Unpickler diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonExec.scala index 860dc78c1dd1b..942a6db57416e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonExec.scala @@ -78,8 +78,6 @@ abstract class EvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], chil protected def evaluate( funcs: Seq[ChainedPythonFunctions], - bufferSize: Int, - reuseWorker: Boolean, argOffsets: Array[Array[Int]], iter: Iterator[InternalRow], schema: StructType, @@ -87,8 +85,6 @@ abstract class EvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], chil protected override def doExecute(): RDD[InternalRow] = { val inputRDD = child.execute().map(_.copy()) - val bufferSize = inputRDD.conf.getInt("spark.buffer.size", 65536) - val reuseWorker = inputRDD.conf.getBoolean("spark.python.worker.reuse", defaultValue = true) inputRDD.mapPartitions { iter => val context = TaskContext.get() @@ -97,7 +93,7 @@ abstract class EvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], chil // combine input with output from Python. val queue = HybridRowQueue(context.taskMemoryManager(), new File(Utils.getLocalDir(SparkEnv.get.conf)), child.output.length) - context.addTaskCompletionListener { ctx => + context.addTaskCompletionListener[Unit] { ctx => queue.close() } @@ -129,7 +125,7 @@ abstract class EvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], chil } val outputRowIterator = evaluate( - pyFuncs, bufferSize, reuseWorker, argOffsets, projectedRowIter, schema, context) + pyFuncs, argOffsets, projectedRowIter, schema, context) val joined = new JoinedRow val resultProj = UnsafeProjection.create(output, output) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala index 1e096100f7f43..90b5325919e96 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala @@ -21,11 +21,11 @@ import scala.collection.mutable import scala.collection.mutable.ArrayBuffer import org.apache.spark.api.python.PythonEvalType +import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression -import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan, Project} +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, LogicalPlan, Project} import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.execution.{FilterExec, ProjectExec, SparkPlan} /** @@ -92,38 +92,54 @@ object ExtractPythonUDFFromAggregate extends Rule[LogicalPlan] { * This has the limitation that the input to the Python UDF is not allowed include attributes from * multiple child operators. */ -object ExtractPythonUDFs extends Rule[SparkPlan] with PredicateHelper { +object ExtractPythonUDFs extends Rule[LogicalPlan] with PredicateHelper { - private def hasPythonUDF(e: Expression): Boolean = { + private type EvalType = Int + private type EvalTypeChecker = EvalType => Boolean + + private def hasScalarPythonUDF(e: Expression): Boolean = { e.find(PythonUDF.isScalarPythonUDF).isDefined } private def canEvaluateInPython(e: PythonUDF): Boolean = { e.children match { // single PythonUDF child could be chained and evaluated in Python - case Seq(u: PythonUDF) => canEvaluateInPython(u) + case Seq(u: PythonUDF) => e.evalType == u.evalType && canEvaluateInPython(u) // Python UDF can't be evaluated directly in JVM - case children => !children.exists(hasPythonUDF) + case children => !children.exists(hasScalarPythonUDF) } } - private def collectEvaluatableUDF(expr: Expression): Seq[PythonUDF] = expr match { - case udf: PythonUDF if PythonUDF.isScalarPythonUDF(udf) && canEvaluateInPython(udf) => Seq(udf) - case e => e.children.flatMap(collectEvaluatableUDF) + private def collectEvaluableUDFsFromExpressions(expressions: Seq[Expression]): Seq[PythonUDF] = { + // Eval type checker is set once when we find the first evaluable UDF and its value + // shouldn't change later. + // Used to check if subsequent UDFs are of the same type as the first UDF. (since we can only + // extract UDFs of the same eval type) + var evalTypeChecker: Option[EvalTypeChecker] = None + + def collectEvaluableUDFs(expr: Expression): Seq[PythonUDF] = expr match { + case udf: PythonUDF if PythonUDF.isScalarPythonUDF(udf) && canEvaluateInPython(udf) + && evalTypeChecker.isEmpty => + evalTypeChecker = Some((otherEvalType: EvalType) => otherEvalType == udf.evalType) + Seq(udf) + case udf: PythonUDF if PythonUDF.isScalarPythonUDF(udf) && canEvaluateInPython(udf) + && evalTypeChecker.get(udf.evalType) => + Seq(udf) + case e => e.children.flatMap(collectEvaluableUDFs) + } + + expressions.flatMap(collectEvaluableUDFs) } - def apply(plan: SparkPlan): SparkPlan = plan transformUp { - // AggregateInPandasExec and FlatMapGroupsInPandas can be evaluated directly in python worker - // Therefore we don't need to extract the UDFs - case plan: FlatMapGroupsInPandasExec => plan - case plan: SparkPlan => extract(plan) + def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { + case plan: LogicalPlan => extract(plan) } /** * Extract all the PythonUDFs from the current operator and evaluate them before the operator. */ - private def extract(plan: SparkPlan): SparkPlan = { - val udfs = plan.expressions.flatMap(collectEvaluatableUDF) + private def extract(plan: LogicalPlan): LogicalPlan = { + val udfs = collectEvaluableUDFsFromExpressions(plan.expressions) // ignore the PythonUDF that come from second/third aggregate, which is not used .filter(udf => udf.references.subsetOf(plan.inputSet)) if (udfs.isEmpty) { @@ -134,7 +150,7 @@ object ExtractPythonUDFs extends Rule[SparkPlan] with PredicateHelper { val prunedChildren = plan.children.map { child => val allNeededOutput = inputsForPlan.intersect(child.outputSet).toSeq if (allNeededOutput.length != child.output.length) { - ProjectExec(allNeededOutput, child) + Project(allNeededOutput, child) } else { child } @@ -163,11 +179,12 @@ object ExtractPythonUDFs extends Rule[SparkPlan] with PredicateHelper { _.evalType == PythonEvalType.SQL_SCALAR_PANDAS_UDF ) match { case (vectorizedUdfs, plainUdfs) if plainUdfs.isEmpty => - ArrowEvalPythonExec(vectorizedUdfs, child.output ++ resultAttrs, child) + ArrowEvalPython(vectorizedUdfs, child.output ++ resultAttrs, child) case (vectorizedUdfs, plainUdfs) if vectorizedUdfs.isEmpty => - BatchEvalPythonExec(plainUdfs, child.output ++ resultAttrs, child) + BatchEvalPython(plainUdfs, child.output ++ resultAttrs, child) case _ => - throw new IllegalArgumentException("Can not mix vectorized and non-vectorized UDFs") + throw new AnalysisException( + "Expected either Scalar Pandas UDFs or Batched UDFs but got both") } attributeMap ++= validUdfs.zip(resultAttrs) @@ -191,7 +208,7 @@ object ExtractPythonUDFs extends Rule[SparkPlan] with PredicateHelper { val newPlan = extract(rewritten) if (newPlan.output != plan.output) { // Trim away the new UDF value if it was only used for filtering or something. - ProjectExec(plan.output, newPlan) + Project(plan.output, newPlan) } else { newPlan } @@ -200,15 +217,15 @@ object ExtractPythonUDFs extends Rule[SparkPlan] with PredicateHelper { // Split the original FilterExec to two FilterExecs. Only push down the first few predicates // that are all deterministic. - private def trySplitFilter(plan: SparkPlan): SparkPlan = { + private def trySplitFilter(plan: LogicalPlan): LogicalPlan = { plan match { - case filter: FilterExec => + case filter: Filter => val (candidates, nonDeterministic) = splitConjunctivePredicates(filter.condition).partition(_.deterministic) - val (pushDown, rest) = candidates.partition(!hasPythonUDF(_)) + val (pushDown, rest) = candidates.partition(!hasScalarPythonUDF(_)) if (pushDown.nonEmpty) { - val newChild = FilterExec(pushDown.reduceLeft(And), filter.child) - FilterExec((rest ++ nonDeterministic).reduceLeft(And), newChild) + val newChild = Filter(pushDown.reduceLeft(And), filter.child) + Filter((rest ++ nonDeterministic).reduceLeft(And), newChild) } else { filter } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala index f5a563baf52df..e9cff1a5a2007 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala @@ -74,8 +74,6 @@ case class FlatMapGroupsInPandasExec( override protected def doExecute(): RDD[InternalRow] = { val inputRDD = child.execute() - val bufferSize = inputRDD.conf.getInt("spark.buffer.size", 65536) - val reuseWorker = inputRDD.conf.getBoolean("spark.python.worker.reuse", defaultValue = true) val chainedFunc = Seq(ChainedPythonFunctions(Seq(pandasFunction))) val sessionLocalTimeZone = conf.sessionLocalTimeZone val pythonRunnerConf = ArrowUtils.getPythonRunnerConfMap(conf) @@ -141,8 +139,6 @@ case class FlatMapGroupsInPandasExec( val columnarBatchIter = new ArrowPythonRunner( chainedFunc, - bufferSize, - reuseWorker, PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF, argOffsets, dedupSchema, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonForeachWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonForeachWriter.scala index a58773122922f..a4e9b3305052f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonForeachWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonForeachWriter.scala @@ -45,10 +45,7 @@ class PythonForeachWriter(func: PythonFunction, schema: StructType) } private lazy val pythonRunner = { - val conf = SparkEnv.get.conf - val bufferSize = conf.getInt("spark.buffer.size", 65536) - val reuseWorker = conf.getBoolean("spark.python.worker.reuse", true) - PythonRunner(func, bufferSize, reuseWorker) + PythonRunner(func) } private lazy val outputIterator = @@ -56,7 +53,7 @@ class PythonForeachWriter(func: PythonFunction, schema: StructType) override def open(partitionId: Long, version: Long): Boolean = { outputIterator // initialize everything - TaskContext.get.addTaskCompletionListener { _ => buffer.close() } + TaskContext.get.addTaskCompletionListener[Unit] { _ => buffer.close() } true } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala index e28def1c4b423..cc61faa7e7051 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala @@ -29,12 +29,10 @@ import org.apache.spark.api.python._ */ class PythonUDFRunner( funcs: Seq[ChainedPythonFunctions], - bufferSize: Int, - reuseWorker: Boolean, evalType: Int, argOffsets: Array[Array[Int]]) extends BasePythonRunner[Array[Byte], Array[Byte]]( - funcs, bufferSize, reuseWorker, evalType, argOffsets) { + funcs, evalType, argOffsets) { protected override def newWriterThread( env: SparkEnv, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasExec.scala index 628029b13a6c3..27bed1137e5b3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasExec.scala @@ -95,8 +95,6 @@ case class WindowInPandasExec( protected override def doExecute(): RDD[InternalRow] = { val inputRDD = child.execute() - val bufferSize = inputRDD.conf.getInt("spark.buffer.size", 65536) - val reuseWorker = inputRDD.conf.getBoolean("spark.python.worker.reuse", defaultValue = true) val sessionLocalTimeZone = conf.sessionLocalTimeZone val pythonRunnerConf = ArrowUtils.getPythonRunnerConfMap(conf) @@ -142,7 +140,7 @@ case class WindowInPandasExec( // combine input with output from Python. val queue = HybridRowQueue(context.taskMemoryManager(), new File(Utils.getLocalDir(SparkEnv.get.conf)), child.output.length) - context.addTaskCompletionListener { _ => + context.addTaskCompletionListener[Unit] { _ => queue.close() } @@ -156,8 +154,6 @@ case class WindowInPandasExec( val windowFunctionResult = new ArrowPythonRunner( pyFuncs, - bufferSize, - reuseWorker, PythonEvalType.SQL_WINDOW_AGG_PANDAS_UDF, argOffsets, windowInputSchema, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala index 685d5841ab551..bea652cc33076 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala @@ -157,7 +157,7 @@ object StatFunctions extends Logging { cols.map(name => (name, df.schema.fields.find(_.name == name))).foreach { case (name, data) => require(data.nonEmpty, s"Couldn't find column with name $name") require(data.get.dataType.isInstanceOf[NumericType], s"Currently $functionName calculation " + - s"for columns with dataType ${data.get.dataType} not supported.") + s"for columns with dataType ${data.get.dataType.catalogString} not supported.") } val columns = cols.map(n => Column(Cast(Column(n).expr, DoubleType))) df.select(columns: _*).queryExecution.toRdd.treeAggregate(new CovarianceCounter)( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/CommitLog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/CommitLog.scala index 5b114242558dc..0063318db332d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/CommitLog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/CommitLog.scala @@ -22,6 +22,9 @@ import java.nio.charset.StandardCharsets._ import scala.io.{Source => IOSource} +import org.json4s.NoTypeHints +import org.json4s.jackson.Serialization + import org.apache.spark.sql.SparkSession /** @@ -43,36 +46,28 @@ import org.apache.spark.sql.SparkSession * line 2: metadata (optional json string) */ class CommitLog(sparkSession: SparkSession, path: String) - extends HDFSMetadataLog[String](sparkSession, path) { + extends HDFSMetadataLog[CommitMetadata](sparkSession, path) { import CommitLog._ - def add(batchId: Long): Unit = { - super.add(batchId, EMPTY_JSON) - } - - override def add(batchId: Long, metadata: String): Boolean = { - throw new UnsupportedOperationException( - "CommitLog does not take any metadata, use 'add(batchId)' instead") - } - - override protected def deserialize(in: InputStream): String = { + override protected def deserialize(in: InputStream): CommitMetadata = { // called inside a try-finally where the underlying stream is closed in the caller val lines = IOSource.fromInputStream(in, UTF_8.name()).getLines() if (!lines.hasNext) { throw new IllegalStateException("Incomplete log file in the offset commit log") } parseVersion(lines.next.trim, VERSION) - EMPTY_JSON + val metadataJson = if (lines.hasNext) lines.next else EMPTY_JSON + CommitMetadata(metadataJson) } - override protected def serialize(metadata: String, out: OutputStream): Unit = { + override protected def serialize(metadata: CommitMetadata, out: OutputStream): Unit = { // called inside a try-finally where the underlying stream is closed in the caller out.write(s"v${VERSION}".getBytes(UTF_8)) out.write('\n') // write metadata - out.write(EMPTY_JSON.getBytes(UTF_8)) + out.write(metadata.json.getBytes(UTF_8)) } } @@ -81,3 +76,13 @@ object CommitLog { private val EMPTY_JSON = "{}" } + +case class CommitMetadata(nextBatchWatermarkMs: Long = 0) { + def json: String = Serialization.write(this)(CommitMetadata.format) +} + +object CommitMetadata { + implicit val format = Serialization.formats(NoTypeHints) + + def apply(json: String): CommitMetadata = Serialization.read[CommitMetadata](json) +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ContinuousRecordEndpoint.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ContinuousRecordEndpoint.scala new file mode 100644 index 0000000000000..c9c2ebc875f28 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ContinuousRecordEndpoint.scala @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.streaming + +import org.apache.spark.SparkEnv +import org.apache.spark.rpc.{RpcCallContext, RpcEnv, ThreadSafeRpcEndpoint} +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.sources.v2.reader.streaming.PartitionOffset + +case class ContinuousRecordPartitionOffset(partitionId: Int, offset: Int) extends PartitionOffset +case class GetRecord(offset: ContinuousRecordPartitionOffset) + +/** + * A RPC end point for continuous readers to poll for + * records from the driver. + * + * @param buckets the data buckets. Each bucket contains a sequence of items to be + * returned for a partition. The number of buckets should be equal to + * to the number of partitions. + * @param lock a lock object for locking the buckets for read + */ +class ContinuousRecordEndpoint(buckets: Seq[Seq[Any]], lock: Object) + extends ThreadSafeRpcEndpoint { + + private var startOffsets: Seq[Int] = List.fill(buckets.size)(0) + + /** + * Sets the start offset. + * + * @param offsets the base offset per partition to be used + * while retrieving the data in {#receiveAndReply}. + */ + def setStartOffsets(offsets: Seq[Int]): Unit = { + lock.synchronized { + startOffsets = offsets + } + } + + override val rpcEnv: RpcEnv = SparkEnv.get.rpcEnv + + /** + * Process messages from `RpcEndpointRef.ask`. If receiving a unmatched message, + * `SparkException` will be thrown and sent to `onError`. + */ + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { + case GetRecord(ContinuousRecordPartitionOffset(partitionId, offset)) => + lock.synchronized { + val bufOffset = offset - startOffsets(partitionId) + val buf = buckets(partitionId) + val record = if (buf.size <= bufOffset) None else Some(buf(bufOffset)) + + context.reply(record.map(InternalRow(_))) + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala index 8c016abc5b643..103fa7ce9066d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala @@ -50,7 +50,7 @@ class FileStreamSource( @transient private val fs = new Path(path).getFileSystem(hadoopConf) private val qualifiedBasePath: Path = { - fs.makeQualified(new Path(path)) // can contains glob patterns + fs.makeQualified(new Path(path)) // can contain glob patterns } private val optionsWithPartitionBasePath = sourceOptions.optionMapWithoutPath ++ { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala index 8e82cccbc8fa3..bfe7d00f56048 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala @@ -23,10 +23,8 @@ import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, Attribut import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, Distribution} import org.apache.spark.sql.execution._ -import org.apache.spark.sql.execution.streaming.GroupStateImpl.NO_TIMESTAMP import org.apache.spark.sql.execution.streaming.state._ import org.apache.spark.sql.streaming.{GroupStateTimeout, OutputMode} -import org.apache.spark.sql.types.IntegerType import org.apache.spark.util.CompletionIterator /** @@ -52,6 +50,7 @@ case class FlatMapGroupsWithStateExec( outputObjAttr: Attribute, stateInfo: Option[StatefulOperatorStateInfo], stateEncoder: ExpressionEncoder[Any], + stateFormatVersion: Int, outputMode: OutputMode, timeoutConf: GroupStateTimeout, batchTimestampMs: Option[Long], @@ -60,32 +59,15 @@ case class FlatMapGroupsWithStateExec( ) extends UnaryExecNode with ObjectProducerExec with StateStoreWriter with WatermarkSupport { import GroupStateImpl._ + import FlatMapGroupsWithStateExecHelper._ private val isTimeoutEnabled = timeoutConf != NoTimeout - private val timestampTimeoutAttribute = - AttributeReference("timeoutTimestamp", dataType = IntegerType, nullable = false)() - private val stateAttributes: Seq[Attribute] = { - val encSchemaAttribs = stateEncoder.schema.toAttributes - if (isTimeoutEnabled) encSchemaAttribs :+ timestampTimeoutAttribute else encSchemaAttribs - } - // Get the serializer for the state, taking into account whether we need to save timestamps - private val stateSerializer = { - val encoderSerializer = stateEncoder.namedExpressions - if (isTimeoutEnabled) { - encoderSerializer :+ Literal(GroupStateImpl.NO_TIMESTAMP) - } else { - encoderSerializer - } - } - // Get the deserializer for the state. Note that this must be done in the driver, as - // resolving and binding of deserializer expressions to the encoded type can be safely done - // only in the driver. - private val stateDeserializer = stateEncoder.resolveAndBind().deserializer - private val watermarkPresent = child.output.exists { case a: Attribute if a.metadata.contains(EventTimeWatermark.delayKey) => true case _ => false } + private[sql] val stateManager = + createStateManager(stateEncoder, isTimeoutEnabled, stateFormatVersion) /** Distribute by grouping attributes */ override def requiredChildDistribution: Seq[Distribution] = @@ -125,11 +107,11 @@ case class FlatMapGroupsWithStateExec( child.execute().mapPartitionsWithStateStore[InternalRow]( getStateInfo, groupingAttributes.toStructType, - stateAttributes.toStructType, + stateManager.stateSchema, indexOrdinal = None, sqlContext.sessionState, Some(sqlContext.streams.stateStoreCoordinator)) { case (store, iter) => - val updater = new StateStoreUpdater(store) + val processor = new InputProcessor(store) // If timeout is based on event time, then filter late data based on watermark val filteredIter = watermarkPredicateForData match { @@ -143,7 +125,7 @@ case class FlatMapGroupsWithStateExec( // all the data has been processed. This is to ensure that the timeout information of all // the keys with data is updated before they are processed for timeouts. val outputIterator = - updater.updateStateForKeysWithData(filteredIter) ++ updater.updateStateForTimedOutKeys() + processor.processNewData(filteredIter) ++ processor.processTimedOutState() // Return an iterator of all the rows generated by all the keys, such that when fully // consumed, all the state updates will be committed by the state store @@ -158,7 +140,7 @@ case class FlatMapGroupsWithStateExec( } /** Helper class to update the state store */ - class StateStoreUpdater(store: StateStore) { + class InputProcessor(store: StateStore) { // Converters for translating input keys, values, output data between rows and Java objects private val getKeyObj = @@ -167,14 +149,6 @@ case class FlatMapGroupsWithStateExec( ObjectOperator.deserializeRowToObject(valueDeserializer, dataAttributes) private val getOutputRow = ObjectOperator.wrapObjectToRow(outputObjAttr.dataType) - // Converters for translating state between rows and Java objects - private val getStateObjFromRow = ObjectOperator.deserializeRowToObject( - stateDeserializer, stateAttributes) - private val getStateRowFromObj = ObjectOperator.serializeObjectToRow(stateSerializer) - - // Index of the additional metadata fields in the state row - private val timeoutTimestampIndex = stateAttributes.indexOf(timestampTimeoutAttribute) - // Metrics private val numUpdatedStateRows = longMetric("numUpdatedStateRows") private val numOutputRows = longMetric("numOutputRows") @@ -183,20 +157,19 @@ case class FlatMapGroupsWithStateExec( * For every group, get the key, values and corresponding state and call the function, * and return an iterator of rows */ - def updateStateForKeysWithData(dataIter: Iterator[InternalRow]): Iterator[InternalRow] = { + def processNewData(dataIter: Iterator[InternalRow]): Iterator[InternalRow] = { val groupedIter = GroupedIterator(dataIter, groupingAttributes, child.output) groupedIter.flatMap { case (keyRow, valueRowIter) => val keyUnsafeRow = keyRow.asInstanceOf[UnsafeRow] callFunctionAndUpdateState( - keyUnsafeRow, + stateManager.getState(store, keyUnsafeRow), valueRowIter, - store.get(keyUnsafeRow), hasTimedOut = false) } } /** Find the groups that have timeout set and are timing out right now, and call the function */ - def updateStateForTimedOutKeys(): Iterator[InternalRow] = { + def processTimedOutState(): Iterator[InternalRow] = { if (isTimeoutEnabled) { val timeoutThreshold = timeoutConf match { case ProcessingTimeTimeout => batchTimestampMs.get @@ -205,12 +178,11 @@ case class FlatMapGroupsWithStateExec( throw new IllegalStateException( s"Cannot filter timed out keys for $timeoutConf") } - val timingOutPairs = store.getRange(None, None).filter { rowPair => - val timeoutTimestamp = getTimeoutTimestamp(rowPair.value) - timeoutTimestamp != NO_TIMESTAMP && timeoutTimestamp < timeoutThreshold + val timingOutPairs = stateManager.getAllState(store).filter { state => + state.timeoutTimestamp != NO_TIMESTAMP && state.timeoutTimestamp < timeoutThreshold } - timingOutPairs.flatMap { rowPair => - callFunctionAndUpdateState(rowPair.key, Iterator.empty, rowPair.value, hasTimedOut = true) + timingOutPairs.flatMap { stateData => + callFunctionAndUpdateState(stateData, Iterator.empty, hasTimedOut = true) } } else Iterator.empty } @@ -220,22 +192,19 @@ case class FlatMapGroupsWithStateExec( * iterator. Note that the store updating is lazy, that is, the store will be updated only * after the returned iterator is fully consumed. * - * @param keyRow Row representing the key, cannot be null + * @param stateData All the data related to the state to be updated * @param valueRowIter Iterator of values as rows, cannot be null, but can be empty - * @param prevStateRow Row representing the previous state, can be null * @param hasTimedOut Whether this function is being called for a key timeout */ private def callFunctionAndUpdateState( - keyRow: UnsafeRow, + stateData: StateData, valueRowIter: Iterator[InternalRow], - prevStateRow: UnsafeRow, hasTimedOut: Boolean): Iterator[InternalRow] = { - val keyObj = getKeyObj(keyRow) // convert key to objects + val keyObj = getKeyObj(stateData.keyRow) // convert key to objects val valueObjIter = valueRowIter.map(getValueObj.apply) // convert value rows to objects - val stateObj = getStateObj(prevStateRow) - val keyedState = GroupStateImpl.createForStreaming( - Option(stateObj), + val groupState = GroupStateImpl.createForStreaming( + Option(stateData.stateObj), batchTimestampMs.getOrElse(NO_TIMESTAMP), eventTimeWatermark.getOrElse(NO_TIMESTAMP), timeoutConf, @@ -243,50 +212,24 @@ case class FlatMapGroupsWithStateExec( watermarkPresent) // Call function, get the returned objects and convert them to rows - val mappedIterator = func(keyObj, valueObjIter, keyedState).map { obj => + val mappedIterator = func(keyObj, valueObjIter, groupState).map { obj => numOutputRows += 1 getOutputRow(obj) } // When the iterator is consumed, then write changes to state def onIteratorCompletion: Unit = { - - val currentTimeoutTimestamp = keyedState.getTimeoutTimestamp - // If the state has not yet been set but timeout has been set, then - // we have to generate a row to save the timeout. However, attempting serialize - // null using case class encoder throws - - // java.lang.NullPointerException: Null value appeared in non-nullable field: - // If the schema is inferred from a Scala tuple / case class, or a Java bean, please - // try to use scala.Option[_] or other nullable types. - if (!keyedState.exists && currentTimeoutTimestamp != NO_TIMESTAMP) { - throw new IllegalStateException( - "Cannot set timeout when state is not defined, that is, state has not been" + - "initialized or has been removed") - } - - if (keyedState.hasRemoved) { - store.remove(keyRow) + if (groupState.hasRemoved && groupState.getTimeoutTimestamp == NO_TIMESTAMP) { + stateManager.removeState(store, stateData.keyRow) numUpdatedStateRows += 1 - } else { - val previousTimeoutTimestamp = getTimeoutTimestamp(prevStateRow) - val stateRowToWrite = if (keyedState.hasUpdated) { - getStateRow(keyedState.get) - } else { - prevStateRow - } - - val hasTimeoutChanged = currentTimeoutTimestamp != previousTimeoutTimestamp - val shouldWriteState = keyedState.hasUpdated || hasTimeoutChanged + val currentTimeoutTimestamp = groupState.getTimeoutTimestamp + val hasTimeoutChanged = currentTimeoutTimestamp != stateData.timeoutTimestamp + val shouldWriteState = groupState.hasUpdated || groupState.hasRemoved || hasTimeoutChanged if (shouldWriteState) { - if (stateRowToWrite == null) { - // This should never happen because checks in GroupStateImpl should avoid cases - // where empty state would need to be written - throw new IllegalStateException("Attempting to write empty state") - } - setTimeoutTimestamp(stateRowToWrite, currentTimeoutTimestamp) - store.put(keyRow, stateRowToWrite) + val updatedStateObj = if (groupState.exists) groupState.get else null + stateManager.putState(store, stateData.keyRow, updatedStateObj, currentTimeoutTimestamp) numUpdatedStateRows += 1 } } @@ -295,28 +238,5 @@ case class FlatMapGroupsWithStateExec( // Return an iterator of rows such that fully consumed, the updated state value will be saved CompletionIterator[InternalRow, Iterator[InternalRow]](mappedIterator, onIteratorCompletion) } - - /** Returns the state as Java object if defined */ - def getStateObj(stateRow: UnsafeRow): Any = { - if (stateRow != null) getStateObjFromRow(stateRow) else null - } - - /** Returns the row for an updated state */ - def getStateRow(obj: Any): UnsafeRow = { - assert(obj != null) - getStateRowFromObj(obj) - } - - /** Returns the timeout timestamp of a state row is set */ - def getTimeoutTimestamp(stateRow: UnsafeRow): Long = { - if (isTimeoutEnabled && stateRow != null) { - stateRow.getLong(timeoutTimestampIndex) - } else NO_TIMESTAMP - } - - /** Set the timestamp in a state row */ - def setTimeoutTimestamp(stateRow: UnsafeRow, timeoutTimestamps: Long): Unit = { - if (isTimeoutEnabled) stateRow.setLong(timeoutTimestampIndex, timeoutTimestamps) - } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala index c480b96626f84..fad287e28877d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala @@ -22,7 +22,7 @@ import java.util.concurrent.atomic.AtomicInteger import org.apache.spark.internal.Logging import org.apache.spark.sql.{AnalysisException, SparkSession, Strategy} -import org.apache.spark.sql.catalyst.expressions.CurrentBatchTimestamp +import org.apache.spark.sql.catalyst.expressions.{CurrentBatchTimestamp, ExpressionWithRandomSeed} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution, HashPartitioning, SinglePartition} import org.apache.spark.sql.catalyst.rules.Rule @@ -30,6 +30,7 @@ import org.apache.spark.sql.execution.{QueryExecution, SparkPlan, SparkPlanner, import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.OutputMode +import org.apache.spark.util.Utils /** * A variant of [[QueryExecution]] that allows the execution of the given [[LogicalPlan]] @@ -59,7 +60,8 @@ class IncrementalExecution( StatefulAggregationStrategy :: FlatMapGroupsWithStateStrategy :: StreamingRelationStrategy :: - StreamingDeduplicationStrategy :: Nil + StreamingDeduplicationStrategy :: + StreamingGlobalLimitStrategy(outputMode) :: Nil } private[sql] val numStateStores = offsetSeqMetadata.conf.get(SQLConf.SHUFFLE_PARTITIONS.key) @@ -76,6 +78,7 @@ class IncrementalExecution( case ts @ CurrentBatchTimestamp(timestamp, _, _) => logInfo(s"Current batch timestamp = $timestamp") ts.toLiteral + case e: ExpressionWithRandomSeed => e.withNewSeed(Utils.random.nextLong()) } } @@ -99,19 +102,21 @@ class IncrementalExecution( val state = new Rule[SparkPlan] { override def apply(plan: SparkPlan): SparkPlan = plan transform { - case StateStoreSaveExec(keys, None, None, None, + case StateStoreSaveExec(keys, None, None, None, stateFormatVersion, UnaryExecNode(agg, - StateStoreRestoreExec(_, None, child))) => + StateStoreRestoreExec(_, None, _, child))) => val aggStateInfo = nextStatefulOperationStateInfo StateStoreSaveExec( keys, Some(aggStateInfo), Some(outputMode), Some(offsetSeqMetadata.batchWatermarkMs), + stateFormatVersion, agg.withNewChildren( StateStoreRestoreExec( keys, Some(aggStateInfo), + stateFormatVersion, child) :: Nil)) case StreamingDeduplicateExec(keys, child, None, None) => @@ -134,8 +139,12 @@ class IncrementalExecution( stateWatermarkPredicates = StreamingSymmetricHashJoinHelper.getStateWatermarkPredicates( j.left.output, j.right.output, j.leftKeys, j.rightKeys, j.condition.full, - Some(offsetSeqMetadata.batchWatermarkMs)) - ) + Some(offsetSeqMetadata.batchWatermarkMs))) + + case l: StreamingGlobalLimitExec => + l.copy( + stateInfo = Some(nextStatefulOperationStateInfo), + outputMode = Some(outputMode)) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MetricsReporter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MetricsReporter.scala index 66b11ecddf233..8709822acff12 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MetricsReporter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MetricsReporter.scala @@ -17,10 +17,13 @@ package org.apache.spark.sql.execution.streaming +import java.text.SimpleDateFormat + import com.codahale.metrics.{Gauge, MetricRegistry} import org.apache.spark.internal.Logging import org.apache.spark.metrics.source.{Source => CodahaleSource} +import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.streaming.StreamingQueryProgress /** @@ -39,6 +42,23 @@ class MetricsReporter( registerGauge("processingRate-total", _.processedRowsPerSecond, 0.0) registerGauge("latency", _.durationMs.get("triggerExecution").longValue(), 0L) + private val timestampFormat = new SimpleDateFormat("yyyy-MM-dd'T'HH:mm:ss.SSS'Z'") // ISO8601 + timestampFormat.setTimeZone(DateTimeUtils.getTimeZone("UTC")) + + registerGauge("eventTime-watermark", + progress => convertStringDateToMillis(progress.eventTime.get("watermark")), 0L) + + registerGauge("states-rowsTotal", _.stateOperators.map(_.numRowsTotal).sum, 0L) + registerGauge("states-usedBytes", _.stateOperators.map(_.memoryUsedBytes).sum, 0L) + + private def convertStringDateToMillis(isoUtcDateStr: String) = { + if (isoUtcDateStr != null) { + timestampFormat.parse(isoUtcDateStr).getTime + } else { + 0L + } + } + private def registerGauge[T]( name: String, f: StreamingQueryProgress => T, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala index 17ffa2a517312..2cac86599ef19 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala @@ -17,8 +17,6 @@ package org.apache.spark.sql.execution.streaming -import java.util.Optional - import scala.collection.JavaConverters._ import scala.collection.mutable.{Map => MutableMap} @@ -28,10 +26,9 @@ import org.apache.spark.sql.catalyst.expressions.{Alias, CurrentBatchTimestamp, import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, Project} import org.apache.spark.sql.execution.SQLExecution import org.apache.spark.sql.execution.datasources.v2.{StreamingDataSourceV2Relation, WriteToDataSourceV2} -import org.apache.spark.sql.execution.streaming.sources.{InternalRowMicroBatchWriter, MicroBatchWriter} -import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2, MicroBatchReadSupport, StreamWriteSupport} -import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReader, Offset => OffsetV2} -import org.apache.spark.sql.sources.v2.writer.SupportsWriteInternalRow +import org.apache.spark.sql.execution.streaming.sources.{MicroBatchWritSupport, RateControlMicroBatchReadSupport} +import org.apache.spark.sql.sources.v2._ +import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReadSupport, Offset => OffsetV2} import org.apache.spark.sql.streaming.{OutputMode, ProcessingTime, Trigger} import org.apache.spark.util.{Clock, Utils} @@ -52,8 +49,8 @@ class MicroBatchExecution( @volatile protected var sources: Seq[BaseStreamingSource] = Seq.empty - private val readerToDataSourceMap = - MutableMap.empty[MicroBatchReader, (DataSourceV2, Map[String, String])] + private val readSupportToDataSourceMap = + MutableMap.empty[MicroBatchReadSupport, (DataSourceV2, Map[String, String])] private val triggerExecutor = trigger match { case t: ProcessingTime => ProcessingTimeExecutor(t, triggerClock) @@ -61,7 +58,7 @@ class MicroBatchExecution( case _ => throw new IllegalStateException(s"Unknown type of trigger: $trigger") } - private val watermarkTracker = new WatermarkTracker() + private var watermarkTracker: WatermarkTracker = _ override lazy val logicalPlan: LogicalPlan = { assert(queryExecutionThread eq Thread.currentThread, @@ -92,20 +89,19 @@ class MicroBatchExecution( StreamingExecutionRelation(source, output)(sparkSession) }) case s @ StreamingRelationV2( - dataSourceV2: MicroBatchReadSupport, sourceName, options, output, _) if + dataSourceV2: MicroBatchReadSupportProvider, sourceName, options, output, _) if !disabledSources.contains(dataSourceV2.getClass.getCanonicalName) => v2ToExecutionRelationMap.getOrElseUpdate(s, { // Materialize source to avoid creating it in every batch val metadataPath = s"$resolvedCheckpointRoot/sources/$nextSourceId" - val reader = dataSourceV2.createMicroBatchReader( - Optional.empty(), // user specified schema + val readSupport = dataSourceV2.createMicroBatchReadSupport( metadataPath, new DataSourceOptions(options.asJava)) nextSourceId += 1 - readerToDataSourceMap(reader) = dataSourceV2 -> options - logInfo(s"Using MicroBatchReader [$reader] from " + + readSupportToDataSourceMap(readSupport) = dataSourceV2 -> options + logInfo(s"Using MicroBatchReadSupport [$readSupport] from " + s"DataSourceV2 named '$sourceName' [$dataSourceV2]") - StreamingExecutionRelation(reader, output)(sparkSession) + StreamingExecutionRelation(readSupport, output)(sparkSession) }) case s @ StreamingRelationV2(dataSourceV2, sourceName, _, output, v1Relation) => v2ToExecutionRelationMap.getOrElseUpdate(s, { @@ -184,6 +180,9 @@ class MicroBatchExecution( isCurrentBatchConstructed = constructNextBatch(noDataBatchesEnabled) } + // Record the trigger offset range for progress reporting *before* processing the batch + recordTriggerOffsets(from = committedOffsets, to = availableOffsets) + // Remember whether the current batch has data or not. This will be required later // for bookkeeping after running the batch, when `isNewDataAvailable` will have changed // to false as the batch would have already processed the available data. @@ -201,6 +200,10 @@ class MicroBatchExecution( finishTrigger(currentBatchHasNewData) // Must be outside reportTimeTaken so it is recorded + // Signal waiting threads. Note this must be after finishTrigger() to ensure all + // activities (progress generation, etc.) have completed before signaling. + withProgressLocked { awaitProgressLockCondition.signalAll() } + // If the current batch has been executed, then increment the batch id and reset flag. // Otherwise, there was no data to execute the batch and sleep for some time if (isCurrentBatchConstructed) { @@ -257,6 +260,7 @@ class MicroBatchExecution( OffsetSeqMetadata.setSessionConf(metadata, sparkSessionToRunBatches.conf) offsetSeqMetadata = OffsetSeqMetadata( metadata.batchWatermarkMs, metadata.batchTimestampMs, sparkSessionToRunBatches.conf) + watermarkTracker = WatermarkTracker(sparkSessionToRunBatches.conf) watermarkTracker.setWatermark(metadata.batchWatermarkMs) } @@ -264,7 +268,7 @@ class MicroBatchExecution( * latest batch id in the offset log, then we can safely move to the next batch * i.e., committedBatchId + 1 */ commitLog.getLatest() match { - case Some((latestCommittedBatchId, _)) => + case Some((latestCommittedBatchId, commitMetadata)) => if (latestBatchId == latestCommittedBatchId) { /* The last batch was successfully committed, so we can safely process a * new next batch but first: @@ -282,7 +286,8 @@ class MicroBatchExecution( currentBatchId = latestCommittedBatchId + 1 isCurrentBatchConstructed = false committedOffsets ++= availableOffsets - // Construct a new batch be recomputing availableOffsets + watermarkTracker.setWatermark( + math.max(watermarkTracker.currentWatermark, commitMetadata.nextBatchWatermarkMs)) } else if (latestCommittedBatchId < latestBatchId - 1) { logWarning(s"Batch completion log latest batch id is " + s"${latestCommittedBatchId}, which is not trailing " + @@ -295,6 +300,7 @@ class MicroBatchExecution( case None => // We are starting this stream for the first time. logInfo(s"Starting new streaming query.") currentBatchId = 0 + watermarkTracker = WatermarkTracker(sparkSessionToRunBatches.conf) } } @@ -335,19 +341,19 @@ class MicroBatchExecution( reportTimeTaken("getOffset") { (s, s.getOffset) } - case s: MicroBatchReader => + case s: RateControlMicroBatchReadSupport => updateStatusMessage(s"Getting offsets from $s") - reportTimeTaken("setOffsetRange") { - // Once v1 streaming source execution is gone, we can refactor this away. - // For now, we set the range here to get the source to infer the available end offset, - // get that offset, and then set the range again when we later execute. - s.setOffsetRange( - toJava(availableOffsets.get(s).map(off => s.deserializeOffset(off.json))), - Optional.empty()) + reportTimeTaken("latestOffset") { + val startOffset = availableOffsets + .get(s).map(off => s.deserializeOffset(off.json)) + .getOrElse(s.initialOffset()) + (s, Option(s.latestOffset(startOffset))) + } + case s: MicroBatchReadSupport => + updateStatusMessage(s"Getting offsets from $s") + reportTimeTaken("latestOffset") { + (s, Option(s.latestOffset())) } - - val currentOffset = reportTimeTaken("getEndOffset") { s.getEndOffset() } - (s, Option(currentOffset)) }.toMap availableOffsets ++= latestOffsets.filter { case (_, o) => o.nonEmpty }.mapValues(_.get) @@ -387,8 +393,11 @@ class MicroBatchExecution( if (prevBatchOff.isDefined) { prevBatchOff.get.toStreamProgress(sources).foreach { case (src: Source, off) => src.commit(off) - case (reader: MicroBatchReader, off) => - reader.commit(reader.deserializeOffset(off.json)) + case (readSupport: MicroBatchReadSupport, off) => + readSupport.commit(readSupport.deserializeOffset(off.json)) + case (src, _) => + throw new IllegalArgumentException( + s"Unknown source is found at constructNextBatch: $src") } } else { throw new IllegalStateException(s"batch ${currentBatchId - 1} doesn't exist") @@ -429,30 +438,34 @@ class MicroBatchExecution( s"${batch.queryExecution.logical}") logDebug(s"Retrieving data from $source: $current -> $available") Some(source -> batch.logicalPlan) - case (reader: MicroBatchReader, available) - if committedOffsets.get(reader).map(_ != available).getOrElse(true) => - val current = committedOffsets.get(reader).map(off => reader.deserializeOffset(off.json)) - val availableV2: OffsetV2 = available match { - case v1: SerializedOffset => reader.deserializeOffset(v1.json) + + // TODO(cloud-fan): for data source v2, the new batch is just a new `ScanConfigBuilder`, but + // to be compatible with streaming source v1, we return a logical plan as a new batch here. + case (readSupport: MicroBatchReadSupport, available) + if committedOffsets.get(readSupport).map(_ != available).getOrElse(true) => + val current = committedOffsets.get(readSupport).map { + off => readSupport.deserializeOffset(off.json) + } + val endOffset: OffsetV2 = available match { + case v1: SerializedOffset => readSupport.deserializeOffset(v1.json) case v2: OffsetV2 => v2 } - reader.setOffsetRange( - toJava(current), - Optional.of(availableV2)) - logDebug(s"Retrieving data from $reader: $current -> $availableV2") + val startOffset = current.getOrElse(readSupport.initialOffset) + val scanConfigBuilder = readSupport.newScanConfigBuilder(startOffset, endOffset) + logDebug(s"Retrieving data from $readSupport: $current -> $endOffset") - val (source, options) = reader match { + val (source, options) = readSupport match { // `MemoryStream` is special. It's for test only and doesn't have a `DataSourceV2` // implementation. We provide a fake one here for explain. case _: MemoryStream[_] => MemoryStreamDataSource -> Map.empty[String, String] // Provide a fake value here just in case something went wrong, e.g. the reader gives // a wrong `equals` implementation. - case _ => readerToDataSourceMap.getOrElse(reader, { + case _ => readSupportToDataSourceMap.getOrElse(readSupport, { FakeDataSourceV2 -> Map.empty[String, String] }) } - Some(reader -> StreamingDataSourceV2Relation( - reader.readSchema().toAttributes, source, options, reader)) + Some(readSupport -> StreamingDataSourceV2Relation( + readSupport.fullSchema().toAttributes, source, options, readSupport, scanConfigBuilder)) case _ => None } } @@ -486,23 +499,20 @@ class MicroBatchExecution( val triggerLogicalPlan = sink match { case _: Sink => newAttributePlan - case s: StreamWriteSupport => - val writer = s.createStreamWriter( + case s: StreamingWriteSupportProvider => + val writer = s.createStreamingWriteSupport( s"$runId", newAttributePlan.schema, outputMode, new DataSourceOptions(extraOptions.asJava)) - if (writer.isInstanceOf[SupportsWriteInternalRow]) { - WriteToDataSourceV2( - new InternalRowMicroBatchWriter(currentBatchId, writer), newAttributePlan) - } else { - WriteToDataSourceV2(new MicroBatchWriter(currentBatchId, writer), newAttributePlan) - } + WriteToDataSourceV2(new MicroBatchWritSupport(currentBatchId, writer), newAttributePlan) case _ => throw new IllegalArgumentException(s"unknown sink type for $sink") } sparkSessionToRunBatch.sparkContext.setLocalProperty( MicroBatchExecution.BATCH_ID_KEY, currentBatchId.toString) + sparkSessionToRunBatch.sparkContext.setLocalProperty( + StreamExecution.IS_CONTINUOUS_PROCESSING, false.toString) reportTimeTaken("queryPlanning") { lastExecution = new IncrementalExecution( @@ -523,7 +533,7 @@ class MicroBatchExecution( SQLExecution.withNewExecutionId(sparkSessionToRunBatch, lastExecution) { sink match { case s: Sink => s.addBatch(currentBatchId, nextBatch) - case _: StreamWriteSupport => + case _: StreamingWriteSupportProvider => // This doesn't accumulate any data - it just forces execution of the microbatch writer. nextBatch.collect() } @@ -531,11 +541,10 @@ class MicroBatchExecution( } withProgressLocked { - commitLog.add(currentBatchId) + watermarkTracker.updateWatermark(lastExecution.executedPlan) + commitLog.add(currentBatchId, CommitMetadata(watermarkTracker.currentWatermark)) committedOffsets ++= availableOffsets - awaitProgressLockCondition.signalAll() } - watermarkTracker.updateWatermark(lastExecution.executedPlan) logDebug(s"Completed batch ${currentBatchId}") } @@ -548,10 +557,6 @@ class MicroBatchExecution( awaitProgressLock.unlock() } } - - private def toJava(scalaOption: Option[OffsetV2]): Optional[OffsetV2] = { - Optional.ofNullable(scalaOption.orNull) - } } object MicroBatchExecution { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala index 787174481ff08..73cf355dbe758 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala @@ -22,7 +22,8 @@ import org.json4s.jackson.Serialization import org.apache.spark.internal.Logging import org.apache.spark.sql.RuntimeConfig -import org.apache.spark.sql.internal.SQLConf.{SHUFFLE_PARTITIONS, STATE_STORE_PROVIDER_CLASS} +import org.apache.spark.sql.execution.streaming.state.{FlatMapGroupsWithStateExecHelper, StreamingAggregationStateManager} +import org.apache.spark.sql.internal.SQLConf.{FLATMAPGROUPSWITHSTATE_STATE_FORMAT_VERSION, _} /** * An ordered collection of offsets, used to track the progress of processing data from one or more @@ -86,7 +87,27 @@ case class OffsetSeqMetadata( object OffsetSeqMetadata extends Logging { private implicit val format = Serialization.formats(NoTypeHints) - private val relevantSQLConfs = Seq(SHUFFLE_PARTITIONS, STATE_STORE_PROVIDER_CLASS) + private val relevantSQLConfs = Seq( + SHUFFLE_PARTITIONS, STATE_STORE_PROVIDER_CLASS, STREAMING_MULTIPLE_WATERMARK_POLICY, + FLATMAPGROUPSWITHSTATE_STATE_FORMAT_VERSION, STREAMING_AGGREGATION_STATE_FORMAT_VERSION) + + /** + * Default values of relevant configurations that are used for backward compatibility. + * As new configurations are added to the metadata, existing checkpoints may not have those + * confs. The values in this list ensures that the confs without recovered values are + * set to a default value that ensure the same behavior of the streaming query as it was before + * the restart. + * + * Note, that this is optional; set values here if you *have* to override existing session conf + * with a specific default value for ensuring same behavior of the query as before. + */ + private val relevantSQLConfDefaultValues = Map[String, String]( + STREAMING_MULTIPLE_WATERMARK_POLICY.key -> MultipleWatermarkPolicy.DEFAULT_POLICY_NAME, + FLATMAPGROUPSWITHSTATE_STATE_FORMAT_VERSION.key -> + FlatMapGroupsWithStateExecHelper.legacyVersion.toString, + STREAMING_AGGREGATION_STATE_FORMAT_VERSION.key -> + StreamingAggregationStateManager.legacyVersion.toString + ) def apply(json: String): OffsetSeqMetadata = Serialization.read[OffsetSeqMetadata](json) @@ -115,8 +136,22 @@ object OffsetSeqMetadata extends Logging { case None => // For backward compatibility, if a config was not recorded in the offset log, - // then log it, and let the existing conf value in SparkSession prevail. - logWarning (s"Conf '$confKey' was not found in the offset log, using existing value") + // then either inject a default value (if specified in `relevantSQLConfDefaultValues`) or + // let the existing conf value in SparkSession prevail. + relevantSQLConfDefaultValues.get(confKey) match { + + case Some(defaultValue) => + sessionConf.set(confKey, defaultValue) + logWarning(s"Conf '$confKey' was not found in the offset log, " + + s"using default value '$defaultValue'") + + case None => + val valueStr = sessionConf.getOption(confKey).map { v => + s" Using existing session conf value '$v'." + }.getOrElse { " No value set in session conf." } + logWarning(s"Conf '$confKey' was not found in the offset log. $valueStr") + + } } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala index 16ad3ef9a3d4a..73b180468d367 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala @@ -24,12 +24,12 @@ import scala.collection.JavaConverters._ import scala.collection.mutable import org.apache.spark.internal.Logging -import org.apache.spark.sql.{DataFrame, SparkSession} +import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.plans.logical.{EventTimeWatermark, LogicalPlan} import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution.QueryExecution import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanExec -import org.apache.spark.sql.sources.v2.reader.streaming.MicroBatchReader +import org.apache.spark.sql.sources.v2.reader.streaming.MicroBatchReadSupport import org.apache.spark.sql.streaming._ import org.apache.spark.sql.streaming.StreamingQueryListener.QueryProgressEvent import org.apache.spark.util.Clock @@ -56,8 +56,6 @@ trait ProgressReporter extends Logging { protected def logicalPlan: LogicalPlan protected def lastExecution: QueryExecution protected def newData: Map[BaseStreamingSource, LogicalPlan] - protected def availableOffsets: StreamProgress - protected def committedOffsets: StreamProgress protected def sources: Seq[BaseStreamingSource] protected def sink: BaseStreamingSink protected def offsetSeqMetadata: OffsetSeqMetadata @@ -68,8 +66,11 @@ trait ProgressReporter extends Logging { // Local timestamps and counters. private var currentTriggerStartTimestamp = -1L private var currentTriggerEndTimestamp = -1L + private var currentTriggerStartOffsets: Map[BaseStreamingSource, String] = _ + private var currentTriggerEndOffsets: Map[BaseStreamingSource, String] = _ // TODO: Restore this from the checkpoint when possible. private var lastTriggerStartTimestamp = -1L + private val currentDurationsMs = new mutable.HashMap[String, Long]() /** Flag that signals whether any error with input metrics have already been logged */ @@ -114,9 +115,20 @@ trait ProgressReporter extends Logging { lastTriggerStartTimestamp = currentTriggerStartTimestamp currentTriggerStartTimestamp = triggerClock.getTimeMillis() currentStatus = currentStatus.copy(isTriggerActive = true) + currentTriggerStartOffsets = null + currentTriggerEndOffsets = null currentDurationsMs.clear() } + /** + * Record the offsets range this trigger will process. Call this before updating + * `committedOffsets` in `StreamExecution` to make sure that the correct range is recorded. + */ + protected def recordTriggerOffsets(from: StreamProgress, to: StreamProgress): Unit = { + currentTriggerStartOffsets = from.mapValues(_.json) + currentTriggerEndOffsets = to.mapValues(_.json) + } + private def updateProgress(newProgress: StreamingQueryProgress): Unit = { progressBuffer.synchronized { progressBuffer += newProgress @@ -130,6 +142,7 @@ trait ProgressReporter extends Logging { /** Finalizes the query progress and adds it to list of recent status updates. */ protected def finishTrigger(hasNewData: Boolean): Unit = { + assert(currentTriggerStartOffsets != null && currentTriggerEndOffsets != null) currentTriggerEndTimestamp = triggerClock.getTimeMillis() val executionStats = extractExecutionStats(hasNewData) @@ -147,13 +160,14 @@ trait ProgressReporter extends Logging { val numRecords = executionStats.inputRows.getOrElse(source, 0L) new SourceProgress( description = source.toString, - startOffset = committedOffsets.get(source).map(_.json).orNull, - endOffset = availableOffsets.get(source).map(_.json).orNull, + startOffset = currentTriggerStartOffsets.get(source).orNull, + endOffset = currentTriggerEndOffsets.get(source).orNull, numInputRows = numRecords, inputRowsPerSecond = numRecords / inputTimeSec, processedRowsPerSecond = numRecords / processingTimeSec ) } + val sinkProgress = new SinkProgress(sink.toString) val newProgress = new StreamingQueryProgress( @@ -237,44 +251,16 @@ trait ProgressReporter extends Logging { // Check whether the streaming query's logical plan has only V2 data sources val allStreamingLeaves = logicalPlan.collect { case s: StreamingExecutionRelation => s } - allStreamingLeaves.forall { _.source.isInstanceOf[MicroBatchReader] } + allStreamingLeaves.forall { _.source.isInstanceOf[MicroBatchReadSupport] } } if (onlyDataSourceV2Sources) { - // DataSourceV2ScanExec is the execution plan leaf that is responsible for reading data - // from a V2 source and has a direct reference to the V2 source that generated it. Each - // DataSourceV2ScanExec records the number of rows it has read using SQLMetrics. However, - // just collecting all DataSourceV2ScanExec nodes and getting the metric is not correct as - // a DataSourceV2ScanExec instance may be referred to in the execution plan from two (or - // even multiple times) points and considering it twice will leads to double counting. We - // can't dedup them using their hashcode either because two different instances of - // DataSourceV2ScanExec can have the same hashcode but account for separate sets of - // records read, and deduping them to consider only one of them would be undercounting the - // records read. Therefore the right way to do this is to consider the unique instances of - // DataSourceV2ScanExec (using their identity hash codes) and get metrics from them. - // Hence we calculate in the following way. - // - // 1. Collect all the unique DataSourceV2ScanExec instances using IdentityHashMap. - // - // 2. Extract the source and the number of rows read from the DataSourceV2ScanExec instanes. - // - // 3. Multiple DataSourceV2ScanExec instance may refer to the same source (can happen with - // self-unions or self-joins). Add up the number of rows for each unique source. - val uniqueStreamingExecLeavesMap = - new IdentityHashMap[DataSourceV2ScanExec, DataSourceV2ScanExec]() - - lastExecution.executedPlan.collectLeaves().foreach { - case s: DataSourceV2ScanExec if s.reader.isInstanceOf[BaseStreamingSource] => - uniqueStreamingExecLeavesMap.put(s, s) - case _ => - } - - val sourceToInputRowsTuples = - uniqueStreamingExecLeavesMap.values.asScala.map { execLeaf => - val numRows = execLeaf.metrics.get("numOutputRows").map(_.value).getOrElse(0L) - val source = execLeaf.reader.asInstanceOf[BaseStreamingSource] + val sourceToInputRowsTuples = lastExecution.executedPlan.collect { + case s: DataSourceV2ScanExec if s.readSupport.isInstanceOf[BaseStreamingSource] => + val numRows = s.metrics.get("numOutputRows").map(_.value).getOrElse(0L) + val source = s.readSupport.asInstanceOf[BaseStreamingSource] source -> numRows - }.toSeq + } logDebug("Source -> # input rows\n\t" + sourceToInputRowsTuples.mkString("\n\t")) sumRows(sourceToInputRowsTuples) } else { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/SimpleStreamingScanConfigBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/SimpleStreamingScanConfigBuilder.scala new file mode 100644 index 0000000000000..1be071614d92e --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/SimpleStreamingScanConfigBuilder.scala @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming + +import org.apache.spark.sql.sources.v2.reader.{ScanConfig, ScanConfigBuilder} +import org.apache.spark.sql.types.StructType + +/** + * A very simple [[ScanConfigBuilder]] implementation that creates a simple [[ScanConfig]] to + * carry schema and offsets for streaming data sources. + */ +class SimpleStreamingScanConfigBuilder( + schema: StructType, + start: Offset, + end: Option[Offset] = None) + extends ScanConfigBuilder { + + override def build(): ScanConfig = SimpleStreamingScanConfig(schema, start, end) +} + +case class SimpleStreamingScanConfig( + readSchema: StructType, + start: Offset, + end: Option[Offset]) + extends ScanConfig diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala index 290de873c5cfb..f6c60c1c92124 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala @@ -382,7 +382,7 @@ abstract class StreamExecution( * Blocks the current thread until processing for data from the given `source` has reached at * least the given `Offset`. This method is intended for use primarily when writing tests. */ - private[sql] def awaitOffset(sourceIndex: Int, newOffset: Offset): Unit = { + private[sql] def awaitOffset(sourceIndex: Int, newOffset: Offset, timeoutMs: Long): Unit = { assertAwaitThread() def notDone = { val localCommittedOffsets = committedOffsets @@ -398,7 +398,7 @@ abstract class StreamExecution( while (notDone) { awaitProgressLock.lock() try { - awaitProgressLockCondition.await(100, TimeUnit.MILLISECONDS) + awaitProgressLockCondition.await(timeoutMs, TimeUnit.MILLISECONDS) if (streamDeathCause != null) { throw streamDeathCause } @@ -529,6 +529,7 @@ abstract class StreamExecution( object StreamExecution { val QUERY_ID_KEY = "sql.streaming.queryId" + val IS_CONTINUOUS_PROCESSING = "__is_continuous_processing" def isInterruptionException(e: Throwable): Boolean = e match { // InterruptedIOException - thrown when an I/O operation is interrupted diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingGlobalLimitExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingGlobalLimitExec.scala new file mode 100644 index 0000000000000..bf4af60c8cf03 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingGlobalLimitExec.scala @@ -0,0 +1,102 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.streaming + +import java.util.concurrent.TimeUnit.NANOSECONDS + +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.expressions.GenericInternalRow +import org.apache.spark.sql.catalyst.expressions.UnsafeProjection +import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, Distribution, Partitioning} +import org.apache.spark.sql.catalyst.streaming.InternalOutputModes +import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode} +import org.apache.spark.sql.execution.streaming.state.StateStoreOps +import org.apache.spark.sql.streaming.OutputMode +import org.apache.spark.sql.types.{LongType, NullType, StructField, StructType} +import org.apache.spark.util.CompletionIterator + +/** + * A physical operator for executing a streaming limit, which makes sure no more than streamLimit + * rows are returned. This operator is meant for streams in Append mode only. + */ +case class StreamingGlobalLimitExec( + streamLimit: Long, + child: SparkPlan, + stateInfo: Option[StatefulOperatorStateInfo] = None, + outputMode: Option[OutputMode] = None) + extends UnaryExecNode with StateStoreWriter { + + private val keySchema = StructType(Array(StructField("key", NullType))) + private val valueSchema = StructType(Array(StructField("value", LongType))) + + override protected def doExecute(): RDD[InternalRow] = { + metrics // force lazy init at driver + + assert(outputMode.isDefined && outputMode.get == InternalOutputModes.Append, + "StreamingGlobalLimitExec is only valid for streams in Append output mode") + + child.execute().mapPartitionsWithStateStore( + getStateInfo, + keySchema, + valueSchema, + indexOrdinal = None, + sqlContext.sessionState, + Some(sqlContext.streams.stateStoreCoordinator)) { (store, iter) => + val key = UnsafeProjection.create(keySchema)(new GenericInternalRow(Array[Any](null))) + val numOutputRows = longMetric("numOutputRows") + val numUpdatedStateRows = longMetric("numUpdatedStateRows") + val allUpdatesTimeMs = longMetric("allUpdatesTimeMs") + val commitTimeMs = longMetric("commitTimeMs") + val updatesStartTimeNs = System.nanoTime + + val preBatchRowCount: Long = Option(store.get(key)).map(_.getLong(0)).getOrElse(0L) + var cumulativeRowCount = preBatchRowCount + + val result = iter.filter { r => + val x = cumulativeRowCount < streamLimit + if (x) { + cumulativeRowCount += 1 + } + x + } + + CompletionIterator[InternalRow, Iterator[InternalRow]](result, { + if (cumulativeRowCount > preBatchRowCount) { + numUpdatedStateRows += 1 + numOutputRows += cumulativeRowCount - preBatchRowCount + store.put(key, getValueRow(cumulativeRowCount)) + } + allUpdatesTimeMs += NANOSECONDS.toMillis(System.nanoTime - updatesStartTimeNs) + commitTimeMs += timeTakenMs { store.commit() } + setStoreMetrics(store) + }) + } + } + + override def output: Seq[Attribute] = child.output + + override def outputPartitioning: Partitioning = child.outputPartitioning + + override def requiredChildDistribution: Seq[Distribution] = AllTuples :: Nil + + private def getValueRow(value: Long): UnsafeRow = { + UnsafeProjection.create(valueSchema)(new GenericInternalRow(Array[Any](value))) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala index 24195b5657e8a..4b696dfa57359 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, Statistics} import org.apache.spark.sql.execution.LeafExecNode import org.apache.spark.sql.execution.datasources.DataSource -import org.apache.spark.sql.sources.v2.{ContinuousReadSupport, DataSourceV2} +import org.apache.spark.sql.sources.v2.{ContinuousReadSupportProvider, DataSourceV2} object StreamingRelation { def apply(dataSource: DataSource): StreamingRelation = { @@ -83,7 +83,7 @@ case class StreamingExecutionRelation( // We have to pack in the V1 data source as a shim, for the case when a source implements // continuous processing (which is always V2) but only has V1 microbatch support. We don't -// know at read time whether the query is conntinuous or not, so we need to be able to +// know at read time whether the query is continuous or not, so we need to be able to // swap a V1 relation back in. /** * Used to link a [[DataSourceV2]] into a streaming @@ -113,7 +113,7 @@ case class StreamingRelationV2( * Used to link a [[DataSourceV2]] into a continuous processing execution. */ case class ContinuousExecutionRelation( - source: ContinuousReadSupport, + source: ContinuousReadSupportProvider, extraOptions: Map[String, String], output: Seq[Attribute])(session: SparkSession) extends LeafNode with MultiInstanceRelation { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinHelper.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinHelper.scala index 4aba76cad367e..2d4c3c10e6445 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinHelper.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinHelper.scala @@ -144,7 +144,7 @@ object StreamingSymmetricHashJoinHelper extends Logging { // Join keys of both sides generate rows of the same fields, that is, same sequence of data - // types. If one side (say left side) has a column (say timestmap) that has a watermark on it, + // types. If one side (say left side) has a column (say timestamp) that has a watermark on it, // then it will never consider joining keys that are < state key watermark (i.e. event time // watermark). On the other side (i.e. right side), even if there is no watermark defined, // there has to be an equivalent column (i.e., timestamp). And any right side data that has the diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/WatermarkTracker.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/WatermarkTracker.scala index 80865669558dd..7b30db44a2090 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/WatermarkTracker.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/WatermarkTracker.scala @@ -20,15 +20,68 @@ package org.apache.spark.sql.execution.streaming import scala.collection.mutable import org.apache.spark.internal.Logging +import org.apache.spark.sql.RuntimeConfig import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.internal.SQLConf -class WatermarkTracker extends Logging { +/** + * Policy to define how to choose a new global watermark value if there are + * multiple watermark operators in a streaming query. + */ +sealed trait MultipleWatermarkPolicy { + def chooseGlobalWatermark(operatorWatermarks: Seq[Long]): Long +} + +object MultipleWatermarkPolicy { + val DEFAULT_POLICY_NAME = "min" + + def apply(policyName: String): MultipleWatermarkPolicy = { + policyName.toLowerCase match { + case DEFAULT_POLICY_NAME => MinWatermark + case "max" => MaxWatermark + case _ => + throw new IllegalArgumentException(s"Could not recognize watermark policy '$policyName'") + } + } +} + +/** + * Policy to choose the *min* of the operator watermark values as the global watermark value. + * Note that this is the safe (hence default) policy as the global watermark will advance + * only if all the individual operator watermarks have advanced. In other words, in a + * streaming query with multiple input streams and watermarks defined on all of them, + * the global watermark will advance as slowly as the slowest input. So if there is watermark + * based state cleanup or late-data dropping, then this policy is the most conservative one. + */ +case object MinWatermark extends MultipleWatermarkPolicy { + def chooseGlobalWatermark(operatorWatermarks: Seq[Long]): Long = { + assert(operatorWatermarks.nonEmpty) + operatorWatermarks.min + } +} + +/** + * Policy to choose the *min* of the operator watermark values as the global watermark value. So the + * global watermark will advance if any of the individual operator watermarks has advanced. + * In other words, in a streaming query with multiple input streams and watermarks defined on all + * of them, the global watermark will advance as fast as the fastest input. So if there is watermark + * based state cleanup or late-data dropping, then this policy is the most aggressive one and + * may lead to unexpected behavior if the data of the slow stream is delayed. + */ +case object MaxWatermark extends MultipleWatermarkPolicy { + def chooseGlobalWatermark(operatorWatermarks: Seq[Long]): Long = { + assert(operatorWatermarks.nonEmpty) + operatorWatermarks.max + } +} + +/** Tracks the watermark value of a streaming query based on a given `policy` */ +case class WatermarkTracker(policy: MultipleWatermarkPolicy) extends Logging { private val operatorToWatermarkMap = mutable.HashMap[Int, Long]() - private var watermarkMs: Long = 0 - private var updated = false + private var globalWatermarkMs: Long = 0 def setWatermark(newWatermarkMs: Long): Unit = synchronized { - watermarkMs = newWatermarkMs + globalWatermarkMs = newWatermarkMs } def updateWatermark(executedPlan: SparkPlan): Unit = synchronized { @@ -37,7 +90,6 @@ class WatermarkTracker extends Logging { } if (watermarkOperators.isEmpty) return - watermarkOperators.zipWithIndex.foreach { case (e, index) if e.eventTimeStats.value.count > 0 => logDebug(s"Observed event time stats $index: ${e.eventTimeStats.value}") @@ -58,16 +110,28 @@ class WatermarkTracker extends Logging { // This is the safest option, because only the global watermark is fault-tolerant. Making // it the minimum of all individual watermarks guarantees it will never advance past where // any individual watermark operator would be if it were in a plan by itself. - val newWatermarkMs = operatorToWatermarkMap.minBy(_._2)._2 - if (newWatermarkMs > watermarkMs) { - logInfo(s"Updating eventTime watermark to: $newWatermarkMs ms") - watermarkMs = newWatermarkMs - updated = true + val chosenGlobalWatermark = policy.chooseGlobalWatermark(operatorToWatermarkMap.values.toSeq) + if (chosenGlobalWatermark > globalWatermarkMs) { + logInfo(s"Updating event-time watermark from $globalWatermarkMs to $chosenGlobalWatermark ms") + globalWatermarkMs = chosenGlobalWatermark } else { - logDebug(s"Event time didn't move: $newWatermarkMs < $watermarkMs") - updated = false + logDebug(s"Event time watermark didn't move: $chosenGlobalWatermark < $globalWatermarkMs") } } - def currentWatermark: Long = synchronized { watermarkMs } + def currentWatermark: Long = synchronized { globalWatermarkMs } +} + +object WatermarkTracker { + def apply(conf: RuntimeConfig): WatermarkTracker = { + // If the session has been explicitly configured to use non-default policy then use it, + // otherwise use the default `min` policy as thats the safe thing to do. + // When recovering from a checkpoint location, it is expected that the `conf` will already + // be configured with the value present in the checkpoint. If there is no policy explicitly + // saved in the checkpoint (e.g., old checkpoints), then the default `min` policy is enforced + // through defaults specified in OffsetSeqMetadata.setSessionConf(). + val policyName = conf.get( + SQLConf.STREAMING_MULTIPLE_WATERMARK_POLICY, MultipleWatermarkPolicy.DEFAULT_POLICY_NAME) + new WatermarkTracker(MultipleWatermarkPolicy(policyName)) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala index cfba1001c6de0..9c5c16f4f5d13 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala @@ -18,10 +18,10 @@ package org.apache.spark.sql.execution.streaming import org.apache.spark.sql._ -import org.apache.spark.sql.execution.streaming.sources.ConsoleWriter +import org.apache.spark.sql.execution.streaming.sources.ConsoleWriteSupport import org.apache.spark.sql.sources.{BaseRelation, CreatableRelationProvider, DataSourceRegister} -import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2, StreamWriteSupport} -import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter +import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2, StreamingWriteSupportProvider} +import org.apache.spark.sql.sources.v2.writer.streaming.StreamingWriteSupport import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.StructType @@ -31,16 +31,16 @@ case class ConsoleRelation(override val sqlContext: SQLContext, data: DataFrame) } class ConsoleSinkProvider extends DataSourceV2 - with StreamWriteSupport + with StreamingWriteSupportProvider with DataSourceRegister with CreatableRelationProvider { - override def createStreamWriter( + override def createStreamingWriteSupport( queryId: String, schema: StructType, mode: OutputMode, - options: DataSourceOptions): StreamWriter = { - new ConsoleWriter(schema, options) + options: DataSourceOptions): StreamingWriteSupport = { + new ConsoleWriteSupport(schema, options) } def createRelation( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousCoalesceRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousCoalesceRDD.scala index ba85b355f974f..aec756c0eb2a4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousCoalesceRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousCoalesceRDD.scala @@ -40,7 +40,7 @@ case class ContinuousCoalesceRDDPartition( queueSize, numShuffleWriters, epochIntervalMs, env) val endpoint = env.setupEndpoint(endpointName, receiver) - TaskContext.get().addTaskCompletionListener { ctx => + TaskContext.get().addTaskCompletionListener[Unit] { ctx => env.stop(endpoint) } (receiver, endpoint) @@ -118,7 +118,7 @@ class ContinuousCoalesceRDD( } } - context.addTaskCompletionListener { ctx => + context.addTaskCompletionListener[Unit] { ctx => threadPool.shutdownNow() } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDD.scala index 73868d5967e90..b68f67e0b22d9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDD.scala @@ -19,16 +19,15 @@ package org.apache.spark.sql.execution.streaming.continuous import org.apache.spark._ import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{Row, SQLContext} -import org.apache.spark.sql.catalyst.expressions.UnsafeRow -import org.apache.spark.sql.execution.datasources.v2.{DataSourceRDDPartition, RowToUnsafeInputPartitionReader} +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.sources.v2.reader._ -import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousInputPartitionReader, PartitionOffset} -import org.apache.spark.util.{NextIterator, ThreadUtils} +import org.apache.spark.sql.sources.v2.reader.streaming.ContinuousPartitionReaderFactory +import org.apache.spark.sql.types.StructType +import org.apache.spark.util.NextIterator class ContinuousDataSourceRDDPartition( val index: Int, - val inputPartition: InputPartition[UnsafeRow]) + val inputPartition: InputPartition) extends Partition with Serializable { // This is semantically a lazy val - it's initialized once the first time a call to @@ -51,37 +50,46 @@ class ContinuousDataSourceRDD( sc: SparkContext, dataQueueSize: Int, epochPollIntervalMs: Long, - private val readerInputPartitions: Seq[InputPartition[UnsafeRow]]) - extends RDD[UnsafeRow](sc, Nil) { + private val inputPartitions: Seq[InputPartition], + schema: StructType, + partitionReaderFactory: ContinuousPartitionReaderFactory) + extends RDD[InternalRow](sc, Nil) { override protected def getPartitions: Array[Partition] = { - readerInputPartitions.zipWithIndex.map { + inputPartitions.zipWithIndex.map { case (inputPartition, index) => new ContinuousDataSourceRDDPartition(index, inputPartition) }.toArray } + private def castPartition(split: Partition): ContinuousDataSourceRDDPartition = split match { + case p: ContinuousDataSourceRDDPartition => p + case _ => throw new SparkException(s"[BUG] Not a ContinuousDataSourceRDDPartition: $split") + } + /** * Initialize the shared reader for this partition if needed, then read rows from it until * it returns null to signal the end of the epoch. */ - override def compute(split: Partition, context: TaskContext): Iterator[UnsafeRow] = { + override def compute(split: Partition, context: TaskContext): Iterator[InternalRow] = { // If attempt number isn't 0, this is a task retry, which we don't support. if (context.attemptNumber() != 0) { throw new ContinuousTaskRetryException() } val readerForPartition = { - val partition = split.asInstanceOf[ContinuousDataSourceRDDPartition] + val partition = castPartition(split) if (partition.queueReader == null) { - partition.queueReader = - new ContinuousQueuedDataReader(partition, context, dataQueueSize, epochPollIntervalMs) + val partitionReader = partitionReaderFactory.createReader( + partition.inputPartition) + partition.queueReader = new ContinuousQueuedDataReader( + partition.index, partitionReader, schema, context, dataQueueSize, epochPollIntervalMs) } partition.queueReader } - new NextIterator[UnsafeRow] { - override def getNext(): UnsafeRow = { + new NextIterator[InternalRow] { + override def getNext(): InternalRow = { readerForPartition.next() match { case null => finished = true @@ -95,19 +103,6 @@ class ContinuousDataSourceRDD( } override def getPreferredLocations(split: Partition): Seq[String] = { - split.asInstanceOf[ContinuousDataSourceRDDPartition].inputPartition.preferredLocations() - } -} - -object ContinuousDataSourceRDD { - private[continuous] def getContinuousReader( - reader: InputPartitionReader[UnsafeRow]): ContinuousInputPartitionReader[_] = { - reader match { - case r: ContinuousInputPartitionReader[UnsafeRow] => r - case wrapped: RowToUnsafeInputPartitionReader => - wrapped.rowReader.asInstanceOf[ContinuousInputPartitionReader[Row]] - case _ => - throw new IllegalStateException(s"Unknown continuous reader type ${reader.getClass}") - } + castPartition(split).inputPartition.preferredLocations() } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala index a0bb8292d7766..ccca72667a217 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala @@ -29,13 +29,12 @@ import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, CurrentDate, CurrentTimestamp} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.SQLExecution -import org.apache.spark.sql.execution.datasources.v2.{StreamingDataSourceV2Relation, WriteToDataSourceV2} +import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2ScanExec, StreamingDataSourceV2Relation} import org.apache.spark.sql.execution.streaming.{ContinuousExecutionRelation, StreamingRelationV2, _} import org.apache.spark.sql.sources.v2 -import org.apache.spark.sql.sources.v2.{ContinuousReadSupport, DataSourceOptions, StreamWriteSupport} -import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousReader, PartitionOffset} +import org.apache.spark.sql.sources.v2.{ContinuousReadSupportProvider, DataSourceOptions, StreamingWriteSupportProvider} +import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousReadSupport, PartitionOffset} import org.apache.spark.sql.streaming.{OutputMode, ProcessingTime, Trigger} -import org.apache.spark.sql.types.StructType import org.apache.spark.util.{Clock, Utils} class ContinuousExecution( @@ -43,7 +42,7 @@ class ContinuousExecution( name: String, checkpointRoot: String, analyzedPlan: LogicalPlan, - sink: StreamWriteSupport, + sink: StreamingWriteSupportProvider, trigger: Trigger, triggerClock: Clock, outputMode: OutputMode, @@ -53,7 +52,7 @@ class ContinuousExecution( sparkSession, name, checkpointRoot, analyzedPlan, sink, trigger, triggerClock, outputMode, deleteCheckpointOnStop) { - @volatile protected var continuousSources: Seq[ContinuousReader] = Seq() + @volatile protected var continuousSources: Seq[ContinuousReadSupport] = Seq() override protected def sources: Seq[BaseStreamingSource] = continuousSources // For use only in test harnesses. @@ -63,7 +62,8 @@ class ContinuousExecution( val toExecutionRelationMap = MutableMap[StreamingRelationV2, ContinuousExecutionRelation]() analyzedPlan.transform { case r @ StreamingRelationV2( - source: ContinuousReadSupport, _, extraReaderOptions, output, _) => + source: ContinuousReadSupportProvider, _, extraReaderOptions, output, _) => + // TODO: shall we create `ContinuousReadSupport` here instead of each reconfiguration? toExecutionRelationMap.getOrElseUpdate(r, { ContinuousExecutionRelation(source, extraReaderOptions, output)(sparkSession) }) @@ -148,8 +148,7 @@ class ContinuousExecution( val metadataPath = s"$resolvedCheckpointRoot/sources/$nextSourceId" nextSourceId += 1 - dataSource.createContinuousReader( - java.util.Optional.empty[StructType](), + dataSource.createContinuousReadSupport( metadataPath, new DataSourceOptions(extraReaderOptions.asJava)) } @@ -160,9 +159,9 @@ class ContinuousExecution( var insertedSourceId = 0 val withNewSources = logicalPlan transform { case ContinuousExecutionRelation(source, options, output) => - val reader = continuousSources(insertedSourceId) + val readSupport = continuousSources(insertedSourceId) insertedSourceId += 1 - val newOutput = reader.readSchema().toAttributes + val newOutput = readSupport.fullSchema().toAttributes assert(output.size == newOutput.size, s"Invalid reader: ${Utils.truncatedString(output, ",")} != " + @@ -170,9 +169,10 @@ class ContinuousExecution( replacements ++= output.zip(newOutput) val loggedOffset = offsets.offsets(0) - val realOffset = loggedOffset.map(off => reader.deserializeOffset(off.json)) - reader.setStartOffset(java.util.Optional.ofNullable(realOffset.orNull)) - StreamingDataSourceV2Relation(newOutput, source, options, reader) + val realOffset = loggedOffset.map(off => readSupport.deserializeOffset(off.json)) + val startOffset = realOffset.getOrElse(readSupport.initialOffset) + val scanConfigBuilder = readSupport.newScanConfigBuilder(startOffset) + StreamingDataSourceV2Relation(newOutput, source, options, readSupport, scanConfigBuilder) } // Rewire the plan to use the new attributes that were returned by the source. @@ -185,17 +185,13 @@ class ContinuousExecution( "CurrentTimestamp and CurrentDate not yet supported for continuous processing") } - val writer = sink.createStreamWriter( + val writer = sink.createStreamingWriteSupport( s"$runId", triggerLogicalPlan.schema, outputMode, new DataSourceOptions(extraOptions.asJava)) val withSink = WriteToContinuousDataSource(writer, triggerLogicalPlan) - val reader = withSink.collect { - case StreamingDataSourceV2Relation(_, _, _, r: ContinuousReader) => r - }.head - reportTimeTaken("queryPlanning") { lastExecution = new IncrementalExecution( sparkSessionForQuery, @@ -208,6 +204,13 @@ class ContinuousExecution( lastExecution.executedPlan // Force the lazy generation of execution plan } + val (readSupport, scanConfig) = lastExecution.executedPlan.collect { + case scan: DataSourceV2ScanExec if scan.readSupport.isInstanceOf[ContinuousReadSupport] => + scan.readSupport.asInstanceOf[ContinuousReadSupport] -> scan.scanConfig + }.head + + sparkSessionForQuery.sparkContext.setLocalProperty( + StreamExecution.IS_CONTINUOUS_PROCESSING, true.toString) sparkSessionForQuery.sparkContext.setLocalProperty( ContinuousExecution.START_EPOCH_KEY, currentBatchId.toString) // Add another random ID on top of the run ID, to distinguish epoch coordinators across @@ -223,14 +226,16 @@ class ContinuousExecution( // Use the parent Spark session for the endpoint since it's where this query ID is registered. val epochEndpoint = EpochCoordinatorRef.create( - writer, reader, this, epochCoordinatorId, currentBatchId, sparkSession, SparkEnv.get) + writer, readSupport, this, epochCoordinatorId, currentBatchId, sparkSession, SparkEnv.get) val epochUpdateThread = new Thread(new Runnable { override def run: Unit = { try { triggerExecutor.execute(() => { startTrigger() - if (reader.needsReconfiguration() && state.compareAndSet(ACTIVE, RECONFIGURING)) { + val shouldReconfigure = readSupport.needsReconfiguration(scanConfig) && + state.compareAndSet(ACTIVE, RECONFIGURING) + if (shouldReconfigure) { if (queryExecutionThread.isAlive) { queryExecutionThread.interrupt() } @@ -280,10 +285,12 @@ class ContinuousExecution( * Report ending partition offsets for the given reader at the given epoch. */ def addOffset( - epoch: Long, reader: ContinuousReader, partitionOffsets: Seq[PartitionOffset]): Unit = { + epoch: Long, + readSupport: ContinuousReadSupport, + partitionOffsets: Seq[PartitionOffset]): Unit = { assert(continuousSources.length == 1, "only one continuous source supported currently") - val globalOffset = reader.mergeOffsets(partitionOffsets.toArray) + val globalOffset = readSupport.mergeOffsets(partitionOffsets.toArray) val oldOffset = synchronized { offsetLog.add(epoch, OffsetSeq.fill(globalOffset)) offsetLog.get(epoch - 1) @@ -309,9 +316,12 @@ class ContinuousExecution( def commit(epoch: Long): Unit = { assert(continuousSources.length == 1, "only one continuous source supported currently") assert(offsetLog.get(epoch).isDefined, s"offset for epoch $epoch not reported before commit") + synchronized { + // Record offsets before updating `committedOffsets` + recordTriggerOffsets(from = committedOffsets, to = availableOffsets) if (queryExecutionThread.isAlive) { - commitLog.add(epoch) + commitLog.add(epoch, CommitMetadata()) val offset = continuousSources(0).deserializeOffset(offsetLog.get(epoch).get.offsets(0).get.json) committedOffsets ++= Seq(continuousSources(0) -> offset) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousQueuedDataReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousQueuedDataReader.scala index 8c74b8244d096..65c5fc63c2f46 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousQueuedDataReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousQueuedDataReader.scala @@ -24,9 +24,10 @@ import scala.util.control.NonFatal import org.apache.spark.{SparkEnv, SparkException, TaskContext} import org.apache.spark.internal.Logging -import org.apache.spark.sql.catalyst.expressions.UnsafeRow -import org.apache.spark.sql.sources.v2.reader.{InputPartition, InputPartitionReader} -import org.apache.spark.sql.sources.v2.reader.streaming.PartitionOffset +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.UnsafeProjection +import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousPartitionReader, PartitionOffset} +import org.apache.spark.sql.types.StructType import org.apache.spark.util.ThreadUtils /** @@ -37,22 +38,21 @@ import org.apache.spark.util.ThreadUtils * offsets across epochs. Each compute() should call the next() method here until null is returned. */ class ContinuousQueuedDataReader( - partition: ContinuousDataSourceRDDPartition, + partitionIndex: Int, + reader: ContinuousPartitionReader[InternalRow], + schema: StructType, context: TaskContext, dataQueueSize: Int, epochPollIntervalMs: Long) extends Closeable { - private val reader = partition.inputPartition.createPartitionReader() - // Important sequencing - we must get our starting point before the provider threads start running - private var currentOffset: PartitionOffset = - ContinuousDataSourceRDD.getContinuousReader(reader).getOffset + private var currentOffset: PartitionOffset = reader.getOffset /** * The record types in the read buffer. */ sealed trait ContinuousRecord case object EpochMarker extends ContinuousRecord - case class ContinuousRow(row: UnsafeRow, offset: PartitionOffset) extends ContinuousRecord + case class ContinuousRow(row: InternalRow, offset: PartitionOffset) extends ContinuousRecord private val queue = new ArrayBlockingQueue[ContinuousRecord](dataQueueSize) @@ -66,11 +66,11 @@ class ContinuousQueuedDataReader( epochMarkerExecutor.scheduleWithFixedDelay( epochMarkerGenerator, 0, epochPollIntervalMs, TimeUnit.MILLISECONDS) - private val dataReaderThread = new DataReaderThread + private val dataReaderThread = new DataReaderThread(schema) dataReaderThread.setDaemon(true) dataReaderThread.start() - context.addTaskCompletionListener(_ => { + context.addTaskCompletionListener[Unit](_ => { this.close() }) @@ -79,12 +79,12 @@ class ContinuousQueuedDataReader( } /** - * Return the next UnsafeRow to be read in the current epoch, or null if the epoch is done. + * Return the next row to be read in the current epoch, or null if the epoch is done. * * After returning null, the [[ContinuousDataSourceRDD]] compute() for the following epoch * will call next() again to start getting rows. */ - def next(): UnsafeRow = { + def next(): InternalRow = { val POLL_TIMEOUT_MS = 1000 var currentEntry: ContinuousRecord = null @@ -113,7 +113,7 @@ class ContinuousQueuedDataReader( currentEntry match { case EpochMarker => epochCoordEndpoint.send(ReportPartitionOffset( - partition.index, EpochTracker.getCurrentEpoch.get, currentOffset)) + partitionIndex, EpochTracker.getCurrentEpoch.get, currentOffset)) null case ContinuousRow(row, offset) => currentOffset = offset @@ -128,16 +128,16 @@ class ContinuousQueuedDataReader( /** * The data component of [[ContinuousQueuedDataReader]]. Pushes (row, offset) to the queue when - * a new row arrives to the [[InputPartitionReader]]. + * a new row arrives to the [[ContinuousPartitionReader]]. */ - class DataReaderThread extends Thread( + class DataReaderThread(schema: StructType) extends Thread( s"continuous-reader--${context.partitionId()}--" + s"${context.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY)}") with Logging { @volatile private[continuous] var failureReason: Throwable = _ + private val toUnsafe = UnsafeProjection.create(schema) override def run(): Unit = { TaskContext.setTaskContext(context) - val baseReader = ContinuousDataSourceRDD.getContinuousReader(reader) try { while (!shouldStop()) { if (!reader.next()) { @@ -149,8 +149,9 @@ class ContinuousQueuedDataReader( return } } - - queue.put(ContinuousRow(reader.get().copy(), baseReader.getOffset)) + // `InternalRow#copy` may not be properly implemented, for safety we convert to unsafe row + // before copy here. + queue.put(ContinuousRow(toUnsafe(reader.get()).copy(), reader.getOffset)) } } catch { case _: InterruptedException => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala index 516a563bdcc7a..a6cde2b8a710f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala @@ -17,25 +17,22 @@ package org.apache.spark.sql.execution.streaming.continuous -import scala.collection.JavaConverters._ - import org.json4s.DefaultFormats import org.json4s.jackson.Serialization -import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.util.DateTimeUtils -import org.apache.spark.sql.execution.streaming.{RateStreamOffset, ValueRunTimeMsPair} +import org.apache.spark.sql.execution.streaming.{RateStreamOffset, SimpleStreamingScanConfig, SimpleStreamingScanConfigBuilder, ValueRunTimeMsPair} import org.apache.spark.sql.execution.streaming.sources.RateStreamProvider import org.apache.spark.sql.sources.v2.DataSourceOptions import org.apache.spark.sql.sources.v2.reader._ -import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousInputPartitionReader, ContinuousReader, Offset, PartitionOffset} +import org.apache.spark.sql.sources.v2.reader.streaming._ import org.apache.spark.sql.types.StructType case class RateStreamPartitionOffset( partition: Int, currentValue: Long, currentTimeMs: Long) extends PartitionOffset -class RateStreamContinuousReader(options: DataSourceOptions) - extends ContinuousReader { +class RateStreamContinuousReadSupport(options: DataSourceOptions) extends ContinuousReadSupport { implicit val defaultFormats: DefaultFormats = DefaultFormats val creationTime = System.currentTimeMillis() @@ -57,18 +54,18 @@ class RateStreamContinuousReader(options: DataSourceOptions) RateStreamOffset(Serialization.read[Map[Int, ValueRunTimeMsPair]](json)) } - override def readSchema(): StructType = RateStreamProvider.SCHEMA - - private var offset: Offset = _ + override def fullSchema(): StructType = RateStreamProvider.SCHEMA - override def setStartOffset(offset: java.util.Optional[Offset]): Unit = { - this.offset = offset.orElse(createInitialOffset(numPartitions, creationTime)) + override def newScanConfigBuilder(start: Offset): ScanConfigBuilder = { + new SimpleStreamingScanConfigBuilder(fullSchema(), start) } - override def getStartOffset(): Offset = offset + override def initialOffset: Offset = createInitialOffset(numPartitions, creationTime) - override def planInputPartitions(): java.util.List[InputPartition[Row]] = { - val partitionStartMap = offset match { + override def planInputPartitions(config: ScanConfig): Array[InputPartition] = { + val startOffset = config.asInstanceOf[SimpleStreamingScanConfig].start + + val partitionStartMap = startOffset match { case off: RateStreamOffset => off.partitionToValueAndRunTimeMs case off => throw new IllegalArgumentException( @@ -91,8 +88,12 @@ class RateStreamContinuousReader(options: DataSourceOptions) i, numPartitions, perPartitionRate) - .asInstanceOf[InputPartition[Row]] - }.asJava + }.toArray + } + + override def createContinuousReaderFactory( + config: ScanConfig): ContinuousPartitionReaderFactory = { + RateStreamContinuousReaderFactory } override def commit(end: Offset): Unit = {} @@ -119,37 +120,28 @@ case class RateStreamContinuousInputPartition( partitionIndex: Int, increment: Long, rowsPerSecond: Double) - extends ContinuousInputPartition[Row] { - - override def createContinuousReader(offset: PartitionOffset): InputPartitionReader[Row] = { - val rateStreamOffset = offset.asInstanceOf[RateStreamPartitionOffset] - require(rateStreamOffset.partition == partitionIndex, - s"Expected partitionIndex: $partitionIndex, but got: ${rateStreamOffset.partition}") - new RateStreamContinuousInputPartitionReader( - rateStreamOffset.currentValue, - rateStreamOffset.currentTimeMs, - partitionIndex, - increment, - rowsPerSecond) - } + extends InputPartition - override def createPartitionReader(): InputPartitionReader[Row] = - new RateStreamContinuousInputPartitionReader( - startValue, startTimeMs, partitionIndex, increment, rowsPerSecond) +object RateStreamContinuousReaderFactory extends ContinuousPartitionReaderFactory { + override def createReader(partition: InputPartition): ContinuousPartitionReader[InternalRow] = { + val p = partition.asInstanceOf[RateStreamContinuousInputPartition] + new RateStreamContinuousPartitionReader( + p.startValue, p.startTimeMs, p.partitionIndex, p.increment, p.rowsPerSecond) + } } -class RateStreamContinuousInputPartitionReader( +class RateStreamContinuousPartitionReader( startValue: Long, startTimeMs: Long, partitionIndex: Int, increment: Long, rowsPerSecond: Double) - extends ContinuousInputPartitionReader[Row] { + extends ContinuousPartitionReader[InternalRow] { private var nextReadTime: Long = startTimeMs private val readTimeIncrement: Long = (1000 / rowsPerSecond).toLong private var currentValue = startValue - private var currentRow: Row = null + private var currentRow: InternalRow = null override def next(): Boolean = { currentValue += increment @@ -165,14 +157,14 @@ class RateStreamContinuousInputPartitionReader( return false } - currentRow = Row( - DateTimeUtils.toJavaTimestamp(DateTimeUtils.fromMillis(nextReadTime)), + currentRow = InternalRow( + DateTimeUtils.fromMillis(nextReadTime), currentValue) true } - override def get: Row = currentRow + override def get: InternalRow = currentRow override def close(): Unit = {} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousTextSocketSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousTextSocketSource.scala new file mode 100644 index 0000000000000..28ab2448a6633 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousTextSocketSource.scala @@ -0,0 +1,304 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.continuous + +import java.io.{BufferedReader, InputStreamReader, IOException} +import java.net.Socket +import java.sql.Timestamp +import java.util.Calendar +import javax.annotation.concurrent.GuardedBy + +import scala.collection.mutable.ListBuffer + +import org.json4s.{DefaultFormats, NoTypeHints} +import org.json4s.jackson.Serialization + +import org.apache.spark.SparkEnv +import org.apache.spark.internal.Logging +import org.apache.spark.rpc.RpcEndpointRef +import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.execution.streaming.{Offset => _, _} +import org.apache.spark.sql.execution.streaming.sources.TextSocketReader +import org.apache.spark.sql.sources.v2.DataSourceOptions +import org.apache.spark.sql.sources.v2.reader._ +import org.apache.spark.sql.sources.v2.reader.streaming._ +import org.apache.spark.sql.types.StructType +import org.apache.spark.util.RpcUtils + + +/** + * A ContinuousReadSupport that reads text lines through a TCP socket, designed only for tutorials + * and debugging. This ContinuousReadSupport will *not* work in production applications due to + * multiple reasons, including no support for fault recovery. + * + * The driver maintains a socket connection to the host-port, keeps the received messages in + * buckets and serves the messages to the executors via a RPC endpoint. + */ +class TextSocketContinuousReadSupport(options: DataSourceOptions) + extends ContinuousReadSupport with Logging { + + implicit val defaultFormats: DefaultFormats = DefaultFormats + + private val host: String = options.get("host").get() + private val port: Int = options.get("port").get().toInt + + assert(SparkSession.getActiveSession.isDefined) + private val spark = SparkSession.getActiveSession.get + private val numPartitions = spark.sparkContext.defaultParallelism + + @GuardedBy("this") + private var socket: Socket = _ + + @GuardedBy("this") + private var readThread: Thread = _ + + @GuardedBy("this") + private val buckets = Seq.fill(numPartitions)(new ListBuffer[(String, Timestamp)]) + + @GuardedBy("this") + private var currentOffset: Int = -1 + + // Exposed for tests. + private[spark] var startOffset: TextSocketOffset = _ + + private val recordEndpoint = new ContinuousRecordEndpoint(buckets, this) + @volatile private var endpointRef: RpcEndpointRef = _ + + initialize() + + override def mergeOffsets(offsets: Array[PartitionOffset]): Offset = { + assert(offsets.length == numPartitions) + val offs = offsets + .map(_.asInstanceOf[ContinuousRecordPartitionOffset]) + .sortBy(_.partitionId) + .map(_.offset) + .toList + TextSocketOffset(offs) + } + + override def deserializeOffset(json: String): Offset = { + TextSocketOffset(Serialization.read[List[Int]](json)) + } + + override def initialOffset(): Offset = { + startOffset = TextSocketOffset(List.fill(numPartitions)(0)) + startOffset + } + + override def newScanConfigBuilder(start: Offset): ScanConfigBuilder = { + new SimpleStreamingScanConfigBuilder(fullSchema(), start) + } + + override def fullSchema(): StructType = { + if (includeTimestamp) { + TextSocketReader.SCHEMA_TIMESTAMP + } else { + TextSocketReader.SCHEMA_REGULAR + } + } + + override def planInputPartitions(config: ScanConfig): Array[InputPartition] = { + val startOffset = config.asInstanceOf[SimpleStreamingScanConfig] + .start.asInstanceOf[TextSocketOffset] + recordEndpoint.setStartOffsets(startOffset.offsets) + val endpointName = s"TextSocketContinuousReaderEndpoint-${java.util.UUID.randomUUID()}" + endpointRef = recordEndpoint.rpcEnv.setupEndpoint(endpointName, recordEndpoint) + + val offsets = startOffset match { + case off: TextSocketOffset => off.offsets + case off => + throw new IllegalArgumentException( + s"invalid offset type ${off.getClass} for TextSocketContinuousReader") + } + + if (offsets.size != numPartitions) { + throw new IllegalArgumentException( + s"The previous run contained ${offsets.size} partitions, but" + + s" $numPartitions partitions are currently configured. The numPartitions option" + + " cannot be changed.") + } + + startOffset.offsets.zipWithIndex.map { + case (offset, i) => + TextSocketContinuousInputPartition(endpointName, i, offset, includeTimestamp) + }.toArray + } + + override def createContinuousReaderFactory( + config: ScanConfig): ContinuousPartitionReaderFactory = { + TextSocketReaderFactory + } + + override def commit(end: Offset): Unit = synchronized { + val endOffset = end match { + case off: TextSocketOffset => off + case _ => throw new IllegalArgumentException(s"TextSocketContinuousReader.commit()" + + s"received an offset ($end) that did not originate with an instance of this class") + } + + endOffset.offsets.zipWithIndex.foreach { + case (offset, partition) => + val max = startOffset.offsets(partition) + buckets(partition).size + if (offset > max) { + throw new IllegalStateException("Invalid offset " + offset + " to commit" + + " for partition " + partition + ". Max valid offset: " + max) + } + val n = offset - startOffset.offsets(partition) + buckets(partition).trimStart(n) + } + startOffset = endOffset + recordEndpoint.setStartOffsets(startOffset.offsets) + } + + /** Stop this source. */ + override def stop(): Unit = synchronized { + if (socket != null) { + try { + // Unfortunately, BufferedReader.readLine() cannot be interrupted, so the only way to + // stop the readThread is to close the socket. + socket.close() + } catch { + case e: IOException => + } + socket = null + } + if (endpointRef != null) recordEndpoint.rpcEnv.stop(endpointRef) + } + + private def initialize(): Unit = synchronized { + socket = new Socket(host, port) + val reader = new BufferedReader(new InputStreamReader(socket.getInputStream)) + // Thread continuously reads from a socket and inserts data into buckets + readThread = new Thread(s"TextSocketContinuousReader($host, $port)") { + setDaemon(true) + + override def run(): Unit = { + try { + while (true) { + val line = reader.readLine() + if (line == null) { + // End of file reached + logWarning(s"Stream closed by $host:$port") + return + } + TextSocketContinuousReadSupport.this.synchronized { + currentOffset += 1 + val newData = (line, + Timestamp.valueOf( + TextSocketReader.DATE_FORMAT.format(Calendar.getInstance().getTime())) + ) + buckets(currentOffset % numPartitions) += newData + } + } + } catch { + case e: IOException => + } + } + } + + readThread.start() + } + + override def toString: String = s"TextSocketContinuousReader[host: $host, port: $port]" + + private def includeTimestamp: Boolean = options.getBoolean("includeTimestamp", false) + +} + +/** + * Continuous text socket input partition. + */ +case class TextSocketContinuousInputPartition( + driverEndpointName: String, + partitionId: Int, + startOffset: Int, + includeTimestamp: Boolean) extends InputPartition + + +object TextSocketReaderFactory extends ContinuousPartitionReaderFactory { + + override def createReader(partition: InputPartition): ContinuousPartitionReader[InternalRow] = { + val p = partition.asInstanceOf[TextSocketContinuousInputPartition] + new TextSocketContinuousPartitionReader( + p.driverEndpointName, p.partitionId, p.startOffset, p.includeTimestamp) + } +} + + +/** + * Continuous text socket input partition reader. + * + * Polls the driver endpoint for new records. + */ +class TextSocketContinuousPartitionReader( + driverEndpointName: String, + partitionId: Int, + startOffset: Int, + includeTimestamp: Boolean) + extends ContinuousPartitionReader[InternalRow] { + + private val endpoint = RpcUtils.makeDriverRef( + driverEndpointName, + SparkEnv.get.conf, + SparkEnv.get.rpcEnv) + + private var currentOffset = startOffset + private var current: Option[InternalRow] = None + + override def next(): Boolean = { + try { + current = getRecord + while (current.isEmpty) { + Thread.sleep(100) + current = getRecord + } + currentOffset += 1 + } catch { + case _: InterruptedException => + // Someone's trying to end the task; just let them. + return false + } + true + } + + override def get(): InternalRow = { + current.get + } + + override def close(): Unit = {} + + override def getOffset: PartitionOffset = + ContinuousRecordPartitionOffset(partitionId, currentOffset) + + private def getRecord: Option[InternalRow] = + endpoint.askSync[Option[InternalRow]](GetRecord( + ContinuousRecordPartitionOffset(partitionId, currentOffset))).map(rec => + if (includeTimestamp) { + rec + } else { + InternalRow(rec.get(0, TextSocketReader.SCHEMA_TIMESTAMP) + .asInstanceOf[(String, Timestamp)]._1) + } + ) +} + +case class TextSocketOffset(offsets: List[Int]) extends Offset { + private implicit val formats = Serialization.formats(NoTypeHints) + override def json: String = Serialization.write(offsets) +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousWriteRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousWriteRDD.scala index 76f3f5baa8d56..a08411d746abe 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousWriteRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousWriteRDD.scala @@ -17,13 +17,11 @@ package org.apache.spark.sql.execution.streaming.continuous -import java.util.concurrent.atomic.AtomicLong - import org.apache.spark.{Partition, SparkEnv, TaskContext} import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.execution.datasources.v2.DataWritingSparkTask.{logError, logInfo} -import org.apache.spark.sql.sources.v2.writer.{DataWriter, DataWriterFactory, WriterCommitMessage} +import org.apache.spark.sql.sources.v2.writer.DataWriter +import org.apache.spark.sql.sources.v2.writer.streaming.StreamingDataWriterFactory import org.apache.spark.util.Utils /** @@ -34,7 +32,7 @@ import org.apache.spark.util.Utils * * We keep repeating prev.compute() and writing new epochs until the query is shut down. */ -class ContinuousWriteRDD(var prev: RDD[InternalRow], writeTask: DataWriterFactory[InternalRow]) +class ContinuousWriteRDD(var prev: RDD[InternalRow], writerFactory: StreamingDataWriterFactory) extends RDD[Unit](prev) { override val partitioner = prev.partitioner @@ -47,14 +45,13 @@ class ContinuousWriteRDD(var prev: RDD[InternalRow], writeTask: DataWriterFactor SparkEnv.get) EpochTracker.initializeCurrentEpoch( context.getLocalProperty(ContinuousExecution.START_EPOCH_KEY).toLong) - while (!context.isInterrupted() && !context.isCompleted()) { var dataWriter: DataWriter[InternalRow] = null // write the data and commit this writer. Utils.tryWithSafeFinallyAndFailureCallbacks(block = { try { val dataIterator = prev.compute(split, context) - dataWriter = writeTask.createDataWriter( + dataWriter = writerFactory.createWriter( context.partitionId(), context.taskAttemptId(), EpochTracker.getCurrentEpoch.get) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala index 8877ebeb26735..2238ce26e7b46 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala @@ -23,9 +23,9 @@ import org.apache.spark.SparkEnv import org.apache.spark.internal.Logging import org.apache.spark.rpc.{RpcCallContext, RpcEndpointRef, RpcEnv, ThreadSafeRpcEndpoint} import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousReader, PartitionOffset} +import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousReadSupport, PartitionOffset} import org.apache.spark.sql.sources.v2.writer.WriterCommitMessage -import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter +import org.apache.spark.sql.sources.v2.writer.streaming.StreamingWriteSupport import org.apache.spark.util.RpcUtils private[continuous] sealed trait EpochCoordinatorMessage extends Serializable @@ -82,15 +82,15 @@ private[sql] object EpochCoordinatorRef extends Logging { * Create a reference to a new [[EpochCoordinator]]. */ def create( - writer: StreamWriter, - reader: ContinuousReader, + writeSupport: StreamingWriteSupport, + readSupport: ContinuousReadSupport, query: ContinuousExecution, epochCoordinatorId: String, startEpoch: Long, session: SparkSession, env: SparkEnv): RpcEndpointRef = synchronized { val coordinator = new EpochCoordinator( - writer, reader, query, startEpoch, session, env.rpcEnv) + writeSupport, readSupport, query, startEpoch, session, env.rpcEnv) val ref = env.rpcEnv.setupEndpoint(endpointName(epochCoordinatorId), coordinator) logInfo("Registered EpochCoordinator endpoint") ref @@ -115,8 +115,8 @@ private[sql] object EpochCoordinatorRef extends Logging { * have both committed and reported an end offset for a given epoch. */ private[continuous] class EpochCoordinator( - writer: StreamWriter, - reader: ContinuousReader, + writeSupport: StreamingWriteSupport, + readSupport: ContinuousReadSupport, query: ContinuousExecution, startEpoch: Long, session: SparkSession, @@ -198,7 +198,7 @@ private[continuous] class EpochCoordinator( s"and is ready to be committed. Committing epoch $epoch.") // Sequencing is important here. We must commit to the writer before recording the commit // in the query, or we will end up dropping the commit if we restart in the middle. - writer.commit(epoch, messages.toArray) + writeSupport.commit(epoch, messages.toArray) query.commit(epoch) } @@ -220,7 +220,7 @@ private[continuous] class EpochCoordinator( partitionOffsets.collect { case ((e, _), o) if e == epoch => o } if (thisEpochOffsets.size == numReaderPartitions) { logDebug(s"Epoch $epoch has offsets reported from all partitions: $thisEpochOffsets") - query.addOffset(epoch, reader, thisEpochOffsets.toSeq) + query.addOffset(epoch, readSupport, thisEpochOffsets.toSeq) resolveCommitsAtEpoch(epoch) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSource.scala index 943c731a70529..7ad21cc304e7c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSource.scala @@ -19,13 +19,13 @@ package org.apache.spark.sql.execution.streaming.continuous import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter +import org.apache.spark.sql.sources.v2.writer.streaming.StreamingWriteSupport /** * The logical plan for writing data in a continuous stream. */ case class WriteToContinuousDataSource( - writer: StreamWriter, query: LogicalPlan) extends LogicalPlan { + writeSupport: StreamingWriteSupport, query: LogicalPlan) extends LogicalPlan { override def children: Seq[LogicalPlan] = Seq(query) override def output: Seq[Attribute] = Nil } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSourceExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSourceExec.scala index e0af3a2f1b85d..c216b61383856 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSourceExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSourceExec.scala @@ -19,36 +19,28 @@ package org.apache.spark.sql.execution.streaming.continuous import scala.util.control.NonFatal -import org.apache.spark.{SparkEnv, SparkException, TaskContext} +import org.apache.spark.SparkException import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.execution.SparkPlan -import org.apache.spark.sql.execution.datasources.v2.{DataWritingSparkTask, InternalRowDataWriterFactory} -import org.apache.spark.sql.execution.datasources.v2.DataWritingSparkTask.{logError, logInfo} import org.apache.spark.sql.execution.streaming.StreamExecution -import org.apache.spark.sql.sources.v2.writer._ -import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter -import org.apache.spark.util.Utils +import org.apache.spark.sql.sources.v2.writer.streaming.StreamingWriteSupport /** - * The physical plan for writing data into a continuous processing [[StreamWriter]]. + * The physical plan for writing data into a continuous processing [[StreamingWriteSupport]]. */ -case class WriteToContinuousDataSourceExec(writer: StreamWriter, query: SparkPlan) +case class WriteToContinuousDataSourceExec(writeSupport: StreamingWriteSupport, query: SparkPlan) extends SparkPlan with Logging { override def children: Seq[SparkPlan] = Seq(query) override def output: Seq[Attribute] = Nil override protected def doExecute(): RDD[InternalRow] = { - val writerFactory = writer match { - case w: SupportsWriteInternalRow => w.createInternalRowWriterFactory() - case _ => new InternalRowDataWriterFactory(writer.createWriterFactory(), query.schema) - } - + val writerFactory = writeSupport.createStreamingWriterFactory() val rdd = new ContinuousWriteRDD(query.execute(), writerFactory) - logInfo(s"Start processing data source writer: $writer. " + + logInfo(s"Start processing data source write support: $writeSupport. " + s"The input RDD has ${rdd.partitions.length} partitions.") EpochCoordinatorRef.get( sparkContext.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala index 518223f3cd008..9b13f6398d837 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala @@ -40,7 +40,7 @@ case class ContinuousShuffleReadPartition( queueSize, numShuffleWriters, epochIntervalMs, env) val endpoint = env.setupEndpoint(endpointName, receiver) - TaskContext.get().addTaskCompletionListener { ctx => + TaskContext.get().addTaskCompletionListener[Unit] { ctx => env.stop(endpoint) } (receiver, endpoint) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala index 7fa13c4aa2c01..adf52aba21a04 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala @@ -17,25 +17,22 @@ package org.apache.spark.sql.execution.streaming -import java.{util => ju} -import java.util.Optional import java.util.concurrent.atomic.AtomicInteger import javax.annotation.concurrent.GuardedBy -import scala.collection.JavaConverters._ import scala.collection.mutable.{ArrayBuffer, ListBuffer} import scala.util.control.NonFatal import org.apache.spark.internal.Logging import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.encoders.encoderFor import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeRow} import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, Statistics} import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._ -import org.apache.spark.sql.sources.v2.DataSourceOptions -import org.apache.spark.sql.sources.v2.reader.{InputPartition, InputPartitionReader, SupportsScanUnsafeRow} -import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReader, Offset => OffsetV2} +import org.apache.spark.sql.sources.v2.reader._ +import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReadSupport, Offset => OffsetV2} import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.StructType import org.apache.spark.util.Utils @@ -67,7 +64,7 @@ abstract class MemoryStreamBase[A : Encoder](sqlContext: SQLContext) extends Bas addData(data.toTraversable) } - def readSchema(): StructType = encoder.schema + def fullSchema(): StructType = encoder.schema protected def logicalPlan: LogicalPlan @@ -80,8 +77,7 @@ abstract class MemoryStreamBase[A : Encoder](sqlContext: SQLContext) extends Bas * available. */ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) - extends MemoryStreamBase[A](sqlContext) - with MicroBatchReader with SupportsScanUnsafeRow with Logging { + extends MemoryStreamBase[A](sqlContext) with MicroBatchReadSupport with Logging { protected val logicalPlan: LogicalPlan = StreamingExecutionRelation(this, attributes)(sqlContext.sparkSession) @@ -123,24 +119,22 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) override def toString: String = s"MemoryStream[${Utils.truncatedString(output, ",")}]" - override def setOffsetRange(start: Optional[OffsetV2], end: Optional[OffsetV2]): Unit = { - synchronized { - startOffset = start.orElse(LongOffset(-1)).asInstanceOf[LongOffset] - endOffset = end.orElse(currentOffset).asInstanceOf[LongOffset] - } - } - override def deserializeOffset(json: String): OffsetV2 = LongOffset(json.toLong) - override def getStartOffset: OffsetV2 = synchronized { - if (startOffset.offset == -1) null else startOffset + override def initialOffset: OffsetV2 = LongOffset(-1) + + override def latestOffset(): OffsetV2 = { + if (currentOffset.offset == -1) null else currentOffset } - override def getEndOffset: OffsetV2 = synchronized { - if (endOffset.offset == -1) null else endOffset + override def newScanConfigBuilder(start: OffsetV2, end: OffsetV2): ScanConfigBuilder = { + new SimpleStreamingScanConfigBuilder(fullSchema(), start, Some(end)) } - override def planUnsafeInputPartitions(): ju.List[InputPartition[UnsafeRow]] = { + override def planInputPartitions(config: ScanConfig): Array[InputPartition] = { + val sc = config.asInstanceOf[SimpleStreamingScanConfig] + val startOffset = sc.start.asInstanceOf[LongOffset] + val endOffset = sc.end.get.asInstanceOf[LongOffset] synchronized { // Compute the internal batch numbers to fetch: [startOrdinal, endOrdinal) val startOrdinal = startOffset.offset.toInt + 1 @@ -157,11 +151,15 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) logDebug(generateDebugString(newBlocks.flatten, startOrdinal, endOrdinal)) newBlocks.map { block => - new MemoryStreamInputPartition(block).asInstanceOf[InputPartition[UnsafeRow]] - }.asJava + new MemoryStreamInputPartition(block) + }.toArray } } + override def createReaderFactory(config: ScanConfig): PartitionReaderFactory = { + MemoryStreamReaderFactory + } + private def generateDebugString( rows: Seq[UnsafeRow], startOrdinal: Int, @@ -202,10 +200,12 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) } -class MemoryStreamInputPartition(records: Array[UnsafeRow]) - extends InputPartition[UnsafeRow] { - override def createPartitionReader(): InputPartitionReader[UnsafeRow] = { - new InputPartitionReader[UnsafeRow] { +class MemoryStreamInputPartition(val records: Array[UnsafeRow]) extends InputPartition + +object MemoryStreamReaderFactory extends PartitionReaderFactory { + override def createReader(partition: InputPartition): PartitionReader[InternalRow] = { + val records = partition.asInstanceOf[MemoryStreamInputPartition].records + new PartitionReader[InternalRow] { private var currentIndex = -1 override def next(): Boolean = { @@ -222,60 +222,19 @@ class MemoryStreamInputPartition(records: Array[UnsafeRow]) } /** A common trait for MemorySinks with methods used for testing */ -trait MemorySinkBase extends BaseStreamingSink with Logging { +trait MemorySinkBase extends BaseStreamingSink { def allData: Seq[Row] def latestBatchData: Seq[Row] def dataSinceBatch(sinceBatchId: Long): Seq[Row] def latestBatchId: Option[Long] - - /** - * Truncates the given rows to return at most maxRows rows. - * @param rows The data that may need to be truncated. - * @param batchLimit Number of rows to keep in this batch; the rest will be truncated - * @param sinkLimit Total number of rows kept in this sink, for logging purposes. - * @param batchId The ID of the batch that sent these rows, for logging purposes. - * @return Truncated rows. - */ - protected def truncateRowsIfNeeded( - rows: Array[Row], - batchLimit: Int, - sinkLimit: Int, - batchId: Long): Array[Row] = { - if (rows.length > batchLimit && batchLimit >= 0) { - logWarning(s"Truncating batch $batchId to $batchLimit rows because of sink limit $sinkLimit") - rows.take(batchLimit) - } else { - rows - } - } -} - -/** - * Companion object to MemorySinkBase. - */ -object MemorySinkBase { - val MAX_MEMORY_SINK_ROWS = "maxRows" - val MAX_MEMORY_SINK_ROWS_DEFAULT = -1 - - /** - * Gets the max number of rows a MemorySink should store. This number is based on the memory - * sink row limit option if it is set. If not, we use a large value so that data truncates - * rather than causing out of memory errors. - * @param options Options for writing from which we get the max rows option - * @return The maximum number of rows a memorySink should store. - */ - def getMemorySinkCapacity(options: DataSourceOptions): Int = { - val maxRows = options.getInt(MAX_MEMORY_SINK_ROWS, MAX_MEMORY_SINK_ROWS_DEFAULT) - if (maxRows >= 0) maxRows else Int.MaxValue - 10 - } } /** * A sink that stores the results in memory. This [[Sink]] is primarily intended for use in unit * tests and does not provide durability. */ -class MemorySink(val schema: StructType, outputMode: OutputMode, options: DataSourceOptions) - extends Sink with MemorySinkBase with Logging { +class MemorySink(val schema: StructType, outputMode: OutputMode) extends Sink + with MemorySinkBase with Logging { private case class AddedData(batchId: Long, data: Array[Row]) @@ -283,12 +242,6 @@ class MemorySink(val schema: StructType, outputMode: OutputMode, options: DataSo @GuardedBy("this") private val batches = new ArrayBuffer[AddedData]() - /** The number of rows in this MemorySink. */ - private var numRows = 0 - - /** The capacity in rows of this sink. */ - val sinkCapacity: Int = MemorySinkBase.getMemorySinkCapacity(options) - /** Returns all rows that are stored in this [[Sink]]. */ def allData: Seq[Row] = synchronized { batches.flatMap(_.data) @@ -321,23 +274,14 @@ class MemorySink(val schema: StructType, outputMode: OutputMode, options: DataSo logDebug(s"Committing batch $batchId to $this") outputMode match { case Append | Update => - var rowsToAdd = data.collect() - synchronized { - rowsToAdd = - truncateRowsIfNeeded(rowsToAdd, sinkCapacity - numRows, sinkCapacity, batchId) - val rows = AddedData(batchId, rowsToAdd) - batches += rows - numRows += rowsToAdd.length - } + val rows = AddedData(batchId, data.collect()) + synchronized { batches += rows } case Complete => - var rowsToAdd = data.collect() + val rows = AddedData(batchId, data.collect()) synchronized { - rowsToAdd = truncateRowsIfNeeded(rowsToAdd, sinkCapacity, sinkCapacity, batchId) - val rows = AddedData(batchId, rowsToAdd) batches.clear() batches += rows - numRows = rowsToAdd.length } case _ => @@ -351,7 +295,6 @@ class MemorySink(val schema: StructType, outputMode: OutputMode, options: DataSo def clear(): Unit = synchronized { batches.clear() - numRows = 0 } override def toString(): String = "MemorySink" diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriteSupport.scala similarity index 80% rename from sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriter.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriteSupport.scala index d276403190b3c..833e62f35ede1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriteSupport.scala @@ -17,18 +17,17 @@ package org.apache.spark.sql.execution.streaming.sources -import scala.collection.JavaConverters._ - import org.apache.spark.internal.Logging -import org.apache.spark.sql.{Row, SparkSession} +import org.apache.spark.sql.{Dataset, SparkSession} +import org.apache.spark.sql.catalyst.plans.logical.LocalRelation import org.apache.spark.sql.sources.v2.DataSourceOptions -import org.apache.spark.sql.sources.v2.writer.{DataWriterFactory, WriterCommitMessage} -import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter +import org.apache.spark.sql.sources.v2.writer.WriterCommitMessage +import org.apache.spark.sql.sources.v2.writer.streaming.{StreamingDataWriterFactory, StreamingWriteSupport} import org.apache.spark.sql.types.StructType /** Common methods used to create writes for the the console sink */ -class ConsoleWriter(schema: StructType, options: DataSourceOptions) - extends StreamWriter with Logging { +class ConsoleWriteSupport(schema: StructType, options: DataSourceOptions) + extends StreamingWriteSupport with Logging { // Number of rows to display, by default 20 rows protected val numRowsToShow = options.getInt("numRows", 20) @@ -39,7 +38,7 @@ class ConsoleWriter(schema: StructType, options: DataSourceOptions) assert(SparkSession.getActiveSession.isDefined) protected val spark = SparkSession.getActiveSession.get - def createWriterFactory(): DataWriterFactory[Row] = PackedRowWriterFactory + def createStreamingWriterFactory(): StreamingDataWriterFactory = PackedRowWriterFactory override def commit(epochId: Long, messages: Array[WriterCommitMessage]): Unit = { // We have to print a "Batch" label for the epoch for compatibility with the pre-data source V2 @@ -62,8 +61,7 @@ class ConsoleWriter(schema: StructType, options: DataSourceOptions) println(printMessage) println("-------------------------------------------") // scalastyle:off println - spark - .createDataFrame(rows.toList.asJava, schema) + Dataset.ofRows(spark, LocalRelation(schema.toAttributes, rows)) .show(numRowsToShow, isTruncated) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala index 0bf90b8063326..dbcc4483e5770 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala @@ -17,27 +17,22 @@ package org.apache.spark.sql.execution.streaming.sources -import java.{util => ju} -import java.util.Optional import java.util.concurrent.atomic.AtomicInteger import javax.annotation.concurrent.GuardedBy -import scala.collection.JavaConverters._ -import scala.collection.SortedMap import scala.collection.mutable.ListBuffer import org.json4s.NoTypeHints import org.json4s.jackson.Serialization import org.apache.spark.{SparkEnv, TaskContext} -import org.apache.spark.rpc.{RpcCallContext, RpcEndpointRef, RpcEnv, ThreadSafeRpcEndpoint} -import org.apache.spark.sql.{Encoder, Row, SQLContext} -import org.apache.spark.sql.execution.streaming._ -import org.apache.spark.sql.execution.streaming.sources.ContinuousMemoryStream.GetRecord -import org.apache.spark.sql.sources.v2.{ContinuousReadSupport, DataSourceOptions} -import org.apache.spark.sql.sources.v2.reader.InputPartition -import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousInputPartitionReader, ContinuousReader, Offset, PartitionOffset} -import org.apache.spark.sql.types.StructType +import org.apache.spark.rpc.RpcEndpointRef +import org.apache.spark.sql.{Encoder, SQLContext} +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.execution.streaming.{Offset => _, _} +import org.apache.spark.sql.sources.v2.{ContinuousReadSupportProvider, DataSourceOptions} +import org.apache.spark.sql.sources.v2.reader.{InputPartition, ScanConfig, ScanConfigBuilder} +import org.apache.spark.sql.sources.v2.reader.streaming._ import org.apache.spark.util.RpcUtils /** @@ -49,7 +44,9 @@ import org.apache.spark.util.RpcUtils * the specified offset within the list, or null if that offset doesn't yet have a record. */ class ContinuousMemoryStream[A : Encoder](id: Int, sqlContext: SQLContext, numPartitions: Int = 2) - extends MemoryStreamBase[A](sqlContext) with ContinuousReader with ContinuousReadSupport { + extends MemoryStreamBase[A](sqlContext) + with ContinuousReadSupportProvider with ContinuousReadSupport { + private implicit val formats = Serialization.formats(NoTypeHints) protected val logicalPlan = @@ -60,10 +57,7 @@ class ContinuousMemoryStream[A : Encoder](id: Int, sqlContext: SQLContext, numPa @GuardedBy("this") private val records = Seq.fill(numPartitions)(new ListBuffer[A]) - @GuardedBy("this") - private var startOffset: ContinuousMemoryStreamOffset = _ - - private val recordEndpoint = new RecordEndpoint() + private val recordEndpoint = new ContinuousRecordEndpoint(records, this) @volatile private var endpointRef: RpcEndpointRef = _ def addData(data: TraversableOnce[A]): Offset = synchronized { @@ -76,15 +70,8 @@ class ContinuousMemoryStream[A : Encoder](id: Int, sqlContext: SQLContext, numPa ContinuousMemoryStreamOffset((0 until numPartitions).map(i => (i, records(i).size)).toMap) } - override def setStartOffset(start: Optional[Offset]): Unit = synchronized { - // Inferred initial offset is position 0 in each partition. - startOffset = start.orElse { - ContinuousMemoryStreamOffset((0 until numPartitions).map(i => (i, 0)).toMap) - }.asInstanceOf[ContinuousMemoryStreamOffset] - } - - override def getStartOffset: Offset = synchronized { - startOffset + override def initialOffset(): Offset = { + ContinuousMemoryStreamOffset((0 until numPartitions).map(i => (i, 0)).toMap) } override def deserializeOffset(json: String): ContinuousMemoryStreamOffset = { @@ -94,60 +81,48 @@ class ContinuousMemoryStream[A : Encoder](id: Int, sqlContext: SQLContext, numPa override def mergeOffsets(offsets: Array[PartitionOffset]): ContinuousMemoryStreamOffset = { ContinuousMemoryStreamOffset( offsets.map { - case ContinuousMemoryStreamPartitionOffset(part, num) => (part, num) + case ContinuousRecordPartitionOffset(part, num) => (part, num) }.toMap ) } - override def planInputPartitions(): ju.List[InputPartition[Row]] = { + override def newScanConfigBuilder(start: Offset): ScanConfigBuilder = { + new SimpleStreamingScanConfigBuilder(fullSchema(), start) + } + + override def planInputPartitions(config: ScanConfig): Array[InputPartition] = { + val startOffset = config.asInstanceOf[SimpleStreamingScanConfig] + .start.asInstanceOf[ContinuousMemoryStreamOffset] synchronized { val endpointName = s"ContinuousMemoryStreamRecordEndpoint-${java.util.UUID.randomUUID()}-$id" endpointRef = recordEndpoint.rpcEnv.setupEndpoint(endpointName, recordEndpoint) startOffset.partitionNums.map { - case (part, index) => - new ContinuousMemoryStreamInputPartition( - endpointName, part, index): InputPartition[Row] - }.toList.asJava + case (part, index) => ContinuousMemoryStreamInputPartition(endpointName, part, index) + }.toArray } } + override def createContinuousReaderFactory( + config: ScanConfig): ContinuousPartitionReaderFactory = { + ContinuousMemoryStreamReaderFactory + } + override def stop(): Unit = { if (endpointRef != null) recordEndpoint.rpcEnv.stop(endpointRef) } override def commit(end: Offset): Unit = {} - // ContinuousReadSupport implementation + // ContinuousReadSupportProvider implementation // This is necessary because of how StreamTest finds the source for AddDataMemory steps. - def createContinuousReader( - schema: Optional[StructType], + override def createContinuousReadSupport( checkpointLocation: String, - options: DataSourceOptions): ContinuousReader = { - this - } - - /** - * Endpoint for executors to poll for records. - */ - private class RecordEndpoint extends ThreadSafeRpcEndpoint { - override val rpcEnv: RpcEnv = SparkEnv.get.rpcEnv - - override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { - case GetRecord(ContinuousMemoryStreamPartitionOffset(part, index)) => - ContinuousMemoryStream.this.synchronized { - val buf = records(part) - val record = if (buf.size <= index) None else Some(buf(index)) - - context.reply(record.map(Row(_))) - } - } - } + options: DataSourceOptions): ContinuousReadSupport = this } object ContinuousMemoryStream { - case class GetRecord(offset: ContinuousMemoryStreamPartitionOffset) protected val memoryStreamId = new AtomicInteger(0) def apply[A : Encoder](implicit sqlContext: SQLContext): ContinuousMemoryStream[A] = @@ -160,12 +135,16 @@ object ContinuousMemoryStream { /** * An input partition for continuous memory stream. */ -class ContinuousMemoryStreamInputPartition( +case class ContinuousMemoryStreamInputPartition( driverEndpointName: String, partition: Int, - startOffset: Int) extends InputPartition[Row] { - override def createPartitionReader: ContinuousMemoryStreamInputPartitionReader = - new ContinuousMemoryStreamInputPartitionReader(driverEndpointName, partition, startOffset) + startOffset: Int) extends InputPartition + +object ContinuousMemoryStreamReaderFactory extends ContinuousPartitionReaderFactory { + override def createReader(partition: InputPartition): ContinuousPartitionReader[InternalRow] = { + val p = partition.asInstanceOf[ContinuousMemoryStreamInputPartition] + new ContinuousMemoryStreamPartitionReader(p.driverEndpointName, p.partition, p.startOffset) + } } /** @@ -173,17 +152,17 @@ class ContinuousMemoryStreamInputPartition( * * Polls the driver endpoint for new records. */ -class ContinuousMemoryStreamInputPartitionReader( +class ContinuousMemoryStreamPartitionReader( driverEndpointName: String, partition: Int, - startOffset: Int) extends ContinuousInputPartitionReader[Row] { + startOffset: Int) extends ContinuousPartitionReader[InternalRow] { private val endpoint = RpcUtils.makeDriverRef( driverEndpointName, SparkEnv.get.conf, SparkEnv.get.rpcEnv) private var currentOffset = startOffset - private var current: Option[Row] = None + private var current: Option[InternalRow] = None // Defense-in-depth against failing to propagate the task context. Since it's not inheritable, // we have to do a bit of error prone work to get it into every thread used by continuous @@ -203,16 +182,16 @@ class ContinuousMemoryStreamInputPartitionReader( true } - override def get(): Row = current.get + override def get(): InternalRow = current.get override def close(): Unit = {} - override def getOffset: ContinuousMemoryStreamPartitionOffset = - ContinuousMemoryStreamPartitionOffset(partition, currentOffset) + override def getOffset: ContinuousRecordPartitionOffset = + ContinuousRecordPartitionOffset(partition, currentOffset) - private def getRecord: Option[Row] = - endpoint.askSync[Option[Row]]( - GetRecord(ContinuousMemoryStreamPartitionOffset(partition, currentOffset))) + private def getRecord: Option[InternalRow] = + endpoint.askSync[Option[InternalRow]]( + GetRecord(ContinuousRecordPartitionOffset(partition, currentOffset))) } case class ContinuousMemoryStreamOffset(partitionNums: Map[Int, Int]) @@ -220,6 +199,3 @@ case class ContinuousMemoryStreamOffset(partitionNums: Map[Int, Int]) private implicit val formats = Serialization.formats(NoTypeHints) override def json(): String = Serialization.write(partitionNums) } - -case class ContinuousMemoryStreamPartitionOffset(partition: Int, numProcessed: Int) - extends PartitionOffset diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriteSupportProvider.scala similarity index 79% rename from sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterProvider.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriteSupportProvider.scala index bc9b6d93ce7d9..4218fd51ad206 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriteSupportProvider.scala @@ -17,14 +17,14 @@ package org.apache.spark.sql.execution.streaming.sources -import org.apache.spark.sql.{Encoder, ForeachWriter, SparkSession} +import org.apache.spark.sql.{ForeachWriter, SparkSession} import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder} +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.execution.python.PythonForeachWriter -import org.apache.spark.sql.sources.v2.{DataSourceOptions, StreamWriteSupport} -import org.apache.spark.sql.sources.v2.writer.{DataWriter, DataWriterFactory, SupportsWriteInternalRow, WriterCommitMessage} -import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter +import org.apache.spark.sql.sources.v2.{DataSourceOptions, StreamingWriteSupportProvider} +import org.apache.spark.sql.sources.v2.writer.{DataWriter, WriterCommitMessage} +import org.apache.spark.sql.sources.v2.writer.streaming.{StreamingDataWriterFactory, StreamingWriteSupport} import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.StructType @@ -37,20 +37,21 @@ import org.apache.spark.sql.types.StructType * a [[ExpressionEncoder]] or a direct converter function. * @tparam T The expected type of the sink. */ -case class ForeachWriterProvider[T]( +case class ForeachWriteSupportProvider[T]( writer: ForeachWriter[T], - converter: Either[ExpressionEncoder[T], InternalRow => T]) extends StreamWriteSupport { + converter: Either[ExpressionEncoder[T], InternalRow => T]) + extends StreamingWriteSupportProvider { - override def createStreamWriter( + override def createStreamingWriteSupport( queryId: String, schema: StructType, mode: OutputMode, - options: DataSourceOptions): StreamWriter = { - new StreamWriter with SupportsWriteInternalRow { + options: DataSourceOptions): StreamingWriteSupport = { + new StreamingWriteSupport { override def commit(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {} override def abort(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {} - override def createInternalRowWriterFactory(): DataWriterFactory[InternalRow] = { + override def createStreamingWriterFactory(): StreamingDataWriterFactory = { val rowConverter: InternalRow => T = converter match { case Left(enc) => val boundEnc = enc.resolveAndBind( @@ -68,16 +69,16 @@ case class ForeachWriterProvider[T]( } } -object ForeachWriterProvider { +object ForeachWriteSupportProvider { def apply[T]( writer: ForeachWriter[T], - encoder: ExpressionEncoder[T]): ForeachWriterProvider[_] = { + encoder: ExpressionEncoder[T]): ForeachWriteSupportProvider[_] = { writer match { case pythonWriter: PythonForeachWriter => - new ForeachWriterProvider[UnsafeRow]( + new ForeachWriteSupportProvider[UnsafeRow]( pythonWriter, Right((x: InternalRow) => x.asInstanceOf[UnsafeRow])) case _ => - new ForeachWriterProvider[T](writer, Left(encoder)) + new ForeachWriteSupportProvider[T](writer, Left(encoder)) } } } @@ -85,8 +86,8 @@ object ForeachWriterProvider { case class ForeachWriterFactory[T]( writer: ForeachWriter[T], rowConverter: InternalRow => T) - extends DataWriterFactory[InternalRow] { - override def createDataWriter( + extends StreamingDataWriterFactory { + override def createWriter( partitionId: Int, taskId: Long, epochId: Long): ForeachDataWriter[T] = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/MicroBatchWritSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/MicroBatchWritSupport.scala new file mode 100644 index 0000000000000..9f88416871f8e --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/MicroBatchWritSupport.scala @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.sources + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.sources.v2.writer.{BatchWriteSupport, DataWriter, DataWriterFactory, WriterCommitMessage} +import org.apache.spark.sql.sources.v2.writer.streaming.{StreamingDataWriterFactory, StreamingWriteSupport} + +/** + * A [[BatchWriteSupport]] used to hook V2 stream writers into a microbatch plan. It implements + * the non-streaming interface, forwarding the epoch ID determined at construction to a wrapped + * streaming write support. + */ +class MicroBatchWritSupport(eppchId: Long, val writeSupport: StreamingWriteSupport) + extends BatchWriteSupport { + + override def commit(messages: Array[WriterCommitMessage]): Unit = { + writeSupport.commit(eppchId, messages) + } + + override def abort(messages: Array[WriterCommitMessage]): Unit = { + writeSupport.abort(eppchId, messages) + } + + override def createBatchWriterFactory(): DataWriterFactory = { + new MicroBatchWriterFactory(eppchId, writeSupport.createStreamingWriterFactory()) + } +} + +class MicroBatchWriterFactory(epochId: Long, streamingWriterFactory: StreamingDataWriterFactory) + extends DataWriterFactory { + + override def createWriter(partitionId: Int, taskId: Long): DataWriter[InternalRow] = { + streamingWriterFactory.createWriter(partitionId, taskId, epochId) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/MicroBatchWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/MicroBatchWriter.scala deleted file mode 100644 index 56f7ff25cbed0..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/MicroBatchWriter.scala +++ /dev/null @@ -1,54 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.streaming.sources - -import org.apache.spark.sql.Row -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.sources.v2.writer.{DataSourceWriter, DataWriterFactory, SupportsWriteInternalRow, WriterCommitMessage} -import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter - -/** - * A [[DataSourceWriter]] used to hook V2 stream writers into a microbatch plan. It implements - * the non-streaming interface, forwarding the batch ID determined at construction to a wrapped - * streaming writer. - */ -class MicroBatchWriter(batchId: Long, writer: StreamWriter) extends DataSourceWriter { - override def commit(messages: Array[WriterCommitMessage]): Unit = { - writer.commit(batchId, messages) - } - - override def abort(messages: Array[WriterCommitMessage]): Unit = writer.abort(batchId, messages) - - override def createWriterFactory(): DataWriterFactory[Row] = writer.createWriterFactory() -} - -class InternalRowMicroBatchWriter(batchId: Long, writer: StreamWriter) - extends DataSourceWriter with SupportsWriteInternalRow { - override def commit(messages: Array[WriterCommitMessage]): Unit = { - writer.commit(batchId, messages) - } - - override def abort(messages: Array[WriterCommitMessage]): Unit = writer.abort(batchId, messages) - - override def createInternalRowWriterFactory(): DataWriterFactory[InternalRow] = - writer match { - case w: SupportsWriteInternalRow => w.createInternalRowWriterFactory() - case _ => throw new IllegalStateException( - "InternalRowMicroBatchWriter should only be created with base writer support") - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/PackedRowWriterFactory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/PackedRowWriterFactory.scala index b501d90c81f06..ac3c71cc222b1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/PackedRowWriterFactory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/PackedRowWriterFactory.scala @@ -20,21 +20,22 @@ package org.apache.spark.sql.execution.streaming.sources import scala.collection.mutable import org.apache.spark.internal.Logging -import org.apache.spark.sql.Row -import org.apache.spark.sql.sources.v2.writer.{DataSourceWriter, DataWriter, DataWriterFactory, WriterCommitMessage} +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.sources.v2.writer.{BatchWriteSupport, DataWriter, DataWriterFactory, WriterCommitMessage} +import org.apache.spark.sql.sources.v2.writer.streaming.StreamingDataWriterFactory /** * A simple [[DataWriterFactory]] whose tasks just pack rows into the commit message for delivery - * to a [[DataSourceWriter]] on the driver. + * to a [[BatchWriteSupport]] on the driver. * * Note that, because it sends all rows to the driver, this factory will generally be unsuitable * for production-quality sinks. It's intended for use in tests. */ -case object PackedRowWriterFactory extends DataWriterFactory[Row] { - override def createDataWriter( +case object PackedRowWriterFactory extends StreamingDataWriterFactory { + override def createWriter( partitionId: Int, taskId: Long, - epochId: Long): DataWriter[Row] = { + epochId: Long): DataWriter[InternalRow] = { new PackedRowDataWriter() } } @@ -43,15 +44,16 @@ case object PackedRowWriterFactory extends DataWriterFactory[Row] { * Commit message for a [[PackedRowDataWriter]], containing all the rows written in the most * recent interval. */ -case class PackedRowCommitMessage(rows: Array[Row]) extends WriterCommitMessage +case class PackedRowCommitMessage(rows: Array[InternalRow]) extends WriterCommitMessage /** * A simple [[DataWriter]] that just sends all the rows it's received as a commit message. */ -class PackedRowDataWriter() extends DataWriter[Row] with Logging { - private val data = mutable.Buffer[Row]() +class PackedRowDataWriter() extends DataWriter[InternalRow] with Logging { + private val data = mutable.Buffer[InternalRow]() - override def write(row: Row): Unit = data.append(row) + // Spark reuses the same `InternalRow` instance, here we copy it before buffer it. + override def write(row: InternalRow): Unit = data.append(row.copy()) override def commit(): PackedRowCommitMessage = { val msg = PackedRowCommitMessage(data.toArray) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateControlMicroBatchReadSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateControlMicroBatchReadSupport.scala new file mode 100644 index 0000000000000..90680ea38fbd6 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateControlMicroBatchReadSupport.scala @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.sources + +import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReadSupport, Offset} + +// A special `MicroBatchReadSupport` that can get latestOffset with a start offset. +trait RateControlMicroBatchReadSupport extends MicroBatchReadSupport { + + override def latestOffset(): Offset = { + throw new IllegalAccessException( + "latestOffset should not be called for RateControlMicroBatchReadSupport") + } + + def latestOffset(start: Offset): Offset +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchReadSupport.scala similarity index 76% rename from sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchReader.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchReadSupport.scala index b393c48baee8d..f5364047adff1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchReadSupport.scala @@ -19,26 +19,24 @@ package org.apache.spark.sql.execution.streaming.sources import java.io._ import java.nio.charset.StandardCharsets -import java.util.Optional import java.util.concurrent.TimeUnit -import scala.collection.JavaConverters._ - import org.apache.commons.io.IOUtils import org.apache.spark.internal.Logging import org.apache.spark.network.util.JavaUtils -import org.apache.spark.sql.{Row, SparkSession} +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.sources.v2.DataSourceOptions import org.apache.spark.sql.sources.v2.reader._ -import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReader, Offset} +import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReadSupport, Offset} import org.apache.spark.sql.types.StructType import org.apache.spark.util.{ManualClock, SystemClock} -class RateStreamMicroBatchReader(options: DataSourceOptions, checkpointLocation: String) - extends MicroBatchReader with Logging { +class RateStreamMicroBatchReadSupport(options: DataSourceOptions, checkpointLocation: String) + extends MicroBatchReadSupport with Logging { import RateStreamProvider._ private[sources] val clock = { @@ -105,38 +103,30 @@ class RateStreamMicroBatchReader(options: DataSourceOptions, checkpointLocation: @volatile private var lastTimeMs: Long = creationTimeMs - private var start: LongOffset = _ - private var end: LongOffset = _ - - override def readSchema(): StructType = SCHEMA + override def initialOffset(): Offset = LongOffset(0L) - override def setOffsetRange(start: Optional[Offset], end: Optional[Offset]): Unit = { - this.start = start.orElse(LongOffset(0L)).asInstanceOf[LongOffset] - this.end = end.orElse { - val now = clock.getTimeMillis() - if (lastTimeMs < now) { - lastTimeMs = now - } - LongOffset(TimeUnit.MILLISECONDS.toSeconds(lastTimeMs - creationTimeMs)) - }.asInstanceOf[LongOffset] - } - - override def getStartOffset(): Offset = { - if (start == null) throw new IllegalStateException("start offset not set") - start - } - override def getEndOffset(): Offset = { - if (end == null) throw new IllegalStateException("end offset not set") - end + override def latestOffset(): Offset = { + val now = clock.getTimeMillis() + if (lastTimeMs < now) { + lastTimeMs = now + } + LongOffset(TimeUnit.MILLISECONDS.toSeconds(lastTimeMs - creationTimeMs)) } override def deserializeOffset(json: String): Offset = { LongOffset(json.toLong) } - override def planInputPartitions(): java.util.List[InputPartition[Row]] = { - val startSeconds = LongOffset.convert(start).map(_.offset).getOrElse(0L) - val endSeconds = LongOffset.convert(end).map(_.offset).getOrElse(0L) + override def fullSchema(): StructType = SCHEMA + + override def newScanConfigBuilder(start: Offset, end: Offset): ScanConfigBuilder = { + new SimpleStreamingScanConfigBuilder(fullSchema(), start, Some(end)) + } + + override def planInputPartitions(config: ScanConfig): Array[InputPartition] = { + val sc = config.asInstanceOf[SimpleStreamingScanConfig] + val startSeconds = sc.start.asInstanceOf[LongOffset].offset + val endSeconds = sc.end.get.asInstanceOf[LongOffset].offset assert(startSeconds <= endSeconds, s"startSeconds($startSeconds) > endSeconds($endSeconds)") if (endSeconds > maxSeconds) { throw new ArithmeticException("Integer overflow. Max offset with " + @@ -152,7 +142,7 @@ class RateStreamMicroBatchReader(options: DataSourceOptions, checkpointLocation: s"rangeStart: $rangeStart, rangeEnd: $rangeEnd") if (rangeStart == rangeEnd) { - return List.empty.asJava + return Array.empty } val localStartTimeMs = creationTimeMs + TimeUnit.SECONDS.toMillis(startSeconds) @@ -169,8 +159,11 @@ class RateStreamMicroBatchReader(options: DataSourceOptions, checkpointLocation: (0 until numPartitions).map { p => new RateStreamMicroBatchInputPartition( p, numPartitions, rangeStart, rangeEnd, localStartTimeMs, relativeMsPerValue) - : InputPartition[Row] - }.toList.asJava + }.toArray + } + + override def createReaderFactory(config: ScanConfig): PartitionReaderFactory = { + RateStreamMicroBatchReaderFactory } override def commit(end: Offset): Unit = {} @@ -182,41 +175,40 @@ class RateStreamMicroBatchReader(options: DataSourceOptions, checkpointLocation: s"numPartitions=${options.get(NUM_PARTITIONS).orElse("default")}" } -class RateStreamMicroBatchInputPartition( +case class RateStreamMicroBatchInputPartition( partitionId: Int, numPartitions: Int, rangeStart: Long, rangeEnd: Long, localStartTimeMs: Long, - relativeMsPerValue: Double) extends InputPartition[Row] { + relativeMsPerValue: Double) extends InputPartition - override def createPartitionReader(): InputPartitionReader[Row] = - new RateStreamMicroBatchInputPartitionReader(partitionId, numPartitions, rangeStart, rangeEnd, - localStartTimeMs, relativeMsPerValue) +object RateStreamMicroBatchReaderFactory extends PartitionReaderFactory { + override def createReader(partition: InputPartition): PartitionReader[InternalRow] = { + val p = partition.asInstanceOf[RateStreamMicroBatchInputPartition] + new RateStreamMicroBatchPartitionReader(p.partitionId, p.numPartitions, p.rangeStart, + p.rangeEnd, p.localStartTimeMs, p.relativeMsPerValue) + } } -class RateStreamMicroBatchInputPartitionReader( +class RateStreamMicroBatchPartitionReader( partitionId: Int, numPartitions: Int, rangeStart: Long, rangeEnd: Long, localStartTimeMs: Long, - relativeMsPerValue: Double) extends InputPartitionReader[Row] { + relativeMsPerValue: Double) extends PartitionReader[InternalRow] { private var count: Long = 0 override def next(): Boolean = { rangeStart + partitionId + numPartitions * count < rangeEnd } - override def get(): Row = { + override def get(): InternalRow = { val currValue = rangeStart + partitionId + numPartitions * count count += 1 val relative = math.round((currValue - rangeStart) * relativeMsPerValue) - Row( - DateTimeUtils.toJavaTimestamp( - DateTimeUtils.fromMillis(relative + localStartTimeMs)), - currValue - ) + InternalRow(DateTimeUtils.fromMillis(relative + localStartTimeMs), currValue) } override def close(): Unit = {} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProvider.scala index 6bdd492f0cb35..6942dfbfe0ecf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProvider.scala @@ -17,14 +17,11 @@ package org.apache.spark.sql.execution.streaming.sources -import java.util.Optional - import org.apache.spark.network.util.JavaUtils -import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.execution.streaming.continuous.RateStreamContinuousReader +import org.apache.spark.sql.execution.streaming.continuous.RateStreamContinuousReadSupport import org.apache.spark.sql.sources.DataSourceRegister import org.apache.spark.sql.sources.v2._ -import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousReader, MicroBatchReader} +import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousReadSupport, MicroBatchReadSupport} import org.apache.spark.sql.types._ /** @@ -42,13 +39,12 @@ import org.apache.spark.sql.types._ * be resource constrained, and `numPartitions` can be tweaked to help reach the desired speed. */ class RateStreamProvider extends DataSourceV2 - with MicroBatchReadSupport with ContinuousReadSupport with DataSourceRegister { + with MicroBatchReadSupportProvider with ContinuousReadSupportProvider with DataSourceRegister { import RateStreamProvider._ - override def createMicroBatchReader( - schema: Optional[StructType], + override def createMicroBatchReadSupport( checkpointLocation: String, - options: DataSourceOptions): MicroBatchReader = { + options: DataSourceOptions): MicroBatchReadSupport = { if (options.get(ROWS_PER_SECOND).isPresent) { val rowsPerSecond = options.get(ROWS_PER_SECOND).get().toLong if (rowsPerSecond <= 0) { @@ -74,17 +70,14 @@ class RateStreamProvider extends DataSourceV2 } } - if (schema.isPresent) { - throw new AnalysisException("The rate source does not support a user-specified schema.") - } - - new RateStreamMicroBatchReader(options, checkpointLocation) + new RateStreamMicroBatchReadSupport(options, checkpointLocation) } - override def createContinuousReader( - schema: Optional[StructType], + override def createContinuousReadSupport( checkpointLocation: String, - options: DataSourceOptions): ContinuousReader = new RateStreamContinuousReader(options) + options: DataSourceOptions): ContinuousReadSupport = { + new RateStreamContinuousReadSupport(options) + } override def shortName(): String = "rate" } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala index 29f8cca476722..c50dc7bcb8da1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala @@ -25,14 +25,16 @@ import scala.util.control.NonFatal import org.apache.spark.internal.Logging import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.encoders.RowEncoder import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, Statistics} import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils import org.apache.spark.sql.catalyst.streaming.InternalOutputModes.{Append, Complete, Update} import org.apache.spark.sql.execution.streaming.{MemorySinkBase, Sink} -import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2, StreamWriteSupport} +import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2, StreamingWriteSupportProvider} import org.apache.spark.sql.sources.v2.writer._ -import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter +import org.apache.spark.sql.sources.v2.writer.streaming.{StreamingDataWriterFactory, StreamingWriteSupport} import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.StructType @@ -40,13 +42,15 @@ import org.apache.spark.sql.types.StructType * A sink that stores the results in memory. This [[Sink]] is primarily intended for use in unit * tests and does not provide durability. */ -class MemorySinkV2 extends DataSourceV2 with StreamWriteSupport with MemorySinkBase with Logging { - override def createStreamWriter( +class MemorySinkV2 extends DataSourceV2 with StreamingWriteSupportProvider + with MemorySinkBase with Logging { + + override def createStreamingWriteSupport( queryId: String, schema: StructType, mode: OutputMode, - options: DataSourceOptions): StreamWriter = { - new MemoryStreamWriter(this, mode, options) + options: DataSourceOptions): StreamingWriteSupport = { + new MemoryStreamingWriteSupport(this, mode, schema) } private case class AddedData(batchId: Long, data: Array[Row]) @@ -55,9 +59,6 @@ class MemorySinkV2 extends DataSourceV2 with StreamWriteSupport with MemorySinkB @GuardedBy("this") private val batches = new ArrayBuffer[AddedData]() - /** The number of rows in this MemorySink. */ - private var numRows = 0 - /** Returns all rows that are stored in this [[Sink]]. */ def allData: Seq[Row] = synchronized { batches.flatMap(_.data) @@ -84,11 +85,7 @@ class MemorySinkV2 extends DataSourceV2 with StreamWriteSupport with MemorySinkB }.mkString("\n") } - def write( - batchId: Long, - outputMode: OutputMode, - newRows: Array[Row], - sinkCapacity: Int): Unit = { + def write(batchId: Long, outputMode: OutputMode, newRows: Array[Row]): Unit = { val notCommitted = synchronized { latestBatchId.isEmpty || batchId > latestBatchId.get } @@ -96,21 +93,14 @@ class MemorySinkV2 extends DataSourceV2 with StreamWriteSupport with MemorySinkB logDebug(s"Committing batch $batchId to $this") outputMode match { case Append | Update => - synchronized { - val rowsToAdd = - truncateRowsIfNeeded(newRows, sinkCapacity - numRows, sinkCapacity, batchId) - val rows = AddedData(batchId, rowsToAdd) - batches += rows - numRows += rowsToAdd.length - } + val rows = AddedData(batchId, newRows) + synchronized { batches += rows } case Complete => + val rows = AddedData(batchId, newRows) synchronized { - val rowsToAdd = truncateRowsIfNeeded(newRows, sinkCapacity, sinkCapacity, batchId) - val rows = AddedData(batchId, rowsToAdd) batches.clear() batches += rows - numRows = rowsToAdd.length } case _ => @@ -124,52 +114,27 @@ class MemorySinkV2 extends DataSourceV2 with StreamWriteSupport with MemorySinkB def clear(): Unit = synchronized { batches.clear() - numRows = 0 } override def toString(): String = "MemorySinkV2" } -case class MemoryWriterCommitMessage(partition: Int, data: Seq[Row]) extends WriterCommitMessage {} - -class MemoryWriter( - sink: MemorySinkV2, - batchId: Long, - outputMode: OutputMode, - options: DataSourceOptions) - extends DataSourceWriter with Logging { +case class MemoryWriterCommitMessage(partition: Int, data: Seq[Row]) + extends WriterCommitMessage {} - val sinkCapacity: Int = MemorySinkBase.getMemorySinkCapacity(options) - - override def createWriterFactory: MemoryWriterFactory = MemoryWriterFactory(outputMode) - - def commit(messages: Array[WriterCommitMessage]): Unit = { - val newRows = messages.flatMap { - case message: MemoryWriterCommitMessage => message.data - } - sink.write(batchId, outputMode, newRows, sinkCapacity) - } +class MemoryStreamingWriteSupport( + val sink: MemorySinkV2, outputMode: OutputMode, schema: StructType) + extends StreamingWriteSupport { - override def abort(messages: Array[WriterCommitMessage]): Unit = { - // Don't accept any of the new input. + override def createStreamingWriterFactory: MemoryWriterFactory = { + MemoryWriterFactory(outputMode, schema) } -} - -class MemoryStreamWriter( - val sink: MemorySinkV2, - outputMode: OutputMode, - options: DataSourceOptions) - extends StreamWriter { - - val sinkCapacity: Int = MemorySinkBase.getMemorySinkCapacity(options) - - override def createWriterFactory: MemoryWriterFactory = MemoryWriterFactory(outputMode) override def commit(epochId: Long, messages: Array[WriterCommitMessage]): Unit = { val newRows = messages.flatMap { case message: MemoryWriterCommitMessage => message.data } - sink.write(epochId, outputMode, newRows, sinkCapacity) + sink.write(epochId, outputMode, newRows) } override def abort(epochId: Long, messages: Array[WriterCommitMessage]): Unit = { @@ -177,22 +142,32 @@ class MemoryStreamWriter( } } -case class MemoryWriterFactory(outputMode: OutputMode) extends DataWriterFactory[Row] { - override def createDataWriter( +case class MemoryWriterFactory(outputMode: OutputMode, schema: StructType) + extends DataWriterFactory with StreamingDataWriterFactory { + + override def createWriter( + partitionId: Int, + taskId: Long): DataWriter[InternalRow] = { + new MemoryDataWriter(partitionId, outputMode, schema) + } + + override def createWriter( partitionId: Int, taskId: Long, - epochId: Long): DataWriter[Row] = { - new MemoryDataWriter(partitionId, outputMode) + epochId: Long): DataWriter[InternalRow] = { + createWriter(partitionId, taskId) } } -class MemoryDataWriter(partition: Int, outputMode: OutputMode) - extends DataWriter[Row] with Logging { +class MemoryDataWriter(partition: Int, outputMode: OutputMode, schema: StructType) + extends DataWriter[InternalRow] with Logging { private val data = mutable.Buffer[Row]() - override def write(row: Row): Unit = { - data.append(row) + private val encoder = RowEncoder(schema).resolveAndBind() + + override def write(row: InternalRow): Unit = { + data.append(encoder.fromRow(row)) } override def commit(): MemoryWriterCommitMessage = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/socket.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/socket.scala index 91e3b7179c34a..b2a573eae504a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/socket.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/socket.scala @@ -19,26 +19,28 @@ package org.apache.spark.sql.execution.streaming.sources import java.io.{BufferedReader, InputStreamReader, IOException} import java.net.Socket -import java.sql.Timestamp import java.text.SimpleDateFormat -import java.util.{Calendar, List => JList, Locale, Optional} +import java.util.{Calendar, Locale} import java.util.concurrent.atomic.AtomicBoolean import javax.annotation.concurrent.GuardedBy -import scala.collection.JavaConverters._ import scala.collection.mutable.ListBuffer import scala.util.{Failure, Success, Try} import org.apache.spark.internal.Logging import org.apache.spark.sql._ -import org.apache.spark.sql.execution.streaming.LongOffset +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.execution.streaming.{LongOffset, SimpleStreamingScanConfig, SimpleStreamingScanConfigBuilder} +import org.apache.spark.sql.execution.streaming.continuous.TextSocketContinuousReadSupport import org.apache.spark.sql.sources.DataSourceRegister -import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2, MicroBatchReadSupport} -import org.apache.spark.sql.sources.v2.reader.{InputPartition, InputPartitionReader} -import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReader, Offset} +import org.apache.spark.sql.sources.v2.{ContinuousReadSupportProvider, DataSourceOptions, DataSourceV2, MicroBatchReadSupportProvider} +import org.apache.spark.sql.sources.v2.reader._ +import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousReadSupport, MicroBatchReadSupport, Offset} import org.apache.spark.sql.types.{StringType, StructField, StructType, TimestampType} +import org.apache.spark.unsafe.types.UTF8String -object TextSocketMicroBatchReader { +object TextSocketReader { val SCHEMA_REGULAR = StructType(StructField("value", StringType) :: Nil) val SCHEMA_TIMESTAMP = StructType(StructField("value", StringType) :: StructField("timestamp", TimestampType) :: Nil) @@ -46,14 +48,12 @@ object TextSocketMicroBatchReader { } /** - * A MicroBatchReader that reads text lines through a TCP socket, designed only for tutorials and - * debugging. This MicroBatchReader will *not* work in production applications due to multiple - * reasons, including no support for fault recovery. + * A MicroBatchReadSupport that reads text lines through a TCP socket, designed only for tutorials + * and debugging. This MicroBatchReadSupport will *not* work in production applications due to + * multiple reasons, including no support for fault recovery. */ -class TextSocketMicroBatchReader(options: DataSourceOptions) extends MicroBatchReader with Logging { - - private var startOffset: Offset = _ - private var endOffset: Offset = _ +class TextSocketMicroBatchReadSupport(options: DataSourceOptions) + extends MicroBatchReadSupport with Logging { private val host: String = options.get("host").get() private val port: Int = options.get("port").get().toInt @@ -69,7 +69,7 @@ class TextSocketMicroBatchReader(options: DataSourceOptions) extends MicroBatchR * Stored in a ListBuffer to facilitate removing committed batches. */ @GuardedBy("this") - private val batches = new ListBuffer[(String, Timestamp)] + private val batches = new ListBuffer[(UTF8String, Long)] @GuardedBy("this") private var currentOffset: LongOffset = LongOffset(-1L) @@ -99,10 +99,10 @@ class TextSocketMicroBatchReader(options: DataSourceOptions) extends MicroBatchR logWarning(s"Stream closed by $host:$port") return } - TextSocketMicroBatchReader.this.synchronized { - val newData = (line, - Timestamp.valueOf( - TextSocketMicroBatchReader.DATE_FORMAT.format(Calendar.getInstance().getTime())) + TextSocketMicroBatchReadSupport.this.synchronized { + val newData = ( + UTF8String.fromString(line), + DateTimeUtils.fromMillis(Calendar.getInstance().getTimeInMillis) ) currentOffset += 1 batches.append(newData) @@ -116,37 +116,30 @@ class TextSocketMicroBatchReader(options: DataSourceOptions) extends MicroBatchR readThread.start() } - override def setOffsetRange(start: Optional[Offset], end: Optional[Offset]): Unit = synchronized { - startOffset = start.orElse(LongOffset(-1L)) - endOffset = end.orElse(currentOffset) - } + override def initialOffset(): Offset = LongOffset(-1L) - override def getStartOffset(): Offset = { - Option(startOffset).getOrElse(throw new IllegalStateException("start offset not set")) - } - - override def getEndOffset(): Offset = { - Option(endOffset).getOrElse(throw new IllegalStateException("end offset not set")) - } + override def latestOffset(): Offset = currentOffset override def deserializeOffset(json: String): Offset = { LongOffset(json.toLong) } - override def readSchema(): StructType = { + override def fullSchema(): StructType = { if (options.getBoolean("includeTimestamp", false)) { - TextSocketMicroBatchReader.SCHEMA_TIMESTAMP + TextSocketReader.SCHEMA_TIMESTAMP } else { - TextSocketMicroBatchReader.SCHEMA_REGULAR + TextSocketReader.SCHEMA_REGULAR } } - override def planInputPartitions(): JList[InputPartition[Row]] = { - assert(startOffset != null && endOffset != null, - "start offset and end offset should already be set before create read tasks.") + override def newScanConfigBuilder(start: Offset, end: Offset): ScanConfigBuilder = { + new SimpleStreamingScanConfigBuilder(fullSchema(), start, Some(end)) + } - val startOrdinal = LongOffset.convert(startOffset).get.offset.toInt + 1 - val endOrdinal = LongOffset.convert(endOffset).get.offset.toInt + 1 + override def planInputPartitions(config: ScanConfig): Array[InputPartition] = { + val sc = config.asInstanceOf[SimpleStreamingScanConfig] + val startOrdinal = sc.start.asInstanceOf[LongOffset].offset.toInt + 1 + val endOrdinal = sc.end.get.asInstanceOf[LongOffset].offset.toInt + 1 // Internal buffer only holds the batches after lastOffsetCommitted val rawList = synchronized { @@ -163,31 +156,34 @@ class TextSocketMicroBatchReader(options: DataSourceOptions) extends MicroBatchR val spark = SparkSession.getActiveSession.get val numPartitions = spark.sparkContext.defaultParallelism - val slices = Array.fill(numPartitions)(new ListBuffer[(String, Timestamp)]) + val slices = Array.fill(numPartitions)(new ListBuffer[(UTF8String, Long)]) rawList.zipWithIndex.foreach { case (r, idx) => slices(idx % numPartitions).append(r) } - (0 until numPartitions).map { i => - val slice = slices(i) - new InputPartition[Row] { - override def createPartitionReader(): InputPartitionReader[Row] = - new InputPartitionReader[Row] { - private var currentIdx = -1 + slices.map(TextSocketInputPartition) + } - override def next(): Boolean = { - currentIdx += 1 - currentIdx < slice.size - } + override def createReaderFactory(config: ScanConfig): PartitionReaderFactory = { + new PartitionReaderFactory { + override def createReader(partition: InputPartition): PartitionReader[InternalRow] = { + val slice = partition.asInstanceOf[TextSocketInputPartition].slice + new PartitionReader[InternalRow] { + private var currentIdx = -1 - override def get(): Row = { - Row(slice(currentIdx)._1, slice(currentIdx)._2) - } + override def next(): Boolean = { + currentIdx += 1 + currentIdx < slice.size + } - override def close(): Unit = {} + override def get(): InternalRow = { + InternalRow(slice(currentIdx)._1, slice(currentIdx)._2) } + + override def close(): Unit = {} + } } - }.toList.asJava + } } override def commit(end: Offset): Unit = synchronized { @@ -223,8 +219,11 @@ class TextSocketMicroBatchReader(options: DataSourceOptions) extends MicroBatchR override def toString: String = s"TextSocketV2[host: $host, port: $port]" } +case class TextSocketInputPartition(slice: ListBuffer[(UTF8String, Long)]) extends InputPartition + class TextSocketSourceProvider extends DataSourceV2 - with MicroBatchReadSupport with DataSourceRegister with Logging { + with MicroBatchReadSupportProvider with ContinuousReadSupportProvider + with DataSourceRegister with Logging { private def checkParameters(params: DataSourceOptions): Unit = { logWarning("The socket source should not be used for production applications! " + @@ -244,16 +243,18 @@ class TextSocketSourceProvider extends DataSourceV2 } } - override def createMicroBatchReader( - schema: Optional[StructType], + override def createMicroBatchReadSupport( checkpointLocation: String, - options: DataSourceOptions): MicroBatchReader = { + options: DataSourceOptions): MicroBatchReadSupport = { checkParameters(options) - if (schema.isPresent) { - throw new AnalysisException("The socket source does not support a user-specified schema.") - } + new TextSocketMicroBatchReadSupport(options) + } - new TextSocketMicroBatchReader(options) + override def createContinuousReadSupport( + checkpointLocation: String, + options: DataSourceOptions): ContinuousReadSupport = { + checkParameters(options) + new TextSocketContinuousReadSupport(options) } /** String that represents the format that this data source provider uses. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/FlatMapGroupsWithStateExecHelper.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/FlatMapGroupsWithStateExecHelper.scala new file mode 100644 index 0000000000000..0a16a3819b778 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/FlatMapGroupsWithStateExecHelper.scala @@ -0,0 +1,247 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.state + +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.execution.ObjectOperator +import org.apache.spark.sql.execution.streaming.GroupStateImpl +import org.apache.spark.sql.execution.streaming.GroupStateImpl.NO_TIMESTAMP +import org.apache.spark.sql.types._ + + +object FlatMapGroupsWithStateExecHelper { + + val supportedVersions = Seq(1, 2) + val legacyVersion = 1 + + /** + * Class to capture deserialized state and timestamp return by the state manager. + * This is intended for reuse. + */ + case class StateData( + var keyRow: UnsafeRow = null, + var stateRow: UnsafeRow = null, + var stateObj: Any = null, + var timeoutTimestamp: Long = -1) { + + private[FlatMapGroupsWithStateExecHelper] def withNew( + newKeyRow: UnsafeRow, + newStateRow: UnsafeRow, + newStateObj: Any, + newTimeout: Long): this.type = { + keyRow = newKeyRow + stateRow = newStateRow + stateObj = newStateObj + timeoutTimestamp = newTimeout + this + } + } + + /** Interface for interacting with state data of FlatMapGroupsWithState */ + sealed trait StateManager extends Serializable { + def stateSchema: StructType + def getState(store: StateStore, keyRow: UnsafeRow): StateData + def putState(store: StateStore, keyRow: UnsafeRow, state: Any, timeoutTimestamp: Long): Unit + def removeState(store: StateStore, keyRow: UnsafeRow): Unit + def getAllState(store: StateStore): Iterator[StateData] + } + + def createStateManager( + stateEncoder: ExpressionEncoder[Any], + shouldStoreTimestamp: Boolean, + stateFormatVersion: Int): StateManager = { + stateFormatVersion match { + case 1 => new StateManagerImplV1(stateEncoder, shouldStoreTimestamp) + case 2 => new StateManagerImplV2(stateEncoder, shouldStoreTimestamp) + case _ => throw new IllegalArgumentException(s"Version $stateFormatVersion is invalid") + } + } + + // =============================================================================================== + // =========================== Private implementations of StateManager =========================== + // =============================================================================================== + + /** Commmon methods for StateManager implementations */ + private abstract class StateManagerImplBase(shouldStoreTimestamp: Boolean) + extends StateManager { + + protected def stateSerializerExprs: Seq[Expression] + protected def stateDeserializerExpr: Expression + protected def timeoutTimestampOrdinalInRow: Int + + /** Get deserialized state and corresponding timeout timestamp for a key */ + override def getState(store: StateStore, keyRow: UnsafeRow): StateData = { + val stateRow = store.get(keyRow) + stateDataForGets.withNew(keyRow, stateRow, getStateObject(stateRow), getTimestamp(stateRow)) + } + + /** Put state and timeout timestamp for a key */ + override def putState(store: StateStore, key: UnsafeRow, state: Any, timestamp: Long): Unit = { + val stateRow = getStateRow(state) + setTimestamp(stateRow, timestamp) + store.put(key, stateRow) + } + + override def removeState(store: StateStore, keyRow: UnsafeRow): Unit = { + store.remove(keyRow) + } + + override def getAllState(store: StateStore): Iterator[StateData] = { + val stateData = StateData() + store.getRange(None, None).map { p => + stateData.withNew(p.key, p.value, getStateObject(p.value), getTimestamp(p.value)) + } + } + + private lazy val stateSerializerFunc = ObjectOperator.serializeObjectToRow(stateSerializerExprs) + private lazy val stateDeserializerFunc = { + ObjectOperator.deserializeRowToObject(stateDeserializerExpr, stateSchema.toAttributes) + } + private lazy val stateDataForGets = StateData() + + protected def getStateObject(row: UnsafeRow): Any = { + if (row != null) stateDeserializerFunc(row) else null + } + + protected def getStateRow(obj: Any): UnsafeRow = { + stateSerializerFunc(obj) + } + + /** Returns the timeout timestamp of a state row is set */ + private def getTimestamp(stateRow: UnsafeRow): Long = { + if (shouldStoreTimestamp && stateRow != null) { + stateRow.getLong(timeoutTimestampOrdinalInRow) + } else NO_TIMESTAMP + } + + /** Set the timestamp in a state row */ + private def setTimestamp(stateRow: UnsafeRow, timeoutTimestamps: Long): Unit = { + if (shouldStoreTimestamp) stateRow.setLong(timeoutTimestampOrdinalInRow, timeoutTimestamps) + } + } + + /** + * Version 1 of the StateManager which stores the user-defined state as flattened columns in + * the UnsafeRow. Say the user-defined state has 3 fields - col1, col2, col3. The + * unsafe rows will look like this. + * + * UnsafeRow[ col1 | col2 | col3 | timestamp ] + * + * The limitation of this format is that timestamp cannot be set when the user-defined + * state has been removed. This is because the columns cannot be collectively marked to be + * empty/null. + */ + private class StateManagerImplV1( + stateEncoder: ExpressionEncoder[Any], + shouldStoreTimestamp: Boolean) extends StateManagerImplBase(shouldStoreTimestamp) { + + private val timestampTimeoutAttribute = + AttributeReference("timeoutTimestamp", dataType = IntegerType, nullable = false)() + + private val stateAttributes: Seq[Attribute] = { + val encSchemaAttribs = stateEncoder.schema.toAttributes + if (shouldStoreTimestamp) encSchemaAttribs :+ timestampTimeoutAttribute else encSchemaAttribs + } + + override val stateSchema: StructType = stateAttributes.toStructType + + override val timeoutTimestampOrdinalInRow: Int = { + stateAttributes.indexOf(timestampTimeoutAttribute) + } + + override val stateSerializerExprs: Seq[Expression] = { + val encoderSerializer = stateEncoder.namedExpressions + if (shouldStoreTimestamp) { + encoderSerializer :+ Literal(GroupStateImpl.NO_TIMESTAMP) + } else { + encoderSerializer + } + } + + override val stateDeserializerExpr: Expression = { + // Note that this must be done in the driver, as resolving and binding of deserializer + // expressions to the encoded type can be safely done only in the driver. + stateEncoder.resolveAndBind().deserializer + } + + override protected def getStateRow(obj: Any): UnsafeRow = { + require(obj != null, "State object cannot be null") + super.getStateRow(obj) + } + } + + /** + * Version 2 of the StateManager which stores the user-defined state as a nested struct + * in the UnsafeRow. Say the user-defined state has 3 fields - col1, col2, col3. The + * unsafe rows will look like this. + * ___________________________ + * | | + * | V + * UnsafeRow[ nested-struct | timestamp | UnsafeRow[ col1 | col2 | col3 ] ] + * + * This allows the entire user-defined state to be collectively marked as empty/null, + * thus allowing timestamp to be set without requiring the state to be present. + */ + private class StateManagerImplV2( + stateEncoder: ExpressionEncoder[Any], + shouldStoreTimestamp: Boolean) extends StateManagerImplBase(shouldStoreTimestamp) { + + /** Schema of the state rows saved in the state store */ + override val stateSchema: StructType = { + var schema = new StructType().add("groupState", stateEncoder.schema, nullable = true) + if (shouldStoreTimestamp) schema = schema.add("timeoutTimestamp", LongType, nullable = false) + schema + } + + // Ordinals of the information stored in the state row + private val nestedStateOrdinal = 0 + override val timeoutTimestampOrdinalInRow = 1 + + override val stateSerializerExprs: Seq[Expression] = { + val boundRefToSpecificInternalRow = BoundReference( + 0, stateEncoder.serializer.head.collect { case b: BoundReference => b.dataType }.head, true) + + val nestedStateSerExpr = + CreateNamedStruct(stateEncoder.namedExpressions.flatMap(e => Seq(Literal(e.name), e))) + + val nullSafeNestedStateSerExpr = { + val nullLiteral = Literal(null, nestedStateSerExpr.dataType) + CaseWhen(Seq(IsNull(boundRefToSpecificInternalRow) -> nullLiteral), nestedStateSerExpr) + } + + if (shouldStoreTimestamp) { + Seq(nullSafeNestedStateSerExpr, Literal(GroupStateImpl.NO_TIMESTAMP)) + } else { + Seq(nullSafeNestedStateSerExpr) + } + } + + override val stateDeserializerExpr: Expression = { + // Note that this must be done in the driver, as resolving and binding of deserializer + // expressions to the encoded type can be safely done only in the driver. + val boundRefToNestedState = + BoundReference(nestedStateOrdinal, stateEncoder.schema, nullable = true) + val deserExpr = stateEncoder.resolveAndBind().deserializer.transformUp { + case BoundReference(ordinal, _, _) => GetStructField(boundRefToNestedState, ordinal) + } + val nullLiteral = Literal(null, deserExpr.dataType) + CaseWhen(Seq(IsNull(boundRefToNestedState) -> nullLiteral), elseValue = deserExpr) + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala index 118c82aa75e68..92a2480e8b017 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala @@ -18,7 +18,9 @@ package org.apache.spark.sql.execution.streaming.state import java.io._ +import java.util import java.util.Locale +import java.util.concurrent.atomic.LongAdder import scala.collection.JavaConverters._ import scala.collection.mutable @@ -164,7 +166,16 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit } override def metrics: StateStoreMetrics = { - StateStoreMetrics(mapToUpdate.size(), SizeEstimator.estimate(mapToUpdate), Map.empty) + // NOTE: we provide estimation of cache size as "memoryUsedBytes", and size of state for + // current version as "stateOnCurrentVersionSizeBytes" + val metricsFromProvider: Map[String, Long] = getMetricsForProvider() + + val customMetrics = metricsFromProvider.flatMap { case (name, value) => + // just allow searching from list cause the list is small enough + supportedCustomMetrics.find(_.name == name).map(_ -> value) + } + (metricStateOnCurrentVersionSizeBytes -> SizeEstimator.estimate(mapToUpdate)) + + StateStoreMetrics(mapToUpdate.size(), metricsFromProvider("memoryUsedBytes"), customMetrics) } /** @@ -179,6 +190,12 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit } } + def getMetricsForProvider(): Map[String, Long] = synchronized { + Map("memoryUsedBytes" -> SizeEstimator.estimate(loadedMaps), + metricLoadedMapCacheHit.name -> loadedMapCacheHitCount.sum(), + metricLoadedMapCacheMiss.name -> loadedMapCacheMissCount.sum()) + } + /** Get the state store for making updates to create a new `version` of the store. */ override def getStore(version: Long): StateStore = synchronized { require(version >= 0, "Version cannot be less than 0") @@ -203,6 +220,7 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit this.valueSchema = valueSchema this.storeConf = storeConf this.hadoopConf = hadoopConf + this.numberOfVersionsToRetainInMemory = storeConf.maxVersionsToRetainInMemory fm.mkdirs(baseDir) } @@ -220,11 +238,12 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit } override def close(): Unit = { - loadedMaps.values.foreach(_.clear()) + loadedMaps.values.asScala.foreach(_.clear()) } override def supportedCustomMetrics: Seq[StateStoreCustomMetric] = { - Nil + metricStateOnCurrentVersionSizeBytes :: metricLoadedMapCacheHit :: metricLoadedMapCacheMiss :: + Nil } override def toString(): String = { @@ -239,18 +258,34 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit @volatile private var valueSchema: StructType = _ @volatile private var storeConf: StateStoreConf = _ @volatile private var hadoopConf: Configuration = _ + @volatile private var numberOfVersionsToRetainInMemory: Int = _ - private lazy val loadedMaps = new mutable.HashMap[Long, MapType] + private lazy val loadedMaps = new util.TreeMap[Long, MapType](Ordering[Long].reverse) private lazy val baseDir = stateStoreId.storeCheckpointLocation() private lazy val fm = CheckpointFileManager.create(baseDir, hadoopConf) private lazy val sparkConf = Option(SparkEnv.get).map(_.conf).getOrElse(new SparkConf) + private val loadedMapCacheHitCount: LongAdder = new LongAdder + private val loadedMapCacheMissCount: LongAdder = new LongAdder + + private lazy val metricStateOnCurrentVersionSizeBytes: StateStoreCustomSizeMetric = + StateStoreCustomSizeMetric("stateOnCurrentVersionSizeBytes", + "estimated size of state only on current version") + + private lazy val metricLoadedMapCacheHit: StateStoreCustomMetric = + StateStoreCustomSumMetric("loadedMapCacheHitCount", + "count of cache hit on states cache in provider") + + private lazy val metricLoadedMapCacheMiss: StateStoreCustomMetric = + StateStoreCustomSumMetric("loadedMapCacheMissCount", + "count of cache miss on states cache in provider") + private case class StoreFile(version: Long, path: Path, isSnapshot: Boolean) private def commitUpdates(newVersion: Long, map: MapType, output: DataOutputStream): Unit = { synchronized { finalizeDeltaFile(output) - loadedMaps.put(newVersion, map) + putStateIntoStateCacheMap(newVersion, map) } } @@ -260,7 +295,7 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit */ private[state] def latestIterator(): Iterator[UnsafeRowPair] = synchronized { val versionsInFiles = fetchFiles().map(_.version).toSet - val versionsLoaded = loadedMaps.keySet + val versionsLoaded = loadedMaps.keySet.asScala val allKnownVersions = versionsInFiles ++ versionsLoaded val unsafeRowTuple = new UnsafeRowPair() if (allKnownVersions.nonEmpty) { @@ -270,12 +305,45 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit } else Iterator.empty } + /** This method is intended to be only used for unit test(s). DO NOT TOUCH ELEMENTS IN MAP! */ + private[state] def getLoadedMaps(): util.SortedMap[Long, MapType] = synchronized { + // shallow copy as a minimal guard + loadedMaps.clone().asInstanceOf[util.SortedMap[Long, MapType]] + } + + private def putStateIntoStateCacheMap(newVersion: Long, map: MapType): Unit = synchronized { + if (numberOfVersionsToRetainInMemory <= 0) { + if (loadedMaps.size() > 0) loadedMaps.clear() + return + } + + while (loadedMaps.size() > numberOfVersionsToRetainInMemory) { + loadedMaps.remove(loadedMaps.lastKey()) + } + + val size = loadedMaps.size() + if (size == numberOfVersionsToRetainInMemory) { + val versionIdForLastKey = loadedMaps.lastKey() + if (versionIdForLastKey > newVersion) { + // this is the only case which we can avoid putting, because new version will be placed to + // the last key and it should be evicted right away + return + } else if (versionIdForLastKey < newVersion) { + // this case needs removal of the last key before putting new one + loadedMaps.remove(versionIdForLastKey) + } + } + + loadedMaps.put(newVersion, map) + } + /** Load the required version of the map data from the backing files */ private def loadMap(version: Long): MapType = { // Shortcut if the map for this version is already there to avoid a redundant put. - val loadedCurrentVersionMap = synchronized { loadedMaps.get(version) } + val loadedCurrentVersionMap = synchronized { Option(loadedMaps.get(version)) } if (loadedCurrentVersionMap.isDefined) { + loadedMapCacheHitCount.increment() return loadedCurrentVersionMap.get } @@ -283,10 +351,12 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit "Reading snapshot file and delta files if needed..." + "Note that this is normal for the first batch of starting query.") + loadedMapCacheMissCount.increment() + val (result, elapsedMs) = Utils.timeTakenMs { val snapshotCurrentVersionMap = readSnapshotFile(version) if (snapshotCurrentVersionMap.isDefined) { - synchronized { loadedMaps.put(version, snapshotCurrentVersionMap.get) } + synchronized { putStateIntoStateCacheMap(version, snapshotCurrentVersionMap.get) } return snapshotCurrentVersionMap.get } @@ -302,7 +372,7 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit lastAvailableMap = Some(new MapType) } else { lastAvailableMap = - synchronized { loadedMaps.get(lastAvailableVersion) } + synchronized { Option(loadedMaps.get(lastAvailableVersion)) } .orElse(readSnapshotFile(lastAvailableVersion)) } } @@ -314,7 +384,7 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit updateFromDeltaFile(deltaVersion, resultMap) } - synchronized { loadedMaps.put(version, resultMap) } + synchronized { putStateIntoStateCacheMap(version, resultMap) } resultMap } @@ -506,7 +576,7 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit val lastVersion = files.last.version val deltaFilesForLastVersion = filesForVersion(files, lastVersion).filter(_.isSnapshot == false) - synchronized { loadedMaps.get(lastVersion) } match { + synchronized { Option(loadedMaps.get(lastVersion)) } match { case Some(map) => if (deltaFilesForLastVersion.size > storeConf.minDeltasForSnapshot) { val (_, e2) = Utils.timeTakenMs(writeSnapshotFile(lastVersion, map)) @@ -536,10 +606,6 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit val earliestVersionToRetain = files.last.version - storeConf.minVersionsToRetain if (earliestVersionToRetain > 0) { val earliestFileToRetain = filesForVersion(files, earliestVersionToRetain).head - synchronized { - val mapsToRemove = loadedMaps.keys.filter(_ < earliestVersionToRetain).toSeq - mapsToRemove.foreach(loadedMaps.remove) - } val filesToDelete = files.filter(_.version < earliestFileToRetain.version) val (_, e2) = Utils.timeTakenMs { filesToDelete.foreach { f => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala index 7eb68c21569ba..d3313b8a315c9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala @@ -138,6 +138,8 @@ trait StateStoreCustomMetric { def name: String def desc: String } + +case class StateStoreCustomSumMetric(name: String, desc: String) extends StateStoreCustomMetric case class StateStoreCustomSizeMetric(name: String, desc: String) extends StateStoreCustomMetric case class StateStoreCustomTimingMetric(name: String, desc: String) extends StateStoreCustomMetric diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala index 765ff076cb467..d145082a39b57 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala @@ -34,6 +34,9 @@ class StateStoreConf(@transient private val sqlConf: SQLConf) /** Minimum versions a State Store implementation should retain to allow rollbacks */ val minVersionsToRetain: Int = sqlConf.minBatchesToRetain + /** Maximum count of versions a State Store implementation should retain in memory */ + val maxVersionsToRetainInMemory: Int = sqlConf.maxBatchesToRetainInMemory + /** * Optional fully qualified name of the subclass of [[StateStoreProvider]] * managing state data. That is, the implementation of the State Store to use. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala index 3f11b8f79943c..4a69a48fed75f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala @@ -23,6 +23,7 @@ import scala.reflect.ClassTag import org.apache.spark.{Partition, TaskContext} import org.apache.spark.rdd.RDD +import org.apache.spark.sql.execution.streaming.StreamExecution import org.apache.spark.sql.execution.streaming.continuous.EpochTracker import org.apache.spark.sql.internal.SessionState import org.apache.spark.sql.types.StructType @@ -74,9 +75,14 @@ class StateStoreRDD[T: ClassTag, U: ClassTag]( // If we're in continuous processing mode, we should get the store version for the current // epoch rather than the one at planning time. - val currentVersion = EpochTracker.getCurrentEpoch match { - case None => storeVersion - case Some(value) => value + val isContinuous = Option(ctxt.getLocalProperty(StreamExecution.IS_CONTINUOUS_PROCESSING)) + .map(_.toBoolean).getOrElse(false) + val currentVersion = if (isContinuous) { + val epoch = EpochTracker.getCurrentEpoch + assert(epoch.isDefined, "Current epoch must be defined for continuous processing streams.") + epoch.get + } else { + storeVersion } store = StateStore.get( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StreamingAggregationStateManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StreamingAggregationStateManager.scala new file mode 100644 index 0000000000000..9bfb9561b42a1 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StreamingAggregationStateManager.scala @@ -0,0 +1,205 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.state + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeRow} +import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateUnsafeProjection, GenerateUnsafeRowJoiner} +import org.apache.spark.sql.types.StructType + +/** + * Base trait for state manager purposed to be used from streaming aggregations. + */ +sealed trait StreamingAggregationStateManager extends Serializable { + + /** Extract columns consisting key from input row, and return the new row for key columns. */ + def getKey(row: UnsafeRow): UnsafeRow + + /** Calculate schema for the value of state. The schema is mainly passed to the StateStoreRDD. */ + def getStateValueSchema: StructType + + /** Get the current value of a non-null key from the target state store. */ + def get(store: StateStore, key: UnsafeRow): UnsafeRow + + /** + * Put a new value for a non-null key to the target state store. Note that key will be + * extracted from the input row, and the key would be same as the result of getKey(inputRow). + */ + def put(store: StateStore, row: UnsafeRow): Unit + + /** + * Commit all the updates that have been made to the target state store, and return the + * new version. + */ + def commit(store: StateStore): Long + + /** Remove a single non-null key from the target state store. */ + def remove(store: StateStore, key: UnsafeRow): Unit + + /** Return an iterator containing all the key-value pairs in target state store. */ + def iterator(store: StateStore): Iterator[UnsafeRowPair] + + /** Return an iterator containing all the keys in target state store. */ + def keys(store: StateStore): Iterator[UnsafeRow] + + /** Return an iterator containing all the values in target state store. */ + def values(store: StateStore): Iterator[UnsafeRow] +} + +object StreamingAggregationStateManager extends Logging { + val supportedVersions = Seq(1, 2) + val legacyVersion = 1 + + def createStateManager( + keyExpressions: Seq[Attribute], + inputRowAttributes: Seq[Attribute], + stateFormatVersion: Int): StreamingAggregationStateManager = { + stateFormatVersion match { + case 1 => new StreamingAggregationStateManagerImplV1(keyExpressions, inputRowAttributes) + case 2 => new StreamingAggregationStateManagerImplV2(keyExpressions, inputRowAttributes) + case _ => throw new IllegalArgumentException(s"Version $stateFormatVersion is invalid") + } + } +} + +abstract class StreamingAggregationStateManagerBaseImpl( + protected val keyExpressions: Seq[Attribute], + protected val inputRowAttributes: Seq[Attribute]) extends StreamingAggregationStateManager { + + @transient protected lazy val keyProjector = + GenerateUnsafeProjection.generate(keyExpressions, inputRowAttributes) + + override def getKey(row: UnsafeRow): UnsafeRow = keyProjector(row) + + override def commit(store: StateStore): Long = store.commit() + + override def remove(store: StateStore, key: UnsafeRow): Unit = store.remove(key) + + override def keys(store: StateStore): Iterator[UnsafeRow] = { + // discard and don't convert values to avoid computation + store.getRange(None, None).map(_.key) + } +} + +/** + * The implementation of StreamingAggregationStateManager for state version 1. + * In state version 1, the schema of key and value in state are follow: + * + * - key: Same as key expressions. + * - value: Same as input row attributes. The schema of value contains key expressions as well. + * + * @param keyExpressions The attributes of keys. + * @param inputRowAttributes The attributes of input row. + */ +class StreamingAggregationStateManagerImplV1( + keyExpressions: Seq[Attribute], + inputRowAttributes: Seq[Attribute]) + extends StreamingAggregationStateManagerBaseImpl(keyExpressions, inputRowAttributes) { + + override def getStateValueSchema: StructType = inputRowAttributes.toStructType + + override def get(store: StateStore, key: UnsafeRow): UnsafeRow = { + store.get(key) + } + + override def put(store: StateStore, row: UnsafeRow): Unit = { + store.put(getKey(row), row) + } + + override def iterator(store: StateStore): Iterator[UnsafeRowPair] = { + store.iterator() + } + + override def values(store: StateStore): Iterator[UnsafeRow] = { + store.iterator().map(_.value) + } +} + +/** + * The implementation of StreamingAggregationStateManager for state version 2. + * In state version 2, the schema of key and value in state are follow: + * + * - key: Same as key expressions. + * - value: The diff between input row attributes and key expressions. + * + * The schema of value is changed to optimize the memory/space usage in state, via removing + * duplicated columns in key-value pair. Hence key columns are excluded from the schema of value. + * + * @param keyExpressions The attributes of keys. + * @param inputRowAttributes The attributes of input row. + */ +class StreamingAggregationStateManagerImplV2( + keyExpressions: Seq[Attribute], + inputRowAttributes: Seq[Attribute]) + extends StreamingAggregationStateManagerBaseImpl(keyExpressions, inputRowAttributes) { + + private val valueExpressions: Seq[Attribute] = inputRowAttributes.diff(keyExpressions) + private val keyValueJoinedExpressions: Seq[Attribute] = keyExpressions ++ valueExpressions + + // flag to check whether the row needs to be project into input row attributes after join + // e.g. if the fields in the joined row are not in the expected order + private val needToProjectToRestoreValue: Boolean = + keyValueJoinedExpressions != inputRowAttributes + + @transient private lazy val valueProjector = + GenerateUnsafeProjection.generate(valueExpressions, inputRowAttributes) + + @transient private lazy val joiner = + GenerateUnsafeRowJoiner.create(StructType.fromAttributes(keyExpressions), + StructType.fromAttributes(valueExpressions)) + @transient private lazy val restoreValueProjector = GenerateUnsafeProjection.generate( + inputRowAttributes, keyValueJoinedExpressions) + + override def getStateValueSchema: StructType = valueExpressions.toStructType + + override def get(store: StateStore, key: UnsafeRow): UnsafeRow = { + val savedState = store.get(key) + if (savedState == null) { + return savedState + } + + restoreOriginalRow(key, savedState) + } + + override def put(store: StateStore, row: UnsafeRow): Unit = { + val key = keyProjector(row) + val value = valueProjector(row) + store.put(key, value) + } + + override def iterator(store: StateStore): Iterator[UnsafeRowPair] = { + store.iterator().map(rowPair => new UnsafeRowPair(rowPair.key, restoreOriginalRow(rowPair))) + } + + override def values(store: StateStore): Iterator[UnsafeRow] = { + store.iterator().map(rowPair => restoreOriginalRow(rowPair)) + } + + private def restoreOriginalRow(rowPair: UnsafeRowPair): UnsafeRow = { + restoreOriginalRow(rowPair.key, rowPair.value) + } + + private def restoreOriginalRow(key: UnsafeRow, value: UnsafeRow): UnsafeRow = { + val joinedRow = joiner.join(key, value) + if (needToProjectToRestoreValue) { + restoreValueProjector(joinedRow) + } else { + joinedRow + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManager.scala index 6b386308c79fb..352b3d3616fba 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManager.scala @@ -269,10 +269,15 @@ class SymmetricHashJoinStateManager( keyWithIndexToValueMetrics.numKeys, // represent each buffered row only once keyToNumValuesMetrics.memoryUsedBytes + keyWithIndexToValueMetrics.memoryUsedBytes, keyWithIndexToValueMetrics.customMetrics.map { + case (s @ StateStoreCustomSumMetric(_, desc), value) => + s.copy(desc = newDesc(desc)) -> value case (s @ StateStoreCustomSizeMetric(_, desc), value) => s.copy(desc = newDesc(desc)) -> value case (s @ StateStoreCustomTimingMetric(_, desc), value) => s.copy(desc = newDesc(desc)) -> value + case (s, _) => + throw new IllegalArgumentException( + s"Unknown state store custom metric is found at metrics: $s") } ) } @@ -290,7 +295,7 @@ class SymmetricHashJoinStateManager( private val keyWithIndexToValue = new KeyWithIndexToValueStore() // Clean up any state store resources if necessary at the end of the task - Option(TaskContext.get()).foreach { _.addTaskCompletionListener { _ => abortIfNeeded() } } + Option(TaskContext.get()).foreach { _.addTaskCompletionListener[Unit] { _ => abortIfNeeded() } } /** Helper trait for invoking common functionalities of a state store. */ private abstract class StateStoreHandler(stateStoreType: StateStoreType) extends Logging { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala index 0b32327e51dbf..b6021438e902b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala @@ -61,7 +61,7 @@ package object state { val cleanedF = dataRDD.sparkContext.clean(storeUpdateFunction) val wrappedF = (store: StateStore, iter: Iterator[T]) => { // Abort the state store in case of error - TaskContext.get().addTaskCompletionListener(_ => { + TaskContext.get().addTaskCompletionListener[Unit](_ => { if (!store.hasCommitted) store.abort() }) cleanedF(store, iter) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala index 6759fb42b4052..c11af345b0248 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala @@ -90,10 +90,18 @@ trait StateStoreWriter extends StatefulOperator { self: SparkPlan => * the driver after this SparkPlan has been executed and metrics have been updated. */ def getProgress(): StateOperatorProgress = { + val customMetrics = stateStoreCustomMetrics + .map(entry => entry._1 -> longMetric(entry._1).value) + + val javaConvertedCustomMetrics: java.util.HashMap[String, java.lang.Long] = + new java.util.HashMap(customMetrics.mapValues(long2Long).asJava) + new StateOperatorProgress( numRowsTotal = longMetric("numTotalStateRows").value, numRowsUpdated = longMetric("numUpdatedStateRows").value, - memoryUsedBytes = longMetric("stateMemory").value) + memoryUsedBytes = longMetric("stateMemory").value, + javaConvertedCustomMetrics + ) } /** Records the duration of running `body` for the next query progress update. */ @@ -115,6 +123,8 @@ trait StateStoreWriter extends StatefulOperator { self: SparkPlan => private def stateStoreCustomMetrics: Map[String, SQLMetric] = { val provider = StateStoreProvider.create(sqlContext.conf.stateStoreProviderClass) provider.supportedCustomMetrics.map { + case StateStoreCustomSumMetric(name, desc) => + name -> SQLMetrics.createMetric(sparkContext, desc) case StateStoreCustomSizeMetric(name, desc) => name -> SQLMetrics.createSizeMetric(sparkContext, desc) case StateStoreCustomTimingMetric(name, desc) => @@ -167,6 +177,18 @@ trait WatermarkSupport extends UnaryExecNode { } } } + + protected def removeKeysOlderThanWatermark( + storeManager: StreamingAggregationStateManager, + store: StateStore): Unit = { + if (watermarkPredicateForKeys.nonEmpty) { + storeManager.keys(store).foreach { keyRow => + if (watermarkPredicateForKeys.get.eval(keyRow)) { + storeManager.remove(store, keyRow) + } + } + } + } } object WatermarkSupport { @@ -201,20 +223,23 @@ object WatermarkSupport { case class StateStoreRestoreExec( keyExpressions: Seq[Attribute], stateInfo: Option[StatefulOperatorStateInfo], + stateFormatVersion: Int, child: SparkPlan) extends UnaryExecNode with StateStoreReader { + private[sql] val stateManager = StreamingAggregationStateManager.createStateManager( + keyExpressions, child.output, stateFormatVersion) + override protected def doExecute(): RDD[InternalRow] = { val numOutputRows = longMetric("numOutputRows") child.execute().mapPartitionsWithStateStore( getStateInfo, keyExpressions.toStructType, - child.output.toStructType, + stateManager.getStateValueSchema, indexOrdinal = None, sqlContext.sessionState, Some(sqlContext.streams.stateStoreCoordinator)) { case (store, iter) => - val getKey = GenerateUnsafeProjection.generate(keyExpressions, child.output) val hasInput = iter.hasNext if (!hasInput && keyExpressions.isEmpty) { // If our `keyExpressions` are empty, we're getting a global aggregation. In that case @@ -224,10 +249,10 @@ case class StateStoreRestoreExec( store.iterator().map(_.value) } else { iter.flatMap { row => - val key = getKey(row) - val savedState = store.get(key) + val key = stateManager.getKey(row.asInstanceOf[UnsafeRow]) + val restoredRow = stateManager.get(store, key) numOutputRows += 1 - Option(savedState).toSeq :+ row + Option(restoredRow).toSeq :+ row } } } @@ -254,9 +279,13 @@ case class StateStoreSaveExec( stateInfo: Option[StatefulOperatorStateInfo] = None, outputMode: Option[OutputMode] = None, eventTimeWatermark: Option[Long] = None, + stateFormatVersion: Int, child: SparkPlan) extends UnaryExecNode with StateStoreWriter with WatermarkSupport { + private[sql] val stateManager = StreamingAggregationStateManager.createStateManager( + keyExpressions, child.output, stateFormatVersion) + override protected def doExecute(): RDD[InternalRow] = { metrics // force lazy init at driver assert(outputMode.nonEmpty, @@ -265,11 +294,10 @@ case class StateStoreSaveExec( child.execute().mapPartitionsWithStateStore( getStateInfo, keyExpressions.toStructType, - child.output.toStructType, + stateManager.getStateValueSchema, indexOrdinal = None, sqlContext.sessionState, Some(sqlContext.streams.stateStoreCoordinator)) { (store, iter) => - val getKey = GenerateUnsafeProjection.generate(keyExpressions, child.output) val numOutputRows = longMetric("numOutputRows") val numUpdatedStateRows = longMetric("numUpdatedStateRows") val allUpdatesTimeMs = longMetric("allUpdatesTimeMs") @@ -282,19 +310,18 @@ case class StateStoreSaveExec( allUpdatesTimeMs += timeTakenMs { while (iter.hasNext) { val row = iter.next().asInstanceOf[UnsafeRow] - val key = getKey(row) - store.put(key, row) + stateManager.put(store, row) numUpdatedStateRows += 1 } } allRemovalsTimeMs += 0 commitTimeMs += timeTakenMs { - store.commit() + stateManager.commit(store) } setStoreMetrics(store) - store.iterator().map { rowPair => + stateManager.values(store).map { valueRow => numOutputRows += 1 - rowPair.value + valueRow } // Update and output only rows being evicted from the StateStore @@ -304,14 +331,13 @@ case class StateStoreSaveExec( val filteredIter = iter.filter(row => !watermarkPredicateForData.get.eval(row)) while (filteredIter.hasNext) { val row = filteredIter.next().asInstanceOf[UnsafeRow] - val key = getKey(row) - store.put(key, row) + stateManager.put(store, row) numUpdatedStateRows += 1 } } val removalStartTimeNs = System.nanoTime - val rangeIter = store.getRange(None, None) + val rangeIter = stateManager.iterator(store) new NextIterator[InternalRow] { override protected def getNext(): InternalRow = { @@ -319,7 +345,7 @@ case class StateStoreSaveExec( while(rangeIter.hasNext && removedValueRow == null) { val rowPair = rangeIter.next() if (watermarkPredicateForKeys.get.eval(rowPair.key)) { - store.remove(rowPair.key) + stateManager.remove(store, rowPair.key) removedValueRow = rowPair.value } } @@ -333,7 +359,7 @@ case class StateStoreSaveExec( override protected def close(): Unit = { allRemovalsTimeMs += NANOSECONDS.toMillis(System.nanoTime - removalStartTimeNs) - commitTimeMs += timeTakenMs { store.commit() } + commitTimeMs += timeTakenMs { stateManager.commit(store) } setStoreMetrics(store) } } @@ -352,8 +378,7 @@ case class StateStoreSaveExec( override protected def getNext(): InternalRow = { if (baseIterator.hasNext) { val row = baseIterator.next().asInstanceOf[UnsafeRow] - val key = getKey(row) - store.put(key, row) + stateManager.put(store, row) numOutputRows += 1 numUpdatedStateRows += 1 row @@ -367,8 +392,10 @@ case class StateStoreSaveExec( allUpdatesTimeMs += NANOSECONDS.toMillis(System.nanoTime - updatesStartTimeNs) // Remove old aggregates if watermark specified - allRemovalsTimeMs += timeTakenMs { removeKeysOlderThanWatermark(store) } - commitTimeMs += timeTakenMs { store.commit() } + allRemovalsTimeMs += timeTakenMs { + removeKeysOlderThanWatermark(stateManager, store) + } + commitTimeMs += timeTakenMs { stateManager.commit(store) } setStoreMetrics(store) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala index 884f945815e0f..e57d080dadf78 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala @@ -202,7 +202,7 @@ private[ui] class SparkPlanGraphCluster( /** - * Represent an edge in the SparkPlan tree. `fromId` is the parent node id, and `toId` is the child + * Represent an edge in the SparkPlan tree. `fromId` is the child node id, and `toId` is the parent * node id. */ private[ui] case class SparkPlanGraphEdge(fromId: Long, toId: Long) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExec.scala index 626f39d9e95cc..fede0f3e92d67 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExec.scala @@ -323,8 +323,6 @@ case class WindowExec( fetchNextRow() // Manage the current partition. - val inputFields = child.output.length - val buffer: ExternalAppendOnlyUnsafeRowArray = new ExternalAppendOnlyUnsafeRowArray(inMemoryThreshold, spillThreshold) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala index bdc4bb4422ae7..697757f8a73ce 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.expressions import org.apache.spark.annotation.InterfaceStability import org.apache.spark.sql.Column +import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.expressions.ScalaUDF import org.apache.spark.sql.types.DataType @@ -46,6 +47,10 @@ case class UserDefinedFunction protected[sql] ( private var _nullable: Boolean = true private var _deterministic: Boolean = true + // This is a `var` instead of in the constructor for backward compatibility of this case class. + // TODO: revisit this case class in Spark 3.0, and narrow down the public surface. + private[sql] var nullableTypes: Option[Seq[Boolean]] = None + /** * Returns true when the UDF can return a nullable value. * @@ -68,6 +73,10 @@ case class UserDefinedFunction protected[sql] ( */ @scala.annotation.varargs def apply(exprs: Column*): Column = { + if (inputTypes.isDefined && nullableTypes.isDefined) { + require(inputTypes.get.length == nullableTypes.get.length) + } + Column(ScalaUDF( f, dataType, @@ -75,7 +84,8 @@ case class UserDefinedFunction protected[sql] ( inputTypes.getOrElse(Nil), udfName = _nameOption, nullable = _nullable, - udfDeterministic = _deterministic)) + udfDeterministic = _deterministic, + nullableTypes = nullableTypes.getOrElse(Nil))) } private def copyAll(): UserDefinedFunction = { @@ -83,6 +93,7 @@ case class UserDefinedFunction protected[sql] ( udf._nameOption = _nameOption udf._nullable = _nullable udf._deterministic = _deterministic + udf.nullableTypes = nullableTypes udf } @@ -127,3 +138,17 @@ case class UserDefinedFunction protected[sql] ( } } } + +// We have to use a name different than `UserDefinedFunction` here, to avoid breaking the binary +// compatibility of the auto-generate UserDefinedFunction object. +private[sql] object SparkUserDefinedFunction { + + def create( + f: AnyRef, + dataType: DataType, + inputSchemas: Option[Seq[ScalaReflection.Schema]]): UserDefinedFunction = { + val udf = new UserDefinedFunction(f, dataType, inputSchemas.map(_.map(_.dataType))) + udf.nullableTypes = inputSchemas.map(_.map(_.nullable)) + udf + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index acca9572cb14c..10b67d7a1ca54 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -32,14 +32,28 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.logical.{HintInfo, ResolvedHint} import org.apache.spark.sql.execution.SparkSqlParser -import org.apache.spark.sql.expressions.UserDefinedFunction +import org.apache.spark.sql.expressions.{SparkUserDefinedFunction, UserDefinedFunction} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.util.Utils /** - * Functions available for DataFrame operations. + * Commonly used functions available for DataFrame operations. Using functions defined here provides + * a little bit more compile-time safety to make sure the function exists. + * + * Spark also includes more built-in functions that are less common and are not defined here. + * You can still access them (and all the functions defined here) using the `functions.expr()` API + * and calling them through a SQL expression string. You can find the entire list of functions + * at SQL API documentation. + * + * As an example, `isnan` is a function that is defined here. You can use `isnan(col("myCol"))` + * to invoke the `isnan` function. This way the programming language's compiler ensures `isnan` + * exists and is of the proper form. You can also use `expr("isnan(myCol)")` function to invoke the + * same function. In this case, Spark itself will ensure `isnan` exists when it analyzes the query. + * + * `regr_count` is an example of a function that is built-in but not defined here, because it is + * less commonly used. To invoke it, use `expr("regr_count(yCol, xCol)")`. * * @groupname udf_funcs UDF functions * @groupname agg_funcs Aggregate functions @@ -1646,7 +1660,7 @@ object functions { def expm1(e: Column): Column = withExpr { Expm1(e.expr) } /** - * Computes the exponential of the given column. + * Computes the exponential of the given column minus one. * * @group math_funcs * @since 1.4.0 @@ -2612,8 +2626,12 @@ object functions { ////////////////////////////////////////////////////////////////////////////////////////////// /** - * Returns the date that is numMonths after startDate. + * Returns the date that is `numMonths` after `startDate`. * + * @param startDate A date, timestamp or string. If a string, the data must be in a format that + * can be cast to a date, such as `yyyy-MM-dd` or `yyyy-MM-dd HH:mm:ss.SSSS` + * @param numMonths The number of months to add to `startDate`, can be negative to subtract months + * @return A date, or null if `startDate` was a string that could not be cast to a date * @group datetime_funcs * @since 1.5.0 */ @@ -2641,12 +2659,15 @@ object functions { * Converts a date/timestamp/string to a value of string in the format specified by the date * format given by the second argument. * - * A pattern `dd.MM.yyyy` would return a string like `18.03.1993`. - * All pattern letters of `java.text.SimpleDateFormat` can be used. + * See [[java.text.SimpleDateFormat]] for valid date and time format patterns * + * @param dateExpr A date, timestamp or string. If a string, the data must be in a format that + * can be cast to a timestamp, such as `yyyy-MM-dd` or `yyyy-MM-dd HH:mm:ss.SSSS` + * @param format A pattern `dd.MM.yyyy` would return a string like `18.03.1993` + * @return A string, or null if `dateExpr` was a string that could not be cast to a timestamp * @note Use specialized functions like [[year]] whenever possible as they benefit from a * specialized implementation. - * + * @throws IllegalArgumentException if the `format` pattern is invalid * @group datetime_funcs * @since 1.5.0 */ @@ -2656,6 +2677,11 @@ object functions { /** * Returns the date that is `days` days after `start` + * + * @param start A date, timestamp or string. If a string, the data must be in a format that + * can be cast to a date, such as `yyyy-MM-dd` or `yyyy-MM-dd HH:mm:ss.SSSS` + * @param days The number of days to add to `start`, can be negative to subtract days + * @return A date, or null if `start` was a string that could not be cast to a date * @group datetime_funcs * @since 1.5.0 */ @@ -2663,6 +2689,11 @@ object functions { /** * Returns the date that is `days` days before `start` + * + * @param start A date, timestamp or string. If a string, the data must be in a format that + * can be cast to a date, such as `yyyy-MM-dd` or `yyyy-MM-dd HH:mm:ss.SSSS` + * @param days The number of days to subtract from `start`, can be negative to add days + * @return A date, or null if `start` was a string that could not be cast to a date * @group datetime_funcs * @since 1.5.0 */ @@ -2670,6 +2701,19 @@ object functions { /** * Returns the number of days from `start` to `end`. + * + * Only considers the date part of the input. For example: + * {{{ + * dateddiff("2018-01-10 00:00:00", "2018-01-09 23:59:59") + * // returns 1 + * }}} + * + * @param end A date, timestamp or string. If a string, the data must be in a format that + * can be cast to a date, such as `yyyy-MM-dd` or `yyyy-MM-dd HH:mm:ss.SSSS` + * @param start A date, timestamp or string. If a string, the data must be in a format that + * can be cast to a date, such as `yyyy-MM-dd` or `yyyy-MM-dd HH:mm:ss.SSSS` + * @return An integer, or null if either `end` or `start` were strings that could not be cast to + * a date. Negative if `end` is before `start` * @group datetime_funcs * @since 1.5.0 */ @@ -2677,6 +2721,7 @@ object functions { /** * Extracts the year as an integer from a given date/timestamp/string. + * @return An integer, or null if the input was a string that could not be cast to a date * @group datetime_funcs * @since 1.5.0 */ @@ -2684,6 +2729,7 @@ object functions { /** * Extracts the quarter as an integer from a given date/timestamp/string. + * @return An integer, or null if the input was a string that could not be cast to a date * @group datetime_funcs * @since 1.5.0 */ @@ -2691,6 +2737,7 @@ object functions { /** * Extracts the month as an integer from a given date/timestamp/string. + * @return An integer, or null if the input was a string that could not be cast to a date * @group datetime_funcs * @since 1.5.0 */ @@ -2698,6 +2745,8 @@ object functions { /** * Extracts the day of the week as an integer from a given date/timestamp/string. + * Ranges from 1 for a Sunday through to 7 for a Saturday + * @return An integer, or null if the input was a string that could not be cast to a date * @group datetime_funcs * @since 2.3.0 */ @@ -2705,6 +2754,7 @@ object functions { /** * Extracts the day of the month as an integer from a given date/timestamp/string. + * @return An integer, or null if the input was a string that could not be cast to a date * @group datetime_funcs * @since 1.5.0 */ @@ -2712,6 +2762,7 @@ object functions { /** * Extracts the day of the year as an integer from a given date/timestamp/string. + * @return An integer, or null if the input was a string that could not be cast to a date * @group datetime_funcs * @since 1.5.0 */ @@ -2719,16 +2770,20 @@ object functions { /** * Extracts the hours as an integer from a given date/timestamp/string. + * @return An integer, or null if the input was a string that could not be cast to a date * @group datetime_funcs * @since 1.5.0 */ def hour(e: Column): Column = withExpr { Hour(e.expr) } /** - * Given a date column, returns the last day of the month which the given date belongs to. + * Returns the last day of the month which the given date belongs to. * For example, input "2015-07-27" returns "2015-07-31" since July 31 is the last day of the * month in July 2015. * + * @param e A date, timestamp or string. If a string, the data must be in a format that can be + * cast to a date, such as `yyyy-MM-dd` or `yyyy-MM-dd HH:mm:ss.SSSS` + * @return A date, or null if the input was a string that could not be cast to a date * @group datetime_funcs * @since 1.5.0 */ @@ -2736,46 +2791,60 @@ object functions { /** * Extracts the minutes as an integer from a given date/timestamp/string. + * @return An integer, or null if the input was a string that could not be cast to a date * @group datetime_funcs * @since 1.5.0 */ def minute(e: Column): Column = withExpr { Minute(e.expr) } /** - * Returns number of months between dates `date1` and `date2`. - * If `date1` is later than `date2`, then the result is positive. - * If `date1` and `date2` are on the same day of month, or both are the last day of month, - * time of day will be ignored. + * Returns number of months between dates `start` and `end`. * - * Otherwise, the difference is calculated based on 31 days per month, and rounded to - * 8 digits. + * A whole number is returned if both inputs have the same day of month or both are the last day + * of their respective months. Otherwise, the difference is calculated assuming 31 days per month. + * + * For example: + * {{{ + * months_between("2017-11-14", "2017-07-14") // returns 4.0 + * months_between("2017-01-01", "2017-01-10") // returns 0.29032258 + * months_between("2017-06-01", "2017-06-16 12:00:00") // returns -0.5 + * }}} + * + * @param end A date, timestamp or string. If a string, the data must be in a format that can + * be cast to a timestamp, such as `yyyy-MM-dd` or `yyyy-MM-dd HH:mm:ss.SSSS` + * @param start A date, timestamp or string. If a string, the data must be in a format that can + * cast to a timestamp, such as `yyyy-MM-dd` or `yyyy-MM-dd HH:mm:ss.SSSS` + * @return A double, or null if either `end` or `start` were strings that could not be cast to a + * timestamp. Negative if `end` is before `start` * @group datetime_funcs * @since 1.5.0 */ - def months_between(date1: Column, date2: Column): Column = withExpr { - new MonthsBetween(date1.expr, date2.expr) + def months_between(end: Column, start: Column): Column = withExpr { + new MonthsBetween(end.expr, start.expr) } /** - * Returns number of months between dates `date1` and `date2`. If `roundOff` is set to true, the + * Returns number of months between dates `end` and `start`. If `roundOff` is set to true, the * result is rounded off to 8 digits; it is not rounded otherwise. * @group datetime_funcs * @since 2.4.0 */ - def months_between(date1: Column, date2: Column, roundOff: Boolean): Column = withExpr { - MonthsBetween(date1.expr, date2.expr, lit(roundOff).expr) + def months_between(end: Column, start: Column, roundOff: Boolean): Column = withExpr { + MonthsBetween(end.expr, start.expr, lit(roundOff).expr) } /** - * Given a date column, returns the first date which is later than the value of the date column - * that is on the specified day of the week. + * Returns the first date which is later than the value of the `date` column that is on the + * specified day of the week. * * For example, `next_day('2015-07-27', "Sunday")` returns 2015-08-02 because that is the first * Sunday after 2015-07-27. * - * Day of the week parameter is case insensitive, and accepts: - * "Mon", "Tue", "Wed", "Thu", "Fri", "Sat", "Sun". - * + * @param date A date, timestamp or string. If a string, the data must be in a format that + * can be cast to a date, such as `yyyy-MM-dd` or `yyyy-MM-dd HH:mm:ss.SSSS` + * @param dayOfWeek Case insensitive, and accepts: "Mon", "Tue", "Wed", "Thu", "Fri", "Sat", "Sun" + * @return A date, or null if `date` was a string that could not be cast to a date or if + * `dayOfWeek` was an invalid value * @group datetime_funcs * @since 1.5.0 */ @@ -2785,6 +2854,7 @@ object functions { /** * Extracts the seconds as an integer from a given date/timestamp/string. + * @return An integer, or null if the input was a string that could not be cast to a timestamp * @group datetime_funcs * @since 1.5.0 */ @@ -2792,6 +2862,11 @@ object functions { /** * Extracts the week number as an integer from a given date/timestamp/string. + * + * A week is considered to start on a Monday and week 1 is the first week with more than 3 days, + * as defined by ISO 8601 + * + * @return An integer, or null if the input was a string that could not be cast to a date * @group datetime_funcs * @since 1.5.0 */ @@ -2799,8 +2874,12 @@ object functions { /** * Converts the number of seconds from unix epoch (1970-01-01 00:00:00 UTC) to a string - * representing the timestamp of that moment in the current system time zone in the given - * format. + * representing the timestamp of that moment in the current system time zone in the + * yyyy-MM-dd HH:mm:ss format. + * + * @param ut A number of a type that is castable to a long, such as string or integer. Can be + * negative for timestamps before the unix epoch + * @return A string, or null if the input was a string that could not be cast to a long * @group datetime_funcs * @since 1.5.0 */ @@ -2812,6 +2891,14 @@ object functions { * Converts the number of seconds from unix epoch (1970-01-01 00:00:00 UTC) to a string * representing the timestamp of that moment in the current system time zone in the given * format. + * + * See [[java.text.SimpleDateFormat]] for valid date and time format patterns + * + * @param ut A number of a type that is castable to a long, such as string or integer. Can be + * negative for timestamps before the unix epoch + * @param f A date time pattern that the input will be formatted to + * @return A string, or null if `ut` was a string that could not be cast to a long or `f` was + * an invalid date time pattern * @group datetime_funcs * @since 1.5.0 */ @@ -2820,7 +2907,7 @@ object functions { } /** - * Returns the current Unix timestamp (in seconds). + * Returns the current Unix timestamp (in seconds) as a long. * * @note All calls of `unix_timestamp` within the same query return the same value * (i.e. the current timestamp is calculated at the start of query evaluation). @@ -2835,8 +2922,10 @@ object functions { /** * Converts time string in format yyyy-MM-dd HH:mm:ss to Unix timestamp (in seconds), * using the default timezone and the default locale. - * Returns `null` if fails. * + * @param s A date, timestamp or string. If a string, the data must be in the + * `yyyy-MM-dd HH:mm:ss` format + * @return A long, or null if the input was a string not of the correct format * @group datetime_funcs * @since 1.5.0 */ @@ -2846,17 +2935,25 @@ object functions { /** * Converts time string with given pattern to Unix timestamp (in seconds). - * Returns `null` if fails. * - * @see - * Customizing Formats + * See [[java.text.SimpleDateFormat]] for valid date and time format patterns + * + * @param s A date, timestamp or string. If a string, the data must be in a format that can be + * cast to a date, such as `yyyy-MM-dd` or `yyyy-MM-dd HH:mm:ss.SSSS` + * @param p A date time pattern detailing the format of `s` when `s` is a string + * @return A long, or null if `s` was a string that could not be cast to a date or `p` was + * an invalid format * @group datetime_funcs * @since 1.5.0 */ def unix_timestamp(s: Column, p: String): Column = withExpr { UnixTimestamp(s.expr, Literal(p)) } /** - * Convert time string to a Unix timestamp (in seconds) by casting rules to `TimestampType`. + * Converts to a timestamp by casting rules to `TimestampType`. + * + * @param s A date, timestamp or string. If a string, the data must be in a format that can be + * cast to a timestamp, such as `yyyy-MM-dd` or `yyyy-MM-dd HH:mm:ss.SSSS` + * @return A timestamp, or null if the input was a string that could not be cast to a timestamp * @group datetime_funcs * @since 2.2.0 */ @@ -2865,9 +2962,15 @@ object functions { } /** - * Convert time string to a Unix timestamp (in seconds) with a specified format - * (see [http://docs.oracle.com/javase/tutorial/i18n/format/simpleDateFormat.html]) - * to Unix timestamp (in seconds), return null if fail. + * Converts time string with the given pattern to timestamp. + * + * See [[java.text.SimpleDateFormat]] for valid date and time format patterns + * + * @param s A date, timestamp or string. If a string, the data must be in a format that can be + * cast to a timestamp, such as `yyyy-MM-dd` or `yyyy-MM-dd HH:mm:ss.SSSS` + * @param fmt A date time pattern detailing the format of `s` when `s` is a string + * @return A timestamp, or null if `s` was a string that could not be cast to a timestamp or + * `fmt` was an invalid format * @group datetime_funcs * @since 2.2.0 */ @@ -2885,9 +2988,14 @@ object functions { /** * Converts the column into a `DateType` with a specified format - * (see [http://docs.oracle.com/javase/tutorial/i18n/format/simpleDateFormat.html]) - * return null if fail. * + * See [[java.text.SimpleDateFormat]] for valid date and time format patterns + * + * @param e A date, timestamp or string. If a string, the data must be in a format that can be + * cast to a date, such as `yyyy-MM-dd` or `yyyy-MM-dd HH:mm:ss.SSSS` + * @param fmt A date time pattern detailing the format of `e` when `e`is a string + * @return A date, or null if `e` was a string that could not be cast to a date or `fmt` was an + * invalid format * @group datetime_funcs * @since 2.2.0 */ @@ -2898,9 +3006,15 @@ object functions { /** * Returns date truncated to the unit specified by the format. * + * For example, `trunc("2018-11-19 12:01:19", "year")` returns 2018-01-01 + * + * @param date A date, timestamp or string. If a string, the data must be in a format that can be + * cast to a date, such as `yyyy-MM-dd` or `yyyy-MM-dd HH:mm:ss.SSSS` * @param format: 'year', 'yyyy', 'yy' for truncate by year, * or 'month', 'mon', 'mm' for truncate by month * + * @return A date, or null if `date` was a string that could not be cast to a date or `format` + * was an invalid value * @group datetime_funcs * @since 1.5.0 */ @@ -2911,11 +3025,16 @@ object functions { /** * Returns timestamp truncated to the unit specified by the format. * + * For example, `date_tunc("2018-11-19 12:01:19", "year")` returns 2018-01-01 00:00:00 + * * @param format: 'year', 'yyyy', 'yy' for truncate by year, * 'month', 'mon', 'mm' for truncate by month, * 'day', 'dd' for truncate by day, * Other options are: 'second', 'minute', 'hour', 'week', 'month', 'quarter' - * + * @param timestamp A date, timestamp or string. If a string, the data must be in a format that + * can be cast to a timestamp, such as `yyyy-MM-dd` or `yyyy-MM-dd HH:mm:ss.SSSS` + * @return A timestamp, or null if `timestamp` was a string that could not be cast to a timestamp + * or `format` was an invalid value * @group datetime_funcs * @since 2.3.0 */ @@ -2927,6 +3046,13 @@ object functions { * Given a timestamp like '2017-07-14 02:40:00.0', interprets it as a time in UTC, and renders * that time as a timestamp in the given time zone. For example, 'GMT+1' would yield * '2017-07-14 03:40:00.0'. + * + * @param ts A date, timestamp or string. If a string, the data must be in a format that can be + * cast to a timestamp, such as `yyyy-MM-dd` or `yyyy-MM-dd HH:mm:ss.SSSS` + * @param tz A string detailing the time zone that the input should be adjusted to, such as + * `Europe/London`, `PST` or `GMT+5` + * @return A timestamp, or null if `ts` was a string that could not be cast to a timestamp or + * `tz` was an invalid value * @group datetime_funcs * @since 1.5.0 */ @@ -2934,10 +3060,28 @@ object functions { FromUTCTimestamp(ts.expr, Literal(tz)) } + /** + * Given a timestamp like '2017-07-14 02:40:00.0', interprets it as a time in UTC, and renders + * that time as a timestamp in the given time zone. For example, 'GMT+1' would yield + * '2017-07-14 03:40:00.0'. + * @group datetime_funcs + * @since 2.4.0 + */ + def from_utc_timestamp(ts: Column, tz: Column): Column = withExpr { + FromUTCTimestamp(ts.expr, tz.expr) + } + /** * Given a timestamp like '2017-07-14 02:40:00.0', interprets it as a time in the given time * zone, and renders that time as a timestamp in UTC. For example, 'GMT+1' would yield * '2017-07-14 01:40:00.0'. + * + * @param ts A date, timestamp or string. If a string, the data must be in a format that can be + * cast to a timestamp, such as `yyyy-MM-dd` or `yyyy-MM-dd HH:mm:ss.SSSS` + * @param tz A string detailing the time zone that the input belongs to, such as `Europe/London`, + * `PST` or `GMT+5` + * @return A timestamp, or null if `ts` was a string that could not be cast to a timestamp or + * `tz` was an invalid value * @group datetime_funcs * @since 1.5.0 */ @@ -2945,6 +3089,17 @@ object functions { ToUTCTimestamp(ts.expr, Literal(tz)) } + /** + * Given a timestamp like '2017-07-14 02:40:00.0', interprets it as a time in the given time + * zone, and renders that time as a timestamp in UTC. For example, 'GMT+1' would yield + * '2017-07-14 01:40:00.0'. + * @group datetime_funcs + * @since 2.4.0 + */ + def to_utc_timestamp(ts: Column, tz: Column): Column = withExpr { + ToUTCTimestamp(ts.expr, tz.expr) + } + /** * Bucketize rows into one or more time windows given a timestamp specifying column. Window * starts are inclusive but the window ends are exclusive, e.g. 12:05 will be in the window @@ -3182,6 +3337,7 @@ object functions { /** * Remove all elements that equal to element from the given array. + * * @group collection_funcs * @since 2.4.0 */ @@ -3196,6 +3352,38 @@ object functions { */ def array_distinct(e: Column): Column = withExpr { ArrayDistinct(e.expr) } + /** + * Returns an array of the elements in the intersection of the given two arrays, + * without duplicates. + * + * @group collection_funcs + * @since 2.4.0 + */ + def array_intersect(col1: Column, col2: Column): Column = withExpr { + ArrayIntersect(col1.expr, col2.expr) + } + + /** + * Returns an array of the elements in the union of the given two arrays, without duplicates. + * + * @group collection_funcs + * @since 2.4.0 + */ + def array_union(col1: Column, col2: Column): Column = withExpr { + ArrayUnion(col1.expr, col2.expr) + } + + /** + * Returns an array of the elements in the first array but not in the second array, + * without duplicates. The order of elements in the result is not determined + * + * @group collection_funcs + * @since 2.4.0 + */ + def array_except(col1: Column, col2: Column): Column = withExpr { + ArrayExcept(col1.expr, col2.expr) + } + /** * Creates a new row for each element in the given array or map column. * @@ -3270,7 +3458,7 @@ object functions { /** * (Scala-specific) Parses a column containing a JSON string into a `MapType` with `StringType` - * as keys type, `StructType` or `ArrayType` of `StructType`s with the specified schema. + * as keys type, `StructType` or `ArrayType` with the specified schema. * Returns `null`, in the case of an unparseable string. * * @param e a string column containing JSON data. @@ -3282,7 +3470,7 @@ object functions { * @since 2.2.0 */ def from_json(e: Column, schema: DataType, options: Map[String, String]): Column = withExpr { - new JsonToStructs(schema, options, e.expr) + JsonToStructs(schema, options, e.expr) } /** @@ -3302,7 +3490,7 @@ object functions { /** * (Java-specific) Parses a column containing a JSON string into a `MapType` with `StringType` - * as keys type, `StructType` or `ArrayType` of `StructType`s with the specified schema. + * as keys type, `StructType` or `ArrayType` with the specified schema. * Returns `null`, in the case of an unparseable string. * * @param e a string column containing JSON data. @@ -3331,7 +3519,7 @@ object functions { /** * Parses a column containing a JSON string into a `MapType` with `StringType` as keys type, - * `StructType` or `ArrayType` of `StructType`s with the specified schema. + * `StructType` or `ArrayType` with the specified schema. * Returns `null`, in the case of an unparseable string. * * @param e a string column containing JSON data. @@ -3345,7 +3533,7 @@ object functions { /** * (Java-specific) Parses a column containing a JSON string into a `MapType` with `StringType` - * as keys type, `StructType` or `ArrayType` of `StructType`s with the specified schema. + * as keys type, `StructType` or `ArrayType` with the specified schema. * Returns `null`, in the case of an unparseable string. * * @param e a string column containing JSON data. @@ -3362,7 +3550,7 @@ object functions { /** * (Scala-specific) Parses a column containing a JSON string into a `MapType` with `StringType` - * as keys type, `StructType` or `ArrayType` of `StructType`s with the specified schema. + * as keys type, `StructType` or `ArrayType` with the specified schema. * Returns `null`, in the case of an unparseable string. * * @param e a string column containing JSON data. @@ -3382,11 +3570,53 @@ object functions { } /** - * (Scala-specific) Converts a column containing a `StructType`, `ArrayType` of `StructType`s, - * a `MapType` or `ArrayType` of `MapType`s into a JSON string with the specified schema. + * (Scala-specific) Parses a column containing a JSON string into a `MapType` with `StringType` + * as keys type, `StructType` or `ArrayType` of `StructType`s with the specified schema. + * Returns `null`, in the case of an unparseable string. + * + * @param e a string column containing JSON data. + * @param schema the schema to use when parsing the json string + * + * @group collection_funcs + * @since 2.4.0 + */ + def from_json(e: Column, schema: Column): Column = { + from_json(e, schema, Map.empty[String, String].asJava) + } + + /** + * (Java-specific) Parses a column containing a JSON string into a `MapType` with `StringType` + * as keys type, `StructType` or `ArrayType` of `StructType`s with the specified schema. + * Returns `null`, in the case of an unparseable string. + * + * @param e a string column containing JSON data. + * @param schema the schema to use when parsing the json string + * @param options options to control how the json is parsed. accepts the same options and the + * json data source. + * + * @group collection_funcs + * @since 2.4.0 + */ + def from_json(e: Column, schema: Column, options: java.util.Map[String, String]): Column = { + withExpr(new JsonToStructs(e.expr, schema.expr, options.asScala.toMap)) + } + + /** + * Parses a column containing a JSON string and infers its schema. + * + * @param e a string column containing JSON data. + * + * @group collection_funcs + * @since 2.4.0 + */ + def schema_of_json(e: Column): Column = withExpr(new SchemaOfJson(e.expr)) + + /** + * (Scala-specific) Converts a column containing a `StructType`, `ArrayType` or + * a `MapType` into a JSON string with the specified schema. * Throws an exception, in the case of an unsupported type. * - * @param e a column containing a struct or array of the structs. + * @param e a column containing a struct, an array or a map. * @param options options to control how the struct column is converted into a json string. * accepts the same options and the json data source. * @@ -3398,11 +3628,11 @@ object functions { } /** - * (Java-specific) Converts a column containing a `StructType`, `ArrayType` of `StructType`s, - * a `MapType` or `ArrayType` of `MapType`s into a JSON string with the specified schema. + * (Java-specific) Converts a column containing a `StructType`, `ArrayType` or + * a `MapType` into a JSON string with the specified schema. * Throws an exception, in the case of an unsupported type. * - * @param e a column containing a struct or array of the structs. + * @param e a column containing a struct, an array or a map. * @param options options to control how the struct column is converted into a json string. * accepts the same options and the json data source. * @@ -3413,11 +3643,11 @@ object functions { to_json(e, options.asScala.toMap) /** - * Converts a column containing a `StructType`, `ArrayType` of `StructType`s, - * a `MapType` or `ArrayType` of `MapType`s into a JSON string with the specified schema. + * Converts a column containing a `StructType`, `ArrayType` or + * a `MapType` into a JSON string with the specified schema. * Throws an exception, in the case of an unsupported type. * - * @param e a column containing a struct or array of the structs. + * @param e a column containing a struct, an array or a map. * * @group collection_funcs * @since 2.1.0 @@ -3431,7 +3661,7 @@ object functions { * @group collection_funcs * @since 1.5.0 */ - def size(e: Column): Column = withExpr { new Size(e.expr) } + def size(e: Column): Column = withExpr { Size(e.expr) } /** * Sorts the input array for the given column in ascending order, @@ -3470,6 +3700,16 @@ object functions { */ def array_max(e: Column): Column = withExpr { ArrayMax(e.expr) } + /** + * Returns a random permutation of the given array. + * + * @note The function is non-deterministic. + * + * @group collection_funcs + * @since 2.4.0 + */ + def shuffle(e: Column): Column = withExpr { Shuffle(e.expr) } + /** * Returns a reversed string or an array with reverse order of elements. * @group collection_funcs @@ -3563,124 +3803,13 @@ object functions { @scala.annotation.varargs def arrays_zip(e: Column*): Column = withExpr { ArraysZip(e.map(_.expr)) } - ////////////////////////////////////////////////////////////////////////////////////////////// - // Mask functions - ////////////////////////////////////////////////////////////////////////////////////////////// - /** - * Returns a string which is the masked representation of the input. - * @group mask_funcs - * @since 2.4.0 - */ - def mask(e: Column): Column = withExpr { new Mask(e.expr) } - /** - * Returns a string which is the masked representation of the input, using `upper`, `lower` and - * `digit` as replacement characters. - * @group mask_funcs - * @since 2.4.0 - */ - def mask(e: Column, upper: String, lower: String, digit: String): Column = withExpr { - Mask(e.expr, upper, lower, digit) - } - - /** - * Returns a string with the first `n` characters masked. - * @group mask_funcs - * @since 2.4.0 - */ - def mask_first_n(e: Column, n: Int): Column = withExpr { new MaskFirstN(e.expr, Literal(n)) } - - /** - * Returns a string with the first `n` characters masked, using `upper`, `lower` and `digit` as - * replacement characters. - * @group mask_funcs - * @since 2.4.0 - */ - def mask_first_n( - e: Column, - n: Int, - upper: String, - lower: String, - digit: String): Column = withExpr { - MaskFirstN(e.expr, n, upper, lower, digit) - } - - /** - * Returns a string with the last `n` characters masked. - * @group mask_funcs - * @since 2.4.0 - */ - def mask_last_n(e: Column, n: Int): Column = withExpr { new MaskLastN(e.expr, Literal(n)) } - - /** - * Returns a string with the last `n` characters masked, using `upper`, `lower` and `digit` as - * replacement characters. - * @group mask_funcs - * @since 2.4.0 - */ - def mask_last_n( - e: Column, - n: Int, - upper: String, - lower: String, - digit: String): Column = withExpr { - MaskLastN(e.expr, n, upper, lower, digit) - } - - /** - * Returns a string with all but the first `n` characters masked. - * @group mask_funcs - * @since 2.4.0 - */ - def mask_show_first_n(e: Column, n: Int): Column = withExpr { - new MaskShowFirstN(e.expr, Literal(n)) - } - - /** - * Returns a string with all but the first `n` characters masked, using `upper`, `lower` and - * `digit` as replacement characters. - * @group mask_funcs - * @since 2.4.0 - */ - def mask_show_first_n( - e: Column, - n: Int, - upper: String, - lower: String, - digit: String): Column = withExpr { - MaskShowFirstN(e.expr, n, upper, lower, digit) - } - - /** - * Returns a string with all but the last `n` characters masked. - * @group mask_funcs - * @since 2.4.0 - */ - def mask_show_last_n(e: Column, n: Int): Column = withExpr { - new MaskShowLastN(e.expr, Literal(n)) - } - - /** - * Returns a string with all but the last `n` characters masked, using `upper`, `lower` and - * `digit` as replacement characters. - * @group mask_funcs - * @since 2.4.0 - */ - def mask_show_last_n( - e: Column, - n: Int, - upper: String, - lower: String, - digit: String): Column = withExpr { - MaskShowLastN(e.expr, n, upper, lower, digit) - } - - /** - * Returns a hashed value based on the input column. - * @group mask_funcs + * Returns the union of all the given maps. + * @group collection_funcs * @since 2.4.0 */ - def mask_hash(e: Column): Column = withExpr { MaskHash(e.expr) } + @scala.annotation.varargs + def map_concat(cols: Column*): Column = withExpr { MapConcat(cols.map(_.expr)) } // scalastyle:off line.size.limit // scalastyle:off parameter.number @@ -3690,7 +3819,7 @@ object functions { (0 to 10).foreach { x => val types = (1 to x).foldRight("RT")((i, s) => {s"A$i, $s"}) val typeTags = (1 to x).map(i => s"A$i: TypeTag").foldLeft("RT: TypeTag")(_ + ", " + _) - val inputTypes = (1 to x).foldRight("Nil")((i, s) => {s"ScalaReflection.schemaFor(typeTag[A$i]).dataType :: $s"}) + val inputSchemas = (1 to x).foldRight("Nil")((i, s) => {s"ScalaReflection.schemaFor(typeTag[A$i]) :: $s"}) println(s""" |/** | * Defines a Scala closure of $x arguments as user-defined function (UDF). @@ -3703,8 +3832,8 @@ object functions { | */ |def udf[$typeTags](f: Function$x[$types]): UserDefinedFunction = { | val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - | val inputTypes = Try($inputTypes).toOption - | val udf = UserDefinedFunction(f, dataType, inputTypes) + | val inputSchemas = Try($inputTypes).toOption + | val udf = SparkUserDefinedFunction.create(f, dataType, inputSchemas) | if (nullable) udf else udf.asNonNullable() |}""".stripMargin) } @@ -3727,7 +3856,7 @@ object functions { | */ |def udf(f: UDF$i[$extTypeArgs], returnType: DataType): UserDefinedFunction = { | val func = f$anyCast.call($anyParams) - | UserDefinedFunction($funcCall, returnType, inputTypes = None) + | SparkUserDefinedFunction.create($funcCall, returnType, inputSchemas = None) |}""".stripMargin) } @@ -3748,8 +3877,8 @@ object functions { */ def udf[RT: TypeTag](f: Function0[RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - val inputTypes = Try(Nil).toOption - val udf = UserDefinedFunction(f, dataType, inputTypes) + val inputSchemas = Try(Nil).toOption + val udf = SparkUserDefinedFunction.create(f, dataType, inputSchemas) if (nullable) udf else udf.asNonNullable() } @@ -3764,8 +3893,8 @@ object functions { */ def udf[RT: TypeTag, A1: TypeTag](f: Function1[A1, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - val inputTypes = Try(ScalaReflection.schemaFor(typeTag[A1]).dataType :: Nil).toOption - val udf = UserDefinedFunction(f, dataType, inputTypes) + val inputSchemas = Try(ScalaReflection.schemaFor(typeTag[A1]) :: Nil).toOption + val udf = SparkUserDefinedFunction.create(f, dataType, inputSchemas) if (nullable) udf else udf.asNonNullable() } @@ -3780,8 +3909,8 @@ object functions { */ def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag](f: Function2[A1, A2, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - val inputTypes = Try(ScalaReflection.schemaFor(typeTag[A1]).dataType :: ScalaReflection.schemaFor(typeTag[A2]).dataType :: Nil).toOption - val udf = UserDefinedFunction(f, dataType, inputTypes) + val inputSchemas = Try(ScalaReflection.schemaFor(typeTag[A1]) :: ScalaReflection.schemaFor(typeTag[A2]) :: Nil).toOption + val udf = SparkUserDefinedFunction.create(f, dataType, inputSchemas) if (nullable) udf else udf.asNonNullable() } @@ -3796,8 +3925,8 @@ object functions { */ def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag](f: Function3[A1, A2, A3, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - val inputTypes = Try(ScalaReflection.schemaFor(typeTag[A1]).dataType :: ScalaReflection.schemaFor(typeTag[A2]).dataType :: ScalaReflection.schemaFor(typeTag[A3]).dataType :: Nil).toOption - val udf = UserDefinedFunction(f, dataType, inputTypes) + val inputSchemas = Try(ScalaReflection.schemaFor(typeTag[A1]) :: ScalaReflection.schemaFor(typeTag[A2]) :: ScalaReflection.schemaFor(typeTag[A3]) :: Nil).toOption + val udf = SparkUserDefinedFunction.create(f, dataType, inputSchemas) if (nullable) udf else udf.asNonNullable() } @@ -3812,8 +3941,8 @@ object functions { */ def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag](f: Function4[A1, A2, A3, A4, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - val inputTypes = Try(ScalaReflection.schemaFor(typeTag[A1]).dataType :: ScalaReflection.schemaFor(typeTag[A2]).dataType :: ScalaReflection.schemaFor(typeTag[A3]).dataType :: ScalaReflection.schemaFor(typeTag[A4]).dataType :: Nil).toOption - val udf = UserDefinedFunction(f, dataType, inputTypes) + val inputSchemas = Try(ScalaReflection.schemaFor(typeTag[A1]) :: ScalaReflection.schemaFor(typeTag[A2]) :: ScalaReflection.schemaFor(typeTag[A3]) :: ScalaReflection.schemaFor(typeTag[A4]) :: Nil).toOption + val udf = SparkUserDefinedFunction.create(f, dataType, inputSchemas) if (nullable) udf else udf.asNonNullable() } @@ -3828,8 +3957,8 @@ object functions { */ def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag](f: Function5[A1, A2, A3, A4, A5, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - val inputTypes = Try(ScalaReflection.schemaFor(typeTag[A1]).dataType :: ScalaReflection.schemaFor(typeTag[A2]).dataType :: ScalaReflection.schemaFor(typeTag[A3]).dataType :: ScalaReflection.schemaFor(typeTag[A4]).dataType :: ScalaReflection.schemaFor(typeTag[A5]).dataType :: Nil).toOption - val udf = UserDefinedFunction(f, dataType, inputTypes) + val inputSchemas = Try(ScalaReflection.schemaFor(typeTag[A1]) :: ScalaReflection.schemaFor(typeTag[A2]) :: ScalaReflection.schemaFor(typeTag[A3]) :: ScalaReflection.schemaFor(typeTag[A4]) :: ScalaReflection.schemaFor(typeTag[A5]) :: Nil).toOption + val udf = SparkUserDefinedFunction.create(f, dataType, inputSchemas) if (nullable) udf else udf.asNonNullable() } @@ -3844,8 +3973,8 @@ object functions { */ def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag](f: Function6[A1, A2, A3, A4, A5, A6, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - val inputTypes = Try(ScalaReflection.schemaFor(typeTag[A1]).dataType :: ScalaReflection.schemaFor(typeTag[A2]).dataType :: ScalaReflection.schemaFor(typeTag[A3]).dataType :: ScalaReflection.schemaFor(typeTag[A4]).dataType :: ScalaReflection.schemaFor(typeTag[A5]).dataType :: ScalaReflection.schemaFor(typeTag[A6]).dataType :: Nil).toOption - val udf = UserDefinedFunction(f, dataType, inputTypes) + val inputSchemas = Try(ScalaReflection.schemaFor(typeTag[A1]) :: ScalaReflection.schemaFor(typeTag[A2]) :: ScalaReflection.schemaFor(typeTag[A3]) :: ScalaReflection.schemaFor(typeTag[A4]) :: ScalaReflection.schemaFor(typeTag[A5]) :: ScalaReflection.schemaFor(typeTag[A6]) :: Nil).toOption + val udf = SparkUserDefinedFunction.create(f, dataType, inputSchemas) if (nullable) udf else udf.asNonNullable() } @@ -3860,8 +3989,8 @@ object functions { */ def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag](f: Function7[A1, A2, A3, A4, A5, A6, A7, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - val inputTypes = Try(ScalaReflection.schemaFor(typeTag[A1]).dataType :: ScalaReflection.schemaFor(typeTag[A2]).dataType :: ScalaReflection.schemaFor(typeTag[A3]).dataType :: ScalaReflection.schemaFor(typeTag[A4]).dataType :: ScalaReflection.schemaFor(typeTag[A5]).dataType :: ScalaReflection.schemaFor(typeTag[A6]).dataType :: ScalaReflection.schemaFor(typeTag[A7]).dataType :: Nil).toOption - val udf = UserDefinedFunction(f, dataType, inputTypes) + val inputSchemas = Try(ScalaReflection.schemaFor(typeTag[A1]) :: ScalaReflection.schemaFor(typeTag[A2]) :: ScalaReflection.schemaFor(typeTag[A3]) :: ScalaReflection.schemaFor(typeTag[A4]) :: ScalaReflection.schemaFor(typeTag[A5]) :: ScalaReflection.schemaFor(typeTag[A6]) :: ScalaReflection.schemaFor(typeTag[A7]) :: Nil).toOption + val udf = SparkUserDefinedFunction.create(f, dataType, inputSchemas) if (nullable) udf else udf.asNonNullable() } @@ -3876,8 +4005,8 @@ object functions { */ def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag](f: Function8[A1, A2, A3, A4, A5, A6, A7, A8, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - val inputTypes = Try(ScalaReflection.schemaFor(typeTag[A1]).dataType :: ScalaReflection.schemaFor(typeTag[A2]).dataType :: ScalaReflection.schemaFor(typeTag[A3]).dataType :: ScalaReflection.schemaFor(typeTag[A4]).dataType :: ScalaReflection.schemaFor(typeTag[A5]).dataType :: ScalaReflection.schemaFor(typeTag[A6]).dataType :: ScalaReflection.schemaFor(typeTag[A7]).dataType :: ScalaReflection.schemaFor(typeTag[A8]).dataType :: Nil).toOption - val udf = UserDefinedFunction(f, dataType, inputTypes) + val inputSchemas = Try(ScalaReflection.schemaFor(typeTag[A1]) :: ScalaReflection.schemaFor(typeTag[A2]) :: ScalaReflection.schemaFor(typeTag[A3]) :: ScalaReflection.schemaFor(typeTag[A4]) :: ScalaReflection.schemaFor(typeTag[A5]) :: ScalaReflection.schemaFor(typeTag[A6]) :: ScalaReflection.schemaFor(typeTag[A7]) :: ScalaReflection.schemaFor(typeTag[A8]) :: Nil).toOption + val udf = SparkUserDefinedFunction.create(f, dataType, inputSchemas) if (nullable) udf else udf.asNonNullable() } @@ -3892,8 +4021,8 @@ object functions { */ def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag](f: Function9[A1, A2, A3, A4, A5, A6, A7, A8, A9, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - val inputTypes = Try(ScalaReflection.schemaFor(typeTag[A1]).dataType :: ScalaReflection.schemaFor(typeTag[A2]).dataType :: ScalaReflection.schemaFor(typeTag[A3]).dataType :: ScalaReflection.schemaFor(typeTag[A4]).dataType :: ScalaReflection.schemaFor(typeTag[A5]).dataType :: ScalaReflection.schemaFor(typeTag[A6]).dataType :: ScalaReflection.schemaFor(typeTag[A7]).dataType :: ScalaReflection.schemaFor(typeTag[A8]).dataType :: ScalaReflection.schemaFor(typeTag[A9]).dataType :: Nil).toOption - val udf = UserDefinedFunction(f, dataType, inputTypes) + val inputSchemas = Try(ScalaReflection.schemaFor(typeTag[A1]) :: ScalaReflection.schemaFor(typeTag[A2]) :: ScalaReflection.schemaFor(typeTag[A3]) :: ScalaReflection.schemaFor(typeTag[A4]) :: ScalaReflection.schemaFor(typeTag[A5]) :: ScalaReflection.schemaFor(typeTag[A6]) :: ScalaReflection.schemaFor(typeTag[A7]) :: ScalaReflection.schemaFor(typeTag[A8]) :: ScalaReflection.schemaFor(typeTag[A9]) :: Nil).toOption + val udf = SparkUserDefinedFunction.create(f, dataType, inputSchemas) if (nullable) udf else udf.asNonNullable() } @@ -3908,8 +4037,8 @@ object functions { */ def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag](f: Function10[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - val inputTypes = Try(ScalaReflection.schemaFor(typeTag[A1]).dataType :: ScalaReflection.schemaFor(typeTag[A2]).dataType :: ScalaReflection.schemaFor(typeTag[A3]).dataType :: ScalaReflection.schemaFor(typeTag[A4]).dataType :: ScalaReflection.schemaFor(typeTag[A5]).dataType :: ScalaReflection.schemaFor(typeTag[A6]).dataType :: ScalaReflection.schemaFor(typeTag[A7]).dataType :: ScalaReflection.schemaFor(typeTag[A8]).dataType :: ScalaReflection.schemaFor(typeTag[A9]).dataType :: ScalaReflection.schemaFor(typeTag[A10]).dataType :: Nil).toOption - val udf = UserDefinedFunction(f, dataType, inputTypes) + val inputSchemas = Try(ScalaReflection.schemaFor(typeTag[A1]) :: ScalaReflection.schemaFor(typeTag[A2]) :: ScalaReflection.schemaFor(typeTag[A3]) :: ScalaReflection.schemaFor(typeTag[A4]) :: ScalaReflection.schemaFor(typeTag[A5]) :: ScalaReflection.schemaFor(typeTag[A6]) :: ScalaReflection.schemaFor(typeTag[A7]) :: ScalaReflection.schemaFor(typeTag[A8]) :: ScalaReflection.schemaFor(typeTag[A9]) :: ScalaReflection.schemaFor(typeTag[A10]) :: Nil).toOption + val udf = SparkUserDefinedFunction.create(f, dataType, inputSchemas) if (nullable) udf else udf.asNonNullable() } @@ -3928,7 +4057,7 @@ object functions { */ def udf(f: UDF0[_], returnType: DataType): UserDefinedFunction = { val func = f.asInstanceOf[UDF0[Any]].call() - UserDefinedFunction(() => func, returnType, inputTypes = None) + SparkUserDefinedFunction.create(() => func, returnType, inputSchemas = None) } /** @@ -3942,7 +4071,7 @@ object functions { */ def udf(f: UDF1[_, _], returnType: DataType): UserDefinedFunction = { val func = f.asInstanceOf[UDF1[Any, Any]].call(_: Any) - UserDefinedFunction(func, returnType, inputTypes = None) + SparkUserDefinedFunction.create(func, returnType, inputSchemas = None) } /** @@ -3956,7 +4085,7 @@ object functions { */ def udf(f: UDF2[_, _, _], returnType: DataType): UserDefinedFunction = { val func = f.asInstanceOf[UDF2[Any, Any, Any]].call(_: Any, _: Any) - UserDefinedFunction(func, returnType, inputTypes = None) + SparkUserDefinedFunction.create(func, returnType, inputSchemas = None) } /** @@ -3970,7 +4099,7 @@ object functions { */ def udf(f: UDF3[_, _, _, _], returnType: DataType): UserDefinedFunction = { val func = f.asInstanceOf[UDF3[Any, Any, Any, Any]].call(_: Any, _: Any, _: Any) - UserDefinedFunction(func, returnType, inputTypes = None) + SparkUserDefinedFunction.create(func, returnType, inputSchemas = None) } /** @@ -3984,7 +4113,7 @@ object functions { */ def udf(f: UDF4[_, _, _, _, _], returnType: DataType): UserDefinedFunction = { val func = f.asInstanceOf[UDF4[Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any) - UserDefinedFunction(func, returnType, inputTypes = None) + SparkUserDefinedFunction.create(func, returnType, inputSchemas = None) } /** @@ -3998,7 +4127,7 @@ object functions { */ def udf(f: UDF5[_, _, _, _, _, _], returnType: DataType): UserDefinedFunction = { val func = f.asInstanceOf[UDF5[Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any) - UserDefinedFunction(func, returnType, inputTypes = None) + SparkUserDefinedFunction.create(func, returnType, inputSchemas = None) } /** @@ -4012,7 +4141,7 @@ object functions { */ def udf(f: UDF6[_, _, _, _, _, _, _], returnType: DataType): UserDefinedFunction = { val func = f.asInstanceOf[UDF6[Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any) - UserDefinedFunction(func, returnType, inputTypes = None) + SparkUserDefinedFunction.create(func, returnType, inputSchemas = None) } /** @@ -4026,7 +4155,7 @@ object functions { */ def udf(f: UDF7[_, _, _, _, _, _, _, _], returnType: DataType): UserDefinedFunction = { val func = f.asInstanceOf[UDF7[Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) - UserDefinedFunction(func, returnType, inputTypes = None) + SparkUserDefinedFunction.create(func, returnType, inputSchemas = None) } /** @@ -4040,7 +4169,7 @@ object functions { */ def udf(f: UDF8[_, _, _, _, _, _, _, _, _], returnType: DataType): UserDefinedFunction = { val func = f.asInstanceOf[UDF8[Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) - UserDefinedFunction(func, returnType, inputTypes = None) + SparkUserDefinedFunction.create(func, returnType, inputSchemas = None) } /** @@ -4054,7 +4183,7 @@ object functions { */ def udf(f: UDF9[_, _, _, _, _, _, _, _, _, _], returnType: DataType): UserDefinedFunction = { val func = f.asInstanceOf[UDF9[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) - UserDefinedFunction(func, returnType, inputTypes = None) + SparkUserDefinedFunction.create(func, returnType, inputSchemas = None) } /** @@ -4068,7 +4197,7 @@ object functions { */ def udf(f: UDF10[_, _, _, _, _, _, _, _, _, _, _], returnType: DataType): UserDefinedFunction = { val func = f.asInstanceOf[UDF10[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) - UserDefinedFunction(func, returnType, inputTypes = None) + SparkUserDefinedFunction.create(func, returnType, inputSchemas = None) } // scalastyle:on parameter.number @@ -4087,7 +4216,7 @@ object functions { * @since 2.0.0 */ def udf(f: AnyRef, dataType: DataType): UserDefinedFunction = { - UserDefinedFunction(f, dataType, None) + SparkUserDefinedFunction.create(f, dataType, inputSchemas = None) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/AggregatedDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/AggregatedDialect.scala index 8b92c8b4f56b5..3a3246a1b1d13 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/AggregatedDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/AggregatedDialect.scala @@ -64,7 +64,16 @@ private class AggregatedDialect(dialects: List[JdbcDialect]) extends JdbcDialect } } - override def getTruncateQuery(table: String): String = { - dialects.head.getTruncateQuery(table) + /** + * The SQL query used to truncate a table. + * @param table The table to truncate. + * @param cascade Whether or not to cascade the truncation. Default value is the + * value of isCascadingTruncateTable() + * @return The SQL query to use for truncating a table + */ + override def getTruncateQuery( + table: String, + cascade: Option[Boolean] = isCascadingTruncateTable): String = { + dialects.head.getTruncateQuery(table, cascade) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DerbyDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DerbyDialect.scala index 84f68e779c38c..d13c29ed46bd5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DerbyDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DerbyDialect.scala @@ -41,4 +41,6 @@ private object DerbyDialect extends JdbcDialect { Option(JdbcType("DECIMAL(31,5)", java.sql.Types.DECIMAL)) case _ => None } + + override def isCascadingTruncateTable(): Option[Boolean] = Some(false) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala index 83d87a11810c1..f76c1fae562c6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala @@ -22,6 +22,7 @@ import java.sql.{Connection, Date, Timestamp} import org.apache.commons.lang3.StringUtils import org.apache.spark.annotation.{DeveloperApi, InterfaceStability, Since} +import org.apache.spark.sql.execution.datasources.jdbc.JDBCOptions import org.apache.spark.sql.types._ /** @@ -120,12 +121,27 @@ abstract class JdbcDialect extends Serializable { * The SQL query that should be used to truncate a table. Dialects can override this method to * return a query that is suitable for a particular database. For PostgreSQL, for instance, * a different query is used to prevent "TRUNCATE" affecting other tables. - * @param table The name of the table. + * @param table The table to truncate * @return The SQL query to use for truncating a table */ @Since("2.3.0") def getTruncateQuery(table: String): String = { - s"TRUNCATE TABLE $table" + getTruncateQuery(table, isCascadingTruncateTable) + } + + /** + * The SQL query that should be used to truncate a table. Dialects can override this method to + * return a query that is suitable for a particular database. For PostgreSQL, for instance, + * a different query is used to prevent "TRUNCATE" affecting other tables. + * @param table The table to truncate + * @param cascade Whether or not to cascade the truncation + * @return The SQL query to use for truncating a table + */ + @Since("2.4.0") + def getTruncateQuery( + table: String, + cascade: Option[Boolean] = isCascadingTruncateTable): String = { + s"TRUNCATE TABLE $table" } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala index 6ef77f24460be..f4a6d0a4d2e44 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala @@ -95,4 +95,20 @@ private case object OracleDialect extends JdbcDialect { } override def isCascadingTruncateTable(): Option[Boolean] = Some(false) + + /** + * The SQL query used to truncate a table. + * @param table The table to truncate + * @param cascade Whether or not to cascade the truncation. Default value is the + * value of isCascadingTruncateTable() + * @return The SQL query to use for truncating a table + */ + override def getTruncateQuery( + table: String, + cascade: Option[Boolean] = isCascadingTruncateTable): String = { + cascade match { + case Some(true) => s"TRUNCATE TABLE $table CASCADE" + case _ => s"TRUNCATE TABLE $table" + } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala index 13a2035f4d0c4..f8d2bc8e0f13f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala @@ -85,15 +85,27 @@ private object PostgresDialect extends JdbcDialect { s"SELECT 1 FROM $table LIMIT 1" } + override def isCascadingTruncateTable(): Option[Boolean] = Some(false) + /** - * The SQL query used to truncate a table. For Postgres, the default behaviour is to - * also truncate any descendant tables. As this is a (possibly unwanted) side-effect, - * the Postgres dialect adds 'ONLY' to truncate only the table in question - * @param table The name of the table. - * @return The SQL query to use for truncating a table - */ - override def getTruncateQuery(table: String): String = { - s"TRUNCATE TABLE ONLY $table" + * The SQL query used to truncate a table. For Postgres, the default behaviour is to + * also truncate any descendant tables. As this is a (possibly unwanted) side-effect, + * the Postgres dialect adds 'ONLY' to truncate only the table in question + * @param table The table to truncate + * @param cascade Whether or not to cascade the truncation. Default value is the value of + * isCascadingTruncateTable(). Cascading a truncation will truncate tables + * with a foreign key relationship to the target table. However, it will not + * truncate tables with an inheritance relationship to the target table, as + * the truncate query always includes "ONLY" to prevent this behaviour. + * @return The SQL query to use for truncating a table + */ + override def getTruncateQuery( + table: String, + cascade: Option[Boolean] = isCascadingTruncateTable): String = { + cascade match { + case Some(true) => s"TRUNCATE TABLE ONLY $table CASCADE" + case _ => s"TRUNCATE TABLE ONLY $table" + } } override def beforeFetch(connection: Connection, properties: Map[String, String]): Unit = { @@ -110,5 +122,4 @@ private object PostgresDialect extends JdbcDialect { } } - override def isCascadingTruncateTable(): Option[Boolean] = Some(false) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/TeradataDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/TeradataDialect.scala index 5749b791fca25..6c17bd7ed9ec4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/TeradataDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/TeradataDialect.scala @@ -31,4 +31,22 @@ private case object TeradataDialect extends JdbcDialect { case BooleanType => Option(JdbcType("CHAR(1)", java.sql.Types.CHAR)) case _ => None } + + // Teradata does not support cascading a truncation + override def isCascadingTruncateTable(): Option[Boolean] = Some(false) + + /** + * The SQL query used to truncate a table. Teradata does not support the 'TRUNCATE' syntax that + * other dialects use. Instead, we need to use a 'DELETE FROM' statement. + * @param table The table to truncate. + * @param cascade Whether or not to cascade the truncation. Default value is the + * value of isCascadingTruncateTable(). Teradata does not support cascading a + * 'DELETE FROM' statement (and as mentioned, does not support 'TRUNCATE' syntax) + * @return The SQL query to use for truncating a table + */ + override def getTruncateQuery( + table: String, + cascade: Option[Boolean] = isCascadingTruncateTable): String = { + s"DELETE FROM $table ALL" + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/filters.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/filters.scala index 2499e9b604f3e..bdd8c4da6bd30 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/filters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/filters.scala @@ -199,7 +199,7 @@ case class StringStartsWith(attribute: String, value: String) extends Filter { /** * A filter that evaluates to `true` iff the attribute evaluates to - * a string that starts with `value`. + * a string that ends with `value`. * * @since 1.3.1 */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala index ef8dc3a325a33..2a4db4afbe005 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.streaming -import java.util.{Locale, Optional} +import java.util.Locale import scala.collection.JavaConverters._ @@ -28,8 +28,8 @@ import org.apache.spark.sql.execution.command.DDLUtils import org.apache.spark.sql.execution.datasources.DataSource import org.apache.spark.sql.execution.streaming.{StreamingRelation, StreamingRelationV2} import org.apache.spark.sql.sources.StreamSourceProvider -import org.apache.spark.sql.sources.v2.{ContinuousReadSupport, DataSourceOptions, MicroBatchReadSupport} -import org.apache.spark.sql.sources.v2.reader.streaming.MicroBatchReader +import org.apache.spark.sql.sources.v2.{ContinuousReadSupportProvider, DataSourceOptions, MicroBatchReadSupportProvider} +import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousReadSupport, MicroBatchReadSupport} import org.apache.spark.sql.types.StructType import org.apache.spark.util.Utils @@ -172,19 +172,21 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo case _ => None } ds match { - case s: MicroBatchReadSupport => - var tempReader: MicroBatchReader = null + case s: MicroBatchReadSupportProvider => + var tempReadSupport: MicroBatchReadSupport = null val schema = try { - tempReader = s.createMicroBatchReader( - Optional.ofNullable(userSpecifiedSchema.orNull), - Utils.createTempDir(namePrefix = s"temporaryReader").getCanonicalPath, - options) - tempReader.readSchema() + val tmpCheckpointPath = Utils.createTempDir(namePrefix = s"tempCP").getCanonicalPath + tempReadSupport = if (userSpecifiedSchema.isDefined) { + s.createMicroBatchReadSupport(userSpecifiedSchema.get, tmpCheckpointPath, options) + } else { + s.createMicroBatchReadSupport(tmpCheckpointPath, options) + } + tempReadSupport.fullSchema() } finally { // Stop tempReader to avoid side-effect thing - if (tempReader != null) { - tempReader.stop() - tempReader = null + if (tempReadSupport != null) { + tempReadSupport.stop() + tempReadSupport = null } } Dataset.ofRows( @@ -192,16 +194,28 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo StreamingRelationV2( s, source, extraOptions.toMap, schema.toAttributes, v1Relation)(sparkSession)) - case s: ContinuousReadSupport => - val tempReader = s.createContinuousReader( - Optional.ofNullable(userSpecifiedSchema.orNull), - Utils.createTempDir(namePrefix = s"temporaryReader").getCanonicalPath, - options) + case s: ContinuousReadSupportProvider => + var tempReadSupport: ContinuousReadSupport = null + val schema = try { + val tmpCheckpointPath = Utils.createTempDir(namePrefix = s"tempCP").getCanonicalPath + tempReadSupport = if (userSpecifiedSchema.isDefined) { + s.createContinuousReadSupport(userSpecifiedSchema.get, tmpCheckpointPath, options) + } else { + s.createContinuousReadSupport(tmpCheckpointPath, options) + } + tempReadSupport.fullSchema() + } finally { + // Stop tempReader to avoid side-effect thing + if (tempReadSupport != null) { + tempReadSupport.stop() + tempReadSupport = null + } + } Dataset.ofRows( sparkSession, StreamingRelationV2( s, source, extraOptions.toMap, - tempReader.readSchema().toAttributes, v1Relation)(sparkSession)) + schema.toAttributes, v1Relation)(sparkSession)) case _ => // Code path for data source v1. Dataset.ofRows(sparkSession, StreamingRelation(v1DataSource)) @@ -313,6 +327,7 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo * whitespaces from values being read should be skipped. *
    • `nullValue` (default empty string): sets the string representation of a null value. Since * 2.0.1, this applies to all supported types including the string type.
    • + *
    • `emptyValue` (default empty string): sets the string representation of an empty value.
    • *
    • `nanValue` (default `NaN`): sets the string representation of a non-number" value.
    • *
    • `positiveInf` (default `Inf`): sets the string representation of a positive infinity * value.
    • diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala index 926c0b69a03fd..7866e4f70f14b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala @@ -21,7 +21,7 @@ import java.util.Locale import scala.collection.JavaConverters._ -import org.apache.spark.annotation.{InterfaceStability, Since} +import org.apache.spark.annotation.InterfaceStability import org.apache.spark.api.java.function.VoidFunction2 import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.streaming.InternalOutputModes @@ -30,7 +30,7 @@ import org.apache.spark.sql.execution.datasources.DataSource import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.continuous.ContinuousTrigger import org.apache.spark.sql.execution.streaming.sources._ -import org.apache.spark.sql.sources.v2.{DataSourceOptions, StreamWriteSupport} +import org.apache.spark.sql.sources.v2.StreamingWriteSupportProvider /** * Interface used to write a streaming `Dataset` to external storage systems (e.g. file systems, @@ -250,7 +250,7 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) { val r = Dataset.ofRows(df.sparkSession, new MemoryPlanV2(s, df.schema.toAttributes)) (s, r) case _ => - val s = new MemorySink(df.schema, outputMode, new DataSourceOptions(extraOptions.asJava)) + val s = new MemorySink(df.schema, outputMode) val r = Dataset.ofRows(df.sparkSession, new MemoryPlan(s)) (s, r) } @@ -270,7 +270,7 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) { query } else if (source == "foreach") { assertNotPartitioned("foreach") - val sink = ForeachWriterProvider[T](foreachWriter, ds.exprEnc) + val sink = ForeachWriteSupportProvider[T](foreachWriter, ds.exprEnc) df.sparkSession.sessionState.streamingQueryManager.startQuery( extraOptions.get("queryName"), extraOptions.get("checkpointLocation"), @@ -299,7 +299,8 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) { val ds = DataSource.lookupDataSource(source, df.sparkSession.sessionState.conf) val disabledSources = df.sparkSession.sqlContext.conf.disabledV2StreamingWriters.split(",") val sink = ds.newInstance() match { - case w: StreamWriteSupport if !disabledSources.contains(w.getClass.getCanonicalName) => w + case w: StreamingWriteSupportProvider + if !disabledSources.contains(w.getClass.getCanonicalName) => w case _ => val ds = DataSource( df.sparkSession, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala index 25bb05212d66f..cd52d991d55c9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala @@ -34,7 +34,7 @@ import org.apache.spark.sql.execution.streaming.continuous.{ContinuousExecution, import org.apache.spark.sql.execution.streaming.state.StateStoreCoordinatorRef import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.StaticSQLConf.STREAMING_QUERY_LISTENERS -import org.apache.spark.sql.sources.v2.StreamWriteSupport +import org.apache.spark.sql.sources.v2.StreamingWriteSupportProvider import org.apache.spark.util.{Clock, SystemClock, Utils} /** @@ -256,7 +256,7 @@ class StreamingQueryManager private[sql] (sparkSession: SparkSession) extends Lo } (sink, trigger) match { - case (v2Sink: StreamWriteSupport, trigger: ContinuousTrigger) => + case (v2Sink: StreamingWriteSupportProvider, trigger: ContinuousTrigger) => if (sparkSession.sessionState.conf.isUnsupportedOperationCheckEnabled) { UnsupportedOperationChecker.checkForContinuous(analyzedPlan, outputMode) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/progress.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/progress.scala index 0dcb666e2c3e4..f2173aa1e59c2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/progress.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/progress.scala @@ -38,7 +38,8 @@ import org.apache.spark.annotation.InterfaceStability class StateOperatorProgress private[sql]( val numRowsTotal: Long, val numRowsUpdated: Long, - val memoryUsedBytes: Long + val memoryUsedBytes: Long, + val customMetrics: ju.Map[String, JLong] = new ju.HashMap() ) extends Serializable { /** The compact JSON representation of this progress. */ @@ -48,12 +49,20 @@ class StateOperatorProgress private[sql]( def prettyJson: String = pretty(render(jsonValue)) private[sql] def copy(newNumRowsUpdated: Long): StateOperatorProgress = - new StateOperatorProgress(numRowsTotal, newNumRowsUpdated, memoryUsedBytes) + new StateOperatorProgress(numRowsTotal, newNumRowsUpdated, memoryUsedBytes, customMetrics) private[sql] def jsonValue: JValue = { ("numRowsTotal" -> JInt(numRowsTotal)) ~ ("numRowsUpdated" -> JInt(numRowsUpdated)) ~ - ("memoryUsedBytes" -> JInt(memoryUsedBytes)) + ("memoryUsedBytes" -> JInt(memoryUsedBytes)) ~ + ("customMetrics" -> { + if (!customMetrics.isEmpty) { + val keys = customMetrics.keySet.asScala.toSeq.sorted + keys.map { k => k -> JInt(customMetrics.get(k).toLong) : JObject }.reduce(_ ~ _) + } else { + JNothing + } + }) } override def toString: String = prettyJson diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaColumnExpressionSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaColumnExpressionSuite.java new file mode 100644 index 0000000000000..38d606c5e108e --- /dev/null +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaColumnExpressionSuite.java @@ -0,0 +1,95 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package test.org.apache.spark.sql; + +import org.apache.spark.api.java.function.FilterFunction; +import org.apache.spark.sql.Column; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.test.TestSparkSession; +import org.apache.spark.sql.types.StructType; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import java.util.*; + +import static org.apache.spark.sql.types.DataTypes.*; + +public class JavaColumnExpressionSuite { + private transient TestSparkSession spark; + + @Before + public void setUp() { + spark = new TestSparkSession(); + } + + @After + public void tearDown() { + spark.stop(); + spark = null; + } + + @Test + public void isInCollectionWorksCorrectlyOnJava() { + List rows = Arrays.asList( + RowFactory.create(1, "x"), + RowFactory.create(2, "y"), + RowFactory.create(3, "z")); + StructType schema = createStructType(Arrays.asList( + createStructField("a", IntegerType, false), + createStructField("b", StringType, false))); + Dataset df = spark.createDataFrame(rows, schema); + // Test with different types of collections + Assert.assertTrue(Arrays.equals( + (Row[]) df.filter(df.col("a").isInCollection(Arrays.asList(1, 2))).collect(), + (Row[]) df.filter((FilterFunction) r -> r.getInt(0) == 1 || r.getInt(0) == 2).collect() + )); + Assert.assertTrue(Arrays.equals( + (Row[]) df.filter(df.col("a").isInCollection(new HashSet<>(Arrays.asList(1, 2)))).collect(), + (Row[]) df.filter((FilterFunction) r -> r.getInt(0) == 1 || r.getInt(0) == 2).collect() + )); + Assert.assertTrue(Arrays.equals( + (Row[]) df.filter(df.col("a").isInCollection(new ArrayList<>(Arrays.asList(3, 1)))).collect(), + (Row[]) df.filter((FilterFunction) r -> r.getInt(0) == 3 || r.getInt(0) == 1).collect() + )); + } + + @Test + public void isInCollectionCheckExceptionMessage() { + List rows = Arrays.asList( + RowFactory.create(1, Arrays.asList(1)), + RowFactory.create(2, Arrays.asList(2)), + RowFactory.create(3, Arrays.asList(3))); + StructType schema = createStructType(Arrays.asList( + createStructField("a", IntegerType, false), + createStructField("b", createArrayType(IntegerType, false), false))); + Dataset df = spark.createDataFrame(rows, schema); + try { + df.filter(df.col("a").isInCollection(Arrays.asList(new Column("b")))); + Assert.fail("Expected org.apache.spark.sql.AnalysisException"); + } catch (Exception e) { + Arrays.asList("cannot resolve", + "due to data type mismatch: Arguments must be same type but were") + .forEach(s -> Assert.assertTrue( + e.getMessage().toLowerCase(Locale.ROOT).contains(s.toLowerCase(Locale.ROOT)))); + } + } +} diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/execution/sort/RecordBinaryComparatorSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/execution/sort/RecordBinaryComparatorSuite.java index a19ddbdbadba2..97f3dc588ecc5 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/execution/sort/RecordBinaryComparatorSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/execution/sort/RecordBinaryComparatorSuite.java @@ -253,4 +253,70 @@ public void testBinaryComparatorForNullColumns() throws Exception { assert(compare(0, 0) == 0); assert(compare(0, 1) > 0); } + + @Test + public void testBinaryComparatorWhenSubtractionIsDivisibleByMaxIntValue() throws Exception { + int numFields = 1; + + UnsafeRow row1 = new UnsafeRow(numFields); + byte[] data1 = new byte[100]; + row1.pointTo(data1, computeSizeInBytes(numFields * 8)); + row1.setLong(0, 11); + + UnsafeRow row2 = new UnsafeRow(numFields); + byte[] data2 = new byte[100]; + row2.pointTo(data2, computeSizeInBytes(numFields * 8)); + row2.setLong(0, 11L + Integer.MAX_VALUE); + + insertRow(row1); + insertRow(row2); + + assert(compare(0, 1) < 0); + } + + @Test + public void testBinaryComparatorWhenSubtractionCanOverflowLongValue() throws Exception { + int numFields = 1; + + UnsafeRow row1 = new UnsafeRow(numFields); + byte[] data1 = new byte[100]; + row1.pointTo(data1, computeSizeInBytes(numFields * 8)); + row1.setLong(0, Long.MIN_VALUE); + + UnsafeRow row2 = new UnsafeRow(numFields); + byte[] data2 = new byte[100]; + row2.pointTo(data2, computeSizeInBytes(numFields * 8)); + row2.setLong(0, 1); + + insertRow(row1); + insertRow(row2); + + assert(compare(0, 1) < 0); + } + + @Test + public void testBinaryComparatorWhenOnlyTheLastColumnDiffers() throws Exception { + int numFields = 4; + + UnsafeRow row1 = new UnsafeRow(numFields); + byte[] data1 = new byte[100]; + row1.pointTo(data1, computeSizeInBytes(numFields * 8)); + row1.setInt(0, 11); + row1.setDouble(1, 3.14); + row1.setInt(2, -1); + row1.setLong(3, 0); + + UnsafeRow row2 = new UnsafeRow(numFields); + byte[] data2 = new byte[100]; + row2.pointTo(data2, computeSizeInBytes(numFields * 8)); + row2.setInt(0, 11); + row2.setDouble(1, 3.14); + row2.setInt(2, -1); + row2.setLong(3, 1); + + insertRow(row1); + insertRow(row2); + + assert(compare(0, 1) < 0); + } } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaAdvancedDataSourceV2.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaAdvancedDataSourceV2.java index 445cb29f5ee3a..5602310219a74 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaAdvancedDataSourceV2.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaAdvancedDataSourceV2.java @@ -20,33 +20,75 @@ import java.io.IOException; import java.util.*; -import org.apache.spark.sql.Row; -import org.apache.spark.sql.catalyst.expressions.GenericRow; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.catalyst.expressions.GenericInternalRow; import org.apache.spark.sql.sources.Filter; import org.apache.spark.sql.sources.GreaterThan; +import org.apache.spark.sql.sources.v2.BatchReadSupportProvider; import org.apache.spark.sql.sources.v2.DataSourceOptions; import org.apache.spark.sql.sources.v2.DataSourceV2; -import org.apache.spark.sql.sources.v2.ReadSupport; import org.apache.spark.sql.sources.v2.reader.*; import org.apache.spark.sql.types.StructType; -public class JavaAdvancedDataSourceV2 implements DataSourceV2, ReadSupport { +public class JavaAdvancedDataSourceV2 implements DataSourceV2, BatchReadSupportProvider { - public class Reader implements DataSourceReader, SupportsPushDownRequiredColumns, - SupportsPushDownFilters { + public class ReadSupport extends JavaSimpleReadSupport { + @Override + public ScanConfigBuilder newScanConfigBuilder() { + return new AdvancedScanConfigBuilder(); + } + + @Override + public InputPartition[] planInputPartitions(ScanConfig config) { + Filter[] filters = ((AdvancedScanConfigBuilder) config).filters; + List res = new ArrayList<>(); + + Integer lowerBound = null; + for (Filter filter : filters) { + if (filter instanceof GreaterThan) { + GreaterThan f = (GreaterThan) filter; + if ("i".equals(f.attribute()) && f.value() instanceof Integer) { + lowerBound = (Integer) f.value(); + break; + } + } + } + + if (lowerBound == null) { + res.add(new JavaRangeInputPartition(0, 5)); + res.add(new JavaRangeInputPartition(5, 10)); + } else if (lowerBound < 4) { + res.add(new JavaRangeInputPartition(lowerBound + 1, 5)); + res.add(new JavaRangeInputPartition(5, 10)); + } else if (lowerBound < 9) { + res.add(new JavaRangeInputPartition(lowerBound + 1, 10)); + } + + return res.stream().toArray(InputPartition[]::new); + } + + @Override + public PartitionReaderFactory createReaderFactory(ScanConfig config) { + StructType requiredSchema = ((AdvancedScanConfigBuilder) config).requiredSchema; + return new AdvancedReaderFactory(requiredSchema); + } + } + + public static class AdvancedScanConfigBuilder implements ScanConfigBuilder, ScanConfig, + SupportsPushDownFilters, SupportsPushDownRequiredColumns { // Exposed for testing. public StructType requiredSchema = new StructType().add("i", "int").add("j", "int"); public Filter[] filters = new Filter[0]; @Override - public StructType readSchema() { - return requiredSchema; + public void pruneColumns(StructType requiredSchema) { + this.requiredSchema = requiredSchema; } @Override - public void pruneColumns(StructType requiredSchema) { - this.requiredSchema = requiredSchema; + public StructType readSchema() { + return requiredSchema; } @Override @@ -79,79 +121,54 @@ public Filter[] pushedFilters() { } @Override - public List> planInputPartitions() { - List> res = new ArrayList<>(); - - Integer lowerBound = null; - for (Filter filter : filters) { - if (filter instanceof GreaterThan) { - GreaterThan f = (GreaterThan) filter; - if ("i".equals(f.attribute()) && f.value() instanceof Integer) { - lowerBound = (Integer) f.value(); - break; - } - } - } - - if (lowerBound == null) { - res.add(new JavaAdvancedInputPartition(0, 5, requiredSchema)); - res.add(new JavaAdvancedInputPartition(5, 10, requiredSchema)); - } else if (lowerBound < 4) { - res.add(new JavaAdvancedInputPartition(lowerBound + 1, 5, requiredSchema)); - res.add(new JavaAdvancedInputPartition(5, 10, requiredSchema)); - } else if (lowerBound < 9) { - res.add(new JavaAdvancedInputPartition(lowerBound + 1, 10, requiredSchema)); - } - - return res; + public ScanConfig build() { + return this; } } - static class JavaAdvancedInputPartition implements InputPartition, - InputPartitionReader { - private int start; - private int end; - private StructType requiredSchema; + static class AdvancedReaderFactory implements PartitionReaderFactory { + StructType requiredSchema; - JavaAdvancedInputPartition(int start, int end, StructType requiredSchema) { - this.start = start; - this.end = end; + AdvancedReaderFactory(StructType requiredSchema) { this.requiredSchema = requiredSchema; } @Override - public InputPartitionReader createPartitionReader() { - return new JavaAdvancedInputPartition(start - 1, end, requiredSchema); - } - - @Override - public boolean next() { - start += 1; - return start < end; - } + public PartitionReader createReader(InputPartition partition) { + JavaRangeInputPartition p = (JavaRangeInputPartition) partition; + return new PartitionReader() { + private int current = p.start - 1; + + @Override + public boolean next() throws IOException { + current += 1; + return current < p.end; + } - @Override - public Row get() { - Object[] values = new Object[requiredSchema.size()]; - for (int i = 0; i < values.length; i++) { - if ("i".equals(requiredSchema.apply(i).name())) { - values[i] = start; - } else if ("j".equals(requiredSchema.apply(i).name())) { - values[i] = -start; + @Override + public InternalRow get() { + Object[] values = new Object[requiredSchema.size()]; + for (int i = 0; i < values.length; i++) { + if ("i".equals(requiredSchema.apply(i).name())) { + values[i] = current; + } else if ("j".equals(requiredSchema.apply(i).name())) { + values[i] = -current; + } + } + return new GenericInternalRow(values); } - } - return new GenericRow(values); - } - @Override - public void close() throws IOException { + @Override + public void close() throws IOException { + } + }; } } @Override - public DataSourceReader createReader(DataSourceOptions options) { - return new Reader(); + public BatchReadSupport createBatchReadSupport(DataSourceOptions options) { + return new ReadSupport(); } } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaBatchDataSourceV2.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaBatchDataSourceV2.java deleted file mode 100644 index 97d6176d02559..0000000000000 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaBatchDataSourceV2.java +++ /dev/null @@ -1,114 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package test.org.apache.spark.sql.sources.v2; - -import java.io.IOException; -import java.util.List; - -import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector; -import org.apache.spark.sql.sources.v2.DataSourceOptions; -import org.apache.spark.sql.sources.v2.DataSourceV2; -import org.apache.spark.sql.sources.v2.ReadSupport; -import org.apache.spark.sql.sources.v2.reader.*; -import org.apache.spark.sql.types.DataTypes; -import org.apache.spark.sql.types.StructType; -import org.apache.spark.sql.vectorized.ColumnVector; -import org.apache.spark.sql.vectorized.ColumnarBatch; - - -public class JavaBatchDataSourceV2 implements DataSourceV2, ReadSupport { - - class Reader implements DataSourceReader, SupportsScanColumnarBatch { - private final StructType schema = new StructType().add("i", "int").add("j", "int"); - - @Override - public StructType readSchema() { - return schema; - } - - @Override - public List> planBatchInputPartitions() { - return java.util.Arrays.asList( - new JavaBatchInputPartition(0, 50), new JavaBatchInputPartition(50, 90)); - } - } - - static class JavaBatchInputPartition - implements InputPartition, InputPartitionReader { - private int start; - private int end; - - private static final int BATCH_SIZE = 20; - - private OnHeapColumnVector i; - private OnHeapColumnVector j; - private ColumnarBatch batch; - - JavaBatchInputPartition(int start, int end) { - this.start = start; - this.end = end; - } - - @Override - public InputPartitionReader createPartitionReader() { - this.i = new OnHeapColumnVector(BATCH_SIZE, DataTypes.IntegerType); - this.j = new OnHeapColumnVector(BATCH_SIZE, DataTypes.IntegerType); - ColumnVector[] vectors = new ColumnVector[2]; - vectors[0] = i; - vectors[1] = j; - this.batch = new ColumnarBatch(vectors); - return this; - } - - @Override - public boolean next() { - i.reset(); - j.reset(); - int count = 0; - while (start < end && count < BATCH_SIZE) { - i.putInt(count, start); - j.putInt(count, -start); - start += 1; - count += 1; - } - - if (count == 0) { - return false; - } else { - batch.setNumRows(count); - return true; - } - } - - @Override - public ColumnarBatch get() { - return batch; - } - - @Override - public void close() throws IOException { - batch.close(); - } - } - - - @Override - public DataSourceReader createReader(DataSourceOptions options) { - return new Reader(); - } -} diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaColumnarDataSourceV2.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaColumnarDataSourceV2.java new file mode 100644 index 0000000000000..28a9330398310 --- /dev/null +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaColumnarDataSourceV2.java @@ -0,0 +1,114 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package test.org.apache.spark.sql.sources.v2; + +import java.io.IOException; + +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector; +import org.apache.spark.sql.sources.v2.BatchReadSupportProvider; +import org.apache.spark.sql.sources.v2.DataSourceOptions; +import org.apache.spark.sql.sources.v2.DataSourceV2; +import org.apache.spark.sql.sources.v2.reader.*; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.vectorized.ColumnVector; +import org.apache.spark.sql.vectorized.ColumnarBatch; + + +public class JavaColumnarDataSourceV2 implements DataSourceV2, BatchReadSupportProvider { + + class ReadSupport extends JavaSimpleReadSupport { + + @Override + public InputPartition[] planInputPartitions(ScanConfig config) { + InputPartition[] partitions = new InputPartition[2]; + partitions[0] = new JavaRangeInputPartition(0, 50); + partitions[1] = new JavaRangeInputPartition(50, 90); + return partitions; + } + + @Override + public PartitionReaderFactory createReaderFactory(ScanConfig config) { + return new ColumnarReaderFactory(); + } + } + + static class ColumnarReaderFactory implements PartitionReaderFactory { + private static final int BATCH_SIZE = 20; + + @Override + public boolean supportColumnarReads(InputPartition partition) { + return true; + } + + @Override + public PartitionReader createReader(InputPartition partition) { + throw new UnsupportedOperationException(""); + } + + @Override + public PartitionReader createColumnarReader(InputPartition partition) { + JavaRangeInputPartition p = (JavaRangeInputPartition) partition; + OnHeapColumnVector i = new OnHeapColumnVector(BATCH_SIZE, DataTypes.IntegerType); + OnHeapColumnVector j = new OnHeapColumnVector(BATCH_SIZE, DataTypes.IntegerType); + ColumnVector[] vectors = new ColumnVector[2]; + vectors[0] = i; + vectors[1] = j; + ColumnarBatch batch = new ColumnarBatch(vectors); + + return new PartitionReader() { + private int current = p.start; + + @Override + public boolean next() throws IOException { + i.reset(); + j.reset(); + int count = 0; + while (current < p.end && count < BATCH_SIZE) { + i.putInt(count, current); + j.putInt(count, -current); + current += 1; + count += 1; + } + + if (count == 0) { + return false; + } else { + batch.setNumRows(count); + return true; + } + } + + @Override + public ColumnarBatch get() { + return batch; + } + + @Override + public void close() throws IOException { + batch.close(); + } + }; + } + } + + @Override + public BatchReadSupport createBatchReadSupport(DataSourceOptions options) { + return new ReadSupport(); + } +} diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaPartitionAwareDataSource.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaPartitionAwareDataSource.java index e49c8cf8b9e16..18a11dde82198 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaPartitionAwareDataSource.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaPartitionAwareDataSource.java @@ -19,38 +19,34 @@ import java.io.IOException; import java.util.Arrays; -import java.util.List; -import org.apache.spark.sql.Row; -import org.apache.spark.sql.catalyst.expressions.GenericRow; -import org.apache.spark.sql.sources.v2.DataSourceOptions; -import org.apache.spark.sql.sources.v2.DataSourceV2; -import org.apache.spark.sql.sources.v2.ReadSupport; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.catalyst.expressions.GenericInternalRow; +import org.apache.spark.sql.sources.v2.*; import org.apache.spark.sql.sources.v2.reader.*; import org.apache.spark.sql.sources.v2.reader.partitioning.ClusteredDistribution; import org.apache.spark.sql.sources.v2.reader.partitioning.Distribution; import org.apache.spark.sql.sources.v2.reader.partitioning.Partitioning; -import org.apache.spark.sql.types.StructType; -public class JavaPartitionAwareDataSource implements DataSourceV2, ReadSupport { +public class JavaPartitionAwareDataSource implements DataSourceV2, BatchReadSupportProvider { - class Reader implements DataSourceReader, SupportsReportPartitioning { - private final StructType schema = new StructType().add("a", "int").add("b", "int"); + class ReadSupport extends JavaSimpleReadSupport implements SupportsReportPartitioning { @Override - public StructType readSchema() { - return schema; + public InputPartition[] planInputPartitions(ScanConfig config) { + InputPartition[] partitions = new InputPartition[2]; + partitions[0] = new SpecificInputPartition(new int[]{1, 1, 3}, new int[]{4, 4, 6}); + partitions[1] = new SpecificInputPartition(new int[]{2, 4, 4}, new int[]{6, 2, 2}); + return partitions; } @Override - public List> planInputPartitions() { - return java.util.Arrays.asList( - new SpecificInputPartition(new int[]{1, 1, 3}, new int[]{4, 4, 6}), - new SpecificInputPartition(new int[]{2, 4, 4}, new int[]{6, 2, 2})); + public PartitionReaderFactory createReaderFactory(ScanConfig config) { + return new SpecificReaderFactory(); } @Override - public Partitioning outputPartitioning() { + public Partitioning outputPartitioning(ScanConfig config) { return new MyPartitioning(); } } @@ -66,48 +62,53 @@ public int numPartitions() { public boolean satisfy(Distribution distribution) { if (distribution instanceof ClusteredDistribution) { String[] clusteredCols = ((ClusteredDistribution) distribution).clusteredColumns; - return Arrays.asList(clusteredCols).contains("a"); + return Arrays.asList(clusteredCols).contains("i"); } return false; } } - static class SpecificInputPartition implements InputPartition, InputPartitionReader { - private int[] i; - private int[] j; - private int current = -1; + static class SpecificInputPartition implements InputPartition { + int[] i; + int[] j; SpecificInputPartition(int[] i, int[] j) { assert i.length == j.length; this.i = i; this.j = j; } + } - @Override - public boolean next() throws IOException { - current += 1; - return current < i.length; - } - - @Override - public Row get() { - return new GenericRow(new Object[] {i[current], j[current]}); - } - - @Override - public void close() throws IOException { - - } + static class SpecificReaderFactory implements PartitionReaderFactory { @Override - public InputPartitionReader createPartitionReader() { - return this; + public PartitionReader createReader(InputPartition partition) { + SpecificInputPartition p = (SpecificInputPartition) partition; + return new PartitionReader() { + private int current = -1; + + @Override + public boolean next() throws IOException { + current += 1; + return current < p.i.length; + } + + @Override + public InternalRow get() { + return new GenericInternalRow(new Object[] {p.i[current], p.j[current]}); + } + + @Override + public void close() throws IOException { + + } + }; } } @Override - public DataSourceReader createReader(DataSourceOptions options) { - return new Reader(); + public BatchReadSupport createBatchReadSupport(DataSourceOptions options) { + return new ReadSupport(); } } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSchemaRequiredDataSource.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSchemaRequiredDataSource.java index 80eeffd95f83b..cc9ac04a0dad3 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSchemaRequiredDataSource.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSchemaRequiredDataSource.java @@ -17,38 +17,39 @@ package test.org.apache.spark.sql.sources.v2; -import java.util.List; - -import org.apache.spark.sql.Row; +import org.apache.spark.sql.sources.v2.BatchReadSupportProvider; import org.apache.spark.sql.sources.v2.DataSourceOptions; import org.apache.spark.sql.sources.v2.DataSourceV2; -import org.apache.spark.sql.sources.v2.ReadSupportWithSchema; -import org.apache.spark.sql.sources.v2.reader.DataSourceReader; -import org.apache.spark.sql.sources.v2.reader.InputPartition; +import org.apache.spark.sql.sources.v2.reader.*; import org.apache.spark.sql.types.StructType; -public class JavaSchemaRequiredDataSource implements DataSourceV2, ReadSupportWithSchema { +public class JavaSchemaRequiredDataSource implements DataSourceV2, BatchReadSupportProvider { - class Reader implements DataSourceReader { + class ReadSupport extends JavaSimpleReadSupport { private final StructType schema; - Reader(StructType schema) { + ReadSupport(StructType schema) { this.schema = schema; } @Override - public StructType readSchema() { + public StructType fullSchema() { return schema; } @Override - public List> planInputPartitions() { - return java.util.Collections.emptyList(); + public InputPartition[] planInputPartitions(ScanConfig config) { + return new InputPartition[0]; } } @Override - public DataSourceReader createReader(StructType schema, DataSourceOptions options) { - return new Reader(schema); + public BatchReadSupport createBatchReadSupport(DataSourceOptions options) { + throw new IllegalArgumentException("requires a user-supplied schema"); + } + + @Override + public BatchReadSupport createBatchReadSupport(StructType schema, DataSourceOptions options) { + return new ReadSupport(schema); } } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleDataSourceV2.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleDataSourceV2.java index 8522a63898a3b..2cdbba84ec4a4 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleDataSourceV2.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleDataSourceV2.java @@ -17,70 +17,26 @@ package test.org.apache.spark.sql.sources.v2; -import java.io.IOException; -import java.util.List; - -import org.apache.spark.sql.Row; -import org.apache.spark.sql.catalyst.expressions.GenericRow; +import org.apache.spark.sql.sources.v2.BatchReadSupportProvider; import org.apache.spark.sql.sources.v2.DataSourceV2; import org.apache.spark.sql.sources.v2.DataSourceOptions; -import org.apache.spark.sql.sources.v2.ReadSupport; -import org.apache.spark.sql.sources.v2.reader.InputPartitionReader; -import org.apache.spark.sql.sources.v2.reader.InputPartition; -import org.apache.spark.sql.sources.v2.reader.DataSourceReader; -import org.apache.spark.sql.types.StructType; - -public class JavaSimpleDataSourceV2 implements DataSourceV2, ReadSupport { - - class Reader implements DataSourceReader { - private final StructType schema = new StructType().add("i", "int").add("j", "int"); +import org.apache.spark.sql.sources.v2.reader.*; - @Override - public StructType readSchema() { - return schema; - } +public class JavaSimpleDataSourceV2 implements DataSourceV2, BatchReadSupportProvider { - @Override - public List> planInputPartitions() { - return java.util.Arrays.asList( - new JavaSimpleInputPartition(0, 5), - new JavaSimpleInputPartition(5, 10)); - } - } - - static class JavaSimpleInputPartition implements InputPartition, InputPartitionReader { - private int start; - private int end; - - JavaSimpleInputPartition(int start, int end) { - this.start = start; - this.end = end; - } + class ReadSupport extends JavaSimpleReadSupport { @Override - public InputPartitionReader createPartitionReader() { - return new JavaSimpleInputPartition(start - 1, end); - } - - @Override - public boolean next() { - start += 1; - return start < end; - } - - @Override - public Row get() { - return new GenericRow(new Object[] {start, -start}); - } - - @Override - public void close() throws IOException { - + public InputPartition[] planInputPartitions(ScanConfig config) { + InputPartition[] partitions = new InputPartition[2]; + partitions[0] = new JavaRangeInputPartition(0, 5); + partitions[1] = new JavaRangeInputPartition(5, 10); + return partitions; } } @Override - public DataSourceReader createReader(DataSourceOptions options) { - return new Reader(); + public BatchReadSupport createBatchReadSupport(DataSourceOptions options) { + return new ReadSupport(); } } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleReadSupport.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleReadSupport.java new file mode 100644 index 0000000000000..685f9b9747e85 --- /dev/null +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleReadSupport.java @@ -0,0 +1,99 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package test.org.apache.spark.sql.sources.v2; + +import java.io.IOException; + +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.catalyst.expressions.GenericInternalRow; +import org.apache.spark.sql.sources.v2.reader.*; +import org.apache.spark.sql.types.StructType; + +abstract class JavaSimpleReadSupport implements BatchReadSupport { + + @Override + public StructType fullSchema() { + return new StructType().add("i", "int").add("j", "int"); + } + + @Override + public ScanConfigBuilder newScanConfigBuilder() { + return new JavaNoopScanConfigBuilder(fullSchema()); + } + + @Override + public PartitionReaderFactory createReaderFactory(ScanConfig config) { + return new JavaSimpleReaderFactory(); + } +} + +class JavaNoopScanConfigBuilder implements ScanConfigBuilder, ScanConfig { + + private StructType schema; + + JavaNoopScanConfigBuilder(StructType schema) { + this.schema = schema; + } + + @Override + public ScanConfig build() { + return this; + } + + @Override + public StructType readSchema() { + return schema; + } +} + +class JavaSimpleReaderFactory implements PartitionReaderFactory { + + @Override + public PartitionReader createReader(InputPartition partition) { + JavaRangeInputPartition p = (JavaRangeInputPartition) partition; + return new PartitionReader() { + private int current = p.start - 1; + + @Override + public boolean next() throws IOException { + current += 1; + return current < p.end; + } + + @Override + public InternalRow get() { + return new GenericInternalRow(new Object[] {current, -current}); + } + + @Override + public void close() throws IOException { + + } + }; + } +} + +class JavaRangeInputPartition implements InputPartition { + int start; + int end; + + JavaRangeInputPartition(int start, int end) { + this.start = start; + this.end = end; + } +} diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaUnsafeRowDataSourceV2.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaUnsafeRowDataSourceV2.java deleted file mode 100644 index 3ad8e7a0104ce..0000000000000 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaUnsafeRowDataSourceV2.java +++ /dev/null @@ -1,89 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package test.org.apache.spark.sql.sources.v2; - -import java.io.IOException; -import java.util.List; - -import org.apache.spark.sql.catalyst.expressions.UnsafeRow; -import org.apache.spark.sql.sources.v2.DataSourceOptions; -import org.apache.spark.sql.sources.v2.DataSourceV2; -import org.apache.spark.sql.sources.v2.ReadSupport; -import org.apache.spark.sql.sources.v2.reader.*; -import org.apache.spark.sql.types.StructType; - -public class JavaUnsafeRowDataSourceV2 implements DataSourceV2, ReadSupport { - - class Reader implements DataSourceReader, SupportsScanUnsafeRow { - private final StructType schema = new StructType().add("i", "int").add("j", "int"); - - @Override - public StructType readSchema() { - return schema; - } - - @Override - public List> planUnsafeInputPartitions() { - return java.util.Arrays.asList( - new JavaUnsafeRowInputPartition(0, 5), - new JavaUnsafeRowInputPartition(5, 10)); - } - } - - static class JavaUnsafeRowInputPartition - implements InputPartition, InputPartitionReader { - private int start; - private int end; - private UnsafeRow row; - - JavaUnsafeRowInputPartition(int start, int end) { - this.start = start; - this.end = end; - this.row = new UnsafeRow(2); - row.pointTo(new byte[8 * 3], 8 * 3); - } - - @Override - public InputPartitionReader createPartitionReader() { - return new JavaUnsafeRowInputPartition(start - 1, end); - } - - @Override - public boolean next() { - start += 1; - return start < end; - } - - @Override - public UnsafeRow get() { - row.setInt(0, start); - row.setInt(1, -start); - return row; - } - - @Override - public void close() throws IOException { - - } - } - - @Override - public DataSourceReader createReader(DataSourceOptions options) { - return new Reader(); - } -} diff --git a/sql/core/src/test/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister b/sql/core/src/test/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister index 46b38bed1c0fb..a36b0cfa6ff18 100644 --- a/sql/core/src/test/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister +++ b/sql/core/src/test/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister @@ -9,6 +9,6 @@ org.apache.spark.sql.streaming.sources.FakeReadMicroBatchOnly org.apache.spark.sql.streaming.sources.FakeReadContinuousOnly org.apache.spark.sql.streaming.sources.FakeReadBothModes org.apache.spark.sql.streaming.sources.FakeReadNeitherMode -org.apache.spark.sql.streaming.sources.FakeWrite +org.apache.spark.sql.streaming.sources.FakeWriteSupportProvider org.apache.spark.sql.streaming.sources.FakeNoWrite -org.apache.spark.sql.streaming.sources.FakeWriteV1Fallback +org.apache.spark.sql.streaming.sources.FakeWriteSupportProviderV1Fallback diff --git a/sql/core/src/test/resources/sql-tests/inputs/columnresolution-views.sql b/sql/core/src/test/resources/sql-tests/inputs/columnresolution-views.sql index d3f928751757c..83c32a5bf2435 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/columnresolution-views.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/columnresolution-views.sql @@ -13,10 +13,8 @@ DROP VIEW view1; -- Test scenario with Global Temp view CREATE OR REPLACE GLOBAL TEMPORARY VIEW view1 as SELECT 1 as i1; SELECT * FROM global_temp.view1; --- TODO: Support this scenario SELECT global_temp.view1.* FROM global_temp.view1; SELECT i1 FROM global_temp.view1; --- TODO: Support this scenario SELECT global_temp.view1.i1 FROM global_temp.view1; SELECT view1.i1 FROM global_temp.view1; SELECT a.i1 FROM global_temp.view1 AS a; diff --git a/sql/core/src/test/resources/sql-tests/inputs/columnresolution.sql b/sql/core/src/test/resources/sql-tests/inputs/columnresolution.sql index 79e90ad3de91d..d001185a73931 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/columnresolution.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/columnresolution.sql @@ -14,9 +14,7 @@ SELECT i1 FROM mydb1.t1; SELECT t1.i1 FROM t1; SELECT t1.i1 FROM mydb1.t1; --- TODO: Support this scenario SELECT mydb1.t1.i1 FROM t1; --- TODO: Support this scenario SELECT mydb1.t1.i1 FROM mydb1.t1; USE mydb2; @@ -24,7 +22,6 @@ SELECT i1 FROM t1; SELECT i1 FROM mydb1.t1; SELECT t1.i1 FROM t1; SELECT t1.i1 FROM mydb1.t1; --- TODO: Support this scenario SELECT mydb1.t1.i1 FROM mydb1.t1; -- Scenario: resolve fully qualified table name in star expansion @@ -34,7 +31,6 @@ SELECT mydb1.t1.* FROM mydb1.t1; SELECT t1.* FROM mydb1.t1; USE mydb2; SELECT t1.* FROM t1; --- TODO: Support this scenario SELECT mydb1.t1.* FROM mydb1.t1; SELECT t1.* FROM mydb1.t1; SELECT a.* FROM mydb1.t1 AS a; @@ -47,21 +43,17 @@ CREATE TABLE t4 USING parquet AS SELECT * FROM VALUES (4,1), (2,1) AS t4(c2, c3) SELECT * FROM t3 WHERE c1 IN (SELECT c2 FROM t4 WHERE t4.c3 = t3.c2); --- TODO: Support this scenario SELECT * FROM mydb1.t3 WHERE c1 IN (SELECT mydb1.t4.c2 FROM mydb1.t4 WHERE mydb1.t4.c3 = mydb1.t3.c2); -- Scenario: column resolution scenarios in join queries SET spark.sql.crossJoin.enabled = true; --- TODO: Support this scenario SELECT mydb1.t1.i1 FROM t1, mydb2.t1; --- TODO: Support this scenario SELECT mydb1.t1.i1 FROM mydb1.t1, mydb2.t1; USE mydb2; --- TODO: Support this scenario SELECT mydb1.t1.i1 FROM t1, mydb1.t1; SET spark.sql.crossJoin.enabled = false; @@ -75,12 +67,10 @@ SELECT t5.t5.i1 FROM mydb1.t5; SELECT t5.i1 FROM mydb1.t5; SELECT t5.* FROM mydb1.t5; SELECT t5.t5.* FROM mydb1.t5; --- TODO: Support this scenario SELECT mydb1.t5.t5.i1 FROM mydb1.t5; --- TODO: Support this scenario SELECT mydb1.t5.t5.i2 FROM mydb1.t5; --- TODO: Support this scenario SELECT mydb1.t5.* FROM mydb1.t5; +SELECT mydb1.t5.* FROM t5; -- Cleanup and Reset USE default; diff --git a/sql/core/src/test/resources/sql-tests/inputs/except-all.sql b/sql/core/src/test/resources/sql-tests/inputs/except-all.sql new file mode 100644 index 0000000000000..e28f0721a6449 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/except-all.sql @@ -0,0 +1,160 @@ +CREATE TEMPORARY VIEW tab1 AS SELECT * FROM VALUES + (0), (1), (2), (2), (2), (2), (3), (null), (null) AS tab1(c1); +CREATE TEMPORARY VIEW tab2 AS SELECT * FROM VALUES + (1), (2), (2), (3), (5), (5), (null) AS tab2(c1); +CREATE TEMPORARY VIEW tab3 AS SELECT * FROM VALUES + (1, 2), + (1, 2), + (1, 3), + (2, 3), + (2, 2) + AS tab3(k, v); +CREATE TEMPORARY VIEW tab4 AS SELECT * FROM VALUES + (1, 2), + (2, 3), + (2, 2), + (2, 2), + (2, 20) + AS tab4(k, v); + +-- Basic EXCEPT ALL +SELECT * FROM tab1 +EXCEPT ALL +SELECT * FROM tab2; + +-- MINUS ALL (synonym for EXCEPT) +SELECT * FROM tab1 +MINUS ALL +SELECT * FROM tab2; + +-- EXCEPT ALL same table in both branches +SELECT * FROM tab1 +EXCEPT ALL +SELECT * FROM tab2 WHERE c1 IS NOT NULL; + +-- Empty left relation +SELECT * FROM tab1 WHERE c1 > 5 +EXCEPT ALL +SELECT * FROM tab2; + +-- Empty right relation +SELECT * FROM tab1 +EXCEPT ALL +SELECT * FROM tab2 WHERE c1 > 6; + +-- Type Coerced ExceptAll +SELECT * FROM tab1 +EXCEPT ALL +SELECT CAST(1 AS BIGINT); + +-- Error as types of two side are not compatible +SELECT * FROM tab1 +EXCEPT ALL +SELECT array(1); + +-- Basic +SELECT * FROM tab3 +EXCEPT ALL +SELECT * FROM tab4; + +-- Basic +SELECT * FROM tab4 +EXCEPT ALL +SELECT * FROM tab3; + +-- EXCEPT ALL + INTERSECT +SELECT * FROM tab4 +EXCEPT ALL +SELECT * FROM tab3 +INTERSECT DISTINCT +SELECT * FROM tab4; + +-- EXCEPT ALL + EXCEPT +SELECT * FROM tab4 +EXCEPT ALL +SELECT * FROM tab3 +EXCEPT DISTINCT +SELECT * FROM tab4; + +-- Chain of set operations +SELECT * FROM tab3 +EXCEPT ALL +SELECT * FROM tab4 +UNION ALL +SELECT * FROM tab3 +EXCEPT DISTINCT +SELECT * FROM tab4; + +-- Mismatch on number of columns across both branches +SELECT k FROM tab3 +EXCEPT ALL +SELECT k, v FROM tab4; + +-- Chain of set operations +SELECT * FROM tab3 +EXCEPT ALL +SELECT * FROM tab4 +UNION +SELECT * FROM tab3 +EXCEPT DISTINCT +SELECT * FROM tab4; + +-- Using MINUS ALL +SELECT * FROM tab3 +MINUS ALL +SELECT * FROM tab4 +UNION +SELECT * FROM tab3 +MINUS DISTINCT +SELECT * FROM tab4; + +-- Chain of set operations +SELECT * FROM tab3 +EXCEPT ALL +SELECT * FROM tab4 +EXCEPT DISTINCT +SELECT * FROM tab3 +EXCEPT DISTINCT +SELECT * FROM tab4; + +-- Join under except all. Should produce empty resultset since both left and right sets +-- are same. +SELECT * +FROM (SELECT tab3.k, + tab4.v + FROM tab3 + JOIN tab4 + ON tab3.k = tab4.k) +EXCEPT ALL +SELECT * +FROM (SELECT tab3.k, + tab4.v + FROM tab3 + JOIN tab4 + ON tab3.k = tab4.k); + +-- Join under except all (2) +SELECT * +FROM (SELECT tab3.k, + tab4.v + FROM tab3 + JOIN tab4 + ON tab3.k = tab4.k) +EXCEPT ALL +SELECT * +FROM (SELECT tab4.v AS k, + tab3.k AS v + FROM tab3 + JOIN tab4 + ON tab3.k = tab4.k); + +-- Group by under ExceptAll +SELECT v FROM tab3 GROUP BY v +EXCEPT ALL +SELECT k FROM tab4 GROUP BY k; + +-- Clean-up +DROP VIEW IF EXISTS tab1; +DROP VIEW IF EXISTS tab2; +DROP VIEW IF EXISTS tab3; +DROP VIEW IF EXISTS tab4; diff --git a/sql/core/src/test/resources/sql-tests/inputs/group-by-ordinal.sql b/sql/core/src/test/resources/sql-tests/inputs/group-by-ordinal.sql index 928f766b4add2..3144833b608be 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/group-by-ordinal.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/group-by-ordinal.sql @@ -38,7 +38,9 @@ select a, b, sum(b) from data group by 3; select a, b, sum(b) + 2 from data group by 3; -- negative case: nondeterministic expression -select a, rand(0), sum(b) from data group by a, 2; +select a, rand(0), sum(b) +from +(select /*+ REPARTITION(1) */ a, b from data) group by a, 2; -- negative case: star select * from data group by a, b, 1; diff --git a/sql/core/src/test/resources/sql-tests/inputs/grouping_set.sql b/sql/core/src/test/resources/sql-tests/inputs/grouping_set.sql index 3594283505280..6bbde9f38d657 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/grouping_set.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/grouping_set.sql @@ -13,5 +13,41 @@ SELECT a, b, c, count(d) FROM grouping GROUP BY a, b, c GROUPING SETS ((a)); -- SPARK-17849: grouping set throws NPE #3 SELECT a, b, c, count(d) FROM grouping GROUP BY a, b, c GROUPING SETS ((c)); +-- Group sets without explicit group by +SELECT c1, sum(c2) FROM (VALUES ('x', 10, 0), ('y', 20, 0)) AS t (c1, c2, c3) GROUP BY GROUPING SETS (c1); +-- Group sets without group by and with grouping +SELECT c1, sum(c2), grouping(c1) FROM (VALUES ('x', 10, 0), ('y', 20, 0)) AS t (c1, c2, c3) GROUP BY GROUPING SETS (c1); + +-- Mutiple grouping within a grouping set +SELECT c1, c2, Sum(c3), grouping__id +FROM (VALUES ('x', 'a', 10), ('y', 'b', 20) ) AS t (c1, c2, c3) +GROUP BY GROUPING SETS ( ( c1 ), ( c2 ) ) +HAVING GROUPING__ID > 1; + +-- Group sets without explicit group by +SELECT grouping(c1) FROM (VALUES ('x', 'a', 10), ('y', 'b', 20)) AS t (c1, c2, c3) GROUP BY GROUPING SETS (c1,c2); + +-- Mutiple grouping within a grouping set +SELECT -c1 AS c1 FROM (values (1,2), (3,2)) t(c1, c2) GROUP BY GROUPING SETS ((c1), (c1, c2)); + +-- complex expression in grouping sets +SELECT a + b, b, sum(c) FROM (VALUES (1,1,1),(2,2,2)) AS t(a,b,c) GROUP BY GROUPING SETS ( (a + b), (b)); + +-- complex expression in grouping sets +SELECT a + b, b, sum(c) FROM (VALUES (1,1,1),(2,2,2)) AS t(a,b,c) GROUP BY GROUPING SETS ( (a + b), (b + a), (b)); + +-- more query constructs with grouping sets +SELECT c1 AS col1, c2 AS col2 +FROM (VALUES (1, 2), (3, 2)) t(c1, c2) +GROUP BY GROUPING SETS ( ( c1 ), ( c1, c2 ) ) +HAVING col2 IS NOT NULL +ORDER BY -col1; + +-- negative tests - must have at least one grouping expression +SELECT a, b, c, count(d) FROM grouping GROUP BY WITH ROLLUP; + +SELECT a, b, c, count(d) FROM grouping GROUP BY WITH CUBE; + +SELECT c1 FROM (values (1,2), (3,2)) t(c1, c2) GROUP BY GROUPING SETS (()); diff --git a/sql/core/src/test/resources/sql-tests/inputs/higher-order-functions.sql b/sql/core/src/test/resources/sql-tests/inputs/higher-order-functions.sql new file mode 100644 index 0000000000000..02ad5e3538689 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/higher-order-functions.sql @@ -0,0 +1,85 @@ +create or replace temporary view nested as values + (1, array(32, 97), array(array(12, 99), array(123, 42), array(1))), + (2, array(77, -76), array(array(6, 96, 65), array(-1, -2))), + (3, array(12), array(array(17))) + as t(x, ys, zs); + +-- Only allow lambda's in higher order functions. +select upper(x -> x) as v; + +-- Identity transform an array +select transform(zs, z -> z) as v from nested; + +-- Transform an array +select transform(ys, y -> y * y) as v from nested; + +-- Transform an array with index +select transform(ys, (y, i) -> y + i) as v from nested; + +-- Transform an array with reference +select transform(zs, z -> concat(ys, z)) as v from nested; + +-- Transform an array to an array of 0's +select transform(ys, 0) as v from nested; + +-- Transform a null array +select transform(cast(null as array), x -> x + 1) as v; + +-- Filter. +select filter(ys, y -> y > 30) as v from nested; + +-- Filter a null array +select filter(cast(null as array), y -> true) as v; + +-- Filter nested arrays +select transform(zs, z -> filter(z, zz -> zz > 50)) as v from nested; + +-- Aggregate. +select aggregate(ys, 0, (y, a) -> y + a + x) as v from nested; + +-- Aggregate average. +select aggregate(ys, (0 as sum, 0 as n), (acc, x) -> (acc.sum + x, acc.n + 1), acc -> acc.sum / acc.n) as v from nested; + +-- Aggregate nested arrays +select transform(zs, z -> aggregate(z, 1, (acc, val) -> acc * val * size(z))) as v from nested; + +-- Aggregate a null array +select aggregate(cast(null as array), 0, (a, y) -> a + y + 1, a -> a + 2) as v; + +-- Check for element existence +select exists(ys, y -> y > 30) as v from nested; + +-- Check for element existence in a null array +select exists(cast(null as array), y -> y > 30) as v; + +-- Zip with array +select zip_with(ys, zs, (a, b) -> a + size(b)) as v from nested; + +-- Zip with array with concat +select zip_with(array('a', 'b', 'c'), array('d', 'e', 'f'), (x, y) -> concat(x, y)) as v; + +-- Zip with array coalesce +select zip_with(array('a'), array('d', null, 'f'), (x, y) -> coalesce(x, y)) as v; + +create or replace temporary view nested as values + (1, map(1, 1, 2, 2, 3, 3)), + (2, map(4, 4, 5, 5, 6, 6)) + as t(x, ys); + +-- Identity Transform Keys in a map +select transform_keys(ys, (k, v) -> k) as v from nested; + +-- Transform Keys in a map by adding constant +select transform_keys(ys, (k, v) -> k + 1) as v from nested; + +-- Transform Keys in a map using values +select transform_keys(ys, (k, v) -> k + v) as v from nested; + +-- Identity Transform values in a map +select transform_values(ys, (k, v) -> v) as v from nested; + +-- Transform values in a map by adding constant +select transform_values(ys, (k, v) -> v + 1) as v from nested; + +-- Transform values in a map using values +select transform_values(ys, (k, v) -> k + v) as v from nested; diff --git a/sql/core/src/test/resources/sql-tests/inputs/intersect-all.sql b/sql/core/src/test/resources/sql-tests/inputs/intersect-all.sql new file mode 100644 index 0000000000000..b0b2244048caa --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/intersect-all.sql @@ -0,0 +1,160 @@ +CREATE TEMPORARY VIEW tab1 AS SELECT * FROM VALUES + (1, 2), + (1, 2), + (1, 3), + (1, 3), + (2, 3), + (null, null), + (null, null) + AS tab1(k, v); +CREATE TEMPORARY VIEW tab2 AS SELECT * FROM VALUES + (1, 2), + (1, 2), + (2, 3), + (3, 4), + (null, null), + (null, null) + AS tab2(k, v); + +-- Basic INTERSECT ALL +SELECT * FROM tab1 +INTERSECT ALL +SELECT * FROM tab2; + +-- INTERSECT ALL same table in both branches +SELECT * FROM tab1 +INTERSECT ALL +SELECT * FROM tab1 WHERE k = 1; + +-- Empty left relation +SELECT * FROM tab1 WHERE k > 2 +INTERSECT ALL +SELECT * FROM tab2; + +-- Empty right relation +SELECT * FROM tab1 +INTERSECT ALL +SELECT * FROM tab2 WHERE k > 3; + +-- Type Coerced INTERSECT ALL +SELECT * FROM tab1 +INTERSECT ALL +SELECT CAST(1 AS BIGINT), CAST(2 AS BIGINT); + +-- Error as types of two side are not compatible +SELECT * FROM tab1 +INTERSECT ALL +SELECT array(1), 2; + +-- Mismatch on number of columns across both branches +SELECT k FROM tab1 +INTERSECT ALL +SELECT k, v FROM tab2; + +-- Basic +SELECT * FROM tab2 +INTERSECT ALL +SELECT * FROM tab1 +INTERSECT ALL +SELECT * FROM tab2; + +-- Chain of different `set operations +SELECT * FROM tab1 +EXCEPT +SELECT * FROM tab2 +UNION ALL +SELECT * FROM tab1 +INTERSECT ALL +SELECT * FROM tab2 +; + +-- Chain of different `set operations +SELECT * FROM tab1 +EXCEPT +SELECT * FROM tab2 +EXCEPT +SELECT * FROM tab1 +INTERSECT ALL +SELECT * FROM tab2 +; + +-- test use parenthesis to control order of evaluation +( + ( + ( + SELECT * FROM tab1 + EXCEPT + SELECT * FROM tab2 + ) + EXCEPT + SELECT * FROM tab1 + ) + INTERSECT ALL + SELECT * FROM tab2 +) +; + +-- Join under intersect all +SELECT * +FROM (SELECT tab1.k, + tab2.v + FROM tab1 + JOIN tab2 + ON tab1.k = tab2.k) +INTERSECT ALL +SELECT * +FROM (SELECT tab1.k, + tab2.v + FROM tab1 + JOIN tab2 + ON tab1.k = tab2.k); + +-- Join under intersect all (2) +SELECT * +FROM (SELECT tab1.k, + tab2.v + FROM tab1 + JOIN tab2 + ON tab1.k = tab2.k) +INTERSECT ALL +SELECT * +FROM (SELECT tab2.v AS k, + tab1.k AS v + FROM tab1 + JOIN tab2 + ON tab1.k = tab2.k); + +-- Group by under intersect all +SELECT v FROM tab1 GROUP BY v +INTERSECT ALL +SELECT k FROM tab2 GROUP BY k; + +-- Test pre spark2.4 behaviour of set operation precedence +-- All the set operators are given equal precedence and are evaluated +-- from left to right as they appear in the query. + +-- Set the property +SET spark.sql.legacy.setopsPrecedence.enabled= true; + +SELECT * FROM tab1 +EXCEPT +SELECT * FROM tab2 +UNION ALL +SELECT * FROM tab1 +INTERSECT ALL +SELECT * FROM tab2; + +SELECT * FROM tab1 +EXCEPT +SELECT * FROM tab2 +UNION ALL +SELECT * FROM tab1 +INTERSECT +SELECT * FROM tab2; + +-- Restore the property +SET spark.sql.legacy.setopsPrecedence.enabled = false; + +-- Clean-up +DROP VIEW IF EXISTS tab1; +DROP VIEW IF EXISTS tab2; diff --git a/sql/core/src/test/resources/sql-tests/inputs/join-empty-relation.sql b/sql/core/src/test/resources/sql-tests/inputs/join-empty-relation.sql index 8afa3270f4de4..2e6a5f362a8fa 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/join-empty-relation.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/join-empty-relation.sql @@ -1,3 +1,8 @@ +-- List of configuration the test suite is run against: +--SET spark.sql.autoBroadcastJoinThreshold=10485760 +--SET spark.sql.autoBroadcastJoinThreshold=-1,spark.sql.join.preferSortMergeJoin=true +--SET spark.sql.autoBroadcastJoinThreshold=-1,spark.sql.join.preferSortMergeJoin=false + CREATE TEMPORARY VIEW t1 AS SELECT * FROM VALUES (1) AS GROUPING(a); CREATE TEMPORARY VIEW t2 AS SELECT * FROM VALUES (1) AS GROUPING(a); diff --git a/sql/core/src/test/resources/sql-tests/inputs/json-functions.sql b/sql/core/src/test/resources/sql-tests/inputs/json-functions.sql index dc15d13cd1dd3..0f22c0eeed581 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/json-functions.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/json-functions.sql @@ -35,3 +35,24 @@ DROP VIEW IF EXISTS jsonTable; -- from_json - complex types select from_json('{"a":1, "b":2}', 'map'); select from_json('{"a":1, "b":"2"}', 'struct'); + +-- infer schema of json literal +select schema_of_json('{"c1":0, "c2":[1]}'); +select from_json('{"c1":[1, 2, 3]}', schema_of_json('{"c1":[0]}')); + +-- from_json - array type +select from_json('[1, 2, 3]', 'array'); +select from_json('[1, "2", 3]', 'array'); +select from_json('[1, 2, null]', 'array'); + +select from_json('[{"a": 1}, {"a":2}]', 'array>'); +select from_json('{"a": 1}', 'array>'); +select from_json('[null, {"a":2}]', 'array>'); + +select from_json('[{"a": 1}, {"b":2}]', 'array>'); +select from_json('[{"a": 1}, 2]', 'array>'); + +-- to_json - array type +select to_json(array('1', '2', '3')); +select to_json(array(array(1, 2, 3), array(4))); + diff --git a/sql/core/src/test/resources/sql-tests/inputs/limit.sql b/sql/core/src/test/resources/sql-tests/inputs/limit.sql index f21912a042716..e33cd819f281f 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/limit.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/limit.sql @@ -1,3 +1,5 @@ +-- Disable global limit parallel +set spark.sql.limit.flatGlobalLimit=false; -- limit on various data types SELECT * FROM testdata LIMIT 2; @@ -13,6 +15,11 @@ SELECT * FROM testdata LIMIT CAST(1 AS int); SELECT * FROM testdata LIMIT -1; SELECT * FROM testData TABLESAMPLE (-1 ROWS); + +SELECT * FROM testdata LIMIT CAST(1 AS INT); +-- evaluated limit must not be null +SELECT * FROM testdata LIMIT CAST(NULL AS INT); + -- limit must be foldable SELECT * FROM testdata LIMIT key > 3; diff --git a/sql/core/src/test/resources/sql-tests/inputs/natural-join.sql b/sql/core/src/test/resources/sql-tests/inputs/natural-join.sql index 71a50157b766c..e0abeda3eb44f 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/natural-join.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/natural-join.sql @@ -1,3 +1,8 @@ +-- List of configuration the test suite is run against: +--SET spark.sql.autoBroadcastJoinThreshold=10485760 +--SET spark.sql.autoBroadcastJoinThreshold=-1,spark.sql.join.preferSortMergeJoin=true +--SET spark.sql.autoBroadcastJoinThreshold=-1,spark.sql.join.preferSortMergeJoin=false + create temporary view nt1 as select * from values ("one", 1), ("two", 2), diff --git a/sql/core/src/test/resources/sql-tests/inputs/outer-join.sql b/sql/core/src/test/resources/sql-tests/inputs/outer-join.sql index cdc6c81e10047..ce09c21568f13 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/outer-join.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/outer-join.sql @@ -1,3 +1,8 @@ +-- List of configuration the test suite is run against: +--SET spark.sql.autoBroadcastJoinThreshold=10485760 +--SET spark.sql.autoBroadcastJoinThreshold=-1,spark.sql.join.preferSortMergeJoin=true +--SET spark.sql.autoBroadcastJoinThreshold=-1,spark.sql.join.preferSortMergeJoin=false + -- SPARK-17099: Incorrect result when HAVING clause is added to group by query CREATE OR REPLACE TEMPORARY VIEW t1 AS SELECT * FROM VALUES (-234), (145), (367), (975), (298) diff --git a/sql/core/src/test/resources/sql-tests/inputs/pivot.sql b/sql/core/src/test/resources/sql-tests/inputs/pivot.sql index 01dea6c81c11b..1f607b334dc18 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/pivot.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/pivot.sql @@ -11,6 +11,11 @@ create temporary view years as select * from values (2013, 2) as years(y, s); +create temporary view yearsWithComplexTypes as select * from values + (2012, array(1, 1), map('1', 1), struct(1, 'a')), + (2013, array(2, 2), map('2', 2), struct(2, 'b')) + as yearsWithComplexTypes(y, a, m, s); + -- pivot courses SELECT * FROM ( SELECT year, course, earnings FROM courseSales @@ -96,6 +101,15 @@ PIVOT ( FOR y IN (2012, 2013) ); +-- pivot with projection and value aliases +SELECT firstYear_s, secondYear_s, firstYear_a, secondYear_a, c FROM ( + SELECT year y, course c, earnings e FROM courseSales +) +PIVOT ( + sum(e) s, avg(e) a + FOR y IN (2012 as firstYear, 2013 secondYear) +); + -- pivot years with non-aggregate function SELECT * FROM courseSales PIVOT ( @@ -103,6 +117,15 @@ PIVOT ( FOR year IN (2012, 2013) ); +-- pivot with one of the expressions as non-aggregate function +SELECT * FROM ( + SELECT year, course, earnings FROM courseSales +) +PIVOT ( + sum(earnings), year + FOR course IN ('dotNET', 'Java') +); + -- pivot with unresolvable columns SELECT * FROM ( SELECT course, earnings FROM courseSales @@ -111,3 +134,156 @@ PIVOT ( sum(earnings) FOR year IN (2012, 2013) ); + +-- pivot with complex aggregate expressions +SELECT * FROM ( + SELECT year, course, earnings FROM courseSales +) +PIVOT ( + ceil(sum(earnings)), avg(earnings) + 1 as a1 + FOR course IN ('dotNET', 'Java') +); + +-- pivot with invalid arguments in aggregate expressions +SELECT * FROM ( + SELECT year, course, earnings FROM courseSales +) +PIVOT ( + sum(avg(earnings)) + FOR course IN ('dotNET', 'Java') +); + +-- pivot on multiple pivot columns +SELECT * FROM ( + SELECT course, year, earnings, s + FROM courseSales + JOIN years ON year = y +) +PIVOT ( + sum(earnings) + FOR (course, year) IN (('dotNET', 2012), ('Java', 2013)) +); + +-- pivot on multiple pivot columns with aliased values +SELECT * FROM ( + SELECT course, year, earnings, s + FROM courseSales + JOIN years ON year = y +) +PIVOT ( + sum(earnings) + FOR (course, s) IN (('dotNET', 2) as c1, ('Java', 1) as c2) +); + +-- pivot on multiple pivot columns with values of wrong data types +SELECT * FROM ( + SELECT course, year, earnings, s + FROM courseSales + JOIN years ON year = y +) +PIVOT ( + sum(earnings) + FOR (course, year) IN ('dotNET', 'Java') +); + +-- pivot with unresolvable values +SELECT * FROM courseSales +PIVOT ( + sum(earnings) + FOR year IN (s, 2013) +); + +-- pivot with non-literal values +SELECT * FROM courseSales +PIVOT ( + sum(earnings) + FOR year IN (course, 2013) +); + +-- pivot on join query with columns of complex data types +SELECT * FROM ( + SELECT course, year, a + FROM courseSales + JOIN yearsWithComplexTypes ON year = y +) +PIVOT ( + min(a) + FOR course IN ('dotNET', 'Java') +); + +-- pivot on multiple pivot columns with agg columns of complex data types +SELECT * FROM ( + SELECT course, year, y, a + FROM courseSales + JOIN yearsWithComplexTypes ON year = y +) +PIVOT ( + max(a) + FOR (y, course) IN ((2012, 'dotNET'), (2013, 'Java')) +); + +-- pivot on pivot column of array type +SELECT * FROM ( + SELECT earnings, year, a + FROM courseSales + JOIN yearsWithComplexTypes ON year = y +) +PIVOT ( + sum(earnings) + FOR a IN (array(1, 1), array(2, 2)) +); + +-- pivot on multiple pivot columns containing array type +SELECT * FROM ( + SELECT course, earnings, year, a + FROM courseSales + JOIN yearsWithComplexTypes ON year = y +) +PIVOT ( + sum(earnings) + FOR (course, a) IN (('dotNET', array(1, 1)), ('Java', array(2, 2))) +); + +-- pivot on pivot column of struct type +SELECT * FROM ( + SELECT earnings, year, s + FROM courseSales + JOIN yearsWithComplexTypes ON year = y +) +PIVOT ( + sum(earnings) + FOR s IN ((1, 'a'), (2, 'b')) +); + +-- pivot on multiple pivot columns containing struct type +SELECT * FROM ( + SELECT course, earnings, year, s + FROM courseSales + JOIN yearsWithComplexTypes ON year = y +) +PIVOT ( + sum(earnings) + FOR (course, s) IN (('dotNET', (1, 'a')), ('Java', (2, 'b'))) +); + +-- pivot on pivot column of map type +SELECT * FROM ( + SELECT earnings, year, m + FROM courseSales + JOIN yearsWithComplexTypes ON year = y +) +PIVOT ( + sum(earnings) + FOR m IN (map('1', 1), map('2', 2)) +); + +-- pivot on multiple pivot columns containing map type +SELECT * FROM ( + SELECT course, earnings, year, m + FROM courseSales + JOIN yearsWithComplexTypes ON year = y +) +PIVOT ( + sum(earnings) + FOR (course, m) IN (('dotNET', map('1', 1)), ('Java', map('2', 2))) +); diff --git a/sql/core/src/test/resources/sql-tests/inputs/subquery/exists-subquery/exists-joins-and-set-ops.sql b/sql/core/src/test/resources/sql-tests/inputs/subquery/exists-subquery/exists-joins-and-set-ops.sql index cc4ed64affec7..cefc3fe6272ab 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/subquery/exists-subquery/exists-joins-and-set-ops.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/subquery/exists-subquery/exists-joins-and-set-ops.sql @@ -1,5 +1,9 @@ -- Tests EXISTS subquery support. Tests Exists subquery -- used in Joins (Both when joins occurs in outer and suquery blocks) +-- List of configuration the test suite is run against: +--SET spark.sql.autoBroadcastJoinThreshold=10485760 +--SET spark.sql.autoBroadcastJoinThreshold=-1,spark.sql.join.preferSortMergeJoin=true +--SET spark.sql.autoBroadcastJoinThreshold=-1,spark.sql.join.preferSortMergeJoin=false CREATE TEMPORARY VIEW EMP AS SELECT * FROM VALUES (100, "emp 1", date "2005-01-01", 100.00D, 10), diff --git a/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/in-basic.sql b/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/in-basic.sql new file mode 100644 index 0000000000000..f4ffc20086386 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/in-basic.sql @@ -0,0 +1,14 @@ +create temporary view tab_a as select * from values (1, 1) as tab_a(a1, b1); +create temporary view tab_b as select * from values (1, 1) as tab_b(a2, b2); +create temporary view struct_tab as select struct(col1 as a, col2 as b) as record from + values (1, 1), (1, 2), (2, 1), (2, 2); + +select 1 from tab_a where (a1, b1) not in (select a2, b2 from tab_b); +-- Invalid query, see SPARK-24341 +select 1 from tab_a where (a1, b1) not in (select (a2, b2) from tab_b); + +-- Aliasing is needed as a workaround for SPARK-24443 +select count(*) from struct_tab where record in + (select (a2 as a, b2 as b) from tab_b); +select count(*) from struct_tab where record not in + (select (a2 as a, b2 as b) from tab_b); diff --git a/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/in-joins.sql b/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/in-joins.sql index 880175fd7add0..22f3eafd6a02d 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/in-joins.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/in-joins.sql @@ -1,5 +1,9 @@ -- A test suite for IN JOINS in parent side, subquery, and both predicate subquery -- It includes correlated cases. +-- List of configuration the test suite is run against: +--SET spark.sql.autoBroadcastJoinThreshold=10485760 +--SET spark.sql.autoBroadcastJoinThreshold=-1,spark.sql.join.preferSortMergeJoin=true +--SET spark.sql.autoBroadcastJoinThreshold=-1,spark.sql.join.preferSortMergeJoin=false create temporary view t1 as select * from values ("val1a", 6S, 8, 10L, float(15.0), 20D, 20E2, timestamp '2014-04-04 01:00:00.000', date '2014-04-04'), diff --git a/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/in-limit.sql b/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/in-limit.sql index a40ee082ba3b9..a862e0985b20c 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/in-limit.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/in-limit.sql @@ -1,6 +1,9 @@ -- A test suite for IN LIMIT in parent side, subquery, and both predicate subquery -- It includes correlated cases. +-- Disable global limit optimization +set spark.sql.limit.flatGlobalLimit=false; + create temporary view t1 as select * from values ("val1a", 6S, 8, 10L, float(15.0), 20D, 20E2, timestamp '2014-04-04 01:00:00.000', date '2014-04-04'), ("val1b", 8S, 16, 19L, float(17.0), 25D, 26E2, timestamp '2014-05-04 01:01:00.000', date '2014-05-04'), @@ -97,4 +100,4 @@ WHERE t1d NOT IN (SELECT t2d LIMIT 1) GROUP BY t1b ORDER BY t1b NULLS last -LIMIT 1; \ No newline at end of file +LIMIT 1; diff --git a/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/not-in-joins.sql b/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/not-in-joins.sql index e09b91f18de0a..4f8ca8bfb27c1 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/not-in-joins.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/not-in-joins.sql @@ -1,5 +1,9 @@ -- A test suite for not-in-joins in parent side, subquery, and both predicate subquery -- It includes correlated cases. +-- List of configuration the test suite is run against: +--SET spark.sql.autoBroadcastJoinThreshold=10485760 +--SET spark.sql.autoBroadcastJoinThreshold=-1,spark.sql.join.preferSortMergeJoin=true +--SET spark.sql.autoBroadcastJoinThreshold=-1,spark.sql.join.preferSortMergeJoin=false create temporary view t1 as select * from values ("val1a", 6S, 8, 10L, float(15.0), 20D, 20E2, timestamp '2014-04-04 01:00:00.000', date '2014-04-04'), diff --git a/sql/core/src/test/resources/sql-tests/inputs/subquery/negative-cases/subq-input-typecheck.sql b/sql/core/src/test/resources/sql-tests/inputs/subquery/negative-cases/subq-input-typecheck.sql index b15f4da81dd93..95b115a8dd094 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/subquery/negative-cases/subq-input-typecheck.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/subquery/negative-cases/subq-input-typecheck.sql @@ -13,6 +13,14 @@ CREATE TEMPORARY VIEW t3 AS SELECT * FROM VALUES (3, 1, 2) AS t3(t3a, t3b, t3c); +CREATE TEMPORARY VIEW t4 AS SELECT * FROM VALUES + (CAST(1 AS DOUBLE), CAST(2 AS STRING), CAST(3 AS STRING)) +AS t1(t4a, t4b, t4c); + +CREATE TEMPORARY VIEW t5 AS SELECT * FROM VALUES + (CAST(1 AS DECIMAL(18, 0)), CAST(2 AS STRING), CAST(3 AS BIGINT)) +AS t1(t5a, t5b, t5c); + -- TC 01.01 SELECT ( SELECT max(t2b), min(t2b) @@ -44,4 +52,10 @@ WHERE (t1a, t1b) IN (SELECT t2a FROM t2 WHERE t1a = t2a); - +-- TC 01.05 +SELECT * FROM t4 +WHERE +(t4a, t4b, t4c) IN (SELECT t5a, + t5b, + t5c + FROM t5); diff --git a/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/concat.sql b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/concat.sql index db00a18f2e7e9..99f46dd19d0e2 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/concat.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/concat.sql @@ -148,6 +148,8 @@ SELECT (tinyint_array1 || smallint_array2) ts_array, (smallint_array1 || int_array2) si_array, (int_array1 || bigint_array2) ib_array, + (bigint_array1 || decimal_array2) bd_array, + (decimal_array1 || double_array2) dd_array, (double_array1 || float_array2) df_array, (string_array1 || data_array2) std_array, (timestamp_array1 || string_array2) tst_array, diff --git a/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/mapZipWith.sql b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/mapZipWith.sql new file mode 100644 index 0000000000000..1727ee725db2e --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/mapZipWith.sql @@ -0,0 +1,78 @@ +CREATE TEMPORARY VIEW various_maps AS SELECT * FROM VALUES ( + map(true, false), + map(2Y, 1Y), + map(2S, 1S), + map(2, 1), + map(2L, 1L), + map(922337203685477897945456575809789456, 922337203685477897945456575809789456), + map(9.22337203685477897945456575809789456, 9.22337203685477897945456575809789456), + map(2.0D, 1.0D), + map(float(2.0), float(1.0)), + map(date '2016-03-14', date '2016-03-13'), + map(timestamp '2016-11-15 20:54:00.000', timestamp '2016-11-12 20:54:00.000'), + map('true', 'false', '2', '1'), + map('2016-03-14', '2016-03-13'), + map('2016-11-15 20:54:00.000', '2016-11-12 20:54:00.000'), + map('922337203685477897945456575809789456', 'text'), + map(array(1L, 2L), array(1L, 2L)), map(array(1, 2), array(1, 2)), + map(struct(1S, 2L), struct(1S, 2L)), map(struct(1, 2), struct(1, 2)) +) AS various_maps( + boolean_map, + tinyint_map, + smallint_map, + int_map, + bigint_map, + decimal_map1, decimal_map2, + double_map, + float_map, + date_map, + timestamp_map, + string_map1, string_map2, string_map3, string_map4, + array_map1, array_map2, + struct_map1, struct_map2 +); + +SELECT map_zip_with(tinyint_map, smallint_map, (k, v1, v2) -> struct(k, v1, v2)) m +FROM various_maps; + +SELECT map_zip_with(smallint_map, int_map, (k, v1, v2) -> struct(k, v1, v2)) m +FROM various_maps; + +SELECT map_zip_with(int_map, bigint_map, (k, v1, v2) -> struct(k, v1, v2)) m +FROM various_maps; + +SELECT map_zip_with(double_map, float_map, (k, v1, v2) -> struct(k, v1, v2)) m +FROM various_maps; + +SELECT map_zip_with(decimal_map1, decimal_map2, (k, v1, v2) -> struct(k, v1, v2)) m +FROM various_maps; + +SELECT map_zip_with(decimal_map1, int_map, (k, v1, v2) -> struct(k, v1, v2)) m +FROM various_maps; + +SELECT map_zip_with(decimal_map1, double_map, (k, v1, v2) -> struct(k, v1, v2)) m +FROM various_maps; + +SELECT map_zip_with(decimal_map2, int_map, (k, v1, v2) -> struct(k, v1, v2)) m +FROM various_maps; + +SELECT map_zip_with(decimal_map2, double_map, (k, v1, v2) -> struct(k, v1, v2)) m +FROM various_maps; + +SELECT map_zip_with(string_map1, int_map, (k, v1, v2) -> struct(k, v1, v2)) m +FROM various_maps; + +SELECT map_zip_with(string_map2, date_map, (k, v1, v2) -> struct(k, v1, v2)) m +FROM various_maps; + +SELECT map_zip_with(timestamp_map, string_map3, (k, v1, v2) -> struct(k, v1, v2)) m +FROM various_maps; + +SELECT map_zip_with(decimal_map1, string_map4, (k, v1, v2) -> struct(k, v1, v2)) m +FROM various_maps; + +SELECT map_zip_with(array_map1, array_map2, (k, v1, v2) -> struct(k, v1, v2)) m +FROM various_maps; + +SELECT map_zip_with(struct_map1, struct_map2, (k, v1, v2) -> struct(k, v1, v2)) m +FROM various_maps; diff --git a/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/mapconcat.sql b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/mapconcat.sql new file mode 100644 index 0000000000000..69da67fc66fc0 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/mapconcat.sql @@ -0,0 +1,95 @@ +CREATE TEMPORARY VIEW various_maps AS SELECT * FROM VALUES ( + map(true, false), map(false, true), + map(1Y, 2Y), map(3Y, 4Y), + map(1S, 2S), map(3S, 4S), + map(4, 6), map(7, 8), + map(6L, 7L), map(8L, 9L), + map(9223372036854775809, 9223372036854775808), map(9223372036854775808, 9223372036854775809), + map(1.0D, 2.0D), map(3.0D, 4.0D), + map(float(1.0D), float(2.0D)), map(float(3.0D), float(4.0D)), + map(date '2016-03-14', date '2016-03-13'), map(date '2016-03-12', date '2016-03-11'), + map(timestamp '2016-11-15 20:54:00.000', timestamp '2016-11-12 20:54:00.000'), + map(timestamp '2016-11-11 20:54:00.000', timestamp '2016-11-09 20:54:00.000'), + map('a', 'b'), map('c', 'd'), + map(array('a', 'b'), array('c', 'd')), map(array('e'), array('f')), + map(struct('a', 1), struct('b', 2)), map(struct('c', 3), struct('d', 4)), + map(map('a', 1), map('b', 2)), map(map('c', 3), map('d', 4)), + map('a', 1), map('c', 2), + map(1, 'a'), map(2, 'c') +) AS various_maps ( + boolean_map1, boolean_map2, + tinyint_map1, tinyint_map2, + smallint_map1, smallint_map2, + int_map1, int_map2, + bigint_map1, bigint_map2, + decimal_map1, decimal_map2, + double_map1, double_map2, + float_map1, float_map2, + date_map1, date_map2, + timestamp_map1, + timestamp_map2, + string_map1, string_map2, + array_map1, array_map2, + struct_map1, struct_map2, + map_map1, map_map2, + string_int_map1, string_int_map2, + int_string_map1, int_string_map2 +); + +-- Concatenate maps of the same type +SELECT + map_concat(boolean_map1, boolean_map2) boolean_map, + map_concat(tinyint_map1, tinyint_map2) tinyint_map, + map_concat(smallint_map1, smallint_map2) smallint_map, + map_concat(int_map1, int_map2) int_map, + map_concat(bigint_map1, bigint_map2) bigint_map, + map_concat(decimal_map1, decimal_map2) decimal_map, + map_concat(float_map1, float_map2) float_map, + map_concat(double_map1, double_map2) double_map, + map_concat(date_map1, date_map2) date_map, + map_concat(timestamp_map1, timestamp_map2) timestamp_map, + map_concat(string_map1, string_map2) string_map, + map_concat(array_map1, array_map2) array_map, + map_concat(struct_map1, struct_map2) struct_map, + map_concat(map_map1, map_map2) map_map, + map_concat(string_int_map1, string_int_map2) string_int_map, + map_concat(int_string_map1, int_string_map2) int_string_map +FROM various_maps; + +-- Concatenate maps of different types +SELECT + map_concat(tinyint_map1, smallint_map2) ts_map, + map_concat(smallint_map1, int_map2) si_map, + map_concat(int_map1, bigint_map2) ib_map, + map_concat(bigint_map1, decimal_map2) bd_map, + map_concat(decimal_map1, float_map2) df_map, + map_concat(string_map1, date_map2) std_map, + map_concat(timestamp_map1, string_map2) tst_map, + map_concat(string_map1, int_map2) sti_map, + map_concat(int_string_map1, tinyint_map2) istt_map +FROM various_maps; + +-- Concatenate map of incompatible types 1 +SELECT + map_concat(tinyint_map1, map_map2) tm_map +FROM various_maps; + +-- Concatenate map of incompatible types 2 +SELECT + map_concat(boolean_map1, int_map2) bi_map +FROM various_maps; + +-- Concatenate map of incompatible types 3 +SELECT + map_concat(int_map1, struct_map2) is_map +FROM various_maps; + +-- Concatenate map of incompatible types 4 +SELECT + map_concat(map_map1, array_map2) ma_map +FROM various_maps; + +-- Concatenate map of incompatible types 5 +SELECT + map_concat(map_map1, struct_map2) ms_map +FROM various_maps; diff --git a/sql/core/src/test/resources/sql-tests/results/columnresolution-negative.sql.out b/sql/core/src/test/resources/sql-tests/results/columnresolution-negative.sql.out index 539f673c9d679..9fc97f0c39149 100644 --- a/sql/core/src/test/resources/sql-tests/results/columnresolution-negative.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/columnresolution-negative.sql.out @@ -72,7 +72,7 @@ SELECT i1 FROM t1, mydb1.t1 struct<> -- !query 8 output org.apache.spark.sql.AnalysisException -Reference 'i1' is ambiguous, could be: t1.i1, t1.i1.; line 1 pos 7 +Reference 'i1' is ambiguous, could be: mydb1.t1.i1, mydb1.t1.i1.; line 1 pos 7 -- !query 9 @@ -81,7 +81,7 @@ SELECT t1.i1 FROM t1, mydb1.t1 struct<> -- !query 9 output org.apache.spark.sql.AnalysisException -Reference 't1.i1' is ambiguous, could be: t1.i1, t1.i1.; line 1 pos 7 +Reference 't1.i1' is ambiguous, could be: mydb1.t1.i1, mydb1.t1.i1.; line 1 pos 7 -- !query 10 @@ -90,7 +90,7 @@ SELECT mydb1.t1.i1 FROM t1, mydb1.t1 struct<> -- !query 10 output org.apache.spark.sql.AnalysisException -cannot resolve '`mydb1.t1.i1`' given input columns: [t1.i1, t1.i1]; line 1 pos 7 +Reference 'mydb1.t1.i1' is ambiguous, could be: mydb1.t1.i1, mydb1.t1.i1.; line 1 pos 7 -- !query 11 @@ -99,7 +99,7 @@ SELECT i1 FROM t1, mydb2.t1 struct<> -- !query 11 output org.apache.spark.sql.AnalysisException -Reference 'i1' is ambiguous, could be: t1.i1, t1.i1.; line 1 pos 7 +Reference 'i1' is ambiguous, could be: mydb1.t1.i1, mydb2.t1.i1.; line 1 pos 7 -- !query 12 @@ -108,7 +108,7 @@ SELECT t1.i1 FROM t1, mydb2.t1 struct<> -- !query 12 output org.apache.spark.sql.AnalysisException -Reference 't1.i1' is ambiguous, could be: t1.i1, t1.i1.; line 1 pos 7 +Reference 't1.i1' is ambiguous, could be: mydb1.t1.i1, mydb2.t1.i1.; line 1 pos 7 -- !query 13 @@ -125,7 +125,7 @@ SELECT i1 FROM t1, mydb1.t1 struct<> -- !query 14 output org.apache.spark.sql.AnalysisException -Reference 'i1' is ambiguous, could be: t1.i1, t1.i1.; line 1 pos 7 +Reference 'i1' is ambiguous, could be: mydb2.t1.i1, mydb1.t1.i1.; line 1 pos 7 -- !query 15 @@ -134,7 +134,7 @@ SELECT t1.i1 FROM t1, mydb1.t1 struct<> -- !query 15 output org.apache.spark.sql.AnalysisException -Reference 't1.i1' is ambiguous, could be: t1.i1, t1.i1.; line 1 pos 7 +Reference 't1.i1' is ambiguous, could be: mydb2.t1.i1, mydb1.t1.i1.; line 1 pos 7 -- !query 16 @@ -143,7 +143,7 @@ SELECT i1 FROM t1, mydb2.t1 struct<> -- !query 16 output org.apache.spark.sql.AnalysisException -Reference 'i1' is ambiguous, could be: t1.i1, t1.i1.; line 1 pos 7 +Reference 'i1' is ambiguous, could be: mydb2.t1.i1, mydb2.t1.i1.; line 1 pos 7 -- !query 17 @@ -152,7 +152,7 @@ SELECT t1.i1 FROM t1, mydb2.t1 struct<> -- !query 17 output org.apache.spark.sql.AnalysisException -Reference 't1.i1' is ambiguous, could be: t1.i1, t1.i1.; line 1 pos 7 +Reference 't1.i1' is ambiguous, could be: mydb2.t1.i1, mydb2.t1.i1.; line 1 pos 7 -- !query 18 @@ -161,7 +161,7 @@ SELECT db1.t1.i1 FROM t1, mydb2.t1 struct<> -- !query 18 output org.apache.spark.sql.AnalysisException -cannot resolve '`db1.t1.i1`' given input columns: [t1.i1, t1.i1]; line 1 pos 7 +cannot resolve '`db1.t1.i1`' given input columns: [mydb2.t1.i1, mydb2.t1.i1]; line 1 pos 7 -- !query 19 @@ -186,7 +186,7 @@ SELECT mydb1.t1 FROM t1 struct<> -- !query 21 output org.apache.spark.sql.AnalysisException -cannot resolve '`mydb1.t1`' given input columns: [t1.i1]; line 1 pos 7 +cannot resolve '`mydb1.t1`' given input columns: [mydb1.t1.i1]; line 1 pos 7 -- !query 22 @@ -204,7 +204,7 @@ SELECT t1 FROM mydb1.t1 struct<> -- !query 23 output org.apache.spark.sql.AnalysisException -cannot resolve '`t1`' given input columns: [t1.i1]; line 1 pos 7 +cannot resolve '`t1`' given input columns: [mydb1.t1.i1]; line 1 pos 7 -- !query 24 @@ -221,7 +221,7 @@ SELECT mydb1.t1.i1 FROM t1 struct<> -- !query 25 output org.apache.spark.sql.AnalysisException -cannot resolve '`mydb1.t1.i1`' given input columns: [t1.i1]; line 1 pos 7 +cannot resolve '`mydb1.t1.i1`' given input columns: [mydb2.t1.i1]; line 1 pos 7 -- !query 26 diff --git a/sql/core/src/test/resources/sql-tests/results/columnresolution-views.sql.out b/sql/core/src/test/resources/sql-tests/results/columnresolution-views.sql.out index 2092119600954..3d8fb661afe55 100644 --- a/sql/core/src/test/resources/sql-tests/results/columnresolution-views.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/columnresolution-views.sql.out @@ -85,10 +85,9 @@ struct -- !query 10 SELECT global_temp.view1.* FROM global_temp.view1 -- !query 10 schema -struct<> +struct -- !query 10 output -org.apache.spark.sql.AnalysisException -cannot resolve 'global_temp.view1.*' given input columns 'i1'; +1 -- !query 11 @@ -102,10 +101,9 @@ struct -- !query 12 SELECT global_temp.view1.i1 FROM global_temp.view1 -- !query 12 schema -struct<> +struct -- !query 12 output -org.apache.spark.sql.AnalysisException -cannot resolve '`global_temp.view1.i1`' given input columns: [view1.i1]; line 1 pos 7 +1 -- !query 13 diff --git a/sql/core/src/test/resources/sql-tests/results/columnresolution.sql.out b/sql/core/src/test/resources/sql-tests/results/columnresolution.sql.out index e10f516ad6e5b..73e3fdc08232c 100644 --- a/sql/core/src/test/resources/sql-tests/results/columnresolution.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/columnresolution.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 54 +-- Number of queries: 55 -- !query 0 @@ -93,19 +93,17 @@ struct -- !query 11 SELECT mydb1.t1.i1 FROM t1 -- !query 11 schema -struct<> +struct -- !query 11 output -org.apache.spark.sql.AnalysisException -cannot resolve '`mydb1.t1.i1`' given input columns: [t1.i1]; line 1 pos 7 +1 -- !query 12 SELECT mydb1.t1.i1 FROM mydb1.t1 -- !query 12 schema -struct<> +struct -- !query 12 output -org.apache.spark.sql.AnalysisException -cannot resolve '`mydb1.t1.i1`' given input columns: [t1.i1]; line 1 pos 7 +1 -- !query 13 @@ -151,10 +149,9 @@ struct -- !query 18 SELECT mydb1.t1.i1 FROM mydb1.t1 -- !query 18 schema -struct<> +struct -- !query 18 output -org.apache.spark.sql.AnalysisException -cannot resolve '`mydb1.t1.i1`' given input columns: [t1.i1]; line 1 pos 7 +1 -- !query 19 @@ -176,10 +173,9 @@ struct -- !query 21 SELECT mydb1.t1.* FROM mydb1.t1 -- !query 21 schema -struct<> +struct -- !query 21 output -org.apache.spark.sql.AnalysisException -cannot resolve 'mydb1.t1.*' given input columns 'i1'; +1 -- !query 22 @@ -209,10 +205,9 @@ struct -- !query 25 SELECT mydb1.t1.* FROM mydb1.t1 -- !query 25 schema -struct<> +struct -- !query 25 output -org.apache.spark.sql.AnalysisException -cannot resolve 'mydb1.t1.*' given input columns 'i1'; +1 -- !query 26 @@ -267,10 +262,9 @@ struct SELECT * FROM mydb1.t3 WHERE c1 IN (SELECT mydb1.t4.c2 FROM mydb1.t4 WHERE mydb1.t4.c3 = mydb1.t3.c2) -- !query 32 schema -struct<> +struct -- !query 32 output -org.apache.spark.sql.AnalysisException -cannot resolve '`mydb1.t4.c3`' given input columns: [t4.c2, t4.c3]; line 2 pos 42 +4 1 -- !query 33 @@ -284,19 +278,17 @@ spark.sql.crossJoin.enabled true -- !query 34 SELECT mydb1.t1.i1 FROM t1, mydb2.t1 -- !query 34 schema -struct<> +struct -- !query 34 output -org.apache.spark.sql.AnalysisException -cannot resolve '`mydb1.t1.i1`' given input columns: [t1.i1, t1.i1]; line 1 pos 7 +1 -- !query 35 SELECT mydb1.t1.i1 FROM mydb1.t1, mydb2.t1 -- !query 35 schema -struct<> +struct -- !query 35 output -org.apache.spark.sql.AnalysisException -cannot resolve '`mydb1.t1.i1`' given input columns: [t1.i1, t1.i1]; line 1 pos 7 +1 -- !query 36 @@ -310,10 +302,9 @@ struct<> -- !query 37 SELECT mydb1.t1.i1 FROM t1, mydb1.t1 -- !query 37 schema -struct<> +struct -- !query 37 output -org.apache.spark.sql.AnalysisException -cannot resolve '`mydb1.t1.i1`' given input columns: [t1.i1, t1.i1]; line 1 pos 7 +1 -- !query 38 @@ -399,40 +390,37 @@ struct -- !query 48 SELECT mydb1.t5.t5.i1 FROM mydb1.t5 -- !query 48 schema -struct<> +struct -- !query 48 output -org.apache.spark.sql.AnalysisException -cannot resolve '`mydb1.t5.t5.i1`' given input columns: [t5.i1, t5.t5]; line 1 pos 7 +2 -- !query 49 SELECT mydb1.t5.t5.i2 FROM mydb1.t5 -- !query 49 schema -struct<> +struct -- !query 49 output -org.apache.spark.sql.AnalysisException -cannot resolve '`mydb1.t5.t5.i2`' given input columns: [t5.i1, t5.t5]; line 1 pos 7 +3 -- !query 50 SELECT mydb1.t5.* FROM mydb1.t5 -- !query 50 schema -struct<> +struct> -- !query 50 output -org.apache.spark.sql.AnalysisException -cannot resolve 'mydb1.t5.*' given input columns 'i1, t5'; +1 {"i1":2,"i2":3} -- !query 51 -USE default +SELECT mydb1.t5.* FROM t5 -- !query 51 schema -struct<> +struct> -- !query 51 output - +1 {"i1":2,"i2":3} -- !query 52 -DROP DATABASE mydb1 CASCADE +USE default -- !query 52 schema struct<> -- !query 52 output @@ -440,8 +428,16 @@ struct<> -- !query 53 -DROP DATABASE mydb2 CASCADE +DROP DATABASE mydb1 CASCADE -- !query 53 schema struct<> -- !query 53 output + + +-- !query 54 +DROP DATABASE mydb2 CASCADE +-- !query 54 schema +struct<> +-- !query 54 output + diff --git a/sql/core/src/test/resources/sql-tests/results/except-all.sql.out b/sql/core/src/test/resources/sql-tests/results/except-all.sql.out new file mode 100644 index 0000000000000..01091a2f751ce --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/except-all.sql.out @@ -0,0 +1,346 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 27 + + +-- !query 0 +CREATE TEMPORARY VIEW tab1 AS SELECT * FROM VALUES + (0), (1), (2), (2), (2), (2), (3), (null), (null) AS tab1(c1) +-- !query 0 schema +struct<> +-- !query 0 output + + + +-- !query 1 +CREATE TEMPORARY VIEW tab2 AS SELECT * FROM VALUES + (1), (2), (2), (3), (5), (5), (null) AS tab2(c1) +-- !query 1 schema +struct<> +-- !query 1 output + + + +-- !query 2 +CREATE TEMPORARY VIEW tab3 AS SELECT * FROM VALUES + (1, 2), + (1, 2), + (1, 3), + (2, 3), + (2, 2) + AS tab3(k, v) +-- !query 2 schema +struct<> +-- !query 2 output + + + +-- !query 3 +CREATE TEMPORARY VIEW tab4 AS SELECT * FROM VALUES + (1, 2), + (2, 3), + (2, 2), + (2, 2), + (2, 20) + AS tab4(k, v) +-- !query 3 schema +struct<> +-- !query 3 output + + + +-- !query 4 +SELECT * FROM tab1 +EXCEPT ALL +SELECT * FROM tab2 +-- !query 4 schema +struct +-- !query 4 output +0 +2 +2 +NULL + + +-- !query 5 +SELECT * FROM tab1 +MINUS ALL +SELECT * FROM tab2 +-- !query 5 schema +struct +-- !query 5 output +0 +2 +2 +NULL + + +-- !query 6 +SELECT * FROM tab1 +EXCEPT ALL +SELECT * FROM tab2 WHERE c1 IS NOT NULL +-- !query 6 schema +struct +-- !query 6 output +0 +2 +2 +NULL +NULL + + +-- !query 7 +SELECT * FROM tab1 WHERE c1 > 5 +EXCEPT ALL +SELECT * FROM tab2 +-- !query 7 schema +struct +-- !query 7 output + + + +-- !query 8 +SELECT * FROM tab1 +EXCEPT ALL +SELECT * FROM tab2 WHERE c1 > 6 +-- !query 8 schema +struct +-- !query 8 output +0 +1 +2 +2 +2 +2 +3 +NULL +NULL + + +-- !query 9 +SELECT * FROM tab1 +EXCEPT ALL +SELECT CAST(1 AS BIGINT) +-- !query 9 schema +struct +-- !query 9 output +0 +2 +2 +2 +2 +3 +NULL +NULL + + +-- !query 10 +SELECT * FROM tab1 +EXCEPT ALL +SELECT array(1) +-- !query 10 schema +struct<> +-- !query 10 output +org.apache.spark.sql.AnalysisException +ExceptAll can only be performed on tables with the compatible column types. array <> int at the first column of the second table; + + +-- !query 11 +SELECT * FROM tab3 +EXCEPT ALL +SELECT * FROM tab4 +-- !query 11 schema +struct +-- !query 11 output +1 2 +1 3 + + +-- !query 12 +SELECT * FROM tab4 +EXCEPT ALL +SELECT * FROM tab3 +-- !query 12 schema +struct +-- !query 12 output +2 2 +2 20 + + +-- !query 13 +SELECT * FROM tab4 +EXCEPT ALL +SELECT * FROM tab3 +INTERSECT DISTINCT +SELECT * FROM tab4 +-- !query 13 schema +struct +-- !query 13 output +2 2 +2 20 + + +-- !query 14 +SELECT * FROM tab4 +EXCEPT ALL +SELECT * FROM tab3 +EXCEPT DISTINCT +SELECT * FROM tab4 +-- !query 14 schema +struct +-- !query 14 output + + + +-- !query 15 +SELECT * FROM tab3 +EXCEPT ALL +SELECT * FROM tab4 +UNION ALL +SELECT * FROM tab3 +EXCEPT DISTINCT +SELECT * FROM tab4 +-- !query 15 schema +struct +-- !query 15 output +1 3 + + +-- !query 16 +SELECT k FROM tab3 +EXCEPT ALL +SELECT k, v FROM tab4 +-- !query 16 schema +struct<> +-- !query 16 output +org.apache.spark.sql.AnalysisException +ExceptAll can only be performed on tables with the same number of columns, but the first table has 1 columns and the second table has 2 columns; + + +-- !query 17 +SELECT * FROM tab3 +EXCEPT ALL +SELECT * FROM tab4 +UNION +SELECT * FROM tab3 +EXCEPT DISTINCT +SELECT * FROM tab4 +-- !query 17 schema +struct +-- !query 17 output +1 3 + + +-- !query 18 +SELECT * FROM tab3 +MINUS ALL +SELECT * FROM tab4 +UNION +SELECT * FROM tab3 +MINUS DISTINCT +SELECT * FROM tab4 +-- !query 18 schema +struct +-- !query 18 output +1 3 + + +-- !query 19 +SELECT * FROM tab3 +EXCEPT ALL +SELECT * FROM tab4 +EXCEPT DISTINCT +SELECT * FROM tab3 +EXCEPT DISTINCT +SELECT * FROM tab4 +-- !query 19 schema +struct +-- !query 19 output + + + +-- !query 20 +SELECT * +FROM (SELECT tab3.k, + tab4.v + FROM tab3 + JOIN tab4 + ON tab3.k = tab4.k) +EXCEPT ALL +SELECT * +FROM (SELECT tab3.k, + tab4.v + FROM tab3 + JOIN tab4 + ON tab3.k = tab4.k) +-- !query 20 schema +struct +-- !query 20 output + + + +-- !query 21 +SELECT * +FROM (SELECT tab3.k, + tab4.v + FROM tab3 + JOIN tab4 + ON tab3.k = tab4.k) +EXCEPT ALL +SELECT * +FROM (SELECT tab4.v AS k, + tab3.k AS v + FROM tab3 + JOIN tab4 + ON tab3.k = tab4.k) +-- !query 21 schema +struct +-- !query 21 output +1 2 +1 2 +1 2 +2 20 +2 20 +2 3 +2 3 + + +-- !query 22 +SELECT v FROM tab3 GROUP BY v +EXCEPT ALL +SELECT k FROM tab4 GROUP BY k +-- !query 22 schema +struct +-- !query 22 output +3 + + +-- !query 23 +DROP VIEW IF EXISTS tab1 +-- !query 23 schema +struct<> +-- !query 23 output + + + +-- !query 24 +DROP VIEW IF EXISTS tab2 +-- !query 24 schema +struct<> +-- !query 24 output + + + +-- !query 25 +DROP VIEW IF EXISTS tab3 +-- !query 25 schema +struct<> +-- !query 25 output + + + +-- !query 26 +DROP VIEW IF EXISTS tab4 +-- !query 26 schema +struct<> +-- !query 26 output + diff --git a/sql/core/src/test/resources/sql-tests/results/group-by-ordinal.sql.out b/sql/core/src/test/resources/sql-tests/results/group-by-ordinal.sql.out index 9ecbe19078dd6..cf5add6a71af2 100644 --- a/sql/core/src/test/resources/sql-tests/results/group-by-ordinal.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/group-by-ordinal.sql.out @@ -135,7 +135,9 @@ aggregate functions are not allowed in GROUP BY, but found (sum(CAST(data.`b` AS -- !query 13 -select a, rand(0), sum(b) from data group by a, 2 +select a, rand(0), sum(b) +from +(select /*+ REPARTITION(1) */ a, b from data) group by a, 2 -- !query 13 schema struct -- !query 13 output diff --git a/sql/core/src/test/resources/sql-tests/results/grouping_set.sql.out b/sql/core/src/test/resources/sql-tests/results/grouping_set.sql.out index edb38a52b7514..34ab09c5e3bba 100644 --- a/sql/core/src/test/resources/sql-tests/results/grouping_set.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/grouping_set.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 4 +-- Number of queries: 15 -- !query 0 @@ -40,3 +40,127 @@ struct NULL NULL 3 1 NULL NULL 6 1 NULL NULL 9 1 + + +-- !query 4 +SELECT c1, sum(c2) FROM (VALUES ('x', 10, 0), ('y', 20, 0)) AS t (c1, c2, c3) GROUP BY GROUPING SETS (c1) +-- !query 4 schema +struct +-- !query 4 output +x 10 +y 20 + + +-- !query 5 +SELECT c1, sum(c2), grouping(c1) FROM (VALUES ('x', 10, 0), ('y', 20, 0)) AS t (c1, c2, c3) GROUP BY GROUPING SETS (c1) +-- !query 5 schema +struct +-- !query 5 output +x 10 0 +y 20 0 + + +-- !query 6 +SELECT c1, c2, Sum(c3), grouping__id +FROM (VALUES ('x', 'a', 10), ('y', 'b', 20) ) AS t (c1, c2, c3) +GROUP BY GROUPING SETS ( ( c1 ), ( c2 ) ) +HAVING GROUPING__ID > 1 +-- !query 6 schema +struct +-- !query 6 output +NULL a 10 2 +NULL b 20 2 + + +-- !query 7 +SELECT grouping(c1) FROM (VALUES ('x', 'a', 10), ('y', 'b', 20)) AS t (c1, c2, c3) GROUP BY GROUPING SETS (c1,c2) +-- !query 7 schema +struct +-- !query 7 output +0 +0 +1 +1 + + +-- !query 8 +SELECT -c1 AS c1 FROM (values (1,2), (3,2)) t(c1, c2) GROUP BY GROUPING SETS ((c1), (c1, c2)) +-- !query 8 schema +struct +-- !query 8 output +-1 +-1 +-3 +-3 + + +-- !query 9 +SELECT a + b, b, sum(c) FROM (VALUES (1,1,1),(2,2,2)) AS t(a,b,c) GROUP BY GROUPING SETS ( (a + b), (b)) +-- !query 9 schema +struct<(a + b):int,b:int,sum(c):bigint> +-- !query 9 output +2 NULL 1 +4 NULL 2 +NULL 1 1 +NULL 2 2 + + +-- !query 10 +SELECT a + b, b, sum(c) FROM (VALUES (1,1,1),(2,2,2)) AS t(a,b,c) GROUP BY GROUPING SETS ( (a + b), (b + a), (b)) +-- !query 10 schema +struct<(a + b):int,b:int,sum(c):bigint> +-- !query 10 output +2 NULL 2 +4 NULL 4 +NULL 1 1 +NULL 2 2 + + +-- !query 11 +SELECT c1 AS col1, c2 AS col2 +FROM (VALUES (1, 2), (3, 2)) t(c1, c2) +GROUP BY GROUPING SETS ( ( c1 ), ( c1, c2 ) ) +HAVING col2 IS NOT NULL +ORDER BY -col1 +-- !query 11 schema +struct +-- !query 11 output +3 2 +1 2 + + +-- !query 12 +SELECT a, b, c, count(d) FROM grouping GROUP BY WITH ROLLUP +-- !query 12 schema +struct<> +-- !query 12 output +org.apache.spark.sql.catalyst.parser.ParseException + +extraneous input 'ROLLUP' expecting (line 1, pos 53) + +== SQL == +SELECT a, b, c, count(d) FROM grouping GROUP BY WITH ROLLUP +-----------------------------------------------------^^^ + + +-- !query 13 +SELECT a, b, c, count(d) FROM grouping GROUP BY WITH CUBE +-- !query 13 schema +struct<> +-- !query 13 output +org.apache.spark.sql.catalyst.parser.ParseException + +extraneous input 'CUBE' expecting (line 1, pos 53) + +== SQL == +SELECT a, b, c, count(d) FROM grouping GROUP BY WITH CUBE +-----------------------------------------------------^^^ + + +-- !query 14 +SELECT c1 FROM (values (1,2), (3,2)) t(c1, c2) GROUP BY GROUPING SETS (()) +-- !query 14 schema +struct<> +-- !query 14 output +org.apache.spark.sql.AnalysisException +expression '`c1`' is neither present in the group by, nor is it an aggregate function. Add to group by or wrap in first() (or first_value) if you don't care which value you get.; diff --git a/sql/core/src/test/resources/sql-tests/results/higher-order-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/higher-order-functions.sql.out new file mode 100644 index 0000000000000..32d20d1b73415 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/higher-order-functions.sql.out @@ -0,0 +1,255 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 27 + + +-- !query 0 +create or replace temporary view nested as values + (1, array(32, 97), array(array(12, 99), array(123, 42), array(1))), + (2, array(77, -76), array(array(6, 96, 65), array(-1, -2))), + (3, array(12), array(array(17))) + as t(x, ys, zs) +-- !query 0 schema +struct<> +-- !query 0 output + + + +-- !query 1 +select upper(x -> x) as v +-- !query 1 schema +struct<> +-- !query 1 output +org.apache.spark.sql.AnalysisException +A lambda function should only be used in a higher order function. However, its class is org.apache.spark.sql.catalyst.expressions.Upper, which is not a higher order function.; line 1 pos 7 + + +-- !query 2 +select transform(zs, z -> z) as v from nested +-- !query 2 schema +struct>> +-- !query 2 output +[[12,99],[123,42],[1]] +[[17]] +[[6,96,65],[-1,-2]] + + +-- !query 3 +select transform(ys, y -> y * y) as v from nested +-- !query 3 schema +struct> +-- !query 3 output +[1024,9409] +[144] +[5929,5776] + + +-- !query 4 +select transform(ys, (y, i) -> y + i) as v from nested +-- !query 4 schema +struct> +-- !query 4 output +[12] +[32,98] +[77,-75] + + +-- !query 5 +select transform(zs, z -> concat(ys, z)) as v from nested +-- !query 5 schema +struct>> +-- !query 5 output +[[12,17]] +[[32,97,12,99],[32,97,123,42],[32,97,1]] +[[77,-76,6,96,65],[77,-76,-1,-2]] + + +-- !query 6 +select transform(ys, 0) as v from nested +-- !query 6 schema +struct> +-- !query 6 output +[0,0] +[0,0] +[0] + + +-- !query 7 +select transform(cast(null as array), x -> x + 1) as v +-- !query 7 schema +struct> +-- !query 7 output +NULL + + +-- !query 8 +select filter(ys, y -> y > 30) as v from nested +-- !query 8 schema +struct> +-- !query 8 output +[32,97] +[77] +[] + + +-- !query 9 +select filter(cast(null as array), y -> true) as v +-- !query 9 schema +struct> +-- !query 9 output +NULL + + +-- !query 10 +select transform(zs, z -> filter(z, zz -> zz > 50)) as v from nested +-- !query 10 schema +struct>> +-- !query 10 output +[[96,65],[]] +[[99],[123],[]] +[[]] + + +-- !query 11 +select aggregate(ys, 0, (y, a) -> y + a + x) as v from nested +-- !query 11 schema +struct +-- !query 11 output +131 +15 +5 + + +-- !query 12 +select aggregate(ys, (0 as sum, 0 as n), (acc, x) -> (acc.sum + x, acc.n + 1), acc -> acc.sum / acc.n) as v from nested +-- !query 12 schema +struct +-- !query 12 output +0.5 +12.0 +64.5 + + +-- !query 13 +select transform(zs, z -> aggregate(z, 1, (acc, val) -> acc * val * size(z))) as v from nested +-- !query 13 schema +struct> +-- !query 13 output +[1010880,8] +[17] +[4752,20664,1] + + +-- !query 14 +select aggregate(cast(null as array), 0, (a, y) -> a + y + 1, a -> a + 2) as v +-- !query 14 schema +struct +-- !query 14 output +NULL + + +-- !query 15 +select exists(ys, y -> y > 30) as v from nested +-- !query 15 schema +struct +-- !query 15 output +false +true +true + + +-- !query 16 +select exists(cast(null as array), y -> y > 30) as v +-- !query 16 schema +struct +-- !query 16 output +NULL + + +-- !query 17 +select zip_with(ys, zs, (a, b) -> a + size(b)) as v from nested +-- !query 17 schema +struct> +-- !query 17 output +[13] +[34,99,null] +[80,-74] + + +-- !query 18 +select zip_with(array('a', 'b', 'c'), array('d', 'e', 'f'), (x, y) -> concat(x, y)) as v +-- !query 18 schema +struct> +-- !query 18 output +["ad","be","cf"] + + +-- !query 19 +select zip_with(array('a'), array('d', null, 'f'), (x, y) -> coalesce(x, y)) as v +-- !query 19 schema +struct> +-- !query 19 output +["a",null,"f"] + + +-- !query 20 +create or replace temporary view nested as values + (1, map(1, 1, 2, 2, 3, 3)), + (2, map(4, 4, 5, 5, 6, 6)) + as t(x, ys) +-- !query 20 schema +struct<> +-- !query 20 output + + +-- !query 21 +select transform_keys(ys, (k, v) -> k) as v from nested +-- !query 21 schema +struct> +-- !query 21 output +{1:1,2:2,3:3} +{4:4,5:5,6:6} + + +-- !query 22 +select transform_keys(ys, (k, v) -> k + 1) as v from nested +-- !query 22 schema +struct> +-- !query 22 output +{2:1,3:2,4:3} +{5:4,6:5,7:6} + + +-- !query 23 +select transform_keys(ys, (k, v) -> k + v) as v from nested +-- !query 23 schema +struct> +-- !query 23 output +{10:5,12:6,8:4} +{2:1,4:2,6:3} + + +-- !query 24 +select transform_values(ys, (k, v) -> v) as v from nested +-- !query 24 schema +struct> +-- !query 24 output +{1:1,2:2,3:3} +{4:4,5:5,6:6} + + +-- !query 25 +select transform_values(ys, (k, v) -> v + 1) as v from nested +-- !query 25 schema +struct> +-- !query 25 output +{1:2,2:3,3:4} +{4:5,5:6,6:7} + + +-- !query 26 +select transform_values(ys, (k, v) -> k + v) as v from nested +-- !query 26 schema +struct> +-- !query 26 output +{1:2,2:4,3:6} +{4:8,5:10,6:12} diff --git a/sql/core/src/test/resources/sql-tests/results/intersect-all.sql.out b/sql/core/src/test/resources/sql-tests/results/intersect-all.sql.out new file mode 100644 index 0000000000000..63dd56ce468bc --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/intersect-all.sql.out @@ -0,0 +1,307 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 22 + + +-- !query 0 +CREATE TEMPORARY VIEW tab1 AS SELECT * FROM VALUES + (1, 2), + (1, 2), + (1, 3), + (1, 3), + (2, 3), + (null, null), + (null, null) + AS tab1(k, v) +-- !query 0 schema +struct<> +-- !query 0 output + + + +-- !query 1 +CREATE TEMPORARY VIEW tab2 AS SELECT * FROM VALUES + (1, 2), + (1, 2), + (2, 3), + (3, 4), + (null, null), + (null, null) + AS tab2(k, v) +-- !query 1 schema +struct<> +-- !query 1 output + + + +-- !query 2 +SELECT * FROM tab1 +INTERSECT ALL +SELECT * FROM tab2 +-- !query 2 schema +struct +-- !query 2 output +1 2 +1 2 +2 3 +NULL NULL +NULL NULL + + +-- !query 3 +SELECT * FROM tab1 +INTERSECT ALL +SELECT * FROM tab1 WHERE k = 1 +-- !query 3 schema +struct +-- !query 3 output +1 2 +1 2 +1 3 +1 3 + + +-- !query 4 +SELECT * FROM tab1 WHERE k > 2 +INTERSECT ALL +SELECT * FROM tab2 +-- !query 4 schema +struct +-- !query 4 output + + + +-- !query 5 +SELECT * FROM tab1 +INTERSECT ALL +SELECT * FROM tab2 WHERE k > 3 +-- !query 5 schema +struct +-- !query 5 output + + + +-- !query 6 +SELECT * FROM tab1 +INTERSECT ALL +SELECT CAST(1 AS BIGINT), CAST(2 AS BIGINT) +-- !query 6 schema +struct +-- !query 6 output +1 2 + + +-- !query 7 +SELECT * FROM tab1 +INTERSECT ALL +SELECT array(1), 2 +-- !query 7 schema +struct<> +-- !query 7 output +org.apache.spark.sql.AnalysisException +IntersectAll can only be performed on tables with the compatible column types. array <> int at the first column of the second table; + + +-- !query 8 +SELECT k FROM tab1 +INTERSECT ALL +SELECT k, v FROM tab2 +-- !query 8 schema +struct<> +-- !query 8 output +org.apache.spark.sql.AnalysisException +IntersectAll can only be performed on tables with the same number of columns, but the first table has 1 columns and the second table has 2 columns; + + +-- !query 9 +SELECT * FROM tab2 +INTERSECT ALL +SELECT * FROM tab1 +INTERSECT ALL +SELECT * FROM tab2 +-- !query 9 schema +struct +-- !query 9 output +1 2 +1 2 +2 3 +NULL NULL +NULL NULL + + +-- !query 10 +SELECT * FROM tab1 +EXCEPT +SELECT * FROM tab2 +UNION ALL +SELECT * FROM tab1 +INTERSECT ALL +SELECT * FROM tab2 +-- !query 10 schema +struct +-- !query 10 output +1 2 +1 2 +1 3 +2 3 +NULL NULL +NULL NULL + + +-- !query 11 +SELECT * FROM tab1 +EXCEPT +SELECT * FROM tab2 +EXCEPT +SELECT * FROM tab1 +INTERSECT ALL +SELECT * FROM tab2 +-- !query 11 schema +struct +-- !query 11 output +1 3 + + +-- !query 12 +( + ( + ( + SELECT * FROM tab1 + EXCEPT + SELECT * FROM tab2 + ) + EXCEPT + SELECT * FROM tab1 + ) + INTERSECT ALL + SELECT * FROM tab2 +) +-- !query 12 schema +struct +-- !query 12 output + + + +-- !query 13 +SELECT * +FROM (SELECT tab1.k, + tab2.v + FROM tab1 + JOIN tab2 + ON tab1.k = tab2.k) +INTERSECT ALL +SELECT * +FROM (SELECT tab1.k, + tab2.v + FROM tab1 + JOIN tab2 + ON tab1.k = tab2.k) +-- !query 13 schema +struct +-- !query 13 output +1 2 +1 2 +1 2 +1 2 +1 2 +1 2 +1 2 +1 2 +2 3 + + +-- !query 14 +SELECT * +FROM (SELECT tab1.k, + tab2.v + FROM tab1 + JOIN tab2 + ON tab1.k = tab2.k) +INTERSECT ALL +SELECT * +FROM (SELECT tab2.v AS k, + tab1.k AS v + FROM tab1 + JOIN tab2 + ON tab1.k = tab2.k) +-- !query 14 schema +struct +-- !query 14 output + + + +-- !query 15 +SELECT v FROM tab1 GROUP BY v +INTERSECT ALL +SELECT k FROM tab2 GROUP BY k +-- !query 15 schema +struct +-- !query 15 output +2 +3 +NULL + + +-- !query 16 +SET spark.sql.legacy.setopsPrecedence.enabled= true +-- !query 16 schema +struct +-- !query 16 output +spark.sql.legacy.setopsPrecedence.enabled true + + +-- !query 17 +SELECT * FROM tab1 +EXCEPT +SELECT * FROM tab2 +UNION ALL +SELECT * FROM tab1 +INTERSECT ALL +SELECT * FROM tab2 +-- !query 17 schema +struct +-- !query 17 output +1 2 +1 2 +2 3 +NULL NULL +NULL NULL + + +-- !query 18 +SELECT * FROM tab1 +EXCEPT +SELECT * FROM tab2 +UNION ALL +SELECT * FROM tab1 +INTERSECT +SELECT * FROM tab2 +-- !query 18 schema +struct +-- !query 18 output +1 2 +2 3 +NULL NULL + + +-- !query 19 +SET spark.sql.legacy.setopsPrecedence.enabled = false +-- !query 19 schema +struct +-- !query 19 output +spark.sql.legacy.setopsPrecedence.enabled false + + +-- !query 20 +DROP VIEW IF EXISTS tab1 +-- !query 20 schema +struct<> +-- !query 20 output + + + +-- !query 21 +DROP VIEW IF EXISTS tab2 +-- !query 21 schema +struct<> +-- !query 21 output + diff --git a/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out index 2b3288dc5a137..e550b43e08c28 100644 --- a/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 28 +-- Number of queries: 40 -- !query 0 @@ -9,7 +9,7 @@ struct -- !query 0 output Class: org.apache.spark.sql.catalyst.expressions.StructsToJson Function: to_json -Usage: to_json(expr[, options]) - Returns a json string with a given struct value +Usage: to_json(expr[, options]) - Returns a JSON string with a given struct value -- !query 1 @@ -24,7 +24,7 @@ Extended Usage: {"a":1,"b":2} > SELECT to_json(named_struct('time', to_timestamp('2015-08-26', 'yyyy-MM-dd')), map('timestampFormat', 'dd/MM/yyyy')); {"time":"26/08/2015"} - > SELECT to_json(array(named_struct('a', 1, 'b', 2)); + > SELECT to_json(array(named_struct('a', 1, 'b', 2))); [{"a":1,"b":2}] > SELECT to_json(map('a', named_struct('b', 1))); {"a":{"b":1}} @@ -38,7 +38,7 @@ Extended Usage: Since: 2.2.0 Function: to_json -Usage: to_json(expr[, options]) - Returns a json string with a given struct value +Usage: to_json(expr[, options]) - Returns a JSON string with a given struct value -- !query 2 @@ -120,7 +120,7 @@ select to_json(named_struct('a', 1, 'b', 2), map('mode', 1)) struct<> -- !query 11 output org.apache.spark.sql.AnalysisException -A type of keys and values in map() must be string, but got MapType(StringType,IntegerType,false);; line 1 pos 7 +A type of keys and values in map() must be string, but got map;; line 1 pos 7 -- !query 12 @@ -183,7 +183,7 @@ select from_json('{"a":1}', 1) struct<> -- !query 17 output org.apache.spark.sql.AnalysisException -Expected a string literal instead of 1;; line 1 pos 7 +Schema should be specified in DDL format as a string literal or output of the schema_of_json function instead of 1;; line 1 pos 7 -- !query 18 @@ -216,7 +216,7 @@ select from_json('{"a":1}', 'a INT', map('mode', 1)) struct<> -- !query 20 output org.apache.spark.sql.AnalysisException -A type of keys and values in map() must be string, but got MapType(StringType,IntegerType,false);; line 1 pos 7 +A type of keys and values in map() must be string, but got map;; line 1 pos 7 -- !query 21 @@ -274,3 +274,99 @@ select from_json('{"a":1, "b":"2"}', 'struct') struct> -- !query 27 output {"a":1,"b":"2"} + + +-- !query 28 +select schema_of_json('{"c1":0, "c2":[1]}') +-- !query 28 schema +struct +-- !query 28 output +struct> + + +-- !query 29 +select from_json('{"c1":[1, 2, 3]}', schema_of_json('{"c1":[0]}')) +-- !query 29 schema +struct>> +-- !query 29 output +{"c1":[1,2,3]} + + +-- !query 30 +select from_json('[1, 2, 3]', 'array') +-- !query 30 schema +struct> +-- !query 30 output +[1,2,3] + + +-- !query 31 +select from_json('[1, "2", 3]', 'array') +-- !query 31 schema +struct> +-- !query 31 output +NULL + + +-- !query 32 +select from_json('[1, 2, null]', 'array') +-- !query 32 schema +struct> +-- !query 32 output +[1,2,null] + + +-- !query 33 +select from_json('[{"a": 1}, {"a":2}]', 'array>') +-- !query 33 schema +struct>> +-- !query 33 output +[{"a":1},{"a":2}] + + +-- !query 34 +select from_json('{"a": 1}', 'array>') +-- !query 34 schema +struct>> +-- !query 34 output +[{"a":1}] + + +-- !query 35 +select from_json('[null, {"a":2}]', 'array>') +-- !query 35 schema +struct>> +-- !query 35 output +[null,{"a":2}] + + +-- !query 36 +select from_json('[{"a": 1}, {"b":2}]', 'array>') +-- !query 36 schema +struct>> +-- !query 36 output +[{"a":1},{"b":2}] + + +-- !query 37 +select from_json('[{"a": 1}, 2]', 'array>') +-- !query 37 schema +struct>> +-- !query 37 output +NULL + + +-- !query 38 +select to_json(array('1', '2', '3')) +-- !query 38 schema +struct +-- !query 38 output +["1","2","3"] + + +-- !query 39 +select to_json(array(array(1, 2, 3), array(4))) +-- !query 39 schema +struct +-- !query 39 output +[[1,2,3],[4]] diff --git a/sql/core/src/test/resources/sql-tests/results/limit.sql.out b/sql/core/src/test/resources/sql-tests/results/limit.sql.out index 146abe6cbd058..187f3bd6858fe 100644 --- a/sql/core/src/test/resources/sql-tests/results/limit.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/limit.sql.out @@ -1,109 +1,134 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 12 +-- Number of queries: 15 -- !query 0 -SELECT * FROM testdata LIMIT 2 +set spark.sql.limit.flatGlobalLimit=false -- !query 0 schema -struct +struct -- !query 0 output +spark.sql.limit.flatGlobalLimit false + + +-- !query 1 +SELECT * FROM testdata LIMIT 2 +-- !query 1 schema +struct +-- !query 1 output 1 1 2 2 --- !query 1 +-- !query 2 SELECT * FROM arraydata LIMIT 2 --- !query 1 schema +-- !query 2 schema struct,nestedarraycol:array>> --- !query 1 output +-- !query 2 output [1,2,3] [[1,2,3]] [2,3,4] [[2,3,4]] --- !query 2 +-- !query 3 SELECT * FROM mapdata LIMIT 2 --- !query 2 schema +-- !query 3 schema struct> --- !query 2 output +-- !query 3 output {1:"a1",2:"b1",3:"c1",4:"d1",5:"e1"} {1:"a2",2:"b2",3:"c2",4:"d2"} --- !query 3 +-- !query 4 SELECT * FROM testdata LIMIT 2 + 1 --- !query 3 schema +-- !query 4 schema struct --- !query 3 output +-- !query 4 output 1 1 2 2 3 3 --- !query 4 +-- !query 5 SELECT * FROM testdata LIMIT CAST(1 AS int) --- !query 4 schema +-- !query 5 schema struct --- !query 4 output +-- !query 5 output 1 1 --- !query 5 +-- !query 6 SELECT * FROM testdata LIMIT -1 --- !query 5 schema +-- !query 6 schema struct<> --- !query 5 output +-- !query 6 output org.apache.spark.sql.AnalysisException The limit expression must be equal to or greater than 0, but got -1; --- !query 6 +-- !query 7 SELECT * FROM testData TABLESAMPLE (-1 ROWS) --- !query 6 schema +-- !query 7 schema struct<> --- !query 6 output +-- !query 7 output org.apache.spark.sql.AnalysisException The limit expression must be equal to or greater than 0, but got -1; --- !query 7 +-- !query 8 +SELECT * FROM testdata LIMIT CAST(1 AS INT) +-- !query 8 schema +struct +-- !query 8 output +1 1 + + +-- !query 9 +SELECT * FROM testdata LIMIT CAST(NULL AS INT) +-- !query 9 schema +struct<> +-- !query 9 output +org.apache.spark.sql.AnalysisException +The evaluated limit expression must not be null, but got CAST(NULL AS INT); + + +-- !query 10 SELECT * FROM testdata LIMIT key > 3 --- !query 7 schema +-- !query 10 schema struct<> --- !query 7 output +-- !query 10 output org.apache.spark.sql.AnalysisException The limit expression must evaluate to a constant value, but got (testdata.`key` > 3); --- !query 8 +-- !query 11 SELECT * FROM testdata LIMIT true --- !query 8 schema +-- !query 11 schema struct<> --- !query 8 output +-- !query 11 output org.apache.spark.sql.AnalysisException The limit expression must be integer type, but got boolean; --- !query 9 +-- !query 12 SELECT * FROM testdata LIMIT 'a' --- !query 9 schema +-- !query 12 schema struct<> --- !query 9 output +-- !query 12 output org.apache.spark.sql.AnalysisException The limit expression must be integer type, but got string; --- !query 10 +-- !query 13 SELECT * FROM (SELECT * FROM range(10) LIMIT 5) WHERE id > 3 --- !query 10 schema +-- !query 13 schema struct --- !query 10 output +-- !query 13 output 4 --- !query 11 +-- !query 14 SELECT * FROM testdata WHERE key < 3 LIMIT ALL --- !query 11 schema +-- !query 14 schema struct --- !query 11 output +-- !query 14 output 1 1 2 2 diff --git a/sql/core/src/test/resources/sql-tests/results/literals.sql.out b/sql/core/src/test/resources/sql-tests/results/literals.sql.out index b8c91dc8b59a4..7f301614523b2 100644 --- a/sql/core/src/test/resources/sql-tests/results/literals.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/literals.sql.out @@ -147,7 +147,7 @@ struct<> -- !query 15 output org.apache.spark.sql.catalyst.parser.ParseException -DecimalType can only support precision up to 38 +decimal can only support precision up to 38 == SQL == select 1234567890123456789012345678901234567890 @@ -159,7 +159,7 @@ struct<> -- !query 16 output org.apache.spark.sql.catalyst.parser.ParseException -DecimalType can only support precision up to 38 +decimal can only support precision up to 38 == SQL == select 1234567890123456789012345678901234567890.0 @@ -379,7 +379,7 @@ struct<> -- !query 39 output org.apache.spark.sql.catalyst.parser.ParseException -DecimalType can only support precision up to 38(line 1, pos 7) +decimal can only support precision up to 38(line 1, pos 7) == SQL == select 1.20E-38BD diff --git a/sql/core/src/test/resources/sql-tests/results/pivot.sql.out b/sql/core/src/test/resources/sql-tests/results/pivot.sql.out index 85e3488990e20..2dd92930f92aa 100644 --- a/sql/core/src/test/resources/sql-tests/results/pivot.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/pivot.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 13 +-- Number of queries: 31 -- !query 0 @@ -28,6 +28,17 @@ struct<> -- !query 2 +create temporary view yearsWithComplexTypes as select * from values + (2012, array(1, 1), map('1', 1), struct(1, 'a')), + (2013, array(2, 2), map('2', 2), struct(2, 'b')) + as yearsWithComplexTypes(y, a, m, s) +-- !query 2 schema +struct<> +-- !query 2 output + + + +-- !query 3 SELECT * FROM ( SELECT year, course, earnings FROM courseSales ) @@ -35,27 +46,27 @@ PIVOT ( sum(earnings) FOR course IN ('dotNET', 'Java') ) --- !query 2 schema +-- !query 3 schema struct --- !query 2 output +-- !query 3 output 2012 15000 20000 2013 48000 30000 --- !query 3 +-- !query 4 SELECT * FROM courseSales PIVOT ( sum(earnings) FOR year IN (2012, 2013) ) --- !query 3 schema +-- !query 4 schema struct --- !query 3 output +-- !query 4 output Java 20000 30000 dotNET 15000 48000 --- !query 4 +-- !query 5 SELECT * FROM ( SELECT year, course, earnings FROM courseSales ) @@ -63,14 +74,14 @@ PIVOT ( sum(earnings), avg(earnings) FOR course IN ('dotNET', 'Java') ) --- !query 4 schema +-- !query 5 schema struct --- !query 4 output +-- !query 5 output 2012 15000 7500.0 20000 20000.0 2013 48000 48000.0 30000 30000.0 --- !query 5 +-- !query 6 SELECT * FROM ( SELECT course, earnings FROM courseSales ) @@ -78,13 +89,13 @@ PIVOT ( sum(earnings) FOR course IN ('dotNET', 'Java') ) --- !query 5 schema +-- !query 6 schema struct --- !query 5 output +-- !query 6 output 63000 50000 --- !query 6 +-- !query 7 SELECT * FROM ( SELECT year, course, earnings FROM courseSales ) @@ -92,13 +103,13 @@ PIVOT ( sum(earnings), min(year) FOR course IN ('dotNET', 'Java') ) --- !query 6 schema +-- !query 7 schema struct --- !query 6 output +-- !query 7 output 63000 2012 50000 2012 --- !query 7 +-- !query 8 SELECT * FROM ( SELECT course, year, earnings, s FROM courseSales @@ -108,16 +119,16 @@ PIVOT ( sum(earnings) FOR s IN (1, 2) ) --- !query 7 schema +-- !query 8 schema struct --- !query 7 output +-- !query 8 output Java 2012 20000 NULL Java 2013 NULL 30000 dotNET 2012 15000 NULL dotNET 2013 NULL 48000 --- !query 8 +-- !query 9 SELECT * FROM ( SELECT course, year, earnings, s FROM courseSales @@ -127,14 +138,14 @@ PIVOT ( sum(earnings), min(s) FOR course IN ('dotNET', 'Java') ) --- !query 8 schema +-- !query 9 schema struct --- !query 8 output +-- !query 9 output 2012 15000 1 20000 1 2013 48000 2 30000 2 --- !query 9 +-- !query 10 SELECT * FROM ( SELECT course, year, earnings, s FROM courseSales @@ -144,14 +155,14 @@ PIVOT ( sum(earnings * s) FOR course IN ('dotNET', 'Java') ) --- !query 9 schema +-- !query 10 schema struct --- !query 9 output +-- !query 10 output 2012 15000 20000 2013 96000 60000 --- !query 10 +-- !query 11 SELECT 2012_s, 2013_s, 2012_a, 2013_a, c FROM ( SELECT year y, course c, earnings e FROM courseSales ) @@ -159,27 +170,57 @@ PIVOT ( sum(e) s, avg(e) a FOR y IN (2012, 2013) ) --- !query 10 schema +-- !query 11 schema struct<2012_s:bigint,2013_s:bigint,2012_a:double,2013_a:double,c:string> --- !query 10 output +-- !query 11 output 15000 48000 7500.0 48000.0 dotNET 20000 30000 20000.0 30000.0 Java --- !query 11 +-- !query 12 +SELECT firstYear_s, secondYear_s, firstYear_a, secondYear_a, c FROM ( + SELECT year y, course c, earnings e FROM courseSales +) +PIVOT ( + sum(e) s, avg(e) a + FOR y IN (2012 as firstYear, 2013 secondYear) +) +-- !query 12 schema +struct +-- !query 12 output +15000 48000 7500.0 48000.0 dotNET +20000 30000 20000.0 30000.0 Java + + +-- !query 13 SELECT * FROM courseSales PIVOT ( abs(earnings) FOR year IN (2012, 2013) ) --- !query 11 schema +-- !query 13 schema struct<> --- !query 11 output +-- !query 13 output org.apache.spark.sql.AnalysisException -Aggregate expression required for pivot, found 'abs(earnings#x)'; +Aggregate expression required for pivot, but 'coursesales.`earnings`' did not appear in any aggregate function.; --- !query 12 +-- !query 14 +SELECT * FROM ( + SELECT year, course, earnings FROM courseSales +) +PIVOT ( + sum(earnings), year + FOR course IN ('dotNET', 'Java') +) +-- !query 14 schema +struct<> +-- !query 14 output +org.apache.spark.sql.AnalysisException +Aggregate expression required for pivot, but '__auto_generated_subquery_name.`year`' did not appear in any aggregate function.; + + +-- !query 15 SELECT * FROM ( SELECT course, earnings FROM courseSales ) @@ -187,8 +228,251 @@ PIVOT ( sum(earnings) FOR year IN (2012, 2013) ) --- !query 12 schema +-- !query 15 schema struct<> --- !query 12 output +-- !query 15 output org.apache.spark.sql.AnalysisException cannot resolve '`year`' given input columns: [__auto_generated_subquery_name.course, __auto_generated_subquery_name.earnings]; line 4 pos 0 + + +-- !query 16 +SELECT * FROM ( + SELECT year, course, earnings FROM courseSales +) +PIVOT ( + ceil(sum(earnings)), avg(earnings) + 1 as a1 + FOR course IN ('dotNET', 'Java') +) +-- !query 16 schema +struct +-- !query 16 output +2012 15000 7501.0 20000 20001.0 +2013 48000 48001.0 30000 30001.0 + + +-- !query 17 +SELECT * FROM ( + SELECT year, course, earnings FROM courseSales +) +PIVOT ( + sum(avg(earnings)) + FOR course IN ('dotNET', 'Java') +) +-- !query 17 schema +struct<> +-- !query 17 output +org.apache.spark.sql.AnalysisException +It is not allowed to use an aggregate function in the argument of another aggregate function. Please use the inner aggregate function in a sub-query.; + + +-- !query 18 +SELECT * FROM ( + SELECT course, year, earnings, s + FROM courseSales + JOIN years ON year = y +) +PIVOT ( + sum(earnings) + FOR (course, year) IN (('dotNET', 2012), ('Java', 2013)) +) +-- !query 18 schema +struct +-- !query 18 output +1 15000 NULL +2 NULL 30000 + + +-- !query 19 +SELECT * FROM ( + SELECT course, year, earnings, s + FROM courseSales + JOIN years ON year = y +) +PIVOT ( + sum(earnings) + FOR (course, s) IN (('dotNET', 2) as c1, ('Java', 1) as c2) +) +-- !query 19 schema +struct +-- !query 19 output +2012 NULL 20000 +2013 48000 NULL + + +-- !query 20 +SELECT * FROM ( + SELECT course, year, earnings, s + FROM courseSales + JOIN years ON year = y +) +PIVOT ( + sum(earnings) + FOR (course, year) IN ('dotNET', 'Java') +) +-- !query 20 schema +struct<> +-- !query 20 output +org.apache.spark.sql.AnalysisException +Invalid pivot value 'dotNET': value data type string does not match pivot column data type struct; + + +-- !query 21 +SELECT * FROM courseSales +PIVOT ( + sum(earnings) + FOR year IN (s, 2013) +) +-- !query 21 schema +struct<> +-- !query 21 output +org.apache.spark.sql.AnalysisException +cannot resolve '`s`' given input columns: [coursesales.course, coursesales.year, coursesales.earnings]; line 4 pos 15 + + +-- !query 22 +SELECT * FROM courseSales +PIVOT ( + sum(earnings) + FOR year IN (course, 2013) +) +-- !query 22 schema +struct<> +-- !query 22 output +org.apache.spark.sql.AnalysisException +Literal expressions required for pivot values, found 'course#x'; + + +-- !query 23 +SELECT * FROM ( + SELECT course, year, a + FROM courseSales + JOIN yearsWithComplexTypes ON year = y +) +PIVOT ( + min(a) + FOR course IN ('dotNET', 'Java') +) +-- !query 23 schema +struct,Java:array> +-- !query 23 output +2012 [1,1] [1,1] +2013 [2,2] [2,2] + + +-- !query 24 +SELECT * FROM ( + SELECT course, year, y, a + FROM courseSales + JOIN yearsWithComplexTypes ON year = y +) +PIVOT ( + max(a) + FOR (y, course) IN ((2012, 'dotNET'), (2013, 'Java')) +) +-- !query 24 schema +struct,[2013, Java]:array> +-- !query 24 output +2012 [1,1] NULL +2013 NULL [2,2] + + +-- !query 25 +SELECT * FROM ( + SELECT earnings, year, a + FROM courseSales + JOIN yearsWithComplexTypes ON year = y +) +PIVOT ( + sum(earnings) + FOR a IN (array(1, 1), array(2, 2)) +) +-- !query 25 schema +struct +-- !query 25 output +2012 35000 NULL +2013 NULL 78000 + + +-- !query 26 +SELECT * FROM ( + SELECT course, earnings, year, a + FROM courseSales + JOIN yearsWithComplexTypes ON year = y +) +PIVOT ( + sum(earnings) + FOR (course, a) IN (('dotNET', array(1, 1)), ('Java', array(2, 2))) +) +-- !query 26 schema +struct +-- !query 26 output +2012 15000 NULL +2013 NULL 30000 + + +-- !query 27 +SELECT * FROM ( + SELECT earnings, year, s + FROM courseSales + JOIN yearsWithComplexTypes ON year = y +) +PIVOT ( + sum(earnings) + FOR s IN ((1, 'a'), (2, 'b')) +) +-- !query 27 schema +struct +-- !query 27 output +2012 35000 NULL +2013 NULL 78000 + + +-- !query 28 +SELECT * FROM ( + SELECT course, earnings, year, s + FROM courseSales + JOIN yearsWithComplexTypes ON year = y +) +PIVOT ( + sum(earnings) + FOR (course, s) IN (('dotNET', (1, 'a')), ('Java', (2, 'b'))) +) +-- !query 28 schema +struct +-- !query 28 output +2012 15000 NULL +2013 NULL 30000 + + +-- !query 29 +SELECT * FROM ( + SELECT earnings, year, m + FROM courseSales + JOIN yearsWithComplexTypes ON year = y +) +PIVOT ( + sum(earnings) + FOR m IN (map('1', 1), map('2', 2)) +) +-- !query 29 schema +struct<> +-- !query 29 output +org.apache.spark.sql.AnalysisException +Invalid pivot column 'm#x'. Pivot columns must be comparable.; + + +-- !query 30 +SELECT * FROM ( + SELECT course, earnings, year, m + FROM courseSales + JOIN yearsWithComplexTypes ON year = y +) +PIVOT ( + sum(earnings) + FOR (course, m) IN (('dotNET', map('1', 1)), ('Java', map('2', 2))) +) +-- !query 30 schema +struct<> +-- !query 30 output +org.apache.spark.sql.AnalysisException +Invalid pivot column 'named_struct(course, course#x, m, m#x)'. Pivot columns must be comparable.; diff --git a/sql/core/src/test/resources/sql-tests/results/string-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/string-functions.sql.out index d5f8705a35ed6..7b3dc84388889 100644 --- a/sql/core/src/test/resources/sql-tests/results/string-functions.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/string-functions.sql.out @@ -36,14 +36,14 @@ struct -- !query 3 output == Parsed Logical Plan == 'Project [concat(concat(concat('col1, 'col2), 'col3), 'col4) AS col#x] -+- 'SubqueryAlias __auto_generated_subquery_name ++- 'SubqueryAlias `__auto_generated_subquery_name` +- 'Project ['id AS col1#x, 'id AS col2#x, 'id AS col3#x, 'id AS col4#x] +- 'UnresolvedTableValuedFunction range, [10] == Analyzed Logical Plan == col: string Project [concat(concat(concat(cast(col1#xL as string), cast(col2#xL as string)), cast(col3#xL as string)), cast(col4#xL as string)) AS col#x] -+- SubqueryAlias __auto_generated_subquery_name ++- SubqueryAlias `__auto_generated_subquery_name` +- Project [id#xL AS col1#xL, id#xL AS col2#xL, id#xL AS col3#xL, id#xL AS col4#xL] +- Range (0, 10, step=1, splits=None) diff --git a/sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/in-basic.sql.out b/sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/in-basic.sql.out new file mode 100644 index 0000000000000..088db55d66406 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/in-basic.sql.out @@ -0,0 +1,70 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 7 + + +-- !query 0 +create temporary view tab_a as select * from values (1, 1) as tab_a(a1, b1) +-- !query 0 schema +struct<> +-- !query 0 output + + + +-- !query 1 +create temporary view tab_b as select * from values (1, 1) as tab_b(a2, b2) +-- !query 1 schema +struct<> +-- !query 1 output + + + +-- !query 2 +create temporary view struct_tab as select struct(col1 as a, col2 as b) as record from + values (1, 1), (1, 2), (2, 1), (2, 2) +-- !query 2 schema +struct<> +-- !query 2 output + + + +-- !query 3 +select 1 from tab_a where (a1, b1) not in (select a2, b2 from tab_b) +-- !query 3 schema +struct<1:int> +-- !query 3 output + + + +-- !query 4 +select 1 from tab_a where (a1, b1) not in (select (a2, b2) from tab_b) +-- !query 4 schema +struct<> +-- !query 4 output +org.apache.spark.sql.AnalysisException +Cannot analyze (named_struct('a1', tab_a.`a1`, 'b1', tab_a.`b1`) IN (listquery())). +The number of columns in the left hand side of an IN subquery does not match the +number of columns in the output of subquery. +#columns in left hand side: 2 +#columns in right hand side: 1 +Left side columns: +[tab_a.`a1`, tab_a.`b1`] +Right side columns: +[`named_struct(a2, a2, b2, b2)`]; + + +-- !query 5 +select count(*) from struct_tab where record in + (select (a2 as a, b2 as b) from tab_b) +-- !query 5 schema +struct +-- !query 5 output +1 + + +-- !query 6 +select count(*) from struct_tab where record not in + (select (a2 as a, b2 as b) from tab_b) +-- !query 6 schema +struct +-- !query 6 output +3 diff --git a/sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/in-limit.sql.out b/sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/in-limit.sql.out index 71ca1f8649475..9eb5b3383e734 100644 --- a/sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/in-limit.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/in-limit.sql.out @@ -1,8 +1,16 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 8 +-- Number of queries: 9 -- !query 0 +set spark.sql.limit.flatGlobalLimit=false +-- !query 0 schema +struct +-- !query 0 output +spark.sql.limit.flatGlobalLimit false + + +-- !query 1 create temporary view t1 as select * from values ("val1a", 6S, 8, 10L, float(15.0), 20D, 20E2, timestamp '2014-04-04 01:00:00.000', date '2014-04-04'), ("val1b", 8S, 16, 19L, float(17.0), 25D, 26E2, timestamp '2014-05-04 01:01:00.000', date '2014-05-04'), @@ -17,13 +25,13 @@ create temporary view t1 as select * from values ("val1a", 6S, 8, 10L, float(15.0), 20D, 20E2, timestamp '2014-04-04 01:02:00.001', date '2014-04-04'), ("val1e", 10S, null, 19L, float(17.0), 25D, 26E2, timestamp '2014-05-04 01:01:00.000', date '2014-05-04') as t1(t1a, t1b, t1c, t1d, t1e, t1f, t1g, t1h, t1i) --- !query 0 schema +-- !query 1 schema struct<> --- !query 0 output +-- !query 1 output --- !query 1 +-- !query 2 create temporary view t2 as select * from values ("val2a", 6S, 12, 14L, float(15), 20D, 20E2, timestamp '2014-04-04 01:01:00.000', date '2014-04-04'), ("val1b", 10S, 12, 19L, float(17), 25D, 26E2, timestamp '2014-05-04 01:01:00.000', date '2014-05-04'), @@ -39,13 +47,13 @@ create temporary view t2 as select * from values ("val1f", 19S, null, 19L, float(17), 25D, 26E2, timestamp '2014-10-04 01:01:00.000', date '2014-10-04'), ("val1b", null, 16, 19L, float(17), 25D, 26E2, timestamp '2014-05-04 01:01:00.000', null) as t2(t2a, t2b, t2c, t2d, t2e, t2f, t2g, t2h, t2i) --- !query 1 schema +-- !query 2 schema struct<> --- !query 1 output +-- !query 2 output --- !query 2 +-- !query 3 create temporary view t3 as select * from values ("val3a", 6S, 12, 110L, float(15), 20D, 20E2, timestamp '2014-04-04 01:02:00.000', date '2014-04-04'), ("val3a", 6S, 12, 10L, float(15), 20D, 20E2, timestamp '2014-05-04 01:02:00.000', date '2014-05-04'), @@ -60,27 +68,27 @@ create temporary view t3 as select * from values ("val3b", 8S, null, 719L, float(17), 25D, 26E2, timestamp '2014-05-04 01:02:00.000', date '2014-05-04'), ("val3b", 8S, null, 19L, float(17), 25D, 26E2, timestamp '2015-05-04 01:02:00.000', date '2015-05-04') as t3(t3a, t3b, t3c, t3d, t3e, t3f, t3g, t3h, t3i) --- !query 2 schema +-- !query 3 schema struct<> --- !query 2 output +-- !query 3 output --- !query 3 +-- !query 4 SELECT * FROM t1 WHERE t1a IN (SELECT t2a FROM t2 WHERE t1d = t2d) LIMIT 2 --- !query 3 schema +-- !query 4 schema struct --- !query 3 output +-- !query 4 output val1b 8 16 19 17.0 25.0 2600 2014-05-04 01:01:00 2014-05-04 val1c 8 16 19 17.0 25.0 2600 2014-05-04 01:02:00.001 2014-05-05 --- !query 4 +-- !query 5 SELECT * FROM t1 WHERE t1c IN (SELECT t2c @@ -88,16 +96,16 @@ WHERE t1c IN (SELECT t2c WHERE t2b >= 8 LIMIT 2) LIMIT 4 --- !query 4 schema +-- !query 5 schema struct --- !query 4 output +-- !query 5 output val1a 16 12 10 15.0 20.0 2000 2014-07-04 01:01:00 2014-07-04 val1a 16 12 21 15.0 20.0 2000 2014-06-04 01:02:00.001 2014-06-04 val1b 8 16 19 17.0 25.0 2600 2014-05-04 01:01:00 2014-05-04 val1c 8 16 19 17.0 25.0 2600 2014-05-04 01:02:00.001 2014-05-05 --- !query 5 +-- !query 6 SELECT Count(DISTINCT( t1a )), t1b FROM t1 @@ -108,29 +116,29 @@ WHERE t1d IN (SELECT t2d GROUP BY t1b ORDER BY t1b DESC NULLS FIRST LIMIT 1 --- !query 5 schema +-- !query 6 schema struct --- !query 5 output +-- !query 6 output 1 NULL --- !query 6 +-- !query 7 SELECT * FROM t1 WHERE t1b NOT IN (SELECT t2b FROM t2 WHERE t2b > 6 LIMIT 2) --- !query 6 schema +-- !query 7 schema struct --- !query 6 output +-- !query 7 output val1a 16 12 10 15.0 20.0 2000 2014-07-04 01:01:00 2014-07-04 val1a 16 12 21 15.0 20.0 2000 2014-06-04 01:02:00.001 2014-06-04 val1a 6 8 10 15.0 20.0 2000 2014-04-04 01:00:00 2014-04-04 val1a 6 8 10 15.0 20.0 2000 2014-04-04 01:02:00.001 2014-04-04 --- !query 7 +-- !query 8 SELECT Count(DISTINCT( t1a )), t1b FROM t1 @@ -141,7 +149,7 @@ WHERE t1d NOT IN (SELECT t2d GROUP BY t1b ORDER BY t1b NULLS last LIMIT 1 --- !query 7 schema +-- !query 8 schema struct --- !query 7 output +-- !query 8 output 1 6 diff --git a/sql/core/src/test/resources/sql-tests/results/subquery/negative-cases/invalid-correlation.sql.out b/sql/core/src/test/resources/sql-tests/results/subquery/negative-cases/invalid-correlation.sql.out index 2586f26f71c35..e49978ddb1ce2 100644 --- a/sql/core/src/test/resources/sql-tests/results/subquery/negative-cases/invalid-correlation.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/subquery/negative-cases/invalid-correlation.sql.out @@ -109,8 +109,8 @@ struct<> org.apache.spark.sql.AnalysisException Expressions referencing the outer query are not supported outside of WHERE/HAVING clauses: Aggregate [min(outer(t2a#x)) AS min(outer())#x] -+- SubqueryAlias t3 ++- SubqueryAlias `t3` +- Project [t3a#x, t3b#x, t3c#x] - +- SubqueryAlias t3 + +- SubqueryAlias `t3` +- LocalRelation [t3a#x, t3b#x, t3c#x] ; diff --git a/sql/core/src/test/resources/sql-tests/results/subquery/negative-cases/subq-input-typecheck.sql.out b/sql/core/src/test/resources/sql-tests/results/subquery/negative-cases/subq-input-typecheck.sql.out index 70aeb9373f3c7..c52e5706deeee 100644 --- a/sql/core/src/test/resources/sql-tests/results/subquery/negative-cases/subq-input-typecheck.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/subquery/negative-cases/subq-input-typecheck.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 7 +-- Number of queries: 10 -- !query 0 @@ -33,6 +33,26 @@ struct<> -- !query 3 +CREATE TEMPORARY VIEW t4 AS SELECT * FROM VALUES + (CAST(1 AS DOUBLE), CAST(2 AS STRING), CAST(3 AS STRING)) +AS t1(t4a, t4b, t4c) +-- !query 3 schema +struct<> +-- !query 3 output + + + +-- !query 4 +CREATE TEMPORARY VIEW t5 AS SELECT * FROM VALUES + (CAST(1 AS DECIMAL(18, 0)), CAST(2 AS STRING), CAST(3 AS BIGINT)) +AS t1(t5a, t5b, t5c) +-- !query 4 schema +struct<> +-- !query 4 output + + + +-- !query 5 SELECT ( SELECT max(t2b), min(t2b) FROM t2 @@ -40,14 +60,14 @@ SELECT GROUP BY t2.t2b ) FROM t1 --- !query 3 schema +-- !query 5 schema struct<> --- !query 3 output +-- !query 5 output org.apache.spark.sql.AnalysisException Scalar subquery must return only one column, but got 2; --- !query 4 +-- !query 6 SELECT ( SELECT max(t2b), min(t2b) FROM t2 @@ -55,50 +75,72 @@ SELECT GROUP BY t2.t2b ) FROM t1 --- !query 4 schema +-- !query 6 schema struct<> --- !query 4 output +-- !query 6 output org.apache.spark.sql.AnalysisException Scalar subquery must return only one column, but got 2; --- !query 5 +-- !query 7 SELECT * FROM t1 WHERE t1a IN (SELECT t2a, t2b FROM t2 WHERE t1a = t2a) --- !query 5 schema +-- !query 7 schema struct<> --- !query 5 output +-- !query 7 output org.apache.spark.sql.AnalysisException -cannot resolve '(t1.`t1a` IN (listquery(t1.`t1a`)))' due to data type mismatch: +Cannot analyze (t1.`t1a` IN (listquery(t1.`t1a`))). The number of columns in the left hand side of an IN subquery does not match the number of columns in the output of subquery. -#columns in left hand side: 1. -#columns in right hand side: 2. +#columns in left hand side: 1 +#columns in right hand side: 2 Left side columns: -[t1.`t1a`]. +[t1.`t1a`] Right side columns: -[t2.`t2a`, t2.`t2b`].; +[t2.`t2a`, t2.`t2b`]; --- !query 6 +-- !query 8 SELECT * FROM T1 WHERE (t1a, t1b) IN (SELECT t2a FROM t2 WHERE t1a = t2a) --- !query 6 schema +-- !query 8 schema struct<> --- !query 6 output +-- !query 8 output org.apache.spark.sql.AnalysisException -cannot resolve '(named_struct('t1a', t1.`t1a`, 't1b', t1.`t1b`) IN (listquery(t1.`t1a`)))' due to data type mismatch: +Cannot analyze (named_struct('t1a', t1.`t1a`, 't1b', t1.`t1b`) IN (listquery(t1.`t1a`))). The number of columns in the left hand side of an IN subquery does not match the number of columns in the output of subquery. -#columns in left hand side: 2. -#columns in right hand side: 1. +#columns in left hand side: 2 +#columns in right hand side: 1 Left side columns: -[t1.`t1a`, t1.`t1b`]. +[t1.`t1a`, t1.`t1b`] Right side columns: -[t2.`t2a`].; +[t2.`t2a`]; + + +-- !query 9 +SELECT * FROM t4 +WHERE +(t4a, t4b, t4c) IN (SELECT t5a, + t5b, + t5c + FROM t5) +-- !query 9 schema +struct<> +-- !query 9 output +org.apache.spark.sql.AnalysisException +cannot resolve '(named_struct('t4a', t4.`t4a`, 't4b', t4.`t4b`, 't4c', t4.`t4c`) IN (listquery()))' due to data type mismatch: +The data type of one or more elements in the left hand side of an IN subquery +is not compatible with the data type of the output of the subquery +Mismatched columns: +[(t4.`t4a`:double, t5.`t5a`:decimal(18,0)), (t4.`t4c`:string, t5.`t5c`:bigint)] +Left side: +[double, string, string]. +Right side: +[decimal(18,0), string, bigint].; diff --git a/sql/core/src/test/resources/sql-tests/results/table-valued-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/table-valued-functions.sql.out index a8bc6faf11262..94af9181225d6 100644 --- a/sql/core/src/test/resources/sql-tests/results/table-valued-functions.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/table-valued-functions.sql.out @@ -83,8 +83,13 @@ select * from range(1, null) -- !query 6 schema struct<> -- !query 6 output -java.lang.IllegalArgumentException -Invalid arguments for resolved function: 1, null +org.apache.spark.sql.AnalysisException +error: table-valued function range with alternatives: + (end: long) + (start: long, end: long) + (start: long, end: long, step: long) + (start: long, end: long, step: long, numPartitions: integer) +cannot be applied to: (integer, null); line 1 pos 14 -- !query 7 diff --git a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/concat.sql.out b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/concat.sql.out index be637b66abc86..6c6d3110d7d0d 100644 --- a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/concat.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/concat.sql.out @@ -306,12 +306,14 @@ SELECT (tinyint_array1 || smallint_array2) ts_array, (smallint_array1 || int_array2) si_array, (int_array1 || bigint_array2) ib_array, + (bigint_array1 || decimal_array2) bd_array, + (decimal_array1 || double_array2) dd_array, (double_array1 || float_array2) df_array, (string_array1 || data_array2) std_array, (timestamp_array1 || string_array2) tst_array, (string_array1 || int_array2) sti_array FROM various_arrays -- !query 13 schema -struct,si_array:array,ib_array:array,df_array:array,std_array:array,tst_array:array,sti_array:array> +struct,si_array:array,ib_array:array,bd_array:array,dd_array:array,df_array:array,std_array:array,tst_array:array,sti_array:array> -- !query 13 output -[2,1,3,4] [2,1,3,4] [2,1,3,4] [2.0,1.0,3.0,4.0] ["a","b","2016-03-12","2016-03-11"] ["2016-11-15 20:54:00","2016-11-12 20:54:00","c","d"] ["a","b","3","4"] +[2,1,3,4] [2,1,3,4] [2,1,3,4] [2,1,9223372036854775808,9223372036854775809] [9.223372036854776E18,9.223372036854776E18,3.0,4.0] [2.0,1.0,3.0,4.0] ["a","b","2016-03-12","2016-03-11"] ["2016-11-15 20:54:00","2016-11-12 20:54:00","c","d"] ["a","b","3","4"] diff --git a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/mapZipWith.sql.out b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/mapZipWith.sql.out new file mode 100644 index 0000000000000..35740094ba53e --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/mapZipWith.sql.out @@ -0,0 +1,179 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 16 + + +-- !query 0 +CREATE TEMPORARY VIEW various_maps AS SELECT * FROM VALUES ( + map(true, false), + map(2Y, 1Y), + map(2S, 1S), + map(2, 1), + map(2L, 1L), + map(922337203685477897945456575809789456, 922337203685477897945456575809789456), + map(9.22337203685477897945456575809789456, 9.22337203685477897945456575809789456), + map(2.0D, 1.0D), + map(float(2.0), float(1.0)), + map(date '2016-03-14', date '2016-03-13'), + map(timestamp '2016-11-15 20:54:00.000', timestamp '2016-11-12 20:54:00.000'), + map('true', 'false', '2', '1'), + map('2016-03-14', '2016-03-13'), + map('2016-11-15 20:54:00.000', '2016-11-12 20:54:00.000'), + map('922337203685477897945456575809789456', 'text'), + map(array(1L, 2L), array(1L, 2L)), map(array(1, 2), array(1, 2)), + map(struct(1S, 2L), struct(1S, 2L)), map(struct(1, 2), struct(1, 2)) +) AS various_maps( + boolean_map, + tinyint_map, + smallint_map, + int_map, + bigint_map, + decimal_map1, decimal_map2, + double_map, + float_map, + date_map, + timestamp_map, + string_map1, string_map2, string_map3, string_map4, + array_map1, array_map2, + struct_map1, struct_map2 +) +-- !query 0 schema +struct<> +-- !query 0 output + + + +-- !query 1 +SELECT map_zip_with(tinyint_map, smallint_map, (k, v1, v2) -> struct(k, v1, v2)) m +FROM various_maps +-- !query 1 schema +struct>> +-- !query 1 output +{2:{"k":2,"v1":1,"v2":1}} + + +-- !query 2 +SELECT map_zip_with(smallint_map, int_map, (k, v1, v2) -> struct(k, v1, v2)) m +FROM various_maps +-- !query 2 schema +struct>> +-- !query 2 output +{2:{"k":2,"v1":1,"v2":1}} + + +-- !query 3 +SELECT map_zip_with(int_map, bigint_map, (k, v1, v2) -> struct(k, v1, v2)) m +FROM various_maps +-- !query 3 schema +struct>> +-- !query 3 output +{2:{"k":2,"v1":1,"v2":1}} + + +-- !query 4 +SELECT map_zip_with(double_map, float_map, (k, v1, v2) -> struct(k, v1, v2)) m +FROM various_maps +-- !query 4 schema +struct>> +-- !query 4 output +{2.0:{"k":2.0,"v1":1.0,"v2":1.0}} + + +-- !query 5 +SELECT map_zip_with(decimal_map1, decimal_map2, (k, v1, v2) -> struct(k, v1, v2)) m +FROM various_maps +-- !query 5 schema +struct<> +-- !query 5 output +org.apache.spark.sql.AnalysisException +cannot resolve 'map_zip_with(various_maps.`decimal_map1`, various_maps.`decimal_map2`, lambdafunction(named_struct(NamePlaceholder(), `k`, NamePlaceholder(), `v1`, NamePlaceholder(), `v2`), `k`, `v1`, `v2`))' due to argument data type mismatch: The input to function map_zip_with should have been two maps with compatible key types, but the key types are [decimal(36,0), decimal(36,35)].; line 1 pos 7 + + +-- !query 6 +SELECT map_zip_with(decimal_map1, int_map, (k, v1, v2) -> struct(k, v1, v2)) m +FROM various_maps +-- !query 6 schema +struct>> +-- !query 6 output +{2:{"k":2,"v1":null,"v2":1},922337203685477897945456575809789456:{"k":922337203685477897945456575809789456,"v1":922337203685477897945456575809789456,"v2":null}} + + +-- !query 7 +SELECT map_zip_with(decimal_map1, double_map, (k, v1, v2) -> struct(k, v1, v2)) m +FROM various_maps +-- !query 7 schema +struct>> +-- !query 7 output +{2.0:{"k":2.0,"v1":null,"v2":1.0},9.223372036854779E35:{"k":9.223372036854779E35,"v1":922337203685477897945456575809789456,"v2":null}} + + +-- !query 8 +SELECT map_zip_with(decimal_map2, int_map, (k, v1, v2) -> struct(k, v1, v2)) m +FROM various_maps +-- !query 8 schema +struct<> +-- !query 8 output +org.apache.spark.sql.AnalysisException +cannot resolve 'map_zip_with(various_maps.`decimal_map2`, various_maps.`int_map`, lambdafunction(named_struct(NamePlaceholder(), `k`, NamePlaceholder(), `v1`, NamePlaceholder(), `v2`), `k`, `v1`, `v2`))' due to argument data type mismatch: The input to function map_zip_with should have been two maps with compatible key types, but the key types are [decimal(36,35), int].; line 1 pos 7 + + +-- !query 9 +SELECT map_zip_with(decimal_map2, double_map, (k, v1, v2) -> struct(k, v1, v2)) m +FROM various_maps +-- !query 9 schema +struct>> +-- !query 9 output +{2.0:{"k":2.0,"v1":null,"v2":1.0},9.223372036854778:{"k":9.223372036854778,"v1":9.22337203685477897945456575809789456,"v2":null}} + + +-- !query 10 +SELECT map_zip_with(string_map1, int_map, (k, v1, v2) -> struct(k, v1, v2)) m +FROM various_maps +-- !query 10 schema +struct>> +-- !query 10 output +{"2":{"k":"2","v1":"1","v2":1},"true":{"k":"true","v1":"false","v2":null}} + + +-- !query 11 +SELECT map_zip_with(string_map2, date_map, (k, v1, v2) -> struct(k, v1, v2)) m +FROM various_maps +-- !query 11 schema +struct>> +-- !query 11 output +{"2016-03-14":{"k":"2016-03-14","v1":"2016-03-13","v2":2016-03-13}} + + +-- !query 12 +SELECT map_zip_with(timestamp_map, string_map3, (k, v1, v2) -> struct(k, v1, v2)) m +FROM various_maps +-- !query 12 schema +struct>> +-- !query 12 output +{"2016-11-15 20:54:00":{"k":"2016-11-15 20:54:00","v1":2016-11-12 20:54:00.0,"v2":null},"2016-11-15 20:54:00.000":{"k":"2016-11-15 20:54:00.000","v1":null,"v2":"2016-11-12 20:54:00.000"}} + + +-- !query 13 +SELECT map_zip_with(decimal_map1, string_map4, (k, v1, v2) -> struct(k, v1, v2)) m +FROM various_maps +-- !query 13 schema +struct>> +-- !query 13 output +{"922337203685477897945456575809789456":{"k":"922337203685477897945456575809789456","v1":922337203685477897945456575809789456,"v2":"text"}} + + +-- !query 14 +SELECT map_zip_with(array_map1, array_map2, (k, v1, v2) -> struct(k, v1, v2)) m +FROM various_maps +-- !query 14 schema +struct,struct,v1:array,v2:array>>> +-- !query 14 output +{[1,2]:{"k":[1,2],"v1":[1,2],"v2":[1,2]}} + + +-- !query 15 +SELECT map_zip_with(struct_map1, struct_map2, (k, v1, v2) -> struct(k, v1, v2)) m +FROM various_maps +-- !query 15 schema +struct,struct,v1:struct,v2:struct>>> +-- !query 15 output +{{"col1":1,"col2":2}:{"k":{"col1":1,"col2":2},"v1":{"col1":1,"col2":2},"v2":{"col1":1,"col2":2}}} diff --git a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/mapconcat.sql.out b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/mapconcat.sql.out new file mode 100644 index 0000000000000..efc88e47209a6 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/mapconcat.sql.out @@ -0,0 +1,144 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 8 + + +-- !query 0 +CREATE TEMPORARY VIEW various_maps AS SELECT * FROM VALUES ( + map(true, false), map(false, true), + map(1Y, 2Y), map(3Y, 4Y), + map(1S, 2S), map(3S, 4S), + map(4, 6), map(7, 8), + map(6L, 7L), map(8L, 9L), + map(9223372036854775809, 9223372036854775808), map(9223372036854775808, 9223372036854775809), + map(1.0D, 2.0D), map(3.0D, 4.0D), + map(float(1.0D), float(2.0D)), map(float(3.0D), float(4.0D)), + map(date '2016-03-14', date '2016-03-13'), map(date '2016-03-12', date '2016-03-11'), + map(timestamp '2016-11-15 20:54:00.000', timestamp '2016-11-12 20:54:00.000'), + map(timestamp '2016-11-11 20:54:00.000', timestamp '2016-11-09 20:54:00.000'), + map('a', 'b'), map('c', 'd'), + map(array('a', 'b'), array('c', 'd')), map(array('e'), array('f')), + map(struct('a', 1), struct('b', 2)), map(struct('c', 3), struct('d', 4)), + map(map('a', 1), map('b', 2)), map(map('c', 3), map('d', 4)), + map('a', 1), map('c', 2), + map(1, 'a'), map(2, 'c') +) AS various_maps ( + boolean_map1, boolean_map2, + tinyint_map1, tinyint_map2, + smallint_map1, smallint_map2, + int_map1, int_map2, + bigint_map1, bigint_map2, + decimal_map1, decimal_map2, + double_map1, double_map2, + float_map1, float_map2, + date_map1, date_map2, + timestamp_map1, + timestamp_map2, + string_map1, string_map2, + array_map1, array_map2, + struct_map1, struct_map2, + map_map1, map_map2, + string_int_map1, string_int_map2, + int_string_map1, int_string_map2 +) +-- !query 0 schema +struct<> +-- !query 0 output + + + +-- !query 1 +SELECT + map_concat(boolean_map1, boolean_map2) boolean_map, + map_concat(tinyint_map1, tinyint_map2) tinyint_map, + map_concat(smallint_map1, smallint_map2) smallint_map, + map_concat(int_map1, int_map2) int_map, + map_concat(bigint_map1, bigint_map2) bigint_map, + map_concat(decimal_map1, decimal_map2) decimal_map, + map_concat(float_map1, float_map2) float_map, + map_concat(double_map1, double_map2) double_map, + map_concat(date_map1, date_map2) date_map, + map_concat(timestamp_map1, timestamp_map2) timestamp_map, + map_concat(string_map1, string_map2) string_map, + map_concat(array_map1, array_map2) array_map, + map_concat(struct_map1, struct_map2) struct_map, + map_concat(map_map1, map_map2) map_map, + map_concat(string_int_map1, string_int_map2) string_int_map, + map_concat(int_string_map1, int_string_map2) int_string_map +FROM various_maps +-- !query 1 schema +struct,tinyint_map:map,smallint_map:map,int_map:map,bigint_map:map,decimal_map:map,float_map:map,double_map:map,date_map:map,timestamp_map:map,string_map:map,array_map:map,array>,struct_map:map,struct>,map_map:map,map>,string_int_map:map,int_string_map:map> +-- !query 1 output +{false:true,true:false} {1:2,3:4} {1:2,3:4} {4:6,7:8} {6:7,8:9} {9223372036854775808:9223372036854775809,9223372036854775809:9223372036854775808} {1.0:2.0,3.0:4.0} {1.0:2.0,3.0:4.0} {2016-03-12:2016-03-11,2016-03-14:2016-03-13} {2016-11-11 20:54:00.0:2016-11-09 20:54:00.0,2016-11-15 20:54:00.0:2016-11-12 20:54:00.0} {"a":"b","c":"d"} {["a","b"]:["c","d"],["e"]:["f"]} {{"col1":"a","col2":1}:{"col1":"b","col2":2},{"col1":"c","col2":3}:{"col1":"d","col2":4}} {{"a":1}:{"b":2},{"c":3}:{"d":4}} {"a":1,"c":2} {1:"a",2:"c"} + + +-- !query 2 +SELECT + map_concat(tinyint_map1, smallint_map2) ts_map, + map_concat(smallint_map1, int_map2) si_map, + map_concat(int_map1, bigint_map2) ib_map, + map_concat(bigint_map1, decimal_map2) bd_map, + map_concat(decimal_map1, float_map2) df_map, + map_concat(string_map1, date_map2) std_map, + map_concat(timestamp_map1, string_map2) tst_map, + map_concat(string_map1, int_map2) sti_map, + map_concat(int_string_map1, tinyint_map2) istt_map +FROM various_maps +-- !query 2 schema +struct,si_map:map,ib_map:map,bd_map:map,df_map:map,std_map:map,tst_map:map,sti_map:map,istt_map:map> +-- !query 2 output +{1:2,3:4} {1:2,7:8} {4:6,8:9} {6:7,9223372036854775808:9223372036854775809} {3.0:4.0,9.223372036854776E18:9.223372036854776E18} {"2016-03-12":"2016-03-11","a":"b"} {"2016-11-15 20:54:00":"2016-11-12 20:54:00","c":"d"} {"7":"8","a":"b"} {1:"a",3:"4"} + + +-- !query 3 +SELECT + map_concat(tinyint_map1, map_map2) tm_map +FROM various_maps +-- !query 3 schema +struct<> +-- !query 3 output +org.apache.spark.sql.AnalysisException +cannot resolve 'map_concat(various_maps.`tinyint_map1`, various_maps.`map_map2`)' due to data type mismatch: input to function map_concat should all be the same type, but it's [map, map,map>]; line 2 pos 4 + + +-- !query 4 +SELECT + map_concat(boolean_map1, int_map2) bi_map +FROM various_maps +-- !query 4 schema +struct<> +-- !query 4 output +org.apache.spark.sql.AnalysisException +cannot resolve 'map_concat(various_maps.`boolean_map1`, various_maps.`int_map2`)' due to data type mismatch: input to function map_concat should all be the same type, but it's [map, map]; line 2 pos 4 + + +-- !query 5 +SELECT + map_concat(int_map1, struct_map2) is_map +FROM various_maps +-- !query 5 schema +struct<> +-- !query 5 output +org.apache.spark.sql.AnalysisException +cannot resolve 'map_concat(various_maps.`int_map1`, various_maps.`struct_map2`)' due to data type mismatch: input to function map_concat should all be the same type, but it's [map, map,struct>]; line 2 pos 4 + + +-- !query 6 +SELECT + map_concat(map_map1, array_map2) ma_map +FROM various_maps +-- !query 6 schema +struct<> +-- !query 6 output +org.apache.spark.sql.AnalysisException +cannot resolve 'map_concat(various_maps.`map_map1`, various_maps.`array_map2`)' due to data type mismatch: input to function map_concat should all be the same type, but it's [map,map>, map,array>]; line 2 pos 4 + + +-- !query 7 +SELECT + map_concat(map_map1, struct_map2) ms_map +FROM various_maps +-- !query 7 schema +struct<> +-- !query 7 output +org.apache.spark.sql.AnalysisException +cannot resolve 'map_concat(various_maps.`map_map1`, various_maps.`struct_map2`)' due to data type mismatch: input to function map_concat should all be the same type, but it's [map,map>, map,struct>]; line 2 pos 4 diff --git a/sql/core/src/test/resources/sql-tests/results/udaf.sql.out b/sql/core/src/test/resources/sql-tests/results/udaf.sql.out index 4815a578b1029..87824ab81cdf7 100644 --- a/sql/core/src/test/resources/sql-tests/results/udaf.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/udaf.sql.out @@ -33,8 +33,8 @@ SELECT default.myDoubleAvg(int_col1, 3) as my_avg from t1 -- !query 3 schema struct<> -- !query 3 output -java.lang.AssertionError -assertion failed: Incorrect number of children +org.apache.spark.sql.AnalysisException +Invalid number of arguments for function default.myDoubleAvg. Expected: 1; Found: 2; line 1 pos 7 -- !query 4 diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/commits/0 b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/commits/0 new file mode 100644 index 0000000000000..83321cd95eb0c --- /dev/null +++ b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/commits/0 @@ -0,0 +1,2 @@ +v1 +{} \ No newline at end of file diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/commits/1 b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/commits/1 new file mode 100644 index 0000000000000..83321cd95eb0c --- /dev/null +++ b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/commits/1 @@ -0,0 +1,2 @@ +v1 +{} \ No newline at end of file diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/metadata b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/metadata new file mode 100644 index 0000000000000..372180b2096ee --- /dev/null +++ b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/metadata @@ -0,0 +1 @@ +{"id":"04d960cd-d38f-4ce6-b8d0-ebcf84c9dccc"} \ No newline at end of file diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/offsets/0 b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/offsets/0 new file mode 100644 index 0000000000000..807d7b0063b96 --- /dev/null +++ b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/offsets/0 @@ -0,0 +1,3 @@ +v1 +{"batchWatermarkMs":0,"batchTimestampMs":1531292029003,"conf":{"spark.sql.shuffle.partitions":"5","spark.sql.streaming.stateStore.providerClass":"org.apache.spark.sql.execution.streaming.state.HDFSBackedStateStoreProvider"}} +0 \ No newline at end of file diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/offsets/1 b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/offsets/1 new file mode 100644 index 0000000000000..cce541073fb4b --- /dev/null +++ b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/offsets/1 @@ -0,0 +1,3 @@ +v1 +{"batchWatermarkMs":5000,"batchTimestampMs":1531292030005,"conf":{"spark.sql.shuffle.partitions":"5","spark.sql.streaming.stateStore.providerClass":"org.apache.spark.sql.execution.streaming.state.HDFSBackedStateStoreProvider"}} +1 \ No newline at end of file diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/state/0/0/1.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/state/0/0/1.delta new file mode 100644 index 0000000000000..193524ffe15b5 Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/state/0/0/1.delta differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/state/0/0/2.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/state/0/0/2.delta new file mode 100644 index 0000000000000..6352978051846 Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/state/0/0/2.delta differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/state/0/1/1.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/state/0/1/1.delta new file mode 100644 index 0000000000000..6352978051846 Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/state/0/1/1.delta differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/state/0/1/2.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/state/0/1/2.delta new file mode 100644 index 0000000000000..6352978051846 Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/state/0/1/2.delta differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/state/0/2/1.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/state/0/2/1.delta new file mode 100644 index 0000000000000..6352978051846 Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/state/0/2/1.delta differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/state/0/2/2.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/state/0/2/2.delta new file mode 100644 index 0000000000000..6352978051846 Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/state/0/2/2.delta differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/state/0/3/1.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/state/0/3/1.delta new file mode 100644 index 0000000000000..6352978051846 Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/state/0/3/1.delta differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/state/0/3/2.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/state/0/3/2.delta new file mode 100644 index 0000000000000..6352978051846 Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/state/0/3/2.delta differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/state/0/4/1.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/state/0/4/1.delta new file mode 100644 index 0000000000000..6352978051846 Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/state/0/4/1.delta differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/state/0/4/2.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/state/0/4/2.delta new file mode 100644 index 0000000000000..6352978051846 Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/state/0/4/2.delta differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-for-multi-watermark-policy/commits/0 b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-for-multi-watermark-policy/commits/0 new file mode 100644 index 0000000000000..83321cd95eb0c --- /dev/null +++ b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-for-multi-watermark-policy/commits/0 @@ -0,0 +1,2 @@ +v1 +{} \ No newline at end of file diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-for-multi-watermark-policy/commits/1 b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-for-multi-watermark-policy/commits/1 new file mode 100644 index 0000000000000..83321cd95eb0c --- /dev/null +++ b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-for-multi-watermark-policy/commits/1 @@ -0,0 +1,2 @@ +v1 +{} \ No newline at end of file diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-for-multi-watermark-policy/metadata b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-for-multi-watermark-policy/metadata new file mode 100644 index 0000000000000..d6be7fbffa9b7 --- /dev/null +++ b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-for-multi-watermark-policy/metadata @@ -0,0 +1 @@ +{"id":"549eeb1a-d762-420c-bb44-3fd6d73a5268"} \ No newline at end of file diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-for-multi-watermark-policy/offsets/0 b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-for-multi-watermark-policy/offsets/0 new file mode 100644 index 0000000000000..43db49d052894 --- /dev/null +++ b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-for-multi-watermark-policy/offsets/0 @@ -0,0 +1,4 @@ +v1 +{"batchWatermarkMs":0,"batchTimestampMs":1531172902041,"conf":{"spark.sql.shuffle.partitions":"10","spark.sql.streaming.stateStore.providerClass":"org.apache.spark.sql.execution.streaming.state.HDFSBackedStateStoreProvider"}} +0 +0 \ No newline at end of file diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-for-multi-watermark-policy/offsets/1 b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-for-multi-watermark-policy/offsets/1 new file mode 100644 index 0000000000000..8cc898e81017f --- /dev/null +++ b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-for-multi-watermark-policy/offsets/1 @@ -0,0 +1,4 @@ +v1 +{"batchWatermarkMs":10000,"batchTimestampMs":1531172902217,"conf":{"spark.sql.shuffle.partitions":"10","spark.sql.streaming.stateStore.providerClass":"org.apache.spark.sql.execution.streaming.state.HDFSBackedStateStoreProvider"}} +1 +0 \ No newline at end of file diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/commits/0 b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/commits/0 new file mode 100644 index 0000000000000..83321cd95eb0c --- /dev/null +++ b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/commits/0 @@ -0,0 +1,2 @@ +v1 +{} \ No newline at end of file diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/commits/1 b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/commits/1 new file mode 100644 index 0000000000000..83321cd95eb0c --- /dev/null +++ b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/commits/1 @@ -0,0 +1,2 @@ +v1 +{} \ No newline at end of file diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/metadata b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/metadata new file mode 100644 index 0000000000000..c160d737278e1 --- /dev/null +++ b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/metadata @@ -0,0 +1 @@ +{"id":"2f32aca2-1b97-458f-a48f-109328724f09"} \ No newline at end of file diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/offsets/0 b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/offsets/0 new file mode 100644 index 0000000000000..acdc6e69e975a --- /dev/null +++ b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/offsets/0 @@ -0,0 +1,3 @@ +v1 +{"batchWatermarkMs":0,"batchTimestampMs":1533784347136,"conf":{"spark.sql.shuffle.partitions":"5","spark.sql.streaming.stateStore.providerClass":"org.apache.spark.sql.execution.streaming.state.HDFSBackedStateStoreProvider"}} +0 \ No newline at end of file diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/offsets/1 b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/offsets/1 new file mode 100644 index 0000000000000..27353e8724507 --- /dev/null +++ b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/offsets/1 @@ -0,0 +1,3 @@ +v1 +{"batchWatermarkMs":0,"batchTimestampMs":1533784349160,"conf":{"spark.sql.shuffle.partitions":"5","spark.sql.streaming.stateStore.providerClass":"org.apache.spark.sql.execution.streaming.state.HDFSBackedStateStoreProvider"}} +1 \ No newline at end of file diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/0/1.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/0/1.delta new file mode 100644 index 0000000000000..6352978051846 Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/0/1.delta differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/0/2.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/0/2.delta new file mode 100644 index 0000000000000..6352978051846 Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/0/2.delta differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/1/1.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/1/1.delta new file mode 100644 index 0000000000000..281b21e960909 Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/1/1.delta differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/1/2.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/1/2.delta new file mode 100644 index 0000000000000..b701841d71535 Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/1/2.delta differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/2/1.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/2/1.delta new file mode 100644 index 0000000000000..6352978051846 Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/2/1.delta differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/2/2.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/2/2.delta new file mode 100644 index 0000000000000..6352978051846 Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/2/2.delta differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/3/1.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/3/1.delta new file mode 100644 index 0000000000000..6352978051846 Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/3/1.delta differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/3/2.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/3/2.delta new file mode 100644 index 0000000000000..6352978051846 Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/3/2.delta differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/4/1.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/4/1.delta new file mode 100644 index 0000000000000..6352978051846 Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/4/1.delta differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/4/2.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/4/2.delta new file mode 100644 index 0000000000000..f4fb2520a4ac4 Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/4/2.delta differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-without-commit-log-metadata/commits/0 b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-without-commit-log-metadata/commits/0 new file mode 100644 index 0000000000000..83321cd95eb0c --- /dev/null +++ b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-without-commit-log-metadata/commits/0 @@ -0,0 +1,2 @@ +v1 +{} \ No newline at end of file diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-without-commit-log-metadata/commits/1 b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-without-commit-log-metadata/commits/1 new file mode 100644 index 0000000000000..83321cd95eb0c --- /dev/null +++ b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-without-commit-log-metadata/commits/1 @@ -0,0 +1,2 @@ +v1 +{} \ No newline at end of file diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-without-commit-log-metadata/metadata b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-without-commit-log-metadata/metadata new file mode 100644 index 0000000000000..f205857e6876f --- /dev/null +++ b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-without-commit-log-metadata/metadata @@ -0,0 +1 @@ +{"id":"73f7f943-0a08-4ffb-a504-9fa88ff7612a"} \ No newline at end of file diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-without-commit-log-metadata/offsets/0 b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-without-commit-log-metadata/offsets/0 new file mode 100644 index 0000000000000..8fa80bedc2285 --- /dev/null +++ b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-without-commit-log-metadata/offsets/0 @@ -0,0 +1,3 @@ +v1 +{"batchWatermarkMs":0,"batchTimestampMs":1531991874513,"conf":{"spark.sql.shuffle.partitions":"5","spark.sql.streaming.stateStore.providerClass":"org.apache.spark.sql.execution.streaming.state.HDFSBackedStateStoreProvider"}} +0 \ No newline at end of file diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-without-commit-log-metadata/offsets/1 b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-without-commit-log-metadata/offsets/1 new file mode 100644 index 0000000000000..2248a58fea006 --- /dev/null +++ b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-without-commit-log-metadata/offsets/1 @@ -0,0 +1,3 @@ +v1 +{"batchWatermarkMs":5000,"batchTimestampMs":1531991878604,"conf":{"spark.sql.shuffle.partitions":"5","spark.sql.streaming.stateStore.providerClass":"org.apache.spark.sql.execution.streaming.state.HDFSBackedStateStoreProvider"}} +1 \ No newline at end of file diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-without-commit-log-metadata/state/0/0/1.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-without-commit-log-metadata/state/0/0/1.delta new file mode 100644 index 0000000000000..6352978051846 Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-without-commit-log-metadata/state/0/0/1.delta differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-without-commit-log-metadata/state/0/0/2.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-without-commit-log-metadata/state/0/0/2.delta new file mode 100644 index 0000000000000..6352978051846 Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-without-commit-log-metadata/state/0/0/2.delta differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-without-commit-log-metadata/state/0/1/1.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-without-commit-log-metadata/state/0/1/1.delta new file mode 100644 index 0000000000000..6352978051846 Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-without-commit-log-metadata/state/0/1/1.delta differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-without-commit-log-metadata/state/0/1/2.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-without-commit-log-metadata/state/0/1/2.delta new file mode 100644 index 0000000000000..6352978051846 Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-without-commit-log-metadata/state/0/1/2.delta differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-without-commit-log-metadata/state/0/2/1.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-without-commit-log-metadata/state/0/2/1.delta new file mode 100644 index 0000000000000..171aa58a06e21 Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-without-commit-log-metadata/state/0/2/1.delta differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-without-commit-log-metadata/state/0/2/2.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-without-commit-log-metadata/state/0/2/2.delta new file mode 100644 index 0000000000000..6352978051846 Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-without-commit-log-metadata/state/0/2/2.delta differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-without-commit-log-metadata/state/0/3/1.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-without-commit-log-metadata/state/0/3/1.delta new file mode 100644 index 0000000000000..6352978051846 Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-without-commit-log-metadata/state/0/3/1.delta differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-without-commit-log-metadata/state/0/3/2.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-without-commit-log-metadata/state/0/3/2.delta new file mode 100644 index 0000000000000..6352978051846 Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-without-commit-log-metadata/state/0/3/2.delta differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-without-commit-log-metadata/state/0/4/1.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-without-commit-log-metadata/state/0/4/1.delta new file mode 100644 index 0000000000000..6352978051846 Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-without-commit-log-metadata/state/0/4/1.delta differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-without-commit-log-metadata/state/0/4/2.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-without-commit-log-metadata/state/0/4/2.delta new file mode 100644 index 0000000000000..cfb3a481deb59 Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-without-commit-log-metadata/state/0/4/2.delta differ diff --git a/sql/core/src/test/resources/test-data/cars-empty-value.csv b/sql/core/src/test/resources/test-data/cars-empty-value.csv new file mode 100644 index 0000000000000..0f20a2f23ac06 --- /dev/null +++ b/sql/core/src/test/resources/test-data/cars-empty-value.csv @@ -0,0 +1,4 @@ +year,make,model,comment,blank +"2012","Tesla","S","","" +1997,Ford,E350,"Go get one now they are going fast", +2015,Chevy,Volt,,"" diff --git a/sql/core/src/test/scala/org/apache/spark/sql/BenchmarkQueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/BenchmarkQueryTest.scala index e51aad021fcbf..d95794d624033 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/BenchmarkQueryTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/BenchmarkQueryTest.scala @@ -54,7 +54,7 @@ abstract class BenchmarkQueryTest extends QueryTest with SharedSQLContext with B plan foreach { case s: WholeStageCodegenExec => codegenSubtrees += s - case s => s + case _ => } codegenSubtrees.toSeq.foreach { subtree => val code = subtree.doCodeGen()._2 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala index 2182bd7eadd63..2917c56dbeb56 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala @@ -436,27 +436,6 @@ class ColumnExpressionSuite extends QueryTest with SharedSQLContext { } } - test("isInCollection: Java Collection") { - val df = Seq((1, "x"), (2, "y"), (3, "z")).toDF("a", "b") - // Test with different types of collections - checkAnswer(df.filter($"a".isInCollection(Seq(1, 2).asJava)), - df.collect().toSeq.filter(r => r.getInt(0) == 1 || r.getInt(0) == 2)) - checkAnswer(df.filter($"a".isInCollection(Seq(1, 2).toSet.asJava)), - df.collect().toSeq.filter(r => r.getInt(0) == 1 || r.getInt(0) == 2)) - checkAnswer(df.filter($"a".isInCollection(Seq(3, 1).toList.asJava)), - df.collect().toSeq.filter(r => r.getInt(0) == 3 || r.getInt(0) == 1)) - - val df2 = Seq((1, Seq(1)), (2, Seq(2)), (3, Seq(3))).toDF("a", "b") - - val e = intercept[AnalysisException] { - df2.filter($"a".isInCollection(Seq($"b").asJava)) - } - Seq("cannot resolve", "due to data type mismatch: Arguments must be same type but were") - .foreach { s => - assert(e.getMessage.toLowerCase(Locale.ROOT).contains(s.toLowerCase(Locale.ROOT))) - } - } - test("&&") { checkAnswer( booleanData.filter($"a" && true), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index f495a949ebc5a..ed110f751645d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -557,11 +557,13 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { } test("SPARK-18004 limit + aggregates") { - val df = Seq(("a", 1), ("b", 2), ("c", 1), ("d", 5)).toDF("id", "value") - val limit2Df = df.limit(2) - checkAnswer( - limit2Df.groupBy("id").count().select($"id"), - limit2Df.select($"id")) + withSQLConf(SQLConf.LIMIT_FLAT_GLOBAL_LIMIT.key -> "true") { + val df = Seq(("a", 1), ("b", 2), ("c", 1), ("d", 5)).toDF("id", "value").repartition(1) + val limit2Df = df.limit(2) + checkAnswer( + limit2Df.groupBy("id").count().select($"id"), + limit2Df.select($"id")) + } } test("SPARK-17237 remove backticks in a pivot result schema") { @@ -717,4 +719,14 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { Row(1, 2, 1) :: Row(2, 2, 2) :: Row(3, 2, 3) :: Nil) } + test("SPARK-24788: RelationalGroupedDataset.toString with unresolved exprs should not fail") { + // Checks if these raise no exception + assert(testData.groupBy('key).toString.contains( + "[grouping expressions: [key], value: [key: int, value: string], type: GroupBy]")) + assert(testData.groupBy(col("key")).toString.contains( + "[grouping expressions: [key], value: [key: int, value: string], type: GroupBy]")) + assert(testData.groupBy(current_date()).toString.contains( + "grouping expressions: [current_date(None)], value: [key: int, value: string], " + + "type: GroupBy]")) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 4c28e2f1cd909..121db442c77f7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -85,14 +85,16 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { } val df5 = Seq((Seq("a", null), Seq(1, 2))).toDF("k", "v") - intercept[RuntimeException] { + val msg1 = intercept[Exception] { df5.select(map_from_arrays($"k", $"v")).collect - } + }.getMessage + assert(msg1.contains("Cannot use null as map key!")) val df6 = Seq((Seq(1, 2), Seq("a"))).toDF("k", "v") - intercept[RuntimeException] { + val msg2 = intercept[Exception] { df6.select(map_from_arrays($"k", $"v")).collect - } + }.getMessage + assert(msg2.contains("The given two arrays should have the same length")) } test("struct with column name") { @@ -309,113 +311,6 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { ) } - test("mask functions") { - val df = Seq("TestString-123", "", null).toDF("a") - checkAnswer(df.select(mask($"a")), Seq(Row("XxxxXxxxxx-nnn"), Row(""), Row(null))) - checkAnswer(df.select(mask_first_n($"a", 4)), Seq(Row("XxxxString-123"), Row(""), Row(null))) - checkAnswer(df.select(mask_last_n($"a", 4)), Seq(Row("TestString-nnn"), Row(""), Row(null))) - checkAnswer(df.select(mask_show_first_n($"a", 4)), - Seq(Row("TestXxxxxx-nnn"), Row(""), Row(null))) - checkAnswer(df.select(mask_show_last_n($"a", 4)), - Seq(Row("XxxxXxxxxx-123"), Row(""), Row(null))) - checkAnswer(df.select(mask_hash($"a")), - Seq(Row("dd78d68ad1b23bde126812482dd70ac6"), - Row("d41d8cd98f00b204e9800998ecf8427e"), - Row(null))) - - checkAnswer(df.select(mask($"a", "U", "l", "#")), - Seq(Row("UlllUlllll-###"), Row(""), Row(null))) - checkAnswer(df.select(mask_first_n($"a", 4, "U", "l", "#")), - Seq(Row("UlllString-123"), Row(""), Row(null))) - checkAnswer(df.select(mask_last_n($"a", 4, "U", "l", "#")), - Seq(Row("TestString-###"), Row(""), Row(null))) - checkAnswer(df.select(mask_show_first_n($"a", 4, "U", "l", "#")), - Seq(Row("TestUlllll-###"), Row(""), Row(null))) - checkAnswer(df.select(mask_show_last_n($"a", 4, "U", "l", "#")), - Seq(Row("UlllUlllll-123"), Row(""), Row(null))) - - checkAnswer( - df.selectExpr("mask(a)", "mask(a, 'U')", "mask(a, 'U', 'l')", "mask(a, 'U', 'l', '#')"), - Seq(Row("XxxxXxxxxx-nnn", "UxxxUxxxxx-nnn", "UlllUlllll-nnn", "UlllUlllll-###"), - Row("", "", "", ""), - Row(null, null, null, null))) - checkAnswer(sql("select mask(null)"), Row(null)) - checkAnswer(sql("select mask('AAaa11', null, null, null)"), Row("XXxxnn")) - intercept[AnalysisException] { - checkAnswer(df.selectExpr("mask(a, a)"), Seq(Row("XxxxXxxxxx-nnn"), Row(""), Row(null))) - } - - checkAnswer( - df.selectExpr( - "mask_first_n(a)", - "mask_first_n(a, 6)", - "mask_first_n(a, 6, 'U')", - "mask_first_n(a, 6, 'U', 'l')", - "mask_first_n(a, 6, 'U', 'l', '#')"), - Seq(Row("XxxxString-123", "XxxxXxring-123", "UxxxUxring-123", "UlllUlring-123", - "UlllUlring-123"), - Row("", "", "", "", ""), - Row(null, null, null, null, null))) - checkAnswer(sql("select mask_first_n(null)"), Row(null)) - checkAnswer(sql("select mask_first_n('A1aA1a', null, null, null, null)"), Row("XnxX1a")) - intercept[AnalysisException] { - checkAnswer(spark.range(1).selectExpr("mask_first_n('A1aA1a', id)"), Row("XnxX1a")) - } - - checkAnswer( - df.selectExpr( - "mask_last_n(a)", - "mask_last_n(a, 6)", - "mask_last_n(a, 6, 'U')", - "mask_last_n(a, 6, 'U', 'l')", - "mask_last_n(a, 6, 'U', 'l', '#')"), - Seq(Row("TestString-nnn", "TestStrixx-nnn", "TestStrixx-nnn", "TestStrill-nnn", - "TestStrill-###"), - Row("", "", "", "", ""), - Row(null, null, null, null, null))) - checkAnswer(sql("select mask_last_n(null)"), Row(null)) - checkAnswer(sql("select mask_last_n('A1aA1a', null, null, null, null)"), Row("A1xXnx")) - intercept[AnalysisException] { - checkAnswer(spark.range(1).selectExpr("mask_last_n('A1aA1a', id)"), Row("A1xXnx")) - } - - checkAnswer( - df.selectExpr( - "mask_show_first_n(a)", - "mask_show_first_n(a, 6)", - "mask_show_first_n(a, 6, 'U')", - "mask_show_first_n(a, 6, 'U', 'l')", - "mask_show_first_n(a, 6, 'U', 'l', '#')"), - Seq(Row("TestXxxxxx-nnn", "TestStxxxx-nnn", "TestStxxxx-nnn", "TestStllll-nnn", - "TestStllll-###"), - Row("", "", "", "", ""), - Row(null, null, null, null, null))) - checkAnswer(sql("select mask_show_first_n(null)"), Row(null)) - checkAnswer(sql("select mask_show_first_n('A1aA1a', null, null, null, null)"), Row("A1aAnx")) - intercept[AnalysisException] { - checkAnswer(spark.range(1).selectExpr("mask_show_first_n('A1aA1a', id)"), Row("A1aAnx")) - } - - checkAnswer( - df.selectExpr( - "mask_show_last_n(a)", - "mask_show_last_n(a, 6)", - "mask_show_last_n(a, 6, 'U')", - "mask_show_last_n(a, 6, 'U', 'l')", - "mask_show_last_n(a, 6, 'U', 'l', '#')"), - Seq(Row("XxxxXxxxxx-123", "XxxxXxxxng-123", "UxxxUxxxng-123", "UlllUlllng-123", - "UlllUlllng-123"), - Row("", "", "", "", ""), - Row(null, null, null, null, null))) - checkAnswer(sql("select mask_show_last_n(null)"), Row(null)) - checkAnswer(sql("select mask_show_last_n('A1aA1a', null, null, null, null)"), Row("XnaA1a")) - intercept[AnalysisException] { - checkAnswer(spark.range(1).selectExpr("mask_show_last_n('A1aA1a', id)"), Row("XnaA1a")) - } - - checkAnswer(sql("select mask_hash(null)"), Row(null)) - } - test("sort_array/array_sort functions") { val df = Seq( (Array[Int](2, 1, 3), Array("b", "c", "a")), @@ -614,8 +509,6 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { } test("map_entries") { - val dummyFilter = (c: Column) => c.isNotNull || c.isNull - // Primitive-type elements val idf = Seq( Map[Int, Int](1 -> 100, 2 -> 200, 3 -> 300), @@ -628,15 +521,18 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { Row(null) ) - checkAnswer(idf.select(map_entries('m)), iExpected) - checkAnswer(idf.selectExpr("map_entries(m)"), iExpected) - checkAnswer(idf.filter(dummyFilter('m)).select(map_entries('m)), iExpected) - checkAnswer( - spark.range(1).selectExpr("map_entries(map(1, null, 2, null))"), - Seq(Row(Seq(Row(1, null), Row(2, null))))) - checkAnswer( - spark.range(1).filter(dummyFilter('id)).selectExpr("map_entries(map(1, null, 2, null))"), - Seq(Row(Seq(Row(1, null), Row(2, null))))) + def testPrimitiveType(): Unit = { + checkAnswer(idf.select(map_entries('m)), iExpected) + checkAnswer(idf.selectExpr("map_entries(m)"), iExpected) + checkAnswer(idf.selectExpr("map_entries(map(1, null, 2, null))"), + Seq.fill(iExpected.length)(Row(Seq(Row(1, null), Row(2, null))))) + } + + // Test with local relation, the Project will be evaluated without codegen + testPrimitiveType() + // Test with cached relation, the Project will be evaluated with codegen + idf.cache() + testPrimitiveType() // Non-primitive-type elements val sdf = Seq( @@ -652,15 +548,97 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { Row(null) ) - checkAnswer(sdf.select(map_entries('m)), sExpected) - checkAnswer(sdf.selectExpr("map_entries(m)"), sExpected) - checkAnswer(sdf.filter(dummyFilter('m)).select(map_entries('m)), sExpected) + def testNonPrimitiveType(): Unit = { + checkAnswer(sdf.select(map_entries('m)), sExpected) + checkAnswer(sdf.selectExpr("map_entries(m)"), sExpected) + } + + // Test with local relation, the Project will be evaluated without codegen + testNonPrimitiveType() + // Test with cached relation, the Project will be evaluated with codegen + sdf.cache() + testNonPrimitiveType() } - test("map_from_entries function") { - def dummyFilter(c: Column): Column = c.isNull || c.isNotNull - val oneRowDF = Seq(3215).toDF("i") + test("map_concat function") { + val df1 = Seq( + (Map[Int, Int](1 -> 100, 2 -> 200), Map[Int, Int](3 -> 300, 4 -> 400)), + (Map[Int, Int](1 -> 100, 2 -> 200), Map[Int, Int](3 -> 300, 1 -> 400)), + (null, Map[Int, Int](3 -> 300, 4 -> 400)) + ).toDF("map1", "map2") + + val expected1a = Seq( + Row(Map(1 -> 100, 2 -> 200, 3 -> 300, 4 -> 400)), + Row(Map(1 -> 400, 2 -> 200, 3 -> 300)), + Row(null) + ) + + checkAnswer(df1.selectExpr("map_concat(map1, map2)"), expected1a) + checkAnswer(df1.select(map_concat('map1, 'map2)), expected1a) + + val expected1b = Seq( + Row(Map(1 -> 100, 2 -> 200)), + Row(Map(1 -> 100, 2 -> 200)), + Row(null) + ) + + checkAnswer(df1.selectExpr("map_concat(map1)"), expected1b) + checkAnswer(df1.select(map_concat('map1)), expected1b) + + val df2 = Seq( + ( + Map[Array[Int], Int](Array(1) -> 100, Array(2) -> 200), + Map[String, Int]("3" -> 300, "4" -> 400) + ) + ).toDF("map1", "map2") + + val expected2 = Seq(Row(Map())) + checkAnswer(df2.selectExpr("map_concat()"), expected2) + checkAnswer(df2.select(map_concat()), expected2) + + val df3 = { + val schema = StructType( + StructField("map1", MapType(StringType, IntegerType, true), false) :: + StructField("map2", MapType(StringType, IntegerType, false), false) :: Nil + ) + val data = Seq( + Row(Map[String, Any]("a" -> 1, "b" -> null), Map[String, Any]("c" -> 3, "d" -> 4)), + Row(Map[String, Any]("a" -> 1, "b" -> 2), Map[String, Any]("c" -> 3, "d" -> 4)) + ) + spark.createDataFrame(spark.sparkContext.parallelize(data), schema) + } + + val expected3 = Seq( + Row(Map[String, Any]("a" -> 1, "b" -> null, "c" -> 3, "d" -> 4)), + Row(Map[String, Any]("a" -> 1, "b" -> 2, "c" -> 3, "d" -> 4)) + ) + + checkAnswer(df3.selectExpr("map_concat(map1, map2)"), expected3) + checkAnswer(df3.select(map_concat('map1, 'map2)), expected3) + + val expectedMessage1 = "input to function map_concat should all be the same type" + + assert(intercept[AnalysisException] { + df2.selectExpr("map_concat(map1, map2)").collect() + }.getMessage().contains(expectedMessage1)) + + assert(intercept[AnalysisException] { + df2.select(map_concat('map1, 'map2)).collect() + }.getMessage().contains(expectedMessage1)) + + val expectedMessage2 = "input to function map_concat should all be of type map" + + assert(intercept[AnalysisException] { + df2.selectExpr("map_concat(map1, 12)").collect() + }.getMessage().contains(expectedMessage2)) + + assert(intercept[AnalysisException] { + df2.select(map_concat('map1, lit(12))).collect() + }.getMessage().contains(expectedMessage2)) + } + + test("map_from_entries function") { // Test cases with primitive-type keys and values val idf = Seq( Seq((1, 10), (2, 20), (3, 10)), @@ -674,18 +652,18 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { Row(Map.empty), Row(null)) - checkAnswer(idf.select(map_from_entries('a)), iExpected) - checkAnswer(idf.selectExpr("map_from_entries(a)"), iExpected) - checkAnswer(idf.filter(dummyFilter('a)).select(map_from_entries('a)), iExpected) - checkAnswer( - oneRowDF.selectExpr("map_from_entries(array(struct(1, null), struct(2, null)))"), - Seq(Row(Map(1 -> null, 2 -> null))) - ) - checkAnswer( - oneRowDF.filter(dummyFilter('i)) - .selectExpr("map_from_entries(array(struct(1, null), struct(2, null)))"), - Seq(Row(Map(1 -> null, 2 -> null))) - ) + def testPrimitiveType(): Unit = { + checkAnswer(idf.select(map_from_entries('a)), iExpected) + checkAnswer(idf.selectExpr("map_from_entries(a)"), iExpected) + checkAnswer(idf.selectExpr("map_from_entries(array(struct(1, null), struct(2, null)))"), + Seq.fill(iExpected.length)(Row(Map(1 -> null, 2 -> null)))) + } + + // Test with local relation, the Project will be evaluated without codegen + testPrimitiveType() + // Test with cached relation, the Project will be evaluated with codegen + idf.cache() + testPrimitiveType() // Test cases with non-primitive-type keys and values val sdf = Seq( @@ -702,9 +680,16 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { Row(Map.empty), Row(null)) - checkAnswer(sdf.select(map_from_entries('a)), sExpected) - checkAnswer(sdf.selectExpr("map_from_entries(a)"), sExpected) - checkAnswer(sdf.filter(dummyFilter('a)).select(map_from_entries('a)), sExpected) + def testNonPrimitiveType(): Unit = { + checkAnswer(sdf.select(map_from_entries('a)), sExpected) + checkAnswer(sdf.selectExpr("map_from_entries(a)"), sExpected) + } + + // Test with local relation, the Project will be evaluated without codegen + testNonPrimitiveType() + // Test with cached relation, the Project will be evaluated with codegen + sdf.cache() + testNonPrimitiveType() } test("array contains function") { @@ -918,63 +903,76 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { } } - test("reverse function") { - val dummyFilter = (c: Column) => c.isNull || c.isNotNull // switch codegen on - - // String test cases + test("reverse function - string") { val oneRowDF = Seq(("Spark", 3215)).toDF("s", "i") + def testString(): Unit = { + checkAnswer(oneRowDF.select(reverse('s)), Seq(Row("krapS"))) + checkAnswer(oneRowDF.selectExpr("reverse(s)"), Seq(Row("krapS"))) + checkAnswer(oneRowDF.select(reverse('i)), Seq(Row("5123"))) + checkAnswer(oneRowDF.selectExpr("reverse(i)"), Seq(Row("5123"))) + checkAnswer(oneRowDF.selectExpr("reverse(null)"), Seq(Row(null))) + } - checkAnswer( - oneRowDF.select(reverse('s)), - Seq(Row("krapS")) - ) - checkAnswer( - oneRowDF.selectExpr("reverse(s)"), - Seq(Row("krapS")) - ) - checkAnswer( - oneRowDF.select(reverse('i)), - Seq(Row("5123")) - ) - checkAnswer( - oneRowDF.selectExpr("reverse(i)"), - Seq(Row("5123")) - ) - checkAnswer( - oneRowDF.selectExpr("reverse(null)"), - Seq(Row(null)) - ) + // Test with local relation, the Project will be evaluated without codegen + testString() + // Test with cached relation, the Project will be evaluated with codegen + oneRowDF.cache() + testString() + } - // Array test cases (primitive-type elements) - val idf = Seq( + test("reverse function - array for primitive type not containing null") { + val idfNotContainsNull = Seq( Seq(1, 9, 8, 7), Seq(5, 8, 9, 7, 2), Seq.empty, null ).toDF("i") - checkAnswer( - idf.select(reverse('i)), - Seq(Row(Seq(7, 8, 9, 1)), Row(Seq(2, 7, 9, 8, 5)), Row(Seq.empty), Row(null)) - ) - checkAnswer( - idf.filter(dummyFilter('i)).select(reverse('i)), - Seq(Row(Seq(7, 8, 9, 1)), Row(Seq(2, 7, 9, 8, 5)), Row(Seq.empty), Row(null)) - ) - checkAnswer( - idf.selectExpr("reverse(i)"), - Seq(Row(Seq(7, 8, 9, 1)), Row(Seq(2, 7, 9, 8, 5)), Row(Seq.empty), Row(null)) - ) - checkAnswer( - oneRowDF.selectExpr("reverse(array(1, null, 2, null))"), - Seq(Row(Seq(null, 2, null, 1))) - ) - checkAnswer( - oneRowDF.filter(dummyFilter('i)).selectExpr("reverse(array(1, null, 2, null))"), - Seq(Row(Seq(null, 2, null, 1))) - ) + def testArrayOfPrimitiveTypeNotContainsNull(): Unit = { + checkAnswer( + idfNotContainsNull.select(reverse('i)), + Seq(Row(Seq(7, 8, 9, 1)), Row(Seq(2, 7, 9, 8, 5)), Row(Seq.empty), Row(null)) + ) + checkAnswer( + idfNotContainsNull.selectExpr("reverse(i)"), + Seq(Row(Seq(7, 8, 9, 1)), Row(Seq(2, 7, 9, 8, 5)), Row(Seq.empty), Row(null)) + ) + } + + // Test with local relation, the Project will be evaluated without codegen + testArrayOfPrimitiveTypeNotContainsNull() + // Test with cached relation, the Project will be evaluated with codegen + idfNotContainsNull.cache() + testArrayOfPrimitiveTypeNotContainsNull() + } + + test("reverse function - array for primitive type containing null") { + val idfContainsNull = Seq[Seq[Integer]]( + Seq(1, 9, 8, null, 7), + Seq(null, 5, 8, 9, 7, 2), + Seq.empty, + null + ).toDF("i") + + def testArrayOfPrimitiveTypeContainsNull(): Unit = { + checkAnswer( + idfContainsNull.select(reverse('i)), + Seq(Row(Seq(7, null, 8, 9, 1)), Row(Seq(2, 7, 9, 8, 5, null)), Row(Seq.empty), Row(null)) + ) + checkAnswer( + idfContainsNull.selectExpr("reverse(i)"), + Seq(Row(Seq(7, null, 8, 9, 1)), Row(Seq(2, 7, 9, 8, 5, null)), Row(Seq.empty), Row(null)) + ) + } + + // Test with local relation, the Project will be evaluated without codegen + testArrayOfPrimitiveTypeContainsNull() + // Test with cached relation, the Project will be evaluated with codegen + idfContainsNull.cache() + testArrayOfPrimitiveTypeContainsNull() + } - // Array test cases (non-primitive-type elements) + test("reverse function - array for non-primitive type") { val sdf = Seq( Seq("c", "a", "b"), Seq("b", null, "c", null), @@ -982,34 +980,38 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { null ).toDF("s") - checkAnswer( - sdf.select(reverse('s)), - Seq(Row(Seq("b", "a", "c")), Row(Seq(null, "c", null, "b")), Row(Seq.empty), Row(null)) - ) - checkAnswer( - sdf.filter(dummyFilter('s)).select(reverse('s)), - Seq(Row(Seq("b", "a", "c")), Row(Seq(null, "c", null, "b")), Row(Seq.empty), Row(null)) - ) - checkAnswer( - sdf.selectExpr("reverse(s)"), - Seq(Row(Seq("b", "a", "c")), Row(Seq(null, "c", null, "b")), Row(Seq.empty), Row(null)) - ) - checkAnswer( - oneRowDF.selectExpr("reverse(array(array(1, 2), array(3, 4)))"), - Seq(Row(Seq(Seq(3, 4), Seq(1, 2)))) - ) - checkAnswer( - oneRowDF.filter(dummyFilter('s)).selectExpr("reverse(array(array(1, 2), array(3, 4)))"), - Seq(Row(Seq(Seq(3, 4), Seq(1, 2)))) - ) + def testArrayOfNonPrimitiveType(): Unit = { + checkAnswer( + sdf.select(reverse('s)), + Seq(Row(Seq("b", "a", "c")), Row(Seq(null, "c", null, "b")), Row(Seq.empty), Row(null)) + ) + checkAnswer( + sdf.selectExpr("reverse(s)"), + Seq(Row(Seq("b", "a", "c")), Row(Seq(null, "c", null, "b")), Row(Seq.empty), Row(null)) + ) + checkAnswer( + sdf.selectExpr("reverse(array(array(1, 2), array(3, 4)))"), + Seq.fill(sdf.count().toInt)(Row(Seq(Seq(3, 4), Seq(1, 2)))) + ) + } - // Error test cases - intercept[AnalysisException] { - oneRowDF.selectExpr("reverse(struct(1, 'a'))") + // Test with local relation, the Project will be evaluated without codegen + testArrayOfNonPrimitiveType() + // Test with cached relation, the Project will be evaluated with codegen + sdf.cache() + testArrayOfNonPrimitiveType() + } + + test("reverse function - data type mismatch") { + val ex1 = intercept[AnalysisException] { + sql("select reverse(struct(1, 'a'))") } - intercept[AnalysisException] { - oneRowDF.selectExpr("reverse(map(1, 'a'))") + assert(ex1.getMessage.contains("data type mismatch")) + + val ex2 = intercept[AnalysisException] { + sql("select reverse(map(1, 'a'))") } + assert(ex2.getMessage.contains("data type mismatch")) } test("array position function") { @@ -1120,69 +1122,122 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { "argument 1 requires (array or map) type, however, '`_1`' is of string type")) } + test("array_union functions") { + val df1 = Seq((Array(1, 2, 3), Array(4, 2))).toDF("a", "b") + val ans1 = Row(Seq(1, 2, 3, 4)) + checkAnswer(df1.select(array_union($"a", $"b")), ans1) + checkAnswer(df1.selectExpr("array_union(a, b)"), ans1) + + val df2 = Seq((Array[Integer](1, 2, null, 4, 5), Array(-5, 4, -3, 2, -1))).toDF("a", "b") + val ans2 = Row(Seq(1, 2, null, 4, 5, -5, -3, -1)) + checkAnswer(df2.select(array_union($"a", $"b")), ans2) + checkAnswer(df2.selectExpr("array_union(a, b)"), ans2) + + val df3 = Seq((Array(1L, 2L, 3L), Array(4L, 2L))).toDF("a", "b") + val ans3 = Row(Seq(1L, 2L, 3L, 4L)) + checkAnswer(df3.select(array_union($"a", $"b")), ans3) + checkAnswer(df3.selectExpr("array_union(a, b)"), ans3) + + val df4 = Seq((Array[java.lang.Long](1L, 2L, null, 4L, 5L), Array(-5L, 4L, -3L, 2L, -1L))) + .toDF("a", "b") + val ans4 = Row(Seq(1L, 2L, null, 4L, 5L, -5L, -3L, -1L)) + checkAnswer(df4.select(array_union($"a", $"b")), ans4) + checkAnswer(df4.selectExpr("array_union(a, b)"), ans4) + + val df5 = Seq((Array("b", "a", "c"), Array("b", null, "a", "g"))).toDF("a", "b") + val ans5 = Row(Seq("b", "a", "c", null, "g")) + checkAnswer(df5.select(array_union($"a", $"b")), ans5) + checkAnswer(df5.selectExpr("array_union(a, b)"), ans5) + + val df6 = Seq((null, Array("a"))).toDF("a", "b") + assert(intercept[AnalysisException] { + df6.select(array_union($"a", $"b")) + }.getMessage.contains("data type mismatch")) + assert(intercept[AnalysisException] { + df6.selectExpr("array_union(a, b)") + }.getMessage.contains("data type mismatch")) + + val df7 = Seq((null, null)).toDF("a", "b") + assert(intercept[AnalysisException] { + df7.select(array_union($"a", $"b")) + }.getMessage.contains("data type mismatch")) + assert(intercept[AnalysisException] { + df7.selectExpr("array_union(a, b)") + }.getMessage.contains("data type mismatch")) + + val df8 = Seq((Array(Array(1)), Array("a"))).toDF("a", "b") + assert(intercept[AnalysisException] { + df8.select(array_union($"a", $"b")) + }.getMessage.contains("data type mismatch")) + assert(intercept[AnalysisException] { + df8.selectExpr("array_union(a, b)") + }.getMessage.contains("data type mismatch")) + } + test("concat function - arrays") { val nseqi : Seq[Int] = null val nseqs : Seq[String] = null val df = Seq( - (Seq(1), Seq(2, 3), Seq(5L, 6L), nseqi, Seq("a", "b", "c"), Seq("d", "e"), Seq("f"), nseqs), (Seq(1, 0), Seq.empty[Int], Seq(2L), nseqi, Seq("a"), Seq.empty[String], Seq(null), nseqs) ).toDF("i1", "i2", "i3", "in", "s1", "s2", "s3", "sn") - val dummyFilter = (c: Column) => c.isNull || c.isNotNull // switch codeGen on - // Simple test cases - checkAnswer( - df.selectExpr("array(1, 2, 3L)"), - Seq(Row(Seq(1L, 2L, 3L)), Row(Seq(1L, 2L, 3L))) - ) + def simpleTest(): Unit = { + checkAnswer ( + df.select(concat($"i1", $"s1")), + Seq(Row(Seq("1", "a", "b", "c")), Row(Seq("1", "0", "a"))) + ) + checkAnswer( + df.select(concat($"i1", $"i2", $"i3")), + Seq(Row(Seq(1, 2, 3, 5, 6)), Row(Seq(1, 0, 2))) + ) + checkAnswer( + df.selectExpr("concat(array(1, null), i2, i3)"), + Seq(Row(Seq(1, null, 2, 3, 5, 6)), Row(Seq(1, null, 2))) + ) + checkAnswer( + df.select(concat($"s1", $"s2", $"s3")), + Seq(Row(Seq("a", "b", "c", "d", "e", "f")), Row(Seq("a", null))) + ) + checkAnswer( + df.selectExpr("concat(s1, s2, s3)"), + Seq(Row(Seq("a", "b", "c", "d", "e", "f")), Row(Seq("a", null))) + ) + } - checkAnswer ( - df.select(concat($"i1", $"s1")), - Seq(Row(Seq("1", "a", "b", "c")), Row(Seq("1", "0", "a"))) - ) - checkAnswer( - df.select(concat($"i1", $"i2", $"i3")), - Seq(Row(Seq(1, 2, 3, 5, 6)), Row(Seq(1, 0, 2))) - ) - checkAnswer( - df.filter(dummyFilter($"i1")).select(concat($"i1", $"i2", $"i3")), - Seq(Row(Seq(1, 2, 3, 5, 6)), Row(Seq(1, 0, 2))) - ) - checkAnswer( - df.selectExpr("concat(array(1, null), i2, i3)"), - Seq(Row(Seq(1, null, 2, 3, 5, 6)), Row(Seq(1, null, 2))) - ) - checkAnswer( - df.select(concat($"s1", $"s2", $"s3")), - Seq(Row(Seq("a", "b", "c", "d", "e", "f")), Row(Seq("a", null))) - ) - checkAnswer( - df.selectExpr("concat(s1, s2, s3)"), - Seq(Row(Seq("a", "b", "c", "d", "e", "f")), Row(Seq("a", null))) - ) - checkAnswer( - df.filter(dummyFilter($"s1"))select(concat($"s1", $"s2", $"s3")), - Seq(Row(Seq("a", "b", "c", "d", "e", "f")), Row(Seq("a", null))) - ) + // Test with local relation, the Project will be evaluated without codegen + simpleTest() + // Test with cached relation, the Project will be evaluated with codegen + df.cache() + simpleTest() // Null test cases - checkAnswer( - df.select(concat($"i1", $"in")), - Seq(Row(null), Row(null)) - ) - checkAnswer( - df.select(concat($"in", $"i1")), - Seq(Row(null), Row(null)) - ) - checkAnswer( - df.select(concat($"s1", $"sn")), - Seq(Row(null), Row(null)) - ) - checkAnswer( - df.select(concat($"sn", $"s1")), - Seq(Row(null), Row(null)) - ) + def nullTest(): Unit = { + checkAnswer( + df.select(concat($"i1", $"in")), + Seq(Row(null), Row(null)) + ) + checkAnswer( + df.select(concat($"in", $"i1")), + Seq(Row(null), Row(null)) + ) + checkAnswer( + df.select(concat($"s1", $"sn")), + Seq(Row(null), Row(null)) + ) + checkAnswer( + df.select(concat($"sn", $"s1")), + Seq(Row(null), Row(null)) + ) + } + + // Test with local relation, the Project will be evaluated without codegen + df.unpersist() + nullTest() + // Test with cached relation, the Project will be evaluated with codegen + df.cache() + nullTest() // Type error test cases intercept[AnalysisException] { @@ -1200,9 +1255,6 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { } test("flatten function") { - val dummyFilter = (c: Column) => c.isNull || c.isNotNull // to switch codeGen on - val oneRowDF = Seq((1, "a", Seq(1, 2, 3))).toDF("i", "s", "arr") - // Test cases with a primitive type val intDF = Seq( (Seq(Seq(1, 2, 3), Seq(4, 5), Seq(6))), @@ -1225,12 +1277,16 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { Row(null), Row(null)) - checkAnswer(intDF.select(flatten($"i")), intDFResult) - checkAnswer(intDF.filter(dummyFilter($"i"))select(flatten($"i")), intDFResult) - checkAnswer(intDF.selectExpr("flatten(i)"), intDFResult) - checkAnswer( - oneRowDF.selectExpr("flatten(array(arr, array(null, 5), array(6, null)))"), - Seq(Row(Seq(1, 2, 3, null, 5, 6, null)))) + def testInt(): Unit = { + checkAnswer(intDF.select(flatten($"i")), intDFResult) + checkAnswer(intDF.selectExpr("flatten(i)"), intDFResult) + } + + // Test with local relation, the Project will be evaluated without codegen + testInt() + // Test with cached relation, the Project will be evaluated with codegen + intDF.cache() + testInt() // Test cases with non-primitive types val strDF = Seq( @@ -1256,14 +1312,36 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { Row(null), Row(null)) - checkAnswer(strDF.select(flatten($"s")), strDFResult) - checkAnswer(strDF.filter(dummyFilter($"s")).select(flatten($"s")), strDFResult) - checkAnswer(strDF.selectExpr("flatten(s)"), strDFResult) - checkAnswer( - oneRowDF.selectExpr("flatten(array(array(arr, arr), array(arr)))"), - Seq(Row(Seq(Seq(1, 2, 3), Seq(1, 2, 3), Seq(1, 2, 3))))) + def testString(): Unit = { + checkAnswer(strDF.select(flatten($"s")), strDFResult) + checkAnswer(strDF.selectExpr("flatten(s)"), strDFResult) + } + + // Test with local relation, the Project will be evaluated without codegen + testString() + // Test with cached relation, the Project will be evaluated with codegen + strDF.cache() + testString() + + val arrDF = Seq((1, "a", Seq(1, 2, 3))).toDF("i", "s", "arr") + + def testArray(): Unit = { + checkAnswer( + arrDF.selectExpr("flatten(array(arr, array(null, 5), array(6, null)))"), + Seq(Row(Seq(1, 2, 3, null, 5, 6, null)))) + checkAnswer( + arrDF.selectExpr("flatten(array(array(arr, arr), array(arr)))"), + Seq(Row(Seq(Seq(1, 2, 3), Seq(1, 2, 3), Seq(1, 2, 3))))) + } + + // Test with local relation, the Project will be evaluated without codegen + testArray() + // Test with cached relation, the Project will be evaluated with codegen + arrDF.cache() + testArray() // Error test cases + val oneRowDF = Seq((1, "a", Seq(1, 2, 3))).toDF("i", "s", "arr") intercept[AnalysisException] { oneRowDF.select(flatten($"arr")) } @@ -1279,7 +1357,6 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { } test("array_repeat function") { - val dummyFilter = (c: Column) => c.isNull || c.isNotNull // to switch codeGen on val strDF = Seq( ("hi", 2), (null, 2) @@ -1290,12 +1367,18 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { Row(Seq(null, null)) ) - checkAnswer(strDF.select(array_repeat($"a", 2)), strDFTwiceResult) - checkAnswer(strDF.filter(dummyFilter($"a")).select(array_repeat($"a", 2)), strDFTwiceResult) - checkAnswer(strDF.select(array_repeat($"a", $"b")), strDFTwiceResult) - checkAnswer(strDF.filter(dummyFilter($"a")).select(array_repeat($"a", $"b")), strDFTwiceResult) - checkAnswer(strDF.selectExpr("array_repeat(a, 2)"), strDFTwiceResult) - checkAnswer(strDF.selectExpr("array_repeat(a, b)"), strDFTwiceResult) + def testString(): Unit = { + checkAnswer(strDF.select(array_repeat($"a", 2)), strDFTwiceResult) + checkAnswer(strDF.select(array_repeat($"a", $"b")), strDFTwiceResult) + checkAnswer(strDF.selectExpr("array_repeat(a, 2)"), strDFTwiceResult) + checkAnswer(strDF.selectExpr("array_repeat(a, b)"), strDFTwiceResult) + } + + // Test with local relation, the Project will be evaluated without codegen + testString() + // Test with cached relation, the Project will be evaluated with codegen + strDF.cache() + testString() val intDF = { val schema = StructType(Seq( @@ -1313,12 +1396,18 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { Row(Seq(null, null)) ) - checkAnswer(intDF.select(array_repeat($"a", 2)), intDFTwiceResult) - checkAnswer(intDF.filter(dummyFilter($"a")).select(array_repeat($"a", 2)), intDFTwiceResult) - checkAnswer(intDF.select(array_repeat($"a", $"b")), intDFTwiceResult) - checkAnswer(intDF.filter(dummyFilter($"a")).select(array_repeat($"a", $"b")), intDFTwiceResult) - checkAnswer(intDF.selectExpr("array_repeat(a, 2)"), intDFTwiceResult) - checkAnswer(intDF.selectExpr("array_repeat(a, b)"), intDFTwiceResult) + def testInt(): Unit = { + checkAnswer(intDF.select(array_repeat($"a", 2)), intDFTwiceResult) + checkAnswer(intDF.select(array_repeat($"a", $"b")), intDFTwiceResult) + checkAnswer(intDF.selectExpr("array_repeat(a, 2)"), intDFTwiceResult) + checkAnswer(intDF.selectExpr("array_repeat(a, b)"), intDFTwiceResult) + } + + // Test with local relation, the Project will be evaluated without codegen + testInt() + // Test with cached relation, the Project will be evaluated with codegen + intDF.cache() + testInt() val nullCountDF = { val schema = StructType(Seq( @@ -1331,13 +1420,18 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { spark.createDataFrame(spark.sparkContext.parallelize(data), schema) } - checkAnswer( - nullCountDF.select(array_repeat($"a", $"b")), - Seq( - Row(null), - Row(null) + def testNull(): Unit = { + checkAnswer( + nullCountDF.select(array_repeat($"a", $"b")), + Seq(Row(null), Row(null)) ) - ) + } + + // Test with local relation, the Project will be evaluated without codegen + testNull() + // Test with cached relation, the Project will be evaluated with codegen + nullCountDF.cache() + testNull() // Error test cases val invalidTypeDF = Seq(("hi", "1")).toDF("a", "b") @@ -1421,40 +1515,1150 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { ) } - private def assertValuesDoNotChangeAfterCoalesceOrUnion(v: Column): Unit = { - import DataFrameFunctionsSuite.CodegenFallbackExpr - for ((codegenFallback, wholeStage) <- Seq((true, false), (false, false), (false, true))) { - val c = if (codegenFallback) { - Column(CodegenFallbackExpr(v.expr)) - } else { - v - } - withSQLConf( - (SQLConf.CODEGEN_FALLBACK.key, codegenFallback.toString), - (SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key, wholeStage.toString)) { - val df = spark.range(0, 4, 1, 4).withColumn("c", c) - val rows = df.collect() - val rowsAfterCoalesce = df.coalesce(2).collect() - assert(rows === rowsAfterCoalesce, "Values changed after coalesce when " + - s"codegenFallback=$codegenFallback and wholeStage=$wholeStage.") + // Shuffle expressions should produce same results at retries in the same DataFrame. + private def checkShuffleResult(df: DataFrame): Unit = { + checkAnswer(df, df.collect()) + } - val df1 = spark.range(0, 2, 1, 2).withColumn("c", c) - val rows1 = df1.collect() - val df2 = spark.range(2, 4, 1, 2).withColumn("c", c) - val rows2 = df2.collect() - val rowsAfterUnion = df1.union(df2).collect() - assert(rowsAfterUnion === rows1 ++ rows2, "Values changed after union when " + - s"codegenFallback=$codegenFallback and wholeStage=$wholeStage.") - } + test("shuffle function - array for primitive type not containing null") { + val idfNotContainsNull = Seq( + Seq(1, 9, 8, 7), + Seq(5, 8, 9, 7, 2), + Seq.empty, + null + ).toDF("i") + + def testArrayOfPrimitiveTypeNotContainsNull(): Unit = { + checkShuffleResult(idfNotContainsNull.select(shuffle('i))) + checkShuffleResult(idfNotContainsNull.selectExpr("shuffle(i)")) } + + // Test with local relation, the Project will be evaluated without codegen + testArrayOfPrimitiveTypeNotContainsNull() + // Test with cached relation, the Project will be evaluated with codegen + idfNotContainsNull.cache() + testArrayOfPrimitiveTypeNotContainsNull() } - test("SPARK-14393: values generated by non-deterministic functions shouldn't change after " + - "coalesce or union") { - Seq( - monotonically_increasing_id(), spark_partition_id(), - rand(Random.nextLong()), randn(Random.nextLong()) - ).foreach(assertValuesDoNotChangeAfterCoalesceOrUnion(_)) + test("shuffle function - array for primitive type containing null") { + val idfContainsNull = Seq[Seq[Integer]]( + Seq(1, 9, 8, null, 7), + Seq(null, 5, 8, 9, 7, 2), + Seq.empty, + null + ).toDF("i") + + def testArrayOfPrimitiveTypeContainsNull(): Unit = { + checkShuffleResult(idfContainsNull.select(shuffle('i))) + checkShuffleResult(idfContainsNull.selectExpr("shuffle(i)")) + } + + // Test with local relation, the Project will be evaluated without codegen + testArrayOfPrimitiveTypeContainsNull() + // Test with cached relation, the Project will be evaluated with codegen + idfContainsNull.cache() + testArrayOfPrimitiveTypeContainsNull() + } + + test("shuffle function - array for non-primitive type") { + val sdf = Seq( + Seq("c", "a", "b"), + Seq("b", null, "c", null), + Seq.empty, + null + ).toDF("s") + + def testNonPrimitiveType(): Unit = { + checkShuffleResult(sdf.select(shuffle('s))) + checkShuffleResult(sdf.selectExpr("shuffle(s)")) + } + + // Test with local relation, the Project will be evaluated without codegen + testNonPrimitiveType() + // Test with cached relation, the Project will be evaluated with codegen + sdf.cache() + testNonPrimitiveType() + } + + test("array_except functions") { + val df1 = Seq((Array(1, 2, 4), Array(4, 2))).toDF("a", "b") + val ans1 = Row(Seq(1)) + checkAnswer(df1.select(array_except($"a", $"b")), ans1) + checkAnswer(df1.selectExpr("array_except(a, b)"), ans1) + + val df2 = Seq((Array[Integer](1, 2, null, 4, 5), Array[Integer](-5, 4, null, 2, -1))) + .toDF("a", "b") + val ans2 = Row(Seq(1, 5)) + checkAnswer(df2.select(array_except($"a", $"b")), ans2) + checkAnswer(df2.selectExpr("array_except(a, b)"), ans2) + + val df3 = Seq((Array(1L, 2L, 4L), Array(4L, 2L))).toDF("a", "b") + val ans3 = Row(Seq(1L)) + checkAnswer(df3.select(array_except($"a", $"b")), ans3) + checkAnswer(df3.selectExpr("array_except(a, b)"), ans3) + + val df4 = Seq( + (Array[java.lang.Long](1L, 2L, null, 4L, 5L), Array[java.lang.Long](-5L, 4L, null, 2L, -1L))) + .toDF("a", "b") + val ans4 = Row(Seq(1L, 5L)) + checkAnswer(df4.select(array_except($"a", $"b")), ans4) + checkAnswer(df4.selectExpr("array_except(a, b)"), ans4) + + val df5 = Seq((Array("c", null, "a", "f"), Array("b", null, "a", "g"))).toDF("a", "b") + val ans5 = Row(Seq("c", "f")) + checkAnswer(df5.select(array_except($"a", $"b")), ans5) + checkAnswer(df5.selectExpr("array_except(a, b)"), ans5) + + val df6 = Seq((null, null)).toDF("a", "b") + intercept[AnalysisException] { + df6.select(array_except($"a", $"b")) + } + intercept[AnalysisException] { + df6.selectExpr("array_except(a, b)") + } + val df7 = Seq((Array(1), Array("a"))).toDF("a", "b") + intercept[AnalysisException] { + df7.select(array_except($"a", $"b")) + } + intercept[AnalysisException] { + df7.selectExpr("array_except(a, b)") + } + val df8 = Seq((Array("a"), null)).toDF("a", "b") + intercept[AnalysisException] { + df8.select(array_except($"a", $"b")) + } + intercept[AnalysisException] { + df8.selectExpr("array_except(a, b)") + } + val df9 = Seq((null, Array("a"))).toDF("a", "b") + intercept[AnalysisException] { + df9.select(array_except($"a", $"b")) + } + intercept[AnalysisException] { + df9.selectExpr("array_except(a, b)") + } + + val df10 = Seq( + (Array[Integer](1, 2), Array[Integer](2)), + (Array[Integer](1, 2), Array[Integer](1, null)), + (Array[Integer](1, null, 3), Array[Integer](1, 2)), + (Array[Integer](1, null), Array[Integer](2, null)) + ).toDF("a", "b") + val result10 = df10.select(array_except($"a", $"b")) + val expectedType10 = ArrayType(IntegerType, containsNull = true) + assert(result10.first.schema(0).dataType === expectedType10) + } + + test("array_intersect functions") { + val df1 = Seq((Array(1, 2, 4), Array(4, 2))).toDF("a", "b") + val ans1 = Row(Seq(2, 4)) + checkAnswer(df1.select(array_intersect($"a", $"b")), ans1) + checkAnswer(df1.selectExpr("array_intersect(a, b)"), ans1) + + val df2 = Seq((Array[Integer](1, 2, null, 4, 5), Array[Integer](-5, 4, null, 2, -1))) + .toDF("a", "b") + val ans2 = Row(Seq(2, null, 4)) + checkAnswer(df2.select(array_intersect($"a", $"b")), ans2) + checkAnswer(df2.selectExpr("array_intersect(a, b)"), ans2) + + val df3 = Seq((Array(1L, 2L, 4L), Array(4L, 2L))).toDF("a", "b") + val ans3 = Row(Seq(2L, 4L)) + checkAnswer(df3.select(array_intersect($"a", $"b")), ans3) + checkAnswer(df3.selectExpr("array_intersect(a, b)"), ans3) + + val df4 = Seq( + (Array[java.lang.Long](1L, 2L, null, 4L, 5L), Array[java.lang.Long](-5L, 4L, null, 2L, -1L))) + .toDF("a", "b") + val ans4 = Row(Seq(2L, null, 4L)) + checkAnswer(df4.select(array_intersect($"a", $"b")), ans4) + checkAnswer(df4.selectExpr("array_intersect(a, b)"), ans4) + + val df5 = Seq((Array("c", null, "a", "f"), Array("b", "a", null, "g"))).toDF("a", "b") + val ans5 = Row(Seq(null, "a")) + checkAnswer(df5.select(array_intersect($"a", $"b")), ans5) + checkAnswer(df5.selectExpr("array_intersect(a, b)"), ans5) + + val df6 = Seq((null, null)).toDF("a", "b") + assert(intercept[AnalysisException] { + df6.select(array_intersect($"a", $"b")) + }.getMessage.contains("data type mismatch")) + assert(intercept[AnalysisException] { + df6.selectExpr("array_intersect(a, b)") + }.getMessage.contains("data type mismatch")) + + val df7 = Seq((Array(1), Array("a"))).toDF("a", "b") + assert(intercept[AnalysisException] { + df7.select(array_intersect($"a", $"b")) + }.getMessage.contains("data type mismatch")) + assert(intercept[AnalysisException] { + df7.selectExpr("array_intersect(a, b)") + }.getMessage.contains("data type mismatch")) + + val df8 = Seq((null, Array("a"))).toDF("a", "b") + assert(intercept[AnalysisException] { + df8.select(array_intersect($"a", $"b")) + }.getMessage.contains("data type mismatch")) + assert(intercept[AnalysisException] { + df8.selectExpr("array_intersect(a, b)") + }.getMessage.contains("data type mismatch")) + } + + test("transform function - array for primitive type not containing null") { + val df = Seq( + Seq(1, 9, 8, 7), + Seq(5, 8, 9, 7, 2), + Seq.empty, + null + ).toDF("i") + + def testArrayOfPrimitiveTypeNotContainsNull(): Unit = { + checkAnswer(df.selectExpr("transform(i, x -> x + 1)"), + Seq( + Row(Seq(2, 10, 9, 8)), + Row(Seq(6, 9, 10, 8, 3)), + Row(Seq.empty), + Row(null))) + checkAnswer(df.selectExpr("transform(i, (x, i) -> x + i)"), + Seq( + Row(Seq(1, 10, 10, 10)), + Row(Seq(5, 9, 11, 10, 6)), + Row(Seq.empty), + Row(null))) + } + + // Test with local relation, the Project will be evaluated without codegen + testArrayOfPrimitiveTypeNotContainsNull() + // Test with cached relation, the Project will be evaluated with codegen + df.cache() + testArrayOfPrimitiveTypeNotContainsNull() + } + + test("transform function - array for primitive type containing null") { + val df = Seq[Seq[Integer]]( + Seq(1, 9, 8, null, 7), + Seq(5, null, 8, 9, 7, 2), + Seq.empty, + null + ).toDF("i") + + def testArrayOfPrimitiveTypeContainsNull(): Unit = { + checkAnswer(df.selectExpr("transform(i, x -> x + 1)"), + Seq( + Row(Seq(2, 10, 9, null, 8)), + Row(Seq(6, null, 9, 10, 8, 3)), + Row(Seq.empty), + Row(null))) + checkAnswer(df.selectExpr("transform(i, (x, i) -> x + i)"), + Seq( + Row(Seq(1, 10, 10, null, 11)), + Row(Seq(5, null, 10, 12, 11, 7)), + Row(Seq.empty), + Row(null))) + } + + // Test with local relation, the Project will be evaluated without codegen + testArrayOfPrimitiveTypeContainsNull() + // Test with cached relation, the Project will be evaluated with codegen + df.cache() + testArrayOfPrimitiveTypeContainsNull() + } + + test("transform function - array for non-primitive type") { + val df = Seq( + Seq("c", "a", "b"), + Seq("b", null, "c", null), + Seq.empty, + null + ).toDF("s") + + def testNonPrimitiveType(): Unit = { + checkAnswer(df.selectExpr("transform(s, x -> concat(x, x))"), + Seq( + Row(Seq("cc", "aa", "bb")), + Row(Seq("bb", null, "cc", null)), + Row(Seq.empty), + Row(null))) + checkAnswer(df.selectExpr("transform(s, (x, i) -> concat(x, i))"), + Seq( + Row(Seq("c0", "a1", "b2")), + Row(Seq("b0", null, "c2", null)), + Row(Seq.empty), + Row(null))) + } + + // Test with local relation, the Project will be evaluated without codegen + testNonPrimitiveType() + // Test with cached relation, the Project will be evaluated with codegen + df.cache() + testNonPrimitiveType() + } + + test("transform function - special cases") { + val df = Seq( + Seq("c", "a", "b"), + Seq("b", null, "c", null), + Seq.empty, + null + ).toDF("arg") + + def testSpecialCases(): Unit = { + checkAnswer(df.selectExpr("transform(arg, arg -> arg)"), + Seq( + Row(Seq("c", "a", "b")), + Row(Seq("b", null, "c", null)), + Row(Seq.empty), + Row(null))) + checkAnswer(df.selectExpr("transform(arg, arg)"), + Seq( + Row(Seq(Seq("c", "a", "b"), Seq("c", "a", "b"), Seq("c", "a", "b"))), + Row(Seq( + Seq("b", null, "c", null), + Seq("b", null, "c", null), + Seq("b", null, "c", null), + Seq("b", null, "c", null))), + Row(Seq.empty), + Row(null))) + checkAnswer(df.selectExpr("transform(arg, x -> concat(arg, array(x)))"), + Seq( + Row(Seq(Seq("c", "a", "b", "c"), Seq("c", "a", "b", "a"), Seq("c", "a", "b", "b"))), + Row(Seq( + Seq("b", null, "c", null, "b"), + Seq("b", null, "c", null, null), + Seq("b", null, "c", null, "c"), + Seq("b", null, "c", null, null))), + Row(Seq.empty), + Row(null))) + } + + // Test with local relation, the Project will be evaluated without codegen + testSpecialCases() + // Test with cached relation, the Project will be evaluated with codegen + df.cache() + testSpecialCases() + } + + test("transform function - invalid") { + val df = Seq( + (Seq("c", "a", "b"), 1), + (Seq("b", null, "c", null), 2), + (Seq.empty, 3), + (null, 4) + ).toDF("s", "i") + + val ex1 = intercept[AnalysisException] { + df.selectExpr("transform(s, (x, y, z) -> x + y + z)") + } + assert(ex1.getMessage.contains("The number of lambda function arguments '3' does not match")) + + val ex2 = intercept[AnalysisException] { + df.selectExpr("transform(i, x -> x)") + } + assert(ex2.getMessage.contains("data type mismatch: argument 1 requires array type")) + + val ex3 = intercept[AnalysisException] { + df.selectExpr("transform(a, x -> x)") + } + assert(ex3.getMessage.contains("cannot resolve '`a`'")) + } + + test("map_filter") { + val dfInts = Seq( + Map(1 -> 10, 2 -> 20, 3 -> 30), + Map(1 -> -1, 2 -> -2, 3 -> -3), + Map(1 -> 10, 2 -> 5, 3 -> -3)).toDF("m") + + checkAnswer(dfInts.selectExpr( + "map_filter(m, (k, v) -> k * 10 = v)", "map_filter(m, (k, v) -> k = -v)"), + Seq( + Row(Map(1 -> 10, 2 -> 20, 3 -> 30), Map()), + Row(Map(), Map(1 -> -1, 2 -> -2, 3 -> -3)), + Row(Map(1 -> 10), Map(3 -> -3)))) + + val dfComplex = Seq( + Map(1 -> Seq(Some(1)), 2 -> Seq(Some(1), Some(2)), 3 -> Seq(Some(1), Some(2), Some(3))), + Map(1 -> null, 2 -> Seq(Some(-2), Some(-2)), 3 -> Seq[Option[Int]](None))).toDF("m") + + checkAnswer(dfComplex.selectExpr( + "map_filter(m, (k, v) -> k = v[0])", "map_filter(m, (k, v) -> k = size(v))"), + Seq( + Row(Map(1 -> Seq(1)), Map(1 -> Seq(1), 2 -> Seq(1, 2), 3 -> Seq(1, 2, 3))), + Row(Map(), Map(2 -> Seq(-2, -2))))) + + // Invalid use cases + val df = Seq( + (Map(1 -> "a"), 1), + (Map.empty[Int, String], 2), + (null, 3) + ).toDF("s", "i") + + val ex1 = intercept[AnalysisException] { + df.selectExpr("map_filter(s, (x, y, z) -> x + y + z)") + } + assert(ex1.getMessage.contains("The number of lambda function arguments '3' does not match")) + + val ex2 = intercept[AnalysisException] { + df.selectExpr("map_filter(s, x -> x)") + } + assert(ex2.getMessage.contains("The number of lambda function arguments '1' does not match")) + + val ex3 = intercept[AnalysisException] { + df.selectExpr("map_filter(i, (k, v) -> k > v)") + } + assert(ex3.getMessage.contains("data type mismatch: argument 1 requires map type")) + + val ex4 = intercept[AnalysisException] { + df.selectExpr("map_filter(a, (k, v) -> k > v)") + } + assert(ex4.getMessage.contains("cannot resolve '`a`'")) + } + + test("filter function - array for primitive type not containing null") { + val df = Seq( + Seq(1, 9, 8, 7), + Seq(5, 8, 9, 7, 2), + Seq.empty, + null + ).toDF("i") + + def testArrayOfPrimitiveTypeNotContainsNull(): Unit = { + checkAnswer(df.selectExpr("filter(i, x -> x % 2 == 0)"), + Seq( + Row(Seq(8)), + Row(Seq(8, 2)), + Row(Seq.empty), + Row(null))) + } + + // Test with local relation, the Project will be evaluated without codegen + testArrayOfPrimitiveTypeNotContainsNull() + // Test with cached relation, the Project will be evaluated with codegen + df.cache() + testArrayOfPrimitiveTypeNotContainsNull() + } + + test("filter function - array for primitive type containing null") { + val df = Seq[Seq[Integer]]( + Seq(1, 9, 8, null, 7), + Seq(5, null, 8, 9, 7, 2), + Seq.empty, + null + ).toDF("i") + + def testArrayOfPrimitiveTypeContainsNull(): Unit = { + checkAnswer(df.selectExpr("filter(i, x -> x % 2 == 0)"), + Seq( + Row(Seq(8)), + Row(Seq(8, 2)), + Row(Seq.empty), + Row(null))) + } + + // Test with local relation, the Project will be evaluated without codegen + testArrayOfPrimitiveTypeContainsNull() + // Test with cached relation, the Project will be evaluated with codegen + df.cache() + testArrayOfPrimitiveTypeContainsNull() + } + + test("filter function - array for non-primitive type") { + val df = Seq( + Seq("c", "a", "b"), + Seq("b", null, "c", null), + Seq.empty, + null + ).toDF("s") + + def testNonPrimitiveType(): Unit = { + checkAnswer(df.selectExpr("filter(s, x -> x is not null)"), + Seq( + Row(Seq("c", "a", "b")), + Row(Seq("b", "c")), + Row(Seq.empty), + Row(null))) + } + + // Test with local relation, the Project will be evaluated without codegen + testNonPrimitiveType() + // Test with cached relation, the Project will be evaluated with codegen + df.cache() + testNonPrimitiveType() + } + + test("filter function - invalid") { + val df = Seq( + (Seq("c", "a", "b"), 1), + (Seq("b", null, "c", null), 2), + (Seq.empty, 3), + (null, 4) + ).toDF("s", "i") + + val ex1 = intercept[AnalysisException] { + df.selectExpr("filter(s, (x, y) -> x + y)") + } + assert(ex1.getMessage.contains("The number of lambda function arguments '2' does not match")) + + val ex2 = intercept[AnalysisException] { + df.selectExpr("filter(i, x -> x)") + } + assert(ex2.getMessage.contains("data type mismatch: argument 1 requires array type")) + + val ex3 = intercept[AnalysisException] { + df.selectExpr("filter(s, x -> x)") + } + assert(ex3.getMessage.contains("data type mismatch: argument 2 requires boolean type")) + + val ex4 = intercept[AnalysisException] { + df.selectExpr("filter(a, x -> x)") + } + assert(ex4.getMessage.contains("cannot resolve '`a`'")) + } + + test("exists function - array for primitive type not containing null") { + val df = Seq( + Seq(1, 9, 8, 7), + Seq(5, 9, 7), + Seq.empty, + null + ).toDF("i") + + def testArrayOfPrimitiveTypeNotContainsNull(): Unit = { + checkAnswer(df.selectExpr("exists(i, x -> x % 2 == 0)"), + Seq( + Row(true), + Row(false), + Row(false), + Row(null))) + } + + // Test with local relation, the Project will be evaluated without codegen + testArrayOfPrimitiveTypeNotContainsNull() + // Test with cached relation, the Project will be evaluated with codegen + df.cache() + testArrayOfPrimitiveTypeNotContainsNull() + } + + test("exists function - array for primitive type containing null") { + val df = Seq[Seq[Integer]]( + Seq(1, 9, 8, null, 7), + Seq(5, null, null, 9, 7, null), + Seq.empty, + null + ).toDF("i") + + def testArrayOfPrimitiveTypeContainsNull(): Unit = { + checkAnswer(df.selectExpr("exists(i, x -> x % 2 == 0)"), + Seq( + Row(true), + Row(false), + Row(false), + Row(null))) + } + + // Test with local relation, the Project will be evaluated without codegen + testArrayOfPrimitiveTypeContainsNull() + // Test with cached relation, the Project will be evaluated with codegen + df.cache() + testArrayOfPrimitiveTypeContainsNull() + } + + test("exists function - array for non-primitive type") { + val df = Seq( + Seq("c", "a", "b"), + Seq("b", null, "c", null), + Seq.empty, + null + ).toDF("s") + + def testNonPrimitiveType(): Unit = { + checkAnswer(df.selectExpr("exists(s, x -> x is null)"), + Seq( + Row(false), + Row(true), + Row(false), + Row(null))) + } + + // Test with local relation, the Project will be evaluated without codegen + testNonPrimitiveType() + // Test with cached relation, the Project will be evaluated with codegen + df.cache() + testNonPrimitiveType() + } + + test("exists function - invalid") { + val df = Seq( + (Seq("c", "a", "b"), 1), + (Seq("b", null, "c", null), 2), + (Seq.empty, 3), + (null, 4) + ).toDF("s", "i") + + val ex1 = intercept[AnalysisException] { + df.selectExpr("exists(s, (x, y) -> x + y)") + } + assert(ex1.getMessage.contains("The number of lambda function arguments '2' does not match")) + + val ex2 = intercept[AnalysisException] { + df.selectExpr("exists(i, x -> x)") + } + assert(ex2.getMessage.contains("data type mismatch: argument 1 requires array type")) + + val ex3 = intercept[AnalysisException] { + df.selectExpr("exists(s, x -> x)") + } + assert(ex3.getMessage.contains("data type mismatch: argument 2 requires boolean type")) + + val ex4 = intercept[AnalysisException] { + df.selectExpr("exists(a, x -> x)") + } + assert(ex4.getMessage.contains("cannot resolve '`a`'")) + } + + test("aggregate function - array for primitive type not containing null") { + val df = Seq( + Seq(1, 9, 8, 7), + Seq(5, 8, 9, 7, 2), + Seq.empty, + null + ).toDF("i") + + def testArrayOfPrimitiveTypeNotContainsNull(): Unit = { + checkAnswer(df.selectExpr("aggregate(i, 0, (acc, x) -> acc + x)"), + Seq( + Row(25), + Row(31), + Row(0), + Row(null))) + checkAnswer(df.selectExpr("aggregate(i, 0, (acc, x) -> acc + x, acc -> acc * 10)"), + Seq( + Row(250), + Row(310), + Row(0), + Row(null))) + } + + // Test with local relation, the Project will be evaluated without codegen + testArrayOfPrimitiveTypeNotContainsNull() + // Test with cached relation, the Project will be evaluated with codegen + df.cache() + testArrayOfPrimitiveTypeNotContainsNull() + } + + test("aggregate function - array for primitive type containing null") { + val df = Seq[Seq[Integer]]( + Seq(1, 9, 8, 7), + Seq(5, null, 8, 9, 7, 2), + Seq.empty, + null + ).toDF("i") + + def testArrayOfPrimitiveTypeContainsNull(): Unit = { + checkAnswer(df.selectExpr("aggregate(i, 0, (acc, x) -> acc + x)"), + Seq( + Row(25), + Row(null), + Row(0), + Row(null))) + checkAnswer( + df.selectExpr("aggregate(i, 0, (acc, x) -> acc + x, acc -> coalesce(acc, 0) * 10)"), + Seq( + Row(250), + Row(0), + Row(0), + Row(null))) + } + + // Test with local relation, the Project will be evaluated without codegen + testArrayOfPrimitiveTypeContainsNull() + // Test with cached relation, the Project will be evaluated with codegen + df.cache() + testArrayOfPrimitiveTypeContainsNull() + } + + test("aggregate function - array for non-primitive type") { + val df = Seq( + (Seq("c", "a", "b"), "a"), + (Seq("b", null, "c", null), "b"), + (Seq.empty, "c"), + (null, "d") + ).toDF("ss", "s") + + def testNonPrimitiveType(): Unit = { + checkAnswer(df.selectExpr("aggregate(ss, s, (acc, x) -> concat(acc, x))"), + Seq( + Row("acab"), + Row(null), + Row("c"), + Row(null))) + checkAnswer( + df.selectExpr("aggregate(ss, s, (acc, x) -> concat(acc, x), acc -> coalesce(acc , ''))"), + Seq( + Row("acab"), + Row(""), + Row("c"), + Row(null))) + } + + // Test with local relation, the Project will be evaluated without codegen + testNonPrimitiveType() + // Test with cached relation, the Project will be evaluated with codegen + df.cache() + testNonPrimitiveType() + } + + test("aggregate function - invalid") { + val df = Seq( + (Seq("c", "a", "b"), 1), + (Seq("b", null, "c", null), 2), + (Seq.empty, 3), + (null, 4) + ).toDF("s", "i") + + val ex1 = intercept[AnalysisException] { + df.selectExpr("aggregate(s, '', x -> x)") + } + assert(ex1.getMessage.contains("The number of lambda function arguments '1' does not match")) + + val ex2 = intercept[AnalysisException] { + df.selectExpr("aggregate(s, '', (acc, x) -> x, (acc, x) -> x)") + } + assert(ex2.getMessage.contains("The number of lambda function arguments '2' does not match")) + + val ex3 = intercept[AnalysisException] { + df.selectExpr("aggregate(i, 0, (acc, x) -> x)") + } + assert(ex3.getMessage.contains("data type mismatch: argument 1 requires array type")) + + val ex4 = intercept[AnalysisException] { + df.selectExpr("aggregate(s, 0, (acc, x) -> x)") + } + assert(ex4.getMessage.contains("data type mismatch: argument 3 requires int type")) + + val ex5 = intercept[AnalysisException] { + df.selectExpr("aggregate(a, 0, (acc, x) -> x)") + } + assert(ex5.getMessage.contains("cannot resolve '`a`'")) + } + + test("map_zip_with function - map of primitive types") { + val df = Seq( + (Map(8 -> 6L, 3 -> 5L, 6 -> 2L), Map[Integer, Integer]((6, 4), (8, 2), (3, 2))), + (Map(10 -> 6L, 8 -> 3L), Map[Integer, Integer]((8, 4), (4, null))), + (Map.empty[Int, Long], Map[Integer, Integer]((5, 1))), + (Map(5 -> 1L), null) + ).toDF("m1", "m2") + + checkAnswer(df.selectExpr("map_zip_with(m1, m2, (k, v1, v2) -> k == v1 + v2)"), + Seq( + Row(Map(8 -> true, 3 -> false, 6 -> true)), + Row(Map(10 -> null, 8 -> false, 4 -> null)), + Row(Map(5 -> null)), + Row(null))) + } + + test("map_zip_with function - map of non-primitive types") { + val df = Seq( + (Map("z" -> "a", "y" -> "b", "x" -> "c"), Map("x" -> "a", "z" -> "c")), + (Map("b" -> "a", "c" -> "d"), Map("c" -> "a", "b" -> null, "d" -> "k")), + (Map("a" -> "d"), Map.empty[String, String]), + (Map("a" -> "d"), null) + ).toDF("m1", "m2") + + checkAnswer(df.selectExpr("map_zip_with(m1, m2, (k, v1, v2) -> (v1, v2))"), + Seq( + Row(Map("z" -> Row("a", "c"), "y" -> Row("b", null), "x" -> Row("c", "a"))), + Row(Map("b" -> Row("a", null), "c" -> Row("d", "a"), "d" -> Row(null, "k"))), + Row(Map("a" -> Row("d", null))), + Row(null))) + } + + test("map_zip_with function - invalid") { + val df = Seq( + (Map(1 -> 2), Map(1 -> "a"), Map("a" -> "b"), Map(Map(1 -> 2) -> 2), 1) + ).toDF("mii", "mis", "mss", "mmi", "i") + + val ex1 = intercept[AnalysisException] { + df.selectExpr("map_zip_with(mii, mis, (x, y) -> x + y)") + } + assert(ex1.getMessage.contains("The number of lambda function arguments '2' does not match")) + + val ex2 = intercept[AnalysisException] { + df.selectExpr("map_zip_with(mis, mmi, (x, y, z) -> concat(x, y, z))") + } + assert(ex2.getMessage.contains("The input to function map_zip_with should have " + + "been two maps with compatible key types")) + + val ex3 = intercept[AnalysisException] { + df.selectExpr("map_zip_with(i, mis, (x, y, z) -> concat(x, y, z))") + } + assert(ex3.getMessage.contains("type mismatch: argument 1 requires map type")) + + val ex4 = intercept[AnalysisException] { + df.selectExpr("map_zip_with(mis, i, (x, y, z) -> concat(x, y, z))") + } + assert(ex4.getMessage.contains("type mismatch: argument 2 requires map type")) + + val ex5 = intercept[AnalysisException] { + df.selectExpr("map_zip_with(mmi, mmi, (x, y, z) -> x)") + } + assert(ex5.getMessage.contains("function map_zip_with does not support ordering on type map")) + } + + test("transform keys function - primitive data types") { + val dfExample1 = Seq( + Map[Int, Int](1 -> 1, 9 -> 9, 8 -> 8, 7 -> 7) + ).toDF("i") + + val dfExample2 = Seq( + Map[Int, Double](1 -> 1.0, 2 -> 1.40, 3 -> 1.70) + ).toDF("j") + + val dfExample3 = Seq( + Map[Int, Boolean](25 -> true, 26 -> false) + ).toDF("x") + + val dfExample4 = Seq( + Map[Array[Int], Boolean](Array(1, 2) -> false) + ).toDF("y") + + + def testMapOfPrimitiveTypesCombination(): Unit = { + checkAnswer(dfExample1.selectExpr("transform_keys(i, (k, v) -> k + v)"), + Seq(Row(Map(2 -> 1, 18 -> 9, 16 -> 8, 14 -> 7)))) + + checkAnswer(dfExample2.selectExpr("transform_keys(j, " + + "(k, v) -> map_from_arrays(ARRAY(1, 2, 3), ARRAY('one', 'two', 'three'))[k])"), + Seq(Row(Map("one" -> 1.0, "two" -> 1.4, "three" -> 1.7)))) + + checkAnswer(dfExample2.selectExpr("transform_keys(j, (k, v) -> CAST(v * 2 AS BIGINT) + k)"), + Seq(Row(Map(3 -> 1.0, 4 -> 1.4, 6 -> 1.7)))) + + checkAnswer(dfExample2.selectExpr("transform_keys(j, (k, v) -> k + v)"), + Seq(Row(Map(2.0 -> 1.0, 3.4 -> 1.4, 4.7 -> 1.7)))) + + checkAnswer(dfExample3.selectExpr("transform_keys(x, (k, v) -> k % 2 = 0 OR v)"), + Seq(Row(Map(true -> true, true -> false)))) + + checkAnswer(dfExample3.selectExpr("transform_keys(x, (k, v) -> if(v, 2 * k, 3 * k))"), + Seq(Row(Map(50 -> true, 78 -> false)))) + + checkAnswer(dfExample3.selectExpr("transform_keys(x, (k, v) -> if(v, 2 * k, 3 * k))"), + Seq(Row(Map(50 -> true, 78 -> false)))) + + checkAnswer(dfExample4.selectExpr("transform_keys(y, (k, v) -> array_contains(k, 3) AND v)"), + Seq(Row(Map(false -> false)))) + } + + // Test with local relation, the Project will be evaluated without codegen + testMapOfPrimitiveTypesCombination() + dfExample1.cache() + dfExample2.cache() + dfExample3.cache() + dfExample4.cache() + // Test with cached relation, the Project will be evaluated with codegen + testMapOfPrimitiveTypesCombination() + } + + test("transform keys function - Invalid lambda functions and exceptions") { + val dfExample1 = Seq( + Map[String, String]("a" -> null) + ).toDF("i") + + val dfExample2 = Seq( + Seq(1, 2, 3, 4) + ).toDF("j") + + val ex1 = intercept[AnalysisException] { + dfExample1.selectExpr("transform_keys(i, k -> k)") + } + assert(ex1.getMessage.contains("The number of lambda function arguments '1' does not match")) + + val ex2 = intercept[AnalysisException] { + dfExample1.selectExpr("transform_keys(i, (k, v, x) -> k + 1)") + } + assert(ex2.getMessage.contains( + "The number of lambda function arguments '3' does not match")) + + val ex3 = intercept[Exception] { + dfExample1.selectExpr("transform_keys(i, (k, v) -> v)").show() + } + assert(ex3.getMessage.contains("Cannot use null as map key!")) + + val ex4 = intercept[AnalysisException] { + dfExample2.selectExpr("transform_keys(j, (k, v) -> k + 1)") + } + assert(ex4.getMessage.contains( + "data type mismatch: argument 1 requires map type")) + } + + test("transform values function - test primitive data types") { + val dfExample1 = Seq( + Map[Int, Int](1 -> 1, 9 -> 9, 8 -> 8, 7 -> 7) + ).toDF("i") + + val dfExample2 = Seq( + Map[Boolean, String](false -> "abc", true -> "def") + ).toDF("x") + + val dfExample3 = Seq( + Map[String, Int]("a" -> 1, "b" -> 2, "c" -> 3) + ).toDF("y") + + val dfExample4 = Seq( + Map[Int, Double](1 -> 1.0, 2 -> 1.40, 3 -> 1.70) + ).toDF("z") + + val dfExample5 = Seq( + Map[Int, Array[Int]](1 -> Array(1, 2)) + ).toDF("c") + + def testMapOfPrimitiveTypesCombination(): Unit = { + checkAnswer(dfExample1.selectExpr("transform_values(i, (k, v) -> k + v)"), + Seq(Row(Map(1 -> 2, 9 -> 18, 8 -> 16, 7 -> 14)))) + + checkAnswer(dfExample2.selectExpr( + "transform_values(x, (k, v) -> if(k, v, CAST(k AS String)))"), + Seq(Row(Map(false -> "false", true -> "def")))) + + checkAnswer(dfExample2.selectExpr("transform_values(x, (k, v) -> NOT k AND v = 'abc')"), + Seq(Row(Map(false -> true, true -> false)))) + + checkAnswer(dfExample3.selectExpr("transform_values(y, (k, v) -> v * v)"), + Seq(Row(Map("a" -> 1, "b" -> 4, "c" -> 9)))) + + checkAnswer(dfExample3.selectExpr( + "transform_values(y, (k, v) -> k || ':' || CAST(v as String))"), + Seq(Row(Map("a" -> "a:1", "b" -> "b:2", "c" -> "c:3")))) + + checkAnswer( + dfExample3.selectExpr("transform_values(y, (k, v) -> concat(k, cast(v as String)))"), + Seq(Row(Map("a" -> "a1", "b" -> "b2", "c" -> "c3")))) + + checkAnswer( + dfExample4.selectExpr( + "transform_values(" + + "z,(k, v) -> map_from_arrays(ARRAY(1, 2, 3), " + + "ARRAY('one', 'two', 'three'))[k] || '_' || CAST(v AS String))"), + Seq(Row(Map(1 -> "one_1.0", 2 -> "two_1.4", 3 ->"three_1.7")))) + + checkAnswer( + dfExample4.selectExpr("transform_values(z, (k, v) -> k-v)"), + Seq(Row(Map(1 -> 0.0, 2 -> 0.6000000000000001, 3 -> 1.3)))) + + checkAnswer( + dfExample5.selectExpr("transform_values(c, (k, v) -> k + cardinality(v))"), + Seq(Row(Map(1 -> 3)))) + } + + // Test with local relation, the Project will be evaluated without codegen + testMapOfPrimitiveTypesCombination() + dfExample1.cache() + dfExample2.cache() + dfExample3.cache() + dfExample4.cache() + dfExample5.cache() + // Test with cached relation, the Project will be evaluated with codegen + testMapOfPrimitiveTypesCombination() + } + + test("transform values function - test empty") { + val dfExample1 = Seq( + Map.empty[Integer, Integer] + ).toDF("i") + + val dfExample2 = Seq( + Map.empty[BigInt, String] + ).toDF("j") + + def testEmpty(): Unit = { + checkAnswer(dfExample1.selectExpr("transform_values(i, (k, v) -> NULL)"), + Seq(Row(Map.empty[Integer, Integer]))) + + checkAnswer(dfExample1.selectExpr("transform_values(i, (k, v) -> k)"), + Seq(Row(Map.empty[Integer, Integer]))) + + checkAnswer(dfExample1.selectExpr("transform_values(i, (k, v) -> v)"), + Seq(Row(Map.empty[Integer, Integer]))) + + checkAnswer(dfExample1.selectExpr("transform_values(i, (k, v) -> 0)"), + Seq(Row(Map.empty[Integer, Integer]))) + + checkAnswer(dfExample1.selectExpr("transform_values(i, (k, v) -> 'value')"), + Seq(Row(Map.empty[Integer, String]))) + + checkAnswer(dfExample1.selectExpr("transform_values(i, (k, v) -> true)"), + Seq(Row(Map.empty[Integer, Boolean]))) + + checkAnswer(dfExample2.selectExpr("transform_values(j, (k, v) -> k + cast(v as BIGINT))"), + Seq(Row(Map.empty[BigInt, BigInt]))) + } + + testEmpty() + dfExample1.cache() + dfExample2.cache() + testEmpty() + } + + test("transform values function - test null values") { + val dfExample1 = Seq( + Map[Int, Integer](1 -> 1, 2 -> 2, 3 -> 3, 4 -> 4) + ).toDF("a") + + val dfExample2 = Seq( + Map[Int, String](1 -> "a", 2 -> "b", 3 -> null) + ).toDF("b") + + def testNullValue(): Unit = { + checkAnswer(dfExample1.selectExpr("transform_values(a, (k, v) -> null)"), + Seq(Row(Map[Int, Integer](1 -> null, 2 -> null, 3 -> null, 4 -> null)))) + + checkAnswer(dfExample2.selectExpr( + "transform_values(b, (k, v) -> IF(v IS NULL, k + 1, k + 2))"), + Seq(Row(Map(1 -> 3, 2 -> 4, 3 -> 4)))) + } + + testNullValue() + dfExample1.cache() + dfExample2.cache() + testNullValue() + } + + test("transform values function - test invalid functions") { + val dfExample1 = Seq( + Map[Int, Int](1 -> 1, 9 -> 9, 8 -> 8, 7 -> 7) + ).toDF("i") + + val dfExample2 = Seq( + Map[String, String]("a" -> "b") + ).toDF("j") + + val dfExample3 = Seq( + Seq(1, 2, 3, 4) + ).toDF("x") + + def testInvalidLambdaFunctions(): Unit = { + + val ex1 = intercept[AnalysisException] { + dfExample1.selectExpr("transform_values(i, k -> k)") + } + assert(ex1.getMessage.contains("The number of lambda function arguments '1' does not match")) + + val ex2 = intercept[AnalysisException] { + dfExample2.selectExpr("transform_values(j, (k, v, x) -> k + 1)") + } + assert(ex2.getMessage.contains("The number of lambda function arguments '3' does not match")) + + val ex3 = intercept[AnalysisException] { + dfExample3.selectExpr("transform_values(x, (k, v) -> k + 1)") + } + assert(ex3.getMessage.contains( + "data type mismatch: argument 1 requires map type")) + } + + testInvalidLambdaFunctions() + dfExample1.cache() + dfExample2.cache() + dfExample3.cache() + testInvalidLambdaFunctions() + } + + test("arrays zip_with function - for primitive types") { + val df1 = Seq[(Seq[Integer], Seq[Integer])]( + (Seq(9001, 9002, 9003), Seq(4, 5, 6)), + (Seq(1, 2), Seq(3, 4)), + (Seq.empty, Seq.empty), + (null, null) + ).toDF("val1", "val2") + val df2 = Seq[(Seq[Integer], Seq[Long])]( + (Seq(1, null, 3), Seq(1L, 2L)), + (Seq(1, 2, 3), Seq(4L, 11L)) + ).toDF("val1", "val2") + val expectedValue1 = Seq( + Row(Seq(9005, 9007, 9009)), + Row(Seq(4, 6)), + Row(Seq.empty), + Row(null)) + checkAnswer(df1.selectExpr("zip_with(val1, val2, (x, y) -> x + y)"), expectedValue1) + val expectedValue2 = Seq( + Row(Seq(Row(1L, 1), Row(2L, null), Row(null, 3))), + Row(Seq(Row(4L, 1), Row(11L, 2), Row(null, 3)))) + checkAnswer(df2.selectExpr("zip_with(val1, val2, (x, y) -> (y, x))"), expectedValue2) + } + + test("arrays zip_with function - for non-primitive types") { + val df = Seq( + (Seq("a"), Seq("x", "y", "z")), + (Seq("a", null), Seq("x", "y")), + (Seq.empty[String], Seq.empty[String]), + (Seq("a", "b", "c"), null) + ).toDF("val1", "val2") + val expectedValue1 = Seq( + Row(Seq(Row("x", "a"), Row("y", null), Row("z", null))), + Row(Seq(Row("x", "a"), Row("y", null))), + Row(Seq.empty), + Row(null)) + checkAnswer(df.selectExpr("zip_with(val1, val2, (x, y) -> (y, x))"), expectedValue1) + } + + test("arrays zip_with function - invalid") { + val df = Seq( + (Seq("c", "a", "b"), Seq("x", "y", "z"), 1), + (Seq("b", null, "c", null), Seq("x"), 2), + (Seq.empty, Seq("x", "z"), 3), + (null, Seq("x", "z"), 4) + ).toDF("a1", "a2", "i") + val ex1 = intercept[AnalysisException] { + df.selectExpr("zip_with(a1, a2, x -> x)") + } + assert(ex1.getMessage.contains("The number of lambda function arguments '1' does not match")) + val ex2 = intercept[AnalysisException] { + df.selectExpr("zip_with(a1, a2, (acc, x) -> x, (acc, x) -> x)") + } + assert(ex2.getMessage.contains("Invalid number of arguments for function zip_with")) + val ex3 = intercept[AnalysisException] { + df.selectExpr("zip_with(i, a2, (acc, x) -> x)") + } + assert(ex3.getMessage.contains("data type mismatch: argument 1 requires array type")) + val ex4 = intercept[AnalysisException] { + df.selectExpr("zip_with(a1, a, (acc, x) -> x)") + } + assert(ex4.getMessage.contains("cannot resolve '`a`'")) + } + + private def assertValuesDoNotChangeAfterCoalesceOrUnion(v: Column): Unit = { + import DataFrameFunctionsSuite.CodegenFallbackExpr + for ((codegenFallback, wholeStage) <- Seq((true, false), (false, false), (false, true))) { + val c = if (codegenFallback) { + Column(CodegenFallbackExpr(v.expr)) + } else { + v + } + withSQLConf( + (SQLConf.CODEGEN_FALLBACK.key, codegenFallback.toString), + (SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key, wholeStage.toString)) { + val df = spark.range(0, 4, 1, 4).withColumn("c", c) + val rows = df.collect() + val rowsAfterCoalesce = df.coalesce(2).collect() + assert(rows === rowsAfterCoalesce, "Values changed after coalesce when " + + s"codegenFallback=$codegenFallback and wholeStage=$wholeStage.") + + val df1 = spark.range(0, 2, 1, 2).withColumn("c", c) + val rows1 = df1.collect() + val df2 = spark.range(2, 4, 1, 2).withColumn("c", c) + val rows2 = df2.collect() + val rowsAfterUnion = df1.union(df2).collect() + assert(rowsAfterUnion === rows1 ++ rows2, "Values changed after union when " + + s"codegenFallback=$codegenFallback and wholeStage=$wholeStage.") + } + } + } + + test("SPARK-14393: values generated by non-deterministic functions shouldn't change after " + + "coalesce or union") { + Seq( + monotonically_increasing_id(), spark_partition_id(), + rand(Random.nextLong()), randn(Random.nextLong()) + ).foreach(assertValuesDoNotChangeAfterCoalesceOrUnion(_)) } test("SPARK-21281 use string types by default if array and map have no argument") { @@ -1473,8 +2677,6 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { val funcsMustHaveAtLeastOneArg = ("coalesce", (df: DataFrame) => df.select(coalesce())) :: ("coalesce", (df: DataFrame) => df.selectExpr("coalesce()")) :: - ("named_struct", (df: DataFrame) => df.select(struct())) :: - ("named_struct", (df: DataFrame) => df.selectExpr("named_struct()")) :: ("hash", (df: DataFrame) => df.select(hash())) :: ("hash", (df: DataFrame) => df.selectExpr("hash()")) :: Nil funcsMustHaveAtLeastOneArg.foreach { case (name, func) => @@ -1492,6 +2694,14 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { assert(errMsg.contains(s"input to function $name requires at least two arguments")) } } + + test("SPARK-24734: Fix containsNull of Concat for array type") { + val df = Seq((Seq(1), Seq[Integer](null), Seq("a", "b"))).toDF("k1", "k2", "v") + val ex = intercept[Exception] { + df.select(map_from_arrays(concat($"k1", $"k2"), $"v")).show() + } + assert(ex.getMessage.contains("Cannot use null as map key")) + } } object DataFrameFunctionsSuite { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameHintSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameHintSuite.scala index 0dd5bdcba2e4c..7ef8b542c79a8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameHintSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameHintSuite.scala @@ -59,4 +59,14 @@ class DataFrameHintSuite extends AnalysisTest with SharedSQLContext { ) ) } + + test("coalesce and repartition hint") { + check( + df.hint("COALESCE", 10), + UnresolvedHint("COALESCE", Seq(10), df.logicalPlan)) + + check( + df.hint("REPARTITION", 100), + UnresolvedHint("REPARTITION", Seq(100), df.logicalPlan)) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala index 0d9eeabb397a1..e6b30f9956daf 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala @@ -196,7 +196,7 @@ class DataFrameJoinSuite extends QueryTest with SharedSQLContext { val df2 = Seq((1, 3, "1"), (5, 6, "5")).toDF("int", "int2", "str").as("b") // outer -> left - val outerJoin2Left = df.join(df2, $"a.int" === $"b.int", "outer").where($"a.int" === 3) + val outerJoin2Left = df.join(df2, $"a.int" === $"b.int", "outer").where($"a.int" >= 3) assert(outerJoin2Left.queryExecution.optimizedPlan.collect { case j @ Join(_, _, LeftOuter, _) => j }.size === 1) checkAnswer( @@ -204,7 +204,7 @@ class DataFrameJoinSuite extends QueryTest with SharedSQLContext { Row(3, 4, "3", null, null, null) :: Nil) // outer -> right - val outerJoin2Right = df.join(df2, $"a.int" === $"b.int", "outer").where($"b.int" === 5) + val outerJoin2Right = df.join(df2, $"a.int" === $"b.int", "outer").where($"b.int" >= 3) assert(outerJoin2Right.queryExecution.optimizedPlan.collect { case j @ Join(_, _, RightOuter, _) => j }.size === 1) checkAnswer( @@ -221,7 +221,7 @@ class DataFrameJoinSuite extends QueryTest with SharedSQLContext { Row(1, 2, "1", 1, 3, "1") :: Nil) // right -> inner - val rightJoin2Inner = df.join(df2, $"a.int" === $"b.int", "right").where($"a.int" === 1) + val rightJoin2Inner = df.join(df2, $"a.int" === $"b.int", "right").where($"a.int" > 0) assert(rightJoin2Inner.queryExecution.optimizedPlan.collect { case j @ Join(_, _, Inner, _) => j }.size === 1) checkAnswer( @@ -229,7 +229,7 @@ class DataFrameJoinSuite extends QueryTest with SharedSQLContext { Row(1, 2, "1", 1, 3, "1") :: Nil) // left -> inner - val leftJoin2Inner = df.join(df2, $"a.int" === $"b.int", "left").where($"b.int2" === 3) + val leftJoin2Inner = df.join(df2, $"a.int" === $"b.int", "left").where($"b.int2" > 0) assert(leftJoin2Inner.queryExecution.optimizedPlan.collect { case j @ Join(_, _, Inner, _) => j }.size === 1) checkAnswer( @@ -287,4 +287,12 @@ class DataFrameJoinSuite extends QueryTest with SharedSQLContext { dfOne.join(dfTwo, $"a" === $"b", "left").queryExecution.optimizedPlan } } + + test("SPARK-24385: Resolve ambiguity in self-joins with EqualNullSafe") { + withSQLConf(SQLConf.CROSS_JOINS_ENABLED.key -> "false") { + val df = spark.range(2) + // this throws an exception before the fix + df.join(df, df("id") <=> df("id")).queryExecution.optimizedPlan + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala index 6ca9ee57e8f49..b972b9ef93e5e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala @@ -27,28 +27,40 @@ class DataFramePivotSuite extends QueryTest with SharedSQLContext { import testImplicits._ test("pivot courses") { + val expected = Row(2012, 15000.0, 20000.0) :: Row(2013, 48000.0, 30000.0) :: Nil checkAnswer( courseSales.groupBy("year").pivot("course", Seq("dotNET", "Java")) .agg(sum($"earnings")), - Row(2012, 15000.0, 20000.0) :: Row(2013, 48000.0, 30000.0) :: Nil - ) + expected) + checkAnswer( + courseSales.groupBy($"year").pivot($"course", Seq("dotNET", "Java")) + .agg(sum($"earnings")), + expected) } test("pivot year") { + val expected = Row("dotNET", 15000.0, 48000.0) :: Row("Java", 20000.0, 30000.0) :: Nil checkAnswer( courseSales.groupBy("course").pivot("year", Seq(2012, 2013)).agg(sum($"earnings")), - Row("dotNET", 15000.0, 48000.0) :: Row("Java", 20000.0, 30000.0) :: Nil - ) + expected) + checkAnswer( + courseSales.groupBy('course).pivot('year, Seq(2012, 2013)).agg(sum('earnings)), + expected) } test("pivot courses with multiple aggregations") { + val expected = Row(2012, 15000.0, 7500.0, 20000.0, 20000.0) :: + Row(2013, 48000.0, 48000.0, 30000.0, 30000.0) :: Nil checkAnswer( courseSales.groupBy($"year") .pivot("course", Seq("dotNET", "Java")) .agg(sum($"earnings"), avg($"earnings")), - Row(2012, 15000.0, 7500.0, 20000.0, 20000.0) :: - Row(2013, 48000.0, 48000.0, 30000.0, 30000.0) :: Nil - ) + expected) + checkAnswer( + courseSales.groupBy($"year") + .pivot($"course", Seq("dotNET", "Java")) + .agg(sum($"earnings"), avg($"earnings")), + expected) } test("pivot year with string values (cast)") { @@ -67,17 +79,23 @@ class DataFramePivotSuite extends QueryTest with SharedSQLContext { test("pivot courses with no values") { // Note Java comes before dotNet in sorted order + val expected = Row(2012, 20000.0, 15000.0) :: Row(2013, 30000.0, 48000.0) :: Nil checkAnswer( courseSales.groupBy("year").pivot("course").agg(sum($"earnings")), - Row(2012, 20000.0, 15000.0) :: Row(2013, 30000.0, 48000.0) :: Nil - ) + expected) + checkAnswer( + courseSales.groupBy($"year").pivot($"course").agg(sum($"earnings")), + expected) } test("pivot year with no values") { + val expected = Row("dotNET", 15000.0, 48000.0) :: Row("Java", 20000.0, 30000.0) :: Nil checkAnswer( courseSales.groupBy("course").pivot("year").agg(sum($"earnings")), - Row("dotNET", 15000.0, 48000.0) :: Row("Java", 20000.0, 30000.0) :: Nil - ) + expected) + checkAnswer( + courseSales.groupBy($"course").pivot($"year").agg(sum($"earnings")), + expected) } test("pivot max values enforced") { @@ -181,10 +199,13 @@ class DataFramePivotSuite extends QueryTest with SharedSQLContext { } test("pivot with datatype not supported by PivotFirst") { + val expected = Row(Seq(1, 1, 1), Seq(2, 2, 2)) :: Nil checkAnswer( complexData.groupBy().pivot("b", Seq(true, false)).agg(max("a")), - Row(Seq(1, 1, 1), Seq(2, 2, 2)) :: Nil - ) + expected) + checkAnswer( + complexData.groupBy().pivot('b, Seq(true, false)).agg(max('a)), + expected) } test("pivot with datatype not supported by PivotFirst 2") { @@ -246,4 +267,45 @@ class DataFramePivotSuite extends QueryTest with SharedSQLContext { checkAnswer(df.select($"a".cast(StringType)), Row(tsWithZone)) } } + + test("SPARK-24722: pivoting nested columns") { + val expected = Row(2012, 15000.0, 20000.0) :: Row(2013, 48000.0, 30000.0) :: Nil + val df = trainingSales + .groupBy($"sales.year") + .pivot(lower($"sales.course"), Seq("dotNet", "Java").map(_.toLowerCase)) + .agg(sum($"sales.earnings")) + + checkAnswer(df, expected) + } + + test("SPARK-24722: references to multiple columns in the pivot column") { + val expected = Row(2012, 10000.0) :: Row(2013, 48000.0) :: Nil + val df = trainingSales + .groupBy($"sales.year") + .pivot(concat_ws("-", $"training", $"sales.course"), Seq("Experts-dotNET")) + .agg(sum($"sales.earnings")) + + checkAnswer(df, expected) + } + + test("SPARK-24722: pivoting by a constant") { + val expected = Row(2012, 35000.0) :: Row(2013, 78000.0) :: Nil + val df1 = trainingSales + .groupBy($"sales.year") + .pivot(lit(123), Seq(123)) + .agg(sum($"sales.earnings")) + + checkAnswer(df1, expected) + } + + test("SPARK-24722: aggregate as the pivot column") { + val exception = intercept[AnalysisException] { + trainingSales + .groupBy($"sales.year") + .pivot(min($"training"), Seq("Experts")) + .agg(sum($"sales.earnings")) + } + + assert(exception.getMessage.contains("aggregate functions are not allowed")) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index ea00d22bff001..f001b138f4b8e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -27,18 +27,20 @@ import scala.util.Random import org.scalatest.Matchers._ import org.apache.spark.SparkException +import org.apache.spark.scheduler.{SparkListener, SparkListenerJobEnd} import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.expressions.Uuid import org.apache.spark.sql.catalyst.plans.logical.{Filter, OneRowRelation, Union} -import org.apache.spark.sql.execution.{FilterExec, QueryExecution, WholeStageCodegenExec} +import org.apache.spark.sql.execution.{FilterExec, QueryExecution, TakeOrderedAndProjectExec, WholeStageCodegenExec} import org.apache.spark.sql.execution.aggregate.HashAggregateExec import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec, ShuffleExchangeExec} import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.{ExamplePoint, ExamplePointUDT, SharedSQLContext} -import org.apache.spark.sql.test.SQLTestData.TestData2 +import org.apache.spark.sql.test.SQLTestData.{NullInts, NullStrings, TestData2} import org.apache.spark.sql.types._ import org.apache.spark.util.Utils +import org.apache.spark.util.random.XORShiftRandom class DataFrameSuite extends QueryTest with SharedSQLContext { import testImplicits._ @@ -629,6 +631,74 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { assert(df4.schema.forall(!_.nullable)) } + test("except all") { + checkAnswer( + lowerCaseData.exceptAll(upperCaseData), + Row(1, "a") :: + Row(2, "b") :: + Row(3, "c") :: + Row(4, "d") :: Nil) + checkAnswer(lowerCaseData.exceptAll(lowerCaseData), Nil) + checkAnswer(upperCaseData.exceptAll(upperCaseData), Nil) + + // check null equality + checkAnswer( + nullInts.exceptAll(nullInts.filter("0 = 1")), + nullInts) + checkAnswer( + nullInts.exceptAll(nullInts), + Nil) + + // check that duplicate values are preserved + checkAnswer( + allNulls.exceptAll(allNulls.filter("0 = 1")), + Row(null) :: Row(null) :: Row(null) :: Row(null) :: Nil) + checkAnswer( + allNulls.exceptAll(allNulls.limit(2)), + Row(null) :: Row(null) :: Nil) + + // check that duplicates are retained. + val df = spark.sparkContext.parallelize( + NullStrings(1, "id1") :: + NullStrings(1, "id1") :: + NullStrings(2, "id1") :: + NullStrings(3, null) :: Nil).toDF("id", "value") + + checkAnswer( + df.exceptAll(df.filter("0 = 1")), + Row(1, "id1") :: + Row(1, "id1") :: + Row(2, "id1") :: + Row(3, null) :: Nil) + + // check if the empty set on the left side works + checkAnswer( + allNulls.filter("0 = 1").exceptAll(allNulls), + Nil) + + } + + test("exceptAll - nullability") { + val nonNullableInts = Seq(Tuple1(11), Tuple1(3)).toDF() + assert(nonNullableInts.schema.forall(!_.nullable)) + + val df1 = nonNullableInts.exceptAll(nullInts) + checkAnswer(df1, Row(11) :: Nil) + assert(df1.schema.forall(!_.nullable)) + + val df2 = nullInts.exceptAll(nonNullableInts) + checkAnswer(df2, Row(1) :: Row(2) :: Row(null) :: Nil) + assert(df2.schema.forall(_.nullable)) + + val df3 = nullInts.exceptAll(nullInts) + checkAnswer(df3, Nil) + assert(df3.schema.forall(_.nullable)) + + val df4 = nonNullableInts.exceptAll(nonNullableInts) + checkAnswer(df4, Nil) + assert(df4.schema.forall(!_.nullable)) + } + test("intersect") { checkAnswer( lowerCaseData.intersect(lowerCaseData), @@ -681,6 +751,60 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { assert(df4.schema.forall(!_.nullable)) } + test("intersectAll") { + checkAnswer( + lowerCaseDataWithDuplicates.intersectAll(lowerCaseDataWithDuplicates), + Row(1, "a") :: + Row(2, "b") :: + Row(2, "b") :: + Row(3, "c") :: + Row(3, "c") :: + Row(3, "c") :: + Row(4, "d") :: Nil) + checkAnswer(lowerCaseData.intersectAll(upperCaseData), Nil) + + // check null equality + checkAnswer( + nullInts.intersectAll(nullInts), + Row(1) :: + Row(2) :: + Row(3) :: + Row(null) :: Nil) + + // Duplicate nulls are preserved. + checkAnswer( + allNulls.intersectAll(allNulls), + Row(null) :: Row(null) :: Row(null) :: Row(null) :: Nil) + + val df_left = Seq(1, 2, 2, 3, 3, 4).toDF("id") + val df_right = Seq(1, 2, 2, 3).toDF("id") + + checkAnswer( + df_left.intersectAll(df_right), + Row(1) :: Row(2) :: Row(2) :: Row(3) :: Nil) + } + + test("intersectAll - nullability") { + val nonNullableInts = Seq(Tuple1(1), Tuple1(3)).toDF() + assert(nonNullableInts.schema.forall(!_.nullable)) + + val df1 = nonNullableInts.intersectAll(nullInts) + checkAnswer(df1, Row(1) :: Row(3) :: Nil) + assert(df1.schema.forall(!_.nullable)) + + val df2 = nullInts.intersectAll(nonNullableInts) + checkAnswer(df2, Row(1) :: Row(3) :: Nil) + assert(df2.schema.forall(!_.nullable)) + + val df3 = nullInts.intersectAll(nullInts) + checkAnswer(df3, Row(1) :: Row(2) :: Row(3) :: Row(null) :: Nil) + assert(df3.schema.forall(_.nullable)) + + val df4 = nonNullableInts.intersectAll(nonNullableInts) + checkAnswer(df4, Row(1) :: Row(3) :: Nil) + assert(df4.schema.forall(!_.nullable)) + } + test("udf") { val foo = udf((a: Int, b: String) => a.toString + b) @@ -1606,10 +1730,8 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { } test("SPARK-9083: sort with non-deterministic expressions") { - import org.apache.spark.util.random.XORShiftRandom - val seed = 33 - val df = (1 to 100).map(Tuple1.apply).toDF("i") + val df = (1 to 100).map(Tuple1.apply).toDF("i").repartition(1) val random = new XORShiftRandom(seed) val expected = (1 to 100).map(_ -> random.nextDouble()).sortBy(_._2).map(_._1) val actual = df.sort(rand(seed)).collect().map(_.getInt(0)) @@ -2320,6 +2442,58 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { assert(df.queryExecution.executedPlan.isInstanceOf[WholeStageCodegenExec]) } + test("SPARK-24165: CaseWhen/If - nullability of nested types") { + val rows = new java.util.ArrayList[Row]() + rows.add(Row(true, ("x", 1), Seq("x", "y"), Map(0 -> "x"))) + rows.add(Row(false, (null, 2), Seq(null, "z"), Map(0 -> null))) + val schema = StructType(Seq( + StructField("cond", BooleanType, true), + StructField("s", StructType(Seq( + StructField("val1", StringType, true), + StructField("val2", IntegerType, false) + )), false), + StructField("a", ArrayType(StringType, true)), + StructField("m", MapType(IntegerType, StringType, true)) + )) + + val sourceDF = spark.createDataFrame(rows, schema) + + def structWhenDF: DataFrame = sourceDF + .select(when('cond, struct(lit("a").as("val1"), lit(10).as("val2"))).otherwise('s) as "res") + .select('res.getField("val1")) + def arrayWhenDF: DataFrame = sourceDF + .select(when('cond, array(lit("a"), lit("b"))).otherwise('a) as "res") + .select('res.getItem(0)) + def mapWhenDF: DataFrame = sourceDF + .select(when('cond, map(lit(0), lit("a"))).otherwise('m) as "res") + .select('res.getItem(0)) + + def structIfDF: DataFrame = sourceDF + .select(expr("if(cond, struct('a' as val1, 10 as val2), s)") as "res") + .select('res.getField("val1")) + def arrayIfDF: DataFrame = sourceDF + .select(expr("if(cond, array('a', 'b'), a)") as "res") + .select('res.getItem(0)) + def mapIfDF: DataFrame = sourceDF + .select(expr("if(cond, map(0, 'a'), m)") as "res") + .select('res.getItem(0)) + + def checkResult(): Unit = { + checkAnswer(structWhenDF, Seq(Row("a"), Row(null))) + checkAnswer(arrayWhenDF, Seq(Row("a"), Row(null))) + checkAnswer(mapWhenDF, Seq(Row("a"), Row(null))) + checkAnswer(structIfDF, Seq(Row("a"), Row(null))) + checkAnswer(arrayIfDF, Seq(Row("a"), Row(null))) + checkAnswer(mapIfDF, Seq(Row("a"), Row(null))) + } + + // Test with local relation, the Project will be evaluated without codegen + checkResult() + // Test with cached relation, the Project will be evaluated with codegen + sourceDF.cache() + checkResult() + } + test("Uuid expressions should produce same results at retries in the same DataFrame") { val df = spark.range(1).select($"id", new Column(Uuid())) checkAnswer(df, df.collect()) @@ -2329,4 +2503,100 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { val mapWithBinaryKey = map(lit(Array[Byte](1.toByte)), lit(1)) checkAnswer(spark.range(1).select(mapWithBinaryKey.getItem(Array[Byte](1.toByte))), Row(1)) } + + test("SPARK-24781: Using a reference from Dataset in Filter/Sort") { + val df = Seq(("test1", 0), ("test2", 1)).toDF("name", "id") + val filter1 = df.select(df("name")).filter(df("id") === 0) + val filter2 = df.select(col("name")).filter(col("id") === 0) + checkAnswer(filter1, filter2.collect()) + + val sort1 = df.select(df("name")).orderBy(df("id")) + val sort2 = df.select(col("name")).orderBy(col("id")) + checkAnswer(sort1, sort2.collect()) + } + + test("SPARK-24781: Using a reference not in aggregation in Filter/Sort") { + withSQLConf(SQLConf.DATAFRAME_RETAIN_GROUP_COLUMNS.key -> "false") { + val df = Seq(("test1", 0), ("test2", 1)).toDF("name", "id") + + val aggPlusSort1 = df.groupBy(df("name")).agg(count(df("name"))).orderBy(df("name")) + val aggPlusSort2 = df.groupBy(col("name")).agg(count(col("name"))).orderBy(col("name")) + checkAnswer(aggPlusSort1, aggPlusSort2.collect()) + + val aggPlusFilter1 = df.groupBy(df("name")).agg(count(df("name"))).filter(df("name") === 0) + val aggPlusFilter2 = df.groupBy(col("name")).agg(count(col("name"))).filter(col("name") === 0) + checkAnswer(aggPlusFilter1, aggPlusFilter2.collect()) + } + } + + test("SPARK-25159: json schema inference should only trigger one job") { + withTempPath { path => + // This test is to prove that the `JsonInferSchema` does not use `RDD#toLocalIterator` which + // triggers one Spark job per RDD partition. + Seq(1 -> "a", 2 -> "b").toDF("i", "p") + // The data set has 2 partitions, so Spark will write at least 2 json files. + // Use a non-splittable compression (gzip), to make sure the json scan RDD has at least 2 + // partitions. + .write.partitionBy("p").option("compression", "gzip").json(path.getCanonicalPath) + + var numJobs = 0 + sparkContext.addSparkListener(new SparkListener { + override def onJobEnd(jobEnd: SparkListenerJobEnd): Unit = { + numJobs += 1 + } + }) + + val df = spark.read.json(path.getCanonicalPath) + assert(df.columns === Array("i", "p")) + assert(numJobs == 1) + } + } + + test("SPARK-25352: Ordered global limit when more than topKSortFallbackThreshold ") { + withSQLConf(SQLConf.LIMIT_FLAT_GLOBAL_LIMIT.key -> "true") { + val baseDf = spark.range(1000).toDF.repartition(3).sort("id") + + withSQLConf(SQLConf.TOP_K_SORT_FALLBACK_THRESHOLD.key -> "100") { + val expected = baseDf.limit(99) + val takeOrderedNode1 = expected.queryExecution.executedPlan + .find(_.isInstanceOf[TakeOrderedAndProjectExec]) + assert(takeOrderedNode1.isDefined) + + val result = baseDf.limit(100) + val takeOrderedNode2 = result.queryExecution.executedPlan + .find(_.isInstanceOf[TakeOrderedAndProjectExec]) + assert(takeOrderedNode2.isEmpty) + + checkAnswer(expected, result.collect().take(99)) + } + } + } + + test("SPARK-25368 Incorrect predicate pushdown returns wrong result") { + def check(newCol: Column, filter: Column, result: Seq[Row]): Unit = { + val df1 = spark.createDataFrame(Seq( + (1, 1) + )).toDF("a", "b").withColumn("c", newCol) + + val df2 = df1.union(df1).withColumn("d", spark_partition_id).filter(filter) + checkAnswer(df2, result) + } + + check(lit(null).cast("int"), $"c".isNull, Seq(Row(1, 1, null, 0), Row(1, 1, null, 1))) + check(lit(null).cast("int"), $"c".isNotNull, Seq()) + check(lit(2).cast("int"), $"c".isNull, Seq()) + check(lit(2).cast("int"), $"c".isNotNull, Seq(Row(1, 1, 2, 0), Row(1, 1, 2, 1))) + check(lit(2).cast("int"), $"c" === 2, Seq(Row(1, 1, 2, 0), Row(1, 1, 2, 1))) + check(lit(2).cast("int"), $"c" =!= 2, Seq()) + } + + test("SPARK-25402 Null handling in BooleanSimplification") { + val schema = StructType.fromDDL("a boolean, b int") + val rows = Seq(Row(null, 1)) + + val rdd = sparkContext.parallelize(rows) + val df = spark.createDataFrame(rdd, schema) + + checkAnswer(df.where("(NOT a) OR a"), Seq.empty) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTimeWindowingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTimeWindowingSuite.scala index 6fe356877c268..2953425b1db49 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTimeWindowingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTimeWindowingSuite.scala @@ -43,6 +43,22 @@ class DataFrameTimeWindowingSuite extends QueryTest with SharedSQLContext with B ) } + test("SPARK-21590: tumbling window using negative start time") { + val df = Seq( + ("2016-03-27 19:39:30", 1, "a"), + ("2016-03-27 19:39:25", 2, "a")).toDF("time", "value", "id") + + checkAnswer( + df.groupBy(window($"time", "10 seconds", "10 seconds", "-5 seconds")) + .agg(count("*").as("counts")) + .orderBy($"window.start".asc) + .select($"window.start".cast("string"), $"window.end".cast("string"), $"counts"), + Seq( + Row("2016-03-27 19:39:25", "2016-03-27 19:39:35", 2) + ) + ) + } + test("tumbling window groupBy statement") { val df = Seq( ("2016-03-27 19:39:34", 1, "a"), @@ -72,6 +88,20 @@ class DataFrameTimeWindowingSuite extends QueryTest with SharedSQLContext with B Seq(Row(1), Row(1), Row(1))) } + test("SPARK-21590: tumbling window groupBy statement with negative startTime") { + val df = Seq( + ("2016-03-27 19:39:34", 1, "a"), + ("2016-03-27 19:39:56", 2, "a"), + ("2016-03-27 19:39:27", 4, "b")).toDF("time", "value", "id") + + checkAnswer( + df.groupBy(window($"time", "10 seconds", "10 seconds", "-5 seconds"), $"id") + .agg(count("*").as("counts")) + .orderBy($"window.start".asc) + .select("counts"), + Seq(Row(1), Row(1), Row(1))) + } + test("tumbling window with multi-column projection") { val df = Seq( ("2016-03-27 19:39:34", 1, "a"), @@ -309,4 +339,19 @@ class DataFrameTimeWindowingSuite extends QueryTest with SharedSQLContext with B ) } } + + test("SPARK-21590: time window in SQL with three expressions including negative start time") { + withTempTable { table => + checkAnswer( + spark.sql( + s"""select window(time, "10 seconds", 10000000, "-5 seconds"), value from $table""") + .select($"window.start".cast(StringType), $"window.end".cast(StringType), $"value"), + Seq( + Row("2016-03-27 19:39:25", "2016-03-27 19:39:35", 1), + Row("2016-03-27 19:39:25", "2016-03-27 19:39:35", 4), + Row("2016-03-27 19:39:55", "2016-03-27 19:40:05", 2) + ) + ) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala index 97a843978f0bd..78277d7dcf757 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala @@ -30,6 +30,7 @@ import org.apache.spark.sql.types._ * Window function testing for DataFrame API. */ class DataFrameWindowFunctionsSuite extends QueryTest with SharedSQLContext { + import testImplicits._ test("reuse window partitionBy") { @@ -72,9 +73,9 @@ class DataFrameWindowFunctionsSuite extends QueryTest with SharedSQLContext { cume_dist().over(Window.partitionBy("value").orderBy("key")), percent_rank().over(Window.partitionBy("value").orderBy("key"))), Row(1, 1, 1, 1.0d, 1, 1, 1, 1, 1, 1, 1.0d, 0.0d) :: - Row(1, 1, 1, 1.0d, 1, 1, 1, 1, 1, 1, 1.0d / 3.0d, 0.0d) :: - Row(2, 2, 1, 5.0d / 3.0d, 3, 5, 1, 2, 2, 2, 1.0d, 0.5d) :: - Row(2, 2, 1, 5.0d / 3.0d, 3, 5, 2, 3, 2, 2, 1.0d, 0.5d) :: Nil) + Row(1, 1, 1, 1.0d, 1, 1, 1, 1, 1, 1, 1.0d / 3.0d, 0.0d) :: + Row(2, 2, 1, 5.0d / 3.0d, 3, 5, 1, 2, 2, 2, 1.0d, 0.5d) :: + Row(2, 2, 1, 5.0d / 3.0d, 3, 5, 2, 3, 2, 2, 1.0d, 0.5d) :: Nil) } test("window function should fail if order by clause is not specified") { @@ -162,12 +163,12 @@ class DataFrameWindowFunctionsSuite extends QueryTest with SharedSQLContext { Seq( Row("a", -50.0, 50.0, 50.0, 7.0710678118654755, 7.0710678118654755), Row("b", -50.0, 50.0, 50.0, 7.0710678118654755, 7.0710678118654755), - Row("c", 0.0, 0.0, 0.0, 0.0, 0.0 ), - Row("d", 0.0, 0.0, 0.0, 0.0, 0.0 ), - Row("e", 24.0, 12.0, 12.0, 3.4641016151377544, 3.4641016151377544 ), - Row("f", 24.0, 12.0, 12.0, 3.4641016151377544, 3.4641016151377544 ), - Row("g", 24.0, 12.0, 12.0, 3.4641016151377544, 3.4641016151377544 ), - Row("h", 24.0, 12.0, 12.0, 3.4641016151377544, 3.4641016151377544 ), + Row("c", 0.0, 0.0, 0.0, 0.0, 0.0), + Row("d", 0.0, 0.0, 0.0, 0.0, 0.0), + Row("e", 24.0, 12.0, 12.0, 3.4641016151377544, 3.4641016151377544), + Row("f", 24.0, 12.0, 12.0, 3.4641016151377544, 3.4641016151377544), + Row("g", 24.0, 12.0, 12.0, 3.4641016151377544, 3.4641016151377544), + Row("h", 24.0, 12.0, 12.0, 3.4641016151377544, 3.4641016151377544), Row("i", Double.NaN, Double.NaN, Double.NaN, Double.NaN, Double.NaN))) } @@ -326,7 +327,7 @@ class DataFrameWindowFunctionsSuite extends QueryTest with SharedSQLContext { var_samp($"value").over(window), approx_count_distinct($"value").over(window)), Seq.fill(4)(Row("a", 1.0d / 4.0d, 1.0d / 3.0d, 2)) - ++ Seq.fill(3)(Row("b", 2.0d / 3.0d, 1.0d, 3))) + ++ Seq.fill(3)(Row("b", 2.0d / 3.0d, 1.0d, 3))) } test("window function with aggregates") { @@ -624,7 +625,7 @@ class DataFrameWindowFunctionsSuite extends QueryTest with SharedSQLContext { test("SPARK-24575: Window functions inside WHERE and HAVING clauses") { def checkAnalysisError(df: => DataFrame): Unit = { - val thrownException = the [AnalysisException] thrownBy { + val thrownException = the[AnalysisException] thrownBy { df.queryExecution.analyzed } assert(thrownException.message.contains("window functions inside WHERE and HAVING clauses")) @@ -658,4 +659,26 @@ class DataFrameWindowFunctionsSuite extends QueryTest with SharedSQLContext { |GROUP BY a |HAVING SUM(b) = 5 AND RANK() OVER(ORDER BY a) = 1""".stripMargin)) } + + test("window functions in multiple selects") { + val df = Seq( + ("S1", "P1", 100), + ("S1", "P1", 700), + ("S2", "P1", 200), + ("S2", "P2", 300) + ).toDF("sno", "pno", "qty") + + val w1 = Window.partitionBy("sno") + val w2 = Window.partitionBy("sno", "pno") + + checkAnswer( + df.select($"sno", $"pno", $"qty", sum($"qty").over(w2).alias("sum_qty_2")) + .select($"sno", $"pno", $"qty", col("sum_qty_2"), sum("qty").over(w1).alias("sum_qty_1")), + Seq( + Row("S1", "P1", 100, 800, 800), + Row("S1", "P1", 700, 800, 800), + Row("S2", "P1", 200, 200, 500), + Row("S2", "P2", 300, 300, 500))) + + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala index 0e7eaa9e88d57..538ea3c66c40e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala @@ -148,6 +148,41 @@ object VeryComplexResultAgg extends Aggregator[Row, String, ComplexAggData] { } +case class OptionBooleanData(name: String, isGood: Option[Boolean]) + +case class OptionBooleanAggregator(colName: String) + extends Aggregator[Row, Option[Boolean], Option[Boolean]] { + + override def zero: Option[Boolean] = None + + override def reduce(buffer: Option[Boolean], row: Row): Option[Boolean] = { + val index = row.fieldIndex(colName) + val value = if (row.isNullAt(index)) { + Option.empty[Boolean] + } else { + Some(row.getBoolean(index)) + } + merge(buffer, value) + } + + override def merge(b1: Option[Boolean], b2: Option[Boolean]): Option[Boolean] = { + if ((b1.isDefined && b1.get) || (b2.isDefined && b2.get)) { + Some(true) + } else if (b1.isDefined) { + b1 + } else { + b2 + } + } + + override def finish(reduction: Option[Boolean]): Option[Boolean] = reduction + + override def bufferEncoder: Encoder[Option[Boolean]] = OptionalBoolEncoder + override def outputEncoder: Encoder[Option[Boolean]] = OptionalBoolEncoder + + def OptionalBoolEncoder: Encoder[Option[Boolean]] = ExpressionEncoder() +} + class DatasetAggregatorSuite extends QueryTest with SharedSQLContext { import testImplicits._ @@ -333,4 +368,29 @@ class DatasetAggregatorSuite extends QueryTest with SharedSQLContext { df.groupBy($"i").agg(VeryComplexResultAgg.toColumn), Row(1, Row(Row(1, "a"), Row(1, "a"))) :: Row(2, Row(Row(2, "bc"), Row(2, "bc"))) :: Nil) } + + test("SPARK-24569: Aggregator with output type Option[Boolean] creates column of type Row") { + val df = Seq( + OptionBooleanData("bob", Some(true)), + OptionBooleanData("bob", Some(false)), + OptionBooleanData("bob", None)).toDF() + val group = df + .groupBy("name") + .agg(OptionBooleanAggregator("isGood").toColumn.alias("isGood")) + assert(df.schema == group.schema) + checkAnswer(group, Row("bob", true) :: Nil) + checkDataset(group.as[OptionBooleanData], OptionBooleanData("bob", Some(true))) + } + + test("SPARK-24569: groupByKey with Aggregator of output type Option[Boolean]") { + val df = Seq( + OptionBooleanData("bob", Some(true)), + OptionBooleanData("bob", Some(false)), + OptionBooleanData("bob", None)).toDF() + val grouped = df.groupByKey((r: Row) => r.getString(0)) + .agg(OptionBooleanAggregator("isGood").toColumn).toDF("name", "isGood") + + assert(grouped.schema == df.schema) + checkDataset(grouped.as[OptionBooleanData], OptionBooleanData("bob", Some(true))) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 2d20c50584c03..4e593ff046a53 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -611,7 +611,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext { ).toDF("id", "stringData") val sampleDF = df.sample(false, 0.7, 50) // After sampling, sampleDF doesn't contain id=1. - assert(!sampleDF.select("id").collect.contains(1)) + assert(!sampleDF.select("id").as[Int].collect.contains(1)) // simpleUdf should not encounter id=1. checkAnswer(sampleDF.select(simpleUdf($"id")), List.fill(sampleDF.count.toInt)(Row(1))) } @@ -969,6 +969,55 @@ class DatasetSuite extends QueryTest with SharedSQLContext { checkShowString(ds, expected) } + test("SPARK-25108 Fix the show method to display the full width character alignment problem") { + // scalastyle:off nonascii + val df = Seq( + (0, null, 1), + (0, "", 1), + (0, "ab c", 1), + (0, "1098", 1), + (0, "mø", 1), + (0, "γύρ", 1), + (0, "pê", 1), + (0, "ー", 1), + (0, "测", 1), + (0, "か", 1), + (0, "걸", 1), + (0, "à", 1), + (0, "焼", 1), + (0, "羍む", 1), + (0, "뺭ᾘ", 1), + (0, "\u0967\u0968\u0969", 1) + ).toDF("b", "a", "c") + // scalastyle:on nonascii + val ds = df.as[ClassData] + val expected = + // scalastyle:off nonascii + """+---+----+---+ + || b| a| c| + |+---+----+---+ + || 0|null| 1| + || 0| | 1| + || 0|ab c| 1| + || 0|1098| 1| + || 0| mø| 1| + || 0| γύρ| 1| + || 0| pê| 1| + || 0| ー| 1| + || 0| 测| 1| + || 0| か| 1| + || 0| 걸| 1| + || 0| à| 1| + || 0| 焼| 1| + || 0|羍む| 1| + || 0| 뺭ᾘ| 1| + || 0| १२३| 1| + |+---+----+---+ + |""".stripMargin + // scalastyle:on nonascii + checkShowString(ds, expected) + } + test( "SPARK-15112: EmbedDeserializerInFilter should not optimize plan fragment that changes schema" ) { @@ -1296,7 +1345,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext { new java.sql.Timestamp(100000)) } - test("SPARK-19896: cannot have circular references in in case class") { + test("SPARK-19896: cannot have circular references in case class") { val errMsg1 = intercept[UnsupportedOperationException] { Seq(CircularReferenceClassA(null)).toDS } @@ -1467,6 +1516,17 @@ class DatasetSuite extends QueryTest with SharedSQLContext { intercept[NullPointerException](ds.as[(Int, Int)].collect()) } + test("SPARK-24569: Option of primitive types are mistakenly mapped to struct type") { + withSQLConf(SQLConf.CROSS_JOINS_ENABLED.key -> "true") { + val a = Seq(Some(1)).toDS + val b = Seq(Some(1.2)).toDS + val expected = Seq((Some(1), Some(1.2))).toDS + val joined = a.joinWith(b, lit(true)) + assert(joined.schema == expected.schema) + checkDataset(joined, expected.collect: _*) + } + } + test("SPARK-24548: Dataset with tuple encoders should have correct schema") { val encoder = Encoders.tuple(newStringEncoder, Encoders.tuple(newStringEncoder, newStringEncoder)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala index 237412aa692e5..3af80b36ec42c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala @@ -663,7 +663,7 @@ class DateFunctionsSuite extends QueryTest with SharedSQLContext { checkAnswer(df.selectExpr("datediff(a, d)"), Seq(Row(1), Row(1))) } - test("from_utc_timestamp") { + test("from_utc_timestamp with literal zone") { val df = Seq( (Timestamp.valueOf("2015-07-24 00:00:00"), "2015-07-24 00:00:00"), (Timestamp.valueOf("2015-07-25 00:00:00"), "2015-07-25 00:00:00") @@ -680,7 +680,24 @@ class DateFunctionsSuite extends QueryTest with SharedSQLContext { Row(Timestamp.valueOf("2015-07-24 17:00:00")))) } - test("to_utc_timestamp") { + test("from_utc_timestamp with column zone") { + val df = Seq( + (Timestamp.valueOf("2015-07-24 00:00:00"), "2015-07-24 00:00:00", "CET"), + (Timestamp.valueOf("2015-07-25 00:00:00"), "2015-07-25 00:00:00", "PST") + ).toDF("a", "b", "c") + checkAnswer( + df.select(from_utc_timestamp(col("a"), col("c"))), + Seq( + Row(Timestamp.valueOf("2015-07-24 02:00:00")), + Row(Timestamp.valueOf("2015-07-24 17:00:00")))) + checkAnswer( + df.select(from_utc_timestamp(col("b"), col("c"))), + Seq( + Row(Timestamp.valueOf("2015-07-24 02:00:00")), + Row(Timestamp.valueOf("2015-07-24 17:00:00")))) + } + + test("to_utc_timestamp with literal zone") { val df = Seq( (Timestamp.valueOf("2015-07-24 00:00:00"), "2015-07-24 00:00:00"), (Timestamp.valueOf("2015-07-25 00:00:00"), "2015-07-25 00:00:00") @@ -697,6 +714,23 @@ class DateFunctionsSuite extends QueryTest with SharedSQLContext { Row(Timestamp.valueOf("2015-07-25 07:00:00")))) } + test("to_utc_timestamp with column zone") { + val df = Seq( + (Timestamp.valueOf("2015-07-24 00:00:00"), "2015-07-24 00:00:00", "PST"), + (Timestamp.valueOf("2015-07-25 00:00:00"), "2015-07-25 00:00:00", "CET") + ).toDF("a", "b", "c") + checkAnswer( + df.select(to_utc_timestamp(col("a"), col("c"))), + Seq( + Row(Timestamp.valueOf("2015-07-24 07:00:00")), + Row(Timestamp.valueOf("2015-07-24 22:00:00")))) + checkAnswer( + df.select(to_utc_timestamp(col("b"), col("c"))), + Seq( + Row(Timestamp.valueOf("2015-07-24 07:00:00")), + Row(Timestamp.valueOf("2015-07-24 22:00:00")))) + } + test("SPARK-23715: to/from_utc_timestamp can retain the previous behavior") { withSQLConf(SQLConf.REJECT_TIMEZONE_IN_STRING.key -> "false") { checkAnswer( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ExplainSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ExplainSuite.scala new file mode 100644 index 0000000000000..56d300e30a58e --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/ExplainSuite.scala @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types.StructType + +class ExplainSuite extends QueryTest with SharedSQLContext { + import testImplicits._ + + /** + * Runs the plan and makes sure the plans contains all of the keywords. + */ + private def checkKeywordsExistsInExplain(df: DataFrame, keywords: String*): Unit = { + val output = new java.io.ByteArrayOutputStream() + Console.withOut(output) { + df.explain(extended = false) + } + for (key <- keywords) { + assert(output.toString.contains(key)) + } + } + + test("SPARK-23034 show rdd names in RDD scan nodes (Dataset)") { + val rddWithName = spark.sparkContext.parallelize(Row(1, "abc") :: Nil).setName("testRdd") + val df = spark.createDataFrame(rddWithName, StructType.fromDDL("c0 int, c1 string")) + checkKeywordsExistsInExplain(df, keywords = "Scan ExistingRDD testRdd") + } + + test("SPARK-23034 show rdd names in RDD scan nodes (DataFrame)") { + val rddWithName = spark.sparkContext.parallelize(ExplainSingleData(1) :: Nil).setName("testRdd") + val df = spark.createDataFrame(rddWithName) + checkKeywordsExistsInExplain(df, keywords = "Scan testRdd") + } + + test("SPARK-24850 InMemoryRelation string representation does not include cached plan") { + val df = Seq(1).toDF("a").cache() + checkKeywordsExistsInExplain(df, + keywords = "InMemoryRelation", "StorageLevel(disk, memory, deserialized, 1 replicas)") + } +} + +case class ExplainSingleData(id: Int) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala index 86f9647b4ac4c..94f163708832c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala @@ -20,10 +20,13 @@ package org.apache.spark.sql import java.io.{File, FileNotFoundException} import java.util.Locale +import scala.collection.mutable + import org.apache.hadoop.fs.Path import org.scalatest.BeforeAndAfterAll import org.apache.spark.SparkException +import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskEnd} import org.apache.spark.sql.TestingUDT.{IntervalData, IntervalUDT, NullData, NullUDT} import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf @@ -205,58 +208,116 @@ class FileBasedDataSourceSuite extends QueryTest with SharedSQLContext with Befo } } + // Text file format only supports string type + test("SPARK-24691 error handling for unsupported types - text") { + withTempDir { dir => + // write path + val textDir = new File(dir, "text").getCanonicalPath + var msg = intercept[AnalysisException] { + Seq(1).toDF.write.text(textDir) + }.getMessage + assert(msg.contains("Text data source does not support int data type")) + + msg = intercept[AnalysisException] { + Seq(1.2).toDF.write.text(textDir) + }.getMessage + assert(msg.contains("Text data source does not support double data type")) + + msg = intercept[AnalysisException] { + Seq(true).toDF.write.text(textDir) + }.getMessage + assert(msg.contains("Text data source does not support boolean data type")) + + msg = intercept[AnalysisException] { + Seq(1).toDF("a").selectExpr("struct(a)").write.text(textDir) + }.getMessage + assert(msg.contains("Text data source does not support struct data type")) + + msg = intercept[AnalysisException] { + Seq((Map("Tesla" -> 3))).toDF("cars").write.mode("overwrite").text(textDir) + }.getMessage + assert(msg.contains("Text data source does not support map data type")) + + msg = intercept[AnalysisException] { + Seq((Array("Tesla", "Chevy", "Ford"))).toDF("brands") + .write.mode("overwrite").text(textDir) + }.getMessage + assert(msg.contains("Text data source does not support array data type")) + + // read path + Seq("aaa").toDF.write.mode("overwrite").text(textDir) + msg = intercept[AnalysisException] { + val schema = StructType(StructField("a", IntegerType, true) :: Nil) + spark.read.schema(schema).text(textDir).collect() + }.getMessage + assert(msg.contains("Text data source does not support int data type")) + + msg = intercept[AnalysisException] { + val schema = StructType(StructField("a", DoubleType, true) :: Nil) + spark.read.schema(schema).text(textDir).collect() + }.getMessage + assert(msg.contains("Text data source does not support double data type")) + + msg = intercept[AnalysisException] { + val schema = StructType(StructField("a", BooleanType, true) :: Nil) + spark.read.schema(schema).text(textDir).collect() + }.getMessage + assert(msg.contains("Text data source does not support boolean data type")) + } + } + // Unsupported data types of csv, json, orc, and parquet are as follows; - // csv -> R/W: Interval, Null, Array, Map, Struct - // json -> W: Interval - // orc -> W: Interval, Null + // csv -> R/W: Null, Array, Map, Struct + // json -> R/W: Interval + // orc -> R/W: Interval, W: Null // parquet -> R/W: Interval, Null test("SPARK-24204 error handling for unsupported Array/Map/Struct types - csv") { withTempDir { dir => val csvDir = new File(dir, "csv").getCanonicalPath - var msg = intercept[UnsupportedOperationException] { + var msg = intercept[AnalysisException] { Seq((1, "Tesla")).toDF("a", "b").selectExpr("struct(a, b)").write.csv(csvDir) }.getMessage assert(msg.contains("CSV data source does not support struct data type")) - msg = intercept[UnsupportedOperationException] { + msg = intercept[AnalysisException] { val schema = StructType.fromDDL("a struct") spark.range(1).write.mode("overwrite").csv(csvDir) spark.read.schema(schema).csv(csvDir).collect() }.getMessage assert(msg.contains("CSV data source does not support struct data type")) - msg = intercept[UnsupportedOperationException] { + msg = intercept[AnalysisException] { Seq((1, Map("Tesla" -> 3))).toDF("id", "cars").write.mode("overwrite").csv(csvDir) }.getMessage assert(msg.contains("CSV data source does not support map data type")) - msg = intercept[UnsupportedOperationException] { + msg = intercept[AnalysisException] { val schema = StructType.fromDDL("a map") spark.range(1).write.mode("overwrite").csv(csvDir) spark.read.schema(schema).csv(csvDir).collect() }.getMessage assert(msg.contains("CSV data source does not support map data type")) - msg = intercept[UnsupportedOperationException] { + msg = intercept[AnalysisException] { Seq((1, Array("Tesla", "Chevy", "Ford"))).toDF("id", "brands") .write.mode("overwrite").csv(csvDir) }.getMessage assert(msg.contains("CSV data source does not support array data type")) - msg = intercept[UnsupportedOperationException] { + msg = intercept[AnalysisException] { val schema = StructType.fromDDL("a array") spark.range(1).write.mode("overwrite").csv(csvDir) spark.read.schema(schema).csv(csvDir).collect() }.getMessage assert(msg.contains("CSV data source does not support array data type")) - msg = intercept[UnsupportedOperationException] { + msg = intercept[AnalysisException] { Seq((1, new UDT.MyDenseVector(Array(0.25, 2.25, 4.25)))).toDF("id", "vectors") .write.mode("overwrite").csv(csvDir) }.getMessage assert(msg.contains("CSV data source does not support array data type")) - msg = intercept[UnsupportedOperationException] { + msg = intercept[AnalysisException] { val schema = StructType(StructField("a", new UDT.MyDenseVectorUDT(), true) :: Nil) spark.range(1).write.mode("overwrite").csv(csvDir) spark.read.schema(schema).csv(csvDir).collect() @@ -276,7 +337,7 @@ class FileBasedDataSourceSuite extends QueryTest with SharedSQLContext with Befo }.getMessage assert(msg.contains("Cannot save interval data type into external storage.")) - msg = intercept[UnsupportedOperationException] { + msg = intercept[AnalysisException] { spark.udf.register("testType", () => new IntervalData()) sql("select testType()").write.format(format).mode("overwrite").save(tempDir) }.getMessage @@ -286,7 +347,7 @@ class FileBasedDataSourceSuite extends QueryTest with SharedSQLContext with Befo // read path Seq("parquet", "csv").foreach { format => - var msg = intercept[UnsupportedOperationException] { + var msg = intercept[AnalysisException] { val schema = StructType(StructField("a", CalendarIntervalType, true) :: Nil) spark.range(1).write.format(format).mode("overwrite").save(tempDir) spark.read.schema(schema).format(format).load(tempDir).collect() @@ -294,7 +355,7 @@ class FileBasedDataSourceSuite extends QueryTest with SharedSQLContext with Befo assert(msg.toLowerCase(Locale.ROOT) .contains(s"$format data source does not support calendarinterval data type.")) - msg = intercept[UnsupportedOperationException] { + msg = intercept[AnalysisException] { val schema = StructType(StructField("a", new IntervalUDT(), true) :: Nil) spark.range(1).write.format(format).mode("overwrite").save(tempDir) spark.read.schema(schema).format(format).load(tempDir).collect() @@ -302,19 +363,6 @@ class FileBasedDataSourceSuite extends QueryTest with SharedSQLContext with Befo assert(msg.toLowerCase(Locale.ROOT) .contains(s"$format data source does not support calendarinterval data type.")) } - - // We expect the types below should be passed for backward-compatibility - Seq("orc", "json").foreach { format => - // Interval type - var schema = StructType(StructField("a", CalendarIntervalType, true) :: Nil) - spark.range(1).write.format(format).mode("overwrite").save(tempDir) - spark.read.schema(schema).format(format).load(tempDir).collect() - - // UDT having interval data - schema = StructType(StructField("a", new IntervalUDT(), true) :: Nil) - spark.range(1).write.format(format).mode("overwrite").save(tempDir) - spark.read.schema(schema).format(format).load(tempDir).collect() - } } } @@ -324,13 +372,13 @@ class FileBasedDataSourceSuite extends QueryTest with SharedSQLContext with Befo Seq("orc").foreach { format => // write path - var msg = intercept[UnsupportedOperationException] { + var msg = intercept[AnalysisException] { sql("select null").write.format(format).mode("overwrite").save(tempDir) }.getMessage assert(msg.toLowerCase(Locale.ROOT) .contains(s"$format data source does not support null data type.")) - msg = intercept[UnsupportedOperationException] { + msg = intercept[AnalysisException] { spark.udf.register("testType", () => new NullData()) sql("select testType()").write.format(format).mode("overwrite").save(tempDir) }.getMessage @@ -353,13 +401,13 @@ class FileBasedDataSourceSuite extends QueryTest with SharedSQLContext with Befo Seq("parquet", "csv").foreach { format => // write path - var msg = intercept[UnsupportedOperationException] { + var msg = intercept[AnalysisException] { sql("select null").write.format(format).mode("overwrite").save(tempDir) }.getMessage assert(msg.toLowerCase(Locale.ROOT) .contains(s"$format data source does not support null data type.")) - msg = intercept[UnsupportedOperationException] { + msg = intercept[AnalysisException] { spark.udf.register("testType", () => new NullData()) sql("select testType()").write.format(format).mode("overwrite").save(tempDir) }.getMessage @@ -367,7 +415,7 @@ class FileBasedDataSourceSuite extends QueryTest with SharedSQLContext with Befo .contains(s"$format data source does not support null data type.")) // read path - msg = intercept[UnsupportedOperationException] { + msg = intercept[AnalysisException] { val schema = StructType(StructField("a", NullType, true) :: Nil) spark.range(1).write.format(format).mode("overwrite").save(tempDir) spark.read.schema(schema).format(format).load(tempDir).collect() @@ -375,7 +423,7 @@ class FileBasedDataSourceSuite extends QueryTest with SharedSQLContext with Befo assert(msg.toLowerCase(Locale.ROOT) .contains(s"$format data source does not support null data type.")) - msg = intercept[UnsupportedOperationException] { + msg = intercept[AnalysisException] { val schema = StructType(StructField("a", new NullUDT(), true) :: Nil) spark.range(1).write.format(format).mode("overwrite").save(tempDir) spark.read.schema(schema).format(format).load(tempDir).collect() @@ -385,6 +433,71 @@ class FileBasedDataSourceSuite extends QueryTest with SharedSQLContext with Befo } } } + + Seq("parquet", "orc").foreach { format => + test(s"Spark native readers should respect spark.sql.caseSensitive - ${format}") { + withTempDir { dir => + val tableName = s"spark_25132_${format}_native" + val tableDir = dir.getCanonicalPath + s"/$tableName" + withTable(tableName) { + val end = 5 + val data = spark.range(end).selectExpr("id as A", "id * 2 as b", "id * 3 as B") + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") { + data.write.format(format).mode("overwrite").save(tableDir) + } + sql(s"CREATE TABLE $tableName (a LONG, b LONG) USING $format LOCATION '$tableDir'") + + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") { + checkAnswer(sql(s"select a from $tableName"), data.select("A")) + checkAnswer(sql(s"select A from $tableName"), data.select("A")) + + // RuntimeException is triggered at executor side, which is then wrapped as + // SparkException at driver side + val e1 = intercept[SparkException] { + sql(s"select b from $tableName").collect() + } + assert( + e1.getCause.isInstanceOf[RuntimeException] && + e1.getCause.getMessage.contains( + """Found duplicate field(s) "b": [b, B] in case-insensitive mode""")) + val e2 = intercept[SparkException] { + sql(s"select B from $tableName").collect() + } + assert( + e2.getCause.isInstanceOf[RuntimeException] && + e2.getCause.getMessage.contains( + """Found duplicate field(s) "b": [b, B] in case-insensitive mode""")) + } + + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") { + checkAnswer(sql(s"select a from $tableName"), (0 until end).map(_ => Row(null))) + checkAnswer(sql(s"select b from $tableName"), data.select("b")) + } + } + } + } + } + + test("SPARK-25237 compute correct input metrics in FileScanRDD") { + withTempPath { p => + val path = p.getAbsolutePath + spark.range(1000).repartition(1).write.csv(path) + val bytesReads = new mutable.ArrayBuffer[Long]() + val bytesReadListener = new SparkListener() { + override def onTaskEnd(taskEnd: SparkListenerTaskEnd) { + bytesReads += taskEnd.taskMetrics.inputMetrics.bytesRead + } + } + sparkContext.addSparkListener(bytesReadListener) + try { + spark.read.csv(path).limit(1).collect() + sparkContext.listenerBus.waitUntilEmpty(1000L) + assert(bytesReads.sum === 7860) + } finally { + sparkContext.removeSparkListener(bytesReadListener) + } + } + } } object TestingUDT { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/GroupedDatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/GroupedDatasetSuite.scala deleted file mode 100644 index 147c0b61f5017..0000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/GroupedDatasetSuite.scala +++ /dev/null @@ -1,96 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql - -import org.apache.spark.api.python.PythonEvalType -import org.apache.spark.sql.catalyst.expressions.PythonUDF -import org.apache.spark.sql.catalyst.plans.logical.AnalysisBarrier -import org.apache.spark.sql.functions.udf -import org.apache.spark.sql.test.SharedSQLContext -import org.apache.spark.sql.types.{LongType, StructField, StructType} - -class GroupedDatasetSuite extends QueryTest with SharedSQLContext { - import testImplicits._ - - private val scalaUDF = udf((x: Long) => { x + 1 }) - private lazy val datasetWithUDF = spark.range(1).toDF("s").select($"s", scalaUDF($"s")) - - private def assertContainsAnalysisBarrier(ds: Dataset[_], atLevel: Int = 1): Unit = { - assert(atLevel >= 0) - var children = Seq(ds.queryExecution.logical) - (1 to atLevel).foreach { _ => - children = children.flatMap(_.children) - } - val barriers = children.collect { - case ab: AnalysisBarrier => ab - } - assert(barriers.nonEmpty, s"Plan does not contain AnalysisBarrier at level $atLevel:\n" + - ds.queryExecution.logical) - } - - test("SPARK-24373: avoid running Analyzer rules twice on RelationalGroupedDataset") { - val groupByDataset = datasetWithUDF.groupBy() - val rollupDataset = datasetWithUDF.rollup("s") - val cubeDataset = datasetWithUDF.cube("s") - val pivotDataset = datasetWithUDF.groupBy().pivot("s", Seq(1, 2)) - datasetWithUDF.cache() - Seq(groupByDataset, rollupDataset, cubeDataset, pivotDataset).foreach { rgDS => - val df = rgDS.count() - assertContainsAnalysisBarrier(df) - assertCached(df) - } - - val flatMapGroupsInRDF = datasetWithUDF.groupBy().flatMapGroupsInR( - Array.emptyByteArray, - Array.emptyByteArray, - Array.empty, - StructType(Seq(StructField("s", LongType)))) - val flatMapGroupsInPandasDF = datasetWithUDF.groupBy().flatMapGroupsInPandas(PythonUDF( - "pyUDF", - null, - StructType(Seq(StructField("s", LongType))), - Seq.empty, - PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF, - true)) - Seq(flatMapGroupsInRDF, flatMapGroupsInPandasDF).foreach { df => - assertContainsAnalysisBarrier(df, 2) - assertCached(df) - } - datasetWithUDF.unpersist(true) - } - - test("SPARK-24373: avoid running Analyzer rules twice on KeyValueGroupedDataset") { - val kvDasaset = datasetWithUDF.groupByKey(_.getLong(0)) - datasetWithUDF.cache() - val mapValuesKVDataset = kvDasaset.mapValues(_.getLong(0)).reduceGroups(_ + _) - val keysKVDataset = kvDasaset.keys - val flatMapGroupsKVDataset = kvDasaset.flatMapGroups((k, _) => Seq(k)) - val aggKVDataset = kvDasaset.count() - val otherKVDataset = spark.range(1).groupByKey(_ + 1) - val cogroupKVDataset = kvDasaset.cogroup(otherKVDataset)((k, _, _) => Seq(k)) - Seq((mapValuesKVDataset, 1), - (keysKVDataset, 2), - (flatMapGroupsKVDataset, 2), - (aggKVDataset, 1), - (cogroupKVDataset, 2)).foreach { case (df, analysisBarrierDepth) => - assertContainsAnalysisBarrier(df, analysisBarrierDepth) - assertCached(df) - } - datasetWithUDF.unpersist(true) - } -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala index 7bf17cbcd9c97..fe4bf15fa3921 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql -import org.apache.spark.sql.functions.{from_json, lit, map, struct, to_json} +import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ @@ -133,15 +133,11 @@ class JsonFunctionsSuite extends QueryTest with SharedSQLContext { Row(null) :: Nil) } - test("from_json invalid schema") { + test("from_json - json doesn't conform to the array type") { val df = Seq("""{"a" 1}""").toDS() val schema = ArrayType(StringType) - val message = intercept[AnalysisException] { - df.select(from_json($"value", schema)) - }.getMessage - assert(message.contains( - "Input schema array must be a struct or an array of structs.")) + checkAnswer(df.select(from_json($"value", schema)), Seq(Row(null))) } test("from_json array support") { @@ -311,7 +307,7 @@ class JsonFunctionsSuite extends QueryTest with SharedSQLContext { val errMsg1 = intercept[AnalysisException] { df3.selectExpr("from_json(value, 1)") } - assert(errMsg1.getMessage.startsWith("Expected a string literal instead of")) + assert(errMsg1.getMessage.startsWith("Schema should be specified in DDL format as a string")) val errMsg2 = intercept[AnalysisException] { df3.selectExpr("""from_json(value, 'time InvalidType')""") } @@ -392,4 +388,134 @@ class JsonFunctionsSuite extends QueryTest with SharedSQLContext { checkAnswer(Seq("""{"{"f": 1}": "a"}""").toDS().select(from_json($"value", schema)), Row(null)) } + + test("SPARK-24709: infers schemas of json strings and pass them to from_json") { + val in = Seq("""{"a": [1, 2, 3]}""").toDS() + val out = in.select(from_json('value, schema_of_json(lit("""{"a": [1]}"""))) as "parsed") + val expected = StructType(StructField( + "parsed", + StructType(StructField( + "a", + ArrayType(LongType, true), true) :: Nil), + true) :: Nil) + + assert(out.schema == expected) + } + + test("from_json - array of primitive types") { + val df = Seq("[1, 2, 3]").toDF("a") + val schema = new ArrayType(IntegerType, false) + + checkAnswer(df.select(from_json($"a", schema)), Seq(Row(Array(1, 2, 3)))) + } + + test("from_json - array of primitive types - malformed row") { + val df = Seq("[1, 2 3]").toDF("a") + val schema = new ArrayType(IntegerType, false) + + checkAnswer(df.select(from_json($"a", schema)), Seq(Row(null))) + } + + test("from_json - array of arrays") { + val jsonDF = Seq("[[1], [2, 3], [4, 5, 6]]").toDF("a") + val schema = new ArrayType(ArrayType(IntegerType, false), false) + jsonDF.select(from_json($"a", schema) as "json").createOrReplaceTempView("jsonTable") + + checkAnswer( + sql("select json[0][0], json[1][1], json[2][2] from jsonTable"), + Seq(Row(1, 3, 6))) + } + + test("from_json - array of arrays - malformed row") { + val jsonDF = Seq("[[1], [2, 3], 4, 5, 6]]").toDF("a") + val schema = new ArrayType(ArrayType(IntegerType, false), false) + jsonDF.select(from_json($"a", schema) as "json").createOrReplaceTempView("jsonTable") + + checkAnswer(sql("select json[0] from jsonTable"), Seq(Row(null))) + } + + test("from_json - array of structs") { + val jsonDF = Seq("""[{"a":1}, {"a":2}, {"a":3}]""").toDF("a") + val schema = new ArrayType(new StructType().add("a", IntegerType), false) + jsonDF.select(from_json($"a", schema) as "json").createOrReplaceTempView("jsonTable") + + checkAnswer( + sql("select json[0], json[1], json[2] from jsonTable"), + Seq(Row(Row(1), Row(2), Row(3)))) + } + + test("from_json - array of structs - malformed row") { + val jsonDF = Seq("""[{"a":1}, {"a:2}, {"a":3}]""").toDF("a") + val schema = new ArrayType(new StructType().add("a", IntegerType), false) + jsonDF.select(from_json($"a", schema) as "json").createOrReplaceTempView("jsonTable") + + checkAnswer(sql("select json[0], json[1]from jsonTable"), Seq(Row(null, null))) + } + + test("from_json - array of maps") { + val jsonDF = Seq("""[{"a":1}, {"b":2}]""").toDF("a") + val schema = new ArrayType(MapType(StringType, IntegerType, false), false) + jsonDF.select(from_json($"a", schema) as "json").createOrReplaceTempView("jsonTable") + + checkAnswer( + sql("""select json[0], json[1] from jsonTable"""), + Seq(Row(Map("a" -> 1), Map("b" -> 2)))) + } + + test("from_json - array of maps - malformed row") { + val jsonDF = Seq("""[{"a":1} "b":2}]""").toDF("a") + val schema = new ArrayType(MapType(StringType, IntegerType, false), false) + jsonDF.select(from_json($"a", schema) as "json").createOrReplaceTempView("jsonTable") + + checkAnswer(sql("""select json[0] from jsonTable"""), Seq(Row(null))) + } + + test("to_json - array of primitive types") { + val df = Seq(Array(1, 2, 3)).toDF("a") + checkAnswer(df.select(to_json($"a")), Seq(Row("[1,2,3]"))) + } + + test("roundtrip to_json -> from_json - array of primitive types") { + val arr = Array(1, 2, 3) + val df = Seq(arr).toDF("a") + checkAnswer(df.select(from_json(to_json($"a"), ArrayType(IntegerType))), Row(arr)) + } + + test("roundtrip from_json -> to_json - array of primitive types") { + val json = "[1,2,3]" + val df = Seq(json).toDF("a") + val schema = new ArrayType(IntegerType, false) + + checkAnswer(df.select(to_json(from_json($"a", schema))), Seq(Row(json))) + } + + test("roundtrip from_json -> to_json - array of arrays") { + val json = "[[1],[2,3],[4,5,6]]" + val jsonDF = Seq(json).toDF("a") + val schema = new ArrayType(ArrayType(IntegerType, false), false) + + checkAnswer( + jsonDF.select(to_json(from_json($"a", schema))), + Seq(Row(json))) + } + + test("roundtrip from_json -> to_json - array of maps") { + val json = """[{"a":1},{"b":2}]""" + val jsonDF = Seq(json).toDF("a") + val schema = new ArrayType(MapType(StringType, IntegerType, false), false) + + checkAnswer( + jsonDF.select(to_json(from_json($"a", schema))), + Seq(Row(json))) + } + + test("roundtrip from_json -> to_json - array of structs") { + val json = """[{"a":1},{"a":2},{"a":3}]""" + val jsonDF = Seq(json).toDF("a") + val schema = new ArrayType(new StructType().add("a", IntegerType), false) + + checkAnswer( + jsonDF.select(to_json(from_json($"a", schema))), + Seq(Row(json))) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/LocalSparkSession.scala b/sql/core/src/test/scala/org/apache/spark/sql/LocalSparkSession.scala index cbef1c7828319..6b90f20a94fa4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/LocalSparkSession.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/LocalSparkSession.scala @@ -36,19 +36,14 @@ trait LocalSparkSession extends BeforeAndAfterEach with BeforeAndAfterAll { self override def afterEach() { try { - resetSparkContext() + LocalSparkSession.stop(spark) SparkSession.clearActiveSession() SparkSession.clearDefaultSession() + spark = null } finally { super.afterEach() } } - - def resetSparkContext(): Unit = { - LocalSparkSession.stop(spark) - spark = null - } - } object LocalSparkSession { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala index 9fb8be423614b..baca9c1cfb9a0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -290,6 +290,16 @@ object QueryTest { Row.fromSeq(row.toSeq.map { case null => null case d: java.math.BigDecimal => BigDecimal(d) + // Equality of WrappedArray differs for AnyVal and AnyRef in Scala 2.12.2+ + case seq: Seq[_] => seq.map { + case b: java.lang.Byte => b.byteValue + case s: java.lang.Short => s.shortValue + case i: java.lang.Integer => i.intValue + case l: java.lang.Long => l.longValue + case f: java.lang.Float => f.floatValue + case d: java.lang.Double => d.doubleValue + case x => x + } // Convert array to Seq for easy equality check. case b: Array[_] => b.toSeq case r: Row => prepareRow(r) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/RuntimeConfigSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/RuntimeConfigSuite.scala index cfe2e9f2dbc44..cdcea09ad9758 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/RuntimeConfigSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/RuntimeConfigSuite.scala @@ -54,4 +54,18 @@ class RuntimeConfigSuite extends SparkFunSuite { conf.get("k1") } } + + test("SPARK-24761: is a config parameter modifiable") { + val conf = newConf() + + // SQL configs + assert(!conf.isModifiable("spark.sql.sources.schemaStringLengthThreshold")) + assert(conf.isModifiable("spark.sql.streaming.checkpointLocation")) + // Core configs + assert(!conf.isModifiable("spark.task.cpus")) + assert(!conf.isModifiable("spark.executor.cores")) + // Invalid config parameters + assert(!conf.isModifiable("")) + assert(!conf.isModifiable("invalid config parameter")) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index dfb9c137b74f0..01dc28d70184e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -27,6 +27,7 @@ import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart} import org.apache.spark.sql.catalyst.util.StringUtils import org.apache.spark.sql.execution.aggregate import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, SortAggregateExec} +import org.apache.spark.sql.execution.datasources.FilePartition import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, CartesianProductExec, SortMergeJoinExec} import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf @@ -523,6 +524,15 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { sortTest() } + test("limit for skew dataframe") { + // Create a skew dataframe. + val df = testData.repartition(100).union(testData).limit(50) + // Because `rdd` of dataframe will add a `DeserializeToObject` on top of `GlobalLimit`, + // the `GlobalLimit` will not be replaced with `CollectLimit`. So we can test if `GlobalLimit` + // work on skew partitions. + assert(df.rdd.count() == 50L) + } + test("CTE feature") { checkAnswer( sql("with q1 as (select * from testData limit 10) select * from q1"), @@ -1689,22 +1699,6 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } assert(e.message.contains("Hive built-in ORC data source must be used with Hive support")) - e = intercept[AnalysisException] { - sql(s"select id from `com.databricks.spark.avro`.`file_path`") - } - assert(e.message.contains("Failed to find data source: com.databricks.spark.avro.")) - - // data source type is case insensitive - e = intercept[AnalysisException] { - sql(s"select id from Avro.`file_path`") - } - assert(e.message.contains("Failed to find data source: avro.")) - - e = intercept[AnalysisException] { - sql(s"select id from avro.`file_path`") - } - assert(e.message.contains("Failed to find data source: avro.")) - e = intercept[AnalysisException] { sql(s"select id from `org.apache.spark.sql.sources.HadoopFsRelationProvider`.`file_path`") } @@ -1950,7 +1944,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { // TODO: support subexpression elimination in whole stage codegen withSQLConf("spark.sql.codegen.wholeStage" -> "false") { // select from a table to prevent constant folding. - val df = sql("SELECT a, b from testData2 limit 1") + val df = sql("SELECT a, b from testData2 order by a, b limit 1") checkAnswer(df, Row(1, 1)) checkAnswer(df.selectExpr("a + 1", "a + 1"), Row(2, 2)) @@ -2704,7 +2698,8 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { val m = intercept[AnalysisException] { sql("SELECT * FROM t, S WHERE c = C") }.message - assert(m.contains("cannot resolve '(t.`c` = S.`C`)' due to data type mismatch")) + assert( + m.contains("cannot resolve '(default.t.`c` = default.S.`C`)' due to data type mismatch")) } } } @@ -2813,4 +2808,56 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { checkAnswer(df, Seq(Row(3, 99, 1))) } } + + + test("SPARK-24940: coalesce and repartition hint") { + withTempView("nums1") { + val numPartitionsSrc = 10 + spark.range(0, 100, 1, numPartitionsSrc).createOrReplaceTempView("nums1") + assert(spark.table("nums1").rdd.getNumPartitions == numPartitionsSrc) + + withTable("nums") { + sql("CREATE TABLE nums (id INT) USING parquet") + + Seq(5, 20, 2).foreach { numPartitions => + sql( + s""" + |INSERT OVERWRITE TABLE nums + |SELECT /*+ REPARTITION($numPartitions) */ * + |FROM nums1 + """.stripMargin) + assert(spark.table("nums").inputFiles.length == numPartitions) + + sql( + s""" + |INSERT OVERWRITE TABLE nums + |SELECT /*+ COALESCE($numPartitions) */ * + |FROM nums1 + """.stripMargin) + // Coalesce can not increase the number of partitions + assert(spark.table("nums").inputFiles.length == Seq(numPartitions, numPartitionsSrc).min) + } + } + } + } + + test("SPARK-25084: 'distribute by' on multiple columns may lead to codegen issue") { + withView("spark_25084") { + val count = 1000 + val df = spark.range(count) + val columns = (0 until 400).map{ i => s"id as id$i" } + val distributeExprs = (0 until 100).map(c => s"id$c").mkString(",") + df.selectExpr(columns : _*).createTempView("spark_25084") + assert( + spark.sql(s"select * from spark_25084 distribute by ($distributeExprs)").count === count) + } + } + + test("SPARK-25144 'distinct' causes memory leak") { + val ds = List(Foo(Some("bar"))).toDS + val result = ds.flatMap(_.bar).distinct + result.rdd.isEmpty + } } + +case class Foo(bar: Option[String]) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala index beac9699585d5..826408c7161e9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala @@ -54,6 +54,7 @@ import org.apache.spark.sql.types.StructType * The format for input files is simple: * 1. A list of SQL queries separated by semicolon. * 2. Lines starting with -- are treated as comments and ignored. + * 3. Lines starting with --SET are used to run the file with the following set of configs. * * For example: * {{{ @@ -138,18 +139,58 @@ class SQLQueryTestSuite extends QueryTest with SharedSQLContext { private def runTest(testCase: TestCase): Unit = { val input = fileToString(new File(testCase.inputFile)) + val (comments, code) = input.split("\n").partition(_.startsWith("--")) + val configSets = { + val configLines = comments.filter(_.startsWith("--SET")).map(_.substring(5)) + val configs = configLines.map(_.split(",").map { confAndValue => + val (conf, value) = confAndValue.span(_ != '=') + conf.trim -> value.substring(1).trim + }) + // When we are regenerating the golden files we don't need to run all the configs as they + // all need to return the same result + if (regenerateGoldenFiles && configs.nonEmpty) { + configs.take(1) + } else { + configs + } + } // List of SQL queries to run - val queries: Seq[String] = { - val cleaned = input.split("\n").filterNot(_.startsWith("--")).mkString("\n") - // note: this is not a robust way to split queries using semicolon, but works for now. - cleaned.split("(?<=[^\\\\]);").map(_.trim).filter(_ != "").toSeq + // note: this is not a robust way to split queries using semicolon, but works for now. + val queries = code.mkString("\n").split("(?<=[^\\\\]);").map(_.trim).filter(_ != "").toSeq + + if (configSets.isEmpty) { + runQueries(queries, testCase.resultFile, None) + } else { + configSets.foreach { configSet => + try { + runQueries(queries, testCase.resultFile, Some(configSet)) + } catch { + case e: Throwable => + val configs = configSet.map { + case (k, v) => s"$k=$v" + } + logError(s"Error using configs: ${configs.mkString(",")}") + throw e + } + } } + } + private def runQueries( + queries: Seq[String], + resultFileName: String, + configSet: Option[Seq[(String, String)]]): Unit = { // Create a local SparkSession to have stronger isolation between different test cases. // This does not isolate catalog changes. val localSparkSession = spark.newSession() loadTestData(localSparkSession) + if (configSet.isDefined) { + // Execute the list of set operation in order to add the desired configs + val setOperations = configSet.get.map { case (key, value) => s"set $key=$value" } + logInfo(s"Setting configs: ${setOperations.mkString(", ")}") + setOperations.foreach(localSparkSession.sql) + } // Run the SQL queries preparing them for comparison. val outputs: Seq[QueryOutput] = queries.map { sql => val (schema, output) = getNormalizedResult(localSparkSession, sql) @@ -167,7 +208,7 @@ class SQLQueryTestSuite extends QueryTest with SharedSQLContext { s"-- Number of queries: ${outputs.size}\n\n\n" + outputs.zipWithIndex.map{case (qr, i) => qr.toString(i)}.mkString("\n\n\n") + "\n" } - val resultFile = new File(testCase.resultFile) + val resultFile = new File(resultFileName) val parent = resultFile.getParentFile if (!parent.exists()) { assert(parent.mkdirs(), "Could not create directory: " + parent) @@ -177,7 +218,7 @@ class SQLQueryTestSuite extends QueryTest with SharedSQLContext { // Read back the golden file. val expectedOutputs: Seq[QueryOutput] = { - val goldenOutput = fileToString(new File(testCase.resultFile)) + val goldenOutput = fileToString(new File(resultFileName)) val segments = goldenOutput.split("-- !query.+\n") // each query has 3 segments, plus the header diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SessionStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SessionStateSuite.scala index 7d1366092d1e6..e1b5eba53f06a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SessionStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SessionStateSuite.scala @@ -41,13 +41,16 @@ class SessionStateSuite extends SparkFunSuite { } override def afterAll(): Unit = { - if (activeSession != null) { - activeSession.stop() - activeSession = null - SparkSession.clearActiveSession() - SparkSession.clearDefaultSession() + try { + if (activeSession != null) { + activeSession.stop() + activeSession = null + SparkSession.clearActiveSession() + SparkSession.clearDefaultSession() + } + } finally { + super.afterAll() } - super.afterAll() } test("fork new session and inherit RuntimeConfig options") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala index 60fa951e23178..cb562d65b6147 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala @@ -204,6 +204,24 @@ class StatisticsCollectionSuite extends StatisticsCollectionTestBase with Shared } } + test("SPARK-25028: column stats collection for null partitioning columns") { + val table = "analyze_partition_with_null" + withTempDir { dir => + withTable(table) { + sql(s""" + |CREATE TABLE $table (value string, name string) + |USING PARQUET + |PARTITIONED BY (name) + |LOCATION '${dir.toURI}'""".stripMargin) + val df = Seq(("a", null), ("b", null)).toDF("value", "name") + df.write.mode("overwrite").insertInto(table) + sql(s"ANALYZE TABLE $table PARTITION (name) COMPUTE STATISTICS") + val partitions = spark.sessionState.catalog.listPartitions(TableIdentifier(table)) + assert(partitions.head.stats.get.rowCount.get == 2) + } + } + } + test("number format in statistics") { val numbers = Seq( BigInt(0) -> (("0.0 B", "0")), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala index acef62d81ee12..cbffed994bb4f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala @@ -17,7 +17,10 @@ package org.apache.spark.sql -import org.apache.spark.sql.catalyst.plans.logical.Join +import scala.collection.mutable.ArrayBuffer + +import org.apache.spark.sql.catalyst.expressions.SubqueryExpression +import org.apache.spark.sql.catalyst.plans.logical.{Join, LogicalPlan, Sort} import org.apache.spark.sql.test.SharedSQLContext class SubquerySuite extends QueryTest with SharedSQLContext { @@ -970,4 +973,299 @@ class SubquerySuite extends QueryTest with SharedSQLContext { Row("3", "b") :: Row("4", "b") :: Nil) } } + + private def getNumSortsInQuery(query: String): Int = { + val plan = sql(query).queryExecution.optimizedPlan + getNumSorts(plan) + getSubqueryExpressions(plan).map{s => getNumSorts(s.plan)}.sum + } + + private def getSubqueryExpressions(plan: LogicalPlan): Seq[SubqueryExpression] = { + val subqueryExpressions = ArrayBuffer.empty[SubqueryExpression] + plan transformAllExpressions { + case s: SubqueryExpression => + subqueryExpressions ++= (getSubqueryExpressions(s.plan) :+ s) + s + } + subqueryExpressions + } + + private def getNumSorts(plan: LogicalPlan): Int = { + plan.collect { case s: Sort => s }.size + } + + test("SPARK-23957 Remove redundant sort from subquery plan(in subquery)") { + withTempView("t1", "t2", "t3") { + Seq((1, 1), (2, 2)).toDF("c1", "c2").createOrReplaceTempView("t1") + Seq((1, 1), (2, 2)).toDF("c1", "c2").createOrReplaceTempView("t2") + Seq((1, 1, 1), (2, 2, 2)).toDF("c1", "c2", "c3").createOrReplaceTempView("t3") + + // Simple order by + val query1 = + """ + |SELECT c1 FROM t1 + |WHERE + |c1 IN (SELECT c1 FROM t2 ORDER BY c1) + """.stripMargin + assert(getNumSortsInQuery(query1) == 0) + + // Nested order bys + val query2 = + """ + |SELECT c1 + |FROM t1 + |WHERE c1 IN (SELECT c1 + | FROM (SELECT * + | FROM t2 + | ORDER BY c2) + | ORDER BY c1) + """.stripMargin + assert(getNumSortsInQuery(query2) == 0) + + + // nested IN + val query3 = + """ + |SELECT c1 + |FROM t1 + |WHERE c1 IN (SELECT c1 + | FROM t2 + | WHERE c1 IN (SELECT c1 + | FROM t3 + | WHERE c1 = 1 + | ORDER BY c3) + | ORDER BY c2) + """.stripMargin + assert(getNumSortsInQuery(query3) == 0) + + // Complex subplan and multiple sorts + val query4 = + """ + |SELECT c1 + |FROM t1 + |WHERE c1 IN (SELECT c1 + | FROM (SELECT c1, c2, count(*) + | FROM t2 + | GROUP BY c1, c2 + | HAVING count(*) > 0 + | ORDER BY c2) + | ORDER BY c1) + """.stripMargin + assert(getNumSortsInQuery(query4) == 0) + + // Join in subplan + val query5 = + """ + |SELECT c1 FROM t1 + |WHERE + |c1 IN (SELECT t2.c1 FROM t2, t3 + | WHERE t2.c1 = t3.c1 + | ORDER BY t2.c1) + """.stripMargin + assert(getNumSortsInQuery(query5) == 0) + + val query6 = + """ + |SELECT c1 + |FROM t1 + |WHERE (c1, c2) IN (SELECT c1, max(c2) + | FROM (SELECT c1, c2, count(*) + | FROM t2 + | GROUP BY c1, c2 + | HAVING count(*) > 0 + | ORDER BY c2) + | GROUP BY c1 + | HAVING max(c2) > 0 + | ORDER BY c1) + """.stripMargin + // The rule to remove redundant sorts is not able to remove the inner sort under + // an Aggregate operator. We only remove the top level sort. + assert(getNumSortsInQuery(query6) == 1) + + // Cases when sort is not removed from the plan + // Limit on top of sort + val query7 = + """ + |SELECT c1 FROM t1 + |WHERE + |c1 IN (SELECT c1 FROM t2 ORDER BY c1 limit 1) + """.stripMargin + assert(getNumSortsInQuery(query7) == 1) + + // Sort below a set operations (intersect, union) + val query8 = + """ + |SELECT c1 FROM t1 + |WHERE + |c1 IN (( + | SELECT c1 FROM t2 + | ORDER BY c1 + | ) + | UNION + | ( + | SELECT c1 FROM t2 + | ORDER BY c1 + | )) + """.stripMargin + assert(getNumSortsInQuery(query8) == 2) + } + } + + test("SPARK-23957 Remove redundant sort from subquery plan(exists subquery)") { + withTempView("t1", "t2", "t3") { + Seq((1, 1), (2, 2)).toDF("c1", "c2").createOrReplaceTempView("t1") + Seq((1, 1), (2, 2)).toDF("c1", "c2").createOrReplaceTempView("t2") + Seq((1, 1, 1), (2, 2, 2)).toDF("c1", "c2", "c3").createOrReplaceTempView("t3") + + // Simple order by exists correlated + val query1 = + """ + |SELECT c1 FROM t1 + |WHERE + |EXISTS (SELECT t2.c1 FROM t2 WHERE t1.c1 = t2.c1 ORDER BY t2.c1) + """.stripMargin + assert(getNumSortsInQuery(query1) == 0) + + // Nested order by and correlated. + val query2 = + """ + |SELECT c1 + |FROM t1 + |WHERE EXISTS (SELECT c1 + | FROM (SELECT * + | FROM t2 + | WHERE t2.c1 = t1.c1 + | ORDER BY t2.c2) t2 + | ORDER BY t2.c1) + """.stripMargin + assert(getNumSortsInQuery(query2) == 0) + + // nested EXISTS + val query3 = + """ + |SELECT c1 + |FROM t1 + |WHERE EXISTS (SELECT c1 + | FROM t2 + | WHERE EXISTS (SELECT c1 + | FROM t3 + | WHERE t3.c1 = t2.c1 + | ORDER BY c3) + | AND t2.c1 = t1.c1 + | ORDER BY c2) + """.stripMargin + assert(getNumSortsInQuery(query3) == 0) + + // Cases when sort is not removed from the plan + // Limit on top of sort + val query4 = + """ + |SELECT c1 FROM t1 + |WHERE + |EXISTS (SELECT t2.c1 FROM t2 WHERE t2.c1 = 1 ORDER BY t2.c1 limit 1) + """.stripMargin + assert(getNumSortsInQuery(query4) == 1) + + // Sort below a set operations (intersect, union) + val query5 = + """ + |SELECT c1 FROM t1 + |WHERE + |EXISTS (( + | SELECT c1 FROM t2 + | WHERE t2.c1 = 1 + | ORDER BY t2.c1 + | ) + | UNION + | ( + | SELECT c1 FROM t2 + | WHERE t2.c1 = 2 + | ORDER BY t2.c1 + | )) + """.stripMargin + assert(getNumSortsInQuery(query5) == 2) + } + } + + test("SPARK-23957 Remove redundant sort from subquery plan(scalar subquery)") { + withTempView("t1", "t2", "t3") { + Seq((1, 1), (2, 2)).toDF("c1", "c2").createOrReplaceTempView("t1") + Seq((1, 1), (2, 2)).toDF("c1", "c2").createOrReplaceTempView("t2") + Seq((1, 1, 1), (2, 2, 2)).toDF("c1", "c2", "c3").createOrReplaceTempView("t3") + + // Two scalar subqueries in OR + val query1 = + """ + |SELECT * FROM t1 + |WHERE c1 = (SELECT max(t2.c1) + | FROM t2 + | ORDER BY max(t2.c1)) + |OR c2 = (SELECT min(t3.c2) + | FROM t3 + | WHERE t3.c1 = 1 + | ORDER BY min(t3.c2)) + """.stripMargin + assert(getNumSortsInQuery(query1) == 0) + + // scalar subquery - groupby and having + val query2 = + """ + |SELECT * + |FROM t1 + |WHERE c1 = (SELECT max(t2.c1) + | FROM t2 + | GROUP BY t2.c1 + | HAVING count(*) >= 1 + | ORDER BY max(t2.c1)) + """.stripMargin + assert(getNumSortsInQuery(query2) == 0) + + // nested scalar subquery + val query3 = + """ + |SELECT * + |FROM t1 + |WHERE c1 = (SELECT max(t2.c1) + | FROM t2 + | WHERE c1 = (SELECT max(t3.c1) + | FROM t3 + | WHERE t3.c1 = 1 + | GROUP BY t3.c1 + | ORDER BY max(t3.c1) + | ) + | GROUP BY t2.c1 + | HAVING count(*) >= 1 + | ORDER BY max(t2.c1)) + """.stripMargin + assert(getNumSortsInQuery(query3) == 0) + + // Scalar subquery in projection + val query4 = + """ + |SELECT (SELECT min(c1) from t1 group by c1 order by c1) + |FROM t1 + |WHERE t1.c1 = 1 + """.stripMargin + assert(getNumSortsInQuery(query4) == 0) + + // Limit on top of sort prevents it from being pruned. + val query5 = + """ + |SELECT * + |FROM t1 + |WHERE c1 = (SELECT max(t2.c1) + | FROM t2 + | WHERE c1 = (SELECT max(t3.c1) + | FROM t3 + | WHERE t3.c1 = 1 + | GROUP BY t3.c1 + | ORDER BY max(t3.c1) + | ) + | GROUP BY t2.c1 + | HAVING count(*) >= 1 + | ORDER BY max(t2.c1) + | LIMIT 1) + """.stripMargin + assert(getNumSortsInQuery(query5) == 1) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TPCDSQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/TPCDSQuerySuite.scala index bc95b4696190d..817224d1c28ef 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/TPCDSQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/TPCDSQuerySuite.scala @@ -147,7 +147,7 @@ class TPCDSQuerySuite extends BenchmarkQueryTest { |`s_company_id` INT, `s_company_name` STRING, `s_street_number` STRING, |`s_street_name` STRING, `s_street_type` STRING, `s_suite_number` STRING, `s_city` STRING, |`s_county` STRING, `s_state` STRING, `s_zip` STRING, `s_country` STRING, - |`s_gmt_offset` DECIMAL(5,2), `s_tax_precentage` DECIMAL(5,2)) + |`s_gmt_offset` DECIMAL(5,2), `s_tax_percentage` DECIMAL(5,2)) |USING parquet """.stripMargin) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala index 21afdc7e2a33f..30dca9497ddde 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala @@ -19,11 +19,16 @@ package org.apache.spark.sql import org.apache.spark.sql.api.java._ import org.apache.spark.sql.catalyst.plans.logical.Project -import org.apache.spark.sql.execution.command.ExplainCommand -import org.apache.spark.sql.functions.udf +import org.apache.spark.sql.execution.QueryExecution +import org.apache.spark.sql.execution.columnar.InMemoryRelation +import org.apache.spark.sql.execution.command.{CreateDataSourceTableAsSelectCommand, ExplainCommand} +import org.apache.spark.sql.execution.datasources.InsertIntoHadoopFsRelationCommand +import org.apache.spark.sql.functions.{lit, udf} import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.test.SQLTestData._ import org.apache.spark.sql.types.{DataTypes, DoubleType} +import org.apache.spark.sql.util.QueryExecutionListener + private case class FunctionResult(f1: String, f2: String) @@ -324,4 +329,68 @@ class UDFSuite extends QueryTest with SharedSQLContext { assert(outputStream.toString.contains("UDF:f(a._1 AS `_1`)")) } } + + test("cached Data should be used in the write path") { + withTable("t") { + withTempPath { path => + var numTotalCachedHit = 0 + val listener = new QueryExecutionListener { + override def onFailure(f: String, qe: QueryExecution, e: Exception): Unit = {} + + override def onSuccess(funcName: String, qe: QueryExecution, duration: Long): Unit = { + qe.withCachedData match { + case c: CreateDataSourceTableAsSelectCommand + if c.query.isInstanceOf[InMemoryRelation] => + numTotalCachedHit += 1 + case i: InsertIntoHadoopFsRelationCommand + if i.query.isInstanceOf[InMemoryRelation] => + numTotalCachedHit += 1 + case _ => + } + } + } + spark.listenerManager.register(listener) + + val udf1 = udf({ (x: Int, y: Int) => x + y }) + val df = spark.range(0, 3).toDF("a") + .withColumn("b", udf1($"a", lit(10))) + df.cache() + df.write.saveAsTable("t") + assert(numTotalCachedHit == 1, "expected to be cached in saveAsTable") + df.write.insertInto("t") + assert(numTotalCachedHit == 2, "expected to be cached in insertInto") + df.write.save(path.getCanonicalPath) + assert(numTotalCachedHit == 3, "expected to be cached in save for native") + } + } + } + + test("SPARK-24891 Fix HandleNullInputsForUDF rule") { + val udf1 = udf({(x: Int, y: Int) => x + y}) + val df = spark.range(0, 3).toDF("a") + .withColumn("b", udf1($"a", udf1($"a", lit(10)))) + .withColumn("c", udf1($"a", lit(null))) + val plan = spark.sessionState.executePlan(df.logicalPlan).analyzed + + comparePlans(df.logicalPlan, plan) + checkAnswer( + df, + Seq( + Row(0, 10, null), + Row(1, 12, null), + Row(2, 14, null))) + } + + test("SPARK-24891 Fix HandleNullInputsForUDF rule - with table") { + withTable("x") { + Seq((1, "2"), (2, "4")).toDF("a", "b").write.format("json").saveAsTable("x") + sql("insert into table x values(3, null)") + sql("insert into table x values(null, '4')") + spark.udf.register("f", (a: Int, b: String) => a + b) + val df = spark.sql("SELECT f(a, b) FROM x") + val plan = spark.sessionState.executePlan(df.logicalPlan).analyzed + comparePlans(df.logicalPlan, plan) + checkAnswer(df, Seq(Row("12"), Row("24"), Row("3null"), Row(null))) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala index 737eeb0af586e..c627c51655c8d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala @@ -21,7 +21,7 @@ import org.scalatest.BeforeAndAfterAll import org.apache.spark.{MapOutputStatistics, SparkConf, SparkFunSuite} import org.apache.spark.sql._ -import org.apache.spark.sql.execution.exchange.{ExchangeCoordinator, ShuffleExchangeExec} +import org.apache.spark.sql.execution.exchange.{ExchangeCoordinator, ReusedExchangeExec, ShuffleExchangeExec} import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf @@ -31,6 +31,7 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll { private var originalInstantiatedSparkSession: Option[SparkSession] = _ override protected def beforeAll(): Unit = { + super.beforeAll() originalActiveSparkSession = SparkSession.getActiveSession originalInstantiatedSparkSession = SparkSession.getDefaultSession @@ -39,9 +40,13 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll { } override protected def afterAll(): Unit = { - // Set these states back. - originalActiveSparkSession.foreach(ctx => SparkSession.setActiveSession(ctx)) - originalInstantiatedSparkSession.foreach(ctx => SparkSession.setDefaultSession(ctx)) + try { + // Set these states back. + originalActiveSparkSession.foreach(ctx => SparkSession.setActiveSession(ctx)) + originalInstantiatedSparkSession.foreach(ctx => SparkSession.setDefaultSession(ctx)) + } finally { + super.afterAll() + } } private def checkEstimation( @@ -50,7 +55,7 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll { expectedPartitionStartIndices: Array[Int]): Unit = { val mapOutputStatistics = bytesByPartitionIdArray.zipWithIndex.map { case (bytesByPartitionId, index) => - new MapOutputStatistics(index, bytesByPartitionId) + new MapOutputStatistics(index, bytesByPartitionId, Array[Long](1)) } val estimatedPartitionStartIndices = coordinator.estimatePartitionStartIndices(mapOutputStatistics) @@ -58,7 +63,7 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll { } test("test estimatePartitionStartIndices - 1 Exchange") { - val coordinator = new ExchangeCoordinator(1, 100L) + val coordinator = new ExchangeCoordinator(100L) { // All bytes per partition are 0. @@ -105,7 +110,7 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll { } test("test estimatePartitionStartIndices - 2 Exchanges") { - val coordinator = new ExchangeCoordinator(2, 100L) + val coordinator = new ExchangeCoordinator(100L) { // If there are multiple values of the number of pre-shuffle partitions, @@ -114,8 +119,8 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll { val bytesByPartitionId2 = Array[Long](0, 0, 0, 0, 0, 0) val mapOutputStatistics = Array( - new MapOutputStatistics(0, bytesByPartitionId1), - new MapOutputStatistics(1, bytesByPartitionId2)) + new MapOutputStatistics(0, bytesByPartitionId1, Array[Long](0)), + new MapOutputStatistics(1, bytesByPartitionId2, Array[Long](0))) intercept[AssertionError](coordinator.estimatePartitionStartIndices(mapOutputStatistics)) } @@ -199,7 +204,7 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll { } test("test estimatePartitionStartIndices and enforce minimal number of reducers") { - val coordinator = new ExchangeCoordinator(2, 100L, Some(2)) + val coordinator = new ExchangeCoordinator(100L, Some(2)) { // The minimal number of post-shuffle partitions is not enforced because @@ -480,4 +485,17 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll { withSparkSession(test, 6144, minNumPostShufflePartitions) } } + + test("SPARK-24705 adaptive query execution works correctly when exchange reuse enabled") { + val test = { spark: SparkSession => + spark.sql("SET spark.sql.exchange.reuse=true") + val df = spark.range(1).selectExpr("id AS key", "id AS value") + val resultDf = df.join(df, "key").join(df, "key") + val sparkPlan = resultDf.queryExecution.executedPlan + assert(sparkPlan.collect { case p: ReusedExchangeExec => p }.length == 1) + assert(sparkPlan.collect { case p @ ShuffleExchangeExec(_, _, Some(c)) => p }.length == 3) + checkAnswer(resultDf, Row(0, 0, 0, 0) :: Nil) + } + withSparkSession(test, 4, None) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArraySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArraySuite.scala index ecc7264d79442..b29de9c4adbaa 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArraySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArraySuite.scala @@ -29,7 +29,11 @@ class ExternalAppendOnlyUnsafeRowArraySuite extends SparkFunSuite with LocalSpar private val random = new java.util.Random() private var taskContext: TaskContext = _ - override def afterAll(): Unit = TaskContext.unset() + override def afterAll(): Unit = try { + TaskContext.unset() + } finally { + super.afterAll() + } private def withExternalArray(inMemoryThreshold: Int, spillThreshold: Int) (f: ExternalAppendOnlyUnsafeRowArray => Unit): Unit = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/LimitSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/LimitSuite.scala new file mode 100644 index 0000000000000..a7840a5fcfae0 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/LimitSuite.scala @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import scala.util.Random + +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.SharedSQLContext + + +class LimitSuite extends SparkPlanTest with SharedSQLContext { + + private var rand: Random = _ + private var seed: Long = 0 + + protected override def beforeAll(): Unit = { + super.beforeAll() + seed = System.currentTimeMillis() + rand = new Random(seed) + } + + test("Produce ordered global limit if more than topKSortFallbackThreshold") { + withSQLConf(SQLConf.TOP_K_SORT_FALLBACK_THRESHOLD.key -> "100") { + val df = LimitTest.generateRandomInputData(spark, rand).sort("a") + + val globalLimit = df.limit(99).queryExecution.executedPlan.collect { + case g: GlobalLimitExec => g + } + assert(globalLimit.size == 0) + + val topKSort = df.limit(99).queryExecution.executedPlan.collect { + case t: TakeOrderedAndProjectExec => t + } + assert(topKSort.size == 1) + + val orderedGlobalLimit = df.limit(100).queryExecution.executedPlan.collect { + case g: GlobalLimitExec => g + } + assert(orderedGlobalLimit.size == 1 && orderedGlobalLimit(0).orderedLimit == true) + } + } + + test("Ordered global limit") { + val baseDf = LimitTest.generateRandomInputData(spark, rand) + .select("a").repartition(3).sort("a") + + withSQLConf(SQLConf.LIMIT_FLAT_GLOBAL_LIMIT.key -> "true") { + val orderedGlobalLimit = GlobalLimitExec(3, baseDf.queryExecution.sparkPlan, + orderedLimit = true) + val orderedGlobalLimitResult = SparkPlanTest.executePlan(orderedGlobalLimit, spark.sqlContext) + .map(_.getInt(0)) + + val globalLimit = GlobalLimitExec(3, baseDf.queryExecution.sparkPlan, orderedLimit = false) + val globalLimitResult = SparkPlanTest.executePlan(globalLimit, spark.sqlContext) + .map(_.getInt(0)) + + // Global limit without order takes values at each partition sequentially. + // After global sort, the values in second partition must be larger than the values + // in first partition. + assert(orderedGlobalLimitResult(0) == globalLimitResult(0)) + assert(orderedGlobalLimitResult(1) < globalLimitResult(1)) + assert(orderedGlobalLimitResult(2) < globalLimitResult(2)) + } + } +} + diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index d254345e8fa54..b10da6c70be16 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -262,7 +262,7 @@ class PlannerSuite extends SharedSQLContext { ).queryExecution.executedPlan.collect { case exchange: ShuffleExchangeExec => exchange }.length - assert(numExchanges === 5) + assert(numExchanges === 3) } { @@ -277,7 +277,7 @@ class PlannerSuite extends SharedSQLContext { ).queryExecution.executedPlan.collect { case exchange: ShuffleExchangeExec => exchange }.length - assert(numExchanges === 5) + assert(numExchanges === 3) } } @@ -624,7 +624,7 @@ class PlannerSuite extends SharedSQLContext { dataType = LongType, nullable = false ) (exprId = exprId, - qualifier = Some("col1_qualifier") + qualifier = Seq("col1_qualifier") ) val attribute2 = @@ -704,6 +704,23 @@ class PlannerSuite extends SharedSQLContext { df.queryExecution.executedPlan.execute() } + test("SPARK-25278: physical nodes should be different instances for same logical nodes") { + val range = Range(1, 1, 1, 1) + val df = Union(range, range) + val ranges = df.queryExecution.optimizedPlan.collect { + case r: Range => r + } + assert(ranges.length == 2) + // Ensure the two Range instances are equal according to their equal method + assert(ranges.head == ranges.last) + val execRanges = df.queryExecution.sparkPlan.collect { + case r: RangeExec => r + } + assert(execRanges.length == 2) + // Ensure the two RangeExec instances are different instances + assert(!execRanges.head.eq(execRanges.last)) + } + test("SPARK-24556: always rewrite output partitioning in ReusedExchangeExec " + "and InMemoryTableScanExec") { def checkOutputPartitioningRewrite( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLJsonProtocolSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLJsonProtocolSuite.scala index c2e62b987e0cc..08e40e28d3d57 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLJsonProtocolSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLJsonProtocolSuite.scala @@ -46,7 +46,7 @@ class SQLJsonProtocolSuite extends SparkFunSuite { """.stripMargin val reconstructedEvent = JsonProtocol.sparkEventFromJson(parse(SQLExecutionStartJsonString)) val expectedEvent = SparkListenerSQLExecutionStart(0, "test desc", "test detail", "test plan", - new SparkPlanInfo("TestNode", "test string", Nil, Nil), 0) + new SparkPlanInfo("TestNode", "test string", Nil, Map(), Nil), 0) assert(reconstructedEvent == expectedEvent) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SameResultSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SameResultSuite.scala index aaf51b5b90111..d088e24e53bfe 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SameResultSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SameResultSuite.scala @@ -18,8 +18,11 @@ package org.apache.spark.sql.execution import org.apache.spark.sql.{DataFrame, QueryTest} +import org.apache.spark.sql.catalyst.expressions.AttributeReference +import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Project} import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types.IntegerType /** * Tests for the sameResult function for [[SparkPlan]]s. @@ -58,4 +61,16 @@ class SameResultSuite extends QueryTest with SharedSQLContext { val df4 = spark.range(10).agg(sumDistinct($"id")) assert(df3.queryExecution.executedPlan.sameResult(df4.queryExecution.executedPlan)) } + + test("Canonicalized result is case-insensitive") { + val a = AttributeReference("A", IntegerType)() + val b = AttributeReference("B", IntegerType)() + val planUppercase = Project(Seq(a), LocalRelation(a, b)) + + val c = AttributeReference("a", IntegerType)() + val d = AttributeReference("b", IntegerType)() + val planLowercase = Project(Seq(c), LocalRelation(c, d)) + + assert(planUppercase.sameResult(planLowercase)) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SelectedFieldSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SelectedFieldSuite.scala new file mode 100644 index 0000000000000..05f7e3ce83880 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SelectedFieldSuite.scala @@ -0,0 +1,455 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import org.scalatest.BeforeAndAfterAll +import org.scalatest.exceptions.TestFailedException + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.expressions.NamedExpression +import org.apache.spark.sql.catalyst.parser.CatalystSqlParser +import org.apache.spark.sql.catalyst.plans.logical.LocalRelation +import org.apache.spark.sql.types._ + +class SelectedFieldSuite extends SparkFunSuite with BeforeAndAfterAll { + private val ignoredField = StructField("col1", StringType, nullable = false) + + // The test schema as a tree string, i.e. `schema.treeString` + // root + // |-- col1: string (nullable = false) + // |-- col2: struct (nullable = true) + // | |-- field1: integer (nullable = true) + // | |-- field6: struct (nullable = true) + // | | |-- subfield1: string (nullable = false) + // | | |-- subfield2: string (nullable = true) + // | |-- field7: struct (nullable = true) + // | | |-- subfield1: struct (nullable = true) + // | | | |-- subsubfield1: integer (nullable = true) + // | | | |-- subsubfield2: integer (nullable = true) + // | |-- field9: map (nullable = true) + // | | |-- key: string + // | | |-- value: integer (valueContainsNull = false) + private val nestedComplex = StructType(ignoredField :: + StructField("col2", StructType( + StructField("field1", IntegerType) :: + StructField("field6", StructType( + StructField("subfield1", StringType, nullable = false) :: + StructField("subfield2", StringType) :: Nil)) :: + StructField("field7", StructType( + StructField("subfield1", StructType( + StructField("subsubfield1", IntegerType) :: + StructField("subsubfield2", IntegerType) :: Nil)) :: Nil)) :: + StructField("field9", + MapType(StringType, IntegerType, valueContainsNull = false)) :: Nil)) :: Nil) + + test("SelectedField should not match an attribute reference") { + val testRelation = LocalRelation(nestedComplex.toAttributes) + assertResult(None)(unapplySelect("col1", testRelation)) + assertResult(None)(unapplySelect("col1 as foo", testRelation)) + assertResult(None)(unapplySelect("col2", testRelation)) + } + + // |-- col1: string (nullable = false) + // |-- col2: struct (nullable = true) + // | |-- field2: array (nullable = true) + // | | |-- element: integer (containsNull = false) + // | |-- field3: array (nullable = false) + // | | |-- element: struct (containsNull = true) + // | | | |-- subfield1: integer (nullable = true) + // | | | |-- subfield2: integer (nullable = true) + // | | | |-- subfield3: array (nullable = true) + // | | | | |-- element: integer (containsNull = true) + private val structOfArray = StructType(ignoredField :: + StructField("col2", StructType( + StructField("field2", ArrayType(IntegerType, containsNull = false)) :: + StructField("field3", ArrayType(StructType( + StructField("subfield1", IntegerType) :: + StructField("subfield2", IntegerType) :: + StructField("subfield3", ArrayType(IntegerType)) :: Nil)), nullable = false) + :: Nil)) + :: Nil) + + testSelect(structOfArray, "col2.field2", "col2.field2[0] as foo") { + StructField("col2", StructType( + StructField("field2", ArrayType(IntegerType, containsNull = false)) :: Nil)) + } + + testSelect(nestedComplex, "col2.field9", "col2.field9['foo'] as foo") { + StructField("col2", StructType( + StructField("field9", MapType(StringType, IntegerType, valueContainsNull = false)) :: Nil)) + } + + testSelect(structOfArray, "col2.field3.subfield3", "col2.field3[0].subfield3 as foo", + "col2.field3.subfield3[0] as foo", "col2.field3[0].subfield3[0] as foo") { + StructField("col2", StructType( + StructField("field3", ArrayType(StructType( + StructField("subfield3", ArrayType(IntegerType)) :: Nil)), nullable = false) :: Nil)) + } + + testSelect(structOfArray, "col2.field3.subfield1") { + StructField("col2", StructType( + StructField("field3", ArrayType(StructType( + StructField("subfield1", IntegerType) :: Nil)), nullable = false) :: Nil)) + } + + // |-- col1: string (nullable = false) + // |-- col2: struct (nullable = true) + // | |-- field4: map (nullable = true) + // | | |-- key: string + // | | |-- value: struct (valueContainsNull = false) + // | | | |-- subfield1: integer (nullable = true) + // | | | |-- subfield2: array (nullable = true) + // | | | | |-- element: integer (containsNull = false) + // | |-- field8: map (nullable = true) + // | | |-- key: string + // | | |-- value: array (valueContainsNull = false) + // | | | |-- element: struct (containsNull = true) + // | | | | |-- subfield1: integer (nullable = true) + // | | | | |-- subfield2: array (nullable = true) + // | | | | | |-- element: integer (containsNull = false) + private val structWithMap = StructType( + ignoredField :: + StructField("col2", StructType( + StructField("field4", MapType(StringType, StructType( + StructField("subfield1", IntegerType) :: + StructField("subfield2", ArrayType(IntegerType, containsNull = false)) :: Nil + ), valueContainsNull = false)) :: + StructField("field8", MapType(StringType, ArrayType(StructType( + StructField("subfield1", IntegerType) :: + StructField("subfield2", ArrayType(IntegerType, containsNull = false)) :: Nil) + ), valueContainsNull = false)) :: Nil + )) :: Nil + ) + + testSelect(structWithMap, "col2.field4['foo'].subfield1 as foo") { + StructField("col2", StructType( + StructField("field4", MapType(StringType, StructType( + StructField("subfield1", IntegerType) :: Nil), valueContainsNull = false)) :: Nil)) + } + + testSelect(structWithMap, + "col2.field4['foo'].subfield2 as foo", "col2.field4['foo'].subfield2[0] as foo") { + StructField("col2", StructType( + StructField("field4", MapType(StringType, StructType( + StructField("subfield2", ArrayType(IntegerType, containsNull = false)) + :: Nil), valueContainsNull = false)) :: Nil)) + } + + // |-- col1: string (nullable = false) + // |-- col2: struct (nullable = true) + // | |-- field5: array (nullable = false) + // | | |-- element: struct (containsNull = true) + // | | | |-- subfield1: struct (nullable = false) + // | | | | |-- subsubfield1: integer (nullable = true) + // | | | | |-- subsubfield2: integer (nullable = true) + // | | | |-- subfield2: struct (nullable = true) + // | | | | |-- subsubfield1: struct (nullable = true) + // | | | | | |-- subsubsubfield1: string (nullable = true) + // | | | | |-- subsubfield2: integer (nullable = true) + private val structWithArray = StructType( + ignoredField :: + StructField("col2", StructType( + StructField("field5", ArrayType(StructType( + StructField("subfield1", StructType( + StructField("subsubfield1", IntegerType) :: + StructField("subsubfield2", IntegerType) :: Nil), nullable = false) :: + StructField("subfield2", StructType( + StructField("subsubfield1", StructType( + StructField("subsubsubfield1", StringType) :: Nil)) :: + StructField("subsubfield2", IntegerType) :: Nil)) :: Nil)), nullable = false) :: Nil) + ) :: Nil + ) + + testSelect(structWithArray, "col2.field5.subfield1") { + StructField("col2", StructType( + StructField("field5", ArrayType(StructType( + StructField("subfield1", StructType( + StructField("subsubfield1", IntegerType) :: + StructField("subsubfield2", IntegerType) :: Nil), nullable = false) + :: Nil)), nullable = false) :: Nil)) + } + + testSelect(structWithArray, "col2.field5.subfield1.subsubfield1") { + StructField("col2", StructType( + StructField("field5", ArrayType(StructType( + StructField("subfield1", StructType( + StructField("subsubfield1", IntegerType) :: Nil), nullable = false) + :: Nil)), nullable = false) :: Nil)) + } + + testSelect(structWithArray, "col2.field5.subfield2.subsubfield1.subsubsubfield1") { + StructField("col2", StructType( + StructField("field5", ArrayType(StructType( + StructField("subfield2", StructType( + StructField("subsubfield1", StructType( + StructField("subsubsubfield1", StringType) :: Nil)) :: Nil)) + :: Nil)), nullable = false) :: Nil)) + } + + testSelect(structWithMap, "col2.field8['foo'][0].subfield1 as foo") { + StructField("col2", StructType( + StructField("field8", MapType(StringType, ArrayType(StructType( + StructField("subfield1", IntegerType) :: Nil)), valueContainsNull = false)) :: Nil)) + } + + testSelect(nestedComplex, "col2.field1") { + StructField("col2", StructType( + StructField("field1", IntegerType) :: Nil)) + } + + testSelect(nestedComplex, "col2.field6") { + StructField("col2", StructType( + StructField("field6", StructType( + StructField("subfield1", StringType, nullable = false) :: + StructField("subfield2", StringType) :: Nil)) :: Nil)) + } + + testSelect(nestedComplex, "col2.field6.subfield1") { + StructField("col2", StructType( + StructField("field6", StructType( + StructField("subfield1", StringType, nullable = false) :: Nil)) :: Nil)) + } + + testSelect(nestedComplex, "col2.field7.subfield1") { + StructField("col2", StructType( + StructField("field7", StructType( + StructField("subfield1", StructType( + StructField("subsubfield1", IntegerType) :: + StructField("subsubfield2", IntegerType) :: Nil)) :: Nil)) :: Nil)) + } + + // |-- col1: string (nullable = false) + // |-- col3: array (nullable = false) + // | |-- element: struct (containsNull = false) + // | | |-- field1: struct (nullable = true) + // | | | |-- subfield1: integer (nullable = false) + // | | | |-- subfield2: integer (nullable = true) + // | | |-- field2: map (nullable = true) + // | | | |-- key: string + // | | | |-- value: integer (valueContainsNull = false) + private val arrayWithStructAndMap = StructType(Array( + StructField("col3", ArrayType(StructType( + StructField("field1", StructType( + StructField("subfield1", IntegerType, nullable = false) :: + StructField("subfield2", IntegerType) :: Nil)) :: + StructField("field2", MapType(StringType, IntegerType, valueContainsNull = false)) + :: Nil), containsNull = false), nullable = false) + )) + + testSelect(arrayWithStructAndMap, "col3.field1.subfield1") { + StructField("col3", ArrayType(StructType( + StructField("field1", StructType( + StructField("subfield1", IntegerType, nullable = false) :: Nil)) + :: Nil), containsNull = false), nullable = false) + } + + testSelect(arrayWithStructAndMap, "col3.field2['foo'] as foo") { + StructField("col3", ArrayType(StructType( + StructField("field2", MapType(StringType, IntegerType, valueContainsNull = false)) + :: Nil), containsNull = false), nullable = false) + } + + // |-- col1: string (nullable = false) + // |-- col4: map (nullable = false) + // | |-- key: string + // | |-- value: struct (valueContainsNull = false) + // | | |-- field1: struct (nullable = true) + // | | | |-- subfield1: integer (nullable = false) + // | | | |-- subfield2: integer (nullable = true) + // | | |-- field2: map (nullable = true) + // | | | |-- key: string + // | | | |-- value: integer (valueContainsNull = false) + private val col4 = StructType(Array(ignoredField, + StructField("col4", MapType(StringType, StructType( + StructField("field1", StructType( + StructField("subfield1", IntegerType, nullable = false) :: + StructField("subfield2", IntegerType) :: Nil)) :: + StructField("field2", MapType(StringType, IntegerType, valueContainsNull = false)) + :: Nil), valueContainsNull = false), nullable = false) + )) + + testSelect(col4, "col4['foo'].field1.subfield1 as foo") { + StructField("col4", MapType(StringType, StructType( + StructField("field1", StructType( + StructField("subfield1", IntegerType, nullable = false) :: Nil)) + :: Nil), valueContainsNull = false), nullable = false) + } + + testSelect(col4, "col4['foo'].field2['bar'] as foo") { + StructField("col4", MapType(StringType, StructType( + StructField("field2", MapType(StringType, IntegerType, valueContainsNull = false)) + :: Nil), valueContainsNull = false), nullable = false) + } + + // |-- col1: string (nullable = false) + // |-- col5: array (nullable = true) + // | |-- element: map (containsNull = true) + // | | |-- key: string + // | | |-- value: struct (valueContainsNull = false) + // | | | |-- field1: struct (nullable = true) + // | | | | |-- subfield1: integer (nullable = true) + // | | | | |-- subfield2: integer (nullable = true) + private val arrayOfStruct = StructType(Array(ignoredField, + StructField("col5", ArrayType(MapType(StringType, StructType( + StructField("field1", StructType( + StructField("subfield1", IntegerType) :: + StructField("subfield2", IntegerType) :: Nil)) :: Nil), valueContainsNull = false))) + )) + + testSelect(arrayOfStruct, "col5[0]['foo'].field1.subfield1 as foo") { + StructField("col5", ArrayType(MapType(StringType, StructType( + StructField("field1", StructType( + StructField("subfield1", IntegerType) :: Nil)) :: Nil), valueContainsNull = false))) + } + + // |-- col1: string (nullable = false) + // |-- col6: map (nullable = true) + // | |-- key: string + // | |-- value: array (valueContainsNull = true) + // | | |-- element: struct (containsNull = false) + // | | | |-- field1: struct (nullable = true) + // | | | | |-- subfield1: integer (nullable = true) + // | | | | |-- subfield2: integer (nullable = true) + private val mapOfArray = StructType(Array(ignoredField, + StructField("col6", MapType(StringType, ArrayType(StructType( + StructField("field1", StructType( + StructField("subfield1", IntegerType) :: + StructField("subfield2", IntegerType) :: Nil)) :: Nil), containsNull = false))))) + + testSelect(mapOfArray, "col6['foo'][0].field1.subfield1 as foo") { + StructField("col6", MapType(StringType, ArrayType(StructType( + StructField("field1", StructType( + StructField("subfield1", IntegerType) :: Nil)) :: Nil), containsNull = false))) + } + + // An array with a struct with a different fields + // |-- col1: string (nullable = false) + // |-- col7: array (nullable = true) + // | |-- element: struct (containsNull = true) + // | | |-- field1: integer (nullable = false) + // | | |-- field2: struct (nullable = true) + // | | | |-- subfield1: integer (nullable = false) + // | | |-- field3: array (nullable = true) + // | | | |-- element: struct (containsNull = true) + // | | | | |-- subfield1: integer (nullable = false) + private val arrayWithMultipleFields = StructType(Array(ignoredField, + StructField("col7", ArrayType(StructType( + StructField("field1", IntegerType, nullable = false) :: + StructField("field2", StructType( + StructField("subfield1", IntegerType, nullable = false) :: Nil)) :: + StructField("field3", ArrayType(StructType( + StructField("subfield1", IntegerType, nullable = false) :: Nil))) :: Nil))))) + + testSelect(arrayWithMultipleFields, + "col7.field1", "col7[0].field1 as foo", "col7.field1[0] as foo") { + StructField("col7", ArrayType(StructType( + StructField("field1", IntegerType, nullable = false) :: Nil))) + } + + testSelect(arrayWithMultipleFields, "col7.field2.subfield1") { + StructField("col7", ArrayType(StructType( + StructField("field2", StructType( + StructField("subfield1", IntegerType, nullable = false) :: Nil)) :: Nil))) + } + + testSelect(arrayWithMultipleFields, "col7.field3.subfield1") { + StructField("col7", ArrayType(StructType( + StructField("field3", ArrayType(StructType( + StructField("subfield1", IntegerType, nullable = false) :: Nil))) :: Nil))) + } + + // Array with a nested int array + // |-- col1: string (nullable = false) + // |-- col8: array (nullable = true) + // | |-- element: struct (containsNull = true) + // | | |-- field1: array (nullable = false) + // | | | |-- element: integer (containsNull = false) + private val arrayOfArray = StructType(Array(ignoredField, + StructField("col8", + ArrayType(StructType(Array(StructField("field1", + ArrayType(IntegerType, containsNull = false), nullable = false)))) + ))) + + testSelect(arrayOfArray, "col8.field1", + "col8[0].field1 as foo", + "col8.field1[0] as foo", + "col8[0].field1[0] as foo") { + StructField("col8", ArrayType(StructType( + StructField("field1", ArrayType(IntegerType, containsNull = false), nullable = false) + :: Nil))) + } + + def assertResult(expected: StructField)(actual: StructField)(selectExpr: String): Unit = { + try { + super.assertResult(expected)(actual) + } catch { + case ex: TestFailedException => + // Print some helpful diagnostics in the case of failure + alert("Expected SELECT \"" + selectExpr + "\" to select the schema\n" + + indent(StructType(expected :: Nil).treeString) + + indent("but it actually selected\n") + + indent(StructType(actual :: Nil).treeString) + + indent("Note that expected.dataType.sameType(actual.dataType) = " + + expected.dataType.sameType(actual.dataType))) + throw ex + } + } + + // Test that the given SELECT expressions prune the test schema to the single-column schema + // defined by the given field + private def testSelect(inputSchema: StructType, selectExprs: String*) + (expected: StructField) { + test(s"SELECT ${selectExprs.map(s => s""""$s"""").mkString(", ")} should select the schema\n" + + indent(StructType(expected :: Nil).treeString)) { + for (selectExpr <- selectExprs) { + assertSelect(selectExpr, expected, inputSchema) + } + } + } + + private def assertSelect(expr: String, expected: StructField, inputSchema: StructType): Unit = { + val relation = LocalRelation(inputSchema.toAttributes) + unapplySelect(expr, relation) match { + case Some(field) => + assertResult(expected)(field)(expr) + case None => + val failureMessage = + "Failed to select a field from " + expr + ". " + + "Expected:\n" + + StructType(expected :: Nil).treeString + fail(failureMessage) + } + } + + private def unapplySelect(expr: String, relation: LocalRelation) = { + val parsedExpr = parseAsCatalystExpression(Seq(expr)).head + val select = relation.select(parsedExpr) + val analyzed = select.analyze + SelectedField.unapply(analyzed.expressions.head) + } + + private def parseAsCatalystExpression(exprs: Seq[String]) = { + exprs.map(CatalystSqlParser.parseExpression(_) match { + case namedExpr: NamedExpression => namedExpr + }) + } + + // Indent every line in `string` by four spaces + private def indent(string: String) = string.replaceAll("(?m)^", " ") +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanSuite.scala index 750d9e4adf8b4..47ff372992b91 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.execution +import org.apache.spark.SparkEnv import org.apache.spark.sql.QueryTest import org.apache.spark.sql.test.SharedSQLContext @@ -33,4 +34,28 @@ class SparkPlanSuite extends QueryTest with SharedSQLContext { intercept[IllegalStateException] { plan.executeTake(1) } } + test("SPARK-23731 plans should be canonicalizable after being (de)serialized") { + withTempPath { path => + spark.range(1).write.parquet(path.getAbsolutePath) + val df = spark.read.parquet(path.getAbsolutePath) + val fileSourceScanExec = + df.queryExecution.sparkPlan.collectFirst { case p: FileSourceScanExec => p }.get + val serializer = SparkEnv.get.serializer.newInstance() + val readback = + serializer.deserialize[FileSourceScanExec](serializer.serialize(fileSourceScanExec)) + try { + readback.canonicalized + } catch { + case e: Throwable => fail("FileSourceScanExec was not canonicalizable", e) + } + } + } + + test("SPARK-25357 SparkPlanInfo of FileScan contains nonEmpty metadata") { + withTempPath { path => + spark.range(5).write.parquet(path.getAbsolutePath) + val f = spark.read.parquet(path.getAbsolutePath) + assert(SparkPlanInfo.fromSparkPlan(f.queryExecution.sparkPlan).metadata.nonEmpty) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala index 107a2f7109793..28a060aff47b5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala @@ -366,4 +366,15 @@ class SparkSqlParserSuite extends AnalysisTest { "SELECT a || b || c FROM t", Project(UnresolvedAlias(concat) :: Nil, UnresolvedRelation(TableIdentifier("t")))) } + + test("SPARK-25046 Fix Alter View ... As Insert Into Table") { + // Single insert query + intercept("ALTER VIEW testView AS INSERT INTO jt VALUES(1, 1)", + "Operation not allowed: ALTER VIEW ... AS INSERT INTO") + + // Multi insert query + intercept("ALTER VIEW testView AS FROM jt INSERT INTO tbl1 SELECT * WHERE jt.id < 5 " + + "INSERT INTO tbl2 SELECT * WHERE jt.id > 4", + "Operation not allowed: ALTER VIEW ... AS FROM ... [INSERT INTO ...]+") + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/TakeOrderedAndProjectSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/TakeOrderedAndProjectSuite.scala index 7e317a4d80265..9322204063af3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/TakeOrderedAndProjectSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/TakeOrderedAndProjectSuite.scala @@ -19,9 +19,10 @@ package org.apache.spark.sql.execution import scala.util.Random -import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.{DataFrame, Row, SparkSession} import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions.Literal +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ @@ -37,14 +38,6 @@ class TakeOrderedAndProjectSuite extends SparkPlanTest with SharedSQLContext { rand = new Random(seed) } - private def generateRandomInputData(): DataFrame = { - val schema = new StructType() - .add("a", IntegerType, nullable = false) - .add("b", IntegerType, nullable = false) - val inputData = Seq.fill(10000)(Row(rand.nextInt(), rand.nextInt())) - spark.createDataFrame(sparkContext.parallelize(Random.shuffle(inputData), 10), schema) - } - /** * Adds a no-op filter to the child plan in order to prevent executeCollect() from being * called directly on the child plan. @@ -55,32 +48,62 @@ class TakeOrderedAndProjectSuite extends SparkPlanTest with SharedSQLContext { val sortOrder = 'a.desc :: 'b.desc :: Nil test("TakeOrderedAndProject.doExecute without project") { - withClue(s"seed = $seed") { - checkThatPlansAgree( - generateRandomInputData(), - input => - noOpFilter(TakeOrderedAndProjectExec(limit, sortOrder, input.output, input)), - input => - GlobalLimitExec(limit, - LocalLimitExec(limit, - SortExec(sortOrder, true, input))), - sortAnswers = false) + withSQLConf(SQLConf.LIMIT_FLAT_GLOBAL_LIMIT.key -> "false") { + withClue(s"seed = $seed") { + checkThatPlansAgree( + LimitTest.generateRandomInputData(spark, rand), + input => + noOpFilter(TakeOrderedAndProjectExec(limit, sortOrder, input.output, input)), + input => + GlobalLimitExec(limit, + LocalLimitExec(limit, + SortExec(sortOrder, true, input))), + sortAnswers = false) + } } } test("TakeOrderedAndProject.doExecute with project") { - withClue(s"seed = $seed") { - checkThatPlansAgree( - generateRandomInputData(), - input => - noOpFilter( - TakeOrderedAndProjectExec(limit, sortOrder, Seq(input.output.last), input)), - input => - GlobalLimitExec(limit, - LocalLimitExec(limit, - ProjectExec(Seq(input.output.last), - SortExec(sortOrder, true, input)))), - sortAnswers = false) + withSQLConf(SQLConf.LIMIT_FLAT_GLOBAL_LIMIT.key -> "false") { + withClue(s"seed = $seed") { + checkThatPlansAgree( + LimitTest.generateRandomInputData(spark, rand), + input => + noOpFilter( + TakeOrderedAndProjectExec(limit, sortOrder, Seq(input.output.last), input)), + input => + GlobalLimitExec(limit, + LocalLimitExec(limit, + ProjectExec(Seq(input.output.last), + SortExec(sortOrder, true, input)))), + sortAnswers = false) + } + } + } + + test("TakeOrderedAndProject.doExecute equals to ordered global limit") { + withSQLConf(SQLConf.LIMIT_FLAT_GLOBAL_LIMIT.key -> "true") { + withClue(s"seed = $seed") { + checkThatPlansAgree( + LimitTest.generateRandomInputData(spark, rand), + input => + noOpFilter(TakeOrderedAndProjectExec(limit, sortOrder, input.output, input)), + input => + GlobalLimitExec(limit, + LocalLimitExec(limit, + SortExec(sortOrder, true, input)), orderedLimit = true), + sortAnswers = false) + } } } } + +object LimitTest { + def generateRandomInputData(spark: SparkSession, rand: Random): DataFrame = { + val schema = new StructType() + .add("a", IntegerType, nullable = false) + .add("b", IntegerType, nullable = false) + val inputData = Seq.fill(10000)(Row(rand.nextInt(), rand.nextInt())) + spark.createDataFrame(spark.sparkContext.parallelize(Random.shuffle(inputData), 10), schema) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala index 3e31d22e15c0e..5c15ecd42fa0c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala @@ -23,6 +23,7 @@ import scala.collection.mutable import scala.util.{Random, Try} import scala.util.control.NonFatal +import org.mockito.Mockito._ import org.scalatest.Matchers import org.apache.spark.{SparkConf, SparkFunSuite, TaskContext, TaskContextImpl} @@ -54,6 +55,8 @@ class UnsafeFixedWidthAggregationMapSuite private var memoryManager: TestMemoryManager = null private var taskMemoryManager: TaskMemoryManager = null + private var taskContext: TaskContext = null + def testWithMemoryLeakDetection(name: String)(f: => Unit) { def cleanup(): Unit = { if (taskMemoryManager != null) { @@ -67,6 +70,8 @@ class UnsafeFixedWidthAggregationMapSuite val conf = new SparkConf().set(MEMORY_OFFHEAP_ENABLED.key, "false") memoryManager = new TestMemoryManager(conf) taskMemoryManager = new TaskMemoryManager(memoryManager, 0) + taskContext = mock(classOf[TaskContext]) + when(taskContext.taskMemoryManager()).thenReturn(taskMemoryManager) TaskContext.setTaskContext(new TaskContextImpl( stageId = 0, @@ -111,7 +116,7 @@ class UnsafeFixedWidthAggregationMapSuite emptyAggregationBuffer, aggBufferSchema, groupKeySchema, - taskMemoryManager, + taskContext, 1024, // initial capacity, PAGE_SIZE_BYTES ) @@ -124,7 +129,7 @@ class UnsafeFixedWidthAggregationMapSuite emptyAggregationBuffer, aggBufferSchema, groupKeySchema, - taskMemoryManager, + taskContext, 1024, // initial capacity PAGE_SIZE_BYTES ) @@ -151,7 +156,7 @@ class UnsafeFixedWidthAggregationMapSuite emptyAggregationBuffer, aggBufferSchema, groupKeySchema, - taskMemoryManager, + taskContext, 128, // initial capacity PAGE_SIZE_BYTES ) @@ -176,7 +181,7 @@ class UnsafeFixedWidthAggregationMapSuite emptyAggregationBuffer, aggBufferSchema, groupKeySchema, - taskMemoryManager, + taskContext, 128, // initial capacity PAGE_SIZE_BYTES ) @@ -223,7 +228,7 @@ class UnsafeFixedWidthAggregationMapSuite emptyAggregationBuffer, aggBufferSchema, groupKeySchema, - taskMemoryManager, + taskContext, 128, // initial capacity PAGE_SIZE_BYTES ) @@ -263,7 +268,7 @@ class UnsafeFixedWidthAggregationMapSuite emptyAggregationBuffer, StructType(Nil), StructType(Nil), - taskMemoryManager, + taskContext, 128, // initial capacity PAGE_SIZE_BYTES ) @@ -307,7 +312,7 @@ class UnsafeFixedWidthAggregationMapSuite emptyAggregationBuffer, aggBufferSchema, groupKeySchema, - taskMemoryManager, + taskContext, 128, // initial capacity pageSize ) @@ -344,7 +349,7 @@ class UnsafeFixedWidthAggregationMapSuite emptyAggregationBuffer, aggBufferSchema, groupKeySchema, - taskMemoryManager, + taskContext, 128, // initial capacity pageSize ) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationStoreSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationStoreSuite.scala index 3fad7dfddadcc..dc67446460877 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationStoreSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationStoreSuite.scala @@ -39,7 +39,11 @@ class SortBasedAggregationStoreSuite extends SparkFunSuite with LocalSparkConte new TaskContextImpl(0, 0, 0, 0, 0, taskManager, new Properties, null)) } - override def afterAll(): Unit = TaskContext.unset() + override def afterAll(): Unit = try { + TaskContext.unset() + } finally { + super.afterAll() + } private val rand = new java.util.Random() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala index 261df06100aef..c36872a6a5289 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala @@ -16,7 +16,7 @@ */ package org.apache.spark.sql.execution.arrow -import java.io.File +import java.io.{ByteArrayOutputStream, DataOutputStream, File} import java.nio.charset.StandardCharsets import java.sql.{Date, Timestamp} import java.text.SimpleDateFormat @@ -26,7 +26,7 @@ import com.google.common.io.Files import org.apache.arrow.memory.RootAllocator import org.apache.arrow.vector.{VectorLoader, VectorSchemaRoot} import org.apache.arrow.vector.ipc.JsonFileReader -import org.apache.arrow.vector.util.Validator +import org.apache.arrow.vector.util.{ByteArrayReadableSeekableByteChannel, Validator} import org.scalatest.BeforeAndAfterAll import org.apache.spark.{SparkException, TaskContext} @@ -51,11 +51,11 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { test("collect to arrow record batch") { val indexData = (1 to 6).toDF("i") - val arrowPayloads = indexData.toArrowPayload.collect() - assert(arrowPayloads.nonEmpty) - assert(arrowPayloads.length == indexData.rdd.getNumPartitions) + val arrowBatches = indexData.toArrowBatchRdd.collect() + assert(arrowBatches.nonEmpty) + assert(arrowBatches.length == indexData.rdd.getNumPartitions) val allocator = new RootAllocator(Long.MaxValue) - val arrowRecordBatches = arrowPayloads.map(_.loadBatch(allocator)) + val arrowRecordBatches = arrowBatches.map(ArrowConverters.loadBatch(_, allocator)) val rowCount = arrowRecordBatches.map(_.getLength).sum assert(rowCount === indexData.count()) arrowRecordBatches.foreach(batch => assert(batch.getNodes.size() > 0)) @@ -1153,9 +1153,9 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { |} """.stripMargin - val arrowPayloads = testData2.toArrowPayload.collect() - // NOTE: testData2 should have 2 partitions -> 2 arrow batches in payload - assert(arrowPayloads.length === 2) + val arrowBatches = testData2.toArrowBatchRdd.collect() + // NOTE: testData2 should have 2 partitions -> 2 arrow batches + assert(arrowBatches.length === 2) val schema = testData2.schema val tempFile1 = new File(tempDataPath, "testData2-ints-part1.json") @@ -1163,25 +1163,25 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { Files.write(json1, tempFile1, StandardCharsets.UTF_8) Files.write(json2, tempFile2, StandardCharsets.UTF_8) - validateConversion(schema, arrowPayloads(0), tempFile1) - validateConversion(schema, arrowPayloads(1), tempFile2) + validateConversion(schema, arrowBatches(0), tempFile1) + validateConversion(schema, arrowBatches(1), tempFile2) } test("empty frame collect") { - val arrowPayload = spark.emptyDataFrame.toArrowPayload.collect() - assert(arrowPayload.isEmpty) + val arrowBatches = spark.emptyDataFrame.toArrowBatchRdd.collect() + assert(arrowBatches.isEmpty) val filteredDF = List[Int](1, 2, 3, 4, 5, 6).toDF("i") - val filteredArrowPayload = filteredDF.filter("i < 0").toArrowPayload.collect() - assert(filteredArrowPayload.isEmpty) + val filteredArrowBatches = filteredDF.filter("i < 0").toArrowBatchRdd.collect() + assert(filteredArrowBatches.isEmpty) } test("empty partition collect") { val emptyPart = spark.sparkContext.parallelize(Seq(1), 2).toDF("i") - val arrowPayloads = emptyPart.toArrowPayload.collect() - assert(arrowPayloads.length === 1) + val arrowBatches = emptyPart.toArrowBatchRdd.collect() + assert(arrowBatches.length === 1) val allocator = new RootAllocator(Long.MaxValue) - val arrowRecordBatches = arrowPayloads.map(_.loadBatch(allocator)) + val arrowRecordBatches = arrowBatches.map(ArrowConverters.loadBatch(_, allocator)) assert(arrowRecordBatches.head.getLength == 1) arrowRecordBatches.foreach(_.close()) allocator.close() @@ -1192,10 +1192,10 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { val maxRecordsPerBatch = 3 spark.conf.set("spark.sql.execution.arrow.maxRecordsPerBatch", maxRecordsPerBatch) val df = spark.sparkContext.parallelize(1 to totalRecords, 2).toDF("i") - val arrowPayloads = df.toArrowPayload.collect() - assert(arrowPayloads.length >= 4) + val arrowBatches = df.toArrowBatchRdd.collect() + assert(arrowBatches.length >= 4) val allocator = new RootAllocator(Long.MaxValue) - val arrowRecordBatches = arrowPayloads.map(_.loadBatch(allocator)) + val arrowRecordBatches = arrowBatches.map(ArrowConverters.loadBatch(_, allocator)) var recordCount = 0 arrowRecordBatches.foreach { batch => assert(batch.getLength > 0) @@ -1217,8 +1217,8 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { assert(msg.getCause.getClass === classOf[UnsupportedOperationException]) } - runUnsupported { mapData.toDF().toArrowPayload.collect() } - runUnsupported { complexData.toArrowPayload.collect() } + runUnsupported { mapData.toDF().toArrowBatchRdd.collect() } + runUnsupported { complexData.toArrowBatchRdd.collect() } } test("test Arrow Validator") { @@ -1318,7 +1318,7 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { } } - test("roundtrip payloads") { + test("roundtrip arrow batches") { val inputRows = (0 until 9).map { i => InternalRow(i) } :+ InternalRow(null) @@ -1326,10 +1326,41 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { val schema = StructType(Seq(StructField("int", IntegerType, nullable = true))) val ctx = TaskContext.empty() - val payloadIter = ArrowConverters.toPayloadIterator(inputRows.toIterator, schema, 0, null, ctx) - val outputRowIter = ArrowConverters.fromPayloadIterator(payloadIter, ctx) + val batchIter = ArrowConverters.toBatchIterator(inputRows.toIterator, schema, 5, null, ctx) + val outputRowIter = ArrowConverters.fromBatchIterator(batchIter, schema, null, ctx) - assert(schema == outputRowIter.schema) + var count = 0 + outputRowIter.zipWithIndex.foreach { case (row, i) => + if (i != 9) { + assert(row.getInt(0) == i) + } else { + assert(row.isNullAt(0)) + } + count += 1 + } + + assert(count == inputRows.length) + } + + test("ArrowBatchStreamWriter roundtrip") { + val inputRows = (0 until 9).map(InternalRow(_)) :+ InternalRow(null) + + val schema = StructType(Seq(StructField("int", IntegerType, nullable = true))) + val ctx = TaskContext.empty() + val batchIter = ArrowConverters.toBatchIterator(inputRows.toIterator, schema, 5, null, ctx) + + // Write batches to Arrow stream format as a byte array + val out = new ByteArrayOutputStream() + Utils.tryWithResource(new DataOutputStream(out)) { dataOut => + val writer = new ArrowBatchStreamWriter(schema, dataOut, null) + writer.writeBatches(batchIter) + writer.end() + } + + // Read Arrow stream into batches, then convert back to rows + val in = new ByteArrayReadableSeekableByteChannel(out.toByteArray) + val readBatches = ArrowConverters.getBatchesFromStream(in) + val outputRowIter = ArrowConverters.fromBatchIterator(readBatches, schema, null, ctx) var count = 0 outputRowIter.zipWithIndex.foreach { case (row, i) => @@ -1348,15 +1379,15 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { private def collectAndValidate( df: DataFrame, json: String, file: String, timeZoneId: String = null): Unit = { // NOTE: coalesce to single partition because can only load 1 batch in validator - val arrowPayload = df.coalesce(1).toArrowPayload.collect().head + val batchBytes = df.coalesce(1).toArrowBatchRdd.collect().head val tempFile = new File(tempDataPath, file) Files.write(json, tempFile, StandardCharsets.UTF_8) - validateConversion(df.schema, arrowPayload, tempFile, timeZoneId) + validateConversion(df.schema, batchBytes, tempFile, timeZoneId) } private def validateConversion( sparkSchema: StructType, - arrowPayload: ArrowPayload, + batchBytes: Array[Byte], jsonFile: File, timeZoneId: String = null): Unit = { val allocator = new RootAllocator(Long.MaxValue) @@ -1368,7 +1399,7 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { val arrowRoot = VectorSchemaRoot.create(arrowSchema, allocator) val vectorLoader = new VectorLoader(arrowRoot) - val arrowRecordBatch = arrowPayload.loadBatch(allocator) + val arrowRecordBatch = ArrowConverters.loadBatch(batchBytes, allocator) vectorLoader.load(arrowRecordBatch) val jsonRoot = jsonReader.read() Validator.compareVectorSchemaRoot(arrowRoot, jsonRoot) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/FilterPushdownBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/FilterPushdownBenchmark.scala index 6d7c7de9a856e..8596abd1b4ff2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/FilterPushdownBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/FilterPushdownBenchmark.scala @@ -17,25 +17,31 @@ package org.apache.spark.sql.execution.benchmark -import java.io.File +import java.io.{File, FileOutputStream, OutputStream} import scala.util.{Random, Try} +import org.scalatest.{BeforeAndAfterEachTestData, Suite, TestData} + import org.apache.spark.SparkConf +import org.apache.spark.SparkFunSuite import org.apache.spark.sql.{DataFrame, SparkSession} import org.apache.spark.sql.functions.monotonically_increasing_id import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.internal.SQLConf.ParquetOutputTimestampType +import org.apache.spark.sql.types.{ByteType, Decimal, DecimalType, TimestampType} import org.apache.spark.util.{Benchmark, Utils} - /** * Benchmark to measure read performance with Filter pushdown. * To run this: - * spark-submit --class + * build/sbt "sql/test-only *FilterPushdownBenchmark" + * + * Results will be written to "benchmarks/FilterPushdownBenchmark-results.txt". */ -object FilterPushdownBenchmark { - val conf = new SparkConf() - .setAppName("FilterPushdownBenchmark") +class FilterPushdownBenchmark extends SparkFunSuite with BenchmarkBeforeAndAfterEachTest { + private val conf = new SparkConf() + .setAppName(this.getClass.getSimpleName) // Since `spark.master` always exists, overrides this value .set("spark.master", "local[1]") .setIfMissing("spark.driver.memory", "3g") @@ -44,8 +50,40 @@ object FilterPushdownBenchmark { .setIfMissing("orc.compression", "snappy") .setIfMissing("spark.sql.parquet.compression.codec", "snappy") + private val numRows = 1024 * 1024 * 15 + private val width = 5 + private val mid = numRows / 2 + private val blockSize = 1048576 + private val spark = SparkSession.builder().config(conf).getOrCreate() + private var out: OutputStream = _ + + override def beforeAll() { + super.beforeAll() + out = new FileOutputStream(new File("benchmarks/FilterPushdownBenchmark-results.txt")) + } + + override def beforeEach(td: TestData) { + super.beforeEach(td) + val separator = "=" * 96 + val testHeader = (separator + '\n' + td.name + '\n' + separator + '\n' + '\n').getBytes + out.write(testHeader) + } + + override def afterEach(td: TestData) { + out.write('\n') + super.afterEach(td) + } + + override def afterAll() { + try { + out.close() + } finally { + super.afterAll() + } + } + def withTempPath(f: File => Unit): Unit = { val path = Utils.createTempDir() path.delete() @@ -81,8 +119,7 @@ object FilterPushdownBenchmark { .withColumn("value", valueCol) .sort("value") - saveAsOrcTable(df, dir.getCanonicalPath + "/orc") - saveAsParquetTable(df, dir.getCanonicalPath + "/parquet") + saveAsTable(df, dir) } private def prepareStringDictTable( @@ -93,19 +130,22 @@ object FilterPushdownBenchmark { } val df = spark.range(numRows).selectExpr(selectExpr: _*).sort("value") - saveAsOrcTable(df, dir.getCanonicalPath + "/orc") - saveAsParquetTable(df, dir.getCanonicalPath + "/parquet") + saveAsTable(df, dir) } - private def saveAsOrcTable(df: DataFrame, dir: String): Unit = { - // To always turn on dictionary encoding, we set 1.0 at the threshold (the default is 0.8) - df.write.mode("overwrite").option("orc.dictionary.key.threshold", 1.0).orc(dir) - spark.read.orc(dir).createOrReplaceTempView("orcTable") - } + private def saveAsTable(df: DataFrame, dir: File): Unit = { + val orcPath = dir.getCanonicalPath + "/orc" + val parquetPath = dir.getCanonicalPath + "/parquet" - private def saveAsParquetTable(df: DataFrame, dir: String): Unit = { - df.write.mode("overwrite").parquet(dir) - spark.read.parquet(dir).createOrReplaceTempView("parquetTable") + // To always turn on dictionary encoding, we set 1.0 at the threshold (the default is 0.8) + df.write.mode("overwrite") + .option("orc.dictionary.key.threshold", 1.0) + .option("orc.stripe.size", blockSize).orc(orcPath) + spark.read.orc(orcPath).createOrReplaceTempView("orcTable") + + df.write.mode("overwrite") + .option("parquet.block.size", blockSize).parquet(parquetPath) + spark.read.parquet(parquetPath).createOrReplaceTempView("parquetTable") } def filterPushDownBenchmark( @@ -113,7 +153,7 @@ object FilterPushdownBenchmark { title: String, whereExpr: String, selectExpr: String = "*"): Unit = { - val benchmark = new Benchmark(title, values, minNumIters = 5) + val benchmark = new Benchmark(title, values, minNumIters = 5, output = Some(out)) Seq(false, true).foreach { pushDownEnabled => val name = s"Parquet Vectorized ${if (pushDownEnabled) s"(Pushdown)" else ""}" @@ -133,214 +173,6 @@ object FilterPushdownBenchmark { } } - /* - OpenJDK 64-Bit Server VM 1.8.0_171-b10 on Linux 4.14.33-51.37.amzn1.x86_64 - Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz - Select 0 string row (value IS NULL): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------------ - Parquet Vectorized 9201 / 9300 1.7 585.0 1.0X - Parquet Vectorized (Pushdown) 89 / 105 176.3 5.7 103.1X - Native ORC Vectorized 8886 / 8898 1.8 564.9 1.0X - Native ORC Vectorized (Pushdown) 110 / 128 143.4 7.0 83.9X - - - Select 0 string row - ('7864320' < value < '7864320'): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------------ - Parquet Vectorized 9336 / 9357 1.7 593.6 1.0X - Parquet Vectorized (Pushdown) 927 / 937 17.0 58.9 10.1X - Native ORC Vectorized 9026 / 9041 1.7 573.9 1.0X - Native ORC Vectorized (Pushdown) 257 / 272 61.1 16.4 36.3X - - - Select 1 string row (value = '7864320'): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------------ - Parquet Vectorized 9209 / 9223 1.7 585.5 1.0X - Parquet Vectorized (Pushdown) 908 / 925 17.3 57.7 10.1X - Native ORC Vectorized 8878 / 8904 1.8 564.4 1.0X - Native ORC Vectorized (Pushdown) 248 / 261 63.4 15.8 37.1X - - - Select 1 string row - (value <=> '7864320'): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------------ - Parquet Vectorized 9194 / 9216 1.7 584.5 1.0X - Parquet Vectorized (Pushdown) 899 / 908 17.5 57.2 10.2X - Native ORC Vectorized 8934 / 8962 1.8 568.0 1.0X - Native ORC Vectorized (Pushdown) 249 / 254 63.3 15.8 37.0X - - - Select 1 string row - ('7864320' <= value <= '7864320'): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------------ - Parquet Vectorized 9332 / 9351 1.7 593.3 1.0X - Parquet Vectorized (Pushdown) 915 / 934 17.2 58.2 10.2X - Native ORC Vectorized 9049 / 9057 1.7 575.3 1.0X - Native ORC Vectorized (Pushdown) 248 / 258 63.5 15.8 37.7X - - - Select all string rows - (value IS NOT NULL): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------------ - Parquet Vectorized 20478 / 20497 0.8 1301.9 1.0X - Parquet Vectorized (Pushdown) 20461 / 20550 0.8 1300.9 1.0X - Native ORC Vectorized 27464 / 27482 0.6 1746.1 0.7X - Native ORC Vectorized (Pushdown) 27454 / 27488 0.6 1745.5 0.7X - - - Select 0 int row (value IS NULL): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------------ - Parquet Vectorized 8489 / 8519 1.9 539.7 1.0X - Parquet Vectorized (Pushdown) 64 / 69 246.1 4.1 132.8X - Native ORC Vectorized 8064 / 8099 2.0 512.7 1.1X - Native ORC Vectorized (Pushdown) 88 / 94 178.6 5.6 96.4X - - - Select 0 int row - (7864320 < value < 7864320): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------------ - Parquet Vectorized 8494 / 8514 1.9 540.0 1.0X - Parquet Vectorized (Pushdown) 835 / 840 18.8 53.1 10.2X - Native ORC Vectorized 8090 / 8106 1.9 514.4 1.0X - Native ORC Vectorized (Pushdown) 249 / 257 63.2 15.8 34.1X - - - Select 1 int row (value = 7864320): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------------ - Parquet Vectorized 8552 / 8560 1.8 543.7 1.0X - Parquet Vectorized (Pushdown) 837 / 841 18.8 53.2 10.2X - Native ORC Vectorized 8178 / 8188 1.9 519.9 1.0X - Native ORC Vectorized (Pushdown) 249 / 258 63.2 15.8 34.4X - - - Select 1 int row (value <=> 7864320): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------------ - Parquet Vectorized 8562 / 8580 1.8 544.3 1.0X - Parquet Vectorized (Pushdown) 833 / 836 18.9 53.0 10.3X - Native ORC Vectorized 8164 / 8185 1.9 519.0 1.0X - Native ORC Vectorized (Pushdown) 245 / 254 64.3 15.6 35.0X - - - Select 1 int row - (7864320 <= value <= 7864320): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------------ - Parquet Vectorized 8540 / 8555 1.8 542.9 1.0X - Parquet Vectorized (Pushdown) 837 / 839 18.8 53.2 10.2X - Native ORC Vectorized 8182 / 8231 1.9 520.2 1.0X - Native ORC Vectorized (Pushdown) 250 / 259 62.9 15.9 34.1X - - - Select 1 int row - (7864319 < value < 7864321): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------------ - Parquet Vectorized 8535 / 8555 1.8 542.6 1.0X - Parquet Vectorized (Pushdown) 835 / 841 18.8 53.1 10.2X - Native ORC Vectorized 8159 / 8179 1.9 518.8 1.0X - Native ORC Vectorized (Pushdown) 244 / 250 64.5 15.5 35.0X - - - Select 10% int rows (value < 1572864): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------------ - Parquet Vectorized 9609 / 9634 1.6 610.9 1.0X - Parquet Vectorized (Pushdown) 2663 / 2672 5.9 169.3 3.6X - Native ORC Vectorized 9824 / 9850 1.6 624.6 1.0X - Native ORC Vectorized (Pushdown) 2717 / 2722 5.8 172.7 3.5X - - - Select 50% int rows (value < 7864320): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------------ - Parquet Vectorized 13592 / 13613 1.2 864.2 1.0X - Parquet Vectorized (Pushdown) 9720 / 9738 1.6 618.0 1.4X - Native ORC Vectorized 16366 / 16397 1.0 1040.5 0.8X - Native ORC Vectorized (Pushdown) 12437 / 12459 1.3 790.7 1.1X - - - Select 90% int rows (value < 14155776): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------------ - Parquet Vectorized 17580 / 17617 0.9 1117.7 1.0X - Parquet Vectorized (Pushdown) 16803 / 16827 0.9 1068.3 1.0X - Native ORC Vectorized 24169 / 24187 0.7 1536.6 0.7X - Native ORC Vectorized (Pushdown) 22147 / 22341 0.7 1408.1 0.8X - - - Select all int rows (value IS NOT NULL): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------------ - Parquet Vectorized 18461 / 18491 0.9 1173.7 1.0X - Parquet Vectorized (Pushdown) 18466 / 18530 0.9 1174.1 1.0X - Native ORC Vectorized 24231 / 24270 0.6 1540.6 0.8X - Native ORC Vectorized (Pushdown) 24207 / 24304 0.6 1539.0 0.8X - - - Select all int rows (value > -1): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------------ - Parquet Vectorized 18414 / 18453 0.9 1170.7 1.0X - Parquet Vectorized (Pushdown) 18435 / 18464 0.9 1172.1 1.0X - Native ORC Vectorized 24430 / 24454 0.6 1553.2 0.8X - Native ORC Vectorized (Pushdown) 24410 / 24465 0.6 1552.0 0.8X - - - Select all int rows (value != -1): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------------ - Parquet Vectorized 18446 / 18457 0.9 1172.8 1.0X - Parquet Vectorized (Pushdown) 18428 / 18440 0.9 1171.6 1.0X - Native ORC Vectorized 24414 / 24450 0.6 1552.2 0.8X - Native ORC Vectorized (Pushdown) 24385 / 24472 0.6 1550.4 0.8X - - - Select 0 distinct string row - (value IS NULL): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------------ - Parquet Vectorized 8322 / 8352 1.9 529.1 1.0X - Parquet Vectorized (Pushdown) 53 / 57 296.3 3.4 156.7X - Native ORC Vectorized 7903 / 7953 2.0 502.4 1.1X - Native ORC Vectorized (Pushdown) 80 / 82 197.2 5.1 104.3X - - - Select 0 distinct string row - ('100' < value < '100'): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------------ - Parquet Vectorized 8712 / 8743 1.8 553.9 1.0X - Parquet Vectorized (Pushdown) 995 / 1030 15.8 63.3 8.8X - Native ORC Vectorized 8345 / 8362 1.9 530.6 1.0X - Native ORC Vectorized (Pushdown) 84 / 87 187.6 5.3 103.9X - - - Select 1 distinct string row - (value = '100'): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------------ - Parquet Vectorized 8574 / 8610 1.8 545.1 1.0X - Parquet Vectorized (Pushdown) 1127 / 1135 14.0 71.6 7.6X - Native ORC Vectorized 8163 / 8181 1.9 519.0 1.1X - Native ORC Vectorized (Pushdown) 426 / 433 36.9 27.1 20.1X - - - Select 1 distinct string row - (value <=> '100'): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------------ - Parquet Vectorized 8549 / 8568 1.8 543.5 1.0X - Parquet Vectorized (Pushdown) 1124 / 1131 14.0 71.4 7.6X - Native ORC Vectorized 8163 / 8210 1.9 519.0 1.0X - Native ORC Vectorized (Pushdown) 426 / 436 36.9 27.1 20.1X - - - Select 1 distinct string row - ('100' <= value <= '100'): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------------ - Parquet Vectorized 8889 / 8896 1.8 565.2 1.0X - Parquet Vectorized (Pushdown) 1161 / 1168 13.6 73.8 7.7X - Native ORC Vectorized 8519 / 8554 1.8 541.6 1.0X - Native ORC Vectorized (Pushdown) 430 / 437 36.6 27.3 20.7X - - - Select all distinct string rows - (value IS NOT NULL): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------------ - Parquet Vectorized 20433 / 20533 0.8 1299.1 1.0X - Parquet Vectorized (Pushdown) 20433 / 20456 0.8 1299.1 1.0X - Native ORC Vectorized 25435 / 25513 0.6 1617.1 0.8X - Native ORC Vectorized (Pushdown) 25435 / 25507 0.6 1617.1 0.8X - */ - benchmark.run() } @@ -408,15 +240,9 @@ object FilterPushdownBenchmark { } } - def main(args: Array[String]): Unit = { - val numRows = 1024 * 1024 * 15 - val width = 5 - - // Pushdown for many distinct value case + ignore("Pushdown for many distinct value case") { withTempPath { dir => - val mid = numRows / 2 - - withTempTable("orcTable", "patquetTable") { + withTempTable("orcTable", "parquetTable") { Seq(true, false).foreach { useStringForValue => prepareTable(dir, numRows, width, useStringForValue) if (useStringForValue) { @@ -427,16 +253,178 @@ object FilterPushdownBenchmark { } } } + } - // Pushdown for few distinct value case (use dictionary encoding) + ignore("Pushdown for few distinct value case (use dictionary encoding)") { withTempPath { dir => val numDistinctValues = 200 - val mid = numDistinctValues / 2 - withTempTable("orcTable", "patquetTable") { + withTempTable("orcTable", "parquetTable") { prepareStringDictTable(dir, numRows, numDistinctValues, width) - runStringBenchmark(numRows, width, mid, "distinct string") + runStringBenchmark(numRows, width, numDistinctValues / 2, "distinct string") } } } + + ignore("Pushdown benchmark for StringStartsWith") { + withTempPath { dir => + withTempTable("orcTable", "parquetTable") { + prepareTable(dir, numRows, width, true) + Seq( + "value like '10%'", + "value like '1000%'", + s"value like '${mid.toString.substring(0, mid.toString.length - 1)}%'" + ).foreach { whereExpr => + val title = s"StringStartsWith filter: ($whereExpr)" + filterPushDownBenchmark(numRows, title, whereExpr) + } + } + } + } + + ignore(s"Pushdown benchmark for ${DecimalType.simpleString}") { + withTempPath { dir => + Seq( + s"decimal(${Decimal.MAX_INT_DIGITS}, 2)", + s"decimal(${Decimal.MAX_LONG_DIGITS}, 2)", + s"decimal(${DecimalType.MAX_PRECISION}, 2)" + ).foreach { dt => + val columns = (1 to width).map(i => s"CAST(id AS string) c$i") + val valueCol = if (dt.equalsIgnoreCase(s"decimal(${Decimal.MAX_INT_DIGITS}, 2)")) { + monotonically_increasing_id() % 9999999 + } else { + monotonically_increasing_id() + } + val df = spark.range(numRows).selectExpr(columns: _*).withColumn("value", valueCol.cast(dt)) + withTempTable("orcTable", "parquetTable") { + saveAsTable(df, dir) + + Seq(s"value = $mid").foreach { whereExpr => + val title = s"Select 1 $dt row ($whereExpr)".replace("value AND value", "value") + filterPushDownBenchmark(numRows, title, whereExpr) + } + + val selectExpr = (1 to width).map(i => s"MAX(c$i)").mkString("", ",", ", MAX(value)") + Seq(10, 50, 90).foreach { percent => + filterPushDownBenchmark( + numRows, + s"Select $percent% $dt rows (value < ${numRows * percent / 100})", + s"value < ${numRows * percent / 100}", + selectExpr + ) + } + } + } + } + } + + ignore("Pushdown benchmark for InSet -> InFilters") { + withTempPath { dir => + withTempTable("orcTable", "parquetTable") { + prepareTable(dir, numRows, width, false) + Seq(5, 10, 50, 100).foreach { count => + Seq(10, 50, 90).foreach { distribution => + val filter = + Range(0, count).map(r => scala.util.Random.nextInt(numRows * distribution / 100)) + val whereExpr = s"value in(${filter.mkString(",")})" + val title = s"InSet -> InFilters (values count: $count, distribution: $distribution)" + filterPushDownBenchmark(numRows, title, whereExpr) + } + } + } + } + } + + ignore(s"Pushdown benchmark for ${ByteType.simpleString}") { + withTempPath { dir => + val columns = (1 to width).map(i => s"CAST(id AS string) c$i") + val df = spark.range(numRows).selectExpr(columns: _*) + .withColumn("value", (monotonically_increasing_id() % Byte.MaxValue).cast(ByteType)) + .orderBy("value") + withTempTable("orcTable", "parquetTable") { + saveAsTable(df, dir) + + Seq(s"value = CAST(${Byte.MaxValue / 2} AS ${ByteType.simpleString})") + .foreach { whereExpr => + val title = s"Select 1 ${ByteType.simpleString} row ($whereExpr)" + .replace("value AND value", "value") + filterPushDownBenchmark(numRows, title, whereExpr) + } + + val selectExpr = (1 to width).map(i => s"MAX(c$i)").mkString("", ",", ", MAX(value)") + Seq(10, 50, 90).foreach { percent => + filterPushDownBenchmark( + numRows, + s"Select $percent% ${ByteType.simpleString} rows " + + s"(value < CAST(${Byte.MaxValue * percent / 100} AS ${ByteType.simpleString}))", + s"value < CAST(${Byte.MaxValue * percent / 100} AS ${ByteType.simpleString})", + selectExpr + ) + } + } + } + } + + ignore(s"Pushdown benchmark for Timestamp") { + withTempPath { dir => + withSQLConf(SQLConf.PARQUET_FILTER_PUSHDOWN_TIMESTAMP_ENABLED.key -> true.toString) { + ParquetOutputTimestampType.values.toSeq.map(_.toString).foreach { fileType => + withSQLConf(SQLConf.PARQUET_OUTPUT_TIMESTAMP_TYPE.key -> fileType) { + val columns = (1 to width).map(i => s"CAST(id AS string) c$i") + val df = spark.range(numRows).selectExpr(columns: _*) + .withColumn("value", monotonically_increasing_id().cast(TimestampType)) + withTempTable("orcTable", "parquetTable") { + saveAsTable(df, dir) + + Seq(s"value = CAST($mid AS timestamp)").foreach { whereExpr => + val title = s"Select 1 timestamp stored as $fileType row ($whereExpr)" + .replace("value AND value", "value") + filterPushDownBenchmark(numRows, title, whereExpr) + } + + val selectExpr = (1 to width).map(i => s"MAX(c$i)").mkString("", ",", ", MAX(value)") + Seq(10, 50, 90).foreach { percent => + filterPushDownBenchmark( + numRows, + s"Select $percent% timestamp stored as $fileType rows " + + s"(value < CAST(${numRows * percent / 100} AS timestamp))", + s"value < CAST(${numRows * percent / 100} as timestamp)", + selectExpr + ) + } + } + } + } + } + } + } + + ignore(s"Pushdown benchmark with many filters") { + val numRows = 1 + val width = 500 + + withTempPath { dir => + val columns = (1 to width).map(i => s"id c$i") + val df = spark.range(1).selectExpr(columns: _*) + withTempTable("orcTable", "parquetTable") { + saveAsTable(df, dir) + Seq(1, 250, 500).foreach { numFilter => + val whereExpr = (1 to numFilter).map(i => s"c$i = 0").mkString(" and ") + // Note: InferFiltersFromConstraints will add more filters to this given filters + filterPushDownBenchmark(numRows, s"Select 1 row with $numFilter filters", whereExpr) + } + } + } + } +} + +trait BenchmarkBeforeAndAfterEachTest extends BeforeAndAfterEachTestData { this: Suite => + + override def beforeEach(td: TestData) { + super.beforeEach(td) + } + + override def afterEach(td: TestData) { + super.afterEach(td) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/SortBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/SortBenchmark.scala index 470b93efd1974..50ae26a3ff9d9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/SortBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/SortBenchmark.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.benchmark import java.util.{Arrays, Comparator} import org.apache.spark.unsafe.array.LongArray -import org.apache.spark.unsafe.memory.OnHeapMemoryBlock +import org.apache.spark.unsafe.memory.MemoryBlock import org.apache.spark.util.Benchmark import org.apache.spark.util.collection.Sorter import org.apache.spark.util.collection.unsafe.sort._ @@ -36,7 +36,7 @@ import org.apache.spark.util.random.XORShiftRandom class SortBenchmark extends BenchmarkBase { private def referenceKeyPrefixSort(buf: LongArray, lo: Int, hi: Int, refCmp: PrefixComparator) { - val sortBuffer = new LongArray(new OnHeapMemoryBlock(buf.size() * 8L)) + val sortBuffer = new LongArray(MemoryBlock.fromLongArray(new Array[Long](buf.size().toInt))) new Sorter(new UnsafeSortDataFormat(sortBuffer)).sort( buf, lo, hi, new Comparator[RecordPointerAndKeyPrefix] { override def compare( @@ -50,8 +50,8 @@ class SortBenchmark extends BenchmarkBase { private def generateKeyPrefixTestData(size: Int, rand: => Long): (LongArray, LongArray) = { val ref = Array.tabulate[Long](size * 2) { i => rand } val extended = ref ++ Array.fill[Long](size * 2)(0) - (new LongArray(OnHeapMemoryBlock.fromArray(ref)), - new LongArray(OnHeapMemoryBlock.fromArray(extended))) + (new LongArray(MemoryBlock.fromLongArray(ref)), + new LongArray(MemoryBlock.fromLongArray(extended))) } ignore("sort") { @@ -60,7 +60,7 @@ class SortBenchmark extends BenchmarkBase { val benchmark = new Benchmark("radix sort " + size, size) benchmark.addTimerCase("reference TimSort key prefix array") { timer => val array = Array.tabulate[Long](size * 2) { i => rand.nextLong } - val buf = new LongArray(OnHeapMemoryBlock.fromArray(array)) + val buf = new LongArray(MemoryBlock.fromLongArray(array)) timer.startTiming() referenceKeyPrefixSort(buf, 0, size, PrefixComparators.BINARY) timer.stopTiming() @@ -78,7 +78,7 @@ class SortBenchmark extends BenchmarkBase { array(i) = rand.nextLong & 0xff i += 1 } - val buf = new LongArray(OnHeapMemoryBlock.fromArray(array)) + val buf = new LongArray(MemoryBlock.fromLongArray(array)) timer.startTiming() RadixSort.sort(buf, size, 0, 7, false, false) timer.stopTiming() @@ -90,7 +90,7 @@ class SortBenchmark extends BenchmarkBase { array(i) = rand.nextLong & 0xffff i += 1 } - val buf = new LongArray(OnHeapMemoryBlock.fromArray(array)) + val buf = new LongArray(MemoryBlock.fromLongArray(array)) timer.startTiming() RadixSort.sort(buf, size, 0, 7, false, false) timer.stopTiming() @@ -102,7 +102,7 @@ class SortBenchmark extends BenchmarkBase { array(i) = rand.nextLong i += 1 } - val buf = new LongArray(OnHeapMemoryBlock.fromArray(array)) + val buf = new LongArray(MemoryBlock.fromLongArray(array)) timer.startTiming() RadixSort.sort(buf, size, 0, 7, false, false) timer.stopTiming() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/TPCDSQueryBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/TPCDSQueryBenchmark.scala index abe61a2c2b9c4..fccee97820e75 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/TPCDSQueryBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/TPCDSQueryBenchmark.scala @@ -72,7 +72,7 @@ object TPCDSQueryBenchmark extends Logging { val queryRelations = scala.collection.mutable.HashSet[String]() spark.sql(queryString).queryExecution.analyzed.foreach { case SubqueryAlias(alias, _: LogicalRelation) => - queryRelations.add(alias) + queryRelations.add(alias.identifier) case LogicalRelation(_, _, Some(catalogTable), _) => queryRelations.add(catalogTable.identifier.table) case HiveTableRelation(tableMeta, _, _) => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/WideSchemaBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/WideSchemaBenchmark.scala index a42891e55a18a..c368f17a84364 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/WideSchemaBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/WideSchemaBenchmark.scala @@ -54,8 +54,11 @@ class WideSchemaBenchmark extends SparkFunSuite with BeforeAndAfterEach { } override def afterAll() { - super.afterAll() - out.close() + try { + out.close() + } finally { + super.afterAll() + } } override def afterEach() { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/PartitionBatchPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/PartitionBatchPruningSuite.scala index 9d862cfdecb21..af493e93b5192 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/PartitionBatchPruningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/PartitionBatchPruningSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.columnar import org.scalatest.BeforeAndAfterEach import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.DataFrame import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.test.SQLTestData._ @@ -35,6 +36,12 @@ class PartitionBatchPruningSuite private lazy val originalColumnBatchSize = spark.conf.get(SQLConf.COLUMN_BATCH_SIZE) private lazy val originalInMemoryPartitionPruning = spark.conf.get(SQLConf.IN_MEMORY_PARTITION_PRUNING) + private val testArrayData = (1 to 100).map { key => + Tuple1(Array.fill(key)(key)) + } + private val testBinaryData = (1 to 100).map { key => + Tuple1(Array.fill(key)(key.toByte)) + } override protected def beforeAll(): Unit = { super.beforeAll() @@ -71,12 +78,22 @@ class PartitionBatchPruningSuite }, 5).toDF() pruningStringData.createOrReplaceTempView("pruningStringData") spark.catalog.cacheTable("pruningStringData") + + val pruningArrayData = sparkContext.makeRDD(testArrayData, 5).toDF() + pruningArrayData.createOrReplaceTempView("pruningArrayData") + spark.catalog.cacheTable("pruningArrayData") + + val pruningBinaryData = sparkContext.makeRDD(testBinaryData, 5).toDF() + pruningBinaryData.createOrReplaceTempView("pruningBinaryData") + spark.catalog.cacheTable("pruningBinaryData") } override protected def afterEach(): Unit = { try { spark.catalog.uncacheTable("pruningData") spark.catalog.uncacheTable("pruningStringData") + spark.catalog.uncacheTable("pruningArrayData") + spark.catalog.uncacheTable("pruningBinaryData") } finally { super.afterEach() } @@ -95,6 +112,14 @@ class PartitionBatchPruningSuite checkBatchPruning("SELECT key FROM pruningData WHERE 11 >= key", 1, 2)(1 to 11) checkBatchPruning("SELECT key FROM pruningData WHERE 88 < key", 1, 2)(89 to 100) checkBatchPruning("SELECT key FROM pruningData WHERE 89 <= key", 1, 2)(89 to 100) + // Do not filter on array type + checkBatchPruning("SELECT _1 FROM pruningArrayData WHERE _1 = array(1)", 5, 10)(Seq(Array(1))) + checkBatchPruning("SELECT _1 FROM pruningArrayData WHERE _1 <= array(1)", 5, 10)(Seq(Array(1))) + checkBatchPruning("SELECT _1 FROM pruningArrayData WHERE _1 >= array(1)", 5, 10)( + testArrayData.map(_._1)) + // Do not filter on binary type + checkBatchPruning( + "SELECT _1 FROM pruningBinaryData WHERE _1 == binary(chr(1))", 5, 10)(Seq(Array(1.toByte))) // IS NULL checkBatchPruning("SELECT key FROM pruningData WHERE value IS NULL", 5, 5) { @@ -131,6 +156,9 @@ class PartitionBatchPruningSuite checkBatchPruning( "SELECT CAST(s AS INT) FROM pruningStringData WHERE s IN ('99', '150', '201')", 1, 1)( Seq(150)) + // Do not filter on array type + checkBatchPruning("SELECT _1 FROM pruningArrayData WHERE _1 IN (array(1), array(2, 2))", 5, 10)( + Seq(Array(1), Array(2, 2))) // With unsupported `InSet` predicate { @@ -161,7 +189,7 @@ class PartitionBatchPruningSuite query: String, expectedReadPartitions: Int, expectedReadBatches: Int)( - expectedQueryResult: => Seq[Int]): Unit = { + expectedQueryResult: => Seq[Any]): Unit = { test(query) { val df = sql(query) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala index 3998ceca38b30..f8d98dead2d42 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala @@ -52,23 +52,24 @@ class InMemoryCatalogedDDLSuite extends DDLSuite with SharedSQLContext with Befo protected override def generateTable( catalog: SessionCatalog, name: TableIdentifier, - isDataSource: Boolean = true): CatalogTable = { + isDataSource: Boolean = true, + partitionCols: Seq[String] = Seq("a", "b")): CatalogTable = { val storage = CatalogStorageFormat.empty.copy(locationUri = Some(catalog.defaultTablePath(name))) val metadata = new MetadataBuilder() .putString("key", "value") .build() + val schema = new StructType() + .add("col1", "int", nullable = true, metadata = metadata) + .add("col2", "string") CatalogTable( identifier = name, tableType = CatalogTableType.EXTERNAL, storage = storage, - schema = new StructType() - .add("col1", "int", nullable = true, metadata = metadata) - .add("col2", "string") - .add("a", "int") - .add("b", "int"), + schema = schema.copy( + fields = schema.fields ++ partitionCols.map(StructField(_, IntegerType))), provider = Some("parquet"), - partitionColumnNames = Seq("a", "b"), + partitionColumnNames = partitionCols, createTime = 0L, createVersion = org.apache.spark.SPARK_VERSION, tracksPartitionsInCatalog = true) @@ -176,7 +177,8 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { protected def generateTable( catalog: SessionCatalog, name: TableIdentifier, - isDataSource: Boolean = true): CatalogTable + isDataSource: Boolean = true, + partitionCols: Seq[String] = Seq("a", "b")): CatalogTable private val escapedIdentifier = "`(.+)`".r @@ -228,8 +230,10 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { private def createTable( catalog: SessionCatalog, name: TableIdentifier, - isDataSource: Boolean = true): Unit = { - catalog.createTable(generateTable(catalog, name, isDataSource), ignoreIfExists = false) + isDataSource: Boolean = true, + partitionCols: Seq[String] = Seq("a", "b")): Unit = { + catalog.createTable( + generateTable(catalog, name, isDataSource, partitionCols), ignoreIfExists = false) } private def createTablePartition( @@ -441,6 +445,24 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { } } + test("rename a managed table with existing empty directory") { + val tableLoc = new File(spark.sessionState.catalog.defaultTablePath(TableIdentifier("tab2"))) + try { + withTable("tab1") { + sql(s"CREATE TABLE tab1 USING $dataSource AS SELECT 1, 'a'") + tableLoc.mkdir() + val ex = intercept[AnalysisException] { + sql("ALTER TABLE tab1 RENAME TO tab2") + }.getMessage + val expectedMsg = "Can not rename the managed table('`tab1`'). The associated location" + assert(ex.contains(expectedMsg)) + } + } finally { + waitForTasksToFinish() + Utils.deleteRecursively(tableLoc) + } + } + private def checkSchemaInCreatedDataSourceTable( path: File, userSpecifiedSchema: Option[String], @@ -1113,7 +1135,7 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { } test("alter table: recover partition (parallel)") { - withSQLConf("spark.rdd.parallelListingThreshold" -> "1") { + withSQLConf("spark.rdd.parallelListingThreshold" -> "0") { testRecoverPartitions() } } @@ -1126,23 +1148,32 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { } val tableIdent = TableIdentifier("tab1") - createTable(catalog, tableIdent) - val part1 = Map("a" -> "1", "b" -> "5") + createTable(catalog, tableIdent, partitionCols = Seq("a", "b", "c")) + val part1 = Map("a" -> "1", "b" -> "5", "c" -> "19") createTablePartition(catalog, part1, tableIdent) assert(catalog.listPartitions(tableIdent).map(_.spec).toSet == Set(part1)) - val part2 = Map("a" -> "2", "b" -> "6") + val part2 = Map("a" -> "2", "b" -> "6", "c" -> "31") val root = new Path(catalog.getTableMetadata(tableIdent).location) val fs = root.getFileSystem(spark.sessionState.newHadoopConf()) // valid - fs.mkdirs(new Path(new Path(root, "a=1"), "b=5")) - fs.createNewFile(new Path(new Path(root, "a=1/b=5"), "a.csv")) // file - fs.createNewFile(new Path(new Path(root, "a=1/b=5"), "_SUCCESS")) // file - fs.mkdirs(new Path(new Path(root, "A=2"), "B=6")) - fs.createNewFile(new Path(new Path(root, "A=2/B=6"), "b.csv")) // file - fs.createNewFile(new Path(new Path(root, "A=2/B=6"), "c.csv")) // file - fs.createNewFile(new Path(new Path(root, "A=2/B=6"), ".hiddenFile")) // file - fs.mkdirs(new Path(new Path(root, "A=2/B=6"), "_temporary")) + fs.mkdirs(new Path(new Path(new Path(root, "a=1"), "b=5"), "c=19")) + fs.createNewFile(new Path(new Path(root, "a=1/b=5/c=19"), "a.csv")) // file + fs.createNewFile(new Path(new Path(root, "a=1/b=5/c=19"), "_SUCCESS")) // file + + fs.mkdirs(new Path(new Path(new Path(root, "A=2"), "B=6"), "C=31")) + fs.createNewFile(new Path(new Path(root, "A=2/B=6/C=31"), "b.csv")) // file + fs.createNewFile(new Path(new Path(root, "A=2/B=6/C=31"), "c.csv")) // file + fs.createNewFile(new Path(new Path(root, "A=2/B=6/C=31"), ".hiddenFile")) // file + fs.mkdirs(new Path(new Path(root, "A=2/B=6/C=31"), "_temporary")) + + val parts = (10 to 100).map { a => + val part = Map("a" -> a.toString, "b" -> "5", "c" -> "42") + fs.mkdirs(new Path(new Path(new Path(root, s"a=$a"), "b=5"), "c=42")) + fs.createNewFile(new Path(new Path(root, s"a=$a/b=5/c=42"), "a.csv")) // file + createTablePartition(catalog, part, tableIdent) + part + } // invalid fs.mkdirs(new Path(new Path(root, "a"), "b")) // bad name @@ -1156,7 +1187,7 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { try { sql("ALTER TABLE tab1 RECOVER PARTITIONS") assert(catalog.listPartitions(tableIdent).map(_.spec).toSet == - Set(part1, part2)) + Set(part1, part2) ++ parts) if (!isUsingHiveMetastore) { assert(catalog.getPartition(tableIdent, part1).parameters("numFiles") == "1") assert(catalog.getPartition(tableIdent, part2).parameters("numFiles") == "2") @@ -2231,6 +2262,68 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { } } + test("Partition table should load empty static partitions") { + // All static partitions + withTable("t", "t1", "t2") { + withTempPath { dir => + spark.sql("CREATE TABLE t(a int) USING parquet") + spark.sql("CREATE TABLE t1(a int, c string, b string) " + + s"USING parquet PARTITIONED BY(c, b) LOCATION '${dir.toURI}'") + + // datasource table + validateStaticPartitionTable("t1") + + // hive table + if (isUsingHiveMetastore) { + spark.sql("CREATE TABLE t2(a int) " + + s"PARTITIONED BY(c string, b string) LOCATION '${dir.toURI}'") + validateStaticPartitionTable("t2") + } + + def validateStaticPartitionTable(tableName: String): Unit = { + val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier(tableName)) + assert(table.location == makeQualifiedPath(dir.getAbsolutePath)) + assert(spark.sql(s"SHOW PARTITIONS $tableName").count() == 0) + spark.sql( + s"INSERT INTO TABLE $tableName PARTITION(b='b', c='c') SELECT * FROM t WHERE 1 = 0") + assert(spark.sql(s"SHOW PARTITIONS $tableName").count() == 1) + assert(new File(dir, "c=c/b=b").exists()) + checkAnswer(spark.table(tableName), Nil) + } + } + } + + // Partial dynamic partitions + withTable("t", "t1", "t2") { + withTempPath { dir => + spark.sql("CREATE TABLE t(a int) USING parquet") + spark.sql("CREATE TABLE t1(a int, b string, c string) " + + s"USING parquet PARTITIONED BY(c, b) LOCATION '${dir.toURI}'") + + // datasource table + validatePartialStaticPartitionTable("t1") + + // hive table + if (isUsingHiveMetastore) { + spark.sql("CREATE TABLE t2(a int) " + + s"PARTITIONED BY(c string, b string) LOCATION '${dir.toURI}'") + validatePartialStaticPartitionTable("t2") + } + + def validatePartialStaticPartitionTable(tableName: String): Unit = { + val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier(tableName)) + assert(table.location == makeQualifiedPath(dir.getAbsolutePath)) + assert(spark.sql(s"SHOW PARTITIONS $tableName").count() == 0) + spark.sql( + s"INSERT INTO TABLE $tableName PARTITION(c='c', b) SELECT *, 'b' FROM t WHERE 1 = 0") + assert(spark.sql(s"SHOW PARTITIONS $tableName").count() == 0) + assert(!new File(dir, "c=c/b=b").exists()) + checkAnswer(spark.table(tableName), Nil) + } + } + } + } + Seq(true, false).foreach { shouldDelete => val tcName = if (shouldDelete) "non-existing" else "existed" test(s"CTAS for external data source table with a $tcName location") { @@ -2495,7 +2588,7 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { test("alter datasource table add columns - text format not supported") { withTable("t1") { - sql("CREATE TABLE t1 (c1 int) USING text") + sql("CREATE TABLE t1 (c1 string) USING text") val e = intercept[AnalysisException] { sql("ALTER TABLE t1 ADD COLUMNS (c2 int)") }.getMessage diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/BasicWriteTaskStatsTrackerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/BasicWriteTaskStatsTrackerSuite.scala index bf3c8ede9a980..32941d8d2cd11 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/BasicWriteTaskStatsTrackerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/BasicWriteTaskStatsTrackerSuite.scala @@ -49,7 +49,11 @@ class BasicWriteTaskStatsTrackerSuite extends SparkFunSuite { * In teardown delete the temp dir. */ protected override def afterAll(): Unit = { - Utils.deleteRecursively(tempDir) + try { + Utils.deleteRecursively(tempDir) + } finally { + super.afterAll() + } } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala index 8764f0c42cf9f..bceaf1a9ec061 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala @@ -25,7 +25,7 @@ import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{BlockLocation, FileStatus, Path, RawLocalFileSystem} import org.apache.hadoop.mapreduce.Job -import org.apache.spark.{SparkConf, SparkException} +import org.apache.spark.SparkException import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.catalog.BucketSpec diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/HadoopFileLinesReaderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/HadoopFileLinesReaderSuite.scala index a39a25be262a6..508614a7e476c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/HadoopFileLinesReaderSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/HadoopFileLinesReaderSuite.scala @@ -38,7 +38,7 @@ class HadoopFileLinesReaderSuite extends SharedSQLContext { val lines = ranges.map { case (start, length) => val file = PartitionedFile(InternalRow.empty, path.getCanonicalPath, start, length) - val hadoopConf = conf.getOrElse(spark.sparkContext.hadoopConfiguration) + val hadoopConf = conf.getOrElse(spark.sessionState.newHadoopConf()) val reader = new HadoopFileLinesReader(file, delimOpt, hadoopConf) reader.map(_.toString) @@ -111,20 +111,20 @@ class HadoopFileLinesReaderSuite extends SharedSQLContext { } test("io.file.buffer.size is less than line length") { - val conf = spark.sparkContext.hadoopConfiguration - conf.set("io.file.buffer.size", "2") - withTempPath { path => - val lines = getLines(path, text = "abcdef\n123456", ranges = Seq((4, 4), (8, 5))) - assert(lines == Seq("123456")) + withSQLConf("io.file.buffer.size" -> "2") { + withTempPath { path => + val lines = getLines(path, text = "abcdef\n123456", ranges = Seq((4, 4), (8, 5))) + assert(lines == Seq("123456")) + } } } test("line cannot be longer than line.maxlength") { - val conf = spark.sparkContext.hadoopConfiguration - conf.set("mapreduce.input.linerecordreader.line.maxlength", "5") - withTempPath { path => - val lines = getLines(path, text = "abcdef\n1234", ranges = Seq((0, 15))) - assert(lines == Seq("1234")) + withSQLConf("mapreduce.input.linerecordreader.line.maxlength" -> "5") { + withTempPath { path => + val lines = getLines(path, text = "abcdef\n1234", ranges = Seq((0, 15))) + assert(lines == Seq("1234")) + } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/ReadSchemaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/ReadSchemaSuite.scala new file mode 100644 index 0000000000000..23c58e175fe5e --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/ReadSchemaSuite.scala @@ -0,0 +1,181 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources + +import org.apache.spark.sql.internal.SQLConf + +/** + * Read schema suites have the following hierarchy and aims to guarantee users + * a backward-compatible read-schema change coverage on file-based data sources, and + * to prevent future regressions. + * + * ReadSchemaSuite + * -> CSVReadSchemaSuite + * -> HeaderCSVReadSchemaSuite + * + * -> JsonReadSchemaSuite + * + * -> OrcReadSchemaSuite + * -> VectorizedOrcReadSchemaSuite + * + * -> ParquetReadSchemaSuite + * -> VectorizedParquetReadSchemaSuite + * -> MergedParquetReadSchemaSuite + */ + +/** + * All file-based data sources supports column addition and removal at the end. + */ +abstract class ReadSchemaSuite + extends AddColumnTest + with HideColumnAtTheEndTest { + + var originalConf: Boolean = _ +} + +class CSVReadSchemaSuite + extends ReadSchemaSuite + with IntegralTypeTest + with ToDoubleTypeTest + with ToDecimalTypeTest + with ToStringTypeTest { + + override val format: String = "csv" +} + +class HeaderCSVReadSchemaSuite + extends ReadSchemaSuite + with IntegralTypeTest + with ToDoubleTypeTest + with ToDecimalTypeTest + with ToStringTypeTest { + + override val format: String = "csv" + + override val options = Map("header" -> "true") +} + +class JsonReadSchemaSuite + extends ReadSchemaSuite + with HideColumnInTheMiddleTest + with ChangePositionTest + with IntegralTypeTest + with ToDoubleTypeTest + with ToDecimalTypeTest + with ToStringTypeTest { + + override val format: String = "json" +} + +class OrcReadSchemaSuite + extends ReadSchemaSuite + with HideColumnInTheMiddleTest + with ChangePositionTest { + + override val format: String = "orc" + + override def beforeAll() { + super.beforeAll() + originalConf = spark.conf.get(SQLConf.ORC_VECTORIZED_READER_ENABLED) + spark.conf.set(SQLConf.ORC_VECTORIZED_READER_ENABLED.key, "false") + } + + override def afterAll() { + spark.conf.set(SQLConf.ORC_VECTORIZED_READER_ENABLED.key, originalConf) + super.afterAll() + } +} + +class VectorizedOrcReadSchemaSuite + extends ReadSchemaSuite + with HideColumnInTheMiddleTest + with ChangePositionTest + with BooleanTypeTest + with IntegralTypeTest + with ToDoubleTypeTest { + + override val format: String = "orc" + + override def beforeAll() { + super.beforeAll() + originalConf = spark.conf.get(SQLConf.ORC_VECTORIZED_READER_ENABLED) + spark.conf.set(SQLConf.ORC_VECTORIZED_READER_ENABLED.key, "true") + } + + override def afterAll() { + spark.conf.set(SQLConf.ORC_VECTORIZED_READER_ENABLED.key, originalConf) + super.afterAll() + } +} + +class ParquetReadSchemaSuite + extends ReadSchemaSuite + with HideColumnInTheMiddleTest + with ChangePositionTest { + + override val format: String = "parquet" + + override def beforeAll() { + super.beforeAll() + originalConf = spark.conf.get(SQLConf.PARQUET_VECTORIZED_READER_ENABLED) + spark.conf.set(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key, "false") + } + + override def afterAll() { + spark.conf.set(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key, originalConf) + super.afterAll() + } +} + +class VectorizedParquetReadSchemaSuite + extends ReadSchemaSuite + with HideColumnInTheMiddleTest + with ChangePositionTest { + + override val format: String = "parquet" + + override def beforeAll() { + super.beforeAll() + originalConf = spark.conf.get(SQLConf.PARQUET_VECTORIZED_READER_ENABLED) + spark.conf.set(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key, "true") + } + + override def afterAll() { + spark.conf.set(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key, originalConf) + super.afterAll() + } +} + +class MergedParquetReadSchemaSuite + extends ReadSchemaSuite + with HideColumnInTheMiddleTest + with ChangePositionTest { + + override val format: String = "parquet" + + override def beforeAll() { + super.beforeAll() + originalConf = spark.conf.get(SQLConf.PARQUET_SCHEMA_MERGING_ENABLED) + spark.conf.set(SQLConf.PARQUET_SCHEMA_MERGING_ENABLED.key, "true") + } + + override def afterAll() { + spark.conf.set(SQLConf.PARQUET_SCHEMA_MERGING_ENABLED.key, originalConf) + super.afterAll() + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/ReadSchemaTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/ReadSchemaTest.scala new file mode 100644 index 0000000000000..2a5457e00b4ef --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/ReadSchemaTest.scala @@ -0,0 +1,493 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources + +import java.io.File + +import org.apache.spark.sql.{QueryTest, Row} +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.test.{SharedSQLContext, SQLTestUtils} + +/** + * The reader schema is said to be evolved (or projected) when it changed after the data is + * written by writers. The followings are supported in file-based data sources. + * Note that partition columns are not maintained in files. Here, `column` means non-partition + * column. + * + * 1. Add a column + * 2. Hide a column + * 3. Change a column position + * 4. Change a column type (Upcast) + * + * Here, we consider safe changes without data loss. For example, data type changes should be + * from small types to larger types like `int`-to-`long`, not vice versa. + * + * So far, file-based data sources have the following coverages. + * + * | File Format | Coverage | Note | + * | ------------ | ------------ | ------------------------------------------------------ | + * | TEXT | N/A | Schema consists of a single string column. | + * | CSV | 1, 2, 4 | | + * | JSON | 1, 2, 3, 4 | | + * | ORC | 1, 2, 3, 4 | Native vectorized ORC reader has the widest coverage. | + * | PARQUET | 1, 2, 3 | | + * + * This aims to provide an explicit test coverage for reader schema change on file-based data + * sources. Since a file format has its own coverage, we need a test suite for each file-based + * data source with corresponding supported test case traits. + * + * The following is a hierarchy of test traits. + * + * ReadSchemaTest + * -> AddColumnTest + * -> HideColumnTest + * -> ChangePositionTest + * -> BooleanTypeTest + * -> IntegralTypeTest + * -> ToDoubleTypeTest + * -> ToDecimalTypeTest + */ + +trait ReadSchemaTest extends QueryTest with SQLTestUtils with SharedSQLContext { + val format: String + val options: Map[String, String] = Map.empty[String, String] +} + +/** + * Add column (Case 1). + * This test suite assumes that the missing column should be `null`. + */ +trait AddColumnTest extends ReadSchemaTest { + import testImplicits._ + + test("append column at the end") { + withTempPath { dir => + val path = dir.getCanonicalPath + + val df1 = Seq("a", "b").toDF("col1") + val df2 = df1.withColumn("col2", lit("x")) + val df3 = df2.withColumn("col3", lit("y")) + + val dir1 = s"$path${File.separator}part=one" + val dir2 = s"$path${File.separator}part=two" + val dir3 = s"$path${File.separator}part=three" + + df1.write.format(format).options(options).save(dir1) + df2.write.format(format).options(options).save(dir2) + df3.write.format(format).options(options).save(dir3) + + val df = spark.read + .schema(df3.schema) + .format(format) + .options(options) + .load(path) + + checkAnswer(df, Seq( + Row("a", null, null, "one"), + Row("b", null, null, "one"), + Row("a", "x", null, "two"), + Row("b", "x", null, "two"), + Row("a", "x", "y", "three"), + Row("b", "x", "y", "three"))) + } + } +} + +/** + * Hide column (Case 2-1). + */ +trait HideColumnAtTheEndTest extends ReadSchemaTest { + import testImplicits._ + + test("hide column at the end") { + withTempPath { dir => + val path = dir.getCanonicalPath + + val df1 = Seq(("1", "a"), ("2", "b")).toDF("col1", "col2") + val df2 = df1.withColumn("col3", lit("y")) + + val dir1 = s"$path${File.separator}part=two" + val dir2 = s"$path${File.separator}part=three" + + df1.write.format(format).options(options).save(dir1) + df2.write.format(format).options(options).save(dir2) + + val df = spark.read + .schema(df1.schema) + .format(format) + .options(options) + .load(path) + + checkAnswer(df, Seq( + Row("1", "a", "two"), + Row("2", "b", "two"), + Row("1", "a", "three"), + Row("2", "b", "three"))) + + val df3 = spark.read + .schema("col1 string") + .format(format) + .options(options) + .load(path) + + checkAnswer(df3, Seq( + Row("1", "two"), + Row("2", "two"), + Row("1", "three"), + Row("2", "three"))) + } + } +} + +/** + * Hide column in the middle (Case 2-2). + */ +trait HideColumnInTheMiddleTest extends ReadSchemaTest { + import testImplicits._ + + test("hide column in the middle") { + withTempPath { dir => + val path = dir.getCanonicalPath + + val df1 = Seq(("1", "a"), ("2", "b")).toDF("col1", "col2") + val df2 = df1.withColumn("col3", lit("y")) + + val dir1 = s"$path${File.separator}part=two" + val dir2 = s"$path${File.separator}part=three" + + df1.write.format(format).options(options).save(dir1) + df2.write.format(format).options(options).save(dir2) + + val df = spark.read + .schema("col2 string") + .format(format) + .options(options) + .load(path) + + checkAnswer(df, Seq( + Row("a", "two"), + Row("b", "two"), + Row("a", "three"), + Row("b", "three"))) + } + } +} + +/** + * Change column positions (Case 3). + * This suite assumes that all data set have the same number of columns. + */ +trait ChangePositionTest extends ReadSchemaTest { + import testImplicits._ + + test("change column position") { + withTempPath { dir => + val path = dir.getCanonicalPath + + val df1 = Seq(("1", "a"), ("2", "b"), ("3", "c")).toDF("col1", "col2") + val df2 = Seq(("d", "4"), ("e", "5"), ("f", "6")).toDF("col2", "col1") + val unionDF = df1.unionByName(df2) + + val dir1 = s"$path${File.separator}part=one" + val dir2 = s"$path${File.separator}part=two" + + df1.write.format(format).options(options).save(dir1) + df2.write.format(format).options(options).save(dir2) + + val df = spark.read + .schema(unionDF.schema) + .format(format) + .options(options) + .load(path) + .select("col1", "col2") + + checkAnswer(df, unionDF) + } + } +} + +/** + * Change a column type (Case 4). + * This suite assumes that a user gives a wider schema intentionally. + */ +trait BooleanTypeTest extends ReadSchemaTest { + import testImplicits._ + + test("change column type from boolean to byte/short/int/long") { + withTempPath { dir => + val path = dir.getCanonicalPath + + val values = (1 to 10).map(_ % 2) + val booleanDF = (1 to 10).map(_ % 2 == 1).toDF("col1") + val byteDF = values.map(_.toByte).toDF("col1") + val shortDF = values.map(_.toShort).toDF("col1") + val intDF = values.toDF("col1") + val longDF = values.map(_.toLong).toDF("col1") + + booleanDF.write.mode("overwrite").format(format).options(options).save(path) + + Seq( + ("col1 byte", byteDF), + ("col1 short", shortDF), + ("col1 int", intDF), + ("col1 long", longDF)).foreach { case (schema, answerDF) => + checkAnswer(spark.read.schema(schema).format(format).options(options).load(path), answerDF) + } + } + } +} + +/** + * Change a column type (Case 4). + * This suite assumes that a user gives a wider schema intentionally. + */ +trait ToStringTypeTest extends ReadSchemaTest { + import testImplicits._ + + test("read as string") { + withTempPath { dir => + val path = dir.getCanonicalPath + + val byteDF = (Byte.MaxValue - 2 to Byte.MaxValue).map(_.toByte).toDF("col1") + val shortDF = (Short.MaxValue - 2 to Short.MaxValue).map(_.toShort).toDF("col1") + val intDF = (Int.MaxValue - 2 to Int.MaxValue).toDF("col1") + val longDF = (Long.MaxValue - 2 to Long.MaxValue).toDF("col1") + val unionDF = byteDF.union(shortDF).union(intDF).union(longDF) + .selectExpr("cast(col1 AS STRING) col1") + + val byteDir = s"$path${File.separator}part=byte" + val shortDir = s"$path${File.separator}part=short" + val intDir = s"$path${File.separator}part=int" + val longDir = s"$path${File.separator}part=long" + + byteDF.write.format(format).options(options).save(byteDir) + shortDF.write.format(format).options(options).save(shortDir) + intDF.write.format(format).options(options).save(intDir) + longDF.write.format(format).options(options).save(longDir) + + val df = spark.read + .schema("col1 string") + .format(format) + .options(options) + .load(path) + .select("col1") + + checkAnswer(df, unionDF) + } + } +} + +/** + * Change a column type (Case 4). + * This suite assumes that a user gives a wider schema intentionally. + */ +trait IntegralTypeTest extends ReadSchemaTest { + + import testImplicits._ + + private lazy val values = 1 to 10 + private lazy val byteDF = values.map(_.toByte).toDF("col1") + private lazy val shortDF = values.map(_.toShort).toDF("col1") + private lazy val intDF = values.toDF("col1") + private lazy val longDF = values.map(_.toLong).toDF("col1") + + test("change column type from byte to short/int/long") { + withTempPath { dir => + val path = dir.getCanonicalPath + + byteDF.write.format(format).options(options).save(path) + + Seq( + ("col1 short", shortDF), + ("col1 int", intDF), + ("col1 long", longDF)).foreach { case (schema, answerDF) => + checkAnswer(spark.read.schema(schema).format(format).options(options).load(path), answerDF) + } + } + } + + test("change column type from short to int/long") { + withTempPath { dir => + val path = dir.getCanonicalPath + + shortDF.write.format(format).options(options).save(path) + + Seq(("col1 int", intDF), ("col1 long", longDF)).foreach { case (schema, answerDF) => + checkAnswer(spark.read.schema(schema).format(format).options(options).load(path), answerDF) + } + } + } + + test("change column type from int to long") { + withTempPath { dir => + val path = dir.getCanonicalPath + + intDF.write.format(format).options(options).save(path) + + Seq(("col1 long", longDF)).foreach { case (schema, answerDF) => + checkAnswer(spark.read.schema(schema).format(format).options(options).load(path), answerDF) + } + } + } + + test("read byte, int, short, long together") { + withTempPath { dir => + val path = dir.getCanonicalPath + + val byteDF = (Byte.MaxValue - 2 to Byte.MaxValue).map(_.toByte).toDF("col1") + val shortDF = (Short.MaxValue - 2 to Short.MaxValue).map(_.toShort).toDF("col1") + val intDF = (Int.MaxValue - 2 to Int.MaxValue).toDF("col1") + val longDF = (Long.MaxValue - 2 to Long.MaxValue).toDF("col1") + val unionDF = byteDF.union(shortDF).union(intDF).union(longDF) + + val byteDir = s"$path${File.separator}part=byte" + val shortDir = s"$path${File.separator}part=short" + val intDir = s"$path${File.separator}part=int" + val longDir = s"$path${File.separator}part=long" + + byteDF.write.format(format).options(options).save(byteDir) + shortDF.write.format(format).options(options).save(shortDir) + intDF.write.format(format).options(options).save(intDir) + longDF.write.format(format).options(options).save(longDir) + + val df = spark.read + .schema(unionDF.schema) + .format(format) + .options(options) + .load(path) + .select("col1") + + checkAnswer(df, unionDF) + } + } +} + +/** + * Change a column type (Case 4). + * This suite assumes that a user gives a wider schema intentionally. + */ +trait ToDoubleTypeTest extends ReadSchemaTest { + import testImplicits._ + + private lazy val values = 1 to 10 + private lazy val floatDF = values.map(_.toFloat).toDF("col1") + private lazy val doubleDF = values.map(_.toDouble).toDF("col1") + private lazy val unionDF = floatDF.union(doubleDF) + + test("change column type from float to double") { + withTempPath { dir => + val path = dir.getCanonicalPath + + floatDF.write.format(format).options(options).save(path) + + val df = spark.read.schema("col1 double").format(format).options(options).load(path) + + checkAnswer(df, doubleDF) + } + } + + test("read float and double together") { + withTempPath { dir => + val path = dir.getCanonicalPath + + val floatDir = s"$path${File.separator}part=float" + val doubleDir = s"$path${File.separator}part=double" + + floatDF.write.format(format).options(options).save(floatDir) + doubleDF.write.format(format).options(options).save(doubleDir) + + val df = spark.read + .schema(unionDF.schema) + .format(format) + .options(options) + .load(path) + .select("col1") + + checkAnswer(df, unionDF) + } + } +} + +/** + * Change a column type (Case 4). + * This suite assumes that a user gives a wider schema intentionally. + */ +trait ToDecimalTypeTest extends ReadSchemaTest { + import testImplicits._ + + private lazy val values = 1 to 10 + private lazy val floatDF = values.map(_.toFloat).toDF("col1") + private lazy val doubleDF = values.map(_.toDouble).toDF("col1") + private lazy val decimalDF = values.map(BigDecimal(_)).toDF("col1") + private lazy val unionDF = floatDF.union(doubleDF).union(decimalDF) + + test("change column type from float to decimal") { + withTempPath { dir => + val path = dir.getCanonicalPath + + floatDF.write.format(format).options(options).save(path) + + val df = spark.read + .schema("col1 decimal(38,18)") + .format(format) + .options(options) + .load(path) + + checkAnswer(df, decimalDF) + } + } + + test("change column type from double to decimal") { + withTempPath { dir => + val path = dir.getCanonicalPath + + doubleDF.write.format(format).options(options).save(path) + + val df = spark.read + .schema("col1 decimal(38,18)") + .format(format) + .options(options) + .load(path) + + checkAnswer(df, decimalDF) + } + } + + test("read float, double, decimal together") { + withTempPath { dir => + val path = dir.getCanonicalPath + + val floatDir = s"$path${File.separator}part=float" + val doubleDir = s"$path${File.separator}part=double" + val decimalDir = s"$path${File.separator}part=decimal" + + floatDF.write.format(format).options(options).save(floatDir) + doubleDF.write.format(format).options(options).save(doubleDir) + decimalDF.write.format(format).options(options).save(decimalDir) + + val df = spark.read + .schema(unionDF.schema) + .format(format) + .options(options) + .load(path) + .select("col1") + + checkAnswer(df, unionDF) + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVBenchmarks.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVBenchmarks.scala index 1a3dacb8398e6..24f5f55d55485 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVBenchmarks.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVBenchmarks.scala @@ -119,8 +119,47 @@ object CSVBenchmarks { } } + def countBenchmark(rowsNum: Int): Unit = { + val colsNum = 10 + val benchmark = new Benchmark(s"Count a dataset with $colsNum columns", rowsNum) + + withTempPath { path => + val fields = Seq.tabulate(colsNum)(i => StructField(s"col$i", IntegerType)) + val schema = StructType(fields) + + spark.range(rowsNum) + .select(Seq.tabulate(colsNum)(i => lit(i).as(s"col$i")): _*) + .write + .csv(path.getAbsolutePath) + + val ds = spark.read.schema(schema).csv(path.getAbsolutePath) + + benchmark.addCase(s"Select $colsNum columns + count()", 3) { _ => + ds.select("*").filter((_: Row) => true).count() + } + benchmark.addCase(s"Select 1 column + count()", 3) { _ => + ds.select($"col1").filter((_: Row) => true).count() + } + benchmark.addCase(s"count()", 3) { _ => + ds.count() + } + + /* + Intel(R) Core(TM) i7-7700HQ CPU @ 2.80GHz + + Count a dataset with 10 columns: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + --------------------------------------------------------------------------------------------- + Select 10 columns + count() 12598 / 12740 0.8 1259.8 1.0X + Select 1 column + count() 7960 / 8175 1.3 796.0 1.6X + count() 2332 / 2386 4.3 233.2 5.4X + */ + benchmark.run() + } + } + def main(args: Array[String]): Unit = { quotedValuesBenchmark(rowsNum = 50 * 1000, numIters = 3) multiColumnsBenchmark(rowsNum = 1000 * 1000) + countBenchmark(10 * 1000 * 1000) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchemaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchemaSuite.scala index 842251be92c18..57e36e082653c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchemaSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchemaSuite.scala @@ -132,7 +132,7 @@ class CSVInferSchemaSuite extends SparkFunSuite { == StringType) } - test("DoubleType should be infered when user defined nan/inf are provided") { + test("DoubleType should be inferred when user defined nan/inf are provided") { val options = new CSVOptions(Map("nanValue" -> "nan", "negativeInf" -> "-inf", "positiveInf" -> "inf"), false, "GMT") assert(CSVInferSchema.inferField(NullType, "nan", options) == DoubleType) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala index 84b91f6309fe8..f70df0bcecde7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala @@ -18,12 +18,14 @@ package org.apache.spark.sql.execution.datasources.csv import java.io.File -import java.nio.charset.UnsupportedCharsetException +import java.nio.charset.{Charset, UnsupportedCharsetException} +import java.nio.file.Files import java.sql.{Date, Timestamp} import java.text.SimpleDateFormat import java.util.Locale import scala.collection.JavaConverters._ +import scala.util.Properties import org.apache.commons.lang3.time.FastDateFormat import org.apache.hadoop.io.SequenceFile.CompressionType @@ -32,7 +34,7 @@ import org.apache.log4j.{AppenderSkeleton, LogManager} import org.apache.log4j.spi.LoggingEvent import org.apache.spark.SparkException -import org.apache.spark.sql.{AnalysisException, DataFrame, QueryTest, Row, UDT} +import org.apache.spark.sql.{AnalysisException, DataFrame, QueryTest, Row} import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.{SharedSQLContext, SQLTestUtils} @@ -48,6 +50,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils with Te private val carsAltFile = "test-data/cars-alternative.csv" private val carsUnbalancedQuotesFile = "test-data/cars-unbalanced-quotes.csv" private val carsNullFile = "test-data/cars-null.csv" + private val carsEmptyValueFile = "test-data/cars-empty-value.csv" private val carsBlankColName = "test-data/cars-blank-column-name.csv" private val emptyFile = "test-data/empty.csv" private val commentsFile = "test-data/comments.csv" @@ -60,10 +63,6 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils with Te private val unescapedQuotesFile = "test-data/unescaped-quotes.csv" private val valueMalformedFile = "test-data/value-malformed.csv" - private def testFile(fileName: String): String = { - Thread.currentThread().getContextClassLoader.getResource(fileName).toString - } - /** Verifies data and schema. */ private def verifyCars( df: DataFrame, @@ -518,6 +517,41 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils with Te } } + test("SPARK-19018: Save csv with custom charset") { + + // scalastyle:off nonascii + val content = "µß áâä ÁÂÄ" + // scalastyle:on nonascii + + Seq("iso-8859-1", "utf-8", "utf-16", "utf-32", "windows-1250").foreach { encoding => + withTempPath { path => + val csvDir = new File(path, "csv") + Seq(content).toDF().write + .option("encoding", encoding) + .csv(csvDir.getCanonicalPath) + + csvDir.listFiles().filter(_.getName.endsWith("csv")).foreach({ csvFile => + val readback = Files.readAllBytes(csvFile.toPath) + val expected = (content + Properties.lineSeparator).getBytes(Charset.forName(encoding)) + assert(readback === expected) + }) + } + } + } + + test("SPARK-19018: error handling for unsupported charsets") { + val exception = intercept[SparkException] { + withTempPath { path => + val csvDir = new File(path, "csv").getCanonicalPath + Seq("a,A,c,A,b,B").toDF().write + .option("encoding", "1-9588-osi") + .csv(csvDir) + } + } + + assert(exception.getCause.getMessage.contains("1-9588-osi")) + } + test("commented lines in CSV data") { Seq("false", "true").foreach { multiLine => @@ -635,6 +669,70 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils with Te assert(results(2).toSeq === Array(null, "Chevy", "Volt", null, null)) } + test("empty fields with user defined empty values") { + + // year,make,model,comment,blank + val dataSchema = StructType(List( + StructField("year", IntegerType, nullable = true), + StructField("make", StringType, nullable = false), + StructField("model", StringType, nullable = false), + StructField("comment", StringType, nullable = true), + StructField("blank", StringType, nullable = true))) + val cars = spark.read + .format("csv") + .schema(dataSchema) + .option("header", "true") + .option("emptyValue", "empty") + .load(testFile(carsEmptyValueFile)) + + verifyCars(cars, withHeader = true, checkValues = false) + val results = cars.collect() + assert(results(0).toSeq === Array(2012, "Tesla", "S", "empty", "empty")) + assert(results(1).toSeq === + Array(1997, "Ford", "E350", "Go get one now they are going fast", null)) + assert(results(2).toSeq === Array(2015, "Chevy", "Volt", null, "empty")) + } + + test("save csv with empty fields with user defined empty values") { + withTempDir { dir => + val csvDir = new File(dir, "csv").getCanonicalPath + + // year,make,model,comment,blank + val dataSchema = StructType(List( + StructField("year", IntegerType, nullable = true), + StructField("make", StringType, nullable = false), + StructField("model", StringType, nullable = false), + StructField("comment", StringType, nullable = true), + StructField("blank", StringType, nullable = true))) + val cars = spark.read + .format("csv") + .schema(dataSchema) + .option("header", "true") + .option("nullValue", "NULL") + .load(testFile(carsEmptyValueFile)) + + cars.coalesce(1).write + .format("csv") + .option("header", "true") + .option("emptyValue", "empty") + .option("nullValue", null) + .save(csvDir) + + val carsCopy = spark.read + .format("csv") + .schema(dataSchema) + .option("header", "true") + .load(csvDir) + + verifyCars(carsCopy, withHeader = true, checkValues = false) + val results = carsCopy.collect() + assert(results(0).toSeq === Array(2012, "Tesla", "S", "empty", "empty")) + assert(results(1).toSeq === + Array(1997, "Ford", "E350", "Go get one now they are going fast", null)) + assert(results(2).toSeq === Array(2015, "Chevy", "Volt", null, "empty")) + } + } + test("save csv with compression codec option") { withTempDir { dir => val csvDir = new File(dir, "csv").getCanonicalPath @@ -1342,6 +1440,52 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils with Te } } + test("SPARK-25241: An empty string should not be coerced to null when emptyValue is passed.") { + val litNull: String = null + val df = Seq( + (1, "John Doe"), + (2, ""), + (3, "-"), + (4, litNull) + ).toDF("id", "name") + + // Checks for new behavior where a null is not coerced to an empty string when `emptyValue` is + // set to anything but an empty string literal. + withTempPath { path => + df.write + .option("emptyValue", "-") + .csv(path.getAbsolutePath) + val computed = spark.read + .option("emptyValue", "-") + .schema(df.schema) + .csv(path.getAbsolutePath) + val expected = Seq( + (1, "John Doe"), + (2, "-"), + (3, "-"), + (4, "-") + ).toDF("id", "name") + + checkAnswer(computed, expected) + } + // Keeps the old behavior where empty string us coerced to emptyValue is not passed. + withTempPath { path => + df.write + .csv(path.getAbsolutePath) + val computed = spark.read + .schema(df.schema) + .csv(path.getAbsolutePath) + val expected = Seq( + (1, "John Doe"), + (2, litNull), + (3, "-"), + (4, litNull) + ).toDF("id", "name") + + checkAnswer(computed, expected) + } + } + test("SPARK-24329: skip lines with comments, and one or multiple whitespaces") { val schema = new StructType().add("colA", StringType) val ds = spark @@ -1570,6 +1714,39 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils with Te .exists(msg => msg.getRenderedMessage.contains("CSV header does not conform to the schema"))) } + test("SPARK-25134: check header on parsing of dataset with projection and column pruning") { + withSQLConf(SQLConf.CSV_PARSER_COLUMN_PRUNING.key -> "true") { + Seq(false, true).foreach { multiLine => + withTempPath { path => + val dir = path.getAbsolutePath + Seq(("a", "b")).toDF("columnA", "columnB").write + .format("csv") + .option("header", true) + .save(dir) + + // schema with one column + checkAnswer(spark.read + .format("csv") + .option("header", true) + .option("enforceSchema", false) + .option("multiLine", multiLine) + .load(dir) + .select("columnA"), + Row("a")) + + // empty schema + assert(spark.read + .format("csv") + .option("header", true) + .option("enforceSchema", false) + .option("multiLine", multiLine) + .load(dir) + .count() === 1L) + } + } + } + } + test("SPARK-24645 skip parsing when columnPruning enabled and partitions scanned only") { withSQLConf(SQLConf.CSV_PARSER_COLUMN_PRUNING.key -> "true") { withTempPath { path => @@ -1579,4 +1756,68 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils with Te } } } + + test("SPARK-24676 project required data from parsed data when columnPruning disabled") { + withSQLConf(SQLConf.CSV_PARSER_COLUMN_PRUNING.key -> "false") { + withTempPath { path => + val dir = path.getAbsolutePath + spark.range(10).selectExpr("id % 2 AS p", "id AS c0", "id AS c1").write.partitionBy("p") + .option("header", "true").csv(dir) + val df1 = spark.read.option("header", true).csv(dir).selectExpr("sum(p)", "count(c0)") + checkAnswer(df1, Row(5, 10)) + + // empty required column case + val df2 = spark.read.option("header", true).csv(dir).selectExpr("sum(p)") + checkAnswer(df2, Row(5)) + } + + // the case where tokens length != parsedSchema length + withTempPath { path => + val dir = path.getAbsolutePath + Seq("1,2").toDF().write.text(dir) + // more tokens + val df1 = spark.read.schema("c0 int").format("csv").option("mode", "permissive").load(dir) + checkAnswer(df1, Row(1)) + // less tokens + val df2 = spark.read.schema("c0 int, c1 int, c2 int").format("csv") + .option("mode", "permissive").load(dir) + checkAnswer(df2, Row(1, 2, null)) + } + } + } + + test("count() for malformed input") { + def countForMalformedCSV(expected: Long, input: Seq[String]): Unit = { + val schema = new StructType().add("a", IntegerType) + val strings = spark.createDataset(input) + val df = spark.read.schema(schema).option("header", false).csv(strings) + + assert(df.count() == expected) + } + def checkCount(expected: Long): Unit = { + val validRec = "1" + val inputs = Seq( + Seq("{-}", validRec), + Seq(validRec, "?"), + Seq("0xAC", validRec), + Seq(validRec, "0.314"), + Seq("\\\\\\", validRec) + ) + inputs.foreach { input => + countForMalformedCSV(expected, input) + } + } + + checkCount(2) + countForMalformedCSV(0, Seq("")) + } + + test("SPARK-25387: bad input should not cause NPE") { + val schema = StructType(StructField("a", IntegerType) :: Nil) + val input = spark.createDataset(Seq("\u0000\u0000\u0001234")) + + checkAnswer(spark.read.schema(schema).csv(input), Row(null)) + checkAnswer(spark.read.option("multiLine", true).schema(schema).csv(input), Row(null)) + assert(spark.read.csv(input).collect().toSet == Set(Row())) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonBenchmarks.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonBenchmarks.scala index 85cf054e51f6b..a2b747eaab411 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonBenchmarks.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonBenchmarks.scala @@ -19,8 +19,9 @@ package org.apache.spark.sql.execution.datasources.json import java.io.File import org.apache.spark.SparkConf -import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.types.{LongType, StringType, StructType} +import org.apache.spark.sql.{Row, SparkSession} +import org.apache.spark.sql.functions.lit +import org.apache.spark.sql.types._ import org.apache.spark.util.{Benchmark, Utils} /** @@ -171,9 +172,49 @@ object JSONBenchmarks { } } + def countBenchmark(rowsNum: Int): Unit = { + val colsNum = 10 + val benchmark = new Benchmark(s"Count a dataset with $colsNum columns", rowsNum) + + withTempPath { path => + val fields = Seq.tabulate(colsNum)(i => StructField(s"col$i", IntegerType)) + val schema = StructType(fields) + val columnNames = schema.fieldNames + + spark.range(rowsNum) + .select(Seq.tabulate(colsNum)(i => lit(i).as(s"col$i")): _*) + .write + .json(path.getAbsolutePath) + + val ds = spark.read.schema(schema).json(path.getAbsolutePath) + + benchmark.addCase(s"Select $colsNum columns + count()", 3) { _ => + ds.select("*").filter((_: Row) => true).count() + } + benchmark.addCase(s"Select 1 column + count()", 3) { _ => + ds.select($"col1").filter((_: Row) => true).count() + } + benchmark.addCase(s"count()", 3) { _ => + ds.count() + } + + /* + Intel(R) Core(TM) i7-7700HQ CPU @ 2.80GHz + + Count a dataset with 10 columns: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + --------------------------------------------------------------------------------------------- + Select 10 columns + count() 9961 / 10006 1.0 996.1 1.0X + Select 1 column + count() 8355 / 8470 1.2 835.5 1.2X + count() 2104 / 2156 4.8 210.4 4.7X + */ + benchmark.run() + } + } + def main(args: Array[String]): Unit = { schemaInferring(100 * 1000 * 1000) perlineParsing(100 * 1000 * 1000) perlineParsingOfWideColumn(10 * 1000 * 1000) + countBenchmark(10 * 1000 * 1000) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala index 897424daca0cb..3e4cc8f166279 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala @@ -31,11 +31,11 @@ import org.apache.hadoop.io.compress.GzipCodec import org.apache.spark.{SparkException, TestUtils} import org.apache.spark.rdd.RDD import org.apache.spark.sql.{functions => F, _} -import org.apache.spark.sql.catalyst.json.{CreateJacksonParser, JacksonParser, JSONOptions} +import org.apache.spark.sql.catalyst.json.{CreateJacksonParser, JacksonParser, JsonInferSchema, JSONOptions} +import org.apache.spark.sql.catalyst.json.JsonInferSchema.compatibleType import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution.ExternalRDD import org.apache.spark.sql.execution.datasources.DataSource -import org.apache.spark.sql.execution.datasources.json.JsonInferSchema.compatibleType import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ @@ -48,10 +48,6 @@ class TestFileFilter extends PathFilter { class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { import testImplicits._ - def testFile(fileName: String): String = { - Thread.currentThread().getContextClassLoader.getResource(fileName).toString - } - test("Type promotion") { def checkTypePromotion(expected: Any, actual: Any) { assert(expected.getClass == actual.getClass, @@ -2227,7 +2223,6 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { checkAnswer(jsonDF, Seq(Row("Chris", "Baird"))) } - test("SPARK-23723: specified encoding is not matched to actual encoding") { val fileName = "test-data/utf16LE.json" val schema = new StructType().add("firstName", StringType).add("lastName", StringType) @@ -2494,4 +2489,30 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { assert(exception.getMessage.contains("encoding must not be included in the blacklist")) } } + + test("count() for malformed input") { + def countForMalformedJSON(expected: Long, input: Seq[String]): Unit = { + val schema = new StructType().add("a", StringType) + val strings = spark.createDataset(input) + val df = spark.read.schema(schema).json(strings) + + assert(df.count() == expected) + } + def checkCount(expected: Long): Unit = { + val validRec = """{"a":"b"}""" + val inputs = Seq( + Seq("{-}", validRec), + Seq(validRec, "?"), + Seq("}", validRec), + Seq(validRec, """{"a": [1, 2, 3]}"""), + Seq("""{"a": {"a": "b"}}""", validRec) + ) + inputs.foreach { input => + countForMalformedJSON(expected, input) + } + } + + checkCount(2) + countForMalformedJSON(0, Seq("")) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcQuerySuite.scala index f58c331f33ca8..e9dccbf2e261c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcQuerySuite.scala @@ -562,20 +562,57 @@ abstract class OrcQueryTest extends OrcTest { } } + def testAllCorruptFiles(): Unit = { + withTempDir { dir => + val basePath = dir.getCanonicalPath + spark.range(1).toDF("a").write.json(new Path(basePath, "first").toString) + spark.range(1, 2).toDF("a").write.json(new Path(basePath, "second").toString) + val df = spark.read.orc( + new Path(basePath, "first").toString, + new Path(basePath, "second").toString) + assert(df.count() == 0) + } + } + + def testAllCorruptFilesWithoutSchemaInfer(): Unit = { + withTempDir { dir => + val basePath = dir.getCanonicalPath + spark.range(1).toDF("a").write.json(new Path(basePath, "first").toString) + spark.range(1, 2).toDF("a").write.json(new Path(basePath, "second").toString) + val df = spark.read.schema("a long").orc( + new Path(basePath, "first").toString, + new Path(basePath, "second").toString) + assert(df.count() == 0) + } + } + withSQLConf(SQLConf.IGNORE_CORRUPT_FILES.key -> "true") { testIgnoreCorruptFiles() testIgnoreCorruptFilesWithoutSchemaInfer() + val m1 = intercept[AnalysisException] { + testAllCorruptFiles() + }.getMessage + assert(m1.contains("Unable to infer schema for ORC")) + testAllCorruptFilesWithoutSchemaInfer() } withSQLConf(SQLConf.IGNORE_CORRUPT_FILES.key -> "false") { val m1 = intercept[SparkException] { testIgnoreCorruptFiles() }.getMessage - assert(m1.contains("Could not read footer for file")) + assert(m1.contains("Malformed ORC file")) val m2 = intercept[SparkException] { testIgnoreCorruptFilesWithoutSchemaInfer() }.getMessage assert(m2.contains("Malformed ORC file")) + val m3 = intercept[SparkException] { + testAllCorruptFiles() + }.getMessage + assert(m3.contains("Could not read footer for file")) + val m4 = intercept[SparkException] { + testAllCorruptFilesWithoutSchemaInfer() + }.getMessage + assert(m4.contains("Malformed ORC file")) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCompressionCodecPrecedenceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCompressionCodecPrecedenceSuite.scala index ed8fd2b453456..09de715e87a11 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCompressionCodecPrecedenceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCompressionCodecPrecedenceSuite.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.test.SharedSQLContext class ParquetCompressionCodecPrecedenceSuite extends ParquetTest with SharedSQLContext { test("Test `spark.sql.parquet.compression.codec` config") { - Seq("NONE", "UNCOMPRESSED", "SNAPPY", "GZIP", "LZO").foreach { c => + Seq("NONE", "UNCOMPRESSED", "SNAPPY", "GZIP", "LZO", "LZ4", "BROTLI", "ZSTD").foreach { c => withSQLConf(SQLConf.PARQUET_COMPRESSION.key -> c) { val expected = if (c == "NONE") "UNCOMPRESSED" else c val option = new ParquetOptions(Map.empty[String, String], spark.sessionState.conf) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormatSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormatSuite.scala index 3a0867fd2b78b..94abf115cef35 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormatSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormatSuite.scala @@ -51,9 +51,9 @@ class ParquetFileFormatSuite extends QueryTest with ParquetTest with SharedSQLCo } testReadFooters(true) - val exception = intercept[java.io.IOException] { + val exception = intercept[SparkException] { testReadFooters(false) - } + }.getCause assert(exception.getMessage().contains("Could not read footer for file")) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala index d9ae5858e5ed0..7ebb75009555a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala @@ -17,13 +17,15 @@ package org.apache.spark.sql.execution.datasources.parquet +import java.math.{BigDecimal => JBigDecimal} import java.nio.charset.StandardCharsets -import java.sql.Date +import java.sql.{Date, Timestamp} -import org.apache.parquet.filter2.predicate.{FilterPredicate, Operators} +import org.apache.parquet.filter2.predicate.{FilterApi, FilterPredicate, Operators} import org.apache.parquet.filter2.predicate.FilterApi._ import org.apache.parquet.filter2.predicate.Operators.{Column => _, _} +import org.apache.spark.SparkException import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ @@ -31,6 +33,7 @@ import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.execution.datasources.{DataSourceStrategy, HadoopFsRelation, LogicalRelation} import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.internal.SQLConf.ParquetOutputTimestampType import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ import org.apache.spark.util.{AccumulatorContext, AccumulatorV2} @@ -56,7 +59,9 @@ import org.apache.spark.util.{AccumulatorContext, AccumulatorV2} class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContext { private lazy val parquetFilters = - new ParquetFilters(conf.parquetFilterPushDownDate, conf.parquetFilterPushDownStringStartWith) + new ParquetFilters(conf.parquetFilterPushDownDate, conf.parquetFilterPushDownTimestamp, + conf.parquetFilterPushDownDecimal, conf.parquetFilterPushDownStringStartWith, + conf.parquetFilterPushDownInFilterThreshold, conf.caseSensitiveAnalysis) override def beforeEach(): Unit = { super.beforeEach() @@ -83,6 +88,8 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex withSQLConf( SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED.key -> "true", SQLConf.PARQUET_FILTER_PUSHDOWN_DATE_ENABLED.key -> "true", + SQLConf.PARQUET_FILTER_PUSHDOWN_TIMESTAMP_ENABLED.key -> "true", + SQLConf.PARQUET_FILTER_PUSHDOWN_DECIMAL_ENABLED.key -> "true", SQLConf.PARQUET_FILTER_PUSHDOWN_STRING_STARTSWITH_ENABLED.key -> "true", SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false") { val query = df @@ -103,7 +110,8 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex assert(selectedFilters.nonEmpty, "No filter is pushed down") selectedFilters.foreach { pred => - val maybeFilter = parquetFilters.createFilter(df.schema, pred) + val maybeFilter = parquetFilters.createFilter( + new SparkToParquetSchemaConverter(conf).convert(df.schema), pred) assert(maybeFilter.isDefined, s"Couldn't generate filter predicate for $pred") // Doesn't bother checking type parameters here (e.g. `Eq[Integer]`) maybeFilter.exists(_.getClass === filterClass) @@ -142,6 +150,46 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex checkBinaryFilterPredicate(predicate, filterClass, Seq(Row(expected)))(df) } + private def testTimestampPushdown(data: Seq[Timestamp]): Unit = { + assert(data.size === 4) + val ts1 = data.head + val ts2 = data(1) + val ts3 = data(2) + val ts4 = data(3) + + withParquetDataFrame(data.map(i => Tuple1(i))) { implicit df => + checkFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row]) + checkFilterPredicate('_1.isNotNull, classOf[NotEq[_]], data.map(i => Row.apply(i))) + + checkFilterPredicate('_1 === ts1, classOf[Eq[_]], ts1) + checkFilterPredicate('_1 <=> ts1, classOf[Eq[_]], ts1) + checkFilterPredicate('_1 =!= ts1, classOf[NotEq[_]], + Seq(ts2, ts3, ts4).map(i => Row.apply(i))) + + checkFilterPredicate('_1 < ts2, classOf[Lt[_]], ts1) + checkFilterPredicate('_1 > ts1, classOf[Gt[_]], Seq(ts2, ts3, ts4).map(i => Row.apply(i))) + checkFilterPredicate('_1 <= ts1, classOf[LtEq[_]], ts1) + checkFilterPredicate('_1 >= ts4, classOf[GtEq[_]], ts4) + + checkFilterPredicate(Literal(ts1) === '_1, classOf[Eq[_]], ts1) + checkFilterPredicate(Literal(ts1) <=> '_1, classOf[Eq[_]], ts1) + checkFilterPredicate(Literal(ts2) > '_1, classOf[Lt[_]], ts1) + checkFilterPredicate(Literal(ts3) < '_1, classOf[Gt[_]], ts4) + checkFilterPredicate(Literal(ts1) >= '_1, classOf[LtEq[_]], ts1) + checkFilterPredicate(Literal(ts4) <= '_1, classOf[GtEq[_]], ts4) + + checkFilterPredicate(!('_1 < ts4), classOf[GtEq[_]], ts4) + checkFilterPredicate('_1 < ts2 || '_1 > ts3, classOf[Operators.Or], Seq(Row(ts1), Row(ts4))) + } + } + + private def testDecimalPushDown(data: DataFrame)(f: DataFrame => Unit): Unit = { + withTempPath { file => + data.write.parquet(file.getCanonicalPath) + readParquetFile(file.toString)(f) + } + } + // This function tests that exactly go through the `canDrop` and `inverseCanDrop`. private def testStringStartsWith(dataFrame: DataFrame, filter: String): Unit = { withTempPath { dir => @@ -178,6 +226,62 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex } } + test("filter pushdown - tinyint") { + withParquetDataFrame((1 to 4).map(i => Tuple1(Option(i.toByte)))) { implicit df => + assert(df.schema.head.dataType === ByteType) + checkFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row]) + checkFilterPredicate('_1.isNotNull, classOf[NotEq[_]], (1 to 4).map(Row.apply(_))) + + checkFilterPredicate('_1 === 1.toByte, classOf[Eq[_]], 1) + checkFilterPredicate('_1 <=> 1.toByte, classOf[Eq[_]], 1) + checkFilterPredicate('_1 =!= 1.toByte, classOf[NotEq[_]], (2 to 4).map(Row.apply(_))) + + checkFilterPredicate('_1 < 2.toByte, classOf[Lt[_]], 1) + checkFilterPredicate('_1 > 3.toByte, classOf[Gt[_]], 4) + checkFilterPredicate('_1 <= 1.toByte, classOf[LtEq[_]], 1) + checkFilterPredicate('_1 >= 4.toByte, classOf[GtEq[_]], 4) + + checkFilterPredicate(Literal(1.toByte) === '_1, classOf[Eq[_]], 1) + checkFilterPredicate(Literal(1.toByte) <=> '_1, classOf[Eq[_]], 1) + checkFilterPredicate(Literal(2.toByte) > '_1, classOf[Lt[_]], 1) + checkFilterPredicate(Literal(3.toByte) < '_1, classOf[Gt[_]], 4) + checkFilterPredicate(Literal(1.toByte) >= '_1, classOf[LtEq[_]], 1) + checkFilterPredicate(Literal(4.toByte) <= '_1, classOf[GtEq[_]], 4) + + checkFilterPredicate(!('_1 < 4.toByte), classOf[GtEq[_]], 4) + checkFilterPredicate('_1 < 2.toByte || '_1 > 3.toByte, + classOf[Operators.Or], Seq(Row(1), Row(4))) + } + } + + test("filter pushdown - smallint") { + withParquetDataFrame((1 to 4).map(i => Tuple1(Option(i.toShort)))) { implicit df => + assert(df.schema.head.dataType === ShortType) + checkFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row]) + checkFilterPredicate('_1.isNotNull, classOf[NotEq[_]], (1 to 4).map(Row.apply(_))) + + checkFilterPredicate('_1 === 1.toShort, classOf[Eq[_]], 1) + checkFilterPredicate('_1 <=> 1.toShort, classOf[Eq[_]], 1) + checkFilterPredicate('_1 =!= 1.toShort, classOf[NotEq[_]], (2 to 4).map(Row.apply(_))) + + checkFilterPredicate('_1 < 2.toShort, classOf[Lt[_]], 1) + checkFilterPredicate('_1 > 3.toShort, classOf[Gt[_]], 4) + checkFilterPredicate('_1 <= 1.toShort, classOf[LtEq[_]], 1) + checkFilterPredicate('_1 >= 4.toShort, classOf[GtEq[_]], 4) + + checkFilterPredicate(Literal(1.toShort) === '_1, classOf[Eq[_]], 1) + checkFilterPredicate(Literal(1.toShort) <=> '_1, classOf[Eq[_]], 1) + checkFilterPredicate(Literal(2.toShort) > '_1, classOf[Lt[_]], 1) + checkFilterPredicate(Literal(3.toShort) < '_1, classOf[Gt[_]], 4) + checkFilterPredicate(Literal(1.toShort) >= '_1, classOf[LtEq[_]], 1) + checkFilterPredicate(Literal(4.toShort) <= '_1, classOf[GtEq[_]], 4) + + checkFilterPredicate(!('_1 < 4.toShort), classOf[GtEq[_]], 4) + checkFilterPredicate('_1 < 2.toShort || '_1 > 3.toShort, + classOf[Operators.Or], Seq(Row(1), Row(4))) + } + } + test("filter pushdown - integer") { withParquetDataFrame((1 to 4).map(i => Tuple1(Option(i)))) { implicit df => checkFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row]) @@ -386,6 +490,117 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex } } + test("filter pushdown - timestamp") { + // spark.sql.parquet.outputTimestampType = TIMESTAMP_MILLIS + val millisData = Seq(Timestamp.valueOf("2018-06-14 08:28:53.123"), + Timestamp.valueOf("2018-06-15 08:28:53.123"), + Timestamp.valueOf("2018-06-16 08:28:53.123"), + Timestamp.valueOf("2018-06-17 08:28:53.123")) + withSQLConf(SQLConf.PARQUET_OUTPUT_TIMESTAMP_TYPE.key -> + ParquetOutputTimestampType.TIMESTAMP_MILLIS.toString) { + testTimestampPushdown(millisData) + } + + // spark.sql.parquet.outputTimestampType = TIMESTAMP_MICROS + val microsData = Seq(Timestamp.valueOf("2018-06-14 08:28:53.123456"), + Timestamp.valueOf("2018-06-15 08:28:53.123456"), + Timestamp.valueOf("2018-06-16 08:28:53.123456"), + Timestamp.valueOf("2018-06-17 08:28:53.123456")) + withSQLConf(SQLConf.PARQUET_OUTPUT_TIMESTAMP_TYPE.key -> + ParquetOutputTimestampType.TIMESTAMP_MICROS.toString) { + testTimestampPushdown(microsData) + } + + // spark.sql.parquet.outputTimestampType = INT96 doesn't support pushdown + withSQLConf(SQLConf.PARQUET_OUTPUT_TIMESTAMP_TYPE.key -> + ParquetOutputTimestampType.INT96.toString) { + withParquetDataFrame(millisData.map(i => Tuple1(i))) { implicit df => + assertResult(None) { + parquetFilters.createFilter( + new SparkToParquetSchemaConverter(conf).convert(df.schema), sources.IsNull("_1")) + } + } + } + } + + test("filter pushdown - decimal") { + Seq(true, false).foreach { legacyFormat => + withSQLConf(SQLConf.PARQUET_WRITE_LEGACY_FORMAT.key -> legacyFormat.toString) { + Seq( + s"a decimal(${Decimal.MAX_INT_DIGITS}, 2)", // 32BitDecimalType + s"a decimal(${Decimal.MAX_LONG_DIGITS}, 2)", // 64BitDecimalType + "a decimal(38, 18)" // ByteArrayDecimalType + ).foreach { schemaDDL => + val schema = StructType.fromDDL(schemaDDL) + val rdd = + spark.sparkContext.parallelize((1 to 4).map(i => Row(new java.math.BigDecimal(i)))) + val dataFrame = spark.createDataFrame(rdd, schema) + testDecimalPushDown(dataFrame) { implicit df => + assert(df.schema === schema) + checkFilterPredicate('a.isNull, classOf[Eq[_]], Seq.empty[Row]) + checkFilterPredicate('a.isNotNull, classOf[NotEq[_]], (1 to 4).map(Row.apply(_))) + + checkFilterPredicate('a === 1, classOf[Eq[_]], 1) + checkFilterPredicate('a <=> 1, classOf[Eq[_]], 1) + checkFilterPredicate('a =!= 1, classOf[NotEq[_]], (2 to 4).map(Row.apply(_))) + + checkFilterPredicate('a < 2, classOf[Lt[_]], 1) + checkFilterPredicate('a > 3, classOf[Gt[_]], 4) + checkFilterPredicate('a <= 1, classOf[LtEq[_]], 1) + checkFilterPredicate('a >= 4, classOf[GtEq[_]], 4) + + checkFilterPredicate(Literal(1) === 'a, classOf[Eq[_]], 1) + checkFilterPredicate(Literal(1) <=> 'a, classOf[Eq[_]], 1) + checkFilterPredicate(Literal(2) > 'a, classOf[Lt[_]], 1) + checkFilterPredicate(Literal(3) < 'a, classOf[Gt[_]], 4) + checkFilterPredicate(Literal(1) >= 'a, classOf[LtEq[_]], 1) + checkFilterPredicate(Literal(4) <= 'a, classOf[GtEq[_]], 4) + + checkFilterPredicate(!('a < 4), classOf[GtEq[_]], 4) + checkFilterPredicate('a < 2 || 'a > 3, classOf[Operators.Or], Seq(Row(1), Row(4))) + } + } + } + } + } + + test("Ensure that filter value matched the parquet file schema") { + val scale = 2 + val schema = StructType(Seq( + StructField("cint", IntegerType), + StructField("cdecimal1", DecimalType(Decimal.MAX_INT_DIGITS, scale)), + StructField("cdecimal2", DecimalType(Decimal.MAX_LONG_DIGITS, scale)), + StructField("cdecimal3", DecimalType(DecimalType.MAX_PRECISION, scale)) + )) + + val parquetSchema = new SparkToParquetSchemaConverter(conf).convert(schema) + + val decimal = new JBigDecimal(10).setScale(scale) + val decimal1 = new JBigDecimal(10).setScale(scale + 1) + assert(decimal.scale() === scale) + assert(decimal1.scale() === scale + 1) + + assertResult(Some(lt(intColumn("cdecimal1"), 1000: Integer))) { + parquetFilters.createFilter(parquetSchema, sources.LessThan("cdecimal1", decimal)) + } + assertResult(None) { + parquetFilters.createFilter(parquetSchema, sources.LessThan("cdecimal1", decimal1)) + } + + assertResult(Some(lt(longColumn("cdecimal2"), 1000L: java.lang.Long))) { + parquetFilters.createFilter(parquetSchema, sources.LessThan("cdecimal2", decimal)) + } + assertResult(None) { + parquetFilters.createFilter(parquetSchema, sources.LessThan("cdecimal2", decimal1)) + } + + assert(parquetFilters.createFilter( + parquetSchema, sources.LessThan("cdecimal3", decimal)).isDefined) + assertResult(None) { + parquetFilters.createFilter(parquetSchema, sources.LessThan("cdecimal3", decimal1)) + } + } + test("SPARK-6554: don't push down predicates which reference partition columns") { import testImplicits._ @@ -542,12 +757,14 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex StructField("c", DoubleType, nullable = true) )) + val parquetSchema = new SparkToParquetSchemaConverter(conf).convert(schema) + assertResult(Some(and( lt(intColumn("a"), 10: Integer), gt(doubleColumn("c"), 1.5: java.lang.Double))) ) { parquetFilters.createFilter( - schema, + parquetSchema, sources.And( sources.LessThan("a", 10), sources.GreaterThan("c", 1.5D))) @@ -555,7 +772,7 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex assertResult(None) { parquetFilters.createFilter( - schema, + parquetSchema, sources.And( sources.LessThan("a", 10), sources.StringContains("b", "prefix"))) @@ -563,7 +780,7 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex assertResult(None) { parquetFilters.createFilter( - schema, + parquetSchema, sources.Not( sources.And( sources.GreaterThan("a", 1), @@ -615,21 +832,25 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex } test("SPARK-17213: Broken Parquet filter push-down for string columns") { - withTempPath { dir => - import testImplicits._ + Seq(true, false).foreach { vectorizedEnabled => + withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> vectorizedEnabled.toString) { + withTempPath { dir => + import testImplicits._ - val path = dir.getCanonicalPath - // scalastyle:off nonascii - Seq("a", "é").toDF("name").write.parquet(path) - // scalastyle:on nonascii + val path = dir.getCanonicalPath + // scalastyle:off nonascii + Seq("a", "é").toDF("name").write.parquet(path) + // scalastyle:on nonascii - assert(spark.read.parquet(path).where("name > 'a'").count() == 1) - assert(spark.read.parquet(path).where("name >= 'a'").count() == 2) + assert(spark.read.parquet(path).where("name > 'a'").count() == 1) + assert(spark.read.parquet(path).where("name >= 'a'").count() == 2) - // scalastyle:off nonascii - assert(spark.read.parquet(path).where("name < 'é'").count() == 1) - assert(spark.read.parquet(path).where("name <= 'é'").count() == 2) - // scalastyle:on nonascii + // scalastyle:off nonascii + assert(spark.read.parquet(path).where("name < 'é'").count() == 1) + assert(spark.read.parquet(path).where("name <= 'é'").count() == 2) + // scalastyle:on nonascii + } + } } } @@ -729,7 +950,7 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex assertResult(None) { parquetFilters.createFilter( - df.schema, + new SparkToParquetSchemaConverter(conf).convert(df.schema), sources.StringStartsWith("_1", null)) } } @@ -740,6 +961,179 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex // Test inverseCanDrop() has taken effect testStringStartsWith(spark.range(1024).map(c => "100").toDF(), "value not like '10%'") } + + test("SPARK-17091: Convert IN predicate to Parquet filter push-down") { + val schema = StructType(Seq( + StructField("a", IntegerType, nullable = false) + )) + + val parquetSchema = new SparkToParquetSchemaConverter(conf).convert(schema) + + assertResult(Some(FilterApi.eq(intColumn("a"), null: Integer))) { + parquetFilters.createFilter(parquetSchema, sources.In("a", Array(null))) + } + + assertResult(Some(FilterApi.eq(intColumn("a"), 10: Integer))) { + parquetFilters.createFilter(parquetSchema, sources.In("a", Array(10))) + } + + // Remove duplicates + assertResult(Some(FilterApi.eq(intColumn("a"), 10: Integer))) { + parquetFilters.createFilter(parquetSchema, sources.In("a", Array(10, 10))) + } + + assertResult(Some(or(or( + FilterApi.eq(intColumn("a"), 10: Integer), + FilterApi.eq(intColumn("a"), 20: Integer)), + FilterApi.eq(intColumn("a"), 30: Integer))) + ) { + parquetFilters.createFilter(parquetSchema, sources.In("a", Array(10, 20, 30))) + } + + assert(parquetFilters.createFilter(parquetSchema, sources.In("a", + Range(0, conf.parquetFilterPushDownInFilterThreshold).toArray)).isDefined) + assert(parquetFilters.createFilter(parquetSchema, sources.In("a", + Range(0, conf.parquetFilterPushDownInFilterThreshold + 1).toArray)).isEmpty) + + import testImplicits._ + withTempPath { path => + val data = 0 to 1024 + data.toDF("a").selectExpr("if (a = 1024, null, a) AS a") // convert 1024 to null + .coalesce(1).write.option("parquet.block.size", 512) + .parquet(path.getAbsolutePath) + val df = spark.read.parquet(path.getAbsolutePath) + Seq(true, false).foreach { pushEnabled => + withSQLConf( + SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED.key -> pushEnabled.toString) { + Seq(1, 5, 10, 11).foreach { count => + val filter = s"a in(${Range(0, count).mkString(",")})" + assert(df.where(filter).count() === count) + val actual = stripSparkFilter(df.where(filter)).collect().length + if (pushEnabled && count <= conf.parquetFilterPushDownInFilterThreshold) { + assert(actual > 1 && actual < data.length) + } else { + assert(actual === data.length) + } + } + assert(df.where("a in(null)").count() === 0) + assert(df.where("a = null").count() === 0) + assert(df.where("a is null").count() === 1) + } + } + } + } + + test("SPARK-25207: Case-insensitive field resolution for pushdown when reading parquet") { + def createParquetFilter(caseSensitive: Boolean): ParquetFilters = { + new ParquetFilters(conf.parquetFilterPushDownDate, conf.parquetFilterPushDownTimestamp, + conf.parquetFilterPushDownDecimal, conf.parquetFilterPushDownStringStartWith, + conf.parquetFilterPushDownInFilterThreshold, caseSensitive) + } + val caseSensitiveParquetFilters = createParquetFilter(caseSensitive = true) + val caseInsensitiveParquetFilters = createParquetFilter(caseSensitive = false) + + def testCaseInsensitiveResolution( + schema: StructType, + expected: FilterPredicate, + filter: sources.Filter): Unit = { + val parquetSchema = new SparkToParquetSchemaConverter(conf).convert(schema) + + assertResult(Some(expected)) { + caseInsensitiveParquetFilters.createFilter(parquetSchema, filter) + } + assertResult(None) { + caseSensitiveParquetFilters.createFilter(parquetSchema, filter) + } + } + + val schema = StructType(Seq(StructField("cint", IntegerType))) + + testCaseInsensitiveResolution( + schema, FilterApi.eq(intColumn("cint"), null.asInstanceOf[Integer]), sources.IsNull("CINT")) + + testCaseInsensitiveResolution( + schema, + FilterApi.notEq(intColumn("cint"), null.asInstanceOf[Integer]), + sources.IsNotNull("CINT")) + + testCaseInsensitiveResolution( + schema, FilterApi.eq(intColumn("cint"), 1000: Integer), sources.EqualTo("CINT", 1000)) + + testCaseInsensitiveResolution( + schema, + FilterApi.notEq(intColumn("cint"), 1000: Integer), + sources.Not(sources.EqualTo("CINT", 1000))) + + testCaseInsensitiveResolution( + schema, FilterApi.eq(intColumn("cint"), 1000: Integer), sources.EqualNullSafe("CINT", 1000)) + + testCaseInsensitiveResolution( + schema, + FilterApi.notEq(intColumn("cint"), 1000: Integer), + sources.Not(sources.EqualNullSafe("CINT", 1000))) + + testCaseInsensitiveResolution( + schema, + FilterApi.lt(intColumn("cint"), 1000: Integer), sources.LessThan("CINT", 1000)) + + testCaseInsensitiveResolution( + schema, + FilterApi.ltEq(intColumn("cint"), 1000: Integer), + sources.LessThanOrEqual("CINT", 1000)) + + testCaseInsensitiveResolution( + schema, FilterApi.gt(intColumn("cint"), 1000: Integer), sources.GreaterThan("CINT", 1000)) + + testCaseInsensitiveResolution( + schema, + FilterApi.gtEq(intColumn("cint"), 1000: Integer), + sources.GreaterThanOrEqual("CINT", 1000)) + + testCaseInsensitiveResolution( + schema, + FilterApi.or( + FilterApi.eq(intColumn("cint"), 10: Integer), + FilterApi.eq(intColumn("cint"), 20: Integer)), + sources.In("CINT", Array(10, 20))) + + val dupFieldSchema = StructType( + Seq(StructField("cint", IntegerType), StructField("cINT", IntegerType))) + val dupParquetSchema = new SparkToParquetSchemaConverter(conf).convert(dupFieldSchema) + assertResult(None) { + caseInsensitiveParquetFilters.createFilter( + dupParquetSchema, sources.EqualTo("CINT", 1000)) + } + } + + test("SPARK-25207: exception when duplicate fields in case-insensitive mode") { + withTempPath { dir => + val count = 10 + val tableName = "spark_25207" + val tableDir = dir.getAbsoluteFile + "/table" + withTable(tableName) { + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") { + spark.range(count).selectExpr("id as A", "id as B", "id as b") + .write.mode("overwrite").parquet(tableDir) + } + sql( + s""" + |CREATE TABLE $tableName (A LONG, B LONG) USING PARQUET LOCATION '$tableDir' + """.stripMargin) + + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") { + val e = intercept[SparkException] { + sql(s"select a from $tableName where b > 0").collect() + } + assert(e.getCause.isInstanceOf[RuntimeException] && e.getCause.getMessage.contains( + """Found duplicate field(s) "B": [B, b] in case-insensitive mode""")) + } + + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") { + checkAnswer(sql(s"select A from $tableName where B > 0"), (1 until count).map(Row(_))) + } + } + } + } } class NumRowGroupsAcc extends AccumulatorV2[Integer, Integer] { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetInteroperabilitySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetInteroperabilitySuite.scala index 9c75965639d8a..f06e1867151e5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetInteroperabilitySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetInteroperabilitySuite.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql.execution.datasources.parquet import java.io.File +import scala.language.existentials + import org.apache.commons.io.FileUtils import org.apache.hadoop.fs.{FileSystem, Path, PathFilter} import org.apache.parquet.format.converter.ParquetMetadataConverter.NO_FILTER @@ -175,8 +177,9 @@ class ParquetInteroperabilitySuite extends ParquetCompatibilityTest with SharedS val oneFooter = ParquetFileReader.readFooter(hadoopConf, part.getPath, NO_FILTER) assert(oneFooter.getFileMetaData.getSchema.getColumns.size === 1) - assert(oneFooter.getFileMetaData.getSchema.getColumns.get(0).getType() === - PrimitiveTypeName.INT96) + val typeName = oneFooter + .getFileMetaData.getSchema.getColumns.get(0).getPrimitiveType.getPrimitiveTypeName + assert(typeName === PrimitiveTypeName.INT96) val oneBlockMeta = oneFooter.getBlocks().get(0) val oneBlockColumnMeta = oneBlockMeta.getColumns().get(0) val columnStats = oneBlockColumnMeta.getStatistics diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala index dbf637783e6d2..54c77dddc3525 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala @@ -108,7 +108,7 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext val queryOutput = selfJoin.queryExecution.analyzed.output assertResult(4, "Field count mismatches")(queryOutput.size) - assertResult(2, "Duplicated expression ID in query plan:\n $selfJoin") { + assertResult(2, s"Duplicated expression ID in query plan:\n $selfJoin") { queryOutput.filter(_.name == "_1").map(_.exprId).size } @@ -117,7 +117,7 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext } test("nested data - struct with array field") { - val data = (1 to 10).map(i => Tuple1((i, Seq("val_$i")))) + val data = (1 to 10).map(i => Tuple1((i, Seq(s"val_$i")))) withParquetTable(data, "t") { checkAnswer(sql("SELECT _1._2[0] FROM t"), data.map { case Tuple1((_, Seq(string))) => Row(string) @@ -126,7 +126,7 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext } test("nested data - array of struct") { - val data = (1 to 10).map(i => Tuple1(Seq(i -> "val_$i"))) + val data = (1 to 10).map(i => Tuple1(Seq(i -> s"val_$i"))) withParquetTable(data, "t") { checkAnswer(sql("SELECT _1[0]._2 FROM t"), data.map { case Tuple1(Seq((_, string))) => Row(string) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaPruningSuite.scala new file mode 100644 index 0000000000000..434c4414edeba --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaPruningSuite.scala @@ -0,0 +1,383 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.parquet + +import java.io.File + +import org.scalactic.Equality + +import org.apache.spark.sql.{DataFrame, QueryTest, Row} +import org.apache.spark.sql.catalyst.SchemaPruningTest +import org.apache.spark.sql.catalyst.parser.CatalystSqlParser +import org.apache.spark.sql.execution.FileSourceScanExec +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types.StructType + +class ParquetSchemaPruningSuite + extends QueryTest + with ParquetTest + with SchemaPruningTest + with SharedSQLContext { + case class FullName(first: String, middle: String, last: String) + case class Company(name: String, address: String) + case class Employer(id: Int, company: Company) + case class Contact( + id: Int, + name: FullName, + address: String, + pets: Int, + friends: Array[FullName] = Array.empty, + relatives: Map[String, FullName] = Map.empty, + employer: Employer = null) + + val janeDoe = FullName("Jane", "X.", "Doe") + val johnDoe = FullName("John", "Y.", "Doe") + val susanSmith = FullName("Susan", "Z.", "Smith") + + val employer = Employer(0, Company("abc", "123 Business Street")) + val employerWithNullCompany = Employer(1, null) + + private val contacts = + Contact(0, janeDoe, "123 Main Street", 1, friends = Array(susanSmith), + relatives = Map("brother" -> johnDoe), employer = employer) :: + Contact(1, johnDoe, "321 Wall Street", 3, relatives = Map("sister" -> janeDoe), + employer = employerWithNullCompany) :: Nil + + case class Name(first: String, last: String) + case class BriefContact(id: Int, name: Name, address: String) + + private val briefContacts = + BriefContact(2, Name("Janet", "Jones"), "567 Maple Drive") :: + BriefContact(3, Name("Jim", "Jones"), "6242 Ash Street") :: Nil + + case class ContactWithDataPartitionColumn( + id: Int, + name: FullName, + address: String, + pets: Int, + friends: Array[FullName] = Array(), + relatives: Map[String, FullName] = Map(), + employer: Employer = null, + p: Int) + + case class BriefContactWithDataPartitionColumn(id: Int, name: Name, address: String, p: Int) + + private val contactsWithDataPartitionColumn = + contacts.map { case Contact(id, name, address, pets, friends, relatives, employer) => + ContactWithDataPartitionColumn(id, name, address, pets, friends, relatives, employer, 1) } + private val briefContactsWithDataPartitionColumn = + briefContacts.map { case BriefContact(id, name, address) => + BriefContactWithDataPartitionColumn(id, name, address, 2) } + + testSchemaPruning("select a single complex field") { + val query = sql("select name.middle from contacts") + checkScan(query, "struct>") + checkAnswer(query.orderBy("id"), Row("X.") :: Row("Y.") :: Row(null) :: Row(null) :: Nil) + } + + testSchemaPruning("select a single complex field and its parent struct") { + val query = sql("select name.middle, name from contacts") + checkScan(query, "struct>") + checkAnswer(query.orderBy("id"), + Row("X.", Row("Jane", "X.", "Doe")) :: + Row("Y.", Row("John", "Y.", "Doe")) :: + Row(null, Row("Janet", null, "Jones")) :: + Row(null, Row("Jim", null, "Jones")) :: + Nil) + } + + testSchemaPruning("select a single complex field array and its parent struct array") { + val query = sql("select friends.middle, friends from contacts where p=1") + checkScan(query, + "struct>>") + checkAnswer(query.orderBy("id"), + Row(Array("Z."), Array(Row("Susan", "Z.", "Smith"))) :: + Row(Array.empty[String], Array.empty[Row]) :: + Nil) + } + + testSchemaPruning("select a single complex field from a map entry and its parent map entry") { + val query = + sql("select relatives[\"brother\"].middle, relatives[\"brother\"] from contacts where p=1") + checkScan(query, + "struct>>") + checkAnswer(query.orderBy("id"), + Row("Y.", Row("John", "Y.", "Doe")) :: + Row(null, null) :: + Nil) + } + + testSchemaPruning("select a single complex field and the partition column") { + val query = sql("select name.middle, p from contacts") + checkScan(query, "struct>") + checkAnswer(query.orderBy("id"), + Row("X.", 1) :: Row("Y.", 1) :: Row(null, 2) :: Row(null, 2) :: Nil) + } + + ignore("partial schema intersection - select missing subfield") { + val query = sql("select name.middle, address from contacts where p=2") + checkScan(query, "struct,address:string>") + checkAnswer(query.orderBy("id"), + Row(null, "567 Maple Drive") :: + Row(null, "6242 Ash Street") :: Nil) + } + + testSchemaPruning("no unnecessary schema pruning") { + val query = + sql("select id, name.last, name.middle, name.first, relatives[''].last, " + + "relatives[''].middle, relatives[''].first, friends[0].last, friends[0].middle, " + + "friends[0].first, pets, address from contacts where p=2") + // We've selected every field in the schema. Therefore, no schema pruning should be performed. + // We check this by asserting that the scanned schema of the query is identical to the schema + // of the contacts relation, even though the fields are selected in different orders. + checkScan(query, + "struct,address:string,pets:int," + + "friends:array>," + + "relatives:map>>") + checkAnswer(query.orderBy("id"), + Row(2, "Jones", null, "Janet", null, null, null, null, null, null, null, "567 Maple Drive") :: + Row(3, "Jones", null, "Jim", null, null, null, null, null, null, null, "6242 Ash Street") :: + Nil) + } + + testSchemaPruning("empty schema intersection") { + val query = sql("select name.middle from contacts where p=2") + checkScan(query, "struct>") + checkAnswer(query.orderBy("id"), + Row(null) :: Row(null) :: Nil) + } + + testSchemaPruning("select a single complex field and in where clause") { + val query1 = sql("select name.first from contacts where name.first = 'Jane'") + checkScan(query1, "struct>") + checkAnswer(query1, Row("Jane") :: Nil) + + val query2 = sql("select name.first, name.last from contacts where name.first = 'Jane'") + checkScan(query2, "struct>") + checkAnswer(query2, Row("Jane", "Doe") :: Nil) + + val query3 = sql("select name.first from contacts " + + "where employer.company.name = 'abc' and p = 1") + checkScan(query3, "struct," + + "employer:struct>>") + checkAnswer(query3, Row("Jane") :: Nil) + + val query4 = sql("select name.first, employer.company.name from contacts " + + "where employer.company is not null and p = 1") + checkScan(query4, "struct," + + "employer:struct>>") + checkAnswer(query4, Row("Jane", "abc") :: Nil) + } + + testSchemaPruning("select nullable complex field and having is not null predicate") { + val query = sql("select employer.company from contacts " + + "where employer is not null and p = 1") + checkScan(query, "struct>>") + checkAnswer(query, Row(Row("abc", "123 Business Street")) :: Row(null) :: Nil) + } + + testSchemaPruning("select a single complex field and is null expression in project") { + val query = sql("select name.first, address is not null from contacts") + checkScan(query, "struct,address:string>") + checkAnswer(query.orderBy("id"), + Row("Jane", true) :: Row("John", true) :: Row("Janet", true) :: Row("Jim", true) :: Nil) + } + + testSchemaPruning("select a single complex field array and in clause") { + val query = sql("select friends.middle from contacts where friends.first[0] = 'Susan'") + checkScan(query, + "struct>>") + checkAnswer(query.orderBy("id"), + Row(Array("Z.")) :: Nil) + } + + testSchemaPruning("select a single complex field from a map entry and in clause") { + val query = + sql("select relatives[\"brother\"].middle from contacts " + + "where relatives[\"brother\"].first = 'John'") + checkScan(query, + "struct>>") + checkAnswer(query.orderBy("id"), + Row("Y.") :: Nil) + } + + private def testSchemaPruning(testName: String)(testThunk: => Unit) { + test(s"Spark vectorized reader - without partition data column - $testName") { + withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "true") { + withContacts(testThunk) + } + } + test(s"Spark vectorized reader - with partition data column - $testName") { + withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "true") { + withContactsWithDataPartitionColumn(testThunk) + } + } + + test(s"Parquet-mr reader - without partition data column - $testName") { + withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false") { + withContacts(testThunk) + } + } + test(s"Parquet-mr reader - with partition data column - $testName") { + withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false") { + withContactsWithDataPartitionColumn(testThunk) + } + } + } + + private def withContacts(testThunk: => Unit) { + withTempPath { dir => + val path = dir.getCanonicalPath + + makeParquetFile(contacts, new File(path + "/contacts/p=1")) + makeParquetFile(briefContacts, new File(path + "/contacts/p=2")) + + spark.read.parquet(path + "/contacts").createOrReplaceTempView("contacts") + + testThunk + } + } + + private def withContactsWithDataPartitionColumn(testThunk: => Unit) { + withTempPath { dir => + val path = dir.getCanonicalPath + + makeParquetFile(contactsWithDataPartitionColumn, new File(path + "/contacts/p=1")) + makeParquetFile(briefContactsWithDataPartitionColumn, new File(path + "/contacts/p=2")) + + spark.read.parquet(path + "/contacts").createOrReplaceTempView("contacts") + + testThunk + } + } + + case class MixedCaseColumn(a: String, B: Int) + case class MixedCase(id: Int, CoL1: String, coL2: MixedCaseColumn) + + private val mixedCaseData = + MixedCase(0, "r0c1", MixedCaseColumn("abc", 1)) :: + MixedCase(1, "r1c1", MixedCaseColumn("123", 2)) :: + Nil + + testExactCaseQueryPruning("select with exact column names") { + val query = sql("select CoL1, coL2.B from mixedcase") + checkScan(query, "struct>") + checkAnswer(query.orderBy("id"), + Row("r0c1", 1) :: + Row("r1c1", 2) :: + Nil) + } + + testMixedCaseQueryPruning("select with lowercase column names") { + val query = sql("select col1, col2.b from mixedcase") + checkScan(query, "struct>") + checkAnswer(query.orderBy("id"), + Row("r0c1", 1) :: + Row("r1c1", 2) :: + Nil) + } + + testMixedCaseQueryPruning("select with different-case column names") { + val query = sql("select cOL1, cOl2.b from mixedcase") + checkScan(query, "struct>") + checkAnswer(query.orderBy("id"), + Row("r0c1", 1) :: + Row("r1c1", 2) :: + Nil) + } + + testMixedCaseQueryPruning("filter with different-case column names") { + val query = sql("select id from mixedcase where Col2.b = 2") + checkScan(query, "struct>") + checkAnswer(query.orderBy("id"), Row(1) :: Nil) + } + + // Tests schema pruning for a query whose column and field names are exactly the same as the table + // schema's column and field names. N.B. this implies that `testThunk` should pass using either a + // case-sensitive or case-insensitive query parser + private def testExactCaseQueryPruning(testName: String)(testThunk: => Unit) { + test(s"Spark vectorized reader - case-sensitive parser - mixed-case schema - $testName") { + withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "true", + SQLConf.CASE_SENSITIVE.key -> "true") { + withMixedCaseData(testThunk) + } + } + test(s"Parquet-mr reader - case-sensitive parser - mixed-case schema - $testName") { + withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false", + SQLConf.CASE_SENSITIVE.key -> "true") { + withMixedCaseData(testThunk) + } + } + testMixedCaseQueryPruning(testName)(testThunk) + } + + // Tests schema pruning for a query whose column and field names may differ in case from the table + // schema's column and field names + private def testMixedCaseQueryPruning(testName: String)(testThunk: => Unit) { + test(s"Spark vectorized reader - case-insensitive parser - mixed-case schema - $testName") { + withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "true", + SQLConf.CASE_SENSITIVE.key -> "false") { + withMixedCaseData(testThunk) + } + } + test(s"Parquet-mr reader - case-insensitive parser - mixed-case schema - $testName") { + withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false", + SQLConf.CASE_SENSITIVE.key -> "false") { + withMixedCaseData(testThunk) + } + } + } + + private def withMixedCaseData(testThunk: => Unit) { + withParquetTable(mixedCaseData, "mixedcase") { + testThunk + } + } + + private val schemaEquality = new Equality[StructType] { + override def areEqual(a: StructType, b: Any): Boolean = + b match { + case otherType: StructType => a.sameType(otherType) + case _ => false + } + } + + protected def checkScan(df: DataFrame, expectedSchemaCatalogStrings: String*): Unit = { + checkScanSchemata(df, expectedSchemaCatalogStrings: _*) + // We check here that we can execute the query without throwing an exception. The results + // themselves are irrelevant, and should be checked elsewhere as needed + df.collect() + } + + private def checkScanSchemata(df: DataFrame, expectedSchemaCatalogStrings: String*): Unit = { + val fileSourceScanSchemata = + df.queryExecution.executedPlan.collect { + case scan: FileSourceScanExec => scan.requiredSchema + } + assert(fileSourceScanSchemata.size === expectedSchemaCatalogStrings.size, + s"Found ${fileSourceScanSchemata.size} file sources in dataframe, " + + s"but expected $expectedSchemaCatalogStrings") + fileSourceScanSchemata.zip(expectedSchemaCatalogStrings).foreach { + case (scanSchema, expectedScanSchemaCatalogString) => + val expectedScanSchema = CatalystSqlParser.parseDataType(expectedScanSchemaCatalogString) + implicit val equality = schemaEquality + assert(scanSchema === expectedScanSchema) + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala index 9d3dfae348beb..528a4d0ca8004 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala @@ -427,12 +427,12 @@ class ParquetSchemaSuite extends ParquetSchemaTest { assert(errMsg.startsWith("Parquet column cannot be converted in file")) val file = errMsg.substring("Parquet column cannot be converted in file ".length, errMsg.indexOf(". ")) - val col = spark.read.parquet(file).schema.fields.filter(_.name.equals("a")) + val col = spark.read.parquet(file).schema.fields.filter(_.name == "a") assert(col.length == 1) if (col(0).dataType == StringType) { - assert(errMsg.contains("Column: [a], Expected: IntegerType, Found: BINARY")) + assert(errMsg.contains("Column: [a], Expected: int, Found: BINARY")) } else { - assert(errMsg.endsWith("Column: [a], Expected: StringType, Found: INT32")) + assert(errMsg.endsWith("Column: [a], Expected: string, Found: INT32")) } } } @@ -1014,19 +1014,21 @@ class ParquetSchemaSuite extends ParquetSchemaTest { testName: String, parquetSchema: String, catalystSchema: StructType, - expectedSchema: String): Unit = { + expectedSchema: String, + caseSensitive: Boolean = true): Unit = { testSchemaClipping(testName, parquetSchema, catalystSchema, - MessageTypeParser.parseMessageType(expectedSchema)) + MessageTypeParser.parseMessageType(expectedSchema), caseSensitive) } private def testSchemaClipping( testName: String, parquetSchema: String, catalystSchema: StructType, - expectedSchema: MessageType): Unit = { + expectedSchema: MessageType, + caseSensitive: Boolean): Unit = { test(s"Clipping - $testName") { val actual = ParquetReadSupport.clipParquetSchema( - MessageTypeParser.parseMessageType(parquetSchema), catalystSchema) + MessageTypeParser.parseMessageType(parquetSchema), catalystSchema, caseSensitive) try { expectedSchema.checkContains(actual) @@ -1387,7 +1389,8 @@ class ParquetSchemaSuite extends ParquetSchemaTest { catalystSchema = new StructType(), - expectedSchema = ParquetSchemaConverter.EMPTY_MESSAGE) + expectedSchema = ParquetSchemaConverter.EMPTY_MESSAGE, + caseSensitive = true) testSchemaClipping( "disjoint field sets", @@ -1544,4 +1547,52 @@ class ParquetSchemaSuite extends ParquetSchemaTest { | } |} """.stripMargin) + + testSchemaClipping( + "case-insensitive resolution: no ambiguity", + parquetSchema = + """message root { + | required group A { + | optional int32 B; + | } + | optional int32 c; + |} + """.stripMargin, + catalystSchema = { + val nestedType = new StructType().add("b", IntegerType, nullable = true) + new StructType() + .add("a", nestedType, nullable = true) + .add("c", IntegerType, nullable = true) + }, + expectedSchema = + """message root { + | required group A { + | optional int32 B; + | } + | optional int32 c; + |} + """.stripMargin, + caseSensitive = false) + + test("Clipping - case-insensitive resolution: more than one field is matched") { + val parquetSchema = + """message root { + | required group A { + | optional int32 B; + | } + | optional int32 c; + | optional int32 a; + |} + """.stripMargin + val catalystSchema = { + val nestedType = new StructType().add("b", IntegerType, nullable = true) + new StructType() + .add("a", nestedType, nullable = true) + .add("c", IntegerType, nullable = true) + } + assertThrows[RuntimeException] { + ParquetReadSupport.clipParquetSchema( + MessageTypeParser.parseMessageType(parquetSchema), catalystSchema, caseSensitive = false) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/WholeTextFileSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/WholeTextFileSuite.scala index fff0f82f9bc2b..a302d67b5cbf7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/WholeTextFileSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/WholeTextFileSuite.scala @@ -21,10 +21,10 @@ import java.io.File import org.apache.spark.sql.{QueryTest, Row} import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.test.{SharedSQLContext, SQLTestUtils} import org.apache.spark.sql.types.{StringType, StructType} -class WholeTextFileSuite extends QueryTest with SharedSQLContext { +class WholeTextFileSuite extends QueryTest with SharedSQLContext with SQLTestUtils { // Hadoop's FileSystem caching does not use the Configuration as part of its cache key, which // can cause Filesystem.get(Configuration) to return a cached instance created with a different @@ -35,13 +35,10 @@ class WholeTextFileSuite extends QueryTest with SharedSQLContext { protected override def sparkConf = super.sparkConf.set("spark.hadoop.fs.file.impl.disable.cache", "true") - private def testFile: String = { - Thread.currentThread().getContextClassLoader.getResource("test-data/text-suite.txt").toString - } - test("reading text file with option wholetext=true") { val df = spark.read.option("wholetext", "true") - .format("text").load(testFile) + .format("text") + .load(testFile("test-data/text-suite.txt")) // schema assert(df.schema == new StructType().add("value", StringType)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala index bcdee792f4c70..b4ad1db20a9ea 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala @@ -54,8 +54,12 @@ class BroadcastJoinSuite extends QueryTest with SQLTestUtils { } override def afterAll(): Unit = { - spark.stop() - spark = null + try { + spark.stop() + spark = null + } finally { + super.afterAll() + } } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala index 037cc2e3ccad7..d9b34dcd16476 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala @@ -278,6 +278,35 @@ class HashedRelationSuite extends SparkFunSuite with SharedSQLContext { map.free() } + test("SPARK-24809: Serializing LongToUnsafeRowMap in executor may result in data error") { + val unsafeProj = UnsafeProjection.create(Array[DataType](LongType)) + val originalMap = new LongToUnsafeRowMap(mm, 1) + + val key1 = 1L + val value1 = 4852306286022334418L + + val key2 = 2L + val value2 = 8813607448788216010L + + originalMap.append(key1, unsafeProj(InternalRow(value1))) + originalMap.append(key2, unsafeProj(InternalRow(value2))) + originalMap.optimize() + + val ser = sparkContext.env.serializer.newInstance() + // Simulate serialize/deserialize twice on driver and executor + val firstTimeSerialized = ser.deserialize[LongToUnsafeRowMap](ser.serialize(originalMap)) + val secondTimeSerialized = + ser.deserialize[LongToUnsafeRowMap](ser.serialize(firstTimeSerialized)) + + val resultRow = new UnsafeRow(1) + assert(secondTimeSerialized.getValue(key1, resultRow).getLong(0) === value1) + assert(secondTimeSerialized.getValue(key2, resultRow).getLong(0) === value2) + + originalMap.free() + firstTimeSerialized.free() + secondTimeSerialized.free() + } + test("Spark-14521") { val ser = new KryoSerializer( (new SparkConf).set("spark.kryo.referenceTracking", "false")).newInstance() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala index 8263c9c81c49e..d45eb0c27a6b1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala @@ -19,12 +19,12 @@ package org.apache.spark.sql.execution.metric import java.io.File -import scala.concurrent.{ExecutionContext, ExecutionContextExecutor, Future} import scala.util.Random import org.apache.spark.SparkFunSuite import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.plans.logical.LocalRelation +import org.apache.spark.sql.execution.ui.SQLAppStatusStore import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext @@ -497,6 +497,19 @@ class SQLMetricsSuite extends SparkFunSuite with SQLMetricsTestUtils with Shared } } + test("SPARK-25278: output metrics are wrong for plans repeated in the query") { + val name = "demo_view" + withView(name) { + sql(s"CREATE OR REPLACE VIEW $name AS VALUES 1,2") + val view = spark.table(name) + val union = view.union(view) + testSparkPlanMetrics(union, 1, Map( + 0L -> ("Union" -> Map()), + 1L -> ("LocalTableScan" -> Map("number of output rows" -> 2L)), + 2L -> ("LocalTableScan" -> Map("number of output rows" -> 2L)))) + } + } + test("writing data out metrics: parquet") { testMetricsNonDynamicPartition("parquet", "t1") } @@ -504,38 +517,4 @@ class SQLMetricsSuite extends SparkFunSuite with SQLMetricsTestUtils with Shared test("writing data out metrics with dynamic partition: parquet") { testMetricsDynamicPartition("parquet", "parquet", "t1") } - - test("writing metrics from single thread") { - val nAdds = 10 - val acc = new SQLMetric("test", -10) - assert(acc.isZero()) - acc.set(0) - for (i <- 1 to nAdds) acc.add(1) - assert(!acc.isZero()) - assert(nAdds === acc.value) - acc.reset() - assert(acc.isZero()) - } - - test("writing metrics from multiple threads") { - implicit val ec: ExecutionContextExecutor = ExecutionContext.global - val nFutures = 1000 - val nAdds = 100 - val acc = new SQLMetric("test", -10) - assert(acc.isZero() === true) - acc.set(0) - val l = for ( i <- 1 to nFutures ) yield { - Future { - for (j <- 1 to nAdds) acc.add(1) - i - } - } - for (futures <- Future.sequence(l)) { - assert(nFutures === futures.length) - assert(!acc.isZero()) - assert(nFutures * nAdds === acc.value) - acc.reset() - assert(acc.isZero()) - } - } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExecSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExecSuite.scala index d456c931f5275..289cc667a1c66 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExecSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExecSuite.scala @@ -37,8 +37,11 @@ class BatchEvalPythonExecSuite extends SparkPlanTest with SharedSQLContext { } override def afterAll(): Unit = { - spark.sessionState.functionRegistry.dropFunction(FunctionIdentifier("dummyPythonUDF")) - super.afterAll() + try { + spark.sessionState.functionRegistry.dropFunction(FunctionIdentifier("dummyPythonUDF")) + } finally { + super.afterAll() + } } test("Python UDF: push down deterministic FilterExec predicates") { @@ -115,3 +118,10 @@ class MyDummyPythonUDF extends UserDefinedPythonFunction( dataType = BooleanType, pythonEvalType = PythonEvalType.SQL_BATCHED_UDF, udfDeterministic = true) + +class MyDummyScalarPandasUDF extends UserDefinedPythonFunction( + name = "dummyScalarPandasUDF", + func = new DummyUDF, + dataType = BooleanType, + pythonEvalType = PythonEvalType.SQL_SCALAR_PANDAS_UDF, + udfDeterministic = true) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFsSuite.scala new file mode 100644 index 0000000000000..76b609d111acd --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFsSuite.scala @@ -0,0 +1,92 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.python + +import org.apache.spark.sql.execution.{SparkPlan, SparkPlanTest} +import org.apache.spark.sql.functions.col +import org.apache.spark.sql.test.SharedSQLContext + +class ExtractPythonUDFsSuite extends SparkPlanTest with SharedSQLContext { + import testImplicits.newProductEncoder + import testImplicits.localSeqToDatasetHolder + + val batchedPythonUDF = new MyDummyPythonUDF + val scalarPandasUDF = new MyDummyScalarPandasUDF + + private def collectBatchExec(plan: SparkPlan): Seq[BatchEvalPythonExec] = plan.collect { + case b: BatchEvalPythonExec => b + } + + private def collectArrowExec(plan: SparkPlan): Seq[ArrowEvalPythonExec] = plan.collect { + case b: ArrowEvalPythonExec => b + } + + test("Chained Batched Python UDFs should be combined to a single physical node") { + val df = Seq(("Hello", 4)).toDF("a", "b") + val df2 = df.withColumn("c", batchedPythonUDF(col("a"))) + .withColumn("d", batchedPythonUDF(col("c"))) + val pythonEvalNodes = collectBatchExec(df2.queryExecution.executedPlan) + assert(pythonEvalNodes.size == 1) + } + + test("Chained Scalar Pandas UDFs should be combined to a single physical node") { + val df = Seq(("Hello", 4)).toDF("a", "b") + val df2 = df.withColumn("c", scalarPandasUDF(col("a"))) + .withColumn("d", scalarPandasUDF(col("c"))) + val arrowEvalNodes = collectArrowExec(df2.queryExecution.executedPlan) + assert(arrowEvalNodes.size == 1) + } + + test("Mixed Batched Python UDFs and Pandas UDF should be separate physical node") { + val df = Seq(("Hello", 4)).toDF("a", "b") + val df2 = df.withColumn("c", batchedPythonUDF(col("a"))) + .withColumn("d", scalarPandasUDF(col("b"))) + + val pythonEvalNodes = collectBatchExec(df2.queryExecution.executedPlan) + val arrowEvalNodes = collectArrowExec(df2.queryExecution.executedPlan) + assert(pythonEvalNodes.size == 1) + assert(arrowEvalNodes.size == 1) + } + + test("Independent Batched Python UDFs and Scalar Pandas UDFs should be combined separately") { + val df = Seq(("Hello", 4)).toDF("a", "b") + val df2 = df.withColumn("c1", batchedPythonUDF(col("a"))) + .withColumn("c2", batchedPythonUDF(col("c1"))) + .withColumn("d1", scalarPandasUDF(col("a"))) + .withColumn("d2", scalarPandasUDF(col("d1"))) + + val pythonEvalNodes = collectBatchExec(df2.queryExecution.executedPlan) + val arrowEvalNodes = collectArrowExec(df2.queryExecution.executedPlan) + assert(pythonEvalNodes.size == 1) + assert(arrowEvalNodes.size == 1) + } + + test("Dependent Batched Python UDFs and Scalar Pandas UDFs should not be combined") { + val df = Seq(("Hello", 4)).toDF("a", "b") + val df2 = df.withColumn("c1", batchedPythonUDF(col("a"))) + .withColumn("d1", scalarPandasUDF(col("c1"))) + .withColumn("c2", batchedPythonUDF(col("d1"))) + .withColumn("d2", scalarPandasUDF(col("c2"))) + + val pythonEvalNodes = collectBatchExec(df2.queryExecution.executedPlan) + val arrowEvalNodes = collectArrowExec(df2.queryExecution.executedPlan) + assert(pythonEvalNodes.size == 2) + assert(arrowEvalNodes.size == 2) + } +} + diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/RowQueueSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/RowQueueSuite.scala index 25ee95daa034c..ffda33cf906c5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/RowQueueSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/RowQueueSuite.scala @@ -22,13 +22,13 @@ import java.io.File import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.memory.{MemoryManager, TaskMemoryManager, TestMemoryManager} import org.apache.spark.sql.catalyst.expressions.UnsafeRow -import org.apache.spark.unsafe.memory.OnHeapMemoryBlock +import org.apache.spark.unsafe.memory.MemoryBlock import org.apache.spark.util.Utils class RowQueueSuite extends SparkFunSuite { test("in-memory queue") { - val page = new OnHeapMemoryBlock((1<<10) * 8L) + val page = MemoryBlock.fromLongArray(new Array[Long](1<<10)) val queue = new InMemoryRowQueue(page, 1) { override def close() {} } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkSuite.scala index b2fd6ba27ebb8..3bc36ce55d902 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkSuite.scala @@ -17,13 +17,11 @@ package org.apache.spark.sql.execution.streaming -import scala.collection.JavaConverters._ import scala.language.implicitConversions import org.scalatest.BeforeAndAfter import org.apache.spark.sql._ -import org.apache.spark.sql.sources.v2.DataSourceOptions import org.apache.spark.sql.streaming.{OutputMode, StreamTest} import org.apache.spark.sql.types.{IntegerType, StructField, StructType} import org.apache.spark.util.Utils @@ -38,7 +36,7 @@ class MemorySinkSuite extends StreamTest with BeforeAndAfter { test("directly add data in Append output mode") { implicit val schema = new StructType().add(new StructField("value", IntegerType)) - val sink = new MemorySink(schema, OutputMode.Append, DataSourceOptions.empty()) + val sink = new MemorySink(schema, OutputMode.Append) // Before adding data, check output assert(sink.latestBatchId === None) @@ -70,35 +68,9 @@ class MemorySinkSuite extends StreamTest with BeforeAndAfter { checkAnswer(sink.allData, 1 to 9) } - test("directly add data in Append output mode with row limit") { - implicit val schema = new StructType().add(new StructField("value", IntegerType)) - - var optionsMap = new scala.collection.mutable.HashMap[String, String] - optionsMap.put(MemorySinkBase.MAX_MEMORY_SINK_ROWS, 5.toString()) - var options = new DataSourceOptions(optionsMap.toMap.asJava) - val sink = new MemorySink(schema, OutputMode.Append, options) - - // Before adding data, check output - assert(sink.latestBatchId === None) - checkAnswer(sink.latestBatchData, Seq.empty) - checkAnswer(sink.allData, Seq.empty) - - // Add batch 0 and check outputs - sink.addBatch(0, 1 to 3) - assert(sink.latestBatchId === Some(0)) - checkAnswer(sink.latestBatchData, 1 to 3) - checkAnswer(sink.allData, 1 to 3) - - // Add batch 1 and check outputs - sink.addBatch(1, 4 to 6) - assert(sink.latestBatchId === Some(1)) - checkAnswer(sink.latestBatchData, 4 to 5) - checkAnswer(sink.allData, 1 to 5) // new data should not go over the limit - } - test("directly add data in Update output mode") { implicit val schema = new StructType().add(new StructField("value", IntegerType)) - val sink = new MemorySink(schema, OutputMode.Update, DataSourceOptions.empty()) + val sink = new MemorySink(schema, OutputMode.Update) // Before adding data, check output assert(sink.latestBatchId === None) @@ -132,7 +104,7 @@ class MemorySinkSuite extends StreamTest with BeforeAndAfter { test("directly add data in Complete output mode") { implicit val schema = new StructType().add(new StructField("value", IntegerType)) - val sink = new MemorySink(schema, OutputMode.Complete, DataSourceOptions.empty()) + val sink = new MemorySink(schema, OutputMode.Complete) // Before adding data, check output assert(sink.latestBatchId === None) @@ -164,32 +136,6 @@ class MemorySinkSuite extends StreamTest with BeforeAndAfter { checkAnswer(sink.allData, 7 to 9) } - test("directly add data in Complete output mode with row limit") { - implicit val schema = new StructType().add(new StructField("value", IntegerType)) - - var optionsMap = new scala.collection.mutable.HashMap[String, String] - optionsMap.put(MemorySinkBase.MAX_MEMORY_SINK_ROWS, 5.toString()) - var options = new DataSourceOptions(optionsMap.toMap.asJava) - val sink = new MemorySink(schema, OutputMode.Complete, options) - - // Before adding data, check output - assert(sink.latestBatchId === None) - checkAnswer(sink.latestBatchData, Seq.empty) - checkAnswer(sink.allData, Seq.empty) - - // Add batch 0 and check outputs - sink.addBatch(0, 1 to 3) - assert(sink.latestBatchId === Some(0)) - checkAnswer(sink.latestBatchData, 1 to 3) - checkAnswer(sink.allData, 1 to 3) - - // Add batch 1 and check outputs - sink.addBatch(1, 4 to 10) - assert(sink.latestBatchId === Some(1)) - checkAnswer(sink.latestBatchData, 4 to 8) - checkAnswer(sink.allData, 4 to 8) // new data should replace old data - } - test("registering as a table in Append output mode") { val input = MemoryStream[Int] @@ -265,7 +211,7 @@ class MemorySinkSuite extends StreamTest with BeforeAndAfter { test("MemoryPlan statistics") { implicit val schema = new StructType().add(new StructField("value", IntegerType)) - val sink = new MemorySink(schema, OutputMode.Append, DataSourceOptions.empty()) + val sink = new MemorySink(schema, OutputMode.Append) val plan = new MemoryPlan(sink) // Before adding data, check output diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkV2Suite.scala index e539510e15755..61857365ac989 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkV2Suite.scala @@ -17,24 +17,22 @@ package org.apache.spark.sql.execution.streaming -import scala.collection.JavaConverters._ - import org.scalatest.BeforeAndAfter import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.execution.streaming.sources._ -import org.apache.spark.sql.sources.v2.DataSourceOptions import org.apache.spark.sql.streaming.{OutputMode, StreamTest} -import org.apache.spark.sql.types.IntegerType import org.apache.spark.sql.types.StructType class MemorySinkV2Suite extends StreamTest with BeforeAndAfter { test("data writer") { val partition = 1234 - val writer = new MemoryDataWriter(partition, OutputMode.Append()) - writer.write(Row(1)) - writer.write(Row(2)) - writer.write(Row(44)) + val writer = new MemoryDataWriter( + partition, OutputMode.Append(), new StructType().add("i", "int")) + writer.write(InternalRow(1)) + writer.write(InternalRow(2)) + writer.write(InternalRow(44)) val msg = writer.commit() assert(msg.data.map(_.getInt(0)) == Seq(1, 2, 44)) assert(msg.partition == partition) @@ -43,31 +41,11 @@ class MemorySinkV2Suite extends StreamTest with BeforeAndAfter { assert(writer.commit().data.isEmpty) } - test("continuous writer") { - val sink = new MemorySinkV2 - val writer = new MemoryStreamWriter(sink, OutputMode.Append(), DataSourceOptions.empty()) - writer.commit(0, - Array( - MemoryWriterCommitMessage(0, Seq(Row(1), Row(2))), - MemoryWriterCommitMessage(1, Seq(Row(3), Row(4))), - MemoryWriterCommitMessage(2, Seq(Row(6), Row(7))) - )) - assert(sink.latestBatchId.contains(0)) - assert(sink.latestBatchData.map(_.getInt(0)).sorted == Seq(1, 2, 3, 4, 6, 7)) - writer.commit(19, - Array( - MemoryWriterCommitMessage(3, Seq(Row(11), Row(22))), - MemoryWriterCommitMessage(0, Seq(Row(33))) - )) - assert(sink.latestBatchId.contains(19)) - assert(sink.latestBatchData.map(_.getInt(0)).sorted == Seq(11, 22, 33)) - - assert(sink.allData.map(_.getInt(0)).sorted == Seq(1, 2, 3, 4, 6, 7, 11, 22, 33)) - } - - test("microbatch writer") { + test("streaming writer") { val sink = new MemorySinkV2 - new MemoryWriter(sink, 0, OutputMode.Append(), DataSourceOptions.empty()).commit( + val writeSupport = new MemoryStreamingWriteSupport( + sink, OutputMode.Append(), new StructType().add("i", "int")) + writeSupport.commit(0, Array( MemoryWriterCommitMessage(0, Seq(Row(1), Row(2))), MemoryWriterCommitMessage(1, Seq(Row(3), Row(4))), @@ -75,7 +53,7 @@ class MemorySinkV2Suite extends StreamTest with BeforeAndAfter { )) assert(sink.latestBatchId.contains(0)) assert(sink.latestBatchData.map(_.getInt(0)).sorted == Seq(1, 2, 3, 4, 6, 7)) - new MemoryWriter(sink, 19, OutputMode.Append(), DataSourceOptions.empty()).commit( + writeSupport.commit(19, Array( MemoryWriterCommitMessage(3, Seq(Row(11), Row(22))), MemoryWriterCommitMessage(0, Seq(Row(33))) @@ -85,73 +63,4 @@ class MemorySinkV2Suite extends StreamTest with BeforeAndAfter { assert(sink.allData.map(_.getInt(0)).sorted == Seq(1, 2, 3, 4, 6, 7, 11, 22, 33)) } - - test("continuous writer with row limit") { - val sink = new MemorySinkV2 - val optionsMap = new scala.collection.mutable.HashMap[String, String] - optionsMap.put(MemorySinkBase.MAX_MEMORY_SINK_ROWS, 7.toString()) - val options = new DataSourceOptions(optionsMap.toMap.asJava) - val appendWriter = new MemoryStreamWriter(sink, OutputMode.Append(), options) - appendWriter.commit(0, Array( - MemoryWriterCommitMessage(0, Seq(Row(1), Row(2))), - MemoryWriterCommitMessage(1, Seq(Row(3), Row(4))), - MemoryWriterCommitMessage(2, Seq(Row(6), Row(7))))) - assert(sink.latestBatchId.contains(0)) - assert(sink.latestBatchData.map(_.getInt(0)).sorted == Seq(1, 2, 3, 4, 6, 7)) - appendWriter.commit(19, Array( - MemoryWriterCommitMessage(3, Seq(Row(11), Row(22))), - MemoryWriterCommitMessage(0, Seq(Row(33))))) - assert(sink.latestBatchId.contains(19)) - assert(sink.latestBatchData.map(_.getInt(0)).sorted == Seq(11)) - - assert(sink.allData.map(_.getInt(0)).sorted == Seq(1, 2, 3, 4, 6, 7, 11)) - - val completeWriter = new MemoryStreamWriter(sink, OutputMode.Complete(), options) - completeWriter.commit(20, Array( - MemoryWriterCommitMessage(4, Seq(Row(11), Row(22))), - MemoryWriterCommitMessage(5, Seq(Row(33))))) - assert(sink.latestBatchId.contains(20)) - assert(sink.latestBatchData.map(_.getInt(0)).sorted == Seq(11, 22, 33)) - completeWriter.commit(21, Array( - MemoryWriterCommitMessage(0, Seq(Row(1), Row(2), Row(3))), - MemoryWriterCommitMessage(1, Seq(Row(4), Row(5), Row(6))), - MemoryWriterCommitMessage(2, Seq(Row(7), Row(8), Row(9))))) - assert(sink.latestBatchId.contains(21)) - assert(sink.latestBatchData.map(_.getInt(0)).sorted == Seq(1, 2, 3, 4, 5, 6, 7)) - - assert(sink.allData.map(_.getInt(0)).sorted == Seq(1, 2, 3, 4, 5, 6, 7)) - } - - test("microbatch writer with row limit") { - val sink = new MemorySinkV2 - val optionsMap = new scala.collection.mutable.HashMap[String, String] - optionsMap.put(MemorySinkBase.MAX_MEMORY_SINK_ROWS, 5.toString()) - val options = new DataSourceOptions(optionsMap.toMap.asJava) - - new MemoryWriter(sink, 25, OutputMode.Append(), options).commit(Array( - MemoryWriterCommitMessage(0, Seq(Row(1), Row(2))), - MemoryWriterCommitMessage(1, Seq(Row(3), Row(4))))) - assert(sink.latestBatchId.contains(25)) - assert(sink.latestBatchData.map(_.getInt(0)).sorted == Seq(1, 2, 3, 4)) - assert(sink.allData.map(_.getInt(0)).sorted == Seq(1, 2, 3, 4)) - new MemoryWriter(sink, 26, OutputMode.Append(), options).commit(Array( - MemoryWriterCommitMessage(2, Seq(Row(5), Row(6))), - MemoryWriterCommitMessage(3, Seq(Row(7), Row(8))))) - assert(sink.latestBatchId.contains(26)) - assert(sink.latestBatchData.map(_.getInt(0)).sorted == Seq(5)) - assert(sink.allData.map(_.getInt(0)).sorted == Seq(1, 2, 3, 4, 5)) - - new MemoryWriter(sink, 27, OutputMode.Complete(), options).commit(Array( - MemoryWriterCommitMessage(4, Seq(Row(9), Row(10))), - MemoryWriterCommitMessage(5, Seq(Row(11), Row(12))))) - assert(sink.latestBatchId.contains(27)) - assert(sink.latestBatchData.map(_.getInt(0)).sorted == Seq(9, 10, 11, 12)) - assert(sink.allData.map(_.getInt(0)).sorted == Seq(9, 10, 11, 12)) - new MemoryWriter(sink, 28, OutputMode.Complete(), options).commit(Array( - MemoryWriterCommitMessage(4, Seq(Row(13), Row(14), Row(15))), - MemoryWriterCommitMessage(5, Seq(Row(16), Row(17), Row(18))))) - assert(sink.latestBatchId.contains(28)) - assert(sink.latestBatchData.map(_.getInt(0)).sorted == Seq(13, 14, 15, 16, 17)) - assert(sink.allData.map(_.getInt(0)).sorted == Seq(13, 14, 15, 16, 17)) - } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriteSupportSuite.scala similarity index 98% rename from sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriterSuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriteSupportSuite.scala index 55acf2ba28d2f..5884380271f0e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriteSupportSuite.scala @@ -19,12 +19,10 @@ package org.apache.spark.sql.execution.streaming.sources import java.io.ByteArrayOutputStream -import org.scalatest.time.SpanSugar._ - import org.apache.spark.sql.execution.streaming.MemoryStream import org.apache.spark.sql.streaming.{StreamTest, Trigger} -class ConsoleWriterSuite extends StreamTest { +class ConsoleWriteSupportSuite extends StreamTest { import testImplicits._ test("microbatch - default") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ForeachBatchSinkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ForeachBatchSinkSuite.scala index a4233e15e4ffd..71dff443e8836 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ForeachBatchSinkSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ForeachBatchSinkSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.execution.streaming.sources import scala.collection.mutable +import scala.language.implicitConversions import org.apache.spark.sql._ import org.apache.spark.sql.execution.streaming.MemoryStream diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProviderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProviderSuite.scala index bf72e5c99689f..dd74af873c2e5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProviderSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProviderSuite.scala @@ -17,20 +17,18 @@ package org.apache.spark.sql.execution.streaming.sources -import java.nio.file.Files -import java.util.Optional import java.util.concurrent.TimeUnit import scala.collection.JavaConverters._ import scala.collection.mutable.ArrayBuffer -import org.apache.spark.sql.{AnalysisException, Row, SparkSession} -import org.apache.spark.sql.catalyst.errors.TreeNodeException +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.execution.datasources.DataSource import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.continuous._ import org.apache.spark.sql.functions._ -import org.apache.spark.sql.sources.v2.{ContinuousReadSupport, DataSourceOptions, MicroBatchReadSupport} +import org.apache.spark.sql.sources.v2.{ContinuousReadSupportProvider, DataSourceOptions, MicroBatchReadSupportProvider} import org.apache.spark.sql.sources.v2.reader.streaming.Offset import org.apache.spark.sql.streaming.StreamTest import org.apache.spark.util.ManualClock @@ -43,7 +41,7 @@ class RateSourceSuite extends StreamTest { override def addData(query: Option[StreamExecution]): (BaseStreamingSource, Offset) = { assert(query.nonEmpty) val rateSource = query.get.logicalPlan.collect { - case StreamingExecutionRelation(source: RateStreamMicroBatchReader, _) => source + case StreamingExecutionRelation(source: RateStreamMicroBatchReadSupport, _) => source }.head rateSource.clock.asInstanceOf[ManualClock].advance(TimeUnit.SECONDS.toMillis(seconds)) @@ -54,19 +52,22 @@ class RateSourceSuite extends StreamTest { } test("microbatch in registry") { - DataSource.lookupDataSource("rate", spark.sqlContext.conf).newInstance() match { - case ds: MicroBatchReadSupport => - val reader = ds.createMicroBatchReader(Optional.empty(), "dummy", DataSourceOptions.empty()) - assert(reader.isInstanceOf[RateStreamMicroBatchReader]) - case _ => - throw new IllegalStateException("Could not find read support for rate") + withTempDir { temp => + DataSource.lookupDataSource("rate", spark.sqlContext.conf).newInstance() match { + case ds: MicroBatchReadSupportProvider => + val readSupport = ds.createMicroBatchReadSupport( + temp.getCanonicalPath, DataSourceOptions.empty()) + assert(readSupport.isInstanceOf[RateStreamMicroBatchReadSupport]) + case _ => + throw new IllegalStateException("Could not find read support for rate") + } } } test("compatible with old path in registry") { DataSource.lookupDataSource("org.apache.spark.sql.execution.streaming.RateSourceProvider", spark.sqlContext.conf).newInstance() match { - case ds: MicroBatchReadSupport => + case ds: MicroBatchReadSupportProvider => assert(ds.isInstanceOf[RateStreamProvider]) case _ => throw new IllegalStateException("Could not find read support for rate") @@ -81,12 +82,43 @@ class RateSourceSuite extends StreamTest { .load() testStream(input)( AdvanceRateManualClock(seconds = 1), - CheckLastBatch((0 until 10).map(v => new java.sql.Timestamp(v * 100L) -> v): _*), + CheckLastBatch((0 until 10).map(v => new java.sql.Timestamp(v * 100L) -> v): _*) + ) + } + + test("microbatch - restart") { + val input = spark.readStream + .format("rate") + .option("rowsPerSecond", "10") + .load() + .select('value) + + var streamDuration = 0 + + // Microbatch rate stream offsets contain the number of seconds since the beginning of + // the stream. + def updateStreamDurationFromOffset(s: StreamExecution, expectedMin: Int): Unit = { + streamDuration = s.lastProgress.sources(0).endOffset.toInt + assert(streamDuration >= expectedMin) + } + + // We have to use the lambda version of CheckAnswer because we don't know the right range + // until we see the last offset. + def expectedResultsFromDuration(rows: Seq[Row]): Unit = { + assert(rows.map(_.getLong(0)).sorted == (0 until (streamDuration * 10))) + } + + testStream(input)( + StartStream(), + Execute(_.awaitOffset(0, LongOffset(2), streamingTimeout.toMillis)), StopStream, + Execute(updateStreamDurationFromOffset(_, 2)), + CheckAnswer(expectedResultsFromDuration _), StartStream(), - // Advance 2 seconds because creating a new RateSource will also create a new ManualClock - AdvanceRateManualClock(seconds = 2), - CheckLastBatch((10 until 20).map(v => new java.sql.Timestamp(v * 100L) -> v): _*) + Execute(_.awaitOffset(0, LongOffset(4), streamingTimeout.toMillis)), + StopStream, + Execute(updateStreamDurationFromOffset(_, 4)), + CheckAnswer(expectedResultsFromDuration _) ) } @@ -107,70 +139,67 @@ class RateSourceSuite extends StreamTest { ) } - test("microbatch - set offset") { - val temp = Files.createTempDirectory("dummy").toString - val reader = new RateStreamMicroBatchReader(DataSourceOptions.empty(), temp) - val startOffset = LongOffset(0L) - val endOffset = LongOffset(1L) - reader.setOffsetRange(Optional.of(startOffset), Optional.of(endOffset)) - assert(reader.getStartOffset() == startOffset) - assert(reader.getEndOffset() == endOffset) - } - test("microbatch - infer offsets") { - val tempFolder = Files.createTempDirectory("dummy").toString - val reader = new RateStreamMicroBatchReader( - new DataSourceOptions( - Map("numPartitions" -> "1", "rowsPerSecond" -> "100", "useManualClock" -> "true").asJava), - tempFolder) - reader.clock.asInstanceOf[ManualClock].advance(100000) - reader.setOffsetRange(Optional.empty(), Optional.empty()) - reader.getStartOffset() match { - case r: LongOffset => assert(r.offset === 0L) - case _ => throw new IllegalStateException("unexpected offset type") - } - reader.getEndOffset() match { - case r: LongOffset => assert(r.offset >= 100) - case _ => throw new IllegalStateException("unexpected offset type") + withTempDir { temp => + val readSupport = new RateStreamMicroBatchReadSupport( + new DataSourceOptions( + Map("numPartitions" -> "1", "rowsPerSecond" -> "100", "useManualClock" -> "true").asJava), + temp.getCanonicalPath) + readSupport.clock.asInstanceOf[ManualClock].advance(100000) + val startOffset = readSupport.initialOffset() + startOffset match { + case r: LongOffset => assert(r.offset === 0L) + case _ => throw new IllegalStateException("unexpected offset type") + } + readSupport.latestOffset() match { + case r: LongOffset => assert(r.offset >= 100) + case _ => throw new IllegalStateException("unexpected offset type") + } } } test("microbatch - predetermined batch size") { - val temp = Files.createTempDirectory("dummy").toString - val reader = new RateStreamMicroBatchReader( - new DataSourceOptions(Map("numPartitions" -> "1", "rowsPerSecond" -> "20").asJava), temp) - val startOffset = LongOffset(0L) - val endOffset = LongOffset(1L) - reader.setOffsetRange(Optional.of(startOffset), Optional.of(endOffset)) - val tasks = reader.planInputPartitions() - assert(tasks.size == 1) - val dataReader = tasks.get(0).createPartitionReader() - val data = ArrayBuffer[Row]() - while (dataReader.next()) { - data.append(dataReader.get()) + withTempDir { temp => + val readSupport = new RateStreamMicroBatchReadSupport( + new DataSourceOptions(Map("numPartitions" -> "1", "rowsPerSecond" -> "20").asJava), + temp.getCanonicalPath) + val startOffset = LongOffset(0L) + val endOffset = LongOffset(1L) + val config = readSupport.newScanConfigBuilder(startOffset, endOffset).build() + val tasks = readSupport.planInputPartitions(config) + val readerFactory = readSupport.createReaderFactory(config) + assert(tasks.size == 1) + val dataReader = readerFactory.createReader(tasks(0)) + val data = ArrayBuffer[InternalRow]() + while (dataReader.next()) { + data.append(dataReader.get()) + } + assert(data.size === 20) } - assert(data.size === 20) } test("microbatch - data read") { - val temp = Files.createTempDirectory("dummy").toString - val reader = new RateStreamMicroBatchReader( - new DataSourceOptions(Map("numPartitions" -> "11", "rowsPerSecond" -> "33").asJava), temp) - val startOffset = LongOffset(0L) - val endOffset = LongOffset(1L) - reader.setOffsetRange(Optional.of(startOffset), Optional.of(endOffset)) - val tasks = reader.planInputPartitions() - assert(tasks.size == 11) - - val readData = tasks.asScala - .map(_.createPartitionReader()) - .flatMap { reader => - val buf = scala.collection.mutable.ListBuffer[Row]() - while (reader.next()) buf.append(reader.get()) - buf - } + withTempDir { temp => + val readSupport = new RateStreamMicroBatchReadSupport( + new DataSourceOptions(Map("numPartitions" -> "11", "rowsPerSecond" -> "33").asJava), + temp.getCanonicalPath) + val startOffset = LongOffset(0L) + val endOffset = LongOffset(1L) + val config = readSupport.newScanConfigBuilder(startOffset, endOffset).build() + val tasks = readSupport.planInputPartitions(config) + val readerFactory = readSupport.createReaderFactory(config) + assert(tasks.size == 11) + + val readData = tasks + .map(readerFactory.createReader) + .flatMap { reader => + val buf = scala.collection.mutable.ListBuffer[InternalRow]() + while (reader.next()) buf.append(reader.get()) + buf + } - assert(readData.map(_.getLong(1)).sorted == Range(0, 33)) + assert(readData.map(_.getLong(1)).sorted === 0.until(33).toArray) + } } test("valueAtSecond") { @@ -280,41 +309,44 @@ class RateSourceSuite extends StreamTest { } test("user-specified schema given") { - val exception = intercept[AnalysisException] { + val exception = intercept[UnsupportedOperationException] { spark.readStream .format("rate") .schema(spark.range(1).schema) .load() } assert(exception.getMessage.contains( - "rate source does not support a user-specified schema")) + "rate source does not support user-specified schema")) } test("continuous in registry") { DataSource.lookupDataSource("rate", spark.sqlContext.conf).newInstance() match { - case ds: ContinuousReadSupport => - val reader = ds.createContinuousReader(Optional.empty(), "", DataSourceOptions.empty()) - assert(reader.isInstanceOf[RateStreamContinuousReader]) + case ds: ContinuousReadSupportProvider => + val readSupport = ds.createContinuousReadSupport( + "", DataSourceOptions.empty()) + assert(readSupport.isInstanceOf[RateStreamContinuousReadSupport]) case _ => throw new IllegalStateException("Could not find read support for continuous rate") } } test("continuous data") { - val reader = new RateStreamContinuousReader( + val readSupport = new RateStreamContinuousReadSupport( new DataSourceOptions(Map("numPartitions" -> "2", "rowsPerSecond" -> "20").asJava)) - reader.setStartOffset(Optional.empty()) - val tasks = reader.planInputPartitions() + val config = readSupport.newScanConfigBuilder(readSupport.initialOffset).build() + val tasks = readSupport.planInputPartitions(config) + val readerFactory = readSupport.createContinuousReaderFactory(config) assert(tasks.size == 2) - val data = scala.collection.mutable.ListBuffer[Row]() - tasks.asScala.foreach { + val data = scala.collection.mutable.ListBuffer[InternalRow]() + tasks.foreach { case t: RateStreamContinuousInputPartition => - val startTimeMs = reader.getStartOffset() + val startTimeMs = readSupport.initialOffset() .asInstanceOf[RateStreamOffset] .partitionToValueAndRunTimeMs(t.partitionIndex) .runTimeMs - val r = t.createPartitionReader().asInstanceOf[RateStreamContinuousInputPartitionReader] + val r = readerFactory.createReader(t) + .asInstanceOf[RateStreamContinuousPartitionReader] for (rowIndex <- 0 to 9) { r.next() data.append(r.get()) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketStreamSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketStreamSuite.scala index 52e8386f6b1fa..409156e5ebc70 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketStreamSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketStreamSuite.scala @@ -21,7 +21,6 @@ import java.net.{InetSocketAddress, SocketException} import java.nio.ByteBuffer import java.nio.channels.ServerSocketChannel import java.sql.Timestamp -import java.util.Optional import java.util.concurrent.LinkedBlockingQueue import scala.collection.JavaConverters._ @@ -32,12 +31,13 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.execution.datasources.DataSource import org.apache.spark.sql.execution.streaming._ +import org.apache.spark.sql.execution.streaming.continuous._ import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.sources.v2.{DataSourceOptions, MicroBatchReadSupport} -import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReader, Offset} +import org.apache.spark.sql.sources.v2.{DataSourceOptions, MicroBatchReadSupportProvider} +import org.apache.spark.sql.sources.v2.reader.streaming.Offset import org.apache.spark.sql.streaming.{StreamingQueryException, StreamTest} import org.apache.spark.sql.test.SharedSQLContext -import org.apache.spark.sql.types.{StringType, StructField, StructType, TimestampType} +import org.apache.spark.sql.types._ class TextSocketStreamSuite extends StreamTest with SharedSQLContext with BeforeAndAfterEach { @@ -48,14 +48,9 @@ class TextSocketStreamSuite extends StreamTest with SharedSQLContext with Before serverThread.join() serverThread = null } - if (batchReader != null) { - batchReader.stop() - batchReader = null - } } private var serverThread: ServerThread = null - private var batchReader: MicroBatchReader = null case class AddSocketData(data: String*) extends AddData { override def addData(query: Option[StreamExecution]): (BaseStreamingSource, Offset) = { @@ -64,7 +59,7 @@ class TextSocketStreamSuite extends StreamTest with SharedSQLContext with Before "Cannot add data when there is no query for finding the active socket source") val sources = query.get.logicalPlan.collect { - case StreamingExecutionRelation(source: TextSocketMicroBatchReader, _) => source + case StreamingExecutionRelation(source: TextSocketMicroBatchReadSupport, _) => source } if (sources.isEmpty) { throw new Exception( @@ -90,7 +85,7 @@ class TextSocketStreamSuite extends StreamTest with SharedSQLContext with Before test("backward compatibility with old path") { DataSource.lookupDataSource("org.apache.spark.sql.execution.streaming.TextSocketSourceProvider", spark.sqlContext.conf).newInstance() match { - case ds: MicroBatchReadSupport => + case ds: MicroBatchReadSupportProvider => assert(ds.isInstanceOf[TextSocketSourceProvider]) case _ => throw new IllegalStateException("Could not find socket source") @@ -180,16 +175,16 @@ class TextSocketStreamSuite extends StreamTest with SharedSQLContext with Before test("params not given") { val provider = new TextSocketSourceProvider intercept[AnalysisException] { - provider.createMicroBatchReader(Optional.empty(), "", - new DataSourceOptions(Map.empty[String, String].asJava)) + provider.createMicroBatchReadSupport( + "", new DataSourceOptions(Map.empty[String, String].asJava)) } intercept[AnalysisException] { - provider.createMicroBatchReader(Optional.empty(), "", - new DataSourceOptions(Map("host" -> "localhost").asJava)) + provider.createMicroBatchReadSupport( + "", new DataSourceOptions(Map("host" -> "localhost").asJava)) } intercept[AnalysisException] { - provider.createMicroBatchReader(Optional.empty(), "", - new DataSourceOptions(Map("port" -> "1234").asJava)) + provider.createMicroBatchReadSupport( + "", new DataSourceOptions(Map("port" -> "1234").asJava)) } } @@ -198,7 +193,7 @@ class TextSocketStreamSuite extends StreamTest with SharedSQLContext with Before val params = Map("host" -> "localhost", "port" -> "1234", "includeTimestamp" -> "fasle") intercept[AnalysisException] { val a = new DataSourceOptions(params.asJava) - provider.createMicroBatchReader(Optional.empty(), "", a) + provider.createMicroBatchReadSupport("", a) } } @@ -208,12 +203,12 @@ class TextSocketStreamSuite extends StreamTest with SharedSQLContext with Before StructField("name", StringType) :: StructField("area", StringType) :: Nil) val params = Map("host" -> "localhost", "port" -> "1234") - val exception = intercept[AnalysisException] { - provider.createMicroBatchReader( - Optional.of(userSpecifiedSchema), "", new DataSourceOptions(params.asJava)) + val exception = intercept[UnsupportedOperationException] { + provider.createMicroBatchReadSupport( + userSpecifiedSchema, "", new DataSourceOptions(params.asJava)) } assert(exception.getMessage.contains( - "socket source does not support a user-specified schema")) + "socket source does not support user-specified schema")) } test("input row metrics") { @@ -300,6 +295,102 @@ class TextSocketStreamSuite extends StreamTest with SharedSQLContext with Before } } + test("continuous data") { + serverThread = new ServerThread() + serverThread.start() + + val readSupport = new TextSocketContinuousReadSupport( + new DataSourceOptions(Map("numPartitions" -> "2", "host" -> "localhost", + "port" -> serverThread.port.toString).asJava)) + + val scanConfig = readSupport.newScanConfigBuilder(readSupport.initialOffset()).build() + val tasks = readSupport.planInputPartitions(scanConfig) + assert(tasks.size == 2) + + val numRecords = 10 + val data = scala.collection.mutable.ListBuffer[Int]() + val offsets = scala.collection.mutable.ListBuffer[Int]() + val readerFactory = readSupport.createContinuousReaderFactory(scanConfig) + import org.scalatest.time.SpanSugar._ + failAfter(5 seconds) { + // inject rows, read and check the data and offsets + for (i <- 0 until numRecords) { + serverThread.enqueue(i.toString) + } + tasks.foreach { + case t: TextSocketContinuousInputPartition => + val r = readerFactory.createReader(t).asInstanceOf[TextSocketContinuousPartitionReader] + for (i <- 0 until numRecords / 2) { + r.next() + offsets.append(r.getOffset().asInstanceOf[ContinuousRecordPartitionOffset].offset) + data.append(r.get().get(0, DataTypes.StringType).asInstanceOf[String].toInt) + // commit the offsets in the middle and validate if processing continues + if (i == 2) { + commitOffset(t.partitionId, i + 1) + } + } + assert(offsets.toSeq == Range.inclusive(1, 5)) + assert(data.toSeq == Range(t.partitionId, 10, 2)) + offsets.clear() + data.clear() + case _ => throw new IllegalStateException("Unexpected task type") + } + assert(readSupport.startOffset.offsets == List(3, 3)) + readSupport.commit(TextSocketOffset(List(5, 5))) + assert(readSupport.startOffset.offsets == List(5, 5)) + } + + def commitOffset(partition: Int, offset: Int): Unit = { + val offsetsToCommit = readSupport.startOffset.offsets.updated(partition, offset) + readSupport.commit(TextSocketOffset(offsetsToCommit)) + assert(readSupport.startOffset.offsets == offsetsToCommit) + } + } + + test("continuous data - invalid commit") { + serverThread = new ServerThread() + serverThread.start() + + val readSupport = new TextSocketContinuousReadSupport( + new DataSourceOptions(Map("numPartitions" -> "2", "host" -> "localhost", + "port" -> serverThread.port.toString).asJava)) + + readSupport.startOffset = TextSocketOffset(List(5, 5)) + assertThrows[IllegalStateException] { + readSupport.commit(TextSocketOffset(List(6, 6))) + } + } + + test("continuous data with timestamp") { + serverThread = new ServerThread() + serverThread.start() + + val readSupport = new TextSocketContinuousReadSupport( + new DataSourceOptions(Map("numPartitions" -> "2", "host" -> "localhost", + "includeTimestamp" -> "true", + "port" -> serverThread.port.toString).asJava)) + val scanConfig = readSupport.newScanConfigBuilder(readSupport.initialOffset()).build() + val tasks = readSupport.planInputPartitions(scanConfig) + assert(tasks.size == 2) + + val numRecords = 4 + // inject rows, read and check the data and offsets + for (i <- 0 until numRecords) { + serverThread.enqueue(i.toString) + } + val readerFactory = readSupport.createContinuousReaderFactory(scanConfig) + tasks.foreach { + case t: TextSocketContinuousInputPartition => + val r = readerFactory.createReader(t).asInstanceOf[TextSocketContinuousPartitionReader] + for (i <- 0 until numRecords / 2) { + r.next() + assert(r.get().get(0, TextSocketReader.SCHEMA_TIMESTAMP) + .isInstanceOf[(String, Timestamp)]) + } + case _ => throw new IllegalStateException("Unexpected task type") + } + } + /** * This class tries to mimic the behavior of netcat, so that we can ensure * TextSocketStream supports netcat, which only accepts the first connection diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/FlatMapGroupsWithStateExecHelperSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/FlatMapGroupsWithStateExecHelperSuite.scala new file mode 100644 index 0000000000000..dec30fd01f7e2 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/FlatMapGroupsWithStateExecHelperSuite.scala @@ -0,0 +1,218 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.state + +import java.util.concurrent.atomic.AtomicInteger + +import org.apache.spark.sql.Encoder +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder +import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeProjection, UnsafeRow} +import org.apache.spark.sql.execution.streaming.GroupStateImpl._ +import org.apache.spark.sql.streaming.FlatMapGroupsWithStateSuite._ +import org.apache.spark.sql.streaming.StreamTest +import org.apache.spark.sql.types._ + + +class FlatMapGroupsWithStateExecHelperSuite extends StreamTest { + + import testImplicits._ + import FlatMapGroupsWithStateExecHelper._ + + // ============================ StateManagerImplV1 ============================ + + test(s"StateManager v1 - primitive type - without timestamp") { + val schema = new StructType().add("value", IntegerType, nullable = false) + testStateManagerWithoutTimestamp[Int](version = 1, schema, Seq(0, 10)) + } + + test(s"StateManager v1 - primitive type - with timestamp") { + val schema = new StructType() + .add("value", IntegerType, nullable = false) + .add("timeoutTimestamp", IntegerType, nullable = false) + testStateManagerWithTimestamp[Int](version = 1, schema, Seq(0, 10)) + } + + test(s"StateManager v1 - nested type - without timestamp") { + val schema = StructType(Seq( + StructField("i", IntegerType, nullable = false), + StructField("nested", StructType(Seq( + StructField("d", DoubleType, nullable = false), + StructField("str", StringType)) + )) + )) + + val testValues = Seq( + NestedStruct(1, Struct(1.0, "someString")), + NestedStruct(0, Struct(0.0, "")), + NestedStruct(0, null)) + + testStateManagerWithoutTimestamp[NestedStruct](version = 1, schema, testValues) + + // Verify the limitation of v1 with null state + intercept[Exception] { + testStateManagerWithoutTimestamp[NestedStruct](version = 1, schema, testValues = Seq(null)) + } + } + + test(s"StateManager v1 - nested type - with timestamp") { + val schema = StructType(Seq( + StructField("i", IntegerType, nullable = false), + StructField("nested", StructType(Seq( + StructField("d", DoubleType, nullable = false), + StructField("str", StringType)) + )), + StructField("timeoutTimestamp", IntegerType, nullable = false) + )) + + val testValues = Seq( + NestedStruct(1, Struct(1.0, "someString")), + NestedStruct(0, Struct(0.0, "")), + NestedStruct(0, null)) + + testStateManagerWithTimestamp[NestedStruct](version = 1, schema, testValues) + + // Verify the limitation of v1 with null state + intercept[Exception] { + testStateManagerWithTimestamp[NestedStruct](version = 1, schema, testValues = Seq(null)) + } + } + + // ============================ StateManagerImplV2 ============================ + + test(s"StateManager v2 - primitive type - without timestamp") { + val schema = new StructType() + .add("groupState", new StructType().add("value", IntegerType, nullable = false)) + testStateManagerWithoutTimestamp[Int](version = 2, schema, Seq(0, 10)) + } + + test(s"StateManager v2 - primitive type - with timestamp") { + val schema = new StructType() + .add("groupState", new StructType().add("value", IntegerType, nullable = false)) + .add("timeoutTimestamp", LongType, nullable = false) + testStateManagerWithTimestamp[Int](version = 2, schema, Seq(0, 10)) + } + + test(s"StateManager v2 - nested type - without timestamp") { + val schema = StructType(Seq( + StructField("groupState", StructType(Seq( + StructField("i", IntegerType, nullable = false), + StructField("nested", StructType(Seq( + StructField("d", DoubleType, nullable = false), + StructField("str", StringType) + ))) + ))) + )) + + val testValues = Seq( + NestedStruct(1, Struct(1.0, "someString")), + NestedStruct(0, Struct(0.0, "")), + NestedStruct(0, null), + null) + + testStateManagerWithoutTimestamp[NestedStruct](version = 2, schema, testValues) + } + + test(s"StateManager v2 - nested type - with timestamp") { + val schema = StructType(Seq( + StructField("groupState", StructType(Seq( + StructField("i", IntegerType, nullable = false), + StructField("nested", StructType(Seq( + StructField("d", DoubleType, nullable = false), + StructField("str", StringType) + ))) + ))), + StructField("timeoutTimestamp", LongType, nullable = false) + )) + + val testValues = Seq( + NestedStruct(1, Struct(1.0, "someString")), + NestedStruct(0, Struct(0.0, "")), + NestedStruct(0, null), + null) + + testStateManagerWithTimestamp[NestedStruct](version = 2, schema, testValues) + } + + + def testStateManagerWithoutTimestamp[T: Encoder]( + version: Int, + expectedStateSchema: StructType, + testValues: Seq[T]): Unit = { + val stateManager = newStateManager[T](version, withTimestamp = false) + assert(stateManager.stateSchema === expectedStateSchema) + testStateManager(stateManager, testValues, NO_TIMESTAMP) + } + + def testStateManagerWithTimestamp[T: Encoder]( + version: Int, + expectedStateSchema: StructType, + testValues: Seq[T]): Unit = { + val stateManager = newStateManager[T](version, withTimestamp = true) + assert(stateManager.stateSchema === expectedStateSchema) + for (timestamp <- Seq(NO_TIMESTAMP, 1000)) { + testStateManager(stateManager, testValues, timestamp) + } + } + + private def testStateManager[T: Encoder]( + stateManager: StateManager, + values: Seq[T], + timestamp: Long): Unit = { + val keys = (1 to values.size).map(_ => newKey()) + val store = new MemoryStateStore() + + // Test stateManager.getState(), putState(), removeState() + keys.zip(values).foreach { case (key, value) => + try { + stateManager.putState(store, key, value, timestamp) + val data = stateManager.getState(store, key) + assert(data.stateObj == value) + assert(data.timeoutTimestamp === timestamp) + stateManager.removeState(store, key) + assert(stateManager.getState(store, key).stateObj == null) + } catch { + case e: Throwable => + fail(s"put/get/remove test with '$value' failed", e) + } + } + + // Test stateManager.getAllState() + for (i <- keys.indices) { + stateManager.putState(store, keys(i), values(i), timestamp) + } + val allData = stateManager.getAllState(store).map(_.copy()).toArray + assert(allData.map(_.timeoutTimestamp).toSet == Set(timestamp)) + assert(allData.map(_.stateObj).toSet == values.toSet) + } + + private def newStateManager[T: Encoder](version: Int, withTimestamp: Boolean): StateManager = { + FlatMapGroupsWithStateExecHelper.createStateManager( + implicitly[Encoder[T]].asInstanceOf[ExpressionEncoder[Any]], + withTimestamp, + version) + } + + private val proj = UnsafeProjection.create(Array[DataType](IntegerType)) + private val keyCounter = new AtomicInteger(0) + private def newKey(): UnsafeRow = { + proj.apply(new GenericInternalRow(Array[Any](keyCounter.getAndDecrement()))).copy() + } +} + +case class Struct(d: Double, str: String) +case class NestedStruct(i: Int, nested: Struct) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/MemoryStateStore.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/MemoryStateStore.scala new file mode 100644 index 0000000000000..98586d6492c9e --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/MemoryStateStore.scala @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.state + +import java.util.concurrent.ConcurrentHashMap + +import org.apache.spark.sql.catalyst.expressions.UnsafeRow + +class MemoryStateStore extends StateStore() { + import scala.collection.JavaConverters._ + private val map = new ConcurrentHashMap[UnsafeRow, UnsafeRow] + + override def iterator(): Iterator[UnsafeRowPair] = { + map.entrySet.iterator.asScala.map { case e => new UnsafeRowPair(e.getKey, e.getValue) } + } + + override def get(key: UnsafeRow): UnsafeRow = map.get(key) + + override def put(key: UnsafeRow, newValue: UnsafeRow): Unit = map.put(key.copy(), newValue.copy()) + + override def remove(key: UnsafeRow): Unit = map.remove(key) + + override def commit(): Long = version + 1 + + override def abort(): Unit = {} + + override def id: StateStoreId = null + + override def version: Long = 0 + + override def metrics: StateStoreMetrics = new StateStoreMetrics(map.size, 0, Map.empty) + + override def hasCommitted: Boolean = true +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala index 579a364ebc3e5..015415a534ff5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala @@ -49,8 +49,11 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn } override def afterAll(): Unit = { - super.afterAll() - Utils.deleteRecursively(new File(tempDir)) + try { + super.afterAll() + } finally { + Utils.deleteRecursively(new File(tempDir)) + } } test("versioning and immutability") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala index 73f8705060402..5e973145b0a37 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.streaming.state import java.io.{File, IOException} import java.net.URI +import java.util import java.util.UUID import scala.collection.JavaConverters._ @@ -47,6 +48,7 @@ import org.apache.spark.util.Utils class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider] with BeforeAndAfter with PrivateMethodTester { type MapType = mutable.HashMap[UnsafeRow, UnsafeRow] + type ProviderMapType = java.util.concurrent.ConcurrentHashMap[UnsafeRow, UnsafeRow] import StateStoreCoordinatorSuite._ import StateStoreTestsHelper._ @@ -64,21 +66,143 @@ class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider] require(!StateStore.isMaintenanceRunning) } + def updateVersionTo( + provider: StateStoreProvider, + currentVersion: Int, + targetVersion: Int): Int = { + var newCurrentVersion = currentVersion + for (i <- newCurrentVersion until targetVersion) { + newCurrentVersion = incrementVersion(provider, i) + } + require(newCurrentVersion === targetVersion) + newCurrentVersion + } + + def incrementVersion(provider: StateStoreProvider, currentVersion: Int): Int = { + val store = provider.getStore(currentVersion) + put(store, "a", currentVersion + 1) + store.commit() + currentVersion + 1 + } + + def checkLoadedVersions( + loadedMaps: util.SortedMap[Long, ProviderMapType], + count: Int, + earliestKey: Long, + latestKey: Long): Unit = { + assert(loadedMaps.size() === count) + assert(loadedMaps.firstKey() === earliestKey) + assert(loadedMaps.lastKey() === latestKey) + } + + def checkVersion( + loadedMaps: util.SortedMap[Long, ProviderMapType], + version: Long, + expectedData: Map[String, Int]): Unit = { + + val originValueMap = loadedMaps.get(version).asScala.map { entry => + rowToString(entry._1) -> rowToInt(entry._2) + }.toMap + + assert(originValueMap === expectedData) + } + + test("retaining only two latest versions when MAX_BATCHES_TO_RETAIN_IN_MEMORY set to 2") { + val provider = newStoreProvider(opId = Random.nextInt, partition = 0, + numOfVersToRetainInMemory = 2) + + var currentVersion = 0 + + // commit the ver 1 : cache will have one element + currentVersion = incrementVersion(provider, currentVersion) + assert(getData(provider) === Set("a" -> 1)) + var loadedMaps = provider.getLoadedMaps() + checkLoadedVersions(loadedMaps, count = 1, earliestKey = 1, latestKey = 1) + checkVersion(loadedMaps, 1, Map("a" -> 1)) + + // commit the ver 2 : cache will have two elements + currentVersion = incrementVersion(provider, currentVersion) + assert(getData(provider) === Set("a" -> 2)) + loadedMaps = provider.getLoadedMaps() + checkLoadedVersions(loadedMaps, count = 2, earliestKey = 2, latestKey = 1) + checkVersion(loadedMaps, 2, Map("a" -> 2)) + checkVersion(loadedMaps, 1, Map("a" -> 1)) + + // commit the ver 3 : cache has already two elements and adding ver 3 incurs exceeding cache, + // and ver 3 will be added but ver 1 will be evicted + currentVersion = incrementVersion(provider, currentVersion) + assert(getData(provider) === Set("a" -> 3)) + loadedMaps = provider.getLoadedMaps() + checkLoadedVersions(loadedMaps, count = 2, earliestKey = 3, latestKey = 2) + checkVersion(loadedMaps, 3, Map("a" -> 3)) + checkVersion(loadedMaps, 2, Map("a" -> 2)) + } + + test("failure after committing with MAX_BATCHES_TO_RETAIN_IN_MEMORY set to 1") { + val provider = newStoreProvider(opId = Random.nextInt, partition = 0, + numOfVersToRetainInMemory = 1) + + var currentVersion = 0 + + // commit the ver 1 : cache will have one element + currentVersion = incrementVersion(provider, currentVersion) + assert(getData(provider) === Set("a" -> 1)) + var loadedMaps = provider.getLoadedMaps() + checkLoadedVersions(loadedMaps, count = 1, earliestKey = 1, latestKey = 1) + checkVersion(loadedMaps, 1, Map("a" -> 1)) + + // commit the ver 2 : cache has already one elements and adding ver 2 incurs exceeding cache, + // and ver 2 will be added but ver 1 will be evicted + // this fact ensures cache miss will occur when this partition succeeds commit + // but there's a failure afterwards so have to reprocess previous batch + currentVersion = incrementVersion(provider, currentVersion) + assert(getData(provider) === Set("a" -> 2)) + loadedMaps = provider.getLoadedMaps() + checkLoadedVersions(loadedMaps, count = 1, earliestKey = 2, latestKey = 2) + checkVersion(loadedMaps, 2, Map("a" -> 2)) + + // suppose there has been failure after committing, and it decided to reprocess previous batch + currentVersion = 1 + + // committing to existing version which is committed partially but abandoned globally + val store = provider.getStore(currentVersion) + // negative value to represent reprocessing + put(store, "a", -2) + store.commit() + currentVersion += 1 + + // make sure newly committed version is reflected to the cache (overwritten) + assert(getData(provider) === Set("a" -> -2)) + loadedMaps = provider.getLoadedMaps() + checkLoadedVersions(loadedMaps, count = 1, earliestKey = 2, latestKey = 2) + checkVersion(loadedMaps, 2, Map("a" -> -2)) + } + + test("no cache data with MAX_BATCHES_TO_RETAIN_IN_MEMORY set to 0") { + val provider = newStoreProvider(opId = Random.nextInt, partition = 0, + numOfVersToRetainInMemory = 0) + + var currentVersion = 0 + + // commit the ver 1 : never cached + currentVersion = incrementVersion(provider, currentVersion) + assert(getData(provider) === Set("a" -> 1)) + var loadedMaps = provider.getLoadedMaps() + assert(loadedMaps.size() === 0) + + // commit the ver 2 : never cached + currentVersion = incrementVersion(provider, currentVersion) + assert(getData(provider) === Set("a" -> 2)) + loadedMaps = provider.getLoadedMaps() + assert(loadedMaps.size() === 0) + } + test("snapshotting") { val provider = newStoreProvider(opId = Random.nextInt, partition = 0, minDeltasForSnapshot = 5) var currentVersion = 0 - def updateVersionTo(targetVersion: Int): Unit = { - for (i <- currentVersion + 1 to targetVersion) { - val store = provider.getStore(currentVersion) - put(store, "a", i) - store.commit() - currentVersion += 1 - } - require(currentVersion === targetVersion) - } - updateVersionTo(2) + currentVersion = updateVersionTo(provider, currentVersion, 2) require(getData(provider) === Set("a" -> 2)) provider.doMaintenance() // should not generate snapshot files assert(getData(provider) === Set("a" -> 2)) @@ -89,7 +213,7 @@ class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider] } // After version 6, snapshotting should generate one snapshot file - updateVersionTo(6) + currentVersion = updateVersionTo(provider, currentVersion, 6) require(getData(provider) === Set("a" -> 6), "store not updated correctly") provider.doMaintenance() // should generate snapshot files @@ -104,7 +228,7 @@ class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider] "snapshotting messed up the data of the final version") // After version 20, snapshotting should generate newer snapshot files - updateVersionTo(20) + currentVersion = updateVersionTo(provider, currentVersion, 20) require(getData(provider) === Set("a" -> 20), "store not updated correctly") provider.doMaintenance() // do snapshot @@ -193,6 +317,22 @@ class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider] assert(store.metrics.memoryUsedBytes > noDataMemoryUsed) } + test("reports memory usage on current version") { + def getSizeOfStateForCurrentVersion(metrics: StateStoreMetrics): Long = { + val metricPair = metrics.customMetrics.find(_._1.name == "stateOnCurrentVersionSizeBytes") + assert(metricPair.isDefined) + metricPair.get._2 + } + + val provider = newStoreProvider() + val store = provider.getStore(0) + val noDataMemoryUsed = getSizeOfStateForCurrentVersion(store.metrics) + + put(store, "a", 1) + store.commit() + assert(getSizeOfStateForCurrentVersion(store.metrics) > noDataMemoryUsed) + } + test("StateStore.get") { quietly { val dir = newDir() @@ -507,6 +647,90 @@ class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider] assert(CreateAtomicTestManager.cancelCalledInCreateAtomic) } + test("expose metrics with custom metrics to StateStoreMetrics") { + def getCustomMetric(metrics: StateStoreMetrics, name: String): Long = { + val metricPair = metrics.customMetrics.find(_._1.name == name) + assert(metricPair.isDefined) + metricPair.get._2 + } + + def getLoadedMapSizeMetric(metrics: StateStoreMetrics): Long = { + metrics.memoryUsedBytes + } + + def assertCacheHitAndMiss( + metrics: StateStoreMetrics, + expectedCacheHitCount: Long, + expectedCacheMissCount: Long): Unit = { + val cacheHitCount = getCustomMetric(metrics, "loadedMapCacheHitCount") + val cacheMissCount = getCustomMetric(metrics, "loadedMapCacheMissCount") + assert(cacheHitCount === expectedCacheHitCount) + assert(cacheMissCount === expectedCacheMissCount) + } + + val provider = newStoreProvider() + + // Verify state before starting a new set of updates + assert(getLatestData(provider).isEmpty) + + val store = provider.getStore(0) + assert(!store.hasCommitted) + + assert(store.metrics.numKeys === 0) + + val initialLoadedMapSize = getLoadedMapSizeMetric(store.metrics) + assert(initialLoadedMapSize >= 0) + assertCacheHitAndMiss(store.metrics, expectedCacheHitCount = 0, expectedCacheMissCount = 0) + + put(store, "a", 1) + assert(store.metrics.numKeys === 1) + + put(store, "b", 2) + put(store, "aa", 3) + assert(store.metrics.numKeys === 3) + remove(store, _.startsWith("a")) + assert(store.metrics.numKeys === 1) + assert(store.commit() === 1) + + assert(store.hasCommitted) + + val loadedMapSizeForVersion1 = getLoadedMapSizeMetric(store.metrics) + assert(loadedMapSizeForVersion1 > initialLoadedMapSize) + assertCacheHitAndMiss(store.metrics, expectedCacheHitCount = 0, expectedCacheMissCount = 0) + + val storeV2 = provider.getStore(1) + assert(!storeV2.hasCommitted) + assert(storeV2.metrics.numKeys === 1) + + put(storeV2, "cc", 4) + assert(storeV2.metrics.numKeys === 2) + assert(storeV2.commit() === 2) + + assert(storeV2.hasCommitted) + + val loadedMapSizeForVersion1And2 = getLoadedMapSizeMetric(storeV2.metrics) + assert(loadedMapSizeForVersion1And2 > loadedMapSizeForVersion1) + assertCacheHitAndMiss(storeV2.metrics, expectedCacheHitCount = 1, expectedCacheMissCount = 0) + + val reloadedProvider = newStoreProvider(store.id) + // intended to load version 2 instead of 1 + // version 2 will not be loaded to the cache in provider + val reloadedStore = reloadedProvider.getStore(1) + assert(reloadedStore.metrics.numKeys === 1) + + assert(getLoadedMapSizeMetric(reloadedStore.metrics) === loadedMapSizeForVersion1) + assertCacheHitAndMiss(reloadedStore.metrics, expectedCacheHitCount = 0, + expectedCacheMissCount = 1) + + // now we are loading version 2 + val reloadedStoreV2 = reloadedProvider.getStore(2) + assert(reloadedStoreV2.metrics.numKeys === 2) + + assert(getLoadedMapSizeMetric(reloadedStoreV2.metrics) > loadedMapSizeForVersion1) + assertCacheHitAndMiss(reloadedStoreV2.metrics, expectedCacheHitCount = 0, + expectedCacheMissCount = 2) + } + override def newStoreProvider(): HDFSBackedStateStoreProvider = { newStoreProvider(opId = Random.nextInt(), partition = 0) } @@ -535,9 +759,11 @@ class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider] partition: Int, dir: String = newDir(), minDeltasForSnapshot: Int = SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT.defaultValue.get, + numOfVersToRetainInMemory: Int = SQLConf.MAX_BATCHES_TO_RETAIN_IN_MEMORY.defaultValue.get, hadoopConf: Configuration = new Configuration): HDFSBackedStateStoreProvider = { val sqlConf = new SQLConf() sqlConf.setConf(SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT, minDeltasForSnapshot) + sqlConf.setConf(SQLConf.MAX_BATCHES_TO_RETAIN_IN_MEMORY, numOfVersToRetainInMemory) sqlConf.setConf(SQLConf.MIN_BATCHES_TO_RETAIN, 2) val provider = new HDFSBackedStateStoreProvider() provider.init( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StreamingAggregationStateManagerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StreamingAggregationStateManagerSuite.scala new file mode 100644 index 0000000000000..daacdfd58c7b9 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StreamingAggregationStateManagerSuite.scala @@ -0,0 +1,126 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.state + +import org.apache.spark.sql.catalyst.expressions.{Attribute, SpecificInternalRow, UnsafeProjection, UnsafeRow} +import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection +import org.apache.spark.sql.streaming.StreamTest +import org.apache.spark.sql.types.{IntegerType, StructField, StructType} + +class StreamingAggregationStateManagerSuite extends StreamTest { + // ============================ fields and method for test data ============================ + + val testKeys: Seq[String] = Seq("key1", "key2") + val testValues: Seq[String] = Seq("sum(key1)", "sum(key2)") + + val testOutputSchema: StructType = StructType( + testKeys.map(createIntegerField) ++ testValues.map(createIntegerField)) + + val testOutputAttributes: Seq[Attribute] = testOutputSchema.toAttributes + val testKeyAttributes: Seq[Attribute] = testOutputAttributes.filter { p => + testKeys.contains(p.name) + } + val testValuesAttributes: Seq[Attribute] = testOutputAttributes.filter { p => + testValues.contains(p.name) + } + val expectedTestValuesSchema: StructType = testValuesAttributes.toStructType + + val testRow: UnsafeRow = { + val unsafeRowProjection = UnsafeProjection.create(testOutputSchema) + val row = unsafeRowProjection(new SpecificInternalRow(testOutputSchema)) + (testKeys ++ testValues).zipWithIndex.foreach { case (_, index) => row.setInt(index, index) } + row + } + + val expectedTestKeyRow: UnsafeRow = { + val keyProjector = GenerateUnsafeProjection.generate(testKeyAttributes, testOutputAttributes) + keyProjector(testRow) + } + + val expectedTestValueRowForV2: UnsafeRow = { + val valueProjector = GenerateUnsafeProjection.generate(testValuesAttributes, + testOutputAttributes) + valueProjector(testRow) + } + + private def createIntegerField(name: String): StructField = { + StructField(name, IntegerType, nullable = false) + } + + // ============================ StateManagerImplV1 ============================ + + test("StateManager v1 - get, put, iter") { + val stateManager = StreamingAggregationStateManager.createStateManager(testKeyAttributes, + testOutputAttributes, 1) + + // in V1, input row is stored as value + testGetPutIterOnStateManager(stateManager, testOutputSchema, testRow, + expectedTestKeyRow, expectedStateValue = testRow) + } + + // ============================ StateManagerImplV2 ============================ + test("StateManager v2 - get, put, iter") { + val stateManager = StreamingAggregationStateManager.createStateManager(testKeyAttributes, + testOutputAttributes, 2) + + // in V2, row for values itself (excluding keys from input row) is stored as value + // so that stored value doesn't have key part, but state manager V2 will provide same output + // as V1 when getting row for key + testGetPutIterOnStateManager(stateManager, expectedTestValuesSchema, testRow, + expectedTestKeyRow, expectedTestValueRowForV2) + } + + private def testGetPutIterOnStateManager( + stateManager: StreamingAggregationStateManager, + expectedValueSchema: StructType, + inputRow: UnsafeRow, + expectedStateKey: UnsafeRow, + expectedStateValue: UnsafeRow): Unit = { + + assert(stateManager.getStateValueSchema === expectedValueSchema) + + val memoryStateStore = new MemoryStateStore() + stateManager.put(memoryStateStore, inputRow) + + assert(memoryStateStore.iterator().size === 1) + assert(stateManager.iterator(memoryStateStore).size === memoryStateStore.iterator().size) + + val keyRow = stateManager.getKey(inputRow) + assert(keyRow === expectedStateKey) + + // iterate state store and verify whether expected format of key and value are stored + val pair = memoryStateStore.iterator().next() + assert(pair.key === keyRow) + assert(pair.value === expectedStateValue) + + // iterate with state manager and see whether original rows are returned as values + val pairFromStateManager = stateManager.iterator(memoryStateStore).next() + assert(pairFromStateManager.key === keyRow) + assert(pairFromStateManager.value === inputRow) + + // following as keys and values + assert(stateManager.keys(memoryStateStore).next() === keyRow) + assert(stateManager.values(memoryStateStore).next() === inputRow) + + // verify the stored value once again via get + assert(memoryStateStore.get(keyRow) === expectedStateValue) + + // state manager should return row which is same as input row regardless of format version + assert(inputRow === stateManager.get(memoryStateStore, keyRow)) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ArrowColumnVectorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ArrowColumnVectorSuite.scala index b55489cb2678a..4592a1663faed 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ArrowColumnVectorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ArrowColumnVectorSuite.scala @@ -336,7 +336,7 @@ class ArrowColumnVectorSuite extends SparkFunSuite { val allocator = ArrowUtils.rootAllocator.newChildAllocator("struct", 0, Long.MaxValue) val schema = new StructType().add("int", IntegerType).add("long", LongType) val vector = ArrowUtils.toArrowField("struct", schema, nullable = false, null) - .createVector(allocator).asInstanceOf[NullableMapVector] + .createVector(allocator).asInstanceOf[StructVector] vector.allocateNew() val intVector = vector.getChildByOrdinal(0).asInstanceOf[IntVector] @@ -373,7 +373,7 @@ class ArrowColumnVectorSuite extends SparkFunSuite { val allocator = ArrowUtils.rootAllocator.newChildAllocator("struct", 0, Long.MaxValue) val schema = new StructType().add("int", IntegerType).add("long", LongType) val vector = ArrowUtils.toArrowField("struct", schema, nullable = true, null) - .createVector(allocator).asInstanceOf[NullableMapVector] + .createVector(allocator).asInstanceOf[StructVector] vector.allocateNew() val intVector = vector.getChildByOrdinal(0).asInstanceOf[IntVector] val longVector = vector.getChildByOrdinal(1).asInstanceOf[BigIntVector] diff --git a/sql/core/src/test/scala/org/apache/spark/sql/internal/ExecutorSideSQLConfSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/internal/ExecutorSideSQLConfSuite.scala index 3dd0712e02448..d885348f3774a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/internal/ExecutorSideSQLConfSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/internal/ExecutorSideSQLConfSuite.scala @@ -18,7 +18,9 @@ package org.apache.spark.sql.internal import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.{AnalysisException, SparkSession} +import org.apache.spark.sql.execution.debug.codegenStringSeq +import org.apache.spark.sql.functions.col import org.apache.spark.sql.test.SQLTestUtils class ExecutorSideSQLConfSuite extends SparkFunSuite with SQLTestUtils { @@ -36,20 +38,32 @@ class ExecutorSideSQLConfSuite extends SparkFunSuite with SQLTestUtils { } override def afterAll(): Unit = { - spark.stop() - spark = null + try { + spark.stop() + spark = null + } finally { + super.afterAll() + } + } + + override def withSQLConf(pairs: (String, String)*)(f: => Unit): Unit = { + pairs.foreach { case (k, v) => + SQLConf.get.setConfString(k, v) + } + try f finally { + pairs.foreach { case (k, _) => + SQLConf.get.unsetConf(k) + } + } } test("ReadOnlySQLConf is correctly created at the executor side") { - SQLConf.get.setConfString("spark.sql.x", "a") - try { - val checks = spark.range(10).mapPartitions { it => + withSQLConf("spark.sql.x" -> "a") { + val checks = spark.range(10).mapPartitions { _ => val conf = SQLConf.get Iterator(conf.isInstanceOf[ReadOnlySQLConf] && conf.getConfString("spark.sql.x") == "a") }.collect() assert(checks.forall(_ == true)) - } finally { - SQLConf.get.unsetConf("spark.sql.x") } } @@ -63,4 +77,29 @@ class ExecutorSideSQLConfSuite extends SparkFunSuite with SQLTestUtils { } } } + + test("SPARK-24727 CODEGEN_CACHE_MAX_ENTRIES is correctly referenced at the executor side") { + withSQLConf(StaticSQLConf.CODEGEN_CACHE_MAX_ENTRIES.key -> "300") { + val checks = spark.range(10).mapPartitions { _ => + val conf = SQLConf.get + Iterator(conf.isInstanceOf[ReadOnlySQLConf] && + conf.getConfString(StaticSQLConf.CODEGEN_CACHE_MAX_ENTRIES.key) == "300") + }.collect() + assert(checks.forall(_ == true)) + } + } + + test("SPARK-22219: refactor to control to generate comment") { + Seq(true, false).foreach { flag => + withSQLConf(StaticSQLConf.CODEGEN_COMMENTS.key -> flag.toString) { + val res = codegenStringSeq(spark.range(10).groupBy(col("id") * 2).count() + .queryExecution.executedPlan) + assert(res.length == 2) + assert(res.forall { case (_, code) => + (code.contains("* Codegend pipeline") == flag) && + (code.contains("// input[") == flag) + }) + } + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfGetterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfGetterSuite.scala new file mode 100644 index 0000000000000..bb79d3a84e5a3 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfGetterSuite.scala @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.internal + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.{LocalSparkSession, SparkSession} + +class SQLConfGetterSuite extends SparkFunSuite with LocalSparkSession { + + test("SPARK-25076: SQLConf should not be retrieved from a stopped SparkSession") { + spark = SparkSession.builder().master("local").getOrCreate() + assert(SQLConf.get eq spark.sessionState.conf, + "SQLConf.get should get the conf from the active spark session.") + spark.stop() + assert(SQLConf.get eq SQLConf.getFallbackConf, + "SQLConf.get should not get conf from a stopped spark session.") + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index 0389273d6cdfa..7fa0e7fc162ca 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -24,7 +24,7 @@ import java.util.{Calendar, GregorianCalendar, Properties} import org.h2.jdbc.JdbcSQLException import org.scalatest.{BeforeAndAfter, PrivateMethodTester} -import org.apache.spark.{SparkException, SparkFunSuite} +import org.apache.spark.SparkException import org.apache.spark.sql.{AnalysisException, DataFrame, QueryTest, Row} import org.apache.spark.sql.catalyst.parser.CatalystSqlParser import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap @@ -244,6 +244,17 @@ class JDBCSuite extends QueryTest .executeUpdate() conn.commit() + conn.prepareStatement("CREATE TABLE test.datetime (d DATE, t TIMESTAMP)").executeUpdate() + conn.prepareStatement( + "INSERT INTO test.datetime VALUES ('2018-07-06', '2018-07-06 05:50:00.0')").executeUpdate() + conn.prepareStatement( + "INSERT INTO test.datetime VALUES ('2018-07-06', '2018-07-06 08:10:08.0')").executeUpdate() + conn.prepareStatement( + "INSERT INTO test.datetime VALUES ('2018-07-08', '2018-07-08 13:32:01.0')").executeUpdate() + conn.prepareStatement( + "INSERT INTO test.datetime VALUES ('2018-07-12', '2018-07-12 09:51:15.0')").executeUpdate() + conn.commit() + // Untested: IDENTITY, OTHER, UUID, ARRAY, and GEOMETRY types. } @@ -261,21 +272,32 @@ class JDBCSuite extends QueryTest s"Expecting a JDBCRelation with $expectedNumPartitions partitions, but got:`$jdbcRelations`") } + private def checkPushdown(df: DataFrame): DataFrame = { + val parentPlan = df.queryExecution.executedPlan + // Check if SparkPlan Filter is removed in a physical plan and + // the plan only has PhysicalRDD to scan JDBCRelation. + assert(parentPlan.isInstanceOf[org.apache.spark.sql.execution.WholeStageCodegenExec]) + val node = parentPlan.asInstanceOf[org.apache.spark.sql.execution.WholeStageCodegenExec] + assert(node.child.isInstanceOf[org.apache.spark.sql.execution.DataSourceScanExec]) + assert(node.child.asInstanceOf[DataSourceScanExec].nodeName.contains("JDBCRelation")) + df + } + + private def checkNotPushdown(df: DataFrame): DataFrame = { + val parentPlan = df.queryExecution.executedPlan + // Check if SparkPlan Filter is not removed in a physical plan because JDBCRDD + // cannot compile given predicates. + assert(parentPlan.isInstanceOf[org.apache.spark.sql.execution.WholeStageCodegenExec]) + val node = parentPlan.asInstanceOf[org.apache.spark.sql.execution.WholeStageCodegenExec] + assert(node.child.isInstanceOf[org.apache.spark.sql.execution.FilterExec]) + df + } + test("SELECT *") { assert(sql("SELECT * FROM foobar").collect().size === 3) } test("SELECT * WHERE (simple predicates)") { - def checkPushdown(df: DataFrame): DataFrame = { - val parentPlan = df.queryExecution.executedPlan - // Check if SparkPlan Filter is removed in a physical plan and - // the plan only has PhysicalRDD to scan JDBCRelation. - assert(parentPlan.isInstanceOf[org.apache.spark.sql.execution.WholeStageCodegenExec]) - val node = parentPlan.asInstanceOf[org.apache.spark.sql.execution.WholeStageCodegenExec] - assert(node.child.isInstanceOf[org.apache.spark.sql.execution.DataSourceScanExec]) - assert(node.child.asInstanceOf[DataSourceScanExec].nodeName.contains("JDBCRelation")) - df - } assert(checkPushdown(sql("SELECT * FROM foobar WHERE THEID < 1")).collect().size == 0) assert(checkPushdown(sql("SELECT * FROM foobar WHERE THEID != 2")).collect().size == 2) assert(checkPushdown(sql("SELECT * FROM foobar WHERE THEID = 1")).collect().size == 1) @@ -308,15 +330,6 @@ class JDBCSuite extends QueryTest "WHERE (THEID > 0 AND TRIM(NAME) = 'mary') OR (NAME = 'fred')") assert(df2.collect.toSet === Set(Row("fred", 1), Row("mary", 2))) - def checkNotPushdown(df: DataFrame): DataFrame = { - val parentPlan = df.queryExecution.executedPlan - // Check if SparkPlan Filter is not removed in a physical plan because JDBCRDD - // cannot compile given predicates. - assert(parentPlan.isInstanceOf[org.apache.spark.sql.execution.WholeStageCodegenExec]) - val node = parentPlan.asInstanceOf[org.apache.spark.sql.execution.WholeStageCodegenExec] - assert(node.child.isInstanceOf[org.apache.spark.sql.execution.FilterExec]) - df - } assert(checkNotPushdown(sql("SELECT * FROM foobar WHERE (THEID + 1) < 2")).collect().size == 0) assert(checkNotPushdown(sql("SELECT * FROM foobar WHERE (THEID + 2) != 4")).collect().size == 2) } @@ -861,19 +874,51 @@ class JDBCSuite extends QueryTest } test("truncate table query by jdbc dialect") { - val MySQL = JdbcDialects.get("jdbc:mysql://127.0.0.1/db") - val Postgres = JdbcDialects.get("jdbc:postgresql://127.0.0.1/db") + val mysql = JdbcDialects.get("jdbc:mysql://127.0.0.1/db") + val postgres = JdbcDialects.get("jdbc:postgresql://127.0.0.1/db") val db2 = JdbcDialects.get("jdbc:db2://127.0.0.1/db") val h2 = JdbcDialects.get(url) val derby = JdbcDialects.get("jdbc:derby:db") + val oracle = JdbcDialects.get("jdbc:oracle://127.0.0.1/db") + val teradata = JdbcDialects.get("jdbc:teradata://127.0.0.1/db") + val table = "weblogs" val defaultQuery = s"TRUNCATE TABLE $table" val postgresQuery = s"TRUNCATE TABLE ONLY $table" - assert(MySQL.getTruncateQuery(table) == defaultQuery) - assert(Postgres.getTruncateQuery(table) == postgresQuery) - assert(db2.getTruncateQuery(table) == defaultQuery) - assert(h2.getTruncateQuery(table) == defaultQuery) - assert(derby.getTruncateQuery(table) == defaultQuery) + val teradataQuery = s"DELETE FROM $table ALL" + + Seq(mysql, db2, h2, derby).foreach{ dialect => + assert(dialect.getTruncateQuery(table, Some(true)) == defaultQuery) + } + + assert(postgres.getTruncateQuery(table) == postgresQuery) + assert(oracle.getTruncateQuery(table) == defaultQuery) + assert(teradata.getTruncateQuery(table) == teradataQuery) + } + + test("SPARK-22880: Truncate table with CASCADE by jdbc dialect") { + // cascade in a truncate should only be applied for databases that support this, + // even if the parameter is passed. + val mysql = JdbcDialects.get("jdbc:mysql://127.0.0.1/db") + val postgres = JdbcDialects.get("jdbc:postgresql://127.0.0.1/db") + val db2 = JdbcDialects.get("jdbc:db2://127.0.0.1/db") + val h2 = JdbcDialects.get(url) + val derby = JdbcDialects.get("jdbc:derby:db") + val oracle = JdbcDialects.get("jdbc:oracle://127.0.0.1/db") + val teradata = JdbcDialects.get("jdbc:teradata://127.0.0.1/db") + + val table = "weblogs" + val defaultQuery = s"TRUNCATE TABLE $table" + val postgresQuery = s"TRUNCATE TABLE ONLY $table CASCADE" + val oracleQuery = s"TRUNCATE TABLE $table CASCADE" + val teradataQuery = s"DELETE FROM $table ALL" + + Seq(mysql, db2, h2, derby).foreach{ dialect => + assert(dialect.getTruncateQuery(table, Some(true)) == defaultQuery) + } + assert(postgres.getTruncateQuery(table, Some(true)) == postgresQuery) + assert(oracle.getTruncateQuery(table, Some(true)) == oracleQuery) + assert(teradata.getTruncateQuery(table, Some(true)) == teradataQuery) } test("Test DataFrame.where for Date and Timestamp") { @@ -1341,6 +1386,96 @@ class JDBCSuite extends QueryTest checkAnswer( sql("select name, theid from queryOption"), Row("fred", 1) :: Nil) + } + + test("SPARK-22814 support date/timestamp types in partitionColumn") { + val expectedResult = Seq( + ("2018-07-06", "2018-07-06 05:50:00.0"), + ("2018-07-06", "2018-07-06 08:10:08.0"), + ("2018-07-08", "2018-07-08 13:32:01.0"), + ("2018-07-12", "2018-07-12 09:51:15.0") + ).map { case (date, timestamp) => + Row(Date.valueOf(date), Timestamp.valueOf(timestamp)) + } + + // DateType partition column + val df1 = spark.read.format("jdbc") + .option("url", urlWithUserAndPass) + .option("dbtable", "TEST.DATETIME") + .option("partitionColumn", "d") + .option("lowerBound", "2018-07-06") + .option("upperBound", "2018-07-20") + .option("numPartitions", 3) + .load() + + df1.logicalPlan match { + case LogicalRelation(JDBCRelation(_, parts, _), _, _, _) => + val whereClauses = parts.map(_.asInstanceOf[JDBCPartition].whereClause).toSet + assert(whereClauses === Set( + """"D" < '2018-07-10' or "D" is null""", + """"D" >= '2018-07-10' AND "D" < '2018-07-14'""", + """"D" >= '2018-07-14'""")) + } + checkAnswer(df1, expectedResult) + + // TimestampType partition column + val df2 = spark.read.format("jdbc") + .option("url", urlWithUserAndPass) + .option("dbtable", "TEST.DATETIME") + .option("partitionColumn", "t") + .option("lowerBound", "2018-07-04 03:30:00.0") + .option("upperBound", "2018-07-27 14:11:05.0") + .option("numPartitions", 2) + .load() + + df2.logicalPlan match { + case LogicalRelation(JDBCRelation(_, parts, _), _, _, _) => + val whereClauses = parts.map(_.asInstanceOf[JDBCPartition].whereClause).toSet + assert(whereClauses === Set( + """"T" < '2018-07-15 20:50:32.5' or "T" is null""", + """"T" >= '2018-07-15 20:50:32.5'""")) + } + checkAnswer(df2, expectedResult) + } + + test("throws an exception for unsupported partition column types") { + val errMsg = intercept[AnalysisException] { + spark.read.format("jdbc") + .option("url", urlWithUserAndPass) + .option("dbtable", "TEST.PEOPLE") + .option("partitionColumn", "name") + .option("lowerBound", "aaa") + .option("upperBound", "zzz") + .option("numPartitions", 2) + .load() + }.getMessage + assert(errMsg.contains( + "Partition column type should be numeric, date, or timestamp, but string found.")) + } + + test("SPARK-24288: Enable preventing predicate pushdown") { + val table = "test.people" + + val df = spark.read.format("jdbc") + .option("Url", urlWithUserAndPass) + .option("dbTable", table) + .option("pushDownPredicate", false) + .load() + .filter("theid = 1") + .select("name", "theid") + checkAnswer( + checkNotPushdown(df), + Row("fred", 1) :: Nil) + // pushDownPredicate option in the create table path. + sql( + s""" + |CREATE OR REPLACE TEMPORARY VIEW predicateOption + |USING org.apache.spark.sql.jdbc + |OPTIONS (url '$urlWithUserAndPass', dbTable '$table', pushDownPredicate 'false') + """.stripMargin.replaceAll("\n", " ")) + checkAnswer( + checkNotPushdown(sql("SELECT name, theid FROM predicateOption WHERE theid = 1")), + Row("fred", 1) :: Nil) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala index 5ff1ea84d9a7b..fc61050dc7458 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala @@ -18,9 +18,10 @@ package org.apache.spark.sql.sources import java.io.File -import java.net.URI import org.apache.spark.sql.{AnalysisException, QueryTest} +import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.catalog.BucketSpec import org.apache.spark.sql.catalyst.expressions.UnsafeProjection import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning import org.apache.spark.sql.execution.datasources.BucketingUtils @@ -48,16 +49,40 @@ abstract class BucketedWriteSuite extends QueryTest with SQLTestUtils { intercept[AnalysisException](df.write.bucketBy(2, "k").saveAsTable("tt")) } - test("numBuckets be greater than 0 but less than 100000") { + test("numBuckets be greater than 0 but less/eq than default bucketing.maxBuckets (100000)") { val df = Seq(1 -> "a", 2 -> "b").toDF("i", "j") - Seq(-1, 0, 100000).foreach(numBuckets => { + Seq(-1, 0, 100001).foreach(numBuckets => { val e = intercept[AnalysisException](df.write.bucketBy(numBuckets, "i").saveAsTable("tt")) assert( - e.getMessage.contains("Number of buckets should be greater than 0 but less than 100000")) + e.getMessage.contains("Number of buckets should be greater than 0 but less than")) }) } + test("numBuckets be greater than 0 but less/eq than overridden bucketing.maxBuckets (200000)") { + val maxNrBuckets: Int = 200000 + val catalog = spark.sessionState.catalog + + withSQLConf("spark.sql.sources.bucketing.maxBuckets" -> maxNrBuckets.toString) { + // within the new limit + Seq(100001, maxNrBuckets).foreach(numBuckets => { + withTable("t") { + df.write.bucketBy(numBuckets, "i").saveAsTable("t") + val table = catalog.getTableMetadata(TableIdentifier("t")) + assert(table.bucketSpec == Option(BucketSpec(numBuckets, Seq("i"), Seq()))) + } + }) + + // over the new limit + withTable("t") { + val e = intercept[AnalysisException]( + df.write.bucketBy(maxNrBuckets + 1, "i").saveAsTable("t")) + assert( + e.getMessage.contains("Number of buckets should be greater than 0 but less than")) + } + } + } + test("specify sorting columns without bucketing columns") { val df = Seq(1 -> "a", 2 -> "b").toDF("i", "j") val e = intercept[AnalysisException] { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala index 916a01ee0ca8e..d46029e84433c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala @@ -225,7 +225,7 @@ class CreateTableAsSelectSuite test("create table using as select - with invalid number of buckets") { withTable("t") { - Seq(0, 100000).foreach(numBuckets => { + Seq(0, 100001).foreach(numBuckets => { val e = intercept[AnalysisException] { sql( s""" @@ -236,11 +236,42 @@ class CreateTableAsSelectSuite """.stripMargin ) }.getMessage - assert(e.contains("Number of buckets should be greater than 0 but less than 100000")) + assert(e.contains("Number of buckets should be greater than 0 but less than")) }) } } + test("create table using as select - with overriden max number of buckets") { + def createTableSql(numBuckets: Int): String = + s""" + |CREATE TABLE t USING PARQUET + |OPTIONS (PATH '${path.toURI}') + |CLUSTERED BY (a) SORTED BY (b) INTO $numBuckets BUCKETS + |AS SELECT 1 AS a, 2 AS b + """.stripMargin + + val maxNrBuckets: Int = 200000 + val catalog = spark.sessionState.catalog + withSQLConf("spark.sql.sources.bucketing.maxBuckets" -> maxNrBuckets.toString) { + + // Within the new limit + Seq(100001, maxNrBuckets).foreach(numBuckets => { + withTable("t") { + sql(createTableSql(numBuckets)) + val table = catalog.getTableMetadata(TableIdentifier("t")) + assert(table.bucketSpec == Option(BucketSpec(numBuckets, Seq("a"), Seq("b")))) + } + }) + + // Over the new limit + withTable("t") { + val e = intercept[AnalysisException](sql(createTableSql(maxNrBuckets + 1))) + assert( + e.getMessage.contains("Number of buckets should be greater than 0 but less than ")) + } + } + } + test("SPARK-17409: CTAS of decimal calculation") { withTable("tab2") { withTempView("tab1") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala index 438d5d8176b8b..0b6d93975daef 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala @@ -545,6 +545,26 @@ class InsertSuite extends DataSourceTest with SharedSQLContext { } } + test("SPARK-24860: dynamic partition overwrite specified per source without catalog table") { + withTempPath { path => + Seq((1, 1), (2, 2)).toDF("i", "part") + .write.partitionBy("part") + .parquet(path.getAbsolutePath) + checkAnswer(spark.read.parquet(path.getAbsolutePath), Row(1, 1) :: Row(2, 2) :: Nil) + + Seq((1, 2), (1, 3)).toDF("i", "part") + .write.partitionBy("part").mode("overwrite") + .option("partitionOverwriteMode", "dynamic").parquet(path.getAbsolutePath) + checkAnswer(spark.read.parquet(path.getAbsolutePath), + Row(1, 1) :: Row(1, 2) :: Row(1, 3) :: Nil) + + Seq((1, 2), (1, 3)).toDF("i", "part") + .write.partitionBy("part").mode("overwrite") + .option("partitionOverwriteMode", "static").parquet(path.getAbsolutePath) + checkAnswer(spark.read.parquet(path.getAbsolutePath), Row(1, 2) :: Row(1, 3) :: Nil) + } + } + test("SPARK-24583 Wrong schema type in InsertIntoDataSourceCommand") { withTable("test_table") { val schema = new StructType() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/ResolvedDataSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/ResolvedDataSourceSuite.scala index 4adbff5c663bc..0aa67bf1b0d48 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/ResolvedDataSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/ResolvedDataSourceSuite.scala @@ -76,20 +76,28 @@ class ResolvedDataSourceSuite extends SparkFunSuite with SharedSQLContext { classOf[org.apache.spark.sql.execution.datasources.csv.CSVFileFormat]) } - test("error message for unknown data sources") { - val error1 = intercept[AnalysisException] { - getProvidingClass("avro") + test("avro: show deploy guide for loading the external avro module") { + Seq("avro", "org.apache.spark.sql.avro").foreach { provider => + val message = intercept[AnalysisException] { + getProvidingClass(provider) + }.getMessage + assert(message.contains(s"Failed to find data source: $provider")) + assert(message.contains("Please deploy the application as per the deployment section of")) } - assert(error1.getMessage.contains("Failed to find data source: avro.")) + } - val error2 = intercept[AnalysisException] { - getProvidingClass("com.databricks.spark.avro") - } - assert(error2.getMessage.contains("Failed to find data source: com.databricks.spark.avro.")) + test("kafka: show deploy guide for loading the external kafka module") { + val message = intercept[AnalysisException] { + getProvidingClass("kafka") + }.getMessage + assert(message.contains("Failed to find data source: kafka")) + assert(message.contains("Please deploy the application as per the deployment section of")) + } - val error3 = intercept[ClassNotFoundException] { + test("error message for unknown data sources") { + val error = intercept[ClassNotFoundException] { getProvidingClass("asfdwefasdfasdf") } - assert(error3.getMessage.contains("Failed to find data source: asfdwefasdfasdf.")) + assert(error.getMessage.contains("Failed to find data source: asfdwefasdfasdf.")) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala index 17690e3df9155..13a126ff963d5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala @@ -28,6 +28,8 @@ import org.apache.spark.sql.types._ class DefaultSource extends SimpleScanSource +// This class is used by pyspark tests. If this class is modified/moved, make sure pyspark +// tests still pass. class SimpleScanSource extends RelationProvider { override def createRelation( sqlContext: SQLContext, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala index e96cd4500458d..f6c3e0ce82e3f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala @@ -17,13 +17,11 @@ package org.apache.spark.sql.sources.v2 -import java.util.{ArrayList, List => JList} - import test.org.apache.spark.sql.sources.v2._ import org.apache.spark.SparkException -import org.apache.spark.sql.{AnalysisException, DataFrame, QueryTest, Row} -import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.sql.{DataFrame, QueryTest, Row} +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, DataSourceV2ScanExec} import org.apache.spark.sql.execution.exchange.{Exchange, ShuffleExchangeExec} import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector @@ -38,6 +36,21 @@ import org.apache.spark.sql.vectorized.ColumnarBatch class DataSourceV2Suite extends QueryTest with SharedSQLContext { import testImplicits._ + private def getScanConfig(query: DataFrame): AdvancedScanConfigBuilder = { + query.queryExecution.executedPlan.collect { + case d: DataSourceV2ScanExec => + d.scanConfig.asInstanceOf[AdvancedScanConfigBuilder] + }.head + } + + private def getJavaScanConfig( + query: DataFrame): JavaAdvancedDataSourceV2.AdvancedScanConfigBuilder = { + query.queryExecution.executedPlan.collect { + case d: DataSourceV2ScanExec => + d.scanConfig.asInstanceOf[JavaAdvancedDataSourceV2.AdvancedScanConfigBuilder] + }.head + } + test("simplest implementation") { Seq(classOf[SimpleDataSourceV2], classOf[JavaSimpleDataSourceV2]).foreach { cls => withClue(cls.getName) { @@ -50,18 +63,6 @@ class DataSourceV2Suite extends QueryTest with SharedSQLContext { } test("advanced implementation") { - def getReader(query: DataFrame): AdvancedDataSourceV2#Reader = { - query.queryExecution.executedPlan.collect { - case d: DataSourceV2ScanExec => d.reader.asInstanceOf[AdvancedDataSourceV2#Reader] - }.head - } - - def getJavaReader(query: DataFrame): JavaAdvancedDataSourceV2#Reader = { - query.queryExecution.executedPlan.collect { - case d: DataSourceV2ScanExec => d.reader.asInstanceOf[JavaAdvancedDataSourceV2#Reader] - }.head - } - Seq(classOf[AdvancedDataSourceV2], classOf[JavaAdvancedDataSourceV2]).foreach { cls => withClue(cls.getName) { val df = spark.read.format(cls.getName).load() @@ -70,69 +71,58 @@ class DataSourceV2Suite extends QueryTest with SharedSQLContext { val q1 = df.select('j) checkAnswer(q1, (0 until 10).map(i => Row(-i))) if (cls == classOf[AdvancedDataSourceV2]) { - val reader = getReader(q1) - assert(reader.filters.isEmpty) - assert(reader.requiredSchema.fieldNames === Seq("j")) + val config = getScanConfig(q1) + assert(config.filters.isEmpty) + assert(config.requiredSchema.fieldNames === Seq("j")) } else { - val reader = getJavaReader(q1) - assert(reader.filters.isEmpty) - assert(reader.requiredSchema.fieldNames === Seq("j")) + val config = getJavaScanConfig(q1) + assert(config.filters.isEmpty) + assert(config.requiredSchema.fieldNames === Seq("j")) } val q2 = df.filter('i > 3) checkAnswer(q2, (4 until 10).map(i => Row(i, -i))) if (cls == classOf[AdvancedDataSourceV2]) { - val reader = getReader(q2) - assert(reader.filters.flatMap(_.references).toSet == Set("i")) - assert(reader.requiredSchema.fieldNames === Seq("i", "j")) + val config = getScanConfig(q2) + assert(config.filters.flatMap(_.references).toSet == Set("i")) + assert(config.requiredSchema.fieldNames === Seq("i", "j")) } else { - val reader = getJavaReader(q2) - assert(reader.filters.flatMap(_.references).toSet == Set("i")) - assert(reader.requiredSchema.fieldNames === Seq("i", "j")) + val config = getJavaScanConfig(q2) + assert(config.filters.flatMap(_.references).toSet == Set("i")) + assert(config.requiredSchema.fieldNames === Seq("i", "j")) } val q3 = df.select('i).filter('i > 6) checkAnswer(q3, (7 until 10).map(i => Row(i))) if (cls == classOf[AdvancedDataSourceV2]) { - val reader = getReader(q3) - assert(reader.filters.flatMap(_.references).toSet == Set("i")) - assert(reader.requiredSchema.fieldNames === Seq("i")) + val config = getScanConfig(q3) + assert(config.filters.flatMap(_.references).toSet == Set("i")) + assert(config.requiredSchema.fieldNames === Seq("i")) } else { - val reader = getJavaReader(q3) - assert(reader.filters.flatMap(_.references).toSet == Set("i")) - assert(reader.requiredSchema.fieldNames === Seq("i")) + val config = getJavaScanConfig(q3) + assert(config.filters.flatMap(_.references).toSet == Set("i")) + assert(config.requiredSchema.fieldNames === Seq("i")) } val q4 = df.select('j).filter('j < -10) checkAnswer(q4, Nil) if (cls == classOf[AdvancedDataSourceV2]) { - val reader = getReader(q4) + val config = getScanConfig(q4) // 'j < 10 is not supported by the testing data source. - assert(reader.filters.isEmpty) - assert(reader.requiredSchema.fieldNames === Seq("j")) + assert(config.filters.isEmpty) + assert(config.requiredSchema.fieldNames === Seq("j")) } else { - val reader = getJavaReader(q4) + val config = getJavaScanConfig(q4) // 'j < 10 is not supported by the testing data source. - assert(reader.filters.isEmpty) - assert(reader.requiredSchema.fieldNames === Seq("j")) + assert(config.filters.isEmpty) + assert(config.requiredSchema.fieldNames === Seq("j")) } } } } - test("unsafe row scan implementation") { - Seq(classOf[UnsafeRowDataSourceV2], classOf[JavaUnsafeRowDataSourceV2]).foreach { cls => - withClue(cls.getName) { - val df = spark.read.format(cls.getName).load() - checkAnswer(df, (0 until 10).map(i => Row(i, -i))) - checkAnswer(df.select('j), (0 until 10).map(i => Row(-i))) - checkAnswer(df.filter('i > 5), (6 until 10).map(i => Row(i, -i))) - } - } - } - test("columnar batch scan implementation") { - Seq(classOf[BatchDataSourceV2], classOf[JavaBatchDataSourceV2]).foreach { cls => + Seq(classOf[ColumnarDataSourceV2], classOf[JavaColumnarDataSourceV2]).foreach { cls => withClue(cls.getName) { val df = spark.read.format(cls.getName).load() checkAnswer(df, (0 until 90).map(i => Row(i, -i))) @@ -145,8 +135,8 @@ class DataSourceV2Suite extends QueryTest with SharedSQLContext { test("schema required data source") { Seq(classOf[SchemaRequiredDataSource], classOf[JavaSchemaRequiredDataSource]).foreach { cls => withClue(cls.getName) { - val e = intercept[AnalysisException](spark.read.format(cls.getName).load()) - assert(e.message.contains("requires a user-supplied schema")) + val e = intercept[IllegalArgumentException](spark.read.format(cls.getName).load()) + assert(e.getMessage.contains("requires a user-supplied schema")) val schema = new StructType().add("i", "int").add("s", "string") val df = spark.read.format(cls.getName).schema(schema).load() @@ -164,25 +154,25 @@ class DataSourceV2Suite extends QueryTest with SharedSQLContext { val df = spark.read.format(cls.getName).load() checkAnswer(df, Seq(Row(1, 4), Row(1, 4), Row(3, 6), Row(2, 6), Row(4, 2), Row(4, 2))) - val groupByColA = df.groupBy('a).agg(sum('b)) + val groupByColA = df.groupBy('i).agg(sum('j)) checkAnswer(groupByColA, Seq(Row(1, 8), Row(2, 6), Row(3, 6), Row(4, 4))) assert(groupByColA.queryExecution.executedPlan.collectFirst { case e: ShuffleExchangeExec => e }.isEmpty) - val groupByColAB = df.groupBy('a, 'b).agg(count("*")) + val groupByColAB = df.groupBy('i, 'j).agg(count("*")) checkAnswer(groupByColAB, Seq(Row(1, 4, 2), Row(2, 6, 1), Row(3, 6, 1), Row(4, 2, 2))) assert(groupByColAB.queryExecution.executedPlan.collectFirst { case e: ShuffleExchangeExec => e }.isEmpty) - val groupByColB = df.groupBy('b).agg(sum('a)) + val groupByColB = df.groupBy('j).agg(sum('i)) checkAnswer(groupByColB, Seq(Row(2, 8), Row(4, 2), Row(6, 5))) assert(groupByColB.queryExecution.executedPlan.collectFirst { case e: ShuffleExchangeExec => e }.isDefined) - val groupByAPlusB = df.groupBy('a + 'b).agg(count("*")) + val groupByAPlusB = df.groupBy('i + 'j).agg(count("*")) checkAnswer(groupByAPlusB, Seq(Row(5, 2), Row(6, 2), Row(8, 1), Row(9, 1))) assert(groupByAPlusB.queryExecution.executedPlan.collectFirst { case e: ShuffleExchangeExec => e @@ -203,33 +193,33 @@ class DataSourceV2Suite extends QueryTest with SharedSQLContext { val path = file.getCanonicalPath assert(spark.read.format(cls.getName).option("path", path).load().collect().isEmpty) - spark.range(10).select('id, -'id).write.format(cls.getName) + spark.range(10).select('id as 'i, -'id as 'j).write.format(cls.getName) .option("path", path).save() checkAnswer( spark.read.format(cls.getName).option("path", path).load(), spark.range(10).select('id, -'id)) // test with different save modes - spark.range(10).select('id, -'id).write.format(cls.getName) + spark.range(10).select('id as 'i, -'id as 'j).write.format(cls.getName) .option("path", path).mode("append").save() checkAnswer( spark.read.format(cls.getName).option("path", path).load(), spark.range(10).union(spark.range(10)).select('id, -'id)) - spark.range(5).select('id, -'id).write.format(cls.getName) + spark.range(5).select('id as 'i, -'id as 'j).write.format(cls.getName) .option("path", path).mode("overwrite").save() checkAnswer( spark.read.format(cls.getName).option("path", path).load(), spark.range(5).select('id, -'id)) - spark.range(5).select('id, -'id).write.format(cls.getName) + spark.range(5).select('id as 'i, -'id as 'j).write.format(cls.getName) .option("path", path).mode("ignore").save() checkAnswer( spark.read.format(cls.getName).option("path", path).load(), spark.range(5).select('id, -'id)) val e = intercept[Exception] { - spark.range(5).select('id, -'id).write.format(cls.getName) + spark.range(5).select('id as 'i, -'id as 'j).write.format(cls.getName) .option("path", path).mode("error").save() } assert(e.getMessage.contains("data already exists")) @@ -246,20 +236,13 @@ class DataSourceV2Suite extends QueryTest with SharedSQLContext { } } // this input data will fail to read middle way. - val input = spark.range(10).select(failingUdf('id).as('i)).select('i, -'i) + val input = spark.range(10).select(failingUdf('id).as('i)).select('i, -'i as 'j) val e2 = intercept[SparkException] { input.write.format(cls.getName).option("path", path).mode("overwrite").save() } assert(e2.getMessage.contains("Writing job aborted")) // make sure we don't have partial data. assert(spark.read.format(cls.getName).option("path", path).load().collect().isEmpty) - - // test internal row writer - spark.range(5).select('id, -'id).write.format(cls.getName) - .option("path", path).option("internal", "true").mode("overwrite").save() - checkAnswer( - spark.read.format(cls.getName).option("path", path).load(), - spark.range(5).select('id, -'id)) } } } @@ -271,7 +254,7 @@ class DataSourceV2Suite extends QueryTest with SharedSQLContext { assert(spark.read.format(cls.getName).option("path", path).load().collect().isEmpty) val numPartition = 6 - spark.range(0, 10, 1, numPartition).select('id, -'id).write.format(cls.getName) + spark.range(0, 10, 1, numPartition).select('id as 'i, -'id as 'j).write.format(cls.getName) .option("path", path).save() checkAnswer( spark.read.format(cls.getName).option("path", path).load(), @@ -290,36 +273,30 @@ class DataSourceV2Suite extends QueryTest with SharedSQLContext { } test("SPARK-23301: column pruning with arbitrary expressions") { - def getReader(query: DataFrame): AdvancedDataSourceV2#Reader = { - query.queryExecution.executedPlan.collect { - case d: DataSourceV2ScanExec => d.reader.asInstanceOf[AdvancedDataSourceV2#Reader] - }.head - } - val df = spark.read.format(classOf[AdvancedDataSourceV2].getName).load() val q1 = df.select('i + 1) checkAnswer(q1, (1 until 11).map(i => Row(i))) - val reader1 = getReader(q1) - assert(reader1.requiredSchema.fieldNames === Seq("i")) + val config1 = getScanConfig(q1) + assert(config1.requiredSchema.fieldNames === Seq("i")) val q2 = df.select(lit(1)) checkAnswer(q2, (0 until 10).map(i => Row(1))) - val reader2 = getReader(q2) - assert(reader2.requiredSchema.isEmpty) + val config2 = getScanConfig(q2) + assert(config2.requiredSchema.isEmpty) // 'j === 1 can't be pushed down, but we should still be able do column pruning val q3 = df.filter('j === -1).select('j * 2) checkAnswer(q3, Row(-2)) - val reader3 = getReader(q3) - assert(reader3.filters.isEmpty) - assert(reader3.requiredSchema.fieldNames === Seq("j")) + val config3 = getScanConfig(q3) + assert(config3.filters.isEmpty) + assert(config3.requiredSchema.fieldNames === Seq("j")) // column pruning should work with other operators. val q4 = df.sort('i).limit(1).select('i + 1) checkAnswer(q4, Row(1)) - val reader4 = getReader(q4) - assert(reader4.requiredSchema.fieldNames === Seq("i")) + val config4 = getScanConfig(q4) + assert(config4.requiredSchema.fieldNames === Seq("i")) } test("SPARK-23315: get output from canonicalized data source v2 related plans") { @@ -342,272 +319,291 @@ class DataSourceV2Suite extends QueryTest with SharedSQLContext { } } -class SimpleSinglePartitionSource extends DataSourceV2 with ReadSupport { - class Reader extends DataSourceReader { - override def readSchema(): StructType = new StructType().add("i", "int").add("j", "int") +case class RangeInputPartition(start: Int, end: Int) extends InputPartition - override def planInputPartitions(): JList[InputPartition[Row]] = { - java.util.Arrays.asList(new SimpleInputPartition(0, 5)) - } - } - - override def createReader(options: DataSourceOptions): DataSourceReader = new Reader +case class NoopScanConfigBuilder(readSchema: StructType) extends ScanConfigBuilder with ScanConfig { + override def build(): ScanConfig = this } -class SimpleDataSourceV2 extends DataSourceV2 with ReadSupport { +object SimpleReaderFactory extends PartitionReaderFactory { + override def createReader(partition: InputPartition): PartitionReader[InternalRow] = { + val RangeInputPartition(start, end) = partition + new PartitionReader[InternalRow] { + private var current = start - 1 - class Reader extends DataSourceReader { - override def readSchema(): StructType = new StructType().add("i", "int").add("j", "int") + override def next(): Boolean = { + current += 1 + current < end + } - override def planInputPartitions(): JList[InputPartition[Row]] = { - java.util.Arrays.asList(new SimpleInputPartition(0, 5), new SimpleInputPartition(5, 10)) + override def get(): InternalRow = InternalRow(current, -current) + + override def close(): Unit = {} } } - - override def createReader(options: DataSourceOptions): DataSourceReader = new Reader } -class SimpleInputPartition(start: Int, end: Int) - extends InputPartition[Row] - with InputPartitionReader[Row] { - private var current = start - 1 +abstract class SimpleReadSupport extends BatchReadSupport { + override def fullSchema(): StructType = new StructType().add("i", "int").add("j", "int") - override def createPartitionReader(): InputPartitionReader[Row] = - new SimpleInputPartition(start, end) - - override def next(): Boolean = { - current += 1 - current < end + override def newScanConfigBuilder(): ScanConfigBuilder = { + NoopScanConfigBuilder(fullSchema()) } - override def get(): Row = Row(current, -current) - - override def close(): Unit = {} + override def createReaderFactory(config: ScanConfig): PartitionReaderFactory = { + SimpleReaderFactory + } } +class SimpleSinglePartitionSource extends DataSourceV2 with BatchReadSupportProvider { -class AdvancedDataSourceV2 extends DataSourceV2 with ReadSupport { + class ReadSupport extends SimpleReadSupport { + override def planInputPartitions(config: ScanConfig): Array[InputPartition] = { + Array(RangeInputPartition(0, 5)) + } + } - class Reader extends DataSourceReader - with SupportsPushDownRequiredColumns with SupportsPushDownFilters { + override def createBatchReadSupport(options: DataSourceOptions): BatchReadSupport = { + new ReadSupport + } +} - var requiredSchema = new StructType().add("i", "int").add("j", "int") - var filters = Array.empty[Filter] +// This class is used by pyspark tests. If this class is modified/moved, make sure pyspark +// tests still pass. +class SimpleDataSourceV2 extends DataSourceV2 with BatchReadSupportProvider { - override def pruneColumns(requiredSchema: StructType): Unit = { - this.requiredSchema = requiredSchema + class ReadSupport extends SimpleReadSupport { + override def planInputPartitions(config: ScanConfig): Array[InputPartition] = { + Array(RangeInputPartition(0, 5), RangeInputPartition(5, 10)) } + } - override def pushFilters(filters: Array[Filter]): Array[Filter] = { - val (supported, unsupported) = filters.partition { - case GreaterThan("i", _: Int) => true - case _ => false - } - this.filters = supported - unsupported - } + override def createBatchReadSupport(options: DataSourceOptions): BatchReadSupport = { + new ReadSupport + } +} - override def pushedFilters(): Array[Filter] = filters - override def readSchema(): StructType = { - requiredSchema - } +class AdvancedDataSourceV2 extends DataSourceV2 with BatchReadSupportProvider { + + class ReadSupport extends SimpleReadSupport { + override def newScanConfigBuilder(): ScanConfigBuilder = new AdvancedScanConfigBuilder() - override def planInputPartitions(): JList[InputPartition[Row]] = { - val lowerBound = filters.collect { + override def planInputPartitions(config: ScanConfig): Array[InputPartition] = { + val filters = config.asInstanceOf[AdvancedScanConfigBuilder].filters + + val lowerBound = filters.collectFirst { case GreaterThan("i", v: Int) => v - }.headOption + } - val res = new ArrayList[InputPartition[Row]] + val res = scala.collection.mutable.ArrayBuffer.empty[InputPartition] if (lowerBound.isEmpty) { - res.add(new AdvancedInputPartition(0, 5, requiredSchema)) - res.add(new AdvancedInputPartition(5, 10, requiredSchema)) + res.append(RangeInputPartition(0, 5)) + res.append(RangeInputPartition(5, 10)) } else if (lowerBound.get < 4) { - res.add(new AdvancedInputPartition(lowerBound.get + 1, 5, requiredSchema)) - res.add(new AdvancedInputPartition(5, 10, requiredSchema)) + res.append(RangeInputPartition(lowerBound.get + 1, 5)) + res.append(RangeInputPartition(5, 10)) } else if (lowerBound.get < 9) { - res.add(new AdvancedInputPartition(lowerBound.get + 1, 10, requiredSchema)) + res.append(RangeInputPartition(lowerBound.get + 1, 10)) } - res + res.toArray + } + + override def createReaderFactory(config: ScanConfig): PartitionReaderFactory = { + val requiredSchema = config.asInstanceOf[AdvancedScanConfigBuilder].requiredSchema + new AdvancedReaderFactory(requiredSchema) } } - override def createReader(options: DataSourceOptions): DataSourceReader = new Reader + override def createBatchReadSupport(options: DataSourceOptions): BatchReadSupport = { + new ReadSupport + } } -class AdvancedInputPartition(start: Int, end: Int, requiredSchema: StructType) - extends InputPartition[Row] with InputPartitionReader[Row] { +class AdvancedScanConfigBuilder extends ScanConfigBuilder with ScanConfig + with SupportsPushDownRequiredColumns with SupportsPushDownFilters { - private var current = start - 1 + var requiredSchema = new StructType().add("i", "int").add("j", "int") + var filters = Array.empty[Filter] - override def createPartitionReader(): InputPartitionReader[Row] = { - new AdvancedInputPartition(start, end, requiredSchema) + override def pruneColumns(requiredSchema: StructType): Unit = { + this.requiredSchema = requiredSchema } - override def close(): Unit = {} + override def readSchema(): StructType = requiredSchema - override def next(): Boolean = { - current += 1 - current < end - } - - override def get(): Row = { - val values = requiredSchema.map(_.name).map { - case "i" => current - case "j" => -current + override def pushFilters(filters: Array[Filter]): Array[Filter] = { + val (supported, unsupported) = filters.partition { + case GreaterThan("i", _: Int) => true + case _ => false } - Row.fromSeq(values) + this.filters = supported + unsupported } + + override def pushedFilters(): Array[Filter] = filters + + override def build(): ScanConfig = this } +class AdvancedReaderFactory(requiredSchema: StructType) extends PartitionReaderFactory { + override def createReader(partition: InputPartition): PartitionReader[InternalRow] = { + val RangeInputPartition(start, end) = partition + new PartitionReader[InternalRow] { + private var current = start - 1 -class UnsafeRowDataSourceV2 extends DataSourceV2 with ReadSupport { + override def next(): Boolean = { + current += 1 + current < end + } - class Reader extends DataSourceReader with SupportsScanUnsafeRow { - override def readSchema(): StructType = new StructType().add("i", "int").add("j", "int") + override def get(): InternalRow = { + val values = requiredSchema.map(_.name).map { + case "i" => current + case "j" => -current + } + InternalRow.fromSeq(values) + } - override def planUnsafeInputPartitions(): JList[InputPartition[UnsafeRow]] = { - java.util.Arrays.asList(new UnsafeRowInputPartitionReader(0, 5), - new UnsafeRowInputPartitionReader(5, 10)) + override def close(): Unit = {} } } - - override def createReader(options: DataSourceOptions): DataSourceReader = new Reader } -class UnsafeRowInputPartitionReader(start: Int, end: Int) - extends InputPartition[UnsafeRow] with InputPartitionReader[UnsafeRow] { - private val row = new UnsafeRow(2) - row.pointTo(new Array[Byte](8 * 3), 8 * 3) +class SchemaRequiredDataSource extends DataSourceV2 with BatchReadSupportProvider { - private var current = start - 1 + class ReadSupport(val schema: StructType) extends SimpleReadSupport { + override def fullSchema(): StructType = schema - override def createPartitionReader(): InputPartitionReader[UnsafeRow] = this - - override def next(): Boolean = { - current += 1 - current < end - } - override def get(): UnsafeRow = { - row.setInt(0, current) - row.setInt(1, -current) - row + override def planInputPartitions(config: ScanConfig): Array[InputPartition] = + Array.empty } - override def close(): Unit = {} -} - -class SchemaRequiredDataSource extends DataSourceV2 with ReadSupportWithSchema { - - class Reader(val readSchema: StructType) extends DataSourceReader { - override def planInputPartitions(): JList[InputPartition[Row]] = - java.util.Collections.emptyList() + override def createBatchReadSupport(options: DataSourceOptions): BatchReadSupport = { + throw new IllegalArgumentException("requires a user-supplied schema") } - override def createReader(schema: StructType, options: DataSourceOptions): DataSourceReader = - new Reader(schema) + override def createBatchReadSupport( + schema: StructType, options: DataSourceOptions): BatchReadSupport = { + new ReadSupport(schema) + } } -class BatchDataSourceV2 extends DataSourceV2 with ReadSupport { +class ColumnarDataSourceV2 extends DataSourceV2 with BatchReadSupportProvider { - class Reader extends DataSourceReader with SupportsScanColumnarBatch { - override def readSchema(): StructType = new StructType().add("i", "int").add("j", "int") + class ReadSupport extends SimpleReadSupport { + override def planInputPartitions(config: ScanConfig): Array[InputPartition] = { + Array(RangeInputPartition(0, 50), RangeInputPartition(50, 90)) + } - override def planBatchInputPartitions(): JList[InputPartition[ColumnarBatch]] = { - java.util.Arrays.asList( - new BatchInputPartitionReader(0, 50), new BatchInputPartitionReader(50, 90)) + override def createReaderFactory(config: ScanConfig): PartitionReaderFactory = { + ColumnarReaderFactory } } - override def createReader(options: DataSourceOptions): DataSourceReader = new Reader + override def createBatchReadSupport(options: DataSourceOptions): BatchReadSupport = { + new ReadSupport + } } -class BatchInputPartitionReader(start: Int, end: Int) - extends InputPartition[ColumnarBatch] with InputPartitionReader[ColumnarBatch] { - +object ColumnarReaderFactory extends PartitionReaderFactory { private final val BATCH_SIZE = 20 - private lazy val i = new OnHeapColumnVector(BATCH_SIZE, IntegerType) - private lazy val j = new OnHeapColumnVector(BATCH_SIZE, IntegerType) - private lazy val batch = new ColumnarBatch(Array(i, j)) - private var current = start + override def supportColumnarReads(partition: InputPartition): Boolean = true - override def createPartitionReader(): InputPartitionReader[ColumnarBatch] = this + override def createReader(partition: InputPartition): PartitionReader[InternalRow] = { + throw new UnsupportedOperationException + } - override def next(): Boolean = { - i.reset() - j.reset() + override def createColumnarReader(partition: InputPartition): PartitionReader[ColumnarBatch] = { + val RangeInputPartition(start, end) = partition + new PartitionReader[ColumnarBatch] { + private lazy val i = new OnHeapColumnVector(BATCH_SIZE, IntegerType) + private lazy val j = new OnHeapColumnVector(BATCH_SIZE, IntegerType) + private lazy val batch = new ColumnarBatch(Array(i, j)) + + private var current = start + + override def next(): Boolean = { + i.reset() + j.reset() + + var count = 0 + while (current < end && count < BATCH_SIZE) { + i.putInt(count, current) + j.putInt(count, -current) + current += 1 + count += 1 + } - var count = 0 - while (current < end && count < BATCH_SIZE) { - i.putInt(count, current) - j.putInt(count, -current) - current += 1 - count += 1 - } + if (count == 0) { + false + } else { + batch.setNumRows(count) + true + } + } - if (count == 0) { - false - } else { - batch.setNumRows(count) - true - } - } + override def get(): ColumnarBatch = batch - override def get(): ColumnarBatch = { - batch + override def close(): Unit = batch.close() + } } - - override def close(): Unit = batch.close() } -class PartitionAwareDataSource extends DataSourceV2 with ReadSupport { - class Reader extends DataSourceReader with SupportsReportPartitioning { - override def readSchema(): StructType = new StructType().add("a", "int").add("b", "int") +class PartitionAwareDataSource extends DataSourceV2 with BatchReadSupportProvider { - override def planInputPartitions(): JList[InputPartition[Row]] = { + class ReadSupport extends SimpleReadSupport with SupportsReportPartitioning { + override def planInputPartitions(config: ScanConfig): Array[InputPartition] = { // Note that we don't have same value of column `a` across partitions. - java.util.Arrays.asList( - new SpecificInputPartitionReader(Array(1, 1, 3), Array(4, 4, 6)), - new SpecificInputPartitionReader(Array(2, 4, 4), Array(6, 2, 2))) + Array( + SpecificInputPartition(Array(1, 1, 3), Array(4, 4, 6)), + SpecificInputPartition(Array(2, 4, 4), Array(6, 2, 2))) } - override def outputPartitioning(): Partitioning = new MyPartitioning + override def createReaderFactory(config: ScanConfig): PartitionReaderFactory = { + SpecificReaderFactory + } + + override def outputPartitioning(config: ScanConfig): Partitioning = new MyPartitioning } class MyPartitioning extends Partitioning { override def numPartitions(): Int = 2 override def satisfy(distribution: Distribution): Boolean = distribution match { - case c: ClusteredDistribution => c.clusteredColumns.contains("a") + case c: ClusteredDistribution => c.clusteredColumns.contains("i") case _ => false } } - override def createReader(options: DataSourceOptions): DataSourceReader = new Reader + override def createBatchReadSupport(options: DataSourceOptions): BatchReadSupport = { + new ReadSupport + } } -class SpecificInputPartitionReader(i: Array[Int], j: Array[Int]) - extends InputPartition[Row] - with InputPartitionReader[Row] { - assert(i.length == j.length) +case class SpecificInputPartition(i: Array[Int], j: Array[Int]) extends InputPartition - private var current = -1 +object SpecificReaderFactory extends PartitionReaderFactory { + override def createReader(partition: InputPartition): PartitionReader[InternalRow] = { + val p = partition.asInstanceOf[SpecificInputPartition] + new PartitionReader[InternalRow] { + private var current = -1 - override def createPartitionReader(): InputPartitionReader[Row] = this - - override def next(): Boolean = { - current += 1 - current < i.length - } + override def next(): Boolean = { + current += 1 + current < p.i.length + } - override def get(): Row = Row(i(current), j(current)) + override def get(): InternalRow = InternalRow(p.i(current), p.j(current)) - override def close(): Unit = {} + override def close(): Unit = {} + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala index 1334cf71ae988..952241b0b6be5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala @@ -18,34 +18,36 @@ package org.apache.spark.sql.sources.v2 import java.io.{BufferedReader, InputStreamReader, IOException} -import java.util.{Collections, List => JList, Optional} +import java.util.Optional import scala.collection.JavaConverters._ import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.{FileSystem, FSDataInputStream, Path} +import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.spark.SparkContext -import org.apache.spark.sql.{Row, SaveMode} +import org.apache.spark.sql.SaveMode import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.sources.v2.reader.{DataSourceReader, InputPartition, InputPartitionReader} +import org.apache.spark.sql.sources.v2.reader._ import org.apache.spark.sql.sources.v2.writer._ import org.apache.spark.sql.types.{DataType, StructType} import org.apache.spark.util.SerializableConfiguration /** * A HDFS based transactional writable data source. - * Each task writes data to `target/_temporary/jobId/$jobId-$partitionId-$attemptNumber`. - * Each job moves files from `target/_temporary/jobId/` to `target`. + * Each task writes data to `target/_temporary/queryId/$jobId-$partitionId-$attemptNumber`. + * Each job moves files from `target/_temporary/queryId/` to `target`. */ -class SimpleWritableDataSource extends DataSourceV2 with ReadSupport with WriteSupport { +class SimpleWritableDataSource extends DataSourceV2 + with BatchReadSupportProvider with BatchWriteSupportProvider { private val schema = new StructType().add("i", "long").add("j", "long") - class Reader(path: String, conf: Configuration) extends DataSourceReader { - override def readSchema(): StructType = schema + class ReadSupport(path: String, conf: Configuration) extends SimpleReadSupport { - override def planInputPartitions(): JList[InputPartition[Row]] = { + override def fullSchema(): StructType = schema + + override def planInputPartitions(config: ScanConfig): Array[InputPartition] = { val dataPath = new Path(path) val fs = dataPath.getFileSystem(conf) if (fs.exists(dataPath)) { @@ -53,21 +55,23 @@ class SimpleWritableDataSource extends DataSourceV2 with ReadSupport with WriteS val name = status.getPath.getName name.startsWith("_") || name.startsWith(".") }.map { f => - val serializableConf = new SerializableConfiguration(conf) - new SimpleCSVInputPartitionReader( - f.getPath.toUri.toString, - serializableConf): InputPartition[Row] - }.toList.asJava + CSVInputPartitionReader(f.getPath.toUri.toString) + }.toArray } else { - Collections.emptyList() + Array.empty } } + + override def createReaderFactory(config: ScanConfig): PartitionReaderFactory = { + val serializableConf = new SerializableConfiguration(conf) + new CSVReaderFactory(serializableConf) + } } - class Writer(jobId: String, path: String, conf: Configuration) extends DataSourceWriter { - override def createWriterFactory(): DataWriterFactory[Row] = { + class WritSupport(queryId: String, path: String, conf: Configuration) extends BatchWriteSupport { + override def createBatchWriterFactory(): DataWriterFactory = { SimpleCounter.resetCounter - new SimpleCSVDataWriterFactory(path, jobId, new SerializableConfiguration(conf)) + new CSVDataWriterFactory(path, queryId, new SerializableConfiguration(conf)) } override def onDataWriterCommit(message: WriterCommitMessage): Unit = { @@ -76,7 +80,7 @@ class SimpleWritableDataSource extends DataSourceV2 with ReadSupport with WriteS override def commit(messages: Array[WriterCommitMessage]): Unit = { val finalPath = new Path(path) - val jobPath = new Path(new Path(finalPath, "_temporary"), jobId) + val jobPath = new Path(new Path(finalPath, "_temporary"), queryId) val fs = jobPath.getFileSystem(conf) try { for (file <- fs.listStatus(jobPath).map(_.getPath)) { @@ -91,40 +95,27 @@ class SimpleWritableDataSource extends DataSourceV2 with ReadSupport with WriteS } override def abort(messages: Array[WriterCommitMessage]): Unit = { - val jobPath = new Path(new Path(path, "_temporary"), jobId) + val jobPath = new Path(new Path(path, "_temporary"), queryId) val fs = jobPath.getFileSystem(conf) fs.delete(jobPath, true) } } - class InternalRowWriter(jobId: String, path: String, conf: Configuration) - extends Writer(jobId, path, conf) with SupportsWriteInternalRow { - - override def createWriterFactory(): DataWriterFactory[Row] = { - throw new IllegalArgumentException("not expected!") - } - - override def createInternalRowWriterFactory(): DataWriterFactory[InternalRow] = { - new InternalRowCSVDataWriterFactory(path, jobId, new SerializableConfiguration(conf)) - } - } - - override def createReader(options: DataSourceOptions): DataSourceReader = { + override def createBatchReadSupport(options: DataSourceOptions): BatchReadSupport = { val path = new Path(options.get("path").get()) val conf = SparkContext.getActive.get.hadoopConfiguration - new Reader(path.toUri.toString, conf) + new ReadSupport(path.toUri.toString, conf) } - override def createWriter( - jobId: String, + override def createBatchWriteSupport( + queryId: String, schema: StructType, mode: SaveMode, - options: DataSourceOptions): Optional[DataSourceWriter] = { + options: DataSourceOptions): Optional[BatchWriteSupport] = { assert(DataType.equalsStructurally(schema.asNullable, this.schema.asNullable)) assert(!SparkContext.getActive.get.conf.getBoolean("spark.speculation", false)) val path = new Path(options.get("path").get()) - val internal = options.get("internal").isPresent val conf = SparkContext.getActive.get.hadoopConfiguration val fs = path.getFileSystem(conf) @@ -142,49 +133,43 @@ class SimpleWritableDataSource extends DataSourceV2 with ReadSupport with WriteS fs.delete(path, true) } - Optional.of(createWriter(jobId, path, conf, internal)) - } - - private def createWriter( - jobId: String, path: Path, conf: Configuration, internal: Boolean): DataSourceWriter = { val pathStr = path.toUri.toString - if (internal) { - new InternalRowWriter(jobId, pathStr, conf) - } else { - new Writer(jobId, pathStr, conf) - } + Optional.of(new WritSupport(queryId, pathStr, conf)) } } -class SimpleCSVInputPartitionReader(path: String, conf: SerializableConfiguration) - extends InputPartition[Row] with InputPartitionReader[Row] { +case class CSVInputPartitionReader(path: String) extends InputPartition - @transient private var lines: Iterator[String] = _ - @transient private var currentLine: String = _ - @transient private var inputStream: FSDataInputStream = _ +class CSVReaderFactory(conf: SerializableConfiguration) + extends PartitionReaderFactory { - override def createPartitionReader(): InputPartitionReader[Row] = { + override def createReader(partition: InputPartition): PartitionReader[InternalRow] = { + val path = partition.asInstanceOf[CSVInputPartitionReader].path val filePath = new Path(path) val fs = filePath.getFileSystem(conf.value) - inputStream = fs.open(filePath) - lines = new BufferedReader(new InputStreamReader(inputStream)) - .lines().iterator().asScala - this - } - override def next(): Boolean = { - if (lines.hasNext) { - currentLine = lines.next() - true - } else { - false - } - } + new PartitionReader[InternalRow] { + private val inputStream = fs.open(filePath) + private val lines = new BufferedReader(new InputStreamReader(inputStream)) + .lines().iterator().asScala + + private var currentLine: String = _ + + override def next(): Boolean = { + if (lines.hasNext) { + currentLine = lines.next() + true + } else { + false + } + } - override def get(): Row = Row(currentLine.split(",").map(_.trim.toLong): _*) + override def get(): InternalRow = InternalRow(currentLine.split(",").map(_.trim.toLong): _*) - override def close(): Unit = { - inputStream.close() + override def close(): Unit = { + inputStream.close() + } + } } } @@ -204,57 +189,20 @@ private[v2] object SimpleCounter { } } -class SimpleCSVDataWriterFactory(path: String, jobId: String, conf: SerializableConfiguration) - extends DataWriterFactory[Row] { +class CSVDataWriterFactory(path: String, jobId: String, conf: SerializableConfiguration) + extends DataWriterFactory { - override def createDataWriter( - partitionId: Int, - taskId: Long, - epochId: Long): DataWriter[Row] = { - val jobPath = new Path(new Path(path, "_temporary"), jobId) - val filePath = new Path(jobPath, s"$jobId-$partitionId-$taskId") - val fs = filePath.getFileSystem(conf.value) - new SimpleCSVDataWriter(fs, filePath) - } -} - -class SimpleCSVDataWriter(fs: FileSystem, file: Path) extends DataWriter[Row] { - - private val out = fs.create(file) - - override def write(record: Row): Unit = { - out.writeBytes(s"${record.getLong(0)},${record.getLong(1)}\n") - } - - override def commit(): WriterCommitMessage = { - out.close() - null - } - - override def abort(): Unit = { - try { - out.close() - } finally { - fs.delete(file, false) - } - } -} - -class InternalRowCSVDataWriterFactory(path: String, jobId: String, conf: SerializableConfiguration) - extends DataWriterFactory[InternalRow] { - - override def createDataWriter( + override def createWriter( partitionId: Int, - taskId: Long, - epochId: Long): DataWriter[InternalRow] = { + taskId: Long): DataWriter[InternalRow] = { val jobPath = new Path(new Path(path, "_temporary"), jobId) val filePath = new Path(jobPath, s"$jobId-$partitionId-$taskId") val fs = filePath.getFileSystem(conf.value) - new InternalRowCSVDataWriter(fs, filePath) + new CSVDataWriter(fs, filePath) } } -class InternalRowCSVDataWriter(fs: FileSystem, file: Path) extends DataWriter[InternalRow] { +class CSVDataWriter(fs: FileSystem, file: Path) extends DataWriter[InternalRow] { private val out = fs.create(file) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala index 7e8fde1ff8e56..026af17c7b23f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala @@ -18,18 +18,22 @@ package org.apache.spark.sql.streaming import java.{util => ju} +import java.io.File import java.text.SimpleDateFormat import java.util.{Calendar, Date} +import org.apache.commons.io.FileUtils import org.scalatest.{BeforeAndAfter, Matchers} import org.apache.spark.internal.Logging -import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.{AnalysisException, Dataset} import org.apache.spark.sql.catalyst.plans.logical.EventTimeWatermark +import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.functions.{count, window} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.OutputMode._ +import org.apache.spark.util.Utils class EventTimeWatermarkSuite extends StreamTest with BeforeAndAfter with Matchers with Logging { @@ -123,31 +127,133 @@ class EventTimeWatermarkSuite extends StreamTest with BeforeAndAfter with Matche testStream(aggWithWatermark)( AddData(inputData2, 15), CheckAnswer(), - assertEventStats { e => - assert(e.get("max") === formatTimestamp(15)) - assert(e.get("min") === formatTimestamp(15)) - assert(e.get("avg") === formatTimestamp(15)) - assert(e.get("watermark") === formatTimestamp(0)) - }, + assertEventStats(min = 15, max = 15, avg = 15, wtrmark = 0), AddData(inputData2, 10, 12, 14), CheckAnswer(), - assertEventStats { e => - assert(e.get("max") === formatTimestamp(14)) - assert(e.get("min") === formatTimestamp(10)) - assert(e.get("avg") === formatTimestamp(12)) - assert(e.get("watermark") === formatTimestamp(5)) - }, + assertEventStats(min = 10, max = 14, avg = 12, wtrmark = 5), AddData(inputData2, 25), CheckAnswer((10, 3)), - assertEventStats { e => - assert(e.get("max") === formatTimestamp(25)) - assert(e.get("min") === formatTimestamp(25)) - assert(e.get("avg") === formatTimestamp(25)) - assert(e.get("watermark") === formatTimestamp(5)) - } + assertEventStats(min = 25, max = 25, avg = 25, wtrmark = 5) ) } + test("event time and watermark metrics with Trigger.Once (SPARK-24699)") { + // All event time metrics where watermarking is set + val inputData = MemoryStream[Int] + val aggWithWatermark = inputData.toDF() + .withColumn("eventTime", $"value".cast("timestamp")) + .withWatermark("eventTime", "10 seconds") + .groupBy(window($"eventTime", "5 seconds") as 'window) + .agg(count("*") as 'count) + .select($"window".getField("start").cast("long").as[Long], $"count".as[Long]) + + // Unlike the ProcessingTime trigger, Trigger.Once only runs one trigger every time + // the query is started and it does not run no-data batches. Hence the answer generated + // by the updated watermark is only generated the next time the query is started. + // Also, the data to process in the next trigger is added *before* starting the stream in + // Trigger.Once to ensure that first and only trigger picks up the new data. + + testStream(aggWithWatermark)( + StartStream(Trigger.Once), // to make sure the query is not running when adding data 1st time + awaitTermination(), + + AddData(inputData, 15), + StartStream(Trigger.Once), + awaitTermination(), + CheckNewAnswer(), + assertEventStats(min = 15, max = 15, avg = 15, wtrmark = 0), + // watermark should be updated to 15 - 10 = 5 + + AddData(inputData, 10, 12, 14), + StartStream(Trigger.Once), + awaitTermination(), + CheckNewAnswer(), + assertEventStats(min = 10, max = 14, avg = 12, wtrmark = 5), + // watermark should stay at 5 + + AddData(inputData, 25), + StartStream(Trigger.Once), + awaitTermination(), + CheckNewAnswer(), + assertEventStats(min = 25, max = 25, avg = 25, wtrmark = 5), + // watermark should be updated to 25 - 10 = 15 + + AddData(inputData, 50), + StartStream(Trigger.Once), + awaitTermination(), + CheckNewAnswer((10, 3)), // watermark = 15 is used to generate this + assertEventStats(min = 50, max = 50, avg = 50, wtrmark = 15), + // watermark should be updated to 50 - 10 = 40 + + AddData(inputData, 50), + StartStream(Trigger.Once), + awaitTermination(), + CheckNewAnswer((15, 1), (25, 1)), // watermark = 40 is used to generate this + assertEventStats(min = 50, max = 50, avg = 50, wtrmark = 40)) + } + + test("recovery from Spark ver 2.3.1 commit log without commit metadata (SPARK-24699)") { + // All event time metrics where watermarking is set + val inputData = MemoryStream[Int] + val aggWithWatermark = inputData.toDF() + .withColumn("eventTime", $"value".cast("timestamp")) + .withWatermark("eventTime", "10 seconds") + .groupBy(window($"eventTime", "5 seconds") as 'window) + .agg(count("*") as 'count) + .select($"window".getField("start").cast("long").as[Long], $"count".as[Long]) + + + val resourceUri = this.getClass.getResource( + "/structured-streaming/checkpoint-version-2.3.1-without-commit-log-metadata/").toURI + + val checkpointDir = Utils.createTempDir().getCanonicalFile + // Copy the checkpoint to a temp dir to prevent changes to the original. + // Not doing this will lead to the test passing on the first run, but fail subsequent runs. + FileUtils.copyDirectory(new File(resourceUri), checkpointDir) + + inputData.addData(15) + inputData.addData(10, 12, 14) + + testStream(aggWithWatermark)( + /* + + Note: The checkpoint was generated using the following input in Spark version 2.3.1 + + StartStream(checkpointLocation = "./sql/core/src/test/resources/structured-streaming/" + + "checkpoint-version-2.3.1-without-commit-log-metadata/")), + AddData(inputData, 15), // watermark should be updated to 15 - 10 = 5 + CheckAnswer(), + AddData(inputData, 10, 12, 14), // watermark should stay at 5 + CheckAnswer(), + StopStream, + + // Offset log should have watermark recorded as 5. + */ + + StartStream(Trigger.Once), + awaitTermination(), + + AddData(inputData, 25), + StartStream(Trigger.Once, checkpointLocation = checkpointDir.getAbsolutePath), + awaitTermination(), + CheckNewAnswer(), + assertEventStats(min = 25, max = 25, avg = 25, wtrmark = 5), + // watermark should be updated to 25 - 10 = 15 + + AddData(inputData, 50), + StartStream(Trigger.Once, checkpointLocation = checkpointDir.getAbsolutePath), + awaitTermination(), + CheckNewAnswer((10, 3)), // watermark = 15 is used to generate this + assertEventStats(min = 50, max = 50, avg = 50, wtrmark = 15), + // watermark should be updated to 50 - 10 = 40 + + AddData(inputData, 50), + StartStream(Trigger.Once, checkpointLocation = checkpointDir.getAbsolutePath), + awaitTermination(), + CheckNewAnswer((15, 1), (25, 1)), // watermark = 40 is used to generate this + assertEventStats(min = 50, max = 50, avg = 50, wtrmark = 40)) + } + test("append mode") { val inputData = MemoryStream[Int] @@ -484,6 +590,136 @@ class EventTimeWatermarkSuite extends StreamTest with BeforeAndAfter with Matche testWithFlag(false) } + test("MultipleWatermarkPolicy: max") { + val input1 = MemoryStream[Int] + val input2 = MemoryStream[Int] + + withSQLConf(SQLConf.STREAMING_MULTIPLE_WATERMARK_POLICY.key -> "max") { + testStream(dfWithMultipleWatermarks(input1, input2))( + MultiAddData(input1, 20)(input2, 30), + CheckLastBatch(20, 30), + checkWatermark(input1, 15), // max(20 - 10, 30 - 15) = 15 + StopStream, + StartStream(), + checkWatermark(input1, 15), // watermark recovered correctly + MultiAddData(input1, 120)(input2, 130), + CheckLastBatch(120, 130), + checkWatermark(input1, 115), // max(120 - 10, 130 - 15) = 115, policy recovered correctly + AddData(input1, 150), + CheckLastBatch(150), + checkWatermark(input1, 140) // should advance even if one of the input has data + ) + } + } + + test("MultipleWatermarkPolicy: min") { + val input1 = MemoryStream[Int] + val input2 = MemoryStream[Int] + + withSQLConf(SQLConf.STREAMING_MULTIPLE_WATERMARK_POLICY.key -> "min") { + testStream(dfWithMultipleWatermarks(input1, input2))( + MultiAddData(input1, 20)(input2, 30), + CheckLastBatch(20, 30), + checkWatermark(input1, 10), // min(20 - 10, 30 - 15) = 10 + StopStream, + StartStream(), + checkWatermark(input1, 10), // watermark recovered correctly + MultiAddData(input1, 120)(input2, 130), + CheckLastBatch(120, 130), + checkWatermark(input2, 110), // min(120 - 10, 130 - 15) = 110, policy recovered correctly + AddData(input2, 150), + CheckLastBatch(150), + checkWatermark(input2, 110) // does not advance when only one of the input has data + ) + } + } + + test("MultipleWatermarkPolicy: recovery from checkpoints ignores session conf") { + val input1 = MemoryStream[Int] + val input2 = MemoryStream[Int] + + val checkpointDir = Utils.createTempDir().getCanonicalFile + withSQLConf(SQLConf.STREAMING_MULTIPLE_WATERMARK_POLICY.key -> "max") { + testStream(dfWithMultipleWatermarks(input1, input2))( + StartStream(checkpointLocation = checkpointDir.getAbsolutePath), + MultiAddData(input1, 20)(input2, 30), + CheckLastBatch(20, 30), + checkWatermark(input1, 15) // max(20 - 10, 30 - 15) = 15 + ) + } + + withSQLConf(SQLConf.STREAMING_MULTIPLE_WATERMARK_POLICY.key -> "min") { + testStream(dfWithMultipleWatermarks(input1, input2))( + StartStream(checkpointLocation = checkpointDir.getAbsolutePath), + checkWatermark(input1, 15), // watermark recovered correctly + MultiAddData(input1, 120)(input2, 130), + CheckLastBatch(120, 130), + checkWatermark(input1, 115), // max(120 - 10, 130 - 15) = 115, policy recovered correctly + AddData(input1, 150), + CheckLastBatch(150), + checkWatermark(input1, 140) // should advance even if one of the input has data + ) + } + } + + test("MultipleWatermarkPolicy: recovery from Spark ver 2.3.1 checkpoints ensures min policy") { + val input1 = MemoryStream[Int] + val input2 = MemoryStream[Int] + + val resourceUri = this.getClass.getResource( + "/structured-streaming/checkpoint-version-2.3.1-for-multi-watermark-policy/").toURI + + val checkpointDir = Utils.createTempDir().getCanonicalFile + // Copy the checkpoint to a temp dir to prevent changes to the original. + // Not doing this will lead to the test passing on the first run, but fail subsequent runs. + FileUtils.copyDirectory(new File(resourceUri), checkpointDir) + + input1.addData(20) + input2.addData(30) + input1.addData(10) + + withSQLConf(SQLConf.STREAMING_MULTIPLE_WATERMARK_POLICY.key -> "max") { + testStream(dfWithMultipleWatermarks(input1, input2))( + StartStream(checkpointLocation = checkpointDir.getAbsolutePath), + Execute { _.processAllAvailable() }, + MultiAddData(input1, 120)(input2, 130), + CheckLastBatch(120, 130), + checkWatermark(input2, 110), // should calculate 'min' even if session conf has 'max' policy + AddData(input2, 150), + CheckLastBatch(150), + checkWatermark(input2, 110) + ) + } + } + + test("MultipleWatermarkPolicy: fail on incorrect conf values") { + val invalidValues = Seq("", "random") + invalidValues.foreach { value => + val e = intercept[IllegalArgumentException] { + spark.conf.set(SQLConf.STREAMING_MULTIPLE_WATERMARK_POLICY.key, value) + } + assert(e.getMessage.toLowerCase.contains("valid values are 'min' and 'max'")) + } + } + + private def dfWithMultipleWatermarks( + input1: MemoryStream[Int], + input2: MemoryStream[Int]): Dataset[_] = { + val df1 = input1.toDF + .withColumn("eventTime", $"value".cast("timestamp")) + .withWatermark("eventTime", "10 seconds") + val df2 = input2.toDF + .withColumn("eventTime", $"value".cast("timestamp")) + .withWatermark("eventTime", "15 seconds") + df1.union(df2).select($"eventTime".cast("int")) + } + + private def checkWatermark(input: MemoryStream[Int], watermark: Long) = Execute { q => + input.addData(1) + q.processAllAvailable() + assert(q.lastProgress.eventTime.get("watermark") == formatTimestamp(watermark)) + } + private def assertNumStateRows(numTotalRows: Long): AssertOnQuery = AssertOnQuery { q => q.processAllAvailable() val progressWithData = q.recentProgress.lastOption.get @@ -491,10 +727,20 @@ class EventTimeWatermarkSuite extends StreamTest with BeforeAndAfter with Matche true } + /** Assert event stats generated on that last batch with data in it */ private def assertEventStats(body: ju.Map[String, String] => Unit): AssertOnQuery = { - AssertOnQuery { q => + Execute("AssertEventStats") { q => body(q.recentProgress.filter(_.numInputRows > 0).lastOption.get.eventTime) - true + } + } + + /** Assert event stats generated on that last batch with data in it */ + private def assertEventStats(min: Long, max: Long, avg: Double, wtrmark: Long): AssertOnQuery = { + assertEventStats { e => + assert(e.get("min") === formatTimestamp(min), s"min value mismatch") + assert(e.get("max") === formatTimestamp(max), s"max value mismatch") + assert(e.get("avg") === formatTimestamp(avg.toLong), s"avg value mismatch") + assert(e.get("watermark") === formatTimestamp(wtrmark), s"watermark value mismatch") } } @@ -504,4 +750,8 @@ class EventTimeWatermarkSuite extends StreamTest with BeforeAndAfter with Matche private def formatTimestamp(sec: Long): String = { timestampFormat.format(new ju.Date(sec * 1000)) } + + private def awaitTermination(): AssertOnQuery = Execute("AwaitTermination") { q => + q.awaitTermination() + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala index 988c8e6753e25..43463a84093ce 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala @@ -17,43 +17,41 @@ package org.apache.spark.sql.streaming +import java.io.File import java.sql.Date -import java.util.concurrent.ConcurrentHashMap +import org.apache.commons.io.FileUtils import org.scalatest.BeforeAndAfterAll import org.scalatest.exceptions.TestFailedException import org.apache.spark.SparkException import org.apache.spark.api.java.function.FlatMapGroupsWithStateFunction import org.apache.spark.sql.Encoder +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeProjection, UnsafeRow} import org.apache.spark.sql.catalyst.plans.logical.FlatMapGroupsWithState import org.apache.spark.sql.catalyst.plans.physical.UnknownPartitioning import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._ import org.apache.spark.sql.execution.RDDScanExec -import org.apache.spark.sql.execution.streaming.{FlatMapGroupsWithStateExec, GroupStateImpl, MemoryStream} -import org.apache.spark.sql.execution.streaming.state.{StateStore, StateStoreId, StateStoreMetrics, UnsafeRowPair} +import org.apache.spark.sql.execution.streaming._ +import org.apache.spark.sql.execution.streaming.state.{FlatMapGroupsWithStateExecHelper, MemoryStateStore, StateStore, StateStoreId, StateStoreMetrics, UnsafeRowPair} +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.util.StreamManualClock import org.apache.spark.sql.types.{DataType, IntegerType} +import org.apache.spark.util.Utils /** Class to check custom state types */ case class RunningCount(count: Long) case class Result(key: Long, count: Int) -class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest - with BeforeAndAfterAll { +class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest { import testImplicits._ import GroupStateImpl._ import GroupStateTimeout._ import FlatMapGroupsWithStateSuite._ - override def afterAll(): Unit = { - super.afterAll() - StateStore.stop() - } - test("GroupState - get, exists, update, remove") { var state: GroupStateImpl[String] = null @@ -359,13 +357,13 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest } } - // Values used for testing StateStoreUpdater + // Values used for testing InputProcessor val currentBatchTimestamp = 1000 val currentBatchWatermark = 1000 val beforeTimeoutThreshold = 999 val afterTimeoutThreshold = 1001 - // Tests for StateStoreUpdater.updateStateForKeysWithData() when timeout = NoTimeout + // Tests for InputProcessor.processNewData() when timeout = NoTimeout for (priorState <- Seq(None, Some(0))) { val priorStateStr = if (priorState.nonEmpty) "prior state set" else "no prior state" val testName = s"NoTimeout - $priorStateStr - " @@ -396,7 +394,7 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest expectedState = None) // should be removed } - // Tests for StateStoreUpdater.updateStateForKeysWithData() when timeout != NoTimeout + // Tests for InputProcessor.processTimedOutState() when timeout != NoTimeout for (priorState <- Seq(None, Some(0))) { for (priorTimeoutTimestamp <- Seq(NO_TIMESTAMP, 1000)) { var testName = "" @@ -443,6 +441,18 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest expectedState = None) // state should be removed } + // Tests with ProcessingTimeTimeout + if (priorState == None) { + testStateUpdateWithData( + s"ProcessingTimeTimeout - $testName - timeout updated without initializing state", + stateUpdates = state => { state.setTimeoutDuration(5000) }, + timeoutConf = ProcessingTimeTimeout, + priorState = None, + priorTimeoutTimestamp = priorTimeoutTimestamp, + expectedState = None, + expectedTimeoutTimestamp = currentBatchTimestamp + 5000) + } + testStateUpdateWithData( s"ProcessingTimeTimeout - $testName - state and timeout duration updated", stateUpdates = @@ -453,6 +463,30 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest expectedState = Some(5), // state should change expectedTimeoutTimestamp = currentBatchTimestamp + 5000) // timestamp should change + testStateUpdateWithData( + s"ProcessingTimeTimeout - $testName - timeout updated after state removed", + stateUpdates = state => { state.remove(); state.setTimeoutDuration(5000) }, + timeoutConf = ProcessingTimeTimeout, + priorState = priorState, + priorTimeoutTimestamp = priorTimeoutTimestamp, + expectedState = None, + expectedTimeoutTimestamp = currentBatchTimestamp + 5000) + + // Tests with EventTimeTimeout + + if (priorState == None) { + testStateUpdateWithData( + s"EventTimeTimeout - $testName - setting timeout without init state not allowed", + stateUpdates = state => { + state.setTimeoutTimestamp(10000) + }, + timeoutConf = EventTimeTimeout, + priorState = None, + priorTimeoutTimestamp = priorTimeoutTimestamp, + expectedState = None, + expectedTimeoutTimestamp = 10000) + } + testStateUpdateWithData( s"EventTimeTimeout - $testName - state and timeout timestamp updated", stateUpdates = @@ -477,48 +511,21 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest priorTimeoutTimestamp = priorTimeoutTimestamp, expectedState = Some(5), // state should change expectedTimeoutTimestamp = NO_TIMESTAMP) // timestamp should not update - } - } - - // Currently disallowed cases for StateStoreUpdater.updateStateForKeysWithData(), - // Try to remove these cases in the future - for (priorTimeoutTimestamp <- Seq(NO_TIMESTAMP, 1000)) { - val testName = - if (priorTimeoutTimestamp != NO_TIMESTAMP) "prior timeout set" else "no prior timeout" - testStateUpdateWithData( - s"ProcessingTimeTimeout - $testName - setting timeout without init state not allowed", - stateUpdates = state => { state.setTimeoutDuration(5000) }, - timeoutConf = ProcessingTimeTimeout, - priorState = None, - priorTimeoutTimestamp = priorTimeoutTimestamp, - expectedException = classOf[IllegalStateException]) - testStateUpdateWithData( - s"ProcessingTimeTimeout - $testName - setting timeout with state removal not allowed", - stateUpdates = state => { state.remove(); state.setTimeoutDuration(5000) }, - timeoutConf = ProcessingTimeTimeout, - priorState = Some(5), - priorTimeoutTimestamp = priorTimeoutTimestamp, - expectedException = classOf[IllegalStateException]) - - testStateUpdateWithData( - s"EventTimeTimeout - $testName - setting timeout without init state not allowed", - stateUpdates = state => { state.setTimeoutTimestamp(10000) }, - timeoutConf = EventTimeTimeout, - priorState = None, - priorTimeoutTimestamp = priorTimeoutTimestamp, - expectedException = classOf[IllegalStateException]) - - testStateUpdateWithData( - s"EventTimeTimeout - $testName - setting timeout with state removal not allowed", - stateUpdates = state => { state.remove(); state.setTimeoutTimestamp(10000) }, - timeoutConf = EventTimeTimeout, - priorState = Some(5), - priorTimeoutTimestamp = priorTimeoutTimestamp, - expectedException = classOf[IllegalStateException]) + testStateUpdateWithData( + s"EventTimeTimeout - $testName - setting timeout with state removal not allowed", + stateUpdates = state => { + state.remove(); state.setTimeoutTimestamp(10000) + }, + timeoutConf = EventTimeTimeout, + priorState = priorState, + priorTimeoutTimestamp = priorTimeoutTimestamp, + expectedState = None, + expectedTimeoutTimestamp = 10000) + } } - // Tests for StateStoreUpdater.updateStateForTimedOutKeys() + // Tests for InputProcessor.processTimedOutState() val preTimeoutState = Some(5) for (timeoutConf <- Seq(ProcessingTimeTimeout, EventTimeTimeout)) { testStateUpdateWithTimeout( @@ -590,7 +597,7 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest expectedState = Some(5), // state should change expectedTimeoutTimestamp = 5000) // timestamp should change - test("flatMapGroupsWithState - streaming") { + testWithAllStateVersions("flatMapGroupsWithState - streaming") { // Function to maintain running count up to 2, and then remove the count // Returns the data and the count if state is defined, otherwise does not return anything val stateFunc = (key: String, values: Iterator[String], state: GroupState[RunningCount]) => { @@ -669,7 +676,7 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest ) } - test("flatMapGroupsWithState - streaming + aggregation") { + testWithAllStateVersions("flatMapGroupsWithState - streaming + aggregation") { // Function to maintain running count up to 2, and then remove the count // Returns the data and the count (-1 if count reached beyond 2 and state was just removed) val stateFunc = (key: String, values: Iterator[String], state: GroupState[RunningCount]) => { @@ -728,7 +735,7 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest checkAnswer(df, Seq(("a", 2), ("b", 1)).toDF) } - test("flatMapGroupsWithState - streaming with processing time timeout") { + testWithAllStateVersions("flatMapGroupsWithState - streaming with processing time timeout") { // Function to maintain the count as state and set the proc. time timeout delay of 10 seconds. // It returns the count if changed, or -1 if the state was removed by timeout. val stateFunc = (key: String, values: Iterator[String], state: GroupState[RunningCount]) => { @@ -792,7 +799,7 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest ) } - test("flatMapGroupsWithState - streaming with event time timeout + watermark") { + testWithAllStateVersions("flatMapGroupsWithState - streaming w/ event time timeout + watermark") { // Function to maintain the max event time as state and set the timeout timestamp based on the // current max event time seen. It returns the max event time in the state, or -1 if the state // was removed by timeout. @@ -843,6 +850,105 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest ) } + test("flatMapGroupsWithState - uses state format version 2 by default") { + val stateFunc = (key: String, values: Iterator[String], state: GroupState[RunningCount]) => { + val count = state.getOption.map(_.count).getOrElse(0L) + values.size + state.update(RunningCount(count)) + Iterator((key, count.toString)) + } + + val inputData = MemoryStream[String] + val result = inputData.toDS() + .groupByKey(x => x) + .flatMapGroupsWithState(Update, GroupStateTimeout.NoTimeout)(stateFunc) + + testStream(result, Update)( + AddData(inputData, "a"), + CheckNewAnswer(("a", "1")), + Execute { query => + // Verify state format = 2 + val f = query.lastExecution.executedPlan.collect { case f: FlatMapGroupsWithStateExec => f } + assert(f.size == 1) + assert(f.head.stateFormatVersion == 2) + } + ) + } + + test("flatMapGroupsWithState - recovery from checkpoint uses state format version 1") { + // Function to maintain the max event time as state and set the timeout timestamp based on the + // current max event time seen. It returns the max event time in the state, or -1 if the state + // was removed by timeout. + val stateFunc = (key: String, values: Iterator[(String, Long)], state: GroupState[Long]) => { + assertCanGetProcessingTime { state.getCurrentProcessingTimeMs() >= 0 } + assertCanGetWatermark { state.getCurrentWatermarkMs() >= -1 } + + val timeoutDelaySec = 5 + if (state.hasTimedOut) { + state.remove() + Iterator((key, -1)) + } else { + val valuesSeq = values.toSeq + val maxEventTimeSec = math.max(valuesSeq.map(_._2).max, state.getOption.getOrElse(0L)) + val timeoutTimestampSec = maxEventTimeSec + timeoutDelaySec + state.update(maxEventTimeSec) + state.setTimeoutTimestamp(timeoutTimestampSec * 1000) + Iterator((key, maxEventTimeSec.toInt)) + } + } + val inputData = MemoryStream[(String, Int)] + val result = + inputData.toDS + .select($"_1".as("key"), $"_2".cast("timestamp").as("eventTime")) + .withWatermark("eventTime", "10 seconds") + .as[(String, Long)] + .groupByKey(_._1) + .flatMapGroupsWithState(Update, EventTimeTimeout)(stateFunc) + + val resourceUri = this.getClass.getResource( + "/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/").toURI + + val checkpointDir = Utils.createTempDir().getCanonicalFile + // Copy the checkpoint to a temp dir to prevent changes to the original. + // Not doing this will lead to the test passing on the first run, but fail subsequent runs. + FileUtils.copyDirectory(new File(resourceUri), checkpointDir) + + inputData.addData(("a", 11), ("a", 13), ("a", 15)) + inputData.addData(("a", 4)) + + testStream(result, Update)( + StartStream( + checkpointLocation = checkpointDir.getAbsolutePath, + additionalConfs = Map(SQLConf.FLATMAPGROUPSWITHSTATE_STATE_FORMAT_VERSION.key -> "2")), + /* + Note: The checkpoint was generated using the following input in Spark version 2.3.1 + + AddData(inputData, ("a", 11), ("a", 13), ("a", 15)), + // Max event time = 15. Timeout timestamp for "a" = 15 + 5 = 20. Watermark = 15 - 10 = 5. + CheckNewAnswer(("a", 15)), // Output = max event time of a + + AddData(inputData, ("a", 4)), // Add data older than watermark for "a" + CheckNewAnswer(), // No output as data should get filtered by watermark + */ + + AddData(inputData, ("a", 10)), // Add data newer than watermark for "a" + CheckNewAnswer(("a", 15)), // Max event time is still the same + // Timeout timestamp for "a" is still 20 as max event time for "a" is still 15. + // Watermark is still 5 as max event time for all data is still 15. + + Execute { query => + // Verify state format = 1 + val f = query.lastExecution.executedPlan.collect { case f: FlatMapGroupsWithStateExec => f } + assert(f.size == 1) + assert(f.head.stateFormatVersion == 1) + }, + + AddData(inputData, ("b", 31)), // Add data newer than watermark for "b", not "a" + // Watermark = 31 - 10 = 21, so "a" should be timed out as timeout timestamp for "a" is 20. + CheckNewAnswer(("a", -1), ("b", 31)) // State for "a" should timeout and emit -1 + ) + } + + test("mapGroupsWithState - streaming") { // Function to maintain running count up to 2, and then remove the count // Returns the data and the count (-1 if count reached beyond 2 and state was just removed) @@ -1032,7 +1138,7 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest if (priorState.isEmpty && priorTimeoutTimestamp != NO_TIMESTAMP) { return // there can be no prior timestamp, when there is no prior state } - test(s"StateStoreUpdater - updates with data - $testName") { + test(s"InputProcessor - process new data - $testName") { val mapGroupsFunc = (key: Int, values: Iterator[Int], state: GroupState[Int]) => { assert(state.hasTimedOut === false, "hasTimedOut not false") assert(values.nonEmpty, "Some value is expected") @@ -1054,7 +1160,7 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest expectedState: Option[Int], expectedTimeoutTimestamp: Long = NO_TIMESTAMP): Unit = { - test(s"StateStoreUpdater - updates for timeout - $testName") { + test(s"InputProcessor - process timed out state - $testName") { val mapGroupsFunc = (key: Int, values: Iterator[Int], state: GroupState[Int]) => { assert(state.hasTimedOut === true, "hasTimedOut not true") assert(values.isEmpty, "values not empty") @@ -1081,21 +1187,20 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest val store = newStateStore() val mapGroupsSparkPlan = newFlatMapGroupsWithStateExec( mapGroupsFunc, timeoutConf, currentBatchTimestamp) - val updater = new mapGroupsSparkPlan.StateStoreUpdater(store) + val inputProcessor = new mapGroupsSparkPlan.InputProcessor(store) + val stateManager = mapGroupsSparkPlan.stateManager val key = intToRow(0) // Prepare store with prior state configs - if (priorState.nonEmpty) { - val row = updater.getStateRow(priorState.get) - updater.setTimeoutTimestamp(row, priorTimeoutTimestamp) - store.put(key.copy(), row.copy()) + if (priorState.nonEmpty || priorTimeoutTimestamp != NO_TIMESTAMP) { + stateManager.putState(store, key, priorState.orNull, priorTimeoutTimestamp) } // Call updating function to update state store def callFunction() = { val returnedIter = if (testTimeoutUpdates) { - updater.updateStateForTimedOutKeys() + inputProcessor.processTimedOutState() } else { - updater.updateStateForKeysWithData(Iterator(key)) + inputProcessor.processNewData(Iterator(key)) } returnedIter.size // consume the iterator to force state updates } @@ -1106,15 +1211,11 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest } else { // Call function to update and verify updated state in store callFunction() - val updatedStateRow = store.get(key) - assert( - Option(updater.getStateObj(updatedStateRow)).map(_.toString.toInt) === expectedState, + val updatedState = stateManager.getState(store, key) + assert(Option(updatedState.stateObj).map(_.toString.toInt) === expectedState, "final state not as expected") - if (updatedStateRow != null) { - assert( - updater.getTimeoutTimestamp(updatedStateRow) === expectedTimeoutTimestamp, - "final timeout timestamp not as expected") - } + assert(updatedState.timeoutTimestamp === expectedTimeoutTimestamp, + "final timeout timestamp not as expected") } } @@ -1122,6 +1223,8 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest func: (Int, Iterator[Int], GroupState[Int]) => Iterator[Int], timeoutType: GroupStateTimeout = GroupStateTimeout.NoTimeout, batchTimestampMs: Long = NO_TIMESTAMP): FlatMapGroupsWithStateExec = { + val stateFormatVersion = spark.conf.get(SQLConf.FLATMAPGROUPSWITHSTATE_STATE_FORMAT_VERSION) + val emptyRdd = spark.sparkContext.emptyRDD[InternalRow] MemoryStream[Int] .toDS .groupByKey(x => x) @@ -1129,8 +1232,9 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest .logicalPlan.collectFirst { case FlatMapGroupsWithState(f, k, v, g, d, o, s, m, _, t, _) => FlatMapGroupsWithStateExec( - f, k, v, g, d, o, None, s, m, t, - Some(currentBatchTimestamp), Some(currentBatchWatermark), RDDScanExec(g, null, "rdd")) + f, k, v, g, d, o, None, s, stateFormatVersion, m, t, + Some(currentBatchTimestamp), Some(currentBatchWatermark), + RDDScanExec(g, emptyRdd, "rdd")) }.get } @@ -1162,33 +1266,22 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest } def rowToInt(row: UnsafeRow): Int = row.getInt(0) + + def testWithAllStateVersions(name: String)(func: => Unit): Unit = { + for (version <- FlatMapGroupsWithStateExecHelper.supportedVersions) { + test(s"$name - state format version $version") { + withSQLConf(SQLConf.FLATMAPGROUPSWITHSTATE_STATE_FORMAT_VERSION.key -> version.toString) { + func + } + } + } + } } object FlatMapGroupsWithStateSuite { var failInTask = true - class MemoryStateStore extends StateStore() { - import scala.collection.JavaConverters._ - private val map = new ConcurrentHashMap[UnsafeRow, UnsafeRow] - - override def iterator(): Iterator[UnsafeRowPair] = { - map.entrySet.iterator.asScala.map { case e => new UnsafeRowPair(e.getKey, e.getValue) } - } - - override def get(key: UnsafeRow): UnsafeRow = map.get(key) - override def put(key: UnsafeRow, newValue: UnsafeRow): Unit = { - map.put(key.copy(), newValue.copy()) - } - override def remove(key: UnsafeRow): Unit = { map.remove(key) } - override def commit(): Long = version + 1 - override def abort(): Unit = { } - override def id: StateStoreId = null - override def version: Long = 0 - override def metrics: StateStoreMetrics = new StateStoreMetrics(map.size, 0, Map.empty) - override def hasCommitted: Boolean = true - } - def assertCanGetProcessingTime(predicate: => Boolean): Unit = { if (!predicate) throw new TestFailedException("Could not get processing time", 20) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StateStoreMetricsTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StateStoreMetricsTest.scala index e45f9d3e2e97b..fb5d13d09fb0e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StateStoreMetricsTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StateStoreMetricsTest.scala @@ -31,33 +31,37 @@ trait StateStoreMetricsTest extends StreamTest { def assertNumStateRows(total: Seq[Long], updated: Seq[Long]): AssertOnQuery = AssertOnQuery(s"Check total state rows = $total, updated state rows = $updated") { q => - val recentProgress = q.recentProgress - require(recentProgress.nonEmpty, "No progress made, cannot check num state rows") - require(recentProgress.length < spark.sessionState.conf.streamingProgressRetention, - "This test assumes that all progresses are present in q.recentProgress but " + - "some may have been dropped due to retention limits") + // This assumes that the streaming query will not make any progress while the eventually + // is being executed. + eventually(timeout(streamingTimeout)) { + val recentProgress = q.recentProgress + require(recentProgress.nonEmpty, "No progress made, cannot check num state rows") + require(recentProgress.length < spark.sessionState.conf.streamingProgressRetention, + "This test assumes that all progresses are present in q.recentProgress but " + + "some may have been dropped due to retention limits") - if (q.ne(lastQuery)) lastCheckedRecentProgressIndex = -1 - lastQuery = q + if (q.ne(lastQuery)) lastCheckedRecentProgressIndex = -1 + lastQuery = q - val numStateOperators = recentProgress.last.stateOperators.length - val progressesSinceLastCheck = recentProgress - .slice(lastCheckedRecentProgressIndex + 1, recentProgress.length) - .filter(_.stateOperators.length == numStateOperators) + val numStateOperators = recentProgress.last.stateOperators.length + val progressesSinceLastCheck = recentProgress + .slice(lastCheckedRecentProgressIndex + 1, recentProgress.length) + .filter(_.stateOperators.length == numStateOperators) - val allNumUpdatedRowsSinceLastCheck = - progressesSinceLastCheck.map(_.stateOperators.map(_.numRowsUpdated)) + val allNumUpdatedRowsSinceLastCheck = + progressesSinceLastCheck.map(_.stateOperators.map(_.numRowsUpdated)) - lazy val debugString = "recent progresses:\n" + - progressesSinceLastCheck.map(_.prettyJson).mkString("\n\n") + lazy val debugString = "recent progresses:\n" + + progressesSinceLastCheck.map(_.prettyJson).mkString("\n\n") - val numTotalRows = recentProgress.last.stateOperators.map(_.numRowsTotal) - assert(numTotalRows === total, s"incorrect total rows, $debugString") + val numTotalRows = recentProgress.last.stateOperators.map(_.numRowsTotal) + assert(numTotalRows === total, s"incorrect total rows, $debugString") - val numUpdatedRows = arraySum(allNumUpdatedRowsSinceLastCheck, numStateOperators) - assert(numUpdatedRows === updated, s"incorrect updates rows, $debugString") + val numUpdatedRows = arraySum(allNumUpdatedRowsSinceLastCheck, numStateOperators) + assert(numUpdatedRows === updated, s"incorrect updates rows, $debugString") - lastCheckedRecentProgressIndex = recentProgress.length - 1 + lastCheckedRecentProgressIndex = recentProgress.length - 1 + } true } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala index c1ec1eba69fb2..f55ddb5419d20 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala @@ -27,14 +27,17 @@ import scala.util.control.ControlThrowable import com.google.common.util.concurrent.UncheckedExecutionException import org.apache.commons.io.FileUtils import org.apache.hadoop.conf.Configuration +import org.scalatest.time.SpanSugar._ -import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.{SparkConf, SparkContext, TaskContext} import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart} import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.plans.logical.Range import org.apache.spark.sql.catalyst.streaming.InternalOutputModes import org.apache.spark.sql.execution.command.ExplainCommand import org.apache.spark.sql.execution.streaming._ +import org.apache.spark.sql.execution.streaming.continuous.ContinuousExecution +import org.apache.spark.sql.execution.streaming.sources.ContinuousMemoryStream import org.apache.spark.sql.execution.streaming.state.{StateStore, StateStoreConf, StateStoreId, StateStoreProvider} import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf @@ -513,6 +516,120 @@ class StreamSuite extends StreamTest { } } + test("explain-continuous") { + val inputData = ContinuousMemoryStream[Int] + val df = inputData.toDS().map(_ * 2).filter(_ > 5) + + // Test `df.explain` + val explain = ExplainCommand(df.queryExecution.logical, extended = false) + val explainString = + spark.sessionState + .executePlan(explain) + .executedPlan + .executeCollect() + .map(_.getString(0)) + .mkString("\n") + assert(explainString.contains("Filter")) + assert(explainString.contains("MapElements")) + assert(!explainString.contains("LocalTableScan")) + + // Test StreamingQuery.display + val q = df.writeStream.queryName("memory_continuous_explain") + .outputMode(OutputMode.Update()).format("memory") + .trigger(Trigger.Continuous("1 seconds")) + .start() + .asInstanceOf[StreamingQueryWrapper] + .streamingQuery + try { + // in continuous mode, the query will be run even there's no data + // sleep a bit to ensure initialization + eventually(timeout(2.seconds), interval(100.milliseconds)) { + assert(q.lastExecution != null) + } + + val explainWithoutExtended = q.explainInternal(false) + + // `extended = false` only displays the physical plan. + assert("Streaming RelationV2 ContinuousMemoryStream".r + .findAllMatchIn(explainWithoutExtended).size === 0) + assert("ScanV2 ContinuousMemoryStream".r + .findAllMatchIn(explainWithoutExtended).size === 1) + + val explainWithExtended = q.explainInternal(true) + // `extended = true` displays 3 logical plans (Parsed/Optimized/Optimized) and 1 physical + // plan. + assert("Streaming RelationV2 ContinuousMemoryStream".r + .findAllMatchIn(explainWithExtended).size === 3) + assert("ScanV2 ContinuousMemoryStream".r + .findAllMatchIn(explainWithExtended).size === 1) + } finally { + q.stop() + } + } + + test("codegen-microbatch") { + val inputData = MemoryStream[Int] + val df = inputData.toDS().map(_ * 2).filter(_ > 5) + + // Test StreamingQuery.codegen + val q = df.writeStream.queryName("memory_microbatch_codegen") + .outputMode(OutputMode.Update) + .format("memory") + .trigger(Trigger.ProcessingTime("1 seconds")) + .start() + + try { + import org.apache.spark.sql.execution.debug._ + assert("No physical plan. Waiting for data." === codegenString(q)) + assert(codegenStringSeq(q).isEmpty) + + inputData.addData(1, 2, 3, 4, 5) + q.processAllAvailable() + + assertDebugCodegenResult(q) + } finally { + q.stop() + } + } + + test("codegen-continuous") { + val inputData = ContinuousMemoryStream[Int] + val df = inputData.toDS().map(_ * 2).filter(_ > 5) + + // Test StreamingQuery.codegen + val q = df.writeStream.queryName("memory_continuous_codegen") + .outputMode(OutputMode.Update) + .format("memory") + .trigger(Trigger.Continuous("1 seconds")) + .start() + + try { + // in continuous mode, the query will be run even there's no data + // sleep a bit to ensure initialization + eventually(timeout(2.seconds), interval(100.milliseconds)) { + assert(q.asInstanceOf[StreamingQueryWrapper].streamingQuery.lastExecution != null) + } + + assertDebugCodegenResult(q) + } finally { + q.stop() + } + } + + private def assertDebugCodegenResult(query: StreamingQuery): Unit = { + import org.apache.spark.sql.execution.debug._ + + val codegenStr = codegenString(query) + assert(codegenStr.contains("Found 1 WholeStageCodegen subtrees.")) + // assuming that code is generated for the test query + assert(codegenStr.contains("Generated code:")) + + val codegenStrSeq = codegenStringSeq(query) + assert(codegenStrSeq.nonEmpty) + assert(codegenStrSeq.head._1.contains("*(1)")) + assert(codegenStrSeq.head._2.contains("codegenStageId=1")) + } + test("SPARK-19065: dropDuplicates should not create expressions using the same id") { withTempPath { testPath => val data = Seq((1, 2), (2, 3), (3, 4)) @@ -672,7 +789,7 @@ class StreamSuite extends StreamTest { val query = input .toDS() .map { i => - while (!org.apache.spark.TaskContext.get().isInterrupted()) { + while (!TaskContext.get().isInterrupted()) { // keep looping till interrupted by query.stop() Thread.sleep(100) } @@ -805,6 +922,142 @@ class StreamSuite extends StreamTest { } } + test("streaming limit without state") { + val inputData1 = MemoryStream[Int] + testStream(inputData1.toDF().limit(0))( + AddData(inputData1, 1 to 8: _*), + CheckAnswer()) + + val inputData2 = MemoryStream[Int] + testStream(inputData2.toDF().limit(4))( + AddData(inputData2, 1 to 8: _*), + CheckAnswer(1 to 4: _*)) + } + + test("streaming limit with state") { + val inputData = MemoryStream[Int] + testStream(inputData.toDF().limit(4))( + AddData(inputData, 1 to 2: _*), + CheckAnswer(1 to 2: _*), + AddData(inputData, 3 to 6: _*), + CheckAnswer(1 to 4: _*), + AddData(inputData, 7 to 9: _*), + CheckAnswer(1 to 4: _*)) + } + + test("streaming limit with other operators") { + val inputData = MemoryStream[Int] + testStream(inputData.toDF().where("value % 2 = 1").limit(4))( + AddData(inputData, 1 to 5: _*), + CheckAnswer(1, 3, 5), + AddData(inputData, 6 to 9: _*), + CheckAnswer(1, 3, 5, 7), + AddData(inputData, 10 to 12: _*), + CheckAnswer(1, 3, 5, 7)) + } + + test("streaming limit with multiple limits") { + val inputData1 = MemoryStream[Int] + testStream(inputData1.toDF().limit(4).limit(2))( + AddData(inputData1, 1), + CheckAnswer(1), + AddData(inputData1, 2 to 8: _*), + CheckAnswer(1, 2)) + + val inputData2 = MemoryStream[Int] + testStream(inputData2.toDF().limit(4).limit(100).limit(3))( + AddData(inputData2, 1, 2), + CheckAnswer(1, 2), + AddData(inputData2, 3 to 8: _*), + CheckAnswer(1 to 3: _*)) + } + + test("streaming limit in complete mode") { + val inputData = MemoryStream[Int] + val limited = inputData.toDF().limit(5).groupBy("value").count() + testStream(limited, OutputMode.Complete())( + AddData(inputData, 1 to 3: _*), + CheckAnswer(Row(1, 1), Row(2, 1), Row(3, 1)), + AddData(inputData, 1 to 9: _*), + CheckAnswer(Row(1, 2), Row(2, 2), Row(3, 2), Row(4, 1), Row(5, 1))) + } + + test("streaming limits in complete mode") { + val inputData = MemoryStream[Int] + val limited = inputData.toDF().limit(4).groupBy("value").count().orderBy("value").limit(3) + testStream(limited, OutputMode.Complete())( + AddData(inputData, 1 to 9: _*), + CheckAnswer(Row(1, 1), Row(2, 1), Row(3, 1)), + AddData(inputData, 2 to 6: _*), + CheckAnswer(Row(1, 1), Row(2, 2), Row(3, 2))) + } + + test("streaming limit in update mode") { + val inputData = MemoryStream[Int] + val e = intercept[AnalysisException] { + testStream(inputData.toDF().limit(5), OutputMode.Update())( + AddData(inputData, 1 to 3: _*) + ) + } + assert(e.getMessage.contains( + "Limits are not supported on streaming DataFrames/Datasets in Update output mode")) + } + + test("streaming limit in multiple partitions") { + val inputData = MemoryStream[Int] + testStream(inputData.toDF().repartition(2).limit(7))( + AddData(inputData, 1 to 10: _*), + CheckAnswerRowsByFunc( + rows => assert(rows.size == 7 && rows.forall(r => r.getInt(0) <= 10)), + false), + AddData(inputData, 11 to 20: _*), + CheckAnswerRowsByFunc( + rows => assert(rows.size == 7 && rows.forall(r => r.getInt(0) <= 10)), + false)) + } + + test("streaming limit in multiple partitions by column") { + val inputData = MemoryStream[(Int, Int)] + val df = inputData.toDF().repartition(2, $"_2").limit(7) + testStream(df)( + AddData(inputData, (1, 0), (2, 0), (3, 1), (4, 1)), + CheckAnswerRowsByFunc( + rows => assert(rows.size == 4 && rows.forall(r => r.getInt(0) <= 4)), + false), + AddData(inputData, (5, 0), (6, 0), (7, 1), (8, 1)), + CheckAnswerRowsByFunc( + rows => assert(rows.size == 7 && rows.forall(r => r.getInt(0) <= 8)), + false)) + } + + test("is_continuous_processing property should be false for microbatch processing") { + val input = MemoryStream[Int] + val df = input.toDS() + .map(i => TaskContext.get().getLocalProperty(StreamExecution.IS_CONTINUOUS_PROCESSING)) + testStream(df) ( + AddData(input, 1), + CheckAnswer("false") + ) + } + + test("is_continuous_processing property should be true for continuous processing") { + val input = ContinuousMemoryStream[Int] + val stream = input.toDS() + .map(i => TaskContext.get().getLocalProperty(StreamExecution.IS_CONTINUOUS_PROCESSING)) + .writeStream.format("memory") + .queryName("output") + .trigger(Trigger.Continuous("1 seconds")) + .start() + try { + input.addData(1) + stream.processAllAvailable() + } finally { + stream.stop() + } + + checkAnswer(spark.sql("select * from output"), Row("true")) + } + for (e <- Seq( new InterruptedException, new InterruptedIOException, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala index e41b4534ed51d..d878c345c2988 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala @@ -45,7 +45,6 @@ import org.apache.spark.sql.execution.streaming.continuous.{ContinuousExecution, import org.apache.spark.sql.execution.streaming.sources.MemorySinkV2 import org.apache.spark.sql.execution.streaming.state.StateStore import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.sources.v2.DataSourceOptions import org.apache.spark.sql.streaming.StreamingQueryListener._ import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.util.{Clock, SystemClock, Utils} @@ -80,8 +79,11 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be implicit val defaultSignaler: Signaler = ThreadSignaler override def afterAll(): Unit = { - super.afterAll() - StateStore.stop() // stop the state store maintenance thread and unload store providers + try { + super.afterAll() + } finally { + StateStore.stop() // stop the state store maintenance thread and unload store providers + } } protected val defaultTrigger = Trigger.ProcessingTime(0) @@ -292,8 +294,10 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be /** Execute arbitrary code */ object Execute { - def apply(func: StreamExecution => Any): AssertOnQuery = - AssertOnQuery(query => { func(query); true }, "Execute") + def apply(name: String)(func: StreamExecution => Any): AssertOnQuery = + AssertOnQuery(query => { func(query); true }, "name") + + def apply(func: StreamExecution => Any): AssertOnQuery = apply("Execute")(func) } object AwaitEpoch { @@ -338,8 +342,7 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be var currentStream: StreamExecution = null var lastStream: StreamExecution = null val awaiting = new mutable.HashMap[Int, Offset]() // source index -> offset to wait for - val sink = if (useV2Sink) new MemorySinkV2 - else new MemorySink(stream.schema, outputMode, DataSourceOptions.empty()) + val sink = if (useV2Sink) new MemorySinkV2 else new MemorySink(stream.schema, outputMode) val resetConfValues = mutable.Map[String, Option[String]]() val defaultCheckpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath @@ -467,7 +470,7 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be // Block until all data added has been processed for all the source awaiting.foreach { case (sourceIndex, offset) => failAfter(streamingTimeout) { - currentStream.awaitOffset(sourceIndex, offset) + currentStream.awaitOffset(sourceIndex, offset, streamingTimeout.toMillis) // Make sure all processing including no-data-batches have been executed if (!currentStream.triggerClock.isInstanceOf[StreamManualClock]) { currentStream.processAllAvailable() @@ -514,7 +517,7 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be logInfo(s"Processing test stream action: $action") action match { case StartStream(trigger, triggerClock, additionalConfs, checkpointLocation) => - verify(currentStream == null, "stream already running") + verify(currentStream == null || !currentStream.isActive, "stream already running") verify(triggerClock.isInstanceOf[SystemClock] || triggerClock.isInstanceOf[StreamManualClock], "Use either SystemClock or StreamManualClock to start the stream") @@ -686,7 +689,7 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be plan .collect { case r: StreamingExecutionRelation => r.source - case r: StreamingDataSourceV2Relation => r.reader + case r: StreamingDataSourceV2Relation => r.readSupport } .zipWithIndex .find(_._1 == source) @@ -735,7 +738,10 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be } case CheckAnswerRowsByFunc(globalCheckFunction, lastOnly) => - val sparkAnswer = fetchStreamAnswer(currentStream, lastOnly) + val sparkAnswer = currentStream match { + case null => fetchStreamAnswer(lastStream, lastOnly) + case s => fetchStreamAnswer(s, lastOnly) + } try { globalCheckFunction(sparkAnswer) } catch { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala index 382da13430781..97dbb9b0360ec 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala @@ -17,10 +17,11 @@ package org.apache.spark.sql.streaming +import java.io.File import java.util.{Locale, TimeZone} -import org.scalatest.Assertions -import org.scalatest.BeforeAndAfterAll +import org.apache.commons.io.FileUtils +import org.scalatest.{Assertions, BeforeAndAfterAll} import org.apache.spark.{SparkEnv, SparkException} import org.apache.spark.rdd.BlockRDD @@ -31,29 +32,53 @@ import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode} import org.apache.spark.sql.execution.exchange.Exchange import org.apache.spark.sql.execution.streaming._ -import org.apache.spark.sql.execution.streaming.state.StateStore +import org.apache.spark.sql.execution.streaming.state.{StateStore, StreamingAggregationStateManager} import org.apache.spark.sql.expressions.scalalang.typed import org.apache.spark.sql.functions._ +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.OutputMode._ import org.apache.spark.sql.streaming.util.{MockSourceProvider, StreamManualClock} import org.apache.spark.sql.types.StructType import org.apache.spark.storage.{BlockId, StorageLevel, TestBlockId} +import org.apache.spark.util.Utils object FailureSingleton { var firstTime = true } -class StreamingAggregationSuite extends StateStoreMetricsTest - with BeforeAndAfterAll with Assertions { +class StreamingAggregationSuite extends StateStoreMetricsTest with Assertions { - override def afterAll(): Unit = { - super.afterAll() - StateStore.stop() + import testImplicits._ + + def executeFuncWithStateVersionSQLConf( + stateVersion: Int, + confPairs: Seq[(String, String)], + func: => Any): Unit = { + withSQLConf(confPairs ++ + Seq(SQLConf.STREAMING_AGGREGATION_STATE_FORMAT_VERSION.key -> stateVersion.toString): _*) { + func + } } - import testImplicits._ + def testWithAllStateVersions(name: String, confPairs: (String, String)*) + (func: => Any): Unit = { + for (version <- StreamingAggregationStateManager.supportedVersions) { + test(s"$name - state format version $version") { + executeFuncWithStateVersionSQLConf(version, confPairs, func) + } + } + } + + def testQuietlyWithAllStateVersions(name: String, confPairs: (String, String)*) + (func: => Any): Unit = { + for (version <- StreamingAggregationStateManager.supportedVersions) { + testQuietly(s"$name - state format version $version") { + executeFuncWithStateVersionSQLConf(version, confPairs, func) + } + } + } - test("simple count, update mode") { + testWithAllStateVersions("simple count, update mode") { val inputData = MemoryStream[Int] val aggregated = @@ -77,7 +102,7 @@ class StreamingAggregationSuite extends StateStoreMetricsTest ) } - test("count distinct") { + testWithAllStateVersions("count distinct") { val inputData = MemoryStream[(Int, Seq[Int])] val aggregated = @@ -93,7 +118,7 @@ class StreamingAggregationSuite extends StateStoreMetricsTest ) } - test("simple count, complete mode") { + testWithAllStateVersions("simple count, complete mode") { val inputData = MemoryStream[Int] val aggregated = @@ -116,7 +141,7 @@ class StreamingAggregationSuite extends StateStoreMetricsTest ) } - test("simple count, append mode") { + testWithAllStateVersions("simple count, append mode") { val inputData = MemoryStream[Int] val aggregated = @@ -133,7 +158,7 @@ class StreamingAggregationSuite extends StateStoreMetricsTest } } - test("sort after aggregate in complete mode") { + testWithAllStateVersions("sort after aggregate in complete mode") { val inputData = MemoryStream[Int] val aggregated = @@ -158,7 +183,7 @@ class StreamingAggregationSuite extends StateStoreMetricsTest ) } - test("state metrics") { + testWithAllStateVersions("state metrics") { val inputData = MemoryStream[Int] val aggregated = @@ -211,7 +236,7 @@ class StreamingAggregationSuite extends StateStoreMetricsTest ) } - test("multiple keys") { + testWithAllStateVersions("multiple keys") { val inputData = MemoryStream[Int] val aggregated = @@ -228,7 +253,7 @@ class StreamingAggregationSuite extends StateStoreMetricsTest ) } - testQuietly("midbatch failure") { + testQuietlyWithAllStateVersions("midbatch failure") { val inputData = MemoryStream[Int] FailureSingleton.firstTime = true val aggregated = @@ -254,7 +279,7 @@ class StreamingAggregationSuite extends StateStoreMetricsTest ) } - test("typed aggregators") { + testWithAllStateVersions("typed aggregators") { val inputData = MemoryStream[(String, Int)] val aggregated = inputData.toDS().groupByKey(_._1).agg(typed.sumLong(_._2)) @@ -264,7 +289,7 @@ class StreamingAggregationSuite extends StateStoreMetricsTest ) } - test("prune results by current_time, complete mode") { + testWithAllStateVersions("prune results by current_time, complete mode") { import testImplicits._ val clock = new StreamManualClock val inputData = MemoryStream[Long] @@ -316,7 +341,7 @@ class StreamingAggregationSuite extends StateStoreMetricsTest ) } - test("prune results by current_date, complete mode") { + testWithAllStateVersions("prune results by current_date, complete mode") { import testImplicits._ val clock = new StreamManualClock val tz = TimeZone.getDefault.getID @@ -365,7 +390,8 @@ class StreamingAggregationSuite extends StateStoreMetricsTest ) } - test("SPARK-19690: do not convert batch aggregation in streaming query to streaming") { + testWithAllStateVersions("SPARK-19690: do not convert batch aggregation in streaming query " + + "to streaming") { val streamInput = MemoryStream[Int] val batchDF = Seq(1, 2, 3, 4, 5) .toDF("value") @@ -429,7 +455,8 @@ class StreamingAggregationSuite extends StateStoreMetricsTest true } - test("SPARK-21977: coalesce(1) with 0 partition RDD should be repartitioned to 1") { + testWithAllStateVersions("SPARK-21977: coalesce(1) with 0 partition RDD should be " + + "repartitioned to 1") { val inputSource = new BlockRDDBackedSource(spark) MockSourceProvider.withMockSources(inputSource) { // `coalesce(1)` changes the partitioning of data to `SinglePartition` which by default @@ -467,8 +494,8 @@ class StreamingAggregationSuite extends StateStoreMetricsTest } } - test("SPARK-21977: coalesce(1) with aggregation should still be repartitioned when it " + - "has non-empty grouping keys") { + testWithAllStateVersions("SPARK-21977: coalesce(1) with aggregation should still be " + + "repartitioned when it has non-empty grouping keys") { val inputSource = new BlockRDDBackedSource(spark) MockSourceProvider.withMockSources(inputSource) { withTempDir { tempDir => @@ -520,7 +547,7 @@ class StreamingAggregationSuite extends StateStoreMetricsTest } } - test("SPARK-22230: last should change with new batches") { + testWithAllStateVersions("SPARK-22230: last should change with new batches") { val input = MemoryStream[Int] val aggregated = input.toDF().agg(last('value)) @@ -536,7 +563,8 @@ class StreamingAggregationSuite extends StateStoreMetricsTest ) } - test("SPARK-23004: Ensure that TypedImperativeAggregate functions do not throw errors") { + testWithAllStateVersions("SPARK-23004: Ensure that TypedImperativeAggregate functions " + + "do not throw errors", SQLConf.SHUFFLE_PARTITIONS.key -> "1") { // See the JIRA SPARK-23004 for more details. In short, this test reproduces the error // by ensuring the following. // - A streaming query with a streaming aggregation. @@ -545,22 +573,72 @@ class StreamingAggregationSuite extends StateStoreMetricsTest // ObjectHashAggregateExec falls back to sort-based aggregation). This is done by having a // micro-batch with 128 records that shuffle to a single partition. // This test throws the exact error reported in SPARK-23004 without the corresponding fix. - withSQLConf("spark.sql.shuffle.partitions" -> "1") { - val input = MemoryStream[Int] - val df = input.toDF().toDF("value") - .selectExpr("value as group", "value") - .groupBy("group") - .agg(collect_list("value")) - testStream(df, outputMode = OutputMode.Update)( - AddData(input, (1 to spark.sqlContext.conf.objectAggSortBasedFallbackThreshold): _*), - AssertOnQuery { q => - q.processAllAvailable() - true + val input = MemoryStream[Int] + val df = input.toDF().toDF("value") + .selectExpr("value as group", "value") + .groupBy("group") + .agg(collect_list("value")) + testStream(df, outputMode = OutputMode.Update)( + AddData(input, (1 to spark.sqlContext.conf.objectAggSortBasedFallbackThreshold): _*), + AssertOnQuery { q => + q.processAllAvailable() + true + } + ) + } + + + test("simple count, update mode - recovery from checkpoint uses state format version 1") { + val inputData = MemoryStream[Int] + + val aggregated = + inputData.toDF() + .groupBy($"value") + .agg(count("*")) + .as[(Int, Long)] + + val resourceUri = this.getClass.getResource( + "/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/").toURI + + val checkpointDir = Utils.createTempDir().getCanonicalFile + // Copy the checkpoint to a temp dir to prevent changes to the original. + // Not doing this will lead to the test passing on the first run, but fail subsequent runs. + FileUtils.copyDirectory(new File(resourceUri), checkpointDir) + + inputData.addData(3) + inputData.addData(3, 2) + + testStream(aggregated, Update)( + StartStream(checkpointLocation = checkpointDir.getAbsolutePath, + additionalConfs = Map(SQLConf.STREAMING_AGGREGATION_STATE_FORMAT_VERSION.key -> "2")), + /* + Note: The checkpoint was generated using the following input in Spark version 2.3.1 + AddData(inputData, 3), + CheckLastBatch((3, 1)), + AddData(inputData, 3, 2), + CheckLastBatch((3, 2), (2, 1)) + */ + + AddData(inputData, 3, 2, 1), + CheckLastBatch((3, 3), (2, 2), (1, 1)), + + Execute { query => + // Verify state format = 1 + val stateVersions = query.lastExecution.executedPlan.collect { + case f: StateStoreSaveExec => f.stateFormatVersion + case f: StateStoreRestoreExec => f.stateFormatVersion } - ) - } + assert(stateVersions.size == 2) + assert(stateVersions.forall(_ == 1)) + }, + + // By default we run in new tuple mode. + AddData(inputData, 4, 4, 4, 4), + CheckLastBatch((4, 4)) + ) } + /** Add blocks of data to the `BlockRDDBackedSource`. */ case class AddBlockData(source: BlockRDDBackedSource, data: Seq[Int]*) extends AddData { override def addData(query: Option[StreamExecution]): (Source, Offset) = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingDeduplicationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingDeduplicationSuite.scala index 42ffd472eb843..cfd7204ea2931 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingDeduplicationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingDeduplicationSuite.scala @@ -26,15 +26,10 @@ import org.apache.spark.sql.execution.streaming.state.StateStore import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf -class StreamingDeduplicationSuite extends StateStoreMetricsTest with BeforeAndAfterAll { +class StreamingDeduplicationSuite extends StateStoreMetricsTest { import testImplicits._ - override def afterAll(): Unit = { - super.afterAll() - StateStore.stop() - } - test("deduplicate with all columns") { val inputData = MemoryStream[String] val result = inputData.toDS().dropDuplicates() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala index b96f2bcbdd644..fe77a1b4469c5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala @@ -231,7 +231,7 @@ class StreamingQueryListenerSuite extends StreamTest with BeforeAndAfter { test("event ordering") { val listener = new EventCollector withListenerAdded(listener) { - for (i <- 1 to 100) { + for (i <- 1 to 50) { listener.reset() require(listener.startEvent === null) testStream(MemoryStream[Int].toDS)( @@ -299,9 +299,9 @@ class StreamingQueryListenerSuite extends StreamTest with BeforeAndAfter { try { val input = new MemoryStream[Int](0, sqlContext) { @volatile var numTriggers = 0 - override def getEndOffset: OffsetV2 = { + override def latestOffset(): OffsetV2 = { numTriggers += 1 - super.getEndOffset + super.latestOffset() } } val clock = new StreamManualClock() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryStatusAndProgressSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryStatusAndProgressSuite.scala index 79bb827e0de93..7bef687e7e43b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryStatusAndProgressSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryStatusAndProgressSuite.scala @@ -58,7 +58,12 @@ class StreamingQueryStatusAndProgressSuite extends StreamTest with Eventually { | "stateOperators" : [ { | "numRowsTotal" : 0, | "numRowsUpdated" : 1, - | "memoryUsedBytes" : 2 + | "memoryUsedBytes" : 3, + | "customMetrics" : { + | "loadedMapCacheHitCount" : 1, + | "loadedMapCacheMissCount" : 0, + | "stateOnCurrentVersionSizeBytes" : 2 + | } | } ], | "sources" : [ { | "description" : "source", @@ -230,7 +235,11 @@ object StreamingQueryStatusAndProgressSuite { "avg" -> "2016-12-05T20:54:20.827Z", "watermark" -> "2016-12-05T20:54:20.827Z").asJava), stateOperators = Array(new StateOperatorProgress( - numRowsTotal = 0, numRowsUpdated = 1, memoryUsedBytes = 2)), + numRowsTotal = 0, numRowsUpdated = 1, memoryUsedBytes = 3, + customMetrics = new java.util.HashMap(Map("stateOnCurrentVersionSizeBytes" -> 2L, + "loadedMapCacheHitCount" -> 1L, "loadedMapCacheMissCount" -> 0L) + .mapValues(long2Long).asJava) + )), sources = Array( new SourceProgress( description = "source", diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala index dcf6cb5d609ee..1dd817545a969 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala @@ -17,10 +17,10 @@ package org.apache.spark.sql.streaming -import java.{util => ju} -import java.util.Optional import java.util.concurrent.CountDownLatch +import scala.collection.mutable + import org.apache.commons.lang3.RandomStringUtils import org.scalactic.TolerantNumerics import org.scalatest.BeforeAndAfter @@ -29,13 +29,13 @@ import org.scalatest.mockito.MockitoSugar import org.apache.spark.SparkException import org.apache.spark.internal.Logging -import org.apache.spark.sql.{DataFrame, Dataset} -import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.sql.{Column, DataFrame, Dataset, Row} +import org.apache.spark.sql.catalyst.expressions.{Literal, Rand, Randn, Shuffle, Uuid} import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.sources.TestForeachWriter import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.sources.v2.reader.InputPartition +import org.apache.spark.sql.sources.v2.reader.{InputPartition, ScanConfig} import org.apache.spark.sql.sources.v2.reader.streaming.{Offset => OffsetV2} import org.apache.spark.sql.streaming.util.{BlockingSource, MockSourceProvider, StreamManualClock} import org.apache.spark.sql.types.StructType @@ -212,25 +212,17 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi private def dataAdded: Boolean = currentOffset.offset != -1 - // setOffsetRange should take 50 ms the first time it is called after data is added - override def setOffsetRange(start: Optional[OffsetV2], end: Optional[OffsetV2]): Unit = { - synchronized { - if (dataAdded) clock.waitTillTime(1050) - super.setOffsetRange(start, end) - } - } - - // getEndOffset should take 100 ms the first time it is called after data is added - override def getEndOffset(): OffsetV2 = synchronized { - if (dataAdded) clock.waitTillTime(1150) - super.getEndOffset() + // latestOffset should take 50 ms the first time it is called after data is added + override def latestOffset(): OffsetV2 = synchronized { + if (dataAdded) clock.waitTillTime(1050) + super.latestOffset() } // getBatch should take 100 ms the first time it is called - override def planUnsafeInputPartitions(): ju.List[InputPartition[UnsafeRow]] = { + override def planInputPartitions(config: ScanConfig): Array[InputPartition] = { synchronized { - clock.waitTillTime(1350) - super.planUnsafeInputPartitions() + clock.waitTillTime(1150) + super.planInputPartitions(config) } } } @@ -271,34 +263,26 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi AssertOnQuery(_.status.message === "Waiting for next trigger"), AssertOnQuery(_.recentProgress.count(_.numInputRows > 0) === 0), - // Test status and progress when setOffsetRange is being called + // Test status and progress when `latestOffset` is being called AddData(inputData, 1, 2), - AdvanceManualClock(1000), // time = 1000 to start new trigger, will block on setOffsetRange + AdvanceManualClock(1000), // time = 1000 to start new trigger, will block on `latestOffset` AssertStreamExecThreadIsWaitingForTime(1050), AssertOnQuery(_.status.isDataAvailable === false), AssertOnQuery(_.status.isTriggerActive === true), AssertOnQuery(_.status.message.startsWith("Getting offsets from")), AssertOnQuery(_.recentProgress.count(_.numInputRows > 0) === 0), - AdvanceManualClock(50), // time = 1050 to unblock setOffsetRange + AdvanceManualClock(50), // time = 1050 to unblock `latestOffset` AssertClockTime(1050), - AssertStreamExecThreadIsWaitingForTime(1150), // will block on getEndOffset that needs 1150 - AssertOnQuery(_.status.isDataAvailable === false), - AssertOnQuery(_.status.isTriggerActive === true), - AssertOnQuery(_.status.message.startsWith("Getting offsets from")), - AssertOnQuery(_.recentProgress.count(_.numInputRows > 0) === 0), - - AdvanceManualClock(100), // time = 1150 to unblock getEndOffset - AssertClockTime(1150), - // will block on planInputPartitions that needs 1350 - AssertStreamExecThreadIsWaitingForTime(1350), + // will block on `planInputPartitions` that needs 1350 + AssertStreamExecThreadIsWaitingForTime(1150), AssertOnQuery(_.status.isDataAvailable === true), AssertOnQuery(_.status.isTriggerActive === true), AssertOnQuery(_.status.message === "Processing new data"), AssertOnQuery(_.recentProgress.count(_.numInputRows > 0) === 0), - AdvanceManualClock(200), // time = 1350 to unblock planInputPartitions - AssertClockTime(1350), + AdvanceManualClock(100), // time = 1150 to unblock `planInputPartitions` + AssertClockTime(1150), AssertStreamExecThreadIsWaitingForTime(1500), // will block on map task that needs 1500 AssertOnQuery(_.status.isDataAvailable === true), AssertOnQuery(_.status.isTriggerActive === true), @@ -306,7 +290,7 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi AssertOnQuery(_.recentProgress.count(_.numInputRows > 0) === 0), // Test status and progress while batch processing has completed - AdvanceManualClock(150), // time = 1500 to unblock map task + AdvanceManualClock(350), // time = 1500 to unblock map task AssertClockTime(1500), CheckAnswer(2), AssertStreamExecThreadIsWaitingForTime(2000), // will block until the next trigger @@ -326,17 +310,16 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi assert(progress.numInputRows === 2) assert(progress.processedRowsPerSecond === 4.0) - assert(progress.durationMs.get("setOffsetRange") === 50) - assert(progress.durationMs.get("getEndOffset") === 100) - assert(progress.durationMs.get("queryPlanning") === 200) + assert(progress.durationMs.get("latestOffset") === 50) + assert(progress.durationMs.get("queryPlanning") === 100) assert(progress.durationMs.get("walCommit") === 0) - assert(progress.durationMs.get("addBatch") === 150) + assert(progress.durationMs.get("addBatch") === 350) assert(progress.durationMs.get("triggerExecution") === 500) assert(progress.sources.length === 1) assert(progress.sources(0).description contains "MemoryStream") - assert(progress.sources(0).startOffset === "0") - assert(progress.sources(0).endOffset !== null) + assert(progress.sources(0).startOffset === null) // no prior offset + assert(progress.sources(0).endOffset === "0") assert(progress.sources(0).processedRowsPerSecond === 4.0) // 2 rows processed in 500 ms assert(progress.stateOperators.length === 1) @@ -362,6 +345,8 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi assert(query.lastProgress.batchId === 1) assert(query.lastProgress.inputRowsPerSecond === 2.0) assert(query.lastProgress.sources(0).inputRowsPerSecond === 2.0) + assert(query.lastProgress.sources(0).startOffset === "0") + assert(query.lastProgress.sources(0).endOffset === "1") true }, @@ -462,6 +447,9 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi assert(gauges.get("latency").getValue.asInstanceOf[Long] == 0) assert(gauges.get("processingRate-total").getValue.asInstanceOf[Double] == 0.0) assert(gauges.get("inputRate-total").getValue.asInstanceOf[Double] == 0.0) + assert(gauges.get("eventTime-watermark").getValue.asInstanceOf[Long] == 0) + assert(gauges.get("states-rowsTotal").getValue.asInstanceOf[Long] == 0) + assert(gauges.get("states-usedBytes").getValue.asInstanceOf[Long] == 0) sq.stop() } } @@ -832,6 +820,62 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi CheckLastBatch(("A", 1))) } + test("Uuid in streaming query should not produce same uuids in each execution") { + val uuids = mutable.ArrayBuffer[String]() + def collectUuid: Seq[Row] => Unit = { rows: Seq[Row] => + rows.foreach(r => uuids += r.getString(0)) + } + + val stream = MemoryStream[Int] + val df = stream.toDF().select(new Column(Uuid())) + testStream(df)( + AddData(stream, 1), + CheckAnswer(collectUuid), + AddData(stream, 2), + CheckAnswer(collectUuid) + ) + assert(uuids.distinct.size == 2) + } + + test("Rand/Randn in streaming query should not produce same results in each execution") { + val rands = mutable.ArrayBuffer[Double]() + def collectRand: Seq[Row] => Unit = { rows: Seq[Row] => + rows.foreach { r => + rands += r.getDouble(0) + rands += r.getDouble(1) + } + } + + val stream = MemoryStream[Int] + val df = stream.toDF().select(new Column(new Rand()), new Column(new Randn())) + testStream(df)( + AddData(stream, 1), + CheckAnswer(collectRand), + AddData(stream, 2), + CheckAnswer(collectRand) + ) + assert(rands.distinct.size == 4) + } + + test("Shuffle in streaming query should not produce same results in each execution") { + val rands = mutable.ArrayBuffer[Seq[Int]]() + def collectShuffle: Seq[Row] => Unit = { rows: Seq[Row] => + rows.foreach { r => + rands += r.getSeq[Int](0) + } + } + + val stream = MemoryStream[Int] + val df = stream.toDF().select(new Column(new Shuffle(Literal.create[Seq[Int]](0 until 100)))) + testStream(df)( + AddData(stream, 1), + CheckAnswer(collectShuffle), + AddData(stream, 2), + CheckAnswer(collectShuffle) + ) + assert(rands.distinct.size == 2) + } + test("StreamingRelationV2/StreamingExecutionRelation/ContinuousExecutionRelation.toJSON " + "should not fail") { val df = spark.readStream.format("rate").load() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousAggregationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousAggregationSuite.scala index 0223812600961..c5b95fa9b64a9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousAggregationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousAggregationSuite.scala @@ -74,7 +74,7 @@ class ContinuousAggregationSuite extends ContinuousSuiteBase { val df = input.toDF() .select('value as 'copy, 'value) .where('copy =!= 1) - .planWithBarrier + .logicalPlan .coalesce(1) .where('copy =!= 2) .agg(max('value)) @@ -95,7 +95,7 @@ class ContinuousAggregationSuite extends ContinuousSuiteBase { val df = input.toDF() .coalesce(1) - .planWithBarrier + .logicalPlan .coalesce(1) .select('value as 'copy, 'value) .agg(max('value)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousQueuedDataReaderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousQueuedDataReaderSuite.scala index 0e7e6febb53df..d6819eacd07ca 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousQueuedDataReaderSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousQueuedDataReaderSuite.scala @@ -19,19 +19,18 @@ package org.apache.spark.sql.streaming.continuous import java.util.concurrent.{ArrayBlockingQueue, BlockingQueue} -import org.mockito.{ArgumentCaptor, Matchers} import org.mockito.Mockito._ import org.scalatest.mockito.MockitoSugar -import org.apache.spark.{SparkEnv, SparkFunSuite, TaskContext} -import org.apache.spark.rpc.{RpcEndpointRef, RpcEnv} +import org.apache.spark.{SparkEnv, TaskContext} +import org.apache.spark.rpc.RpcEndpointRef +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeProjection, UnsafeRow} import org.apache.spark.sql.execution.streaming.continuous._ -import org.apache.spark.sql.sources.v2.reader.InputPartition -import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousInputPartitionReader, ContinuousReader, PartitionOffset} -import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter +import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousPartitionReader, ContinuousReadSupport, PartitionOffset} +import org.apache.spark.sql.sources.v2.writer.streaming.StreamingWriteSupport import org.apache.spark.sql.streaming.StreamTest -import org.apache.spark.sql.types.{DataType, IntegerType} +import org.apache.spark.sql.types.{DataType, IntegerType, StructType} class ContinuousQueuedDataReaderSuite extends StreamTest with MockitoSugar { case class LongPartitionOffset(offset: Long) extends PartitionOffset @@ -44,8 +43,8 @@ class ContinuousQueuedDataReaderSuite extends StreamTest with MockitoSugar { override def beforeEach(): Unit = { super.beforeEach() epochEndpoint = EpochCoordinatorRef.create( - mock[StreamWriter], - mock[ContinuousReader], + mock[StreamingWriteSupport], + mock[ContinuousReadSupport], mock[ContinuousExecution], coordinatorId, startEpoch, @@ -73,26 +72,26 @@ class ContinuousQueuedDataReaderSuite extends StreamTest with MockitoSugar { */ private def setup(): (BlockingQueue[UnsafeRow], ContinuousQueuedDataReader) = { val queue = new ArrayBlockingQueue[UnsafeRow](1024) - val factory = new InputPartition[UnsafeRow] { - override def createPartitionReader() = new ContinuousInputPartitionReader[UnsafeRow] { - var index = -1 - var curr: UnsafeRow = _ - - override def next() = { - curr = queue.take() - index += 1 - true - } + val partitionReader = new ContinuousPartitionReader[InternalRow] { + var index = -1 + var curr: UnsafeRow = _ + + override def next() = { + curr = queue.take() + index += 1 + true + } - override def get = curr + override def get = curr - override def getOffset = LongPartitionOffset(index) + override def getOffset = LongPartitionOffset(index) - override def close() = {} - } + override def close() = {} } val reader = new ContinuousQueuedDataReader( - new ContinuousDataSourceRDDPartition(0, factory), + 0, + partitionReader, + new StructType().add("i", "int"), mockContext, dataQueueSize = sqlContext.conf.continuousStreamingExecutorQueueSize, epochPollIntervalMs = sqlContext.conf.continuousStreamingExecutorPollIntervalMs) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala index 4980b0cd41f81..3d21bc63e0cc1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala @@ -41,7 +41,7 @@ class ContinuousSuiteBase extends StreamTest { case s: ContinuousExecution => assert(numTriggers >= 2, "must wait for at least 2 triggers to ensure query is initialized") val reader = s.lastExecution.executedPlan.collectFirst { - case DataSourceV2ScanExec(_, _, _, _, r: RateStreamContinuousReader) => r + case DataSourceV2ScanExec(_, _, _, _, r: RateStreamContinuousReadSupport, _) => r }.get val deltaMs = numTriggers * 1000 + 300 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/EpochCoordinatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/EpochCoordinatorSuite.scala index 82836dced9df7..3c973d8ebc704 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/EpochCoordinatorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/EpochCoordinatorSuite.scala @@ -27,9 +27,9 @@ import org.apache.spark._ import org.apache.spark.rpc.RpcEndpointRef import org.apache.spark.sql.LocalSparkSession import org.apache.spark.sql.execution.streaming.continuous._ -import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousReader, PartitionOffset} +import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousReadSupport, PartitionOffset} import org.apache.spark.sql.sources.v2.writer.WriterCommitMessage -import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter +import org.apache.spark.sql.sources.v2.writer.streaming.StreamingWriteSupport import org.apache.spark.sql.test.TestSparkSession class EpochCoordinatorSuite @@ -40,20 +40,20 @@ class EpochCoordinatorSuite private var epochCoordinator: RpcEndpointRef = _ - private var writer: StreamWriter = _ + private var writeSupport: StreamingWriteSupport = _ private var query: ContinuousExecution = _ private var orderVerifier: InOrder = _ override def beforeEach(): Unit = { - val reader = mock[ContinuousReader] - writer = mock[StreamWriter] + val reader = mock[ContinuousReadSupport] + writeSupport = mock[StreamingWriteSupport] query = mock[ContinuousExecution] - orderVerifier = inOrder(writer, query) + orderVerifier = inOrder(writeSupport, query) spark = new TestSparkSession() epochCoordinator - = EpochCoordinatorRef.create(writer, reader, query, "test", 1, spark, SparkEnv.get) + = EpochCoordinatorRef.create(writeSupport, reader, query, "test", 1, spark, SparkEnv.get) } test("single epoch") { @@ -209,12 +209,12 @@ class EpochCoordinatorSuite } private def verifyCommit(epoch: Long): Unit = { - orderVerifier.verify(writer).commit(eqTo(epoch), any()) + orderVerifier.verify(writeSupport).commit(eqTo(epoch), any()) orderVerifier.verify(query).commit(epoch) } private def verifyNoCommitFor(epoch: Long): Unit = { - verify(writer, never()).commit(eqTo(epoch), any()) + verify(writeSupport, never()).commit(eqTo(epoch), any()) verify(query, never()).commit(epoch) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleSuite.scala index f84f3d49707bf..b42f8267916b4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleSuite.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql.execution.streaming.continuous.shuffle import java.util.UUID +import scala.language.implicitConversions + import org.apache.spark.{HashPartitioner, Partition, TaskContext, TaskContextImpl} import org.apache.spark.rpc.RpcEndpointRef import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeProjection, UnsafeRow} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala index c1a28b9bc75ef..aeef4c8fe9332 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala @@ -17,72 +17,74 @@ package org.apache.spark.sql.streaming.sources -import java.util.Optional - -import org.apache.spark.sql.{DataFrame, Row, SQLContext} +import org.apache.spark.sql.{DataFrame, SQLContext} import org.apache.spark.sql.execution.datasources.DataSource import org.apache.spark.sql.execution.streaming.{RateStreamOffset, Sink, StreamingQueryWrapper} import org.apache.spark.sql.execution.streaming.continuous.ContinuousTrigger import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources.{DataSourceRegister, StreamSinkProvider} -import org.apache.spark.sql.sources.v2.{ContinuousReadSupport, DataSourceOptions, MicroBatchReadSupport, StreamWriteSupport} -import org.apache.spark.sql.sources.v2.reader.InputPartition -import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousReader, MicroBatchReader, Offset, PartitionOffset} -import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter +import org.apache.spark.sql.sources.v2._ +import org.apache.spark.sql.sources.v2.reader.{InputPartition, PartitionReaderFactory, ScanConfig, ScanConfigBuilder} +import org.apache.spark.sql.sources.v2.reader.streaming._ +import org.apache.spark.sql.sources.v2.writer.streaming.StreamingWriteSupport import org.apache.spark.sql.streaming.{OutputMode, StreamTest, Trigger} import org.apache.spark.sql.types.StructType import org.apache.spark.util.Utils -case class FakeReader() extends MicroBatchReader with ContinuousReader { - def setOffsetRange(start: Optional[Offset], end: Optional[Offset]): Unit = {} - def getStartOffset: Offset = RateStreamOffset(Map()) - def getEndOffset: Offset = RateStreamOffset(Map()) - def deserializeOffset(json: String): Offset = RateStreamOffset(Map()) - def commit(end: Offset): Unit = {} - def readSchema(): StructType = StructType(Seq()) - def stop(): Unit = {} - def mergeOffsets(offsets: Array[PartitionOffset]): Offset = RateStreamOffset(Map()) - def setStartOffset(start: Optional[Offset]): Unit = {} - - def planInputPartitions(): java.util.ArrayList[InputPartition[Row]] = { +case class FakeReadSupport() extends MicroBatchReadSupport with ContinuousReadSupport { + override def deserializeOffset(json: String): Offset = RateStreamOffset(Map()) + override def commit(end: Offset): Unit = {} + override def stop(): Unit = {} + override def mergeOffsets(offsets: Array[PartitionOffset]): Offset = RateStreamOffset(Map()) + override def fullSchema(): StructType = StructType(Seq()) + override def newScanConfigBuilder(start: Offset, end: Offset): ScanConfigBuilder = null + override def initialOffset(): Offset = RateStreamOffset(Map()) + override def latestOffset(): Offset = RateStreamOffset(Map()) + override def newScanConfigBuilder(start: Offset): ScanConfigBuilder = null + override def createReaderFactory(config: ScanConfig): PartitionReaderFactory = { + throw new IllegalStateException("fake source - cannot actually read") + } + override def createContinuousReaderFactory( + config: ScanConfig): ContinuousPartitionReaderFactory = { + throw new IllegalStateException("fake source - cannot actually read") + } + override def planInputPartitions(config: ScanConfig): Array[InputPartition] = { throw new IllegalStateException("fake source - cannot actually read") } } -trait FakeMicroBatchReadSupport extends MicroBatchReadSupport { - override def createMicroBatchReader( - schema: Optional[StructType], +trait FakeMicroBatchReadSupportProvider extends MicroBatchReadSupportProvider { + override def createMicroBatchReadSupport( checkpointLocation: String, - options: DataSourceOptions): MicroBatchReader = FakeReader() + options: DataSourceOptions): MicroBatchReadSupport = FakeReadSupport() } -trait FakeContinuousReadSupport extends ContinuousReadSupport { - override def createContinuousReader( - schema: Optional[StructType], +trait FakeContinuousReadSupportProvider extends ContinuousReadSupportProvider { + override def createContinuousReadSupport( checkpointLocation: String, - options: DataSourceOptions): ContinuousReader = FakeReader() + options: DataSourceOptions): ContinuousReadSupport = FakeReadSupport() } -trait FakeStreamWriteSupport extends StreamWriteSupport { - override def createStreamWriter( +trait FakeStreamingWriteSupportProvider extends StreamingWriteSupportProvider { + override def createStreamingWriteSupport( queryId: String, schema: StructType, mode: OutputMode, - options: DataSourceOptions): StreamWriter = { + options: DataSourceOptions): StreamingWriteSupport = { throw new IllegalStateException("fake sink - cannot actually write") } } -class FakeReadMicroBatchOnly extends DataSourceRegister with FakeMicroBatchReadSupport { +class FakeReadMicroBatchOnly extends DataSourceRegister with FakeMicroBatchReadSupportProvider { override def shortName(): String = "fake-read-microbatch-only" } -class FakeReadContinuousOnly extends DataSourceRegister with FakeContinuousReadSupport { +class FakeReadContinuousOnly extends DataSourceRegister with FakeContinuousReadSupportProvider { override def shortName(): String = "fake-read-continuous-only" } class FakeReadBothModes extends DataSourceRegister - with FakeMicroBatchReadSupport with FakeContinuousReadSupport { + with FakeMicroBatchReadSupportProvider with FakeContinuousReadSupportProvider { override def shortName(): String = "fake-read-microbatch-continuous" } @@ -90,7 +92,7 @@ class FakeReadNeitherMode extends DataSourceRegister { override def shortName(): String = "fake-read-neither-mode" } -class FakeWrite extends DataSourceRegister with FakeStreamWriteSupport { +class FakeWriteSupportProvider extends DataSourceRegister with FakeStreamingWriteSupportProvider { override def shortName(): String = "fake-write-microbatch-continuous" } @@ -105,8 +107,8 @@ class FakeSink extends Sink { override def addBatch(batchId: Long, data: DataFrame): Unit = {} } -class FakeWriteV1Fallback extends DataSourceRegister - with FakeStreamWriteSupport with StreamSinkProvider { +class FakeWriteSupportProviderV1Fallback extends DataSourceRegister + with FakeStreamingWriteSupportProvider with StreamSinkProvider { override def createSink( sqlContext: SQLContext, @@ -189,11 +191,11 @@ class StreamingDataSourceV2Suite extends StreamTest { val v2Query = testPositiveCase( "fake-read-microbatch-continuous", "fake-write-v1-fallback", Trigger.Once()) assert(v2Query.asInstanceOf[StreamingQueryWrapper].streamingQuery.sink - .isInstanceOf[FakeWriteV1Fallback]) + .isInstanceOf[FakeWriteSupportProviderV1Fallback]) // Ensure we create a V1 sink with the config. Note the config is a comma separated // list, including other fake entries. - val fullSinkName = "org.apache.spark.sql.streaming.sources.FakeWriteV1Fallback" + val fullSinkName = classOf[FakeWriteSupportProviderV1Fallback].getName withSQLConf(SQLConf.DISABLED_V2_STREAMING_WRITERS.key -> s"a,b,c,test,$fullSinkName,d,e") { val v1Query = testPositiveCase( "fake-read-microbatch-continuous", "fake-write-v1-fallback", Trigger.Once()) @@ -217,35 +219,37 @@ class StreamingDataSourceV2Suite extends StreamTest { val writeSource = DataSource.lookupDataSource(write, spark.sqlContext.conf).newInstance() (readSource, writeSource, trigger) match { // Valid microbatch queries. - case (_: MicroBatchReadSupport, _: StreamWriteSupport, t) + case (_: MicroBatchReadSupportProvider, _: StreamingWriteSupportProvider, t) if !t.isInstanceOf[ContinuousTrigger] => testPositiveCase(read, write, trigger) // Valid continuous queries. - case (_: ContinuousReadSupport, _: StreamWriteSupport, _: ContinuousTrigger) => + case (_: ContinuousReadSupportProvider, _: StreamingWriteSupportProvider, + _: ContinuousTrigger) => testPositiveCase(read, write, trigger) // Invalid - can't read at all case (r, _, _) - if !r.isInstanceOf[MicroBatchReadSupport] - && !r.isInstanceOf[ContinuousReadSupport] => + if !r.isInstanceOf[MicroBatchReadSupportProvider] + && !r.isInstanceOf[ContinuousReadSupportProvider] => testNegativeCase(read, write, trigger, s"Data source $read does not support streamed reading") // Invalid - can't write - case (_, w, _) if !w.isInstanceOf[StreamWriteSupport] => + case (_, w, _) if !w.isInstanceOf[StreamingWriteSupportProvider] => testNegativeCase(read, write, trigger, s"Data source $write does not support streamed writing") // Invalid - trigger is continuous but reader is not - case (r, _: StreamWriteSupport, _: ContinuousTrigger) - if !r.isInstanceOf[ContinuousReadSupport] => + case (r, _: StreamingWriteSupportProvider, _: ContinuousTrigger) + if !r.isInstanceOf[ContinuousReadSupportProvider] => testNegativeCase(read, write, trigger, s"Data source $read does not support continuous processing") // Invalid - trigger is microbatch but reader is not case (r, _, t) - if !r.isInstanceOf[MicroBatchReadSupport] && !t.isInstanceOf[ContinuousTrigger] => + if !r.isInstanceOf[MicroBatchReadSupportProvider] && + !t.isInstanceOf[ContinuousTrigger] => testPostCreationNegativeCase(read, write, trigger, s"Data source $read does not support microbatch processing") } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala index b65058fffd339..237872585e11d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala @@ -805,6 +805,80 @@ class DataFrameReaderWriterSuite extends QueryTest with SharedSQLContext with Be } } + test("Insert overwrite table command should output correct schema: basic") { + withTable("tbl", "tbl2") { + withView("view1") { + val df = spark.range(10).toDF("id") + df.write.format("parquet").saveAsTable("tbl") + spark.sql("CREATE VIEW view1 AS SELECT id FROM tbl") + spark.sql("CREATE TABLE tbl2(ID long) USING parquet") + spark.sql("INSERT OVERWRITE TABLE tbl2 SELECT ID FROM view1") + val identifier = TableIdentifier("tbl2") + val location = spark.sessionState.catalog.getTableMetadata(identifier).location.toString + val expectedSchema = StructType(Seq(StructField("ID", LongType, true))) + assert(spark.read.parquet(location).schema == expectedSchema) + checkAnswer(spark.table("tbl2"), df) + } + } + } + + test("Insert overwrite table command should output correct schema: complex") { + withTable("tbl", "tbl2") { + withView("view1") { + val df = spark.range(10).map(x => (x, x.toInt, x.toInt)).toDF("col1", "col2", "col3") + df.write.format("parquet").saveAsTable("tbl") + spark.sql("CREATE VIEW view1 AS SELECT * FROM tbl") + spark.sql("CREATE TABLE tbl2(COL1 long, COL2 int, COL3 int) USING parquet PARTITIONED " + + "BY (COL2) CLUSTERED BY (COL3) INTO 3 BUCKETS") + spark.sql("INSERT OVERWRITE TABLE tbl2 SELECT COL1, COL2, COL3 FROM view1") + val identifier = TableIdentifier("tbl2") + val location = spark.sessionState.catalog.getTableMetadata(identifier).location.toString + val expectedSchema = StructType(Seq( + StructField("COL1", LongType, true), + StructField("COL3", IntegerType, true), + StructField("COL2", IntegerType, true))) + assert(spark.read.parquet(location).schema == expectedSchema) + checkAnswer(spark.table("tbl2"), df) + } + } + } + + test("Create table as select command should output correct schema: basic") { + withTable("tbl", "tbl2") { + withView("view1") { + val df = spark.range(10).toDF("id") + df.write.format("parquet").saveAsTable("tbl") + spark.sql("CREATE VIEW view1 AS SELECT id FROM tbl") + spark.sql("CREATE TABLE tbl2 USING parquet AS SELECT ID FROM view1") + val identifier = TableIdentifier("tbl2") + val location = spark.sessionState.catalog.getTableMetadata(identifier).location.toString + val expectedSchema = StructType(Seq(StructField("ID", LongType, true))) + assert(spark.read.parquet(location).schema == expectedSchema) + checkAnswer(spark.table("tbl2"), df) + } + } + } + + test("Create table as select command should output correct schema: complex") { + withTable("tbl", "tbl2") { + withView("view1") { + val df = spark.range(10).map(x => (x, x.toInt, x.toInt)).toDF("col1", "col2", "col3") + df.write.format("parquet").saveAsTable("tbl") + spark.sql("CREATE VIEW view1 AS SELECT * FROM tbl") + spark.sql("CREATE TABLE tbl2 USING parquet PARTITIONED BY (COL2) " + + "CLUSTERED BY (COL3) INTO 3 BUCKETS AS SELECT COL1, COL2, COL3 FROM view1") + val identifier = TableIdentifier("tbl2") + val location = spark.sessionState.catalog.getTableMetadata(identifier).location.toString + val expectedSchema = StructType(Seq( + StructField("COL1", LongType, true), + StructField("COL3", IntegerType, true), + StructField("COL2", IntegerType, true))) + assert(spark.read.parquet(location).schema == expectedSchema) + checkAnswer(spark.table("tbl2"), df) + } + } + } + test("use Spark jobs to list files") { withSQLConf(SQLConf.PARALLEL_PARTITION_DISCOVERY_THRESHOLD.key -> "1") { withTempDir { dir => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala index 0cfe260e52152..615923fe02d6c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala @@ -136,6 +136,19 @@ private[sql] trait SQLTestData { self => df } + protected lazy val lowerCaseDataWithDuplicates: DataFrame = { + val df = spark.sparkContext.parallelize( + LowerCaseData(1, "a") :: + LowerCaseData(2, "b") :: + LowerCaseData(2, "b") :: + LowerCaseData(3, "c") :: + LowerCaseData(3, "c") :: + LowerCaseData(3, "c") :: + LowerCaseData(4, "d") :: Nil).toDF() + df.createOrReplaceTempView("lowerCaseData") + df + } + protected lazy val arrayData: RDD[ArrayData] = { val rdd = spark.sparkContext.parallelize( ArrayData(Seq(1, 2, 3), Seq(Seq(1, 2, 3))) :: @@ -255,6 +268,17 @@ private[sql] trait SQLTestData { self => df } + protected lazy val trainingSales: DataFrame = { + val df = spark.sparkContext.parallelize( + TrainingSales("Experts", CourseSales("dotNET", 2012, 10000)) :: + TrainingSales("Experts", CourseSales("JAVA", 2012, 20000)) :: + TrainingSales("Dummies", CourseSales("dotNet", 2012, 5000)) :: + TrainingSales("Experts", CourseSales("dotNET", 2013, 48000)) :: + TrainingSales("Dummies", CourseSales("Java", 2013, 30000)) :: Nil).toDF() + df.createOrReplaceTempView("trainingSales") + df + } + /** * Initialize all test data such that all temp tables are properly registered. */ @@ -310,4 +334,5 @@ private[sql] object SQLTestData { case class Salary(personId: Int, salary: Double) case class ComplexData(m: Map[String, Int], s: TestData, a: Seq[Int], b: Boolean) case class CourseSales(course: String, year: Int, earnings: Double) + case class TrainingSales(training: String, sales: CourseSales) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala index bc4a120f7042f..2fb8f70a20791 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala @@ -76,7 +76,7 @@ private[sql] trait SQLTestUtils extends SparkFunSuite with SQLTestUtilsBase with /** * Disable stdout and stderr when running the test. To not output the logs to the console, - * ConsoleAppender's `follow` should be set to `true` so that it will honors reassignments of + * ConsoleAppender's `follow` should be set to `true` so that it will honor reassignments of * System.out or System.err. Otherwise, ConsoleAppender will still output to the console even if * we change System.out and System.err. */ @@ -391,6 +391,13 @@ private[sql] trait SQLTestUtilsBase val fs = hadoopPath.getFileSystem(spark.sessionState.newHadoopConf()) fs.makeQualified(hadoopPath).toUri } + + /** + * Returns full path to the given file in the resource folder + */ + protected def testFile(fileName: String): String = { + Thread.currentThread().getContextClassLoader.getResource(fileName).toString + } } private[sql] object SQLTestUtils { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala index e6c7648c986ae..0dd24d2d56b82 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala @@ -35,7 +35,10 @@ trait SharedSQLContext extends SQLTestUtils with SharedSparkSession { } protected override def afterAll(): Unit = { - super.afterAll() - doThreadPostAudit() + try { + super.afterAll() + } finally { + doThreadPostAudit() + } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSparkSession.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSparkSession.scala index 8968dbf36d507..e7e0ce64963a3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSparkSession.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSparkSession.scala @@ -24,6 +24,7 @@ import org.scalatest.concurrent.Eventually import org.apache.spark.{DebugFilesystem, SparkConf} import org.apache.spark.sql.{SparkSession, SQLContext} +import org.apache.spark.sql.catalyst.optimizer.ConvertToLocalRelation import org.apache.spark.sql.internal.SQLConf /** @@ -39,6 +40,11 @@ trait SharedSparkSession .set("spark.hadoop.fs.file.impl", classOf[DebugFilesystem].getName) .set("spark.unsafe.exceptionOnMemoryLeak", "true") .set(SQLConf.CODEGEN_FALLBACK.key, "false") + // Disable ConvertToLocalRelation for better test coverage. Test cases built on + // LocalRelation will exercise the optimization rules better by disabling it as + // this rule may potentially block testing of other optimization rules such as + // ConstantPropagation etc. + .set(SQLConf.OPTIMIZER_EXCLUDED_RULES.key, ConvertToLocalRelation.ruleName) } /** diff --git a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/Column.java b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/Column.java index 2e21f18d61268..adb269aa235ea 100644 --- a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/Column.java +++ b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/Column.java @@ -349,7 +349,7 @@ public void addValue(Type type, Object field) { break; case FLOAT_TYPE: nulls.set(size, field == null); - doubleVars()[size] = field == null ? 0 : ((Float)field).doubleValue(); + doubleVars()[size] = field == null ? 0 : new Double(field.toString()); break; case DOUBLE_TYPE: nulls.set(size, field == null); diff --git a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/session/HiveSessionImpl.java b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/session/HiveSessionImpl.java index f59cdcd3188e6..745f385e87f78 100644 --- a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/session/HiveSessionImpl.java +++ b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/session/HiveSessionImpl.java @@ -471,7 +471,7 @@ private OperationHandle executeStatementInternal(String statement, Map Unit + // function + ShutdownHookManager.addShutdownHook( + new AbstractFunction0() { + public BoxedUnit apply() { + try { + LOG.info("Hive Server Shutdown hook invoked"); + stop(); + } catch (Throwable e) { + LOG.warn("Ignoring Exception while stopping Hive Server from shutdown hook", + e); + } + return BoxedUnit.UNIT; + } + }); } public static boolean isHTTPTransportMode(HiveConf hiveConf) { @@ -95,7 +110,6 @@ public synchronized void start() { @Override public synchronized void stop() { LOG.info("Shutting down HiveServer2"); - HiveConf hiveConf = this.getHiveConf(); super.stop(); } diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala index d9fd3ebd3c65d..bb96cea2b0ae1 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala @@ -258,6 +258,8 @@ private[hive] object SparkSQLCLIDriver extends Logging { def continuedPromptWithDBSpaces: String = continuedPrompt + ReflectionUtils.invokeStatic( classOf[CliDriver], "spacesForString", classOf[String] -> currentDB) + cli.printMasterAndAppId + var currentPrompt = promptWithCurrentDB var line = reader.readLine(currentPrompt + "> ") @@ -323,6 +325,12 @@ private[hive] class SparkSQLCLIDriver extends CliDriver with Logging { hiveVariables.asScala.foreach(kv => SparkSQLEnv.sqlContext.conf.setConfString(kv._1, kv._2)) } + def printMasterAndAppId(): Unit = { + val master = SparkSQLEnv.sparkContext.master + val appId = SparkSQLEnv.sparkContext.applicationId + console.printInfo(s"Spark master: $master, Application Id: $appId") + } + override def processCmd(cmd: String): Int = { val cmd_trimmed: String = cmd.trim() val cmd_lower = cmd_trimmed.toLowerCase(Locale.ROOT) diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala index 192f33a45e273..70eb28cdd0c64 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala @@ -636,6 +636,14 @@ class HiveThriftBinaryServerSuite extends HiveThriftJdbcTest { assert(pipeoutFileList(sessionID).length == 0) } } + + test("SPARK-24829 Checks cast as float") { + withJdbcStatement() { statement => + val resultSet = statement.executeQuery("SELECT CAST('4.56' AS FLOAT)") + resultSet.next() + assert(resultSet.getString(1) === "4.56") + } + } } class SingleSessionSuite extends HiveThriftJdbcTest { @@ -766,6 +774,14 @@ class HiveThriftHttpServerSuite extends HiveThriftJdbcTest { assert(resultSet.getString(2) === HiveUtils.builtinHiveVersion) } } + + test("SPARK-24829 Checks cast as float") { + withJdbcStatement() { statement => + val resultSet = statement.executeQuery("SELECT CAST('4.56' AS FLOAT)") + resultSet.next() + assert(resultSet.getString(1) === "4.56") + } + } } object ServerMode extends Enumeration { diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/UISeleniumSuite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/UISeleniumSuite.scala index 4c53dd8f4616c..fef18f147b057 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/UISeleniumSuite.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/UISeleniumSuite.scala @@ -46,10 +46,13 @@ class UISeleniumSuite } override def afterAll(): Unit = { - if (webDriver != null) { - webDriver.quit() + try { + if (webDriver != null) { + webDriver.quit() + } + } finally { + super.afterAll() } - super.afterAll() } override protected def serverStartCommand(port: Int) = { diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala index cebaad5b4ad9b..b9b2b7dbf38e8 100644 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala @@ -40,6 +40,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { private val originalColumnBatchSize = TestHive.conf.columnBatchSize private val originalInMemoryPartitionPruning = TestHive.conf.inMemoryPartitionPruning private val originalCrossJoinEnabled = TestHive.conf.crossJoinEnabled + private val originalLimitFlatGlobalLimit = TestHive.conf.limitFlatGlobalLimit private val originalSessionLocalTimeZone = TestHive.conf.sessionLocalTimeZone def testCases: Seq[(String, File)] = { @@ -59,6 +60,8 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { TestHive.setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, true) // Ensures that cross joins are enabled so that we can test them TestHive.setConf(SQLConf.CROSS_JOINS_ENABLED, true) + // Ensure that limit operation returns rows in the same order as Hive + TestHive.setConf(SQLConf.LIMIT_FLAT_GLOBAL_LIMIT, false) // Fix session local timezone to America/Los_Angeles for those timezone sensitive tests // (timestamp_*) TestHive.setConf(SQLConf.SESSION_LOCAL_TIMEZONE, "America/Los_Angeles") @@ -73,6 +76,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { TestHive.setConf(SQLConf.COLUMN_BATCH_SIZE, originalColumnBatchSize) TestHive.setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, originalInMemoryPartitionPruning) TestHive.setConf(SQLConf.CROSS_JOINS_ENABLED, originalCrossJoinEnabled) + TestHive.setConf(SQLConf.LIMIT_FLAT_GLOBAL_LIMIT, originalLimitFlatGlobalLimit) TestHive.setConf(SQLConf.SESSION_LOCAL_TIMEZONE, originalSessionLocalTimeZone) // For debugging dump some statistics about how much time was spent in various optimizer rules diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala index 011a3ba553cb2..505124ae9e7c8 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala @@ -28,6 +28,7 @@ import scala.util.control.NonFatal import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.hadoop.hive.ql.metadata.HiveException +import org.apache.hadoop.hive.serde.serdeConstants.SERIALIZATION_FORMAT import org.apache.thrift.TException import org.apache.spark.{SparkConf, SparkException} @@ -114,7 +115,7 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat * should interpret these special data source properties and restore the original table metadata * before returning it. */ - private[hive] def getRawTable(db: String, table: String): CatalogTable = withClient { + private[hive] def getRawTable(db: String, table: String): CatalogTable = { client.getTable(db, table) } @@ -138,17 +139,37 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat } /** - * Checks the validity of data column names. Hive metastore disallows the table to use comma in - * data column names. Partition columns do not have such a restriction. Views do not have such - * a restriction. + * Checks the validity of data column names. Hive metastore disallows the table to use some + * special characters (',', ':', and ';') in data column names, including nested column names. + * Partition columns do not have such a restriction. Views do not have such a restriction. */ private def verifyDataSchema( tableName: TableIdentifier, tableType: CatalogTableType, dataSchema: StructType): Unit = { if (tableType != VIEW) { - dataSchema.map(_.name).foreach { colName => - if (colName.contains(",")) { - throw new AnalysisException("Cannot create a table having a column whose name contains " + - s"commas in Hive metastore. Table: $tableName; Column: $colName") + val invalidChars = Seq(",", ":", ";") + def verifyNestedColumnNames(schema: StructType): Unit = schema.foreach { f => + f.dataType match { + case st: StructType => verifyNestedColumnNames(st) + case _ if invalidChars.exists(f.name.contains) => + val invalidCharsString = invalidChars.map(c => s"'$c'").mkString(", ") + val errMsg = "Cannot create a table having a nested column whose name contains " + + s"invalid characters ($invalidCharsString) in Hive metastore. Table: $tableName; " + + s"Column: ${f.name}" + throw new AnalysisException(errMsg) + case _ => + } + } + + dataSchema.foreach { f => + f.dataType match { + // Checks top-level column names + case _ if f.name.contains(",") => + throw new AnalysisException("Cannot create a table having a column whose name " + + s"contains commas in Hive metastore. Table: $tableName; Column: ${f.name}") + // Checks nested column names + case st: StructType => + verifyNestedColumnNames(st) + case _ => } } } @@ -765,9 +786,9 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat // schema we read back is different(ignore case and nullability) from the one in table // properties which was written when creating table, we should respect the table schema // from hive. - logWarning(s"The table schema given by Hive metastore(${table.schema.simpleString}) is " + + logWarning(s"The table schema given by Hive metastore(${table.schema.catalogString}) is " + "different from the schema when this table was created by Spark SQL" + - s"(${schemaFromTableProps.simpleString}). We have to fall back to the table schema " + + s"(${schemaFromTableProps.catalogString}). We have to fall back to the table schema " + "from Hive metastore which is not case preserving.") hiveTable.copy(schemaPreservesCase = false) } @@ -786,6 +807,8 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat updateLocationInStorageProps(table, newPath = None).copy( locationUri = tableLocation.map(CatalogUtils.stringToURI(_))) } + val storageWithoutHiveGeneratedProperties = storageWithLocation.copy( + properties = storageWithLocation.properties.filterKeys(!HIVE_GENERATED_STORAGE_PROPERTIES(_))) val partitionProvider = table.properties.get(TABLE_PARTITION_PROVIDER) val schemaFromTableProps = getSchemaFromTableProperties(table) @@ -794,7 +817,7 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat table.copy( provider = Some(provider), - storage = storageWithLocation, + storage = storageWithoutHiveGeneratedProperties, schema = reorderedSchema, partitionColumnNames = partColumnNames, bucketSpec = getBucketSpecFromTableProperties(table), @@ -1289,6 +1312,8 @@ object HiveExternalCatalog { val CREATED_SPARK_VERSION = SPARK_SQL_PREFIX + "create.version" + val HIVE_GENERATED_STORAGE_PROPERTIES = Set(SERIALIZATION_FORMAT) + // When storing data source tables in hive metastore, we need to set data schema to empty if the // schema is hive-incompatible. However we need a hack to preserve existing behavior. Before // Spark 2.0, we do not set a default serde here (this was done in Hive), and so if the user diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala index 94ddeae1bf547..de41bb418181d 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala @@ -175,6 +175,10 @@ private[sql] class HiveSessionCatalog( super.functionExists(name) || hiveFunctions.contains(name.funcName) } + override def isPersistentFunction(name: FunctionIdentifier): Boolean = { + super.isPersistentFunction(name) || hiveFunctions.contains(name.funcName) + } + /** List of functions we pass over to Hive. Note that over time this list should go to 0. */ // We have a list of Hive built-in functions that we do not support. So, we will check // Hive's function registry and lazily load needed functions into our own function registry. diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala index a0c197b06ddab..07ee105404311 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala @@ -87,7 +87,7 @@ class ResolveHiveSerdeTable(session: SparkSession) extends Rule[LogicalPlan] { } } - override def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { + override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case c @ CreateTable(t, _, query) if DDLUtils.isHiveTable(t) => // Finds the database name if the name does not exist. val dbName = t.identifier.database.getOrElse(session.catalog.currentDatabase) @@ -114,7 +114,7 @@ class ResolveHiveSerdeTable(session: SparkSession) extends Rule[LogicalPlan] { } class DetermineTableStats(session: SparkSession) extends Rule[LogicalPlan] { - override def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { + override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case relation: HiveTableRelation if DDLUtils.isHiveTable(relation.tableMeta) && relation.tableMeta.stats.isEmpty => val table = relation.tableMeta @@ -145,11 +145,11 @@ class DetermineTableStats(session: SparkSession) extends Rule[LogicalPlan] { * `PreprocessTableInsertion`. */ object HiveAnalysis extends Rule[LogicalPlan] { - override def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { + override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case InsertIntoTable(r: HiveTableRelation, partSpec, query, overwrite, ifPartitionNotExists) if DDLUtils.isHiveTable(r.tableMeta) => InsertIntoHiveTable(r.tableMeta, partSpec, query, overwrite, - ifPartitionNotExists, query.output) + ifPartitionNotExists, query.output.map(_.name)) case CreateTable(tableDesc, mode, None) if DDLUtils.isHiveTable(tableDesc) => DDLUtils.checkDataColNames(tableDesc) @@ -157,14 +157,14 @@ object HiveAnalysis extends Rule[LogicalPlan] { case CreateTable(tableDesc, mode, Some(query)) if DDLUtils.isHiveTable(tableDesc) => DDLUtils.checkDataColNames(tableDesc) - CreateHiveTableAsSelectCommand(tableDesc, query, query.output, mode) + CreateHiveTableAsSelectCommand(tableDesc, query, query.output.map(_.name), mode) case InsertIntoDir(isLocal, storage, provider, child, overwrite) if DDLUtils.isHiveTable(provider) => val outputPath = new Path(storage.locationUri.get) if (overwrite) DDLUtils.verifyNotReadPath(child, outputPath) - InsertIntoHiveDirCommand(isLocal, storage, child, overwrite, child.output) + InsertIntoHiveDirCommand(isLocal, storage, child, overwrite, child.output.map(_.name)) } } @@ -225,7 +225,7 @@ case class RelationConversions( } override def apply(plan: LogicalPlan): LogicalPlan = { - plan transformUp { + plan resolveOperators { // Write path case InsertIntoTable(r: HiveTableRelation, partition, query, overwrite, ifPartitionNotExists) // Inserting into partitioned table is not supported in Parquet/Orc data source (yet). diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala index b5444a4217924..7d57389947576 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala @@ -110,8 +110,9 @@ class HadoopTableReader( deserializerClass: Class[_ <: Deserializer], filterOpt: Option[PathFilter]): RDD[InternalRow] = { - assert(!hiveTable.isPartitioned, """makeRDDForTable() cannot be called on a partitioned table, - since input formats may differ across partitions. Use makeRDDForTablePartitions() instead.""") + assert(!hiveTable.isPartitioned, + "makeRDDForTable() cannot be called on a partitioned table, since input formats may " + + "differ across partitions. Use makeRDDForPartitionedTable() instead.") // Create local references to member variables, so that the entire `this` object won't be // serialized in the closure below. diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala index 1df46d7431a21..02c1ed93eb2f8 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala @@ -353,15 +353,19 @@ private[hive] class HiveClientImpl( client.getDatabasesByPattern(pattern).asScala } + private def getRawTableOption(dbName: String, tableName: String): Option[HiveTable] = { + Option(client.getTable(dbName, tableName, false /* do not throw exception */)) + } + override def tableExists(dbName: String, tableName: String): Boolean = withHiveState { - Option(client.getTable(dbName, tableName, false /* do not throw exception */)).nonEmpty + getRawTableOption(dbName, tableName).nonEmpty } override def getTableOption( dbName: String, tableName: String): Option[CatalogTable] = withHiveState { logDebug(s"Looking up $dbName.$tableName") - Option(client.getTable(dbName, tableName, false)).map { h => + getRawTableOption(dbName, tableName).map { h => // Note: Hive separates partition columns and the schema, but for us the // partition columns are part of the schema val cols = h.getCols.asScala.map(fromHiveColumn) @@ -923,6 +927,9 @@ private[hive] object HiveClientImpl { case CatalogTableType.MANAGED => HiveTableType.MANAGED_TABLE case CatalogTableType.VIEW => HiveTableType.VIRTUAL_VIEW + case t => + throw new IllegalArgumentException( + s"Unknown table type is found at toHiveTable: $t") }) // Note: In Hive the schema and partition columns must be disjoint sets val (partCols, schema) = table.schema.map(toHiveColumn).partition { c => diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala index 8620f3f6d99fb..bc9d4cd7f4181 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala @@ -45,7 +45,7 @@ import org.apache.spark.sql.catalyst.analysis.NoSuchPermanentFunctionException import org.apache.spark.sql.catalyst.catalog.{CatalogFunction, CatalogTablePartition, CatalogUtils, FunctionResource, FunctionResourceType} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types.{IntegralType, StringType} +import org.apache.spark.sql.types.{AtomicType, IntegralType, StringType} import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.Utils @@ -598,6 +598,7 @@ private[client] class Shim_v0_13 extends Shim_v0_12 { object ExtractableLiteral { def unapply(expr: Expression): Option[String] = expr match { + case Literal(null, _) => None // `null`s can be cast as other types; we want to avoid NPEs. case Literal(value, _: IntegralType) => Some(value.toString) case Literal(value, _: StringType) => Some(quoteStringLiteral(value.toString)) case _ => None @@ -606,7 +607,23 @@ private[client] class Shim_v0_13 extends Shim_v0_12 { object ExtractableLiterals { def unapply(exprs: Seq[Expression]): Option[Seq[String]] = { - val extractables = exprs.map(ExtractableLiteral.unapply) + // SPARK-24879: The Hive metastore filter parser does not support "null", but we still want + // to push down as many predicates as we can while still maintaining correctness. + // In SQL, the `IN` expression evaluates as follows: + // > `1 in (2, NULL)` -> NULL + // > `1 in (1, NULL)` -> true + // > `1 in (2)` -> false + // Since Hive metastore filters are NULL-intolerant binary operations joined only by + // `AND` and `OR`, we can treat `NULL` as `false` and thus rewrite `1 in (2, NULL)` as + // `1 in (2)`. + // If the Hive metastore begins supporting NULL-tolerant predicates and Spark starts + // pushing down these predicates, then this optimization will become incorrect and need + // to be changed. + val extractables = exprs + .filter { + case Literal(null, _) => false + case _ => true + }.map(ExtractableLiteral.unapply) if (extractables.nonEmpty && extractables.forall(_.isDefined)) { Some(extractables.map(_.get)) } else { @@ -660,7 +677,8 @@ private[client] class Shim_v0_13 extends Shim_v0_12 { def unapply(expr: Expression): Option[Attribute] = { expr match { case attr: Attribute => Some(attr) - case Cast(child, dt, _) if !Cast.mayTruncate(child.dataType, dt) => unapply(child) + case Cast(child @ AtomicType(), dt: AtomicType, _) + if Cast.canSafeCast(child.dataType.asInstanceOf[AtomicType], dt) => unapply(child) case _ => None } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala index 2f34f69b5cf48..6a90c44a2633d 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala @@ -182,6 +182,7 @@ private[hive] class IsolatedClientLoader( name.startsWith("org.slf4j") || name.startsWith("org.apache.log4j") || // log4j1.x name.startsWith("org.apache.logging.log4j") || // log4j2 + name.startsWith("org.apache.derby.") || name.startsWith("org.apache.spark.") || (sharesHadoopClasses && isHadoopClass) || name.startsWith("scala.") || diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateHiveTableAsSelectCommand.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateHiveTableAsSelectCommand.scala index 1e801fe1845c4..aa573b54a2b62 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateHiveTableAsSelectCommand.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateHiveTableAsSelectCommand.scala @@ -30,14 +30,14 @@ import org.apache.spark.sql.execution.command.DataWritingCommand /** * Create table and insert the query result into it. * - * @param tableDesc the Table Describe, which may contains serde, storage handler etc. + * @param tableDesc the Table Describe, which may contain serde, storage handler etc. * @param query the query whose result will be insert into the new relation * @param mode SaveMode */ case class CreateHiveTableAsSelectCommand( tableDesc: CatalogTable, query: LogicalPlan, - outputColumns: Seq[Attribute], + outputColumnNames: Seq[String], mode: SaveMode) extends DataWritingCommand { @@ -63,13 +63,14 @@ case class CreateHiveTableAsSelectCommand( query, overwrite = false, ifPartitionNotExists = false, - outputColumns = outputColumns).run(sparkSession, child) + outputColumnNames = outputColumnNames).run(sparkSession, child) } else { // TODO ideally, we should get the output data ready first and then // add the relation into catalog, just in case of failure occurs while data // processing. assert(tableDesc.schema.isEmpty) - catalog.createTable(tableDesc.copy(schema = query.schema), ignoreIfExists = false) + catalog.createTable( + tableDesc.copy(schema = outputColumns.toStructType), ignoreIfExists = false) try { // Read back the metadata of the table which was created just now. @@ -82,7 +83,7 @@ case class CreateHiveTableAsSelectCommand( query, overwrite = true, ifPartitionNotExists = false, - outputColumns = outputColumns).run(sparkSession, child) + outputColumnNames = outputColumnNames).run(sparkSession, child) } catch { case NonFatal(e) => // drop the created table. diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala index 7dcaf170f9693..b3795b4430404 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala @@ -62,6 +62,8 @@ case class HiveTableScanExec( override def conf: SQLConf = sparkSession.sessionState.conf + override def nodeName: String = s"Scan hive ${relation.tableMeta.qualifiedName}" + override lazy val metrics = Map( "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows")) @@ -78,9 +80,9 @@ case class HiveTableScanExec( // Bind all partition key attribute references in the partition pruning predicate for later // evaluation. private lazy val boundPruningPred = partitionPruningPred.reduceLeftOption(And).map { pred => - require( - pred.dataType == BooleanType, - s"Data type of predicate $pred must be BooleanType rather than ${pred.dataType}.") + require(pred.dataType == BooleanType, + s"Data type of predicate $pred must be ${BooleanType.catalogString} rather than " + + s"${pred.dataType.catalogString}.") BindReferences.bindReference(pred, relation.partitionCols) } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveDirCommand.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveDirCommand.scala index cebeca0ce9444..0c694910b06d4 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveDirCommand.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveDirCommand.scala @@ -34,6 +34,7 @@ import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.hive.client.HiveClientImpl +import org.apache.spark.sql.util.SchemaUtils /** * Command for writing the results of `query` to file system. @@ -57,16 +58,20 @@ case class InsertIntoHiveDirCommand( storage: CatalogStorageFormat, query: LogicalPlan, overwrite: Boolean, - outputColumns: Seq[Attribute]) extends SaveAsHiveFile { + outputColumnNames: Seq[String]) extends SaveAsHiveFile { override def run(sparkSession: SparkSession, child: SparkPlan): Seq[Row] = { assert(storage.locationUri.nonEmpty) + SchemaUtils.checkColumnNameDuplication( + outputColumnNames, + s"when inserting into ${storage.locationUri.get}", + sparkSession.sessionState.conf.caseSensitiveAnalysis) val hiveTable = HiveClientImpl.toHiveTable(CatalogTable( identifier = TableIdentifier(storage.locationUri.get.toString, Some("default")), tableType = org.apache.spark.sql.catalyst.catalog.CatalogTableType.VIEW, storage = storage, - schema = query.schema + schema = outputColumns.toStructType )) hiveTable.getMetadata.put(serdeConstants.SERIALIZATION_LIB, storage.serde.getOrElse(classOf[LazySimpleSerDe].getName)) @@ -104,8 +109,7 @@ case class InsertIntoHiveDirCommand( plan = child, hadoopConf = hadoopConf, fileSinkConf = fileSinkConf, - outputLocation = tmpPath.toString, - allColumns = outputColumns) + outputLocation = tmpPath.toString) val fs = writeToPath.getFileSystem(hadoopConf) if (overwrite && fs.exists(writeToPath)) { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala index 02a60f16b3b3a..0ed464dad91b1 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala @@ -69,7 +69,7 @@ case class InsertIntoHiveTable( query: LogicalPlan, overwrite: Boolean, ifPartitionNotExists: Boolean, - outputColumns: Seq[Attribute]) extends SaveAsHiveFile { + outputColumnNames: Seq[String]) extends SaveAsHiveFile { /** * Inserts all the rows in the table into Hive. Row objects are properly serialized with the @@ -198,7 +198,6 @@ case class InsertIntoHiveTable( hadoopConf = hadoopConf, fileSinkConf = fileSinkConf, outputLocation = tmpLocation.toString, - allColumns = outputColumns, partitionAttributes = partitionAttributes) if (partition.nonEmpty) { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/SaveAsHiveFile.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/SaveAsHiveFile.scala index e0f7375387d24..078968ed0145f 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/SaveAsHiveFile.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/SaveAsHiveFile.scala @@ -51,7 +51,6 @@ private[hive] trait SaveAsHiveFile extends DataWritingCommand { hadoopConf: Configuration, fileSinkConf: FileSinkDesc, outputLocation: String, - allColumns: Seq[Attribute], customPartitionLocations: Map[TablePartitionSpec, String] = Map.empty, partitionAttributes: Seq[Attribute] = Nil): Set[String] = { @@ -90,7 +89,7 @@ private[hive] trait SaveAsHiveFile extends DataWritingCommand { fileFormat = new HiveFileFormat(fileSinkConf), committer = committer, outputSpec = - FileFormatWriter.OutputSpec(outputLocation, customPartitionLocations, allColumns), + FileFormatWriter.OutputSpec(outputLocation, customPartitionLocations, outputColumns), hadoopConf = hadoopConf, partitionColumns = partitionAttributes, bucketSpec = None, diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala index dd2144c5fcea8..de8085f07db19 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala @@ -42,7 +42,7 @@ import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.execution.datasources.orc.OrcOptions import org.apache.spark.sql.hive.{HiveInspectors, HiveShim} import org.apache.spark.sql.sources.{Filter, _} -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.types._ import org.apache.spark.util.SerializableConfiguration /** @@ -72,7 +72,6 @@ class OrcFileFormat extends FileFormat with DataSourceRegister with Serializable job: Job, options: Map[String, String], dataSchema: StructType): OutputWriterFactory = { - DataSourceUtils.verifyWriteSchema(this, dataSchema) val orcOptions = new OrcOptions(options, sparkSession.sessionState.conf) @@ -123,7 +122,6 @@ class OrcFileFormat extends FileFormat with DataSourceRegister with Serializable filters: Seq[Filter], options: Map[String, String], hadoopConf: Configuration): (PartitionedFile) => Iterator[InternalRow] = { - DataSourceUtils.verifyReadSchema(this, dataSchema) if (sparkSession.sessionState.conf.orcFilterPushDown) { // Sets pushed predicates @@ -166,7 +164,8 @@ class OrcFileFormat extends FileFormat with DataSourceRegister with Serializable } val recordsIterator = new RecordReaderIterator[OrcStruct](orcRecordReader) - Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => recordsIterator.close())) + Option(TaskContext.get()) + .foreach(_.addTaskCompletionListener[Unit](_ => recordsIterator.close())) // Unwraps `OrcStruct`s to `UnsafeRow`s OrcFileFormat.unwrapOrcStructs( @@ -178,6 +177,23 @@ class OrcFileFormat extends FileFormat with DataSourceRegister with Serializable } } } + + override def supportDataType(dataType: DataType, isReadPath: Boolean): Boolean = dataType match { + case _: AtomicType => true + + case st: StructType => st.forall { f => supportDataType(f.dataType, isReadPath) } + + case ArrayType(elementType, _) => supportDataType(elementType, isReadPath) + + case MapType(keyType, valueType, _) => + supportDataType(keyType, isReadPath) && supportDataType(valueType, isReadPath) + + case udt: UserDefinedType[_] => supportDataType(udt.sqlType, isReadPath) + + case _: NullType => isReadPath + + case _ => false + } } private[orc] class OrcSerializer(dataSchema: StructType, conf: Configuration) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileOperator.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileOperator.scala index 80e44ca504356..713b70f252b6a 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileOperator.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileOperator.scala @@ -92,11 +92,12 @@ private[hive] object OrcFileOperator extends Logging { : Option[StructType] = { // Take the first file where we can open a valid reader if we can find one. Otherwise just // return None to indicate we can't infer the schema. - paths.flatMap(getFileReader(_, conf, ignoreCorruptFiles)).headOption.map { reader => - val readerInspector = reader.getObjectInspector.asInstanceOf[StructObjectInspector] - val schema = readerInspector.getTypeName - logDebug(s"Reading schema from file $paths, got Hive schema string: $schema") - CatalystSqlParser.parseDataType(schema).asInstanceOf[StructType] + paths.toIterator.map(getFileReader(_, conf, ignoreCorruptFiles)).collectFirst { + case Some(reader) => + val readerInspector = reader.getObjectInspector.asInstanceOf[StructObjectInspector] + val schema = readerInspector.getTypeName + logDebug(s"Reading schema from file $paths, got Hive schema string: $schema") + CatalystSqlParser.parseDataType(schema).asInstanceOf[StructType] } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFilters.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFilters.scala index d9efd0cb457cd..aee9cb58a031e 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFilters.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFilters.scala @@ -17,10 +17,12 @@ package org.apache.spark.sql.hive.orc -import org.apache.hadoop.hive.ql.io.sarg.{SearchArgument, SearchArgumentFactory} +import org.apache.hadoop.hive.ql.io.sarg.SearchArgument import org.apache.hadoop.hive.ql.io.sarg.SearchArgument.Builder +import org.apache.hadoop.hive.ql.io.sarg.SearchArgumentFactory.newBuilder import org.apache.spark.internal.Logging +import org.apache.spark.sql.execution.datasources.orc.OrcFilters.buildTree import org.apache.spark.sql.sources._ import org.apache.spark.sql.types._ @@ -62,14 +64,14 @@ private[orc] object OrcFilters extends Logging { // collect all convertible ones to build the final `SearchArgument`. val convertibleFilters = for { filter <- filters - _ <- buildSearchArgument(dataTypeMap, filter, SearchArgumentFactory.newBuilder()) + _ <- buildSearchArgument(dataTypeMap, filter, newBuilder) } yield filter for { // Combines all convertible filters using `And` to produce a single conjunction - conjunction <- convertibleFilters.reduceOption(And) + conjunction <- buildTree(convertibleFilters) // Then tries to build a single ORC `SearchArgument` for the conjunction predicate - builder <- buildSearchArgument(dataTypeMap, conjunction, SearchArgumentFactory.newBuilder()) + builder <- buildSearchArgument(dataTypeMap, conjunction, newBuilder) } yield builder.build() } @@ -77,8 +79,6 @@ private[orc] object OrcFilters extends Logging { dataTypeMap: Map[String, DataType], expression: Filter, builder: Builder): Option[Builder] = { - def newBuilder = SearchArgumentFactory.newBuilder() - def isSearchableType(dataType: DataType): Boolean = dataType match { // Only the values in the Spark types below can be recognized by // the `SearchArgumentImpl.BuilderImpl.boxLiteral()` method. diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala index ee3f99ab7e9bb..71f15a45d162a 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala @@ -36,6 +36,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.{SparkSession, SQLContext} import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation import org.apache.spark.sql.catalyst.catalog.{ExternalCatalog, ExternalCatalogWithListener} +import org.apache.spark.sql.catalyst.optimizer.ConvertToLocalRelation import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, OneRowRelation} import org.apache.spark.sql.execution.{QueryExecution, SQLExecution} import org.apache.spark.sql.execution.command.CacheTableCommand @@ -59,7 +60,12 @@ object TestHive .set("spark.sql.warehouse.dir", TestHiveContext.makeWarehouseDir().toURI.getPath) // SPARK-8910 .set("spark.ui.enabled", "false") - .set("spark.unsafe.exceptionOnMemoryLeak", "true"))) + .set("spark.unsafe.exceptionOnMemoryLeak", "true") + // Disable ConvertToLocalRelation for better test coverage. Test cases built on + // LocalRelation will exercise the optimization rules better by disabling it as + // this rule may potentially block testing of other optimization rules such as + // ConstantPropagation etc. + .set(SQLConf.OPTIMIZER_EXCLUDED_RULES.key, ConvertToLocalRelation.ruleName))) case class TestHiveVersion(hiveClient: HiveClient) diff --git a/sql/hive/src/test/java/org/apache/spark/sql/hive/test/Complex.java b/sql/hive/src/test/java/org/apache/spark/sql/hive/test/Complex.java index a8cbd4fab15bb..48891fdcb1d80 100644 --- a/sql/hive/src/test/java/org/apache/spark/sql/hive/test/Complex.java +++ b/sql/hive/src/test/java/org/apache/spark/sql/hive/test/Complex.java @@ -676,7 +676,7 @@ public int compareTo(Complex other) { } int lastComparison = 0; - Complex typedOther = (Complex)other; + Complex typedOther = other; lastComparison = Boolean.valueOf(isSetAint()).compareTo(typedOther.isSetAint()); if (lastComparison != 0) { diff --git a/sql/hive/src/test/resources/data/scripts/dumpdata_script.py b/sql/hive/src/test/resources/data/scripts/dumpdata_script.py index 341a1b40e07af..5b360208d36f6 100644 --- a/sql/hive/src/test/resources/data/scripts/dumpdata_script.py +++ b/sql/hive/src/test/resources/data/scripts/dumpdata_script.py @@ -18,6 +18,9 @@ # import sys +if sys.version_info[0] >= 3: + xrange = range + for i in xrange(50): for j in xrange(5): for k in xrange(20022): diff --git a/sql/hive/src/test/resources/golden/udf_instr-1-2e76f819563dbaba4beb51e3a130b922 b/sql/hive/src/test/resources/golden/udf_instr-1-2e76f819563dbaba4beb51e3a130b922 index 06461b525b058..967e2d3956414 100644 --- a/sql/hive/src/test/resources/golden/udf_instr-1-2e76f819563dbaba4beb51e3a130b922 +++ b/sql/hive/src/test/resources/golden/udf_instr-1-2e76f819563dbaba4beb51e3a130b922 @@ -1 +1 @@ -instr(str, substr) - Returns the index of the first occurance of substr in str +instr(str, substr) - Returns the index of the first occurrence of substr in str diff --git a/sql/hive/src/test/resources/golden/udf_instr-2-32da357fc754badd6e3898dcc8989182 b/sql/hive/src/test/resources/golden/udf_instr-2-32da357fc754badd6e3898dcc8989182 index 5a8c34271f443..0a745342a4ce9 100644 --- a/sql/hive/src/test/resources/golden/udf_instr-2-32da357fc754badd6e3898dcc8989182 +++ b/sql/hive/src/test/resources/golden/udf_instr-2-32da357fc754badd6e3898dcc8989182 @@ -1,4 +1,4 @@ -instr(str, substr) - Returns the index of the first occurance of substr in str +instr(str, substr) - Returns the index of the first occurrence of substr in str Example: > SELECT instr('Facebook', 'boo') FROM src LIMIT 1; 5 diff --git a/sql/hive/src/test/resources/golden/udf_locate-1-6e41693c9c6dceea4d7fab4c02884e4e b/sql/hive/src/test/resources/golden/udf_locate-1-6e41693c9c6dceea4d7fab4c02884e4e index 84bea329540d1..8e70b0c89b594 100644 --- a/sql/hive/src/test/resources/golden/udf_locate-1-6e41693c9c6dceea4d7fab4c02884e4e +++ b/sql/hive/src/test/resources/golden/udf_locate-1-6e41693c9c6dceea4d7fab4c02884e4e @@ -1 +1 @@ -locate(substr, str[, pos]) - Returns the position of the first occurance of substr in str after position pos +locate(substr, str[, pos]) - Returns the position of the first occurrence of substr in str after position pos diff --git a/sql/hive/src/test/resources/golden/udf_locate-2-d9b5934457931447874d6bb7c13de478 b/sql/hive/src/test/resources/golden/udf_locate-2-d9b5934457931447874d6bb7c13de478 index 092e12586b9e8..e103255a31f03 100644 --- a/sql/hive/src/test/resources/golden/udf_locate-2-d9b5934457931447874d6bb7c13de478 +++ b/sql/hive/src/test/resources/golden/udf_locate-2-d9b5934457931447874d6bb7c13de478 @@ -1,4 +1,4 @@ -locate(substr, str[, pos]) - Returns the position of the first occurance of substr in str after position pos +locate(substr, str[, pos]) - Returns the position of the first occurrence of substr in str after position pos Example: > SELECT locate('bar', 'foobarbar', 5) FROM src LIMIT 1; 7 diff --git a/sql/hive/src/test/resources/golden/udf_translate-2-f7aa38a33ca0df73b7a1e6b6da4b7fe8 b/sql/hive/src/test/resources/golden/udf_translate-2-f7aa38a33ca0df73b7a1e6b6da4b7fe8 index 9ced4ee32cf0b..6caa4b679111d 100644 --- a/sql/hive/src/test/resources/golden/udf_translate-2-f7aa38a33ca0df73b7a1e6b6da4b7fe8 +++ b/sql/hive/src/test/resources/golden/udf_translate-2-f7aa38a33ca0df73b7a1e6b6da4b7fe8 @@ -6,8 +6,8 @@ translate('abcdef', 'adc', '19') returns '1b9ef' replacing 'a' with '1', 'd' wit translate('a b c d', ' ', '') return 'abcd' removing all spaces from the input string -If the same character is present multiple times in the input string, the first occurence of the character is the one that's considered for matching. However, it is not recommended to have the same character more than once in the from string since it's not required and adds to confusion. +If the same character is present multiple times in the input string, the first occurrence of the character is the one that's considered for matching. However, it is not recommended to have the same character more than once in the from string since it's not required and adds to confusion. For example, -translate('abcdef', 'ada', '192') returns '1bc9ef' replaces 'a' with '1' and 'd' with '9' ignoring the second occurence of 'a' in the from string mapping it to '2' +translate('abcdef', 'ada', '192') returns '1bc9ef' replaces 'a' with '1' and 'd' with '9' ignoring the second occurrence of 'a' in the from string mapping it to '2' diff --git a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/annotate_stats_join.q b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/annotate_stats_join.q index 965b0b7ed0a3e..633150b5cf544 100644 --- a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/annotate_stats_join.q +++ b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/annotate_stats_join.q @@ -43,7 +43,7 @@ analyze table loc_orc compute statistics for columns state,locid,zip,year; -- dept_orc - 4 -- loc_orc - 8 --- count distincts for relevant columns (since count distinct values are approximate in some cases count distint values will be greater than number of rows) +-- count distincts for relevant columns (since count distinct values are approximate in some cases count distinct values will be greater than number of rows) -- emp_orc.deptid - 3 -- emp_orc.lastname - 7 -- dept_orc.deptid - 6 diff --git a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/auto_sortmerge_join_11.q b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/auto_sortmerge_join_11.q index da2e26fde7069..e8289772e7544 100644 --- a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/auto_sortmerge_join_11.q +++ b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/auto_sortmerge_join_11.q @@ -26,7 +26,7 @@ set hive.optimize.bucketmapjoin.sortedmerge=true; -- Since size is being used to find the big table, the order of the tables in the join does not matter -- The tables are only bucketed and not sorted, the join should not be converted --- Currenly, a join is only converted to a sort-merge join without a hint, automatic conversion to +-- Currently, a join is only converted to a sort-merge join without a hint, automatic conversion to -- bucketized mapjoin is not done explain extended select count(*) FROM bucket_small a JOIN bucket_big b ON a.key = b.key; select count(*) FROM bucket_small a JOIN bucket_big b ON a.key = b.key; diff --git a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/avro_partitioned.q b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/avro_partitioned.q index 6fe5117026ce8..e4ed7195a0575 100644 --- a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/avro_partitioned.q +++ b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/avro_partitioned.q @@ -69,5 +69,5 @@ SELECT * FROM episodes_partitioned WHERE doctor_pt > 6 ORDER BY air_date; SELECT * FROM episodes_partitioned ORDER BY air_date LIMIT 5; -- Fetch w/filter to specific partition SELECT * FROM episodes_partitioned WHERE doctor_pt = 6; --- Fetch w/non-existant partition +-- Fetch w/non-existent partition SELECT * FROM episodes_partitioned WHERE doctor_pt = 7 LIMIT 5; diff --git a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/decimal_udf.q b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/decimal_udf.q index 0c9f1b86a9e97..39d2d248a311f 100644 --- a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/decimal_udf.q +++ b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/decimal_udf.q @@ -22,7 +22,7 @@ SELECT key + (value/2) FROM DECIMAL_UDF; EXPLAIN SELECT key + '1.0' FROM DECIMAL_UDF; SELECT key + '1.0' FROM DECIMAL_UDF; --- substraction +-- subtraction EXPLAIN SELECT key - key FROM DECIMAL_UDF; SELECT key - key FROM DECIMAL_UDF; diff --git a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby2_map_multi_distinct.q b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby2_map_multi_distinct.q index 3aeae0d5c33d6..d677fe65245ed 100644 --- a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby2_map_multi_distinct.q +++ b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby2_map_multi_distinct.q @@ -13,7 +13,7 @@ INSERT OVERWRITE TABLE dest1 SELECT substr(src.key,1,1), count(DISTINCT substr(s SELECT dest1.* FROM dest1 ORDER BY key; --- HIVE-5560 when group by key is used in distinct funtion, invalid result are returned +-- HIVE-5560 when group by key is used in distinct function, invalid result are returned EXPLAIN FROM src diff --git a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby_sort_8.q b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby_sort_8.q index f53295e4b2435..69d671aa47116 100644 --- a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby_sort_8.q +++ b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby_sort_8.q @@ -12,7 +12,7 @@ LOAD DATA LOCAL INPATH '../../data/files/T1.txt' INTO TABLE T1 PARTITION (ds='1' INSERT OVERWRITE TABLE T1 PARTITION (ds='1') select key, val from T1 where ds = '1'; -- The plan is not converted to a map-side, since although the sorting columns and grouping --- columns match, the user is issueing a distinct. +-- columns match, the user is issuing a distinct. -- However, after HIVE-4310, partial aggregation is performed on the mapper EXPLAIN select count(distinct key) from T1; diff --git a/sql/hive/src/test/resources/regression-test-SPARK-8489/test-2.10.jar b/sql/hive/src/test/resources/regression-test-SPARK-8489/test-2.10.jar deleted file mode 100644 index 3f28d37b93150..0000000000000 Binary files a/sql/hive/src/test/resources/regression-test-SPARK-8489/test-2.10.jar and /dev/null differ diff --git a/sql/hive/src/test/resources/regression-test-SPARK-8489/test-2.12.jar b/sql/hive/src/test/resources/regression-test-SPARK-8489/test-2.12.jar new file mode 100644 index 0000000000000..b0d3fd17a41cb Binary files /dev/null and b/sql/hive/src/test/resources/regression-test-SPARK-8489/test-2.12.jar differ diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CompressionCodecSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CompressionCodecSuite.scala index 4550d350f6db2..30204d1223846 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CompressionCodecSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CompressionCodecSuite.scala @@ -122,7 +122,7 @@ class CompressionCodecSuite extends TestHiveSingleton with ParquetTest with Befo """.stripMargin) } - private def writeDateToTableUsingCTAS( + private def writeDataToTableUsingCTAS( rootDir: File, tableName: String, partitionValue: Option[String], @@ -152,7 +152,7 @@ class CompressionCodecSuite extends TestHiveSingleton with ParquetTest with Befo usingCTAS: Boolean): String = { val partitionValue = if (isPartitioned) Some("test") else None if (usingCTAS) { - writeDateToTableUsingCTAS(tmpDir, tableName, partitionValue, format, compressionCodec) + writeDataToTableUsingCTAS(tmpDir, tableName, partitionValue, format, compressionCodec) } else { createTable(tmpDir, tableName, isPartitioned, format, compressionCodec) writeDataToTable(tableName, partitionValue) @@ -258,8 +258,7 @@ class CompressionCodecSuite extends TestHiveSingleton with ParquetTest with Befo def checkForTableWithCompressProp(format: String, compressCodecs: List[String]): Unit = { Seq(true, false).foreach { isPartitioned => Seq(true, false).foreach { convertMetastore => - // TODO: Also verify CTAS(usingCTAS=true) cases when the bug(SPARK-22926) is fixed. - Seq(false).foreach { usingCTAS => + Seq(true, false).foreach { usingCTAS => checkTableCompressionCodecForCodecs( format, isPartitioned, @@ -281,8 +280,7 @@ class CompressionCodecSuite extends TestHiveSingleton with ParquetTest with Befo def checkForTableWithoutCompressProp(format: String, compressCodecs: List[String]): Unit = { Seq(true, false).foreach { isPartitioned => Seq(true, false).foreach { convertMetastore => - // TODO: Also verify CTAS(usingCTAS=true) cases when the bug(SPARK-22926) is fixed. - Seq(false).foreach { usingCTAS => + Seq(true, false).foreach { usingCTAS => checkTableCompressionCodecForCodecs( format, isPartitioned, diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogSuite.scala index 0a522b6a11c80..1de258f060943 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogSuite.scala @@ -113,4 +113,10 @@ class HiveExternalCatalogSuite extends ExternalCatalogSuite { catalog.createDatabase(newDb("dbWithNullDesc").copy(description = null), ignoreIfExists = false) assert(catalog.getDatabase("dbWithNullDesc").description == "") } + + test("SPARK-23831: Add org.apache.derby to IsolatedClientLoader") { + val client1 = HiveUtils.newClientForMetadata(new SparkConf, new Configuration) + val client2 = HiveUtils.newClientForMetadata(new SparkConf, new Configuration) + assert(!client1.equals(client2)) + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala index 514921875f1f9..a7d6972fa71f7 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala @@ -49,21 +49,31 @@ class HiveExternalCatalogVersionsSuite extends SparkSubmitTestUtils { private val unusedJar = TestUtils.createJarWithClasses(Seq.empty) override def afterAll(): Unit = { - Utils.deleteRecursively(wareHousePath) - Utils.deleteRecursively(tmpDataDir) - Utils.deleteRecursively(sparkTestingDir) - super.afterAll() + try { + Utils.deleteRecursively(wareHousePath) + Utils.deleteRecursively(tmpDataDir) + Utils.deleteRecursively(sparkTestingDir) + } finally { + super.afterAll() + } } private def tryDownloadSpark(version: String, path: String): Unit = { - // Try mirrors a few times until one succeeds - for (i <- 0 until 3) { - // we don't retry on a failure to get mirror url. If we can't get a mirror url, - // the test fails (getStringFromUrl will throw an exception) - val preferredMirror = - getStringFromUrl("https://www.apache.org/dyn/closer.lua?preferred=true") + // Try a few mirrors first; fall back to Apache archive + val mirrors = + (0 until 2).flatMap { _ => + try { + Some(getStringFromUrl("https://www.apache.org/dyn/closer.lua?preferred=true")) + } catch { + // If we can't get a mirror URL, skip it. No retry. + case _: Exception => None + } + } + val sites = mirrors.distinct :+ "https://archive.apache.org/dist" + logInfo(s"Trying to download Spark $version from $sites") + for (site <- sites) { val filename = s"spark-$version-bin-hadoop2.7.tgz" - val url = s"$preferredMirror/spark/spark-$version/$filename" + val url = s"$site/spark/spark-$version/$filename" logInfo(s"Downloading Spark $version from $url") try { getFileFromUrl(url, path, filename) @@ -83,7 +93,8 @@ class HiveExternalCatalogVersionsSuite extends SparkSubmitTestUtils { Seq("rm", "-rf", targetDir).! } } catch { - case ex: Exception => logWarning(s"Failed to download Spark $version from $url", ex) + case ex: Exception => + logWarning(s"Failed to download Spark $version from $url: ${ex.getMessage}") } } fail(s"Unable to download Spark $version") @@ -173,7 +184,7 @@ class HiveExternalCatalogVersionsSuite extends SparkSubmitTestUtils { "--conf", s"spark.sql.test.version.index=$index", "--driver-java-options", s"-Dderby.system.home=${wareHousePath.getCanonicalPath}", tempPyFile.getCanonicalPath) - runSparkSubmit(args, Some(sparkHome.getCanonicalPath)) + runSparkSubmit(args, Some(sparkHome.getCanonicalPath), false) } tempPyFile.delete() @@ -195,7 +206,7 @@ class HiveExternalCatalogVersionsSuite extends SparkSubmitTestUtils { object PROCESS_TABLES extends QueryTest with SQLTestUtils { // Tests the latest version of every release line. - val testingVersions = Seq("2.0.2", "2.1.2", "2.2.1", "2.3.1") + val testingVersions = Seq("2.1.3", "2.2.2", "2.3.1") protected var spark: SparkSession = _ diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala index ba9b944e4a055..688b619cd1bb5 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.hive import org.apache.spark.sql.{QueryTest, Row, SaveMode} -import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.{AliasIdentifier, TableIdentifier} import org.apache.spark.sql.catalyst.catalog.CatalogTableType import org.apache.spark.sql.catalyst.parser.CatalystSqlParser import org.apache.spark.sql.catalyst.plans.logical.SubqueryAlias @@ -62,7 +62,7 @@ class HiveMetastoreCatalogSuite extends TestHiveSingleton with SQLTestUtils { spark.sql("create view vw1 as select 1 as id") val plan = spark.sql("select id from vw1").queryExecution.analyzed val aliases = plan.collect { - case x @ SubqueryAlias("vw1", _) => x + case x @ SubqueryAlias(AliasIdentifier("vw1", Some("default")), _) => x } assert(aliases.size == 1) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveParquetSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveParquetSuite.scala index 09c15473b21c1..e5c9df05d5674 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveParquetSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveParquetSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.hive import org.apache.spark.sql.{QueryTest, Row} import org.apache.spark.sql.execution.datasources.parquet.ParquetTest import org.apache.spark.sql.hive.test.TestHiveSingleton +import org.apache.spark.sql.internal.SQLConf case class Cases(lower: String, UPPER: String) @@ -76,4 +77,19 @@ class HiveParquetSuite extends QueryTest with ParquetTest with TestHiveSingleton } } } + + test("SPARK-25206: wrong records are returned by filter pushdown " + + "when Hive metastore schema and parquet schema are in different letter cases") { + withSQLConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED.key -> true.toString) { + withTempPath { path => + val data = spark.range(1, 10).toDF("id") + data.write.parquet(path.getCanonicalPath) + withTable("SPARK_25206") { + sql("CREATE TABLE SPARK_25206 (ID LONG) USING parquet LOCATION " + + s"'${path.getCanonicalPath}'") + checkAnswer(sql("select id from SPARK_25206 where id > 0"), data) + } + } + } + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala index aa5b531992613..a676cf6ce6925 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.hive import java.io.{BufferedWriter, File, FileWriter} -import scala.tools.nsc.Properties +import scala.util.Properties import org.apache.hadoop.fs.Path import org.scalatest.{BeforeAndAfterEach, Matchers} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertSuite.scala index ab91727049ff5..5879748d05b2b 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertSuite.scala @@ -26,6 +26,7 @@ import org.apache.spark.sql.{QueryTest, _} import org.apache.spark.sql.catalyst.parser.ParseException import org.apache.spark.sql.catalyst.plans.logical.InsertIntoTable import org.apache.spark.sql.hive.test.TestHiveSingleton +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -750,4 +751,27 @@ class InsertSuite extends QueryTest with TestHiveSingleton with BeforeAndAfter } } } + + Seq("LOCAL", "").foreach { local => + Seq(true, false).foreach { caseSensitivity => + Seq("orc", "parquet").foreach { format => + test(s"SPARK-25389 INSERT OVERWRITE $local DIRECTORY ... STORED AS with duplicated names" + + s"(caseSensitivity=$caseSensitivity, format=$format)") { + withTempDir { dir => + withSQLConf(SQLConf.CASE_SENSITIVE.key -> s"$caseSensitivity") { + val m = intercept[AnalysisException] { + sql( + s""" + |INSERT OVERWRITE $local DIRECTORY '${dir.toURI}' + |STORED AS $format + |SELECT 'id', 'id2' ${if (caseSensitivity) "id" else "ID"} + """.stripMargin) + }.getMessage + assert(m.contains("Found duplicate column(s) when inserting into")) + } + } + } + } + } + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ShowCreateTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ShowCreateTableSuite.scala index 473bbced41b31..34ca790299859 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ShowCreateTableSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ShowCreateTableSuite.scala @@ -288,6 +288,21 @@ class ShowCreateTableSuite extends QueryTest with SQLTestUtils with TestHiveSing } } + test("SPARK-24911: keep quotes for nested fields") { + withTable("t1") { + val createTable = "CREATE TABLE `t1`(`a` STRUCT<`b`: STRING>)" + sql(createTable) + val shownDDL = sql(s"SHOW CREATE TABLE t1") + .head() + .getString(0) + .split("\n") + .head + assert(shownDDL == createTable) + + checkCreateTable("t1") + } + } + private def createRawHiveTable(ddl: String): Unit = { hiveContext.sharedState.externalCatalog.unwrapped.asInstanceOf[HiveExternalCatalog] .client.runSqlHive(ddl) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/SparkSubmitTestUtils.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/SparkSubmitTestUtils.scala index 68ed97d6d1f5a..889f81b056397 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/SparkSubmitTestUtils.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/SparkSubmitTestUtils.scala @@ -38,7 +38,10 @@ trait SparkSubmitTestUtils extends SparkFunSuite with TimeLimits { // NOTE: This is an expensive operation in terms of time (10 seconds+). Use sparingly. // This is copied from org.apache.spark.deploy.SparkSubmitSuite - protected def runSparkSubmit(args: Seq[String], sparkHomeOpt: Option[String] = None): Unit = { + protected def runSparkSubmit( + args: Seq[String], + sparkHomeOpt: Option[String] = None, + isSparkTesting: Boolean = true): Unit = { val sparkHome = sparkHomeOpt.getOrElse( sys.props.getOrElse("spark.test.home", fail("spark.test.home is not set!"))) val history = ArrayBuffer.empty[String] @@ -53,7 +56,14 @@ trait SparkSubmitTestUtils extends SparkFunSuite with TimeLimits { val builder = new ProcessBuilder(commands: _*).directory(new File(sparkHome)) val env = builder.environment() - env.put("SPARK_TESTING", "1") + if (isSparkTesting) { + env.put("SPARK_TESTING", "1") + } else { + env.remove("SPARK_TESTING") + env.remove("SPARK_SQL_TESTING") + env.remove("SPARK_PREPEND_CLASSES") + env.remove("SPARK_DIST_CLASSPATH") + } env.put("SPARK_HOME", sparkHome) def captureOutput(source: String)(line: String): Unit = { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala index 61cec82984795..d8ffb29a59317 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala @@ -25,13 +25,14 @@ import scala.util.matching.Regex import org.apache.hadoop.hive.common.StatsSetupConst +import org.apache.spark.metrics.source.HiveCatalogMetrics import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.NoSuchPartitionException import org.apache.spark.sql.catalyst.catalog.{CatalogColumnStat, CatalogStatistics, HiveTableRelation} import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, HistogramBin, HistogramSerializer} import org.apache.spark.sql.catalyst.util.{DateTimeUtils, StringUtils} -import org.apache.spark.sql.execution.command.DDLUtils +import org.apache.spark.sql.execution.command.{CommandUtils, DDLUtils} import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.execution.joins._ import org.apache.spark.sql.hive.HiveExternalCatalog._ @@ -148,6 +149,26 @@ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleto } } + test("SPARK-24626 parallel file listing in Stats computation") { + withSQLConf(SQLConf.PARALLEL_PARTITION_DISCOVERY_THRESHOLD.key -> "2", + SQLConf.PARALLEL_FILE_LISTING_IN_STATS_COMPUTATION.key -> "True") { + val checkSizeTable = "checkSizeTable" + withTable(checkSizeTable) { + sql(s"CREATE TABLE $checkSizeTable (key STRING, value STRING) PARTITIONED BY (ds STRING)") + sql(s"INSERT INTO TABLE $checkSizeTable PARTITION (ds='2010-01-01') SELECT * FROM src") + sql(s"INSERT INTO TABLE $checkSizeTable PARTITION (ds='2010-01-02') SELECT * FROM src") + sql(s"INSERT INTO TABLE $checkSizeTable PARTITION (ds='2010-01-03') SELECT * FROM src") + val tableMeta = spark.sessionState.catalog + .getTableMetadata(TableIdentifier(checkSizeTable)) + HiveCatalogMetrics.reset() + assert(HiveCatalogMetrics.METRIC_PARALLEL_LISTING_JOB_COUNT.getCount() == 0) + val size = CommandUtils.calculateTotalSize(spark, tableMeta) + assert(HiveCatalogMetrics.METRIC_PARALLEL_LISTING_JOB_COUNT.getCount() == 1) + assert(size === BigInt(17436)) + } + } + } + test("analyze non hive compatible datasource tables") { val table = "parquet_tab" withTable(table) { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/UDFSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/UDFSuite.scala index 88cc42efd0fe3..a56c6f73989a7 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/UDFSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/UDFSuite.scala @@ -141,11 +141,10 @@ class UDFSuite withTempDatabase { dbName => withUserDefinedFunction(functionName -> false) { sql(s"CREATE FUNCTION $dbName.$functionName AS '$functionClass'") - // TODO: Re-enable it after can distinguish qualified and unqualified function name - // checkAnswer( - // sql(s"SELECT $dbName.myuPPer(value) from $testTableName"), - // expectedDF - // ) + checkAnswer( + sql(s"SELECT $dbName.$functionName(value) from $testTableName"), + expectedDF + ) checkAnswer( sql(s"SHOW FUNCTIONS like $dbName.$functionNameUpper"), @@ -174,11 +173,10 @@ class UDFSuite // For this block, drop function command uses default.functionName as the function name. withUserDefinedFunction(s"$dbName.$functionNameUpper" -> false) { sql(s"CREATE FUNCTION $dbName.$functionName AS '$functionClass'") - // TODO: Re-enable it after can distinguish qualified and unqualified function name - // checkAnswer( - // sql(s"SELECT $dbName.myupper(value) from $testTableName"), - // expectedDF - // ) + checkAnswer( + sql(s"SELECT $dbName.$functionName(value) from $testTableName"), + expectedDF + ) sql(s"USE $dbName") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/FiltersSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/FiltersSuite.scala index 19765695fbcb4..2a4efd0cce6e0 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/FiltersSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/FiltersSuite.scala @@ -72,6 +72,20 @@ class FiltersSuite extends SparkFunSuite with Logging with PlanTest { (Literal("p2\" and q=\"q2") === a("stringcol", StringType)) :: Nil, """stringcol = 'p1" and q="q1' and 'p2" and q="q2' = stringcol""") + filterTest("SPARK-24879 null literals should be ignored for IN constructs", + (a("intcol", IntegerType) in (Literal(1), Literal(null))) :: Nil, + "(intcol = 1)") + + // Applying the predicate `x IN (NULL)` should return an empty set, but since this optimization + // will be applied by Catalyst, this filter converter does not need to account for this. + filterTest("SPARK-24879 IN predicates with only NULLs will not cause a NPE", + (a("intcol", IntegerType) in Literal(null)) :: Nil, + "") + + filterTest("typecast null literals should not be pushed down in simple predicates", + (a("intcol", IntegerType) === Literal(null, IntegerType)) :: Nil, + "") + private def filterTest(name: String, filters: Seq[Expression], result: String) = { test(name) { withSQLConf(SQLConf.ADVANCED_PARTITION_PREDICATE_PUSHDOWN.key -> "true") { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientSuite.scala index 55275f6b37945..fa9f753795f65 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientSuite.scala @@ -24,7 +24,7 @@ import org.scalatest.BeforeAndAfterAll import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.types.LongType +import org.apache.spark.sql.types.{BooleanType, IntegerType, LongType} // TODO: Refactor this to `HivePartitionFilteringSuite` class HiveClientSuite(version: String) @@ -122,6 +122,22 @@ class HiveClientSuite(version: String) "aa" :: Nil) } + test("getPartitionsByFilter: cast(chunk as int)=1 (not a valid partition predicate)") { + testMetastorePartitionFiltering( + attr("chunk").cast(IntegerType) === 1, + 20170101 to 20170103, + 0 to 23, + "aa" :: "ab" :: "ba" :: "bb" :: Nil) + } + + test("getPartitionsByFilter: cast(chunk as boolean)=true (not a valid partition predicate)") { + testMetastorePartitionFiltering( + attr("chunk").cast(BooleanType) === true, + 20170101 to 20170103, + 0 to 23, + "aa" :: "ab" :: "ba" :: "bb" :: Nil) + } + test("getPartitionsByFilter: 20170101=ds") { testMetastorePartitionFiltering( Literal(20170101) === attr("ds"), @@ -138,7 +154,7 @@ class HiveClientSuite(version: String) "aa" :: "ab" :: "ba" :: "bb" :: Nil) } - test("getPartitionsByFilter: chunk in cast(ds as long)=20170101L") { + test("getPartitionsByFilter: cast(ds as long)=20170101L and h=10") { testMetastorePartitionFiltering( attr("ds").cast(LongType) === 20170101L && attr("h") === 10, 20170101 to 20170101, diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala index ae675149df5e2..c65bf7c14c7a5 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala @@ -1005,6 +1005,19 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te ) ) } + + test("SPARK-24957: average with decimal followed by aggregation returning wrong result") { + val df = Seq(("a", BigDecimal("12.0")), + ("a", BigDecimal("12.0")), + ("a", BigDecimal("11.9999999988")), + ("a", BigDecimal("12.0")), + ("a", BigDecimal("12.0")), + ("a", BigDecimal("11.9999999988")), + ("a", BigDecimal("11.9999999988"))).toDF("text", "number") + val agg1 = df.groupBy($"text").agg(avg($"number").as("avg_res")) + val agg2 = agg1.groupBy($"text").agg(sum($"avg_res")) + checkAnswer(agg2, Row("a", BigDecimal("11.9999999994857142860000"))) + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala index 0341c3b378918..be1aa83d682b2 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.hive.execution import java.io.File import java.net.URI +import java.util.Date import scala.language.existentials @@ -33,6 +34,7 @@ import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.{NoSuchPartitionException, TableAlreadyExistsException} import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.execution.command.{DDLSuite, DDLUtils} +import org.apache.spark.sql.functions._ import org.apache.spark.sql.hive.HiveExternalCatalog import org.apache.spark.sql.hive.HiveUtils.{CONVERT_METASTORE_ORC, CONVERT_METASTORE_PARQUET} import org.apache.spark.sql.hive.orc.OrcFileOperator @@ -58,7 +60,8 @@ class HiveCatalogedDDLSuite extends DDLSuite with TestHiveSingleton with BeforeA protected override def generateTable( catalog: SessionCatalog, name: TableIdentifier, - isDataSource: Boolean): CatalogTable = { + isDataSource: Boolean, + partitionCols: Seq[String] = Seq("a", "b")): CatalogTable = { val storage = if (isDataSource) { val serde = HiveSerDe.sourceToSerDe("parquet") @@ -69,7 +72,7 @@ class HiveCatalogedDDLSuite extends DDLSuite with TestHiveSingleton with BeforeA outputFormat = serde.get.outputFormat, serde = serde.get.serde, compressed = false, - properties = Map("serialization.format" -> "1")) + properties = Map.empty) } else { CatalogStorageFormat( locationUri = Some(catalog.defaultTablePath(name)), @@ -82,17 +85,17 @@ class HiveCatalogedDDLSuite extends DDLSuite with TestHiveSingleton with BeforeA val metadata = new MetadataBuilder() .putString("key", "value") .build() + val schema = new StructType() + .add("col1", "int", nullable = true, metadata = metadata) + .add("col2", "string") CatalogTable( identifier = name, tableType = CatalogTableType.EXTERNAL, storage = storage, - schema = new StructType() - .add("col1", "int", nullable = true, metadata = metadata) - .add("col2", "string") - .add("a", "int") - .add("b", "int"), + schema = schema.copy( + fields = schema.fields ++ partitionCols.map(StructField(_, IntegerType))), provider = if (isDataSource) Some("parquet") else Some("hive"), - partitionColumnNames = Seq("a", "b"), + partitionColumnNames = partitionCols, createTime = 0L, createVersion = org.apache.spark.SPARK_VERSION, tracksPartitionsInCatalog = true) @@ -752,6 +755,73 @@ class HiveDDLSuite } } + test("Insert overwrite Hive table should output correct schema") { + withSQLConf(CONVERT_METASTORE_PARQUET.key -> "false") { + withTable("tbl", "tbl2") { + withView("view1") { + spark.sql("CREATE TABLE tbl(id long)") + spark.sql("INSERT OVERWRITE TABLE tbl VALUES 4") + spark.sql("CREATE VIEW view1 AS SELECT id FROM tbl") + withTempPath { path => + sql( + s""" + |CREATE TABLE tbl2(ID long) USING hive + |OPTIONS(fileFormat 'parquet') + |LOCATION '${path.toURI}' + """.stripMargin) + spark.sql("INSERT OVERWRITE TABLE tbl2 SELECT ID FROM view1") + val expectedSchema = StructType(Seq(StructField("ID", LongType, true))) + assert(spark.read.parquet(path.toString).schema == expectedSchema) + checkAnswer(spark.table("tbl2"), Seq(Row(4))) + } + } + } + } + } + + test("Create Hive table as select should output correct schema") { + withSQLConf(CONVERT_METASTORE_PARQUET.key -> "false") { + withTable("tbl", "tbl2") { + withView("view1") { + spark.sql("CREATE TABLE tbl(id long)") + spark.sql("INSERT OVERWRITE TABLE tbl VALUES 4") + spark.sql("CREATE VIEW view1 AS SELECT id FROM tbl") + withTempPath { path => + sql( + s""" + |CREATE TABLE tbl2 USING hive + |OPTIONS(fileFormat 'parquet') + |LOCATION '${path.toURI}' + |AS SELECT ID FROM view1 + """.stripMargin) + val expectedSchema = StructType(Seq(StructField("ID", LongType, true))) + assert(spark.read.parquet(path.toString).schema == expectedSchema) + checkAnswer(spark.table("tbl2"), Seq(Row(4))) + } + } + } + } + } + + test("SPARK-25313 Insert overwrite directory should output correct schema") { + withSQLConf(CONVERT_METASTORE_PARQUET.key -> "false") { + withTable("tbl") { + withView("view1") { + spark.sql("CREATE TABLE tbl(id long)") + spark.sql("INSERT OVERWRITE TABLE tbl VALUES 4") + spark.sql("CREATE VIEW view1 AS SELECT id FROM tbl") + withTempPath { path => + spark.sql(s"INSERT OVERWRITE LOCAL DIRECTORY '${path.getCanonicalPath}' " + + "STORED AS PARQUET SELECT ID FROM view1") + val expectedSchema = StructType(Seq(StructField("ID", LongType, true))) + assert(spark.read.parquet(path.toString).schema == expectedSchema) + checkAnswer(spark.read.parquet(path.toString), Seq(Row(4))) + } + } + } + } + } + test("alter table partition - storage information") { sql("CREATE TABLE boxes (height INT, length INT) PARTITIONED BY (width INT)") sql("INSERT OVERWRITE TABLE boxes PARTITION (width=4) SELECT 4, 4") @@ -781,7 +851,7 @@ class HiveDDLSuite val part1 = Map("a" -> "1", "b" -> "5") val part2 = Map("a" -> "2", "b" -> "6") val root = new Path(catalog.getTableMetadata(tableIdent).location) - val fs = root.getFileSystem(spark.sparkContext.hadoopConfiguration) + val fs = root.getFileSystem(spark.sessionState.newHadoopConf()) // valid fs.mkdirs(new Path(new Path(root, "a=1"), "b=5")) fs.createNewFile(new Path(new Path(root, "a=1/b=5"), "a.csv")) // file @@ -2248,4 +2318,34 @@ class HiveDDLSuite checkAnswer(spark.table("t4"), Row(0, 0)) } } + + test("SPARK-24812: desc formatted table for last access verification") { + withTable("t1") { + sql( + "CREATE TABLE IF NOT EXISTS t1 (c1_int INT, c2_string STRING, c3_float FLOAT)") + val desc = sql("DESC FORMATTED t1").filter($"col_name".startsWith("Last Access")) + .select("data_type") + // check if the last access time doesnt have the default date of year + // 1970 as its a wrong access time + assert(!(desc.first.toString.contains("1970"))) + } + } + + test("SPARK-24681 checks if nested column names do not include ',', ':', and ';'") { + val expectedMsg = "Cannot create a table having a nested column whose name contains invalid " + + "characters (',', ':', ';') in Hive metastore." + + Seq("nested,column", "nested:column", "nested;column").foreach { nestedColumnName => + withTable("t") { + val e = intercept[AnalysisException] { + spark.range(1) + .select(struct(lit(0).as(nestedColumnName)).as("toplevel")) + .write + .format("hive") + .saveAsTable("t") + }.getMessage + assert(e.contains(expectedMsg)) + } + } + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala index 5d56f89c2271c..c349a327694bf 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala @@ -171,20 +171,15 @@ class HiveExplainSuite extends QueryTest with SQLTestUtils with TestHiveSingleto } } - test("SPARK-23021 AnalysisBarrier should not cut off explain output for parsed logical plans") { - val df = Seq((1, 1)).toDF("a", "b").groupBy("a").count().limit(1) - val outputStream = new java.io.ByteArrayOutputStream() - Console.withOut(outputStream) { - df.explain(true) + test("SPARK-23034 show relation names in Hive table scan nodes") { + val tableName = "tab" + withTable(tableName) { + sql(s"CREATE TABLE $tableName(c1 int) USING hive") + val output = new java.io.ByteArrayOutputStream() + Console.withOut(output) { + spark.table(tableName).explain(extended = false) + } + assert(output.toString.contains(s"Scan hive default.$tableName")) } - assert(outputStream.toString.replaceAll("""#\d+""", "#0").contains( - s"""== Parsed Logical Plan == - |GlobalLimit 1 - |+- LocalLimit 1 - | +- AnalysisBarrier - | +- Aggregate [a#0], [a#0, count(1) AS count#0L] - | +- Project [_1#0 AS a#0, _2#0 AS b#0] - | +- LocalRelation [_1#0, _2#0] - |""".stripMargin)) } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index 2ea51791d0f79..b9c32e789a410 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -84,7 +84,7 @@ class HiveQuerySuite extends HiveComparisonTest with SQLTestUtils with BeforeAnd } // Testing the Broadcast based join for cartesian join (cross join) - // We assume that the Broadcast Join Threshold will works since the src is a small table + // We assume that the Broadcast Join Threshold will work since the src is a small table private val spark_10484_1 = """ | SELECT a.key, b.key | FROM src a LEFT JOIN src b WHERE a.key > b.key + 300 @@ -1177,13 +1177,18 @@ class HiveQuerySuite extends HiveComparisonTest with SQLTestUtils with BeforeAnd assert(spark.table("with_parts").filter($"p" === 2).collect().head == Row(1, 2)) } - val originalValue = spark.sparkContext.hadoopConfiguration.get(modeConfKey, "nonstrict") + // Turn off style check since the following test is to modify hadoop configuration on purpose. + // scalastyle:off hadoopconfiguration + val hadoopConf = spark.sparkContext.hadoopConfiguration + // scalastyle:on hadoopconfiguration + + val originalValue = hadoopConf.get(modeConfKey, "nonstrict") try { - spark.sparkContext.hadoopConfiguration.set(modeConfKey, "nonstrict") + hadoopConf.set(modeConfKey, "nonstrict") sql("INSERT OVERWRITE TABLE with_parts partition(p) select 3, 4") assert(spark.table("with_parts").filter($"p" === 4).collect().head == Row(3, 4)) } finally { - spark.sparkContext.hadoopConfiguration.set(modeConfKey, originalValue) + hadoopConf.set(modeConfKey, originalValue) } } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDAFSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDAFSuite.scala index 7402c9626873c..fe3deceb08067 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDAFSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDAFSuite.scala @@ -37,6 +37,7 @@ class HiveUDAFSuite extends QueryTest with TestHiveSingleton with SQLTestUtils { import testImplicits._ protected override def beforeAll(): Unit = { + super.beforeAll() sql(s"CREATE TEMPORARY FUNCTION mock AS '${classOf[MockUDAF].getName}'") sql(s"CREATE TEMPORARY FUNCTION hive_max AS '${classOf[GenericUDAFMax].getName}'") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ObjectHashAggregateSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ObjectHashAggregateSuite.scala index 8dbcd24cd78de..0ef630bbd3670 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ObjectHashAggregateSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ObjectHashAggregateSuite.scala @@ -43,6 +43,7 @@ class ObjectHashAggregateSuite import testImplicits._ protected override def beforeAll(): Unit = { + super.beforeAll() sql(s"CREATE TEMPORARY FUNCTION hive_max AS '${classOf[GenericUDAFMax].getName}'") } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala index cc592cf6ca629..16541295eb453 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala @@ -22,21 +22,29 @@ import scala.collection.JavaConverters._ import org.scalatest.BeforeAndAfter import org.apache.spark.sql.hive.test.{TestHive, TestHiveQueryExecution} +import org.apache.spark.sql.internal.SQLConf /** * A set of test cases that validate partition and column pruning. */ class PruningSuite extends HiveComparisonTest with BeforeAndAfter { + private val originalLimitFlatGlobalLimit = TestHive.conf.limitFlatGlobalLimit + override def beforeAll(): Unit = { super.beforeAll() TestHive.setCacheTables(false) + TestHive.setConf(SQLConf.LIMIT_FLAT_GLOBAL_LIMIT, false) // Column/partition pruning is not implemented for `InMemoryColumnarTableScan` yet, // need to reset the environment to ensure all referenced tables in this suites are // not cached in-memory. Refer to https://issues.apache.org/jira/browse/SPARK-2283 // for details. TestHive.reset() } + override def afterAll() { + TestHive.setConf(SQLConf.LIMIT_FLAT_GLOBAL_LIMIT, originalLimitFlatGlobalLimit) + super.afterAll() + } // Column pruning tests diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 828c18a770c80..20c4c36c05091 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -1912,11 +1912,60 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { sql("LOAD DATA LOCAL INPATH '/non-exist-folder/*part*' INTO TABLE load_t") }.getMessage assert(m.contains("LOAD DATA input path does not exist")) + } + } + } - val m2 = intercept[AnalysisException] { - sql(s"LOAD DATA LOCAL INPATH '$path*/*part*' INTO TABLE load_t") + test("Support wildcard character in folderlevel for LOAD DATA LOCAL INPATH") { + withTempDir { dir => + val path = dir.toURI.toString.stripSuffix("/") + val dirPath = dir.getAbsoluteFile + for (i <- 1 to 3) { + Files.write(s"$i", new File(dirPath, s"part-r-0000$i"), StandardCharsets.UTF_8) + } + withTable("load_t_folder_wildcard") { + sql("CREATE TABLE load_t (a STRING)") + sql(s"LOAD DATA LOCAL INPATH '${ + path.substring(0, path.length - 1) + .concat("*") + }/' INTO TABLE load_t") + checkAnswer(sql("SELECT * FROM load_t"), Seq(Row("1"), Row("2"), Row("3"))) + val m = intercept[AnalysisException] { + sql(s"LOAD DATA LOCAL INPATH '${ + path.substring(0, path.length - 1).concat("_invalid_dir") concat ("*") + }/' INTO TABLE load_t") }.getMessage - assert(m2.contains("LOAD DATA input path allows only filename wildcard")) + assert(m.contains("LOAD DATA input path does not exist")) + } + } + } + + test("SPARK-17796 Support wildcard '?'char in middle as part of local file path") { + withTempDir { dir => + val path = dir.toURI.toString.stripSuffix("/") + val dirPath = dir.getAbsoluteFile + for (i <- 1 to 3) { + Files.write(s"$i", new File(dirPath, s"part-r-0000$i"), StandardCharsets.UTF_8) + } + withTable("load_t1") { + sql("CREATE TABLE load_t1 (a STRING)") + sql(s"LOAD DATA LOCAL INPATH '$path/part-r-0000?' INTO TABLE load_t1") + checkAnswer(sql("SELECT * FROM load_t1"), Seq(Row("1"), Row("2"), Row("3"))) + } + } + } + + test("SPARK-17796 Support wildcard '?'char in start as part of local file path") { + withTempDir { dir => + val path = dir.toURI.toString.stripSuffix("/") + val dirPath = dir.getAbsoluteFile + for (i <- 1 to 3) { + Files.write(s"$i", new File(dirPath, s"part-r-0000$i"), StandardCharsets.UTF_8) + } + withTable("load_t2") { + sql("CREATE TABLE load_t2 (a STRING)") + sql(s"LOAD DATA LOCAL INPATH '$path/?art-r-00001' INTO TABLE load_t2") + checkAnswer(sql("SELECT * FROM load_t2"), Seq(Row("1"))) } } } @@ -1967,6 +2016,22 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { } } + test("column resolution scenarios with hive table") { + val currentDb = spark.catalog.currentDatabase + withTempDatabase { db1 => + try { + spark.catalog.setCurrentDatabase(db1) + spark.sql("CREATE TABLE t1(i1 int) STORED AS parquet") + spark.sql("INSERT INTO t1 VALUES(1)") + checkAnswer(spark.sql(s"SELECT $db1.t1.i1 FROM t1"), Row(1)) + checkAnswer(spark.sql(s"SELECT $db1.t1.i1 FROM $db1.t1"), Row(1)) + checkAnswer(spark.sql(s"SELECT $db1.t1.* FROM $db1.t1"), Row(1)) + } finally { + spark.catalog.setCurrentDatabase(currentDb) + } + } + } + test("SPARK-17409: Do Not Optimize Query in CTAS (Hive Serde Table) More Than Once") { withTable("bar") { withTempView("foo") { @@ -2053,7 +2118,7 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { val deleteOnExitField = classOf[FileSystem].getDeclaredField("deleteOnExit") deleteOnExitField.setAccessible(true) - val fs = FileSystem.get(spark.sparkContext.hadoopConfiguration) + val fs = FileSystem.get(spark.sessionState.newHadoopConf()) val setOfPath = deleteOnExitField.get(fs).asInstanceOf[Set[Path]] val testData = sparkContext.parallelize(1 to 10).map(i => TestData(i, i.toString)).toDF() diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ScriptTransformationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ScriptTransformationSuite.scala index 5318b4650b01f..5f73b7170c612 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ScriptTransformationSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ScriptTransformationSuite.scala @@ -136,6 +136,25 @@ class ScriptTransformationSuite extends SparkPlanTest with TestHiveSingleton { } assert(e.getMessage.contains("Subprocess exited with status")) } + + test("SPARK-24339 verify the result after pruning the unused columns") { + val rowsDf = Seq( + ("Bob", 16, 176), + ("Alice", 32, 164), + ("David", 60, 192), + ("Amy", 24, 180)).toDF("name", "age", "height") + + checkAnswer( + rowsDf, + (child: SparkPlan) => new ScriptTransformationExec( + input = Seq(rowsDf.col("name").expr), + script = "cat", + output = Seq(AttributeReference("name", StringType)()), + child = child, + ioschema = serdeIOSchema + ), + rowsDf.select("name").collect()) + } } private case class ExceptionInjectingOperator(child: SparkPlan) extends UnaryExecNode { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/HiveOrcSourceSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/HiveOrcSourceSuite.scala index 69009e1b520c2..d84f9a3828207 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/HiveOrcSourceSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/HiveOrcSourceSuite.scala @@ -146,33 +146,26 @@ class HiveOrcSourceSuite extends OrcSuite with TestHiveSingleton { }.getMessage assert(msg.contains("Cannot save interval data type into external storage.")) - msg = intercept[UnsupportedOperationException] { + msg = intercept[AnalysisException] { sql("select null").write.mode("overwrite").orc(orcDir) }.getMessage assert(msg.contains("ORC data source does not support null data type.")) - msg = intercept[UnsupportedOperationException] { + msg = intercept[AnalysisException] { spark.udf.register("testType", () => new IntervalData()) sql("select testType()").write.mode("overwrite").orc(orcDir) }.getMessage assert(msg.contains("ORC data source does not support calendarinterval data type.")) // read path - msg = intercept[UnsupportedOperationException] { + msg = intercept[AnalysisException] { val schema = StructType(StructField("a", CalendarIntervalType, true) :: Nil) spark.range(1).write.mode("overwrite").orc(orcDir) spark.read.schema(schema).orc(orcDir).collect() }.getMessage assert(msg.contains("ORC data source does not support calendarinterval data type.")) - msg = intercept[UnsupportedOperationException] { - val schema = StructType(StructField("a", NullType, true) :: Nil) - spark.range(1).write.mode("overwrite").orc(orcDir) - spark.read.schema(schema).orc(orcDir).collect() - }.getMessage - assert(msg.contains("ORC data source does not support null data type.")) - - msg = intercept[UnsupportedOperationException] { + msg = intercept[AnalysisException] { val schema = StructType(StructField("a", new IntervalUDT(), true) :: Nil) spark.range(1).write.mode("overwrite").orc(orcDir) spark.read.schema(schema).orc(orcDir).collect() diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala index 2327d83a1b4f6..e82d457eee394 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala @@ -1068,7 +1068,7 @@ abstract class ParquetPartitioningTest extends QueryTest with SQLTestUtils with test(s"SPARK-5775 read array from $table") { checkAnswer( sql(s"SELECT arrayField, p FROM $table WHERE p = 1"), - (1 to 10).map(i => Row(1 to i, 1))) + (1 to 10).map(i => Row((1 to i).toArray, 1))) } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/HadoopFsRelationTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/HadoopFsRelationTest.scala index 53397991e59dc..b9ec940ac4925 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/HadoopFsRelationTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/HadoopFsRelationTest.scala @@ -666,7 +666,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils with Tes assert(expectedResult.isRight, s"Was not expecting error with $path: " + e) assert( e.getMessage.contains(expectedResult.right.get), - s"Did not find expected error message wiht $path") + s"Did not find expected error message with $path") } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala index e23edfa506517..4a4d2c5d9d8c8 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala @@ -940,6 +940,11 @@ abstract class DStream[T: ClassTag] ( object DStream { + private val SPARK_CLASS_REGEX = """^org\.apache\.spark""".r + private val SPARK_STREAMING_TESTCLASS_REGEX = """^org\.apache\.spark\.streaming\.test""".r + private val SPARK_EXAMPLES_CLASS_REGEX = """^org\.apache\.spark\.examples""".r + private val SCALA_CLASS_REGEX = """^scala""".r + // `toPairDStreamFunctions` was in SparkContext before 1.3 and users had to // `import StreamingContext._` to enable it. Now we move it here to make the compiler find // it automatically. However, we still keep the old function in StreamingContext for backward @@ -953,11 +958,6 @@ object DStream { /** Get the creation site of a DStream from the stack trace of when the DStream is created. */ private[streaming] def getCreationSite(): CallSite = { - val SPARK_CLASS_REGEX = """^org\.apache\.spark""".r - val SPARK_STREAMING_TESTCLASS_REGEX = """^org\.apache\.spark\.streaming\.test""".r - val SPARK_EXAMPLES_CLASS_REGEX = """^org\.apache\.spark\.examples""".r - val SCALA_CLASS_REGEX = """^scala""".r - /** Filtering function that excludes non-user classes for a streaming application */ def streamingExclustionFunction(className: String): Boolean = { def doesMatch(r: Regex): Boolean = r.findFirstIn(className).isDefined diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ExecutorAllocationManager.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ExecutorAllocationManager.scala index 7b29b40668def..8717555dea491 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ExecutorAllocationManager.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ExecutorAllocationManager.scala @@ -26,7 +26,7 @@ import org.apache.spark.streaming.util.RecurringTimer import org.apache.spark.util.{Clock, Utils} /** - * Class that manages executor allocated to a StreamingContext, and dynamically request or kill + * Class that manages executors allocated to a StreamingContext, and dynamically requests or kills * executors based on the statistics of the streaming computation. This is different from the core * dynamic allocation policy; the core policy relies on executors being idle for a while, but the * micro-batch model of streaming prevents any particular executors from being idle for a long @@ -43,6 +43,10 @@ import org.apache.spark.util.{Clock, Utils} * * This features should ideally be used in conjunction with backpressure, as backpressure ensures * system stability, while executors are being readjusted. + * + * Note that an initial set of executors (spark.executor.instances) was allocated when the + * SparkContext was created. This class scales executors up/down after the StreamingContext + * has started. */ private[streaming] class ExecutorAllocationManager( client: ExecutorAllocationClient, @@ -202,12 +206,7 @@ private[streaming] object ExecutorAllocationManager extends Logging { val MAX_EXECUTORS_KEY = "spark.streaming.dynamicAllocation.maxExecutors" def isDynamicAllocationEnabled(conf: SparkConf): Boolean = { - val numExecutor = conf.getInt("spark.executor.instances", 0) val streamingDynamicAllocationEnabled = conf.getBoolean(ENABLED_KEY, false) - if (numExecutor != 0 && streamingDynamicAllocationEnabled) { - throw new IllegalArgumentException( - "Dynamic Allocation for streaming cannot be enabled while spark.executor.instances is set.") - } if (Utils.isDynamicAllocationEnabled(conf) && streamingDynamicAllocationEnabled) { throw new IllegalArgumentException( """ @@ -217,7 +216,7 @@ private[streaming] object ExecutorAllocationManager extends Logging { """.stripMargin) } val testing = conf.getBoolean("spark.streaming.dynamicAllocation.testing", false) - numExecutor == 0 && streamingDynamicAllocationEnabled && (!Utils.isLocalMaster(conf) || testing) + streamingDynamicAllocationEnabled && (!Utils.isLocalMaster(conf) || testing) } def createIfEnabled( diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala index 2e8599026ea1d..f0161e1465c29 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala @@ -312,6 +312,7 @@ private[streaming] object FileBasedWriteAheadLog { handler: I => Iterator[O]): Iterator[O] = { val taskSupport = new ExecutionContextTaskSupport(executionContext) val groupSize = taskSupport.parallelismLevel.max(8) + source.grouped(groupSize).flatMap { group => val parallelCollection = group.par parallelCollection.tasksupport = taskSupport diff --git a/streaming/src/test/scala/org/apache/spark/streaming/UISeleniumSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/UISeleniumSuite.scala index f2204a1870933..957feca2e552d 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/UISeleniumSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/UISeleniumSuite.scala @@ -77,7 +77,12 @@ class UISeleniumSuite inputStream.foreachRDD { rdd => rdd.foreach(_ => {}) try { - rdd.foreach(_ => throw new RuntimeException("Oops")) + rdd.foreach { _ => + // Failing the task with id 15 to ensure only one task fails + if (TaskContext.get.taskAttemptId() % 15 == 0) { + throw new RuntimeException("Oops") + } + } } catch { case e: SparkException if e.getMessage.contains("Oops") => } @@ -166,7 +171,7 @@ class UISeleniumSuite // Check job progress findAll(cssSelector(""".progress-cell""")).map(_.text).toList should be ( - List("4/4", "4/4", "4/4", "0/4 (1 failed)")) + List("4/4", "4/4", "4/4", "3/4 (1 failed)")) // Check stacktrace val errorCells = findAll(cssSelector(""".stacktrace-details""")).map(_.underlying).toSeq